10 min read

Stata 19 CATE command

Stata 19 new CATE command

Fernando Rios-Avila has a new blog about the new CATE command in Stata 19, which implements Conditional Average Treatment Effect (CATE) estimation.

https://friosavila.github.io/app_metrics/app_metrics12.html

clear all
webuse assets3
global catecovars age educ i.(incomecat pension married twoearn ira ownhome)
cate po (assets $catecovars) (e401k), group(incomecat) nolog
(Excerpt from Chernozhukov and Hansen (2004))



Conditional average treatment effects     Number of observations       = 9,913
Estimator:       Partialing out           Number of folds in cross-fit =    10
Outcome model:   Linear lasso             Number of outcome controls   =    17
Treatment model: Logit lasso              Number of treatment controls =    17
CATE model:      Random forest            Number of CATE variables     =    17

------------------------------------------------------------------------------
             |               Robust
      assets | Coefficient  std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
GATE         |
   incomecat |
          0  |   3999.888   989.8742     4.04   0.000      2059.77    5940.006
          1  |   1424.931   1668.081     0.85   0.393    -1844.447     4694.31
          2  |   5092.801   1344.769     3.79   0.000     2457.102    7728.499
          3  |   8700.512   2272.126     3.83   0.000     4247.226     13153.8
          4  |   20350.97   4717.093     4.31   0.000     11105.64     29596.3
-------------+----------------------------------------------------------------
ATE          |
       e401k |
  (Eligible  |
         vs  |
Not elig..)  |    7914.24    1150.79     6.88   0.000     5658.734    10169.75
-------------+----------------------------------------------------------------
POmean       |
       e401k |
Not eligi..  |   14012.54   834.0963    16.80   0.000     12377.74    15647.34
------------------------------------------------------------------------------

He said “We can achieve comparable results using standard regression with full interactions:”

clear all
webuse assets3
global catecovars age educ i.(incomecat pension married twoearn ira ownhome)

* you do not want to see all interactions.
qui:reg assets i.incomecat##i.e401k##c.($catecovars), robust

* Average Treatment Effect (ATE)
margins, at(e401k=(0 1)) noestimcheck contrast(atcontrast(r))

* Group Average Treatment Effects (GATEs)
margins, at(e401k=(0 1)) noestimcheck contrast(atcontrast(r)) over(incomecat)
(Excerpt from Chernozhukov and Hansen (2004))




Contrasts of predictive margins                          Number of obs = 9,913
Model VCE: Robust

Expression: Linear prediction, predict()
1._at: e401k = 0
2._at: e401k = 1

------------------------------------------------
             |         df           F        P>F
-------------+----------------------------------
         _at |          1       44.72     0.0000
             |
 Denominator |       9833
------------------------------------------------

--------------------------------------------------------------
             |            Delta-method
             |   Contrast   std. err.     [95% conf. interval]
-------------+------------------------------------------------
         _at |
   (2 vs 1)  |   7642.407   1142.797      5402.291    9882.524
--------------------------------------------------------------


Contrasts of predictive margins                          Number of obs = 9,913
Model VCE: Robust

Expression: Linear prediction, predict()
Over:       incomecat
1._at: 0.incomecat
           e401k = 0
       1.incomecat
           e401k = 0
       2.incomecat
           e401k = 0
       3.incomecat
           e401k = 0
       4.incomecat
           e401k = 0
2._at: 0.incomecat
           e401k = 1
       1.incomecat
           e401k = 1
       2.incomecat
           e401k = 1
       3.incomecat
           e401k = 1
       4.incomecat
           e401k = 1

-------------------------------------------------
              |         df           F        P>F
--------------+----------------------------------
_at@incomecat |
  (2 vs 1) 0  |          1       16.08     0.0001
  (2 vs 1) 1  |          1        0.67     0.4123
  (2 vs 1) 2  |          1       16.03     0.0001
  (2 vs 1) 3  |          1       13.77     0.0002
  (2 vs 1) 4  |          1       17.26     0.0000
       Joint  |          5       12.76     0.0000
              |
  Denominator |       9833
-------------------------------------------------

---------------------------------------------------------------
              |            Delta-method
              |   Contrast   std. err.     [95% conf. interval]
--------------+------------------------------------------------
_at@incomecat |
  (2 vs 1) 0  |   3594.182   896.3298      1837.191    5351.172
  (2 vs 1) 1  |     1283.3   1565.151     -1784.718    4351.317
  (2 vs 1) 2  |   5056.144   1262.728      2580.938     7531.35
  (2 vs 1) 3  |   8610.622   2320.156      4062.641     13158.6
  (2 vs 1) 4  |   19665.61   4733.551      10386.88    28944.34
---------------------------------------------------------------

Numerically these two sets of estimations may be close, but I don’t think they are from the same model, even the same idea. The CATE command with “PO” option is to estimate the CATE with “partialling out” approach.

What Fernando did is a “regression adjustment” type of model, which is modeling outcome model and allow interaction of treatment and covariates. The “partialling out” approach is sometimes called double ML.

Let’s see how to replicate “teffects ra” with “reg”:

clear all
webuse assets3
global catecovars age educ i.(incomecat pension married twoearn ira ownhome)

teffects ra (assets age educ i.(incomecat pension married twoearn ira ownhome))(e401k)

qui: reg assets i.e401k##c.(age educ) i.e401k##i.(incomecat pension married twoearn ira ownhome)
margins, at(e401k=(0 1)) noestimcheck contrast(atcontrast(r))
(Excerpt from Chernozhukov and Hansen (2004))



Iteration 0:  EE criterion = 1.291e-21  
Iteration 1:  EE criterion = 1.046e-23  

Treatment-effects estimation                    Number of obs     =      9,913
Estimator      : regression adjustment
Outcome model  : linear
Treatment model: none
------------------------------------------------------------------------------
             |               Robust
      assets | Coefficient  std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ATE          |
       e401k |
  (Eligible  |
         vs  |
Not elig..)  |   7929.242   1216.273     6.52   0.000      5545.39    10313.09
-------------+----------------------------------------------------------------
POmean       |
       e401k |
Not eligi..  |   14094.12   863.7641    16.32   0.000     12401.18    15787.07
------------------------------------------------------------------------------



Contrasts of predictive margins                          Number of obs = 9,913
Model VCE: OLS

Expression: Linear prediction, predict()
1._at: e401k = 0
2._at: e401k = 1

------------------------------------------------
             |         df           F        P>F
-------------+----------------------------------
         _at |          1       34.39     0.0000
             |
 Denominator |       9889
------------------------------------------------

--------------------------------------------------------------
             |            Delta-method
             |   Contrast   std. err.     [95% conf. interval]
-------------+------------------------------------------------
         _at |
   (2 vs 1)  |   7929.242   1352.153      5278.746    10579.74
--------------------------------------------------------------

DoubleML

The R package DoubleML implements the double/debiased machine learning framework of Chernozhukov et al. (2018). It provides functionalities to estimate parameters in causal models based on machine learning methods. The double machine learning framework consist of three key ingredients: Neyman orthogonality, High-quality machine learning estimation and Sample splitting.

They consider a partially linear model:

\[ y_i = \theta d_i + g_0(x_i) + \eta_i \]

\[ d_i = m_0(x_i) + v_i \]

This model is quite general, except it does not allow interaction of \(d\) and \(x\); therefore no hetergeneous treatment effect across \(x\). But “DoubleML” implements more than partially linear model, it actually allows for heterogeneous treatment effects, in models such as interactive regression model.

The basic idea of doubleML is to use any machine learning algorithm to estimate outcome equation (\(l_0(X) = E(Y | X)\)) and treatment equation (\(m_0(X) = E(D | X)\)). Then get the residuals, namely \(\hat W=Y-\hat l_0(X)\) and \(\hat V = D - \hat m_0(X)\).

Then regress \(\hat W\) on \(\hat V\). Based on FWL theorem, you get \(\hat \theta\).

An important component here is to specify a Neyman-orthogonal score function. In the case of PLR, it can be the product of the two residuals:

\[\psi (W; \theta, \eta) = (Y-D\theta -g(X))(D-m(X)) \]

The estimators \(\hat \theta\) is to solve the equation that the sample mean of this score function being 0.

And the variance of this score function is used as the variance of the doubleML estimator’s variance.

Now we try to replicate the CATE estimation in Stata using the R package DoubleML. I have not found a way to include factor variables in the lasso regression in R, so I included them as numerica variables.

The example in stata’s cate command is to use lasso on both outcome and treatment regression, then random forest on the individual treatment effect to get ATE. Here I try to replicate it in R.

library(haven)
library(DoubleML)
library(mlr3)
library(mlr3learners)
library(data.table)
library(dplyr)
# get rid of labels because doubleML does not like them
data <- read_dta("https://www.stata-press.com/data/r19/assets3.dta") |> 
  zap_labels() 


  # mutate(incomecat=as.factor(incomecat), 
  #        pension=as.factor(pension), 
  #        married=as.factor(married), 
  #        twoearn=as.factor(twoearn), 
  #        ira=as.factor(ira), 
  #        ownhome=as.factor(ownhome)) 

assets3 <- data.table::setDT(data)
dml_data = DoubleMLData$new(data = assets3,
                             y_col = "assets",
                             d_cols = "e401k",
                             x_cols = c("age","educ", "incomecat","pension", "married","twoearn","ira","ownhome"))

print(dml_data)
================= DoubleMLData Object ==================


------------------ Data summary      ------------------
Outcome variable: assets
Treatment variable(s): e401k
Covariates: age, educ, incomecat, pension, married, twoearn, ira, ownhome
Instrument(s): 
Selection variable: 
No. Observations: 9913
lgr::get_logger("mlr3")$set_threshold("warn")
learner = lrn("regr.glmnet", lambda=1)
ml_g_sim = learner$clone()
ml_m_sim = learner$clone()
set.seed(123)
obj_dml_plr = DoubleMLPLR$new(dml_data, ml_l=ml_g_sim, ml_m=ml_m_sim, n_folds=10)
obj_dml_plr$fit()
obj_dml_plr$summary()
Estimates and significance testing of the effect of target variables
      Estimate. Std. Error t value Pr(>|t|)    
e401k      7388       1293   5.714  1.1e-08 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

If we use random forest on the two models:

# surpress messages from mlr3 package during fitting
lgr::get_logger("mlr3")$set_threshold("warn")
learner = lrn("regr.ranger", num.trees=2000,  max.depth=5, min.node.size=2)
ml_l = learner$clone()
ml_m = learner$clone()
# learner = lrn("regr.glmnet")
# ml_g_sim = learner$clone()
# ml_m_sim = learner$clone()
set.seed(123)
obj_dml_plr = DoubleMLPLR$new(dml_data, ml_l=ml_l, ml_m=ml_m, n_folds=5)
obj_dml_plr$fit()
obj_dml_plr$summary()
Estimates and significance testing of the effect of target variables
      Estimate. Std. Error t value Pr(>|t|)    
e401k      9474       1376   6.883 5.88e-12 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

We can do the same in Stata:

clear all
webuse assets3
global catecovars age educ incomecat pension married twoearn ira ownhome
cate po (assets age educ incomecat pension married twoearn ira ownhome) (e401k),  omethod(rforest) tmethod(rforest) nolog  xfolds(5)
(Excerpt from Chernozhukov and Hansen (2004))



Conditional average treatment effects     Number of observations       = 9,913
Estimator:       Partialing out           Number of folds in cross-fit =     5
Outcome model:   Random forest            Number of outcome controls   =     8
Treatment model: Random forest            Number of treatment controls =     8
CATE model:      Random forest            Number of CATE variables     =     8

------------------------------------------------------------------------------
             |               Robust
      assets | Coefficient  std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ATE          |
       e401k |
  (Eligible  |
         vs  |
Not elig..)  |   8300.384   1174.524     7.07   0.000     5998.359    10602.41
-------------+----------------------------------------------------------------
POmean       |
       e401k |
Not eligi..  |   14002.82   840.8776    16.65   0.000     12354.73    15650.91
------------------------------------------------------------------------------

The results are somewhat similar, but not exactly the same. There are some subtle differences in the way these two packages running random forest on the two models that I am not aware of.

AIPW

Stata’s cate also has “aipw” option.

clear all
webuse assets3
global catecovars age educ incomecat pension married twoearn ira ownhome
cate aipw (assets $catecovars) (e401k),  nolog
(Excerpt from Chernozhukov and Hansen (2004))



Conditional average treatment effects     Number of observations       = 9,913
Estimator:       Augmented IPW            Number of folds in cross-fit =    10
Outcome model:   Linear lasso             Number of outcome controls   =     8
Treatment model: Logit lasso              Number of treatment controls =     8
CATE model:      Random forest            Number of CATE variables     =     8

------------------------------------------------------------------------------
             |               Robust
      assets | Coefficient  std. err.      z    P>|z|     [95% conf. interval]
-------------+----------------------------------------------------------------
ATE          |
       e401k |
  (Eligible  |
         vs  |
Not elig..)  |   7053.901   1181.863     5.97   0.000     4737.492     9370.31
-------------+----------------------------------------------------------------
POmean       |
       e401k |
Not eligi..  |   14146.76   858.6184    16.48   0.000     12463.89    15829.62
------------------------------------------------------------------------------

I’ll try to do this with “npcausal”:

library(npcausal)
#SL.library <- c("SL.earth","SL.gam","SL.glmnet","SL.glm.interaction", "SL.mean","SL.ranger", "SL.xgboost")
#SL.library <- c("SL.glmnet")
SL.lasso = function(...) {
  SL.glmnet(..., alpha=1)
}
SL.library <- c("SL.lasso")

Y <- data$assets
A <- data$e401k
W <- data[, c("age", "educ", "incomecat", "pension", "married", "twoearn", "ira", "ownhome")]
aipw<- ate(y=Y, a=A, x=W, nsplits=10, sl.lib=SL.library)
  |                                                                              |                                                                      |   0%  |                                                                              |==                                                                    |   2%  |                                                                              |====                                                                  |   5%  |                                                                              |=====                                                                 |   8%  |                                                                              |=======                                                               |  10%  |                                                                              |=========                                                             |  12%  |                                                                              |==========                                                            |  15%  |                                                                              |============                                                          |  18%  |                                                                              |==============                                                        |  20%  |                                                                              |================                                                      |  22%  |                                                                              |==================                                                    |  25%  |                                                                              |===================                                                   |  28%  |                                                                              |=====================                                                 |  30%  |                                                                              |=======================                                               |  32%  |                                                                              |========================                                              |  35%  |                                                                              |==========================                                            |  38%  |                                                                              |============================                                          |  40%  |                                                                              |==============================                                        |  42%  |                                                                              |================================                                      |  45%  |                                                                              |=================================                                     |  48%  |                                                                              |===================================                                   |  50%  |                                                                              |=====================================                                 |  52%  |                                                                              |======================================                                |  55%  |                                                                              |========================================                              |  58%  |                                                                              |==========================================                            |  60%  |                                                                              |============================================                          |  62%  |                                                                              |==============================================                        |  65%  |                                                                              |===============================================                       |  68%  |                                                                              |=================================================                     |  70%  |                                                                              |===================================================                   |  72%  |                                                                              |====================================================                  |  75%  |                                                                              |======================================================                |  78%  |                                                                              |========================================================              |  80%  |                                                                              |==========================================================            |  82%  |                                                                              |============================================================          |  85%  |                                                                              |=============================================================         |  88%  |                                                                              |===============================================================       |  90%  |                                                                              |=================================================================     |  92%  |                                                                              |==================================================================    |  95%  |                                                                              |====================================================================  |  98%  |                                                                              |======================================================================| 100%
     parameter       est        se     ci.ll     ci.ul pval
1      E{Y(0)} 14138.295  851.5868 12469.185 15807.405    0
2      E{Y(1)} 21131.594  893.7780 19379.789 22883.399    0
3 E{Y(1)-Y(0)}  6993.299 1179.0464  4682.368  9304.229    0

Note here to match the Stata results, I included only “glmnet” in the SuperLearner library. The results are not exactly the same, but close enough.

Or we can use a similar package with more functionalities: “AIPW”.

library(AIPW)
library(SuperLearner)

#SuperLearner libraries for outcome (Q) and exposure models (g)
#sl.lib <- c("SL.glmnet")

AIPW_sl <- AIPW$new(Y= Y,
                     A= A,
                     W= W, 
                     Q.SL.library = SL.library,
                     g.SL.library = SL.library,
                     k_split = 10,
                     verbose=FALSE)

AIPW_sl$fit()
AIPW_sl$summary()
AIPW_sl$result
                  Estimate        SE   95% LCL   95% UCL    N
Mean of Exposure 21626.890  874.3117 19913.239 23340.541 3682
Mean of Control  14055.914  847.7744 12394.276 15717.552 6231
Mean Difference   7570.976 1164.1694  5289.204  9852.748 9913
# library(ggplot2)
# AIPW_sl$plot.p_score()
# AIPW_sl$plot.ip_weights()