The teller is a model-agnostic tool for Machine Learning (ML) explainability. For more details on this tool, you can visit the following introductory links: here, here, here or here.

In this post, we are going to use teller to examine model interactions. We’ll use the Boston Housing dataset, again. Here, by interactions, we mean: how does the response variable (variable to be explained) changes when both explanatory variable 1 increases of 1, and explanatory variable 2 increases of 1. On this specific dataset, it would mean for example: understanding how the median value of owner-occupied homes (in $1000’s) changes, when the index of accessibility to radial highways and the number of rooms per dwelling increase of 1 simultaneously.

Install the package and import data

Currently, the teller’s development version can be obtained from Github as:

!pip install git+https://github.com/Techtonique/teller.git

Model training and explanations

import numpy as np
from os import chdir

import teller as tr
import pandas as pd

from sklearn import datasets
import numpy as np      
from sklearn import datasets
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split


# import data
boston = datasets.load_boston()
X = np.delete(boston.data, 11, 1)
y = boston.target
col_names = np.append(np.delete(boston.feature_names, 11), 'MEDV')


# split  data into training and testing sets 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
                                                    random_state=123)
print(X_train.shape)
print(X_test.shape)


# fit a linear regression model 
regr = RandomForestRegressor(n_estimators=1000, random_state=123)
regr.fit(X_train, y_train)


# creating the explainer
expr = tr.Explainer(obj=regr)

Interactions with index of accessibility to radial highways (RAD):

varx = "RAD"
expr.fit(X_test, y_test, X_names=col_names[:-1], 
         y_name=col_names[-1], 
         col_inters = varx, method="inters")
print(expr.summary())
Interactions with RAD: 
            Estimate   Std. Error   95% lbound   95% ubound  Pr(>|t|)     
CRIM               0  2.22045e-16  4.40477e-16 -4.40477e-16         1    -
ZN      -3.37992e-11  2.22045e-16 -3.37988e-11 -3.37997e-11         0  ***
INDUS   -7.06339e-11  4.33823e-11  1.54248e-11 -1.56693e-10  0.106603    -
CHAS               0  2.22045e-16  4.40477e-16 -4.40477e-16         1    -
NOX     -3.80587e-10  4.33823e-10  4.80001e-10 -1.24117e-09  0.382413    -
RM           4.05396  3.23892e-11      4.05396      4.05396         0  ***
AGE      4.08624e-12  3.47213e-12   1.0974e-11 -2.80153e-12  0.242015    -
DIS                0  2.22045e-16  4.40477e-16 -4.40477e-16         1    -
RAD                0  2.22045e-16  4.40477e-16 -4.40477e-16         1    -
TAX        -0.032699  2.22045e-16    -0.032699    -0.032699         0  ***
PTRATIO    -0.889095     0.889095      0.87463     -2.65282    0.3197    -
LSTAT    5.03253e-11  2.22045e-16  5.03258e-11  5.03249e-11         0  ***

Interactions with average number of rooms per dwelling (RM):

varx = "RM"
expr.fit(X_test, y_test, X_names=col_names[:-1], 
         y_name=col_names[-1], 
         col_inters = varx, method="inters")
print(expr.summary())
Interactions with RM: 
            Estimate   Std. Error   95% lbound   95% ubound      Pr(>|t|)     
CRIM               0  2.22045e-16  4.40477e-16 -4.40477e-16             1    -
ZN      -7.41298e-12  2.22045e-16 -7.41254e-12 -7.41342e-12             0  ***
INDUS   -3.79269e-11    1.963e-11  1.01374e-12 -7.68675e-11     0.0561504    .
CHAS               0  2.22045e-16  4.40477e-16 -4.40477e-16             1    -
NOX     -3.13992e-10  2.12499e-10  1.07549e-10 -7.35532e-10      0.142622    -
RM                 0  2.22045e-16  4.40477e-16 -4.40477e-16             1    -
AGE        -0.521402  1.65828e-12    -0.521402    -0.521402             0  ***
DIS       5.2504e-11  2.22045e-16  5.25044e-11  5.25036e-11             0  ***
RAD          4.05396  3.23892e-11      4.05396      4.05396             0  ***
TAX      9.58314e-13  2.22045e-16  9.58755e-13  9.57874e-13  9.44012e-268  ***
PTRATIO  6.77272e-12  7.06388e-12  2.07856e-11 -7.24011e-12      0.339958    -
LSTAT        2.70129  3.97481e-11      2.70129      2.70129             0  ***

A notebook containing these results can be found here. Contributions/remarks are welcome as usual, you can submit a pull request on Github.

Note: I am currently looking for a gig. You can hire me on Malt or send me an email: thierry dot moudiki at pm dot me. I can do descriptive statistics, data preparation, feature engineering, model calibration, training and validation, and model outputs’ interpretation. I am fluent in Python, R, SQL, Microsoft Excel, Visual Basic (among others) and French. My résumé? Here!

Comments powered by Talkyard.