nnpiv.neuralnet.CentroidMMDGMM
- class nnpiv.neuralnet.CentroidMMDGMM(learner, adversary_g, kernel, centers, sigma)[source]
AGMM with centroid-based Maximum Mean Discrepancy.
- Parameters
learner – a pytorch neural net module for the learner.
adversary_g – a pytorch neural net module for the g function of the adversary.
kernel – the kernel function.
centers – numpy array containing the initial value of the centers in the Z space.
sigma – float corresponding to the precision of the kernel.