Tutorial: Implementing New Imputation Models
This guide shows how to add a new imputation model that conforms to PG-SUI’s refactored architecture:
Typed dataclass configuration (
*Config) with presets and optional dot-key overridesInstantiate → fit() → transform() contract (no arguments to
fit/transform)BaseNNImputer wrapper that manages IO, encoding/decoding, training loops, tuning, and plotting
Objective function compatible with Optuna tuning
Top-level API imports (
from pgsui import ...), consistent with existing models (seeImputeAutoencoderandAutoencoderModel)
Prerequisites
PyTorch fundamentals (
torch.nn.Module, optimizers, schedulers)SNP genotype encodings (0/1/2 with -9 or -1 for missing; IUPAC decode/encode via SNPio)
Familiarity with
BaseNNImputerlife-cycle and helper methods
Model Development Overview
A new model requires three parts:
A config dataclass:
MyNewModelConfig(typed, preset-aware)A PyTorch module:
MyNewAutoencoder(or encoder-decoder) with a clearly shaped outputAn imputer wrapper:
ImputeMyNewModel(BaseNNImputer)that plugs your module into PG-SUI’s training/tuning/eval pipeline
Note
Follow this contract strictly:
imputer = ImputeX(genotype_data=gd, config=SomeConfig.from_preset("balanced")); imputer.fit(); X_imp = imputer.transform()
Step 0 — Create a Config Dataclass
Define a typed config that mirrors the structure used by existing models (io, model, train, tune, plot, sim) and provides from_preset plus optional dot-key overrides.
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, Literal
from pgsui.data_processing.containers import (
IOConfig, TrainConfig, TuneConfig, PlotConfig, ModelConfig, SimConfig
)
from pgsui.data_processing.config import apply_dot_overrides
@dataclass
class MyNewModelConfig:
"""Top-level configuration for ImputeMyNewModel."""
io: IOConfig = field(default_factory=IOConfig)
model: ModelConfig = field(default_factory=ModelConfig)
train: TrainConfig = field(default_factory=TrainConfig)
tune: TuneConfig = field(default_factory=TuneConfig)
plot: PlotConfig = field(default_factory=PlotConfig)
sim: SimConfig = field(default_factory=SimConfig)
@classmethod
def from_preset(
cls, preset: Literal["fast", "balanced", "thorough"] = "balanced"
) -> "MyNewModelConfig":
if preset not in {"fast", "balanced", "thorough"}:
raise ValueError(f"Unknown preset: {preset}")
cfg = cls()
# Common defaults
cfg.io.verbose = True
cfg.train.validation_split = 0.2
cfg.model.activation = "relu"
cfg.model.layer_schedule = "pyramid"
cfg.model.latent_init = "random"
if preset == "fast":
cfg.model.latent_dim = 4
cfg.model.num_hidden_layers = 1
cfg.model.dropout_rate = 0.10
cfg.train.batch_size = 128
cfg.train.learning_rate = 1e-3
cfg.tune.enabled = False
elif preset == "balanced":
cfg.model.latent_dim = 8
cfg.model.num_hidden_layers = 2
cfg.model.dropout_rate = 0.20
cfg.train.batch_size = 128
cfg.train.learning_rate = 8e-4
cfg.tune.enabled = True
cfg.tune.n_trials = 100
else: # thorough
cfg.model.latent_dim = 16
cfg.model.num_hidden_layers = 3
cfg.model.dropout_rate = 0.30
cfg.train.batch_size = 64
cfg.train.learning_rate = 6e-4
cfg.tune.enabled = True
cfg.tune.n_trials = 250
return cfg
def apply_overrides(self, overrides: Dict[str, Any] | None) -> "MyNewModelConfig":
if not overrides:
return self
return apply_dot_overrides(self, overrides)
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
Step 1 — Implement the PyTorch Module
The module predicts per-SNP class logits. Match shapes used by your wrapper (e.g., (batch, n_features, n_classes)). Keep the API similar to AutoencoderModel.
from typing import List, Literal
import numpy as np
import torch
import torch.nn as nn
class MyNewDecoder(nn.Module):
"""Minimal decoder predicting per-SNP logits."""
def __init__(
self,
n_features: int,
num_classes: int = 3,
latent_dim: int = 8,
hidden_layer_sizes: List[int] | np.ndarray = (128, 64),
dropout_rate: float = 0.2,
activation: Literal["relu", "elu", "selu", "leaky_relu"] = "relu",
):
super().__init__()
if isinstance(hidden_layer_sizes, np.ndarray):
hidden_layer_sizes = hidden_layer_sizes.tolist()
act = self._resolve_activation(activation)
layers: list[nn.Module] = []
in_dim = latent_dim
for h in hidden_layer_sizes:
layers += [nn.Linear(in_dim, h), nn.BatchNorm1d(h), nn.Dropout(dropout_rate), act]
in_dim = h
layers += [nn.Linear(in_dim, n_features * num_classes)]
self.decoder = nn.Sequential(*layers)
self.n_features = n_features
self.num_classes = num_classes
def _resolve_activation(self, name: str) -> nn.Module:
return {"relu": nn.ReLU(), "elu": nn.ELU(), "selu": nn.SELU(), "leaky_relu": nn.LeakyReLU()}[name]
def forward(self, z: torch.Tensor) -> torch.Tensor:
# Output shape: (batch, n_features, n_classes)
logits = self.decoder(z)
return logits.view(-1, self.n_features, self.num_classes)
Step 2 — Write the Imputer Wrapper
Mirror the pattern in ImputeAutoencoder:
Normalize a config (dataclass, dict, or YAML path) → concrete config
Initialize logging via
LoggerManagerPrepare data via
GenotypeEncoder(0/1/2; -1 for missing)Build/train/evaluate the model
Provide
fit(self)andtransform(self)with no arguments
from __future__ import annotations
from typing import Any, Dict
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.exceptions import NotFittedError
from sklearn.model_selection import train_test_split
from snpio.analysis.genotype_encoder import GenotypeEncoder
from snpio.utils.logging import LoggerManager
from pgsui.impute.unsupervised.base import BaseNNImputer
from pgsui.data_processing.config import load_yaml_to_dataclass, apply_dot_overrides
# -- Config normalization helper -------------------------------------------
def ensure_my_config(config: MyNewModelConfig | dict | str | None) -> MyNewModelConfig:
if config is None:
return MyNewModelConfig.from_preset("balanced")
if isinstance(config, MyNewModelConfig):
return config
if isinstance(config, str):
return load_yaml_to_dataclass(
config, MyNewModelConfig, preset_builder=MyNewModelConfig.from_preset
)
if isinstance(config, dict):
base = MyNewModelConfig.from_preset(config.get("preset", "balanced"))
return apply_dot_overrides(base, _flatten_dict(config))
raise TypeError("config must be a MyNewModelConfig, dict, YAML path, or None.")
def _flatten_dict(d: dict, prefix: str = "", out: dict | None = None) -> dict:
out = out or {}
for k, v in d.items():
kk = f"{prefix}.{k}" if prefix else k
if isinstance(v, dict):
_flatten_dict(v, kk, out)
else:
out[kk] = v
return out
class ImputeMyNewModel(BaseNNImputer):
"""Impute missing 0/1/2 genotypes using MyNewDecoder."""
def __init__(
self,
genotype_data,
*,
config: MyNewModelConfig | dict | str | None = None,
overrides: dict | None = None,
):
self.model_name = "ImputeMyNewModel"
self.genotype_data = genotype_data
cfg = ensure_my_config(config)
if overrides:
cfg = apply_dot_overrides(cfg, overrides)
self.cfg = cfg
logman = LoggerManager(
__name__,
prefix=self.cfg.io.prefix,
debug=self.cfg.io.debug,
verbose=self.cfg.io.verbose,
)
self.logger = logman.get_logger()
super().__init__(
model_name=self.model_name,
genotype_data=self.genotype_data,
prefix=self.cfg.io.prefix,
device=self.cfg.train.device,
verbose=self.cfg.io.verbose,
debug=self.cfg.io.debug,
)
self.Model = MyNewDecoder
self.pgenc = GenotypeEncoder(genotype_data)
self.seed = self.cfg.io.seed
self.rng = np.random.default_rng(self.seed)
# Cache common attrs from config
self.latent_dim = self.cfg.model.latent_dim
self.dropout_rate = self.cfg.model.dropout_rate
self.num_hidden_layers = self.cfg.model.num_hidden_layers
self.activation = self.cfg.model.activation
self.batch_size = self.cfg.train.batch_size
self.learning_rate = self.cfg.train.learning_rate
self.validation_split = self.cfg.train.validation_split
self.epochs = self.cfg.train.max_epochs
self.early_stop_gen = self.cfg.train.early_stop_gen
self.min_epochs = self.cfg.train.min_epochs
# Tuning flags
self.tune = self.cfg.tune.enabled
self.n_trials = self.cfg.tune.n_trials
self.tune_metric = self.cfg.tune.metrics
# Plotting
self.plot_dpi = self.cfg.plot.dpi
self.show_plots = self.cfg.plot.show
# Filled in at fit()
self.num_classes_ = None
self.num_features_ = None
self.class_weights_ = None
self.best_params_: Dict[str, Any] = {}
self.is_fit_ = False
def fit(self) -> "ImputeMyNewModel":
"""Train the model on 0/1/2 data (with -1 for missing)."""
self.logger.info(f"Fitting {self.model_name}...")
# Prepare 0/1/2 matrix and mark missing as -1
X = self.pgenc.genotypes_012.astype(np.int64, copy=True)
X[X < 0] = -1
n_samples, self.num_features_ = X.shape
# Determine classes (diploid: 3; haploid collapses to 2)
is_haploid = self.pgenc.is_haploid
self.num_classes_ = 2 if is_haploid else 3
if is_haploid:
X[X == 2] = 1 # map {0,2} -> {0,1}
# Use base class helper for splitting
self.train_idx_, self.val_idx_, self.test_idx_ = self._train_val_test_split(X)
X_train = X[self.train_idx_]
X_val = X[self.val_idx_]
# Hidden sizes from config helper on Base
input_dim = self.num_features_ * self.num_classes_
hidden = self._compute_hidden_layer_sizes(
n_inputs=input_dim,
n_outputs=self.num_classes_,
n_samples=len(self.train_idx_),
n_hidden=self.num_hidden_layers,
latent_dim=self.latent_dim,
alpha=getattr(self.cfg.model, "layer_scaling_factor", 4.0),
schedule=self.cfg.model.layer_schedule,
)
self.best_params_ = {
"n_features": self.num_features_,
"num_classes": self.num_classes_,
"latent_dim": self.latent_dim,
"hidden_layer_sizes": hidden,
"dropout_rate": self.dropout_rate,
"activation": self.activation,
}
# Build model and train
model = self.build_model(self.Model, self.best_params_)
model.apply(self.initialize_weights)
# NOTE: Use appropriate data loader logic here (masked vs full)
train_loader = self._get_data_loaders(X_train, X_train, mask=np.ones_like(X_train, dtype=bool), batch_size=self.batch_size, shuffle=True)
val_loader = self._get_data_loaders(X_val, X_val, mask=np.ones_like(X_val, dtype=bool), batch_size=self.batch_size, shuffle=False)
loss, self.model_, self.history_ = self._train_and_validate_model(
model=model,
X_train=X_train,
y_train=X_train,
val_loader=val_loader,
lr=self.learning_rate,
l1_penalty=getattr(self.cfg.train, "l1_penalty", 0.0),
params=self.best_params_,
)
self.is_fit_ = True
return self
def transform(self) -> np.ndarray:
"""Impute the full dataset and return IUPAC strings."""
if not self.is_fit_:
raise NotFittedError("Call fit() before transform().")
X_all = self.pgenc.genotypes_012.astype(np.int64, copy=True)
X_all[X_all < 0] = -1
labels, _ = self._predict(self.model_, X_all)
miss = X_all == -1
X_imp = X_all.copy()
X_imp[miss] = labels[miss]
return self.pgenc.decode_012(X_imp)
Step 3 — (Optional) Hyperparameter Tuning
If your model supports Optuna tuning, mirror the ImputeAutoencoder pattern:
_objective(self, trial)samples hyperparameters → trains quickly → returns a scalar metric (e.g.,f1)_sample_hyperparameters(self, trial)returns a dictionary with both raw choices and amodel_paramspayload_set_best_params(self, best_params)converts the winning trial into the finalmodel_params
Keep the validation logic inside the wrapper so you can reuse PG-SUI’s scorers/plotters.
Registering the Model (Top-Level API & CLI)
Expose your model and config in the top-level package so users can import them cleanly:
# pgsui/__init__.py
from .impute.unsupervised.imputers.my_new_model import ImputeMyNewModel
from .data_processing.containers import MyNewModelConfig
__all__ = [..., "ImputeMyNewModel", "MyNewModelConfig"]
Add the class name to the CLI’s --models registry (mirroring how the other models are discovered).
Usage Examples
Python
from snpio import VCFReader
from pgsui import ImputeMyNewModel, MyNewModelConfig
gd = VCFReader("example.vcf.gz", popmapfile="example.popmap", prefix="demo")
cfg = MyNewModelConfig.from_preset("balanced").apply_overrides({"io.prefix": "mymodel_demo"})
model = ImputeMyNewModel(genotype_data=gd, config=cfg)
model.fit()
X_imp = model.transform()
CLI
pg-sui \
--input example.vcf.gz \
--popmap example.popmap \
--models ImputeMyNewModel \
--preset balanced \
--sim-strategy random_weighted \
--sim-prop 0.25 \
--set io.prefix=mymodel_demo
The CLI propagates --sim-strategy/--sim-prop choices to every selected model, so new implementations participate automatically; --disable-simulate-missing is intended for supervised/deterministic runs (unsupervised models require simulated masking).
Design Notes & Best Practices
Shapes: Keep module output
(batch, n_features, n_classes)to simplify loss computation and downstream reports.Missing values: Standardize to -1 before loss; use
ignore_index=-1in CE/focal variants.Class imbalance: Use PG-SUI’s weighting helpers (temperature/alpha/normalize) or your own, but store weights on device.
Reproducibility: Read seeds from
io.seed; pass to NumPy, PyTorch, and Optuna where applicable.Presets: Ensure
from_presettunes depth/width/epochs proportionally sofastis actually fast andthoroughexplores more.Plots/metrics: Reuse the Base helper methods so your model automatically participates in radar/PR/confusion outputs.
FAQ
Q: Do I pass arrays to ``fit`` or ``transform``?
A: No. Like ImputeAutoencoder, you pass genotype_data at construction; then call fit() and transform() with no arguments.
Q: Can my module use a full autoencoder (encoder+decoder)? A: Yes. Expose a consistent forward that returns per-SNP logits and adapt the wrapper’s latent handling accordingly.
Q: How do I add Optuna tuning quickly?
A: Implement _objective, _sample_hyperparameters, and _set_best_params following the other models.