Forest Embeddings Counterfactual

 · 15 mins read

In earlier posts we explored the problem of estimating counterfactual outcomes, one of the central problems in causal inference, and learned that, with a few tweaks, simple decision trees can be a great tool for solving it. In this post, I’ll walk you thorugh the usage of ForestEmbeddingsCounterfactual, one of the main models on the cfml_tools module, and see that it perfectly solves the toy causal inference problem from the fklearn library. You can find the full code for this example here.

Data: make_confounded_data from fklearn

Nubank’s fklearn module provides a nice causal inference problem generator, so we’re going to use the same data generating process and example from its documentation.

# getting confounded data from fklearn
from fklearn.data.datasets import make_confounded_data
df_rnd, df_obs, df_cf = make_confounded_data(50000)
print(df_to_markdown(df_obs.head(5)))
sexageseveritymedicationrecovery
0340.71126
1240.721123
1380.861255
1350.771227
0220.078015

We have five features: sex, age, severity, medication and recovery. We want to estimate the impact of medication on recovery. So, our target variable is recovery, our treatment variable is medication and the rest are our explanatory variables.

A good counterfactual model will tell us how would the recovery time be for each individual for both decisions of taking or not taking medication. The model should be robust to confounders, variables that impact the probability of someone taking the medication, or the effect of taking the medication. For instance, people with higher severity may be more likely to take the medicine. If not properly taken into account, this confounder may lead us to conclude that the medication may make recovery worse: people that took the medication may have worst recovery times (but their condition was already more severe). In the fklearn’s documentation, the data generating process is shown in detail, highlighting the confounders in the data. The effect we’re looking for is $exp(-1) = 0.368$.

The make_confounded_data function outputs three data frames: df_rnd, where treatment assingment is random, df_obs, where treatment assingment is confounded and df_cf, which is the counterfactual dataframe, containing the counterfactual outcome for all the individuals.

Let us try to solve this problem using ForestEmbeddingsCounterfactual!

How ForestEmbeddingsCounterfactual works

In causal inference, we aim to answer what would happen if we made a different decision in the past. This is quite hard because we cannot make two decisions simultaneously, or go back in time and check what would happen if we did things differently. However, what we can do is observe what happened to people who are similar to ourselves and made different choices. We do this all the time using family members, work colleagues, and friends as references.

But what it means to be similar, and most importantly, can similarity be learned? The answer is YES! For instance, when we run a decision tree, more than solving a classification or a regression problem, we’re dividing our data into clusters of similar elements given what features most explain our target. And if we repeat this process, such as building a Random Forest, we’ll note that some samples are more likely to end up together in the same leaf than others. Thus, we can measure similarity by counting at how many trees in the forest two elements ended up together in the same leaf!

ForestEmbeddingsCounterfactual leverages the embedding created by this leaf co-occurrence similarity metric to search for similar elements on the explanatory variables and check how changes on the treatment variable reflect on changes on the target. If we do not have any unobserved variable, we can be confident that the treatment variable really caused changes on the target, since everything else will be controlled.

Let us solve fklearn’s causal inference problem so we can walk through the method.

Easy mode: solving df_rnd

We call solving df_rnd “easy mode” because there’s no confounding, making it easy to estimate counterfactuals without paying attention to it. Nevertheless, it provides a good sanity check for ForestEmbeddingsCounterfactual.

We first organize data in X (explanatory variables), W (treatment variable) and y (target) format, needed to fit ForestEmbeddingsCounterfactual.

# organizing data into X, W and y
X = df_rnd[['sex','age','severity']]
W = df_rnd['medication']
y = df_rnd['recovery']

We then import the class and instantiate it.

# importing cfml-tools
from cfml_tools import ForestEmbeddingsCounterfactual
fecf = ForestEmbeddingsCounterfactual(save_explanatory=True)

I advise you to read the docstring to know about the parameters and make the tutorial easier to follow! Before fitting and getting counterfactuals, a good sanity check is doing 5-fold CV, to test the generalization power of the underlying forest model:

# validating model using 5-fold CV
cv_scores = fecf.get_cross_val_scores(X, y)
print(cv_scores)

[0.55879863 0.5832598 0.58331632 0.58258708 0.56651886]

Here, we have R2 scores in the range of ~0.55, which seem reasonable. However, there’s actually no baseline here: you just need to be confident that the model can capture and generalize relationships between explanatory variables and the target variable. Nevertheless, here are some tips: If your CV metric is too high (R2 very close to 1.00, for instance), it may mean that the treatment variable has no effect on the outcomes, or its effect is “masked” by correlated proxies in the explanatory variables. If your CV metric is too low (e.g. R2 close to 0), it does not mean that the model isn’t useful: the outcome may be explained only by the treatment variable. In this case, since the underlying model is a ExtraTreesRegressor the forest embedding would work as a unsupervised embedding, like sklearn’s RandomTreesEmbedding.

We proceed to fit the model using X, W and y.

# fitting data to our model
fecf.fit(X, W, y)

Calling .fit() builds the forest and creates a nearest neighbor index using leaf co-ocurrence as a similarity metric. For more details on that, check my earlier post about forest embeddings.

We then predict the counterfactuals for all our individuals. By calling .predict(), we get the dataframe in the counterfactuals variable, which stores predictions for both W = 0 and W = 1. The counterfactuals are obtained by querying the nearest neighbor index built on .fit() for n_neighbors and calculating the average outcome given different values of W.

# let us predict counterfactuals for these guys
counterfactuals = fecf.predict(X)
counterfactuals.head()

Then, we can compute treatment effects as follows:

# treatment effects
treatment_effects = counterfactuals['y_hat'][0]/counterfactuals['y_hat'][1]

And compare estimated effects vs real effects:

Cool! As we can see, the model nicely estimated the true effect.

But how can we be sure that the model is performing well? Let us do a quick diagnostic using a visualization of our forest embedding.

Diagnosis and criticism

A good diagnostic is to look at a 2D representation of our forest embedding, using UMAP:

# getting embedding from data
reduced_embed = fecf.get_umap_embedding(X)

Under the hood, UMAP takes our leaf-based similarity metric and creates a 2D representation that tries to preserve it. This shows natural clusters in our data, and surfaces what the model learned in terms of similarity and representation.

The plots show how our explanatory variables are distributed across the embedding. Sex breaks the embedding into two separate clusters, while severity is distributed in a left-right gradient inside each cluster and age follows an up-down gradient. The gradients are smooth enough such that we can be confident that at each local neighborhood individuals are very similar. Thus, the only thing that could explain differences in the outcomes of members of a local neighborhood is dispersion in the treatment variable!

Let us have a look at the outcomes and treatment assignments:

As we can see, treated and not treated individuals are uniformly scattered across the map, and the granular pattern in the right-hand side plot tells us that lack of treatment degrades outcomes for all local neighborhoods. That’s how we compute counterfactuals: comparing how the treatment variable impacts outcomes for similar individuals!

If you still need to dig deeper, ForestEmbeddingsCounterfactual implements a .explain() method so you can see which comparables the model used to calculate counterfactuals for a given sample.

# our test sample
test_sample = X.iloc[[0]]
print(df_to_markdown(test_sample))
sexageseverity
0160.047
# running explanation
comparables_table = fecf.explain(test_sample)

# showing comparables table
print(df_to_markdown(comparables_table.groupby('W').head(5).sort_values('W').reset_index()))
indexsexageseverityWy
115940160.048018
9900160.048015
359090170.04808
194060160.039010
233480170.049010
00160.047131
447250160.049139
430510160.042130
318590170.046143
284980160.034137

As you can see, the model found a lot of “twins” to the test sample with different treatment assignments and outcomes. By looking at the table it becomes crystal clear that the treatment improves outcomes.

Cool, right? The uniform assignment case is good to get the intuition about embeddings and usage of ForestEmbeddingsCounterfactual. Let us move to the case where treatment assignments are not uniformly distributed, making the counterfactual estimation harder (but still doable!).

Hard mode: solving df_obs

Now, we go for the “hard mode” and try to solve df_obs. Now we have confounding, which means that treatment assingment will not be uniform. Nevertheless, we run ForestEmbeddingsCounterfactual like before!

Organizing data in X, W and y format again:

# organizing data into X, W and y
X = df_obs[['sex','age','severity']]
W = df_obs['medication']
y = df_obs['recovery']

Validating the model, as before:

# importing cfml-tools
from cfml_tools import ForestEmbeddingsCounterfactual
fecf = ForestEmbeddingsCounterfactual(save_explanatory=True)

# validating model using 5-fold CV
cv_scores = fecf.get_cross_val_scores(X, y)
print(cv_scores)
[0.9543888  0.95643801 0.95538205 0.95339682 0.92864522]

Here it gets a little bit different. Remember that a high R2 could mean that the treatment variable has little effect on the outcome? As the treatment assignment is correlated with the other variables, they “steal” importance from the treatment and our R2 gets higher in the confounded case. This will become even clearer when we look at the UMAP embedding.

We proceed to fit the model using X, W and y.

# fitting data to our model
fecf.fit(X, W, y)

We then predict the counterfactuals for all our individuals.

In this case, we can see some NaNs. That’s because some individuals do not have enough treated or untreated neighbors to estimate the counterfactuals, controlled by the parameter min_sample_effect. When this parameter is high, we are conservative, getting more NaNs but less variance in counterfactual estimation.

# let us predict counterfactuals for these guys
counterfactuals = fecf.predict(X)
counterfactuals.head()

Comparing true effect with estimated:

Nice! The model estimated the effect very well again. Note that we have less samples in the histogram, due to NaNs. Nevertheless, it is a cool result and shows that ForestEmbeddingsCounterfactual can work with confounded data.

Diagnosis and criticism

Now things get interesting. Let us check how our 2D embedding changes with the confounding effect:

# getting embedding from data
reduced_embed = fecf.get_umap_embedding(X)

The distribution of explanatory variables across the embedding gets a little bit different. Sex still breaks the embedding into two separate clusters, but severity is now distributed in a up-down gradient inside each cluster and age follows a left-right gradient. The gradients are still smooth enough such that we can be confident that at each local neighborhood individuals are very similar.

Let us have a look at the outcomes and treatment assignments:

The effect of confounding gets very clear at this point. At the upper left-hand side, we can see that except for a small region at the center of our clusters, there’s no mix of treated and untreated individuals. This lack of mixing makes difficult to estimate counterfactuals, as there’s no similar individuals with different treatment assigments. This makes a lot of the effects invalid, as we see in the lower left-hand side. However, the lack of mixing also made the outcomes more predictable, as we can see in the right-hand side (the embedding is very homogenoeous with respect to the outcome).

If you backtrack a little bit, you’ll notice that we can only predict counterfactuals for people of average severity! That’s why I really like using embeddings for explaining models: they are a quick and visual diagnosis of what your model learned, and can be used to extract knowledge from it for a lot of purposes including causal inference.

And again, if you want to dive deeper just use .explain(). In this case, we query for an individual with high severity and our comparables only have people who were treated, making inference not feasible:

# our test sample
test_sample = X.query('severity > 0.95').iloc[[0]]
print(df_to_markdown(test_sample))
sexageseverity
0430.95
# running explanation
comparables_table = fecf.explain(test_sample)

# showing comparables table
print(comparables_table['W'].value_counts())
print(df_to_markdown(comparables_table.groupby('W').head(5).sort_values('W').reset_index()))
indexsexageseverityWy
260430.951197
65260430.951207
365820430.951211
210530430.961163
384840440.951189

I hope you liked the tutorial and will use cfml_tools for your causal inference problems soon!