Calculating counterfactuals with decision trees

 · 14 mins read

In earlier posts we explored the problem of estimating counterfactual outcomes, one of the central problems in causal inference, and learned that, with a few tweaks, simple decision trees can be a great tool for solving it. In this post, I’ll walk you thorugh the usage of DecisionTreeCounterfactual, one of the main models on the cfml_tools module, and see that it perfectly solves the toy causal inference problem from the fklearn library. You can find the full code for this example here.

Data: make_confounded_data from fklearn

Nubank’s fklearn module provides a nice causal inference problem generator, so we’re going to use the same data generating process and example from its documentation.

# getting confounded data from fklearn
from fklearn.data.datasets import make_confounded_data
df_rnd, df_obs, df_cf = make_confounded_data(50000)
print(df_to_markdown(df_obs.head(5)))
sexageseveritymedicationrecovery
0340.71126
1240.721123
1380.861255
1350.771227
0220.078015

We have five features: sex, age, severity, medication and recovery. We want to estimate the impact of medication on recovery. So, our target variable is recovery, our treatment variable is medication and the rest are our explanatory variables.

A good counterfactual model will tell us how would the recovery time be for each individual for both decisions of taking or not taking medication. The model should be robust to confounders, variables that impact the probability of someone taking the medication, or the effect of taking the medication. For instance, people with higher severity may be more likely to take the medicine. If not properly taken into account, this confounder may lead us to conclude that the medication may make recovery worse: people that took the medication may have worst recovery times (but their condition was already more severe). In the fklearn’s documentation, the data generating process is shown in detail, highlighting the confounders in the data. The effect we’re looking for is $exp(-1) = 0.368$.

The make_confounded_data function outputs three data frames: df_rnd, where treatment assingment is random, df_obs, where treatment assingment is confounded and df_cf, which is the counterfactual dataframe, containing the counterfactual outcome for all the individuals.

Let us try to solve this problem using DecisionTreeCounterfactual!

How DecisionTreeCounterfactual works

In causal inference, we aim to answer what would happen if we made a different decision in the past. This is quite hard because we cannot make two decisions simultaneously, or go back in time and check what would happen if we did things differently. However, what we can do is observe what happened to people who are similar to ourselves and made different choices. We do this all the time using family members, work colleagues, and friends as references.

But what it means to be similar, and most importantly, can similarity be learned? The answer is YES! For instance, when we run a decision tree, more than solving a classification or a regression problem, we’re dividing our data into clusters of similar elements given what features most explain our target. Thus, a decision tree works like a researcher deciding which variables to control to get the best effect estimate!

DecisionTreeCounterfactual leverages this supervised clustering approach and checks how changes on the treatment variable reflect on changes on the target given clusters determined by the explanatory variables that most impact the target. If we do not have any unobserved variable, we can be confident that the treatment variable really caused changes on the target, since everything else will be controlled.

Let us solve fklearn’s causal inference problem so we can walk though the method.

Easy mode: solving df_rnd

We call solving df_rnd “easy mode” because there’s no confounding, making it easy to estimate counterfactuals without paying attention to it. Nevertheless, it provides a good sanity check for DecisionTreeCounterfactual.

We first organize data in X (explanatory variables), W (treatment variable) and y (target) format, needed to fit DecisionTreeCounterfactual.

# organizing data into X, W and y
X = df_rnd[['sex','age','severity']]
W = df_rnd['medication']
y = df_rnd['recovery']

We then import the class and instantiate it.

# importing cfml-tools
from cfml_tools import DecisionTreeCounterfactual
dtcf = DecisionTreeCounterfactual(save_explanatory=True)

I advise that you read the docstring to know about the parameters and make the tutorial easier to follow!

Before fitting and getting counterfactuals, a good sanity check is doing 5-fold CV, to test the generalization power of the underlying tree model:

# validating model using 5-fold CV
cv_scores = dtcf.get_cross_val_scores(X, y)
print(cv_scores)

[0.54723148 0.57086291 0.56644823 0.56601209 0.543017]

Here, we have R2 scores in the range of ~0.55, which seem reasonable. However, there’s actually no baseline here: you just need to be confident that the model can capture and generalize relationships between explanatory variables and the target variable. Nevertheless, here are some tips: If your CV metric is too high (R2 very close to 1.00, for instance), it may mean that the treatment variable has no effect on the outcomes, or its effect is “masked” by correlated proxies in the explanatory variables. If your CV metric is too low (e.g. R2 close to 0), it does not mean that the model isn’t useful: the outcome may be explained only by the treatment variable.

We proceed to fit the model using X, W and y.

# fitting data to our model
dtcf.fit(X, W, y)

Calling .fit() builds a decision tree, solving the regression problem from X to y. But we actually use the decision tree as a supervised clustering algorithm. Each leaf of the tree determines a cluster of similar elements given the explanatory variables that most impact the target. Thus, we can calculate counterfactuals at the cluster level, by comparing the outcome of its elements for different W. .fit() is done when we have a table with counterfactuals by the tree’s leaves:

# showing counterfactual training table
print(df_to_markdown(dtcf.leaf_counterfactual_df.reset_index().head(6)))
leafWycount
701173
713062
901251
913154
1001369
10134102

We then predict the counterfactuals for all our individuals. By calling .predict(), we get the dataframe in the counterfactuals variable, which stores predictions for both W = 0 and W = 1. The counterfactuals are obtained running the samples in the tree, checking which cluster they’ve been assigned to, and querying the leaf_counterfactual_df built at .fit() for the outcome given different values of W.

# let us predict counterfactuals for these guys
counterfactuals = dtcf.predict(X)
counterfactuals.head()

Then, we can compute treatment effects by using the counterfactual information:

# treatment effects
treatment_effects = counterfactuals['y_hat'][0]/counterfactuals['y_hat'][1]

And compare estimated effects vs real effects:

Cool! As we can see, the model nicely estimated the true effect.

This seems rather “black-boxy”. How can we trust the counterfactual predictions? We can use the DecisionTreeCounterfactual’s .explain() method! For a given test sample, it returns a table of comparable individuals with their treatment assignments and outcomes!

# our test sample
test_sample = X.iloc[[0]]
print(df_to_markdown(test_sample))
sexageseverity
0160.047
# running explanation
comparables_table = dtcf.explain(test_sample)

# showing comparables table
print(df_to_markdown(comparables_table.groupby('W').head(5).sort_values('W').reset_index()))
indexsexageseverityWy
14440150.049013
21800150.083010
23790150.045013
33880160.078010
40360150.056018
00160.047131
710160.089135
1570140.07130
10960150.048128
14120140.093134

For a better visualization, you can check the following boxplot as well:

fig, ax = plt.subplots(1, 4, figsize=(16, 5), dpi=150)
comparables_table.boxplot('age','W', ax=ax[0])
comparables_table.boxplot('sex','W', ax=ax[1])
comparables_table.boxplot('severity','W', ax=ax[2])
comparables_table.boxplot('y','W', ax=ax[3])

As the boxplot shows, both groups of treated and untreated individuals are very similar. Thus, we can be sure that the difference in the outcome is only due by the difference in the treatment. By looking at the results it becomes crystal clear that the treatment improves outcomes.

Let us now go for the “hard mode”, solving a counterfactual estimation problem with confounding.

Hard mode: solving df_obs

Now, we go for the “hard mode” and try to solve df_obs. Now we have confounding, which means that treatment assingment will not be uniform. Nevertheless, we run ForestEmbeddingsCounterfactual like before.

Organizing data in X, W and y format again:

# organizing data into X, W and y
X = df_obs[['sex','age','severity']]
W = df_obs['medication']
y = df_obs['recovery']

Validating the model, as before:

# importing cfml-tools
from cfml_tools import DecisionTreeCounterfactual
dtcf = DecisionTreeCounterfactual(save_explanatory=True)

# validating model using 5-fold CV
cv_scores = dtcf.get_cross_val_scores(X, y)
print(cv_scores)

[0.90593652 0.9394594 0.94191483 0.93571656 0.93803323]

Here it gets a little bit different. Remember that a high R2 could mean that the treatment variable has little effect on the outcome? As the treatment assignment is now correlated with the other variables, they “steal” importance from the treatment and this shows as a higher R2 in the confounded case.

We proceed to fit the model using X, W and y.

# fitting data to our model
dtcf.fit(X, W, y)

We then predict the counterfactuals for all our individuals. We get the dataframe in the counterfactuals variable, which predicts outcomes for both W = 0 and W = 1.

In this case, we can see some NaNs. That’s because for some individuals there are not enough treated or untreated neighbors to estimate the counterfactuals, controlled by the parameter min_sample_effect. When this parameter is high, we are conservative, getting more NaNs but less variance in counterfactual estimation.

# let us predict counterfactuals for these guys
counterfactuals = dtcf.predict(X)
counterfactuals.head()

Let us now compare true effect with estimated, for all the samples we could infer a counterfactual (no NaNs). Will the model get a good estimate in this case?

Nice! The model estimated the effect very well again. Note that we have less samples in the histogram, due to NaNs. Nevertheless, it is a cool result and shows that DecisionTreeCounterfactual can work with confounded data.

Let us explain the counterfactual of the first prediction, which is NaN for W = 1.

# our test sample
test_sample = X.iloc[[0]]

# running explanation
comparables_table = dtcf.explain(test_sample)

# showing comparables table
val_counts = comparables_table['W'].value_counts()
print(f'Number of treated: {val_counts[1]}, number of untreated: {val_counts[0]}')

Number of treated: 0, number of untreated: 142

In this case, there’s no treated comparables for us to draw inferences from. This explains why we cannot predict the outcome for this individual given W = 1! You can control the mininum number of samples required to perform a valid inference with the parameter min_sample_effect.

On the other hand, our third test sample has a healthy number of treated vs. untreated samples, so we can infer counterfactuals:

# our test sample
test_sample = X.iloc[[2]]

# running explanation
comparables_table = dtcf.explain(test_sample)

# showing comparables table
val_counts = comparables_table['W'].value_counts()
print(f'Number of treated: {val_counts[1]}, number of untreated: {val_counts[0]}')

Number of treated: 120, number of untreated: 57

I hope you liked the tutorial and will use cfml_tools for your causal inference problems soon!