agmm.KernelLayerMMDGMM

class agmm.KernelLayerMMDGMM(learner, adversary_g, g_features, n_centers, kernel, centers=None, sigmas=None, trainable=True)[source]

AGMM with kernel layer using Maximum Mean Discrepancy.

Parameters
  • learner – a pytorch neural net module for the learner.

  • adversary_g – a pytorch neural net module for the g function of the adversary.

  • g_features – the number of output features of g.

  • n_centers – the number of centers to use in the kernel layer.

  • kernel – the kernel function.

  • centers – numpy array containing the initial value of the centers in the g(Z) space.

  • sigmas – numpy array containing the initial value of the sigma for each center.

  • trainable – whether to train the centers and the sigmas.