March 27, 2023

A complete guide on Computer Vision XAI libraries

A complete guide on Computer Vision XAI libraries

Introduction

An increasing prerequisite for applying a machine learning model is to confirm its validity using explainable AI techniques. Writing the code responsible for this aspect of a project is a time-consuming and complicated task. Fortunately for machine learning engineers, there are several open-source libraries whose main purpose is to explain trained models using numerous algorithms. 

This post will look at the most popular libraries dedicated to computer vision and identify their advantages and disadvantages. 

Methodology

I used 4 criteria in the selection of libraries. First, the library must be written for the Python language. Python is the most popular language for use in machine learning. For this language, libraries for training neural networks used by a large community of engineers and researchers have been created. Second, it must relate to the task of computer vision. Third, I took into account the libraries most popular in terms of the number of stars on GitHub. The last criterion was active maintenance of the repository. The last commit should be at most 1 year ago at the time of writing this article. The exception here is the M3d-CAM library, whose last commit dates back almost 2 years.

List of described repositories sorted by descending number of GitHub ★:

  1. pytoch-grad-cam
  2. captum
  3. pytorch-cnn-visualizations
  4. torch-cam
  5. innvestigate
  6. tf-explain
  7. OmniXAI
  8. xplique
  9. M3d-Cam

Have your coffee ready, as this will be a blog post of the longer kind. Enjoy the reading!

If you quickly see the library comparison, you can check the tables with a summary of their features in the next section of the article.

Comparison

Repository features

Algorithms

Metrics

pytorch-grad-cam 6.7k

source: https://github.com/jacobgil/pytorch-grad-cam

The repository focused mainly on gradient class activations maps (CAM) algorithms for PyTorch computer vision models. It contains 11 different algorithms and was tested on distinct architectures: CNNs and Vision Transformers. According to the authors, it can be used for multiple computer vision tasks, like classification, object detection, semantic segmentation, and many more. An attractive solution is using attribute heatmap smoothing, which gives the user a smoother image. Another advantage is that repository contains the model’s trust metrics.Algorithms

The library contains attribute methods at the level of the entire model, at the level of a single layer, and the deep feature factorization method. The repository has metrics for XAI methods. 

Visualization

In the methods used for visualization, the developers focused on overlaying the heatmap on top of the original image. There are also examples of visualization of the heatmap attributes for semantic segmentation and object detection tasks.

Documentation

The library contains quite detailed documentation with numerous examples and tutorials. Example results of the algorithms are presented, which allows you to gain a general understanding of what the user can expect. You can find examples here. However, there is a lack of a standard API description where the user can check the method’s parameters.

Captum 3.7k★

source: https://github.com/pytorch/captum/blob/master/website/static/img/captum_insights_screenshot.png.

In their own words, the library's creators write about their project: "Captum is a model interpretability and understanding library for PyTorch." This sentence sums up what this library is for. It contains an implementation of some algorithms, the implementation of which is task-independent, i.e. it can be used in both NLP, computer vision, and other tasks.

Algorithms

The library contains attribute methods at the level of the entire model, at the level of a single layer and a single neuron. In addition, it includes perturbation and noise tunnel methods to calculate the model’s attributes.

Visualization

The library's authors describe a widget for interactive visualization of model explanations called "Captum insights". Unfortunately, I was not able to run the example they provided after installing the package, both by `pypi` and by installing with the `setup.py` script from a cloned repository.

In the methods used for visualization, the developers focused on visualizing attributes next to the original image instead of overlaying the heatmap on top of the original image.

Documentation

The authors took care to provide clear documentation and examples of how to use the library. The first tab presents details of each algorithm, its description, and a summary comparison of the algorithms. The tab with tutorials contains a rather rich collection of examples of using the library for different domains (CV, NLP) and tasks (image classification and semantic segmentation). The examples cover attribute explanation methods and techniques, such as adversarial attacks, influential examples, and concept-based explanations. The library also includes several metrics and methods to check the robustness of the model. And there is also an API reference.

pytorch-cnn-visualizations 3.7k★

source: https://github.com/utkuozbulak/pytorch-cnn-visualizations

This repository contains a set of examples with demonstration capability of different explainable algorithms for CNN networks. Important: this is not a library. If you would like to use this repository in your project you would have to, based on the provided code, write appropriate functions and classes yourself that perform similar operations. Another thing is dependencies which are quite archaic: torch in version 0.4.1, etc.

Interface

No API, just a bunch of scripts with examples.

Algorithms

The library contains attribute methods at the level of the entire model and the level of a single layer. Additionally, the repository contains algorithms for feature visualization.

Visualization

In the scripts used as examples, the artifacts are stored as attribute images or heat maps overlaid on the original image.

Documentation

No documentation.

torch-cam 1.2k★

source: https://github.com/frgfm/torch-cam

Another library focuses on CAM methods for computer vision. It is easy to use and has 9 different algorithms implemented.

Algorithms

The library contains attribute methods at the level of the entire model and the level of a single layer.

Visualization

There is one method to visualize the attributes of the model. It is a function to display an attribute heatmap overlaying the original image.

Documentation

The repository has documentation published using Github Pages. The documentation includes instructions for installing the library, code examples to run explanatory algorithms, tutorials, and a description of the API.

innvestigate 1.1k★

source: https://github.com/albermax/innvestigate/blob/master/examples/images/analysis_grid.png

It is a library designed to facilitate the analysis of machine learning models for Keras/TensorFlow2.

Algorithms

The library contains attribute methods at the level of the entire model. They can be used only for image classification tasks. Among other algorithms present in almost every presented, this library contains multiple variants of the LRP algorithm.

Visualization

There are a few methods to visualize the model's attributes, including a heatmap.

Documentation

The repository has published only API reference but contains multiple tutorials in the `examples` directory.

tf-explain 964★

source: https://github.com/sicara/tf-explain

A library implementing interpretability methods for TensorFlow 2 as a trainer callback. It has 7 explainable algorithms implemented. The documentation includes convenient examples of how to use explainers and the effect in the form of an image with resulting attributes. What sets this library apart from others is the integration with the training process and the ability to save artifacts of explanatory algorithms in the training monitoring tool, in this case, in the TensorBoard. This feature can be beneficial in diagnosing problems during the network training process.

Interface

In this case, the library also defines a simple unified, and parameterizable interface to all explanation algorithms and metrics.

A second interface is also defined for integration with the training process. Each class implementing the explanation algorithm is used as a callback (inherits from `tensorflow.keras.callbacks.Callback`) to the trainer.

Algorithms

The library contains attribute methods at the level of the entire model. They can be used only for image classification tasks.

Visualization

There are two methods to visualize the model's attributes: as an attribute image or as an attributes heatmap overlaying the original image.

Documentation

It has documentation published with ReadtheDocs with basic usage examples and core API description, but unfortunately, it has bugs and the API reference section is empty.

OmniXAI 517★

source: https://github.com/salesforce/OmniXAI/blob/main/docs/_static/ml_pipeline.png

This library provides XAI algorithms to explain machine learning models written in PyTorch and TensorFlow. It works for tabular data, text, images, and a time series on classification and regression tasks.

Interface

In this case, API is more complicated. The library provides several wrapper classes for handling different data types, such as tabular data, text, images, etc. Each explainer class for a different data type accepts plenty of parameters: explainer name list, task type, preprocessing function, and additional parameters.

Algorithms

The repository contains explanatory algorithms for many types of data, but in this blog post, we focus only on computer vision algorithms. In this domain, we can find 8 different algorithms.

Visualization

The repository contains a Dash app for interactive explanations visualization. Authors based visualizations on a `plotly` package. They support displaying attributes as a standalone image and as an overlay layer on the original image.

Documentation

The repository has documentation published using ReadtheDocs. The documentation includes a guide through library design, architecture, installation process, examples, and API reference.

Xplique 348★

source: https://github.com/deel-ai/xplique

This is an option for projects using the TensorFlow framework. The words of the library's authors describe well what they wanted to achieve: "Xplique (...) is a Python language toolkit dedicated to explainability, currently based on Tensorflow. The goal of this library is to bring together the state of the art in Explainable AI to help you understand your complex neural network models." In addition to feature attributes and concept-based explanations, the library also includes functionality for visualizing features learned by the model.

Algorithms

The library contains attribute methods at the level of the entire model and the level of a single layer. In addition, it includes metrics for how much we can trust the methods that explain the model's decisions.

Visualization

In the methods used for visualization, the developers focused on overlaying the heatmap on top of the original image.

Documentation

The authors provided clear documentation and plenty of examples of how to use the library. The developers have limited themselves to a more detailed description of the API, which includes information on how the various algorithms and metrics work. 

A quite comprehensive collection of interactive examples in Google Colaboratory deserves appreciation. Good job! 👏 The examples introduce the use of the library for visualizing attributes, features, concepts, and calculating metrics. What's more, each metric and method has its dedicated example, in which creators describe the algorithm itself, the API and possible parameters, FAQs, and share insights on how each parameter affects the algorithm's performance. This is condensed knowledge of the details of how to use the algorithm in the right way.

M3d-Cam 216★

source: https://github.com/MECLabTUDA/M3d-Cam

This is the PyTorch library that allows an explanation of the model’s decision of 3D/ 2D medical images for both classification and segmentation tasks.

Algorithms

The library contains 4 algorithms: guided backpropagation, GradCAM, Guided GradCAM, and GradCAM++.

Visualization

There are two methods to visualize the model's attributes: as an attribute image or as an attributes heatmap overlaying the original image.

Documentation

The repository has documentation published. The documentation includes only the API reference. The README contains links to a few tutorials on Google Colab.

Summary

The libraries mentioned and described vary greatly. Some are general-purpose, and some of them specialize only in one type of task, algorithm, or specific data type. Currently, there is no universal library that simultaneously contains all the algorithms and metrics. When choosing a library, it is necessary to decide what features are required in your current project.This indicates that there is room for new libraries which will solve this problem. It is noticeable that there is some space to be developed by new libraries.

Few libraries provide metrics that determine the degree to which we can trust explanatory algorithms. There is a lack of integration with experiment trackers, which would allow users to monitor the training process of the network. Moreover, almost all libraries are limited to the simple case of image classification. Support for the segmentation task is lacking, and support for the object detection task is severely limited.