Semiparametrics: DML (Mediated) with Neural Nets - AGMM (Sequential) and AGMM2L2 (Simultaneous)
# ---- Limit BLAS/OpenMP threads BEFORE importing heavy libs ----
import os as os
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
# ---- Standard libs ----
import sys
import time
import platform
from pathlib import Path
# ---- Third-party ----
import numpy as np
from threadpoolctl import threadpool_limits
# Keep native libraries (BLAS/OpenMP) to 1 thread
threadpool_limits(1)
# ---- Local repo imports (adjust path if needed) ----
sys.path.append(str(Path.cwd() / "../../simulations"))
import dgps_mediated as dgps
import torch
import torch.nn as nn
from nnpiv.neuralnet.agmm import AGMM
from nnpiv.neuralnet.agmm2 import AGMM2L2
from nnpiv.semiparametrics import DML_mediated
# -----------------------
# Reproducibility helpers
# -----------------------
def seed_everything(seed: int = 123) -> None:
"""Set seeds for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Keep torch to 1 thread too
try:
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
except Exception:
pass
seed_everything(123)
# -----------------------
# Resource print utility
# -----------------------
def print_resources():
"""Print basic compute resource info (CPU, GPU, library versions)."""
cpu_cores = os.cpu_count()
pyver = sys.version.split()[0]
npver = np.__version__
torchver = torch.__version__
if torch.cuda.is_available():
try:
gpu_name = torch.cuda.get_device_name(0)
except Exception:
gpu_name = "Unknown GPU"
gpu_info = f"CUDA: available — {gpu_name}"
else:
gpu_info = "CUDA: not available"
print("=== Compute resources ===")
print(f"Python: {pyver}")
print(f"NumPy: {npver}")
print(f"PyTorch: {torchver}")
print(f"CPU cores: {cpu_cores}")
print(gpu_info)
print("Thread caps (env):")
for k in ["OMP_NUM_THREADS","OPENBLAS_NUM_THREADS","MKL_NUM_THREADS",
"VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
print(f" {k}={os.environ.get(k, 'unset')}")
print(f"Platform: {platform.platform()}")
print("=========================\n")
# -----------------------
# Result formatter
# -----------------------
def summarize_dml_result(name: str, result, elapsed: float):
"""
Accepts result from .dml() and prints θ, SE, 95% CI when available.
Compatible with returns like (theta, var, ci) or (theta, var, ci, cov).
"""
if isinstance(result, tuple):
if len(result) == 3:
theta, var, ci = result
cov = None
elif len(result) == 4:
theta, var, ci, cov = result
else:
print(f"[{name}] time={elapsed:.2f}s — result={result}")
return
else:
print(f"[{name}] time={elapsed:.2f}s — result={result}")
return
theta = np.atleast_1d(theta).astype(float)
var = np.atleast_1d(var).astype(float)
se = np.sqrt(var)
ci = np.array(ci, dtype=float) if ci is not None else None
def fmt_arr(a):
return f"{float(a[0]):.4f}" if a.size == 1 else np.array2string(a, precision=4)
print(f"[{name}] time={elapsed:.2f}s")
print(f" theta: {fmt_arr(theta)}")
print(f" SE : {fmt_arr(se)}")
if ci is not None:
if ci.ndim == 1 and ci.size == 2:
print(f" 95% CI: [{ci[0]:.4f}, {ci[1]:.4f}]")
else:
print(f" 95% CI: {np.array2string(ci, precision=4)}")
if 'cov' in locals() and cov is not None:
print(f" (cov shape: {cov.shape})")
print("")
# -----------------------
# Print resources
# -----------------------
print_resources()
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
=== Compute resources ===
Python: 3.10.18
NumPy: 2.2.6
PyTorch: 2.5.0
CPU cores: 112
CUDA: not available
Thread caps (env):
OMP_NUM_THREADS=1
OPENBLAS_NUM_THREADS=1
MKL_NUM_THREADS=1
VECLIB_MAXIMUM_THREADS=1
NUMEXPR_NUM_THREADS=1
Platform: Linux-4.18.0-553.44.1.el8_10.x86_64-x86_64-with-glibc2.28
=========================
# =========================================================
# Data generation
# =========================================================
# Function dictionary (for reference):
# {'abs': 0, '2dpoly': 1, 'sigmoid': 2,
# 'sin': 3, 'frequent_sin': 4, 'abs_sqrt': 5, 'step': 6, '3dpoly': 7,
# 'linear': 8, 'rand_pw': 9, 'abspos': 10, 'sqrpos': 11, 'band': 12,
# 'invband': 13, 'steplinear': 14, 'pwlinear': 15, 'exponential': 16}
fn_number = 0
tau_fn = dgps.get_tau_fn(fn_number)
tauinv_fn = dgps.get_tauinv_fn(fn_number) # kept for parity with your code
W, Z, X, M, D, Y, tau_fn = dgps.get_data(2000, tau_fn)
# Ground-truth value for the target estimand (for log reference)
TRUE_PARAM = 4.05
print(f"=== Ground truth (for log reference) ===\nTrue parameter for E[Y(1,M(0))] ≈ {TRUE_PARAM:.2f}\n")
# =========================================================
# NN architecture helpers (dropout & width are configurable)
# =========================================================
p = 0.10
n_hidden = 100
def _get_learner(n_t: int) -> nn.Module:
return nn.Sequential(
nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
nn.Dropout(p=p), nn.Linear(n_hidden, 1)
)
def _get_adversary(n_z: int) -> nn.Module:
return nn.Sequential(
nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
nn.Dropout(p=p), nn.Linear(n_hidden, 1)
)
# =========================================================
# Model builders (dimensions inferred from data)
# =========================================================
def build_agmm_pair_for_mediated(M, X, W, Z):
"""
Build two AGMM models with correct input dims for the mediated setup.
Stage 1 (bridge on treated arm):
T = [M, X, W], Z = [M, X, Z]
Stage 2:
T = [X, W], Z = [X, Z]
"""
T1_dim = M.shape[1] + X.shape[1] + W.shape[1]
Z1_dim = M.shape[1] + X.shape[1] + Z.shape[1]
T2_dim = X.shape[1] + W.shape[1]
Z2_dim = X.shape[1] + Z.shape[1]
m1 = AGMM(_get_learner(T1_dim), _get_adversary(Z1_dim))
m2 = AGMM(_get_learner(T2_dim), _get_adversary(Z2_dim))
return m1, m2
def build_agmm2_for_mediated(M, X, W, Z):
"""For model1 (outcome bridge)."""
A_dim = M.shape[1] + X.shape[1] + W.shape[1]
B_dim = X.shape[1] + W.shape[1]
E_dim = M.shape[1] + X.shape[1] + Z.shape[1]
C_dim = X.shape[1] + Z.shape[1]
return AGMM2L2(
learnerh=_get_learner(B_dim),
learnerg=_get_learner(A_dim),
adversary1=_get_adversary(E_dim),
adversary2=_get_adversary(C_dim),
)
def build_agmm2_for_mediated_q1(M, X, W, Z):
"""For model_q1 (q-bridge)."""
A_prime_dim = X.shape[1] + W.shape[1] # (this goes to learnerg)
B_prime_dim = M.shape[1] + X.shape[1] + W.shape[1] # (this goes to learnerh)
D_prime_dim = X.shape[1] + Z.shape[1] # (this goes to adversary1)
C_prime_dim = M.shape[1] + X.shape[1] + Z.shape[1] # (this goes to adversary2)
return AGMM2L2(
learnerh=_get_learner(B_prime_dim),
learnerg=_get_learner(A_prime_dim),
adversary1=_get_adversary(D_prime_dim),
adversary2=_get_adversary(C_prime_dim),
)
=== Ground truth (for log reference) ===
True parameter for E[Y(1,M(0))] ≈ 4.05
# =========================================================
# 1) Sequential estimator (MR) with AGMM
# =========================================================
m1, m2 = build_agmm_pair_for_mediated(M, X, W, Z)
fitargs_seq = {
"n_epochs": 300, "bs": 100,
"learner_lr": 1e-4, "adversary_lr": 1e-4,
"learner_l2": 1e-3, "adversary_l2": 1e-4,
"adversary_norm_reg": 1e-3,
"device": DEVICE,
}
dml_agmm = DML_mediated(
Y, D, M, W, Z, X,
estimator="MR",
estimand="E[Y(1,M(0))]",
nn_1=[True, True], # use torch path for both bridge stages
nn_q1=[True, True], # and for q-models
model1=[m1, m2],
modelq1=[m2, m1], # your original ordering
n_folds=5, n_rep=1,
fitargs1=[fitargs_seq, fitargs_seq],
fitargsq1=[fitargs_seq, fitargs_seq],
opts={"lin_degree": 1, "burnin": 200},
)
t0 = time.perf_counter()
res_seq = dml_agmm.dml()
t1 = time.perf_counter()
summarize_dml_result("Sequential (MR) with AGMM", res_seq, t1 - t0)
# =========================================================
# 2) Simultaneous estimator (MR) with AGMM2L2
# =========================================================
agmm2_model_1 = build_agmm2_for_mediated(M, X, W, Z)
agmm2_model_q1 = build_agmm2_for_mediated_q1(M, X, W, Z)
fitargs_sim = {
"n_epochs": 600, "bs": 100,
"learner_lr": 1e-4, "adversary_lr": 1e-4,
"learner_l2": 1e-3, "adversary_l2": 1e-4,
"device": DEVICE,
}
opts_sim = {"burnin": 400}
dml2_agmm = DML_mediated(
Y, D, M, W, Z, X,
estimator="MR",
estimand="E[Y(1,M(0))]",
model1=agmm2_model_1, nn_1=True,
modelq1=agmm2_model_q1, nn_q1=True,
fitargs1=fitargs_sim,
fitargsq1=fitargs_sim,
n_folds=5, n_rep=1, opts=opts_sim,
)
t0 = time.perf_counter()
res_sim = dml2_agmm.dml()
t1 = time.perf_counter()
summarize_dml_result("Simultaneous (MR) with AGMM2L2", res_sim, t1 - t0)
Rep: 1
100%|██████████| 5/5 [03:17<00:00, 39.59s/it]
[Sequential (MR) with AGMM] time=197.93s
theta: 4.0745
SE : 5.2253
95% CI: [3.8455, 4.3035]
Rep: 1
100%|██████████| 5/5 [11:24<00:00, 136.81s/it]
[Simultaneous (MR) with AGMM2L2] time=684.06s
theta: 4.1246
SE : 5.2737
95% CI: [3.8935, 4.3557]