“Many machine decisions are still poorly understood”, says IEEE Fellow,Cuntai Guan. Many papers even suggest a rigid dichotomy between accuracy and interpretability.
Explainable AI (XAI) attempts to bridge this divide, but as we explain below, XAI justifies decisions without interpreting the model directly. This means practitioners in applications such as finance and medicine are forced into a dilemma: pick an uninterpretable, accurate model or an inaccurate, interpretable model.
Defining explainability or interpretability for computer vision is challenging: What does it even mean to explain a classification for high-dimensional inputs like images? As we discuss below, two popular definitions involve saliency maps and decision trees, but both approaches have their weaknesses.
Many XAI methods produce heatmaps known as saliency maps, which highlight important input pixels that influence the prediction. However, saliency maps focus on the input and neglect to explain how the model makes decisions.
To illustrate why saliency maps do not fully explain how the model predicts, here is an example: Below, the saliency maps are identical, but the predictions differ. Why? Even though both saliency maps highlight the correct object, one prediction is incorrect. How? Answering this could help us improve the model, but as shown below, saliency maps fail to explain the model’s decision process.
Another approach is to replace neural networks with interpretable models. Before deep learning, decision trees were the gold standard for accuracy and interpretability. Below, we illustrate the interpretability of decision trees, which works by breaking up each prediction into a sequence of decisions.
For accuracy, however, decision trees lag behind neural networks by up to 40% accuracy on image classification datasets². Neural-network-and-decision-tree hybrids also underperform, failing to match neural networks on even the dataset CIFAR10, which features tiny 32x32 images like the one below.
We challenge this false dichotomy by building models that are both interpretable and accurate. Our key insight is to combine neural networks with decision trees, preserving high-level interpretability while using neural networks for low-level decisions, as shown below. We call these models Neural-Backed Decision Trees (NBDTs) and show they can match neural network accuracy while preserving the interpretability of a decision tree.
NBDTs are as interpretable as decision trees. Unlike neural networks today, NBDTs can output intermediate decisions for a prediction. For example, given an image, a neural network may output Dog. However, an NBDT can output both Dog and Animal, Chordate, Carnivore (below).
NBDTs achieve neural network accuracy. Unlike any other decision-tree-based method, NBDTs match neural network accuracy (< 1% difference) on 3 image classification datasets³. NBDTs also achieve accuracy within 2% of neural networks on ImageNet, one of the largest image classification datasets with 1.2 million 224x224 images.
Furthermore, NBDTs set new state-of-the-art accuracies for interpretable models. The NBDT’s ImageNet accuracy of 75.30% outperforms the best competing decision-tree-based method by a whole ~14%. To contextualize this accuracy gain: A similar gain of ~14% for non-interpretable neural networks took 3 years of research.
The most insightful justifications are for objects the model has never seen before. For example, consider an NBDT (below), and run inference on a Zebra. Although this model has never seen Zebra, the intermediate decisions shown below are correct — Zebras are both Animals and Ungulates (hoofed animal). The ability to see justification for individual predictions is quintessential for unseen objects.
Furthermore, we find that with NBDTs, interpretability improves with accuracy. This is contrary to the dichotomy in the introduction: NBDTs not only have both accuracy and interpretability; they also make both accuracy and interpretability the same objective.
For example, the lower-accuracy ResNet⁶ hierarchy (left) makes less sense, grouping Frog, Cat, and Airplane together. This is “less sensible,” as it is difficult to find an obvious visual feature shared by all three classes. By contrast, the higher-accuracy WideResNet hierarchy (right) makes more sense, cleanly separating Animal from Vehicle — thus, the higher accuracy, the more interpretable the NBDT.
With low-dimensional tabular data, decision rules in a decision tree are simple to interpret e.g., if the dish contains a bun, then pick the right child, as shown below. However, decision rules are not as straightforward for inputs like high-dimensional images.
As we qualitatively find in the paper (Sec 5.3), the model’s decision rules are based not only on object type but also on context, shape, and color.
To interpret decision rules quantitatively, we leverage an existing hierarchy of nouns called WordNet⁷; with this hierarchy, we can find the most specific shared meaning between classes. For example, given the classes Cat and Dog, WordNet would provide Mammal. In our paper (Sec 5.2) and pictured below, we quantitatively verify these WordNet hypotheses.
Note that in small datasets with 10 classes i.e., CIFAR10, we can find WordNet hypotheses for all nodes. However, in large datasets with 1000 classes i.e., ImageNet, we can only find WordNet hypotheses for a subset of nodes.
The training and inference process for a Neural-Backed Decision Tree can be broken down into four steps.
Construct a hierarchy for the decision tree. This hierarchy determines which sets of classes the NBDT must decide between. We refer to this hierarchy as an Induced Hierarchy.
This hierarchy yields a loss function, that we call the Tree Supervision Loss. Train the original neural network, without any modifications, using this new loss.
Start inference by passing the sample through the neural network backbone. The backbone is all neural network layers before the final fully-connected layer.
Finish inference by running the final fully-connected layer as a sequence of decision rules, which we call Embedded Decision Rules. These decisions culminate in the final prediction.
Explainable AI does not fully explain how the neural network reaches a prediction: Existing methods explain the image’s impact on model predictions but do not explain the decision process. Decision trees address this, but unfortunately, images⁷ are kryptonite for decision tree accuracy.
Your email address will not be published. Required fields are marked *