ImputeVAE
Overview
ImputeVAE implements a variational autoencoder (VAE) for genotype
imputation. The encoder predicts a latent distribution (mean and log-variance),
samples latent vectors via the reparameterization trick, and the decoder
reconstructs genotype logits. Training combines masked focal reconstruction
loss with a KL divergence penalty.
Model formulation
Let \(X \in \mathbb{R}^{N \times L}\) be the genotype matrix encoded as 0/1/2 (missing = -1). The encoder outputs a Gaussian distribution in latent space:
Sampling uses the reparameterization trick:
The decoder predicts logits \(\hat{x}\) and the total loss is:
where \(\mathcal{L}_{\text{recon}}\) is masked focal cross-entropy over observed entries, and \(\beta\) is the KL weight.
PG-SUI scales the reconstruction term by the average number of masked loci per sample to keep the KL term from dominating on large matrices, and can anneal both \(\beta\) and focal-loss gamma during training.
Algorithm summary
Encode genotypes to 0/1/2, simulate missingness once on the full matrix, and build masks for original and simulated missingness (reused across splits).
Train the encoder-decoder with masked focal reconstruction loss plus KL divergence (weighted by
vae.kl_beta), scaling reconstruction by the average masked loci per sample and optionally schedulingvae.kl_betaandtrain.gamma.Optimize with AdamW and a warmup-to-cosine learning rate schedule while monitoring validation loss for early stopping; metrics are scored on simulated-missing entries only.
transform()predicts genotype probabilities and fills missing entries with MAP labels before decoding to IUPAC outputs.
Configuration highlights
ImputeVAE uses pgsui.data_processing.containers.VAEConfig, which
extends the autoencoder config with a vae section:
vae.kl_betacontrols the KL divergence weight.vae.kl_beta_scheduleenables optional KL annealing.train.gammaandtrain.weights_*control reconstruction loss behavior.train.gamma_scheduleenables optional focal-loss gamma annealing.
See Optuna Hyperparameter Tuning for Optuna-driven tuning details.
Usage
from snpio import VCFReader
from pgsui import ImputeVAE
from pgsui.data_processing.containers import VAEConfig
gdata = VCFReader("cohort.vcf.gz", popmapfile="pops.popmap")
cfg = VAEConfig.from_preset("balanced")
cfg.vae.kl_beta = 1.25
model = ImputeVAE(genotype_data=gdata, config=cfg)
model.fit()
genotypes_iupac = model.transform()
References
Kingma, D. P., & Welling, M. (2013). Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114.