"""
This module provides implementations of adversarial generalized method of moments (AGMM) estimators using neural networks.
Classes:
_BaseAGMM: Base class for AGMM models.
_BaseSupLossAGMM: Base class for AGMM models with supervised loss.
AGMM: Adversarial Generalized Method of Moments estimator.
KernelLayerMMDGMM: AGMM with kernel layer using Maximum Mean Discrepancy.
CentroidMMDGMM: AGMM with centroid-based Maximum Mean Discrepancy.
KernelLossAGMM: AGMM with kernel loss.
MMDGMM: AGMM with Maximum Mean Discrepancy.
"""
# Licensed under the MIT License.
import os
import numpy as np
import tempfile
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from nnpiv.neuralnet.oadam import OAdam
from nnpiv.neuralnet.rbflayer import RBF
# TODO. This epsilon is used only because pytorch 1.5 has an instability in torch.cdist
# when the input distance is close to zero, due to instability of the square root in
# automatic differentiation. Should be removed once pytorch fixes the instability.
# It can be set to 0 if using pytorch 1.4.0
EPSILON = 1e-2
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def add_weight_decay(net, l2_value, skip_list=()):
decay, no_decay = [], []
for name, param in net.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
no_decay.append(param)
else:
decay.append(param)
return [{'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': l2_value}]
def _kernel(x, y, basis_func, sigma):
return basis_func(torch.cdist(x, y + EPSILON) * torch.abs(sigma))
class _BaseAGMM:
"""
Base class for AGMM models.
Methods:
_pretrain: Prepares the variables required to begin training.
predict: Predicts outcomes using the fitted AGMM model.
"""
def _pretrain(self, Z, T, Y,
learner_l2, adversary_l2, adversary_norm_reg,
learner_lr, adversary_lr, n_epochs, bs, train_learner_every, train_adversary_every,
warm_start, logger, model_dir, device, verbose, add_sample_inds=False):
""" Prepares the variables required to begin training.
"""
self.verbose = verbose
if device is None:
device = Z.device if isinstance(Z, torch.Tensor) else DEVICE
self.device = device
if not os.path.exists(model_dir):
os.makedirs(model_dir)
self.tempdir = tempfile.TemporaryDirectory(dir=model_dir)
self.model_dir = self.tempdir.name
self.n_epochs = n_epochs
def to_cpu(x):
return x.detach().cpu() if isinstance(x, torch.Tensor) else x
Z, T, Y = map(to_cpu, (Z, T, Y))
self.train_ds = TensorDataset(Z, T, Y) if not add_sample_inds else \
TensorDataset(Z, T, Y, torch.tensor(np.arange(Y.shape[0])))
# pin only if tensors are CPU and we train on CUDA
pin = (device is not None and isinstance(device, torch.device)
and device.type == "cuda")
self.train_dl = DataLoader(self.train_ds, batch_size=bs, shuffle=True,
pin_memory=pin)
self.learner = self.learner.to(self.device)
self.adversary = self.adversary.to(self.device)
if not warm_start:
self.learner.apply(lambda m: (
m.reset_parameters() if hasattr(m, 'reset_parameters') else None))
self.adversary.apply(lambda m: (
m.reset_parameters() if hasattr(m, 'reset_parameters') else None))
beta1 = 0.
self.optimizerD = OAdam(add_weight_decay(self.learner, learner_l2),
lr=learner_lr, betas=(beta1, .01))
self.optimizerG = OAdam(add_weight_decay(
self.adversary, adversary_l2, skip_list=self.skip_list), lr=adversary_lr, betas=(beta1, .01))
if logger is not None:
self.writer = SummaryWriter()
return Z, T, Y
def predict(self, T, model='avg', burn_in=0, alpha=None):
"""
Parameters
----------
T : treatments
model : one of ('avg', 'final'), whether to use an average of models or the final
burn_in : discard the first "burn_in" epochs when doing averaging
alpha : if not None but a float, then it also returns the a/2 and 1-a/2, percentile of
the predictions across different epochs (proxy for a confidence interval)
"""
if model == 'avg':
preds = np.array([
torch.load(os.path.join(self.model_dir, f"epoch{i}"),
map_location=DEVICE, weights_only=False).to(DEVICE).eval()(
T if isinstance(T, torch.Tensor) and T.device == DEVICE
else torch.as_tensor(T, dtype=torch.float32, device=DEVICE)
).detach().cpu().numpy()
for i in np.arange(burn_in, self.n_epochs)
])
if alpha is None:
return np.mean(preds, axis=0)
else:
return np.mean(preds, axis=0),\
np.percentile(
preds, 100 * alpha / 2, axis=0), np.percentile(preds, 100 * (1 - alpha / 2), axis=0)
if model == 'final':
return (
torch.load(os.path.join(self.model_dir, f"epoch{self.n_epochs - 1}"),
map_location=DEVICE, weights_only=False)
.to(DEVICE).eval()(
T if isinstance(T, torch.Tensor) and T.device == DEVICE
else torch.as_tensor(T, dtype=torch.float32, device=DEVICE)
)
.detach().cpu().numpy()
)
if isinstance(model, int):
return (
torch.load(os.path.join(self.model_dir, f"epoch{model}"),
map_location=DEVICE, weights_only=False)
.to(DEVICE).eval()(
T if isinstance(T, torch.Tensor) and T.device == DEVICE
else torch.as_tensor(T, dtype=torch.float32, device=DEVICE)
)
.detach().cpu().numpy()
)
class _BaseSupLossAGMM(_BaseAGMM):
"""
Base class for AGMM models with supervised loss.
Methods:
fit: Fits the AGMM model with supervised loss to the provided data.
"""
def fit(self, Z, T, Y,
learner_l2=1e-3, adversary_l2=1e-4, adversary_norm_reg=1e-3,
learner_lr=0.001, adversary_lr=0.001, n_epochs=100, bs=100, train_learner_every=1, train_adversary_every=1,
ols_weight=0., warm_start=False, logger=None, model_dir='.', device=None, verbose=0):
"""
Parameters
----------
Z : instruments
T : treatments
Y : outcome
learner_l2, adversary_l2 : l2_regularization of parameters of learner and adversary
adversary_norm_reg : adveresary norm regularization weight
learner_lr : learning rate of the Adam optimizer for learner
adversary_lr : learning rate of the Adam optimizer for adversary
n_epochs : how many passes over the data
bs : batch size
train_learner_every : after how many training iterations of the adversary should we train the learner
ols_weight : weight on OLS (square loss) objective
warm_start : if False then network parameters are initialized at the beginning, otherwise we start
from their current weights
logger : a function that takes as input (learner, adversary, epoch, writer) and is called after every epoch
Supposed to be used to log the state of the learning.
model_dir : folder where to store the learned models after every epoch
"""
Z, T, Y = self._pretrain(Z, T, Y,
learner_l2, adversary_l2, adversary_norm_reg,
learner_lr, adversary_lr, n_epochs, bs, train_learner_every, train_adversary_every,
warm_start, logger, model_dir, device, verbose)
for epoch in range(n_epochs):
if self.verbose > 0:
print("Epoch #", epoch, sep="")
for it, (zb, xb, yb) in enumerate(self.train_dl):
zb = zb.to(self.device, non_blocking=True)
xb = xb.to(self.device, non_blocking=True)
yb = yb.to(self.device, non_blocking=True)
if (it % train_learner_every == 0):
self.learner.train()
pred = self.learner(xb)
test = self.adversary(zb)
D_loss = torch.mean(
(yb - pred) * test) + ols_weight * torch.mean((yb - pred)**2)
self.optimizerD.zero_grad()
D_loss.backward()
self.optimizerD.step()
self.learner.eval()
if (it % train_adversary_every == 0):
self.adversary.train()
pred = self.learner(xb)
reg = 0
if self.adversary_reg:
test, reg = self.adversary(zb, reg=True)
else:
test = self.adversary(zb)
G_loss = - torch.mean((yb - pred) *
test) + torch.mean(test**2)
G_loss += adversary_norm_reg * reg
self.optimizerG.zero_grad()
G_loss.backward()
self.optimizerG.step()
self.adversary.eval()
torch.save(self.learner, os.path.join(
self.model_dir, "epoch{}".format(epoch)))
if logger is not None:
logger(self.learner, self.adversary, epoch, self.writer)
if logger is not None:
self.writer.flush()
self.writer.close()
return self
[docs]class AGMM(_BaseSupLossAGMM):
"""
Adversarial Generalized Method of Moments estimator.
Parameters:
learner : a pytorch neural net module for the learner.
adversary : a pytorch neural net module for the adversary.
"""
def __init__(self, learner, adversary):
self.learner = learner
self.adversary = adversary
# whether we have a norm penalty for the adversary
self.adversary_reg = False
# which adversary parameters to not ell2 penalize
self.skip_list = []
[docs]class KernelLayerMMDGMM(_BaseSupLossAGMM):
"""
AGMM with kernel layer using Maximum Mean Discrepancy.
Parameters:
learner : a pytorch neural net module for the learner.
adversary_g : a pytorch neural net module for the g function of the adversary.
g_features : the number of output features of g.
n_centers : the number of centers to use in the kernel layer.
kernel : the kernel function.
centers : numpy array containing the initial value of the centers in the g(Z) space.
sigmas : numpy array containing the initial value of the sigma for each center.
trainable : whether to train the centers and the sigmas.
"""
def __init__(self, learner, adversary_g, g_features,
n_centers, kernel, centers=None, sigmas=None, trainable=True):
class Adversary(torch.nn.Module):
def __init__(self, g, g_features, n_centers, basis_func,
centers=None, sigmas=None, trainable=True):
super(Adversary, self).__init__()
self.g = g
self.rbf = RBF(g_features, n_centers, basis_func,
centres=centers, sigmas=sigmas, trainable=trainable)
self.beta = nn.Linear(n_centers, 1)
def forward(self, x, reg=False):
test = self.beta(self.rbf(self.g(x)))
if not reg:
return test
beta = self.beta.weight
K = self.rbf(self.rbf.centres + EPSILON)
K = (K + K.T) / 2
rkhs_norm = (beta @ K @ beta.T)[0][0]
return test, rkhs_norm
self.learner = learner
self.adversary = Adversary(adversary_g, g_features, n_centers,
kernel, centers=centers, sigmas=sigmas, trainable=trainable)
# whether we have a norm penalty for the adversary
self.adversary_reg = True
# which adversary parameters to not ell2 penalize
self.skip_list = ['rbf.centres', 'beta.weight']
[docs]class CentroidMMDGMM(_BaseSupLossAGMM):
"""
AGMM with centroid-based Maximum Mean Discrepancy.
Parameters:
learner : a pytorch neural net module for the learner.
adversary_g : a pytorch neural net module for the g function of the adversary.
kernel : the kernel function.
centers : numpy array containing the initial value of the centers in the Z space.
sigma : float corresponding to the precision of the kernel.
"""
def __init__(self, learner, adversary_g,
kernel, centers, sigma):
class Adversary(torch.nn.Module):
def __init__(self, g, basis_func, centers, sigma):
super(Adversary, self).__init__()
self.g = g
self.centers = nn.Parameter(
torch.Tensor(centers), requires_grad=False)
self.basis_func = basis_func
if hasattr(sigma, '__len__'):
self.init_sigma = sigma.reshape(1, -1)
self.sigma = nn.Parameter(torch.Tensor(self.init_sigma))
else:
self.init_sigma = sigma
self.sigma = nn.Parameter(torch.tensor(self.init_sigma))
self.beta = nn.Linear(centers.shape[0], 1)
self.reset_parameters()
def reset_parameters(self):
if hasattr(self.init_sigma, '__len__'):
self.sigma.data = torch.Tensor(
self.init_sigma).to(self.sigma.device)
else:
self.sigma.data = torch.tensor(
self.init_sigma).to(self.sigma.device)
def forward(self, x, reg=False):
x1, x2 = self.g(x), self.g(self.centers)
K12 = _kernel(x1, x2, self.basis_func, self.sigma)
test = self.beta(K12)
if reg:
K22 = _kernel(x2, x2, self.basis_func, self.sigma)
rkhs_reg = (self.beta.weight @ (K22 + K22.T) @
self.beta.weight.T)[0][0] / 2
return test, rkhs_reg
return test
self.learner = learner
self.adversary = Adversary(
adversary_g, kernel, centers, sigma=sigma)
# whether we have a norm penalty for the adversary
self.adversary_reg = True
# which adversary parameters to not ell2 penalize
self.skip_list = ['beta.weight']
[docs]class KernelLossAGMM(_BaseAGMM):
"""
AGMM with kernel loss.
Parameters:
learner : a pytorch neural net module for the learner.
adversary_g : a pytorch neural net module for the g function of the adversary.
kernel : the kernel function.
sigma : float corresponding to the precision of the kernel.
"""
def __init__(self, learner, adversary_g, kernel, sigma):
class Adversary(torch.nn.Module):
def __init__(self, g, basis_func, sigma):
super(Adversary, self).__init__()
self.g = g
self.basis_func = basis_func
if hasattr(sigma, '__len__'):
self.init_sigma = sigma.reshape(1, -1)
self.sigma = nn.Parameter(torch.Tensor(self.init_sigma))
else:
self.init_sigma = sigma
self.sigma = nn.Parameter(torch.tensor(self.init_sigma))
self.reset_parameters()
def reset_parameters(self):
if hasattr(self.init_sigma, '__len__'):
self.sigma.data = torch.Tensor(
self.init_sigma).to(self.sigma.device)
else:
self.sigma.data = torch.tensor(
self.init_sigma).to(self.sigma.device)
def forward(self, x1, x2):
return _kernel(self.g(x1), self.g(x2), self.basis_func, self.sigma)
self.learner = learner
self.adversary = Adversary(adversary_g, kernel, sigma)
self.skip_list = []
def fit(self, Z, T, Y,
learner_l2=1e-3, adversary_l2=1e-4,
learner_lr=0.001, adversary_lr=0.001, n_epochs=100, bs=100, train_learner_every=1, train_adversary_every=1,
ols_weight=0.0, warm_start=False, logger=None, model_dir='.', device=None, verbose=0):
"""
Parameters
----------
Z : instruments
T : treatments
Y : outcome
learner_l2, adversary_l2 : l2_regularization of parameters of learner and adversary
learner_lr : learning rate of the Adam optimizer for learner
adversary_lr : learning rate of the Adam optimizer for adversary
n_epochs : how many passes over the data
bs : batch size
train_learner_every : after how many training iterations of the adversary should we train the learner
ols_weight : weight on OLS (square loss) objective
warm_start : whether to reset weights or not
logger : a function that takes as input (learner, adversary, epoch, writer) and is called after every epoch
Supposed to be used to log the state of the learning.
model_dir : folder where to store the learned models after every epoch
"""
Z, T, Y = self._pretrain(Z, T, Y,
learner_l2, adversary_l2, 0,
learner_lr, adversary_lr, n_epochs, bs, train_learner_every, train_adversary_every,
warm_start, logger, model_dir, device, verbose)
train_dl2 = DataLoader(self.train_ds, batch_size=bs, shuffle=True)
for epoch in range(n_epochs):
print("Epoch #", epoch, sep="")
for it, ((zb1, xb1, yb1), (zb2, xb2, yb2)) in enumerate(zip(self.train_dl, train_dl2)):
zb1 = zb1.to(self.device, non_blocking=True); xb1 = xb1.to(self.device, non_blocking=True); yb1 = yb1.to(self.device, non_blocking=True)
zb2 = zb2.to(self.device, non_blocking=True); xb2 = xb2.to(self.device, non_blocking=True); yb2 = yb2.to(self.device, non_blocking=True)
if it % train_learner_every == 0:
self.learner.train()
psi1, psi2 = yb1 - \
self.learner(xb1), yb2 - self.learner(xb2)
kernel = self.adversary(zb1, zb2)
D_loss = psi1.T @ kernel @ psi2 / (bs**2)
D_loss += ols_weight * \
(torch.mean(psi1**2) + torch.mean(psi2**2)) / 2
self.optimizerD.zero_grad()
D_loss.backward()
self.optimizerD.step()
self.learner.eval()
if it % train_adversary_every == 0:
self.adversary.train()
psi1, psi2 = yb1 - \
self.learner(xb1), yb2 - self.learner(xb2)
kernel = self.adversary(zb1, zb2)
G_loss = - psi1.T @ kernel @ psi2 / (bs**2)
self.optimizerG.zero_grad()
G_loss.backward()
self.optimizerG.step()
self.adversary.eval()
torch.save(self.learner, os.path.join(
self.model_dir, "epoch{}".format(epoch)))
if logger is not None:
logger(self.learner, self.adversary, epoch, self.writer)
if logger is not None:
self.writer.flush()
self.writer.close()
return self
[docs]class MMDGMM(_BaseAGMM):
"""
AGMM with Maximum Mean Discrepancy.
Parameters:
learner : a pytorch neural net module for the learner.
adversary_g : a pytorch neural net module for the g function of the adversary.
n_samples : number of samples.
kernel : the kernel function.
sigma : float corresponding to the precision of the kernel.
"""
def __init__(self, learner, adversary_g, n_samples, kernel, sigma):
class Adversary(torch.nn.Module):
def __init__(self, g, n_samples, basis_func, sigma):
super(Adversary, self).__init__()
self.g = g
self.basis_func = basis_func
if hasattr(sigma, '__len__'):
self.init_sigma = sigma.reshape(1, -1)
self.sigma = nn.Parameter(torch.Tensor(self.init_sigma))
else:
self.init_sigma = sigma
self.sigma = nn.Parameter(torch.tensor(self.init_sigma))
self.beta = nn.Parameter(torch.Tensor(n_samples, 1))
self.reset_parameters()
def reset_parameters(self):
if hasattr(self.init_sigma, '__len__'):
self.sigma.data = torch.Tensor(
self.init_sigma).to(self.sigma.device)
else:
self.sigma.data = torch.tensor(
self.init_sigma).to(self.sigma.device)
stdv = 1. / np.sqrt(self.beta.size(0))
nn.init.uniform_(self.beta, -stdv, stdv)
def forward(self, x1, x2, x3, id1, id2, id3, reg=False):
x1, x2 = self.g(x1), self.g(x2)
K12 = _kernel(x1, x2, self.basis_func, self.sigma[:, id2]) / 2
K12 += _kernel(x2, x1, self.basis_func,
self.sigma[:, id1]).T / 2
ratio2 = self.beta.size(0) / id2.shape[0]
test = K12 @ self.beta[id2] * ratio2
if reg:
x3 = self.g(x3)
K31 = _kernel(x3, x1, self.basis_func,
self.sigma[:, id1]) / 2
K31 += _kernel(x1, x3, self.basis_func,
self.sigma[:, id3]).T / 2
K32 = _kernel(x3, x2, self.basis_func,
self.sigma[:, id2]) / 2
K32 += _kernel(x2, x3, self.basis_func,
self.sigma[:, id3]).T / 2
ratio3 = self.beta.size(0) / id3.shape[0]
rkhs_reg = (self.beta[id3].T @ K32 @ self.beta[id2] *
ratio3 * ratio2)[0][0]
u = self.beta[id3].T @ K31 * ratio3
l2_reg = (u @ test)[0][0] / x1.size(0)
return test, rkhs_reg, l2_reg
return test
self.learner = learner
self.adversary = Adversary(adversary_g, n_samples, kernel, sigma)
self.skip_list = ['beta']
def fit(self, Z, T, Y,
learner_l2=1e-3, adversary_l2=1e-4, adversary_norm_reg=1e-3,
learner_lr=0.001, adversary_lr=0.001, n_epochs=100, bs1=100, bs2=100, bs3=100, train_learner_every=1, train_adversary_every=1,
ols_weight=0.0, warm_start=False, logger=None, model_dir='.', device=None, verbose=0):
"""
Parameters
----------
Z : instruments
T : treatments
Y : outcome
learner_l2, adversary_l2 : l2_regularization of parameters of learner and adversary
learner_lr : learning rate of the Adam optimizer for learner
adversary_lr : learning rate of the Adam optimizer for adversary
n_epochs : how many passes over the data
bs : batch size
train_learner_every : after how many training iterations of the adversary should we train the learner
ols_weight : weight on OLS (square loss) objective
warm_start : whether to reset weights or not
logger : a function that takes as input (learner, adversary, epoch, writer) and is called after every epoch
Supposed to be used to log the state of the learning.
model_dir : folder where to store the learned models after every epoch
"""
Z, T, Y = self._pretrain(Z, T, Y,
learner_l2, adversary_l2, adversary_norm_reg,
learner_lr, adversary_lr, n_epochs, bs1, train_learner_every, train_adversary_every,
warm_start, logger, model_dir, device, verbose, add_sample_inds=True)
sample_inds = np.arange(Y.shape[0]).astype(int)
for epoch in range(n_epochs):
print("Epoch #", epoch, sep="")
for it, (zb1, xb1, yb1, idb1) in enumerate(self.train_dl):
zb1 = zb1.to(self.device, non_blocking=True); xb1 = xb1.to(self.device, non_blocking=True)
yb1 = yb1.to(self.device, non_blocking=True); idb1 = idb1.to(self.device)
idb2 = np.random.choice(sample_inds, bs2, replace=False)
zb2 = Z[idb2].to(self.device)
idb3 = np.random.choice(sample_inds, bs3, replace=False)
zb3 = Z[idb3].to(self.device)
if it % train_learner_every == 0:
self.learner.train()
psi = yb1 - self.learner(xb1)
test = self.adversary(zb1, zb2, zb3, idb1, idb2, idb3)
D_loss = torch.mean(psi * test)
D_loss += ols_weight * torch.mean(psi**2)
self.optimizerD.zero_grad()
D_loss.backward()
self.optimizerD.step()
self.learner.eval()
if it % train_adversary_every == 0:
self.adversary.train()
psi = yb1 - self.learner(xb1)
test, rkhs_reg, l2_reg = self.adversary(
zb1, zb2, zb3, idb1, idb2, idb3, reg=True)
G_loss = - torch.mean(psi * test)
G_loss += adversary_norm_reg * rkhs_reg
G_loss += l2_reg
self.optimizerG.zero_grad()
G_loss.backward()
self.optimizerG.step()
self.adversary.eval()
torch.save(self.learner, os.path.join(
model_dir, "epoch{}".format(epoch)))
if logger is not None:
logger(self.learner, self.adversary, epoch, self.writer)
if logger is not None:
self.writer.flush()
self.writer.close()
return self