14 min read

Recent causal inference tools

In this post, I’ll use simulated data to see how a few recently developed causal inference packages work.

In other posts, I have discussed some methods involving machine learning, such as TMLE, or causal forest. This time we’ll talk about nonparametric methods and double machine learning.

Identification

First suppose we have a “target” parameter in mind, which is some function of the some unknown distribution \(P^*\), called a functional, say \(\psi^*(P^*)\). For example ATE, \(E(Y^1- Y^0)\).

The goal is to have \(\psi^*(P^*)\) estimated by \(\psi(\hat P)\), where \(P\) is some observational distribution, and \(\psi\) is some known function. This is called identification.

For example, suppose we are intreseted in the ATE, \(E[Y^1 - Y^0]\). The identification is done by deriving from \(E(Y^a | X) = E(Y | X, A=a)\); basically from the target parameter, which is counterfactual, to something estimable from the observed variables in the sample. This identification can only be done with assumptions such as ignorability, positivity and consistency.

In this case, \(\psi^* = E(Y^a)\), and \(\psi = E \{ E(Y | X) \}\).

Nonparametric estimation in a nutshell

After idenfication, we then need to find a way to construct a good estimator \(\hat \psi\) for \(\psi (P)\).

Suppose we have a sample \(Z\), which can be \({A,W,Y}\) (treatment, covariates, and outcome) in causal inference situation.

We can use a parametric model, assuming we know what kind of distribution \(P\) is. We assume \(P=P_\theta\). Then we can use MLE to get \(\psi(P_\theta)\). This is what we do in parametric modeling. But this can be too restrictive, we usually have no way of knowing how good \(P_\theta\) is; in other words, the model can be misspecified thus the target parameter is estimated with bias.

If we don’t make any assumptions about the distribution, then we are in nonparametric world. The natural way to approximate \(P\) is to use the empirical distribution \(\hat P\). Then \(\psi^*(\hat P)\) is the “plug-in” estimator. For example we can use sample mean to estimate the population mean, etc. But “plug-in” estimator is not the best choice, in general, because they are not \(\sqrt n\) consistent. There is usually a “plug-in bias” when using “plug-in” estimator.

Another choice is nonparametric estimator based on influence function. The nice thing about this type of estimators based on efficient influence function is that they are often \(\sqrt n\) consistent, meaning they converge to the truth as fast as \(\sqrt n\) rate. This means \(\hat \psi - \psi\) not only goes to 0 with \(n\) goes to infinity, but also as fast as \(1/ \sqrt n\) goes to 0. Why does this matter? Because we have limited sample size, when we are based on asymptotics, the faster the bias goes to 0 the smaller sample size we need.

If we have an estimator such that \[ \sqrt n (\hat \psi - \psi) = \sqrt n P_n(\phi) + o_P(1) \rightarrow N(0, var(\phi))\]

Then this estimator is \(\sqrt n\) consistent and asympotically normal; and \(\phi\) is the influence function. The efficient influence function is the one with variance at the lower bound (similar to parametric case for Cramer-Rao lower bound).

For ATE, target parameter \(\psi = E \{ E(Y | X, A=1) \}\), the efficient influence function is \[\phi (Z; P) = \frac{A}{\pi (X)} \{ Y - \mu (X) \} + \mu(X) - \psi\] where \(\pi (X) = P(A=1 | X)\) the propensity score model, \(\mu (X) = E(Y |X, A=1)\), the outcome model.

The idea of IF-based estimator is to use so called “one-step correction” to estimate the “plug-in” bias. The bias is approximated by first deriving the influence function; then based on the influence function, the bias can be estimated by using the influence function as a slope to the path from “plug-in” estimate” to true parameter. This path can be non-linear; therefore this one-step correction may not be perfect, but it should be closer to the truth than the original estimate.

For example, an IF-based bias-corrected estimator \[\hat \psi = P_n [ \frac{A}{\hat \pi (X)} \{ Y - \hat \mu (X) \} + \hat \mu(X) ]\] where \(P_n\) means sample mean. This is the familiar AIPW estimator. It is known to be doubly robust. Since we know the influence function, we can calculate its variance. This variance is the same as variance of \(\hat \psi\) since they differ by a constant \(\psi\). We can use any estimators to estimate the nuisance parameters, then get \(\hat \psi\). Then the sample variance is the variance of the influence function; and the sample mean is the ATE.

In this estimator, both \(\hat \mu\) and \(\hat \pi\) can be done using any estimator, such as machine learning.

The IF based estimator only works if either \(\phi(P)\) is not too complex (empirical processes theory involved, a lot of difficult situations won’t apply), or we separate \(\hat P\) and \(P_n\) to prevent overfitting. The latter is by sample splitting. Sample splitting is usually preferred because it needs less assumption than empirical processes. The idea is to split the sample, do estimation of nuisance parameters (such as \(\mu\) and \(\pi\)) in one sample, and construct estimator in the other. This way it is shown the EIF based estimator is \(\sqrt n\) consistent and asympotically normal; the variance is the variance of the influence function.

Unfortunately influence function is hard to derive for different situations. Ed Kennedy has done a lot of work on using IF on causal inference http://www.ehkennedy.com/. Here we use his “npcausal” package.

simulation 1

The simulated data is from https://migariane.github.io/TMLE.nb.html

library(npcausal)
library(boot)
library(MASS) 
library(SuperLearner)
library(tmle)
library(tidyverse)
library(ranger)
set.seed(123)
n=1000  
w1 <- rbinom(n, size=1, prob=0.5)
w2 <- rbinom(n, size=1, prob=0.65)
w3 <- round(runif(n, min=0, max=4), digits=3)
w4 <- round(runif(n, min=0, max=5), digits=3)
A  <- rbinom(n, size=1, prob= plogis(-0.4 + 0.2*w2 + 0.15*w3 + 0.2*w4 + 0.15*w2*w4))
Y.1 <- rbinom(n, size=1, prob= plogis(-1 + 1 -0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4))
Y.0 <- rbinom(n, size=1, prob= plogis(-1 + 0 -0.1*w1 + 0.3*w2 + 0.25*w3 + 0.2*w4 + 0.15*w2*w4))
Y <- Y.1*A + Y.0*(1 - A)
W<-data.frame(cbind(w1,w2,w3,w4))
data <-  data.frame(w1, w2, w3, w4, A, Y, Y.1, Y.0)
True_Psi <- mean(data$Y.1-data$Y.0);
cat(" True_Psi:", True_Psi)
##  True_Psi: 0.203

We first include a few learners. “npcausal” has a set of default learners.

SL.library <- c("SL.earth","SL.gam","SL.glm","SL.glmnet","SL.glm.interaction", "SL.mean","SL.ranger", "SL.xgboost")

plug-in estimator from outcome regression

SL.outcome<- SuperLearner(Y=data$Y, X=data %>% select(w1,w2,w3,w4,A),
                                     SL.library=SL.library, family="binomial")
SL.outcome.obs<- predict(SL.outcome, X=data %>% select(w1,w2,w3,w4,A))$pred
# predict the PO Y^1
SL.outcome.1<- predict(SL.outcome, newdata=data %>% select(w1,w2,w3,w4,A) %>% mutate(A=1))$pred
# predict the PO Y^0
SL.outcome.0<- predict(SL.outcome, newdata=data %>% select(w1,w2,w3,w4,A) %>% mutate(A=0))$pred

This is to do machine learning on the outcome equation; then get prediction on treated vs control, to get \(\hat Y^1\) and \(\hat Y^0\). The sample average of the difference is our plug-in estimator.

SL.plugin<-mean(SL.outcome.1-SL.outcome.0)
SL.plugin
## [1] 0.2072614

It’s not too far from the truth in this case.

Q=data.frame(QAW=SL.outcome.obs, Q0W=SL.outcome.0,Q1W=SL.outcome.1)

AIPW with machine learning

We can use “npcausal” to do AIPW, or we can manually calculate it.

First we get propensity score. Then, based on both \(\hat \mu\) and \(\hat \pi\), we can get the AIPW estimator.

set.seed(123)
SL.g<- SuperLearner(Y=data$A, X=data %>% select(w1,w2,w3,w4),
                    SL.library=SL.library, family="binomial")
g1W <- SL.g$SL.predict
g0W<- 1- g1W
IF.1<-((data$A/g1W)*(data$Y-Q$Q1W)+Q$Q1W)
IF.0<-(((1-data$A)/g0W)*(data$Y-Q$Q0W)+Q$Q0W)
IF<-IF.1-IF.0
aipw.1<-mean(IF.1);aipw.0<-mean(IF.0)
aipw.manual<-aipw.1-aipw.0
ci.lb<-mean(IF)-qnorm(.975)*sd(IF)/sqrt(length(IF))
ci.ub<-mean(IF)+qnorm(.975)*sd(IF)/sqrt(length(IF))
res.manual.aipw<-c(aipw.manual,ci.lb, ci.ub)
res.manual.aipw
## [1] 0.2121542 0.1514396 0.2728689

Now we use “npcausal”.

library(npcausal)
aipw<- ate(y=Y, a=A, x=W, nsplits=1, sl.lib=SL.library)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5740347 0.02704671 0.5210231 0.6270462    0
## 2      E{Y(1)} 0.7860537 0.01599838 0.7546969 0.8174106    0
## 3 E{Y(1)-Y(0)} 0.2120191 0.03109855 0.1510659 0.2729722    0
aipw$res
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5740347 0.02704671 0.5210231 0.6270462    0
## 2      E{Y(1)} 0.7860537 0.01599838 0.7546969 0.8174106    0
## 3 E{Y(1)-Y(0)} 0.2120191 0.03109855 0.1510659 0.2729722    0

The results from “npcausal” are almost the same as the manual results. Note here we did not do sample splitting; therefore this is not “bias-corrected” estimator. Sample splitting is to help prevent overfitting, thus reducing bias. “npcausal” does this for you, if you specify “nsplits” option.

aipw2<- ate(y=Y, a=A, x=W, nsplits=10)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |====                                                                  |   5%
  |                                                                            
  |=====                                                                 |   8%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |==========                                                            |  15%
  |                                                                            
  |============                                                          |  18%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |================                                                      |  22%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================                                                   |  28%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |=======================                                               |  32%
  |                                                                            
  |========================                                              |  35%
  |                                                                            
  |==========================                                            |  38%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |==============================                                        |  42%
  |                                                                            
  |================================                                      |  45%
  |                                                                            
  |=================================                                     |  48%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |=====================================                                 |  52%
  |                                                                            
  |======================================                                |  55%
  |                                                                            
  |========================================                              |  58%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |============================================                          |  62%
  |                                                                            
  |==============================================                        |  65%
  |                                                                            
  |===============================================                       |  68%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |===================================================                   |  72%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================                |  78%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |==========================================================            |  82%
  |                                                                            
  |============================================================          |  85%
  |                                                                            
  |=============================================================         |  88%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |=================================================================     |  92%
  |                                                                            
  |==================================================================    |  95%
  |                                                                            
  |====================================================================  |  98%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se     ci.ll     ci.ul pval
## 1      E{Y(0)} 0.5720028 0.02882336 0.5155090 0.6284966    0
## 2      E{Y(1)} 0.7846837 0.01628671 0.7527618 0.8166057    0
## 3 E{Y(1)-Y(0)} 0.2126809 0.03276098 0.1484694 0.2768925    0

TMLE

We can use TMLE for ATE.

library(tmle)
TMLE<- tmle(Y=data$Y,A=data$A,W=W, family="binomial", Q.SL.library=SL.library, g.SL.library=SL.library)

TMLE$estimates$ATE
## $psi
## [1] 0.2125673
## 
## $var.psi
## [1] 0.001036017
## 
## $CI
## [1] 0.1494804 0.2756543
## 
## $pvalue
## [1] 3.999638e-11

DoubleML

The R package DoubleML implements the double/debiased machine learning framework of Chernozhukov et al. (2018). It provides functionalities to estimate parameters in causal models based on machine learning methods. The double machine learning framework consist of three key ingredients: Neyman orthogonality, High-quality machine learning estimation and Sample splitting.

They consider a partially linear model:

\[ y_i = \theta d_i + g_0(x_i) + \eta_i \]

\[ d_i = m_0(x_i) + v_i \]

This model is quite general, except it does not allow interaction of \(d\) and \(x\); therefore no hetergeneous treatment effect across \(x\). But “DoubleML” implements more than partially linear model, it actually allows for heterogeneous treatment effects, in models such as interactive regression model.

The basic idea of doubleML is to use any machine learning algorithm to estimate outcome equation (\(l_0(X) = E(Y | X)\)) and treatment equation (\(m_0(X) = E(D | X)\)). Then get the residuals, namely \(\hat W=Y-\hat l_0(X)\) and \(\hat V = D - \hat m_0(X)\).

Then regress \(\hat W\) on \(\hat V\). Based on FWL theorem, you get \(\hat \theta\).

An important component here is to specify a Neyman-orthogonal score function. In the case of PLR, it can be the product of the two residuals:

\[\psi (W; \theta, \eta) = (Y-D\theta -g(X))(D-m(X)) \]

The estimators \(\hat \theta\) is to solve the equation that the sample mean of this score function being 0.

And the variance of this score function is used as the variance of the doubleML estimator’s variance.

library(DoubleML)
dml_data = DoubleMLData$new(data,
                             y_col = "Y",
                             d_cols = "A",
                             x_cols = c("w1","w2","w3","w4"))
print(dml_data)
## ================= DoubleMLData Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: w1, w2, w3, w4
## Instrument(s): 
## No. Observations: 1000
library(mlr3)
library(mlr3learners)
# surpress messages from mlr3 package during fitting
lgr::get_logger("mlr3")$set_threshold("warn")

learner = lrn("regr.ranger", num.trees=500,  max.depth=5, min.node.size=2)
ml_g = learner$clone()
ml_m = learner$clone()

learner = lrn("regr.glmnet")
ml_g_sim = learner$clone()
ml_m_sim = learner$clone()

set.seed(3141)
obj_dml_plr = DoubleMLPLR$new(dml_data, ml_g=ml_g, ml_m=ml_m)
obj_dml_plr$fit()
print(obj_dml_plr)
## ================= DoubleMLPLR Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: w1, w2, w3, w4
## Instrument(s): 
## No. Observations: 1000
## 
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
## 
## ------------------ Machine learner   ------------------
## ml_g: regr.ranger
## ml_m: regr.ranger
## 
## ------------------ Resampling        ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
## 
## ------------------ Fit summary       ------------------
##  Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.23271    0.03224   7.219 5.25e-13 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

simulation 2

This simulation is based on a package “AIPW”, which has a simulation data “eager_sim_rct”. \(A\) denotes the binary treatment assignment, \(Y\) is the binary outcome, and \(W_g\) represents the covariate that affects the treatment assignment, which in our case was deemed to be the eligibility stratum indicator, sampled with replacement from the EAGeR Trial. Similarly, \(W_Q\) is a set of baseline prognostic covariates, which were also sampled with replacement from the EAGeR Trial, and includes the number of prior pregnancy losses, age, number of months of trying to conceive prior to randomization, body mass index (weight (kg)/height (m)2), and mean arterial blood pressure (denoted W1…5 , respectively).

Data and the true risk difference:

library(AIPW)
data("eager_sim_rct")
glimpse(eager_sim_rct)
## Rows: 1,228
## Columns: 8
## $ sim_Y             <int> 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0…
## $ sim_Tx            <int> 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0…
## $ eligibility       <dbl> 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0…
## $ loss_num          <dbl> 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 2, 2, 1…
## $ age               <dbl> 21.93, 30.65, 30.96, 32.24, 31.49, 27.99, 31.60, 25.…
## $ time_try_pregnant <dbl> 2, 1, 0, 9, 1, 10, 8, 6, 1, 0, 1, 1, 3, 3, 0, 10, 10…
## $ BMI               <dbl> 19.51963, 25.56474, 20.75446, 21.65728, 24.27480, 19…
## $ meanAP            <dbl> 73.33333, 79.00000, 79.33333, 63.83333, 97.44444, 85…
attributes(eager_sim_rct)$true_rd
## [1] 0.1324569
data <- eager_sim_rct %>%
  rename(Y=sim_Y, A=sim_Tx) %>%
  tibble()

W <- data %>%
  select(-Y,-A)

plug-in estimator from outcome regression

SL.outcome<- SuperLearner(Y=data$Y, X=data %>% select(-Y),
                                     SL.library=SL.library, family="binomial")
SL.outcome.obs<- predict(SL.outcome, X=data %>% select(-Y))$pred
# predict the PO Y^1
SL.outcome.1<- predict(SL.outcome, newdata=data %>% select(-Y) %>% mutate(A=1))$pred
# predict the PO Y^0
SL.outcome.0<- predict(SL.outcome, newdata=data %>% select(-Y) %>% mutate(A=0))$pred

This is to do machine learning on the outcome equation; then get prediction on treated vs control, to get \(\hat Y^1\) and \(\hat Y^0\). The sample average of the difference is our plug-in estimator.

SL.plugin<-mean(SL.outcome.1-SL.outcome.0)
SL.plugin
## [1] 0.1206068
Q=data.frame(QAW=SL.outcome.obs, Q0W=SL.outcome.0,Q1W=SL.outcome.1)

AIPW with machine learning

set.seed(123)
SL.g<- SuperLearner(Y=data$A, X=data %>% select(-Y, -A),
                    SL.library=SL.library, family="binomial")
g1W <- SL.g$SL.predict
g0W<- 1- g1W
IF.1<-((data$A/g1W)*(data$Y-Q$Q1W)+Q$Q1W)
IF.0<-(((1-data$A)/g0W)*(data$Y-Q$Q0W)+Q$Q0W)
IF<-IF.1-IF.0
aipw.1<-mean(IF.1);aipw.0<-mean(IF.0)
aipw.manual<-aipw.1-aipw.0
ci.lb<-mean(IF)-qnorm(.975)*sd(IF)/sqrt(length(IF))
ci.ub<-mean(IF)+qnorm(.975)*sd(IF)/sqrt(length(IF))
res.manual.aipw<-c(aipw.manual,ci.lb, ci.ub)
res.manual.aipw
## [1] 0.13093060 0.06941035 0.19245086

Now we use “npcausal”.

library(npcausal)
aipw<- ate(y=data$Y, a=data$A, x=W, nsplits=1, sl.lib=SL.library)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se      ci.ll     ci.ul pval
## 1      E{Y(0)} 0.4463680 0.02281026 0.40165986 0.4910761    0
## 2      E{Y(1)} 0.5738675 0.02101274 0.53268257 0.6150525    0
## 3 E{Y(1)-Y(0)} 0.1274996 0.03086141 0.06701122 0.1879879    0

The results from “npcausal” are almost the same as the manual results. Note here we did not do sample splitting; therefore this is not “bias-corrected” estimator. Sample splitting is to help prevent overfitting, thus reducing bias. “npcausal” does this for you, if you specify “nsplits” option.

aipw2<- ate(y=data$Y, a=data$A, x=W, nsplits=10)
## 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==                                                                    |   2%
  |                                                                            
  |====                                                                  |   5%
  |                                                                            
  |=====                                                                 |   8%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |=========                                                             |  12%
  |                                                                            
  |==========                                                            |  15%
  |                                                                            
  |============                                                          |  18%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |================                                                      |  22%
  |                                                                            
  |==================                                                    |  25%
  |                                                                            
  |===================                                                   |  28%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |=======================                                               |  32%
  |                                                                            
  |========================                                              |  35%
  |                                                                            
  |==========================                                            |  38%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |==============================                                        |  42%
  |                                                                            
  |================================                                      |  45%
  |                                                                            
  |=================================                                     |  48%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |=====================================                                 |  52%
  |                                                                            
  |======================================                                |  55%
  |                                                                            
  |========================================                              |  58%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |============================================                          |  62%
  |                                                                            
  |==============================================                        |  65%
  |                                                                            
  |===============================================                       |  68%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |===================================================                   |  72%
  |                                                                            
  |====================================================                  |  75%
  |                                                                            
  |======================================================                |  78%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |==========================================================            |  82%
  |                                                                            
  |============================================================          |  85%
  |                                                                            
  |=============================================================         |  88%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |=================================================================     |  92%
  |                                                                            
  |==================================================================    |  95%
  |                                                                            
  |====================================================================  |  98%
  |                                                                            
  |======================================================================| 100%
##      parameter       est         se      ci.ll     ci.ul pval
## 1      E{Y(0)} 0.4437742 0.02523783 0.39430808 0.4932404    0
## 2      E{Y(1)} 0.5755262 0.02312439 0.53020242 0.6208500    0
## 3 E{Y(1)-Y(0)} 0.1317520 0.03411870 0.06487934 0.1986247    0

DoubleML

library(DoubleML)
dml_data = DoubleMLData$new(as.data.frame(data),
                             y_col = "Y",
                             d_cols = "A",
                             x_cols = c("eligibility" , "loss_num","age", "time_try_pregnant","BMI","meanAP"))
print(dml_data)
## ================= DoubleMLData Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: eligibility, loss_num, age, time_try_pregnant, BMI, meanAP
## Instrument(s): 
## No. Observations: 1228
library(mlr3)
library(mlr3learners)
# surpress messages from mlr3 package during fitting
lgr::get_logger("mlr3")$set_threshold("warn")

learner = lrn("regr.ranger", num.trees=500,  max.depth=5, min.node.size=2)
ml_g = learner$clone()
ml_m = learner$clone()

learner = lrn("regr.glmnet")
ml_g_sim = learner$clone()
ml_m_sim = learner$clone()

set.seed(3141)
obj_dml_plr = DoubleMLPLR$new(dml_data, ml_g=ml_g, ml_m=ml_m)
obj_dml_plr$fit()
print(obj_dml_plr)
## ================= DoubleMLPLR Object ==================
## 
## 
## ------------------ Data summary      ------------------
## Outcome variable: Y
## Treatment variable(s): A
## Covariates: eligibility, loss_num, age, time_try_pregnant, BMI, meanAP
## Instrument(s): 
## No. Observations: 1228
## 
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
## 
## ------------------ Machine learner   ------------------
## ml_g: regr.ranger
## ml_m: regr.ranger
## 
## ------------------ Resampling        ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
## 
## ------------------ Fit summary       ------------------
##  Estimates and significance testing of the effect of target variables
##   Estimate. Std. Error t value Pr(>|t|)    
## A   0.12505    0.03292   3.798 0.000146 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1