Jan Disselhoff

PhD Student | JGU Mainz


Explainability using Masks | Jan Disselhoff

Explainability using Masks

September 11, 2023

Introduction

The concept of explainability in machine learning is often ill-defined. Most research in this field presents results without an objective metric. This is not due to a lack of desire for objective testing, but rather the absence of an established objective measure.

There are several reasons for this ambiguity.

What is our objective?

The term ‘explainability’ can mean different things to different people. Some interpret explainability as a desire to avoid a ‘black box’ scenario. In other words, they want to understand why a specific prediction was made by a particular network. For instance, if a network incorrectly identifies a cat as a dog, they want to know why. If a network fails to recognize a traffic sign, they want to understand the reason. This interpretation is particularly relevant when networks are applied to real-world data. If a network’s performance suddenly drops, it would be beneficial to understand why and identify potential mitigation strategies, such as adjusting the camera angle, stabilizing the frame, or improving image centering. In this context, explainability implies a desire to better understand the workings of the neural network.

However, others interpret explainability as a desire to understand the process behind a prediction. This interpretation is often more relevant in research applications. For example, if a neural network can predict cancer with 99% accuracy, it would be incredibly valuable to understand how this is achieved. In this context, explainability implies a desire to better understand the underlying phenomena.

Both interpretations are valid, but they can lead to contrasting results. For instance, the first interpretation naturally focuses more on incorrect predictions, while the second focuses on correct predictions.

How do we achieve explainability?

Assuming we have chosen one of these interpretations, what does it mean to have an explanation? One current theoretical approach suggests that we desire a simpler, approximate model. This is often implicitly achieved using linear models.

Given a neural network $f(x;\theta)$ an explanation for the prediction of a sample $x_0$ is given by a simplified function $g(x)$, such that $g(x)\sim f(x;\theta)$ for values of $x$ close to $x_0$

While this definition does indeed capture some of the behavior of a network, it does not seem to coincide with what we as humans would interpret as an “explanation”.

Nevertheless most saliency methods are variants of this approach: For example, GradCam can be viewed as a Taylor expansion around a given sample. While sometimes useful, recent papers argue that gradient-based methods fail to capture any relevant information of the network, making them misleading for the first interpretation of explainability.

Furthermore, our own research has found that gradient-based methods are unstable when averaged over several trained networks. This suggests that they do not capture information about the underlying phenomena either. (see TODO LINK)

While this might be an issue of the specific techniques used, it might also be an issue of the underlying approach used. Maybe a linear approximation is not able to give us the information we desire.

Alternatives

Another alternative is given by a masking approach. Instead of working with the network, we treat it as a black box and try to manipulate the image instead. If we remove parts of the image, but the prediction stays the same, we could argue that the pixels removed were not neccessary for the prediction. So we can ask: What is the smallest set of (contiguous) pixels from the image that we need to correctly predict the class?

Basically this asks the question which “cutout” from the image is responsible for the class. In my opinion this is better than saliency methods but still has a small problem: Deep Neural Networks are infamous for being sensitive to even small perturbations. Removing many pixels will push the image out of the distribution of interest to us. A simple alternative would be to replace some pixels with a neutral background. In that case we can write the approach in the following way:

Given an Input $X\in D$ find a projection $P_X$ of $X$, such that other samples $Y$ from the Domain $D$ with $P_X=P_Y$ have strong class correlations.

Basically this means that all images from a dataset of natural images that contain a similar set of pixels should have the same class.