ImputeVAE

Overview

ImputeVAE implements a variational autoencoder (VAE) for genotype imputation. The encoder predicts a latent distribution (mean and log-variance), samples latent vectors via the reparameterization trick, and the decoder reconstructs genotype logits. Training combines masked focal reconstruction loss with a KL divergence penalty.

Model formulation

Let \(X \in \mathbb{R}^{N \times L}\) be the genotype matrix encoded as 0/1/2 (missing = -1). The encoder outputs a Gaussian distribution in latent space:

\[\mu = f_{\mu}(x), \qquad \log \sigma^2 = f_{\sigma}(x)\]

Sampling uses the reparameterization trick:

\[z = \mu + \epsilon \cdot \sigma, \qquad \epsilon \sim \mathcal{N}(0, I)\]

The decoder predicts logits \(\hat{x}\) and the total loss is:

\[\mathcal{L} = \mathcal{L}_{\text{recon}} + \beta \, D_{\text{KL}}(q(z \mid x)\,\|\,p(z))\]

where \(\mathcal{L}_{\text{recon}}\) is masked focal cross-entropy over observed entries, and \(\beta\) is the KL weight.

PG-SUI scales the reconstruction term by the average number of masked loci per sample to keep the KL term from dominating on large matrices, and can anneal both \(\beta\) and focal-loss gamma during training.

Algorithm summary

  1. Encode genotypes to 0/1/2, simulate missingness once on the full matrix, and build masks for original and simulated missingness (reused across splits).

  2. Train the encoder-decoder with masked focal reconstruction loss plus KL divergence (weighted by vae.kl_beta), scaling reconstruction by the average masked loci per sample and optionally scheduling vae.kl_beta and train.gamma.

  3. Optimize with AdamW and a warmup-to-cosine learning rate schedule while monitoring validation loss for early stopping; metrics are scored on simulated-missing entries only.

  4. transform() predicts genotype probabilities and fills missing entries with MAP labels before decoding to IUPAC outputs.

Configuration highlights

ImputeVAE uses pgsui.data_processing.containers.VAEConfig, which extends the autoencoder config with a vae section:

  • vae.kl_beta controls the KL divergence weight.

  • vae.kl_beta_schedule enables optional KL annealing.

  • train.gamma and train.weights_* control reconstruction loss behavior.

  • train.gamma_schedule enables optional focal-loss gamma annealing.

See Optuna Hyperparameter Tuning for Optuna-driven tuning details.

Usage

from snpio import VCFReader
from pgsui import ImputeVAE
from pgsui.data_processing.containers import VAEConfig

gdata = VCFReader("cohort.vcf.gz", popmapfile="pops.popmap")
cfg = VAEConfig.from_preset("balanced")
cfg.vae.kl_beta = 1.25

model = ImputeVAE(genotype_data=gdata, config=cfg)
model.fit()
genotypes_iupac = model.transform()

References

Kingma, D. P., & Welling, M. (2013). Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114.