Add interpretability to Convolutional Neural Networks through individually trained layers
- Yang Cheng
- Oct 24, 2022
- 2 min read
Updated: Oct 30, 2022
I found a paper that trained individual nodes at the end. Then why not try the layers?

It starts with the concept bottleneck model. In their model, they use the model to predict the properties/traits of the targets. Then they trained another model to predict the final results. Doctors can better trust the predictions of the model.

The concepts saved in a node are limited. Why not have more information and use them to predict the final result? The first version of the improved model is like this:

This model is a reversion of the mutitaskNAM. This model adds none linearity to single nodes to improve interpretability. I adapted their model to achieve multi-task classification. However, the performance is much poorer than the concept bottleneck model. That might be why they only include 1 to 3 tasks in their paper.
The model that works is the second model, called the concept-channel concept bottleneck model (the picture on the top). In that model, I connect each concept of the concept bottleneck model is connected to a channel. One layer of CNN may have 32 to 128 channels. I adapted one of the layers, in which every channel was used to predict one subject's property. For example, one channel is used to predict yellow feet. Another channel is trained by the color of the wings. This model achieves similar, if not better, performance than the original Concept bottleneck model. The intuition is, compared with nodes that have been randomly permutated, layers may be able to capture more spatial information of the data.
I expect to extract some information from channels to generate pictures like this:

Sadly I could generate anything meaningful from my model. It doesn't improve the saliency map either. So it doesn't perform better in interpretability. However, I still believe this model has some potential to explore. Just as I started my experiment, someone had already finished it with a similar idea as mine -- Interpretable convolution neural network. Unlike me, I naively trained individual channels. They implemented special loss functions for each channel which achieved better visualization results.
Comentários