agmm.KernelLayerMMDGMM
- class agmm.KernelLayerMMDGMM(learner, adversary_g, g_features, n_centers, kernel, centers=None, sigmas=None, trainable=True)[source]
AGMM with kernel layer using Maximum Mean Discrepancy.
- Parameters
learner – a pytorch neural net module for the learner.
adversary_g – a pytorch neural net module for the g function of the adversary.
g_features – the number of output features of g.
n_centers – the number of centers to use in the kernel layer.
kernel – the kernel function.
centers – numpy array containing the initial value of the centers in the g(Z) space.
sigmas – numpy array containing the initial value of the sigma for each center.
trainable – whether to train the centers and the sigmas.