library(tabnet)
library(dplyr)
library(purrr)
library(rsample)
library(yardstick)
library(ggplot2)
library(patchwork)
set.seed(202402)Interprestability is a lightweight evolution of Tabnet network design that provides, among other, a stability score of the interpretation mask provided through the tabnet_explain() function.
In this vignette, we will try to improve the workflow on Ames dataset debuted in the Training a Tabnet model from missing-values dataset vignette.
Interprestability score associated with tabnet_explain() results, will help us to select more stable models.
Interprestability score is a metric for the stability of mask between models: score over 0.9 relates very-high stability, between 0.7 and 0.9 is high stability, between 0.5 and 0.7 is moderate and between 0.3 and 0.5 low stability of the interpretation on the model.
The {tabnet} implementation compares the explainability parameters between the last 5 model checkpoints. So it is up to you to make those last 5 checkpoints a good proxy of the model.
Let’s experiment those on a pretraining scenario on the ames dataset :
We will work here with the ames_missing dataset, transformation of the ames dataset done in vignette Training a Tabnet model from missing-values dataset.
data("ames_missing", package = "tabnet")
ames_split <- initial_split(ames_missing, strata = Sale_Price, prop = 0.8)
ames_train <- training(ames_split)
ames_test <- testing(ames_split)cat_emb_dim <- map_dbl(ames_train %>% select_if(is.factor), ~max(log(nlevels(.x)) %>% floor, 1))
tabnet_config <- tabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
valid_split = 0.2)
train_tabnet <- list(tabnet_fit(Sale_Price ~., data = ames_train,
epochs = 100, checkpoint_epoch = 101,
config = tabnet_config, learn_rate = 5e-2))The difference of interpretabnet models is - the presence of a 3 layers MLP as adaptation layer in between the tabnet steps. - the encoder layer use a Multibranch Weighted Linear-Unit implemented as nn_mb_wlu() - the mask type switch to entmax
You don’t have to know them, as the handy interpretabnet_config() function comes as a replacement of tabnet_config() to switch them all at once.
For the training loop, you will realize that Interpretabnet models requires much more epochs to converge due to the MLP network.
interpretabnet_config <- interpretabnet_config( cat_emb_dim = cat_emb_dim, verbose = FALSE,
early_stopping_patience = 12L, early_stopping_tolerance = 0.01,
valid_split = 0.2, learn_rate = 5e-2, lr_scheduler = "step",
lr_decay = .7, step_size = 5)
train_tabnet <- c(train_tabnet,
map(1:3, ~tabnet_fit(Sale_Price ~., data = ames_train,
epochs = 150, checkpoint_epoch = 151,
config = interpretabnet_config),
.progress = TRUE)
)
#> ■■■■■■■■■■■ 33% | ETA: 4m[K
■■■■■■■■■■■■■■■■■■■■■ 67% | ETA: 1m[K
[Kautoplot(train_tabnet[[1]]) +
autoplot(train_tabnet[[2]]) +
autoplot(train_tabnet[[3]]) +
autoplot(train_tabnet[[4]]) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")With a small learning-rate, we will extend the fitted model for 5 epochs in order to measure the Inteprestability.
models_checkpointed <- map(train_tabnet, ~tabnet_fit(Sale_Price ~., data = ames_train,
tabnet_model = .x, epochs = 6, valid_split = 0.2,
checkpoint_epoch = 1, learn_rate = 1e-2),
.progress = TRUE)We can now have a closer look at their training convergence plot:
autoplot(models_checkpointed[[1]]) +
autoplot(models_checkpointed[[2]]) +
autoplot(models_checkpointed[[3]]) +
autoplot(models_checkpointed[[4]]) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")explain_lst <- map(models_checkpointed, tabnet_explain, new_data = ames_train)
interprestability <- map_dbl(explain_lst, "interprestability")
interprestability
#> [1] 0.9947170 0.9948593 0.9942322 0.9847693
autoplot(explain_lst[[1]], quantile = 0.99) +
autoplot(explain_lst[[2]], quantile = 0.99) +
autoplot(explain_lst[[3]], quantile = 0.99) +
autoplot(explain_lst[[4]], quantile = 0.99) +
plot_layout(axes = "collect", guides = "collect") &
theme(legend.position = "bottom")