Source code for pgsui.data_processing.transformers
# Standard library imports
import copy
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional, Tuple
# Third-party imports
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from pgsui.utils.misc import validate_input_type
if TYPE_CHECKING:
from snpio import TreeParser
[docs]
class SimGenotypeDataTransformer:
"""Simulates missing genotypes at the locus level on a 2D integer matrix.
This transformer masks a proportion of known genotypes in the input matrix X, setting them to a specified missing value. The masking can be done randomly or based on inverse genotype frequencies, with an option to boost the likelihood of masking heterozygous genotypes.
Args:
prop_missing (float): Proportion of *known* loci to mask (0..1).
strategy (Literal): Strategy name.
missing_val (int): Missing code value (default: -9).
seed (int | None): RNG seed.
logger (logging.Logger | None): Logger for messages.
het_boost (float): Multiplier for heterozygotes in inv-genotype mode.
"""
def __init__(
self,
*,
prop_missing: float = 0.1,
strategy: Literal["random", "random_inv_genotype"] = "random",
missing_val: int = -1,
seed: int | None = None,
logger: logging.Logger | None = None,
het_boost: float = 1.0,
):
self.prop_missing = float(prop_missing)
self.strategy = strategy
self.missing_val = int(missing_val)
self.seed = seed
self.rng = np.random.default_rng(seed)
self.het_boost = float(het_boost)
self.logger = logger or logging.getLogger(__name__)
[docs]
def fit(self, X, y=None) -> "SimGenotypeDataTransformer":
"""Stateless.
Args:
X (np.ndarray): (n_samples, n_features), integer codes {0..9} or <0 as missing.
y: Ignored.
"""
return self
[docs]
def transform(self, X: np.ndarray) -> tuple[np.ndarray, dict]:
"""Apply missing-data simulation on a 2D genotype matrix.
Args:
X (np.ndarray): (n_samples, n_features), integer codes {0..9} or <0 as missing.
Returns:
tuple[np.ndarray, dict]: (X_masked, masks) where masks has keys: 'original': original missing (boolean 2D). 'simulated': loci masked here (boolean 2D). 'all': union of original + simulated (boolean 2D)
"""
if X.ndim != 2:
msg = f"X must be 2D, got shape {X.shape}"
self.logger.error(msg)
raise ValueError(msg)
X = np.asarray(X)
original_mask = X < 0
sim_mask = self._simulate_missing_mask(X, original_mask)
sim_mask = sim_mask & (~original_mask)
sim_mask = self._validate_mask(sim_mask)
all_mask = original_mask | sim_mask
Xt = X.copy()
Xt[all_mask] = self.missing_val
masks = {"original": original_mask, "simulated": sim_mask, "all": all_mask}
return Xt, masks
# ---- strategies ----
def _simulate_missing_mask(
self, X: np.ndarray, original_mask: np.ndarray
) -> np.ndarray:
"""Simulate missingness mask based on the chosen strategy.
Args:
X (np.ndarray): Input genotype matrix.
original_mask (np.ndarray): Boolean mask of original missing values.
Returns:
np.ndarray: Simulated missing mask.
"""
if self.strategy == "random":
return self._simulate_random(original_mask)
elif self.strategy == "random_inv_genotype":
return self._simulate_inv_genotype(X, original_mask)
msg = "strategy must be one of {'random','random_inv_genotype'}"
self.logger.error(msg)
raise ValueError(msg)
def _simulate_random(self, original_mask: np.ndarray) -> np.ndarray:
rows, cols = np.where(~original_mask)
n_known = len(rows)
mask = np.zeros_like(original_mask, dtype=bool)
if n_known == 0:
return mask
n_to_mask = int(np.floor(self.prop_missing * n_known))
if n_to_mask <= 0:
return mask
idx = self.rng.choice(n_known, size=n_to_mask, replace=False)
mask[rows[idx], cols[idx]] = True
return mask
def _simulate_inv_genotype(
self, X: np.ndarray, original_mask: np.ndarray
) -> np.ndarray:
"""Simulate missingness mask inversely proportional to genotype frequencies.
Args:
X (np.ndarray): Input genotype matrix.
original_mask (np.ndarray): Boolean mask of original missing values.
Returns:
np.ndarray: Simulated missing mask. 0..3: homozygous (0,1,2,3). 4..9: heterozygous (0/1,0/2,0/3,1/2,1/3,2/3).
"""
rows, cols = np.where(~original_mask)
n_known = len(rows)
mask = np.zeros_like(original_mask, dtype=bool)
if n_known == 0:
return mask
# Global genotype frequencies (0..9) from all known
vals = X[~original_mask].astype(int)
vals = vals[(vals >= 0) & (vals < 10)]
if vals.size == 0:
return self._simulate_random(original_mask)
cnt = np.bincount(vals, minlength=10).astype(float)
freqs = cnt / (cnt.sum() + 1e-12)
# Candidate weights
geno_known = X[rows, cols].astype(int) # (n_known,)
inv = 1.0 / (freqs[geno_known] + 1e-12)
# Optional het boost (heterozygous codes are 4..9)
if self.het_boost != 1.0:
is_het = (geno_known >= 4) & (geno_known <= 9)
inv = inv * np.where(is_het, self.het_boost, 1.0)
n_to_mask = int(np.floor(self.prop_missing * n_known))
if n_to_mask <= 0:
return mask
probs = inv / (inv.sum() + 1e-12)
idx = self.rng.choice(n_known, size=n_to_mask, replace=False, p=probs)
mask[rows[idx], cols[idx]] = True
return mask
def _validate_mask(self, mask: np.ndarray) -> np.ndarray:
"""Avoid fully-masked rows/columns.
Args:
mask (np.ndarray): Input boolean mask.
Returns:
np.ndarray: Validated mask.
"""
rng = self.rng
# columns
full_cols = np.where(mask.all(axis=0))[0]
for c in full_cols:
r = int(rng.integers(0, mask.shape[0]))
mask[r, c] = False
# rows
full_rows = np.where(mask.all(axis=1))[0]
for r in full_rows:
c = int(rng.integers(0, mask.shape[1]))
mask[r, c] = False
return mask
[docs]
class SimMissingTransformer(BaseEstimator, TransformerMixin):
"""Simulate missing data on genotypes encoded as 0/1/2 integers.
This transformer is designed to work with genotype data that has been preprocessed into a suitable format. It simulates missing data according to various strategies, allowing for the testing and evaluation of imputation methods. The simulated missing data can be controlled in terms of proportion and distribution across samples and loci.
Args:
genotype_data (GenotypeData object): GenotypeData instance.
prop_missing (float, optional): Proportion of missing data desired in output. Must be in the interval [0, 1]. Defaults to 0.1
strategy (Literal["nonrandom", "nonrandom_weighted", "random_weighted", "random_weighted_inv", "random"]): Strategy for simulating missing data. "random": Uniformly masks genotypes at random among eligible entries until the target missing proportion is reached. "random_weighted": Masks genotypes at random with probabilities proportional to their observed genotype frequencies in each column (more common genotypes are more likely to be masked). "random_weighted_inv": Masks genotypes at random with probabilities inversely proportional to their observed genotype frequencies in each column (rarer genotypes are more likely to be masked). "nonrandom": Uses the supplied genotype tree to place missing data on clades that are sampled uniformly from internal and/or tip nodes, producing phylogenetically clustered missingness. "nonrandom_weighted": As in "nonrandom", but clades are sampled with probabilities proportional to their branch lengths, concentrating missingness on longer branches (e.g., mimicking locus dropout tied to evolutionary divergence). Defaults to "random".
missing_val (int, optional): Value that represents missing data. Defaults to -9.
mask_missing (bool, optional): True if you want to skip original missing values when simulating new missing data, False otherwise. Defaults to True.
verbose (bool, optional): Verbosity level. Defaults to 0.
tol (float): Tolerance to reach proportion specified in self.prop_missing. Defaults to 1/num_snps*num_inds
max_tries (int): Maximum number of tries to reach targeted missing data proportion within specified tol. If None, num_inds will be used. Defaults to None.
Attributes:
original_missing_mask_ (numpy.ndarray): Array with boolean mask for original missing locations.
sim_missing_mask_ (numpy.ndarray): Array with boolean mask for simulated missing locations, excluding the original ones.
all_missing_mask_ (numpy.ndarray): Array with boolean mask for all missing locations, including both simulated and original.
"""
[docs]
def __init__(
self,
genotype_data,
*,
tree_parser: Optional["TreeParser"] = None,
prop_missing=0.1,
strategy="random",
missing_val=-9,
mask_missing=True,
verbose=0,
tol=None,
max_tries=None,
seed: Optional[int] = None,
logger: logging.Logger | None = None,
) -> None:
"""Initialize the SimMissingTransformer.
Args:
genotype_data (GenotypeData object): GenotypeData instance.
tree_parser (TreeParser | None): TreeParser instance with a loaded tree. Required for "nonrandom" and "nonrandom_weighted" strategies.
prop_missing (float, optional): Proportion of missing data desired in output. Must be in the interval [0, 1]. Defaults to 0.1
strategy (Literal["nonrandom", "nonrandom_weighted", "random_weighted", "random_weighted_inv", "random"]): Strategy for simulating missing data. "random": Uniformly masks genotypes at random among eligible entries until the target missing proportion is reached. "random_weighted": Masks genotypes at random with probabilities proportional to their observed genotype frequencies in each column (more common genotypes are more likely to be masked). "random_weighted_inv": Masks genotypes at random with probabilities inversely proportional to their observed genotype frequencies in each column (rarer genotypes are more likely to be masked). "nonrandom": Uses the supplied genotype tree to place missing data on clades that are sampled uniformly from internal and/or tip nodes, producing phylogenetically clustered missingness. "nonrandom_weighted": As in "nonrandom", but clades are sampled with probabilities proportional to their branch lengths, concentrating missingness on longer branches (e.g., mimicking locus dropout tied to evolutionary divergence). Defaults to "random".
missing_val (int, optional): Value that represents missing data. Defaults to -9.
mask_missing (bool, optional): True if you want to skip original missing values when simulating new missing data, False otherwise. Defaults to True.
verbose (bool, optional): Verbosity level. Defaults to 0.
tol (float): Tolerance to reach proportion specified in self.prop_missing. Defaults to 1/num_snps*num_inds
max_tries (int): Maximum number of tries to reach targeted missing data proportion within specified tol. If None, num_inds will be used. Defaults to None.
seed (int | None): RNG seed.
logger (logging.Logger | None): Logger for messages.
"""
self.genotype_data = genotype_data
self.tree_parser = tree_parser
self.prop_missing = prop_missing
self.strategy = strategy
self.missing_val = missing_val
self.mask_missing = mask_missing
self.verbose = verbose
self.tol = tol
self.max_tries = max_tries
self.seed = seed
self.rng = (
np.random.default_rng(seed) if seed is not None else np.random.default_rng()
)
self.logger = logger or logging.getLogger(__name__)
[docs]
def fit(self, X: np.ndarray, y=None) -> "SimMissingTransformer":
"""Fit to input data X by simulating missing data.
Missing data will be simulated in varying ways depending on the ``strategy`` setting.
Args:
X (np.ndarray): Data with which to simulate missing data. It should have already been imputed with one of the non-machine learning simple imputers. ``X`` may contain original missing values; simulation is applied to eligible entries depending on mask_missing.
Raises:
TypeError: ``SimGenotypeDataTreeTransformer.tree`` must not be NoneType when using strategy="nonrandom" or "nonrandom_weighted".
ValueError: Invalid ``strategy`` parameter provided.
"""
X = np.asarray(validate_input_type(X, return_type="array")).astype(np.float32)
self.logger.debug(
f"Adding {self.prop_missing} missing data per column using strategy: {self.strategy}"
)
if not np.isnan(self.missing_val):
X = X.copy()
X[X == self.missing_val] = np.nan
self.original_missing_mask_ = np.isnan(X)
if self.strategy == "random":
present = ~self.original_missing_mask_
self.mask_ = np.zeros_like(X, dtype=bool)
# sample only over present sites
draws = self.rng.random(X.shape)
self.mask_[present] = draws[present] < self.prop_missing
if self.mask_missing:
# keep original-missing as not simulated
pass
else:
# optionally also include original-missing as masked (no-op in
# transform anyway)
self.mask_[~present] = True
self._validate_mask(use_non_original_only=True)
elif self.strategy == "random_weighted":
self.mask_ = self.random_weighted_missing_data(
X,
inv=False,
target_rate=self.prop_missing,
mask_missing=self.mask_missing,
)
elif self.strategy == "random_weighted_inv":
self.mask_ = self.random_weighted_missing_data(
X,
inv=True,
target_rate=self.prop_missing,
mask_missing=self.mask_missing,
)
elif self.strategy.startswith("nonrandom"):
if self.strategy not in {"nonrandom", "nonrandom_weighted"}:
msg = f"strategy must be one of {{'nonrandom','nonrandom_weighted'}}, got: {self.strategy}"
self.logger.error(msg)
raise ValueError(msg)
if self.tree_parser is None or not hasattr(self.tree_parser, "tree"):
msg = "SimMissingTransformer.tree cannot be NoneType when strategy='nonrandom' or strategy='nonrandom_weighted'"
self.logger.error(msg)
raise TypeError(msg)
skip_root = True
weighted = self.strategy == "nonrandom_weighted"
# working mask
mask = np.zeros_like(X, dtype=bool)
# eligible cells
present = (
~self.original_missing_mask_
if self.mask_missing
else np.ones_like(mask, dtype=bool)
)
total_eligible = int(present.sum())
if total_eligible == 0:
self.mask_ = mask
self._validate_mask(use_non_original_only=self.mask_missing)
self.all_missing_mask_ = np.logical_or(
self.mask_, self.original_missing_mask_
)
self.sim_missing_mask_ = np.logical_and(
self.all_missing_mask_, ~self.original_missing_mask_
)
return self
target = int(round(self.prop_missing * total_eligible))
tol = int(
max(
1,
(self.tol if self.tol is not None else 1.0 / mask.size)
* total_eligible,
)
)
# map tip labels -> row indices
name_to_idx = {name: i for i, name in enumerate(self.genotype_data.samples)}
max_outer = (
self.max_tries
if self.max_tries is not None
else max(10_000, mask.shape[0] * 10)
)
placed = int(mask.sum())
best_delta = abs(placed - target)
tries = 0
# simple per-locus quota to distribute hits
col_quota = np.full(
mask.shape[1],
max(1, int(np.ceil(target / max(1, mask.shape[1])))),
dtype=int,
)
while tries < max_outer and abs(placed - target) > tol:
tries += 1
# >>> Call _sample_tree here <<<
try:
tips = self._sample_tree(
internal_only=False,
tips_only=False,
skip_root=skip_root,
weighted=weighted,
rng=self.rng,
)
except ValueError:
# no eligible nodes or no tips intersect samples; try again
continue
# Convert to row indices; skip labels not in matrix
rows = [name_to_idx[t] for t in tips if t in name_to_idx]
if not rows:
continue
# choose a column to edit
cols_left = np.flatnonzero(col_quota > 0)
if cols_left.size == 0:
cols_left = np.arange(mask.shape[1])
j = int(self.rng.choice(cols_left))
# only edit eligible cells in this column
eligible_rows = np.fromiter(
(r for r in rows if present[r, j]), dtype=int
)
if eligible_rows.size == 0:
continue
if placed < target:
prev_col = mask[:, j].copy()
mask[eligible_rows, j] = True
# avoid fully missing column among observed
col_after = mask[present[:, j], j]
if col_after.all():
idx_present = np.flatnonzero(present[:, j])
k = int(self.rng.choice(idx_present))
mask[k, j] = False
new_placed = int(mask.sum())
delta = abs(new_placed - target)
if delta <= best_delta:
best_delta = delta
placed = new_placed
col_quota[j] = max(0, col_quota[j] - 1)
else:
mask[:, j] = prev_col
else:
# remove within the same clade and column
prev_col = mask[:, j].copy()
col_idxs = eligible_rows[mask[eligible_rows, j]]
if col_idxs.size == 0:
continue
need = min(col_idxs.size, max(1, placed - target))
to_clear = self.rng.choice(col_idxs, size=need, replace=False)
mask[to_clear, j] = False
new_placed = int(mask.sum())
delta = abs(new_placed - target)
if delta <= best_delta:
best_delta = delta
placed = new_placed
else:
mask[:, j] = prev_col
self.mask_ = mask
self._validate_mask(use_non_original_only=self.mask_missing)
else:
msg = f"Invalid SimMissingTransformer.strategy value: {self.strategy}"
self.logger.error(msg)
raise ValueError(msg)
# Get all missing values.
self.all_missing_mask_ = np.logical_or(self.mask_, self.original_missing_mask_)
# Get values where original value was not missing and simulated.
# data is missing.
self.sim_missing_mask_ = np.logical_and(
self.all_missing_mask_, self.original_missing_mask_ == False
)
if self.mask_missing:
overlap = self.sim_missing_mask_ & self.original_missing_mask_
if bool(overlap.any()):
n = int(overlap.sum())
msg = f"SimMissingTransformer produced {n} simulated-missing positions that overlap original missing values while mask_missing=True. This violates the no-overlap contract."
self.logger.error(msg)
raise ValueError(msg)
self._validate_mask(use_non_original_only=self.mask_missing)
return self
[docs]
def transform(self, X: np.ndarray) -> np.ndarray:
"""Function to generate masked sites in a SimGenotypeData object
Args:
X (np.ndarray): Data to transform. No missing data should be present in X. It should have already been imputed with one of the non-machine learning simple imputers.
Returns:
np.ndarray: Transformed data with missing data added.
"""
X = np.asarray(validate_input_type(X, return_type="array")).astype("float32")
# mask 012-encoded and one-hot encoded genotypes.
return self._mask_snps(X)
[docs]
def sqrt_transform(self, proportions: np.ndarray) -> np.ndarray:
"""Apply the square root transformation to an array of proportions.
Args:
proportions (np.ndarray): An array of proportions.
Returns:
np.ndarray: The transformed proportions.
"""
return np.sqrt(proportions)
[docs]
def random_weighted_missing_data(
self,
X: np.ndarray,
transform_fn: Literal["sqrt", "exp"] = "sqrt",
power: float = 0.5,
inv: bool = False,
rng: Optional[np.random.Generator] = None,
target_rate: float | None = None,
*,
mask_missing: bool = True,
) -> np.ndarray:
"""Simulate missing data proportional or inversely proportional to genotype frequencies.
This method simulates missing data in a genotype matrix based on genotype frequencies. It allows for different transformation functions to be applied to the base probabilities, and can optionally use inverse genotype frequencies.
Args:
X (np.ndarray): Input genotype matrix.
transform_fn (Literal["sqrt", "exp"]): Transformation function to apply to base probabilities.
power (float): Exponent to raise transformed probabilities.
inv (bool): If True, use inverse genotype frequencies. If False, use direct frequencies to weight missingness.
rng (Optional[np.random.Generator]): Optional NumPy Generator for reproducibility.
target_rate (float | None): If provided, scales the probabilities to achieve this target missing rate.
Returns:
np.ndarray: Simulated missing mask.
"""
rng = rng if rng is not None else self.rng
tf = transform_fn.lower()
if tf not in {"sqrt", "exp"}:
msg = f"transform_fn must be 'sqrt' or 'exp', got: {transform_fn}"
self.logger.error(msg)
raise ValueError(msg)
eps = 1e-12
def _tf(arr: np.ndarray) -> np.ndarray:
arr = np.clip(arr, eps, None)
return np.sqrt(arr) if tf == "sqrt" else np.exp(-arr)
n_samples, n_snps = X.shape
out_mask = np.zeros((n_samples, n_snps), dtype=bool)
for j in range(n_snps):
col = X[:, j]
present = ~np.isnan(col)
eligible = present if mask_missing else np.ones_like(present, dtype=bool)
if not np.any(eligible):
continue
vals = col[eligible]
classes, counts = np.unique(vals, return_counts=True)
if classes.size == 1: # never wipe entire column
continue
p = counts.astype(float) / counts.sum()
base = 1.0 / np.clip(p, eps, None) if inv else p
w = _tf(base)
w = np.clip(w, 0.0, None) ** power
s = w.sum()
w = (
np.full_like(w, 1.0 / w.size, dtype=float)
if (s <= 0 or ~np.isfinite(s))
else (w / s)
)
probs = np.zeros(n_samples, dtype=float)
for c, pw in zip(classes, w):
probs[eligible & (col == c)] = pw
if target_rate is not None:
mean_p = probs[present].mean()
if mean_p > 0:
probs *= float(target_rate) / mean_p
probs = np.clip(probs, 0.0, 1.0)
draws = rng.random(n_samples)
out_mask[:, j] = draws < probs
if mask_missing:
out_mask[~present, j] = False # never alter already-missing
# guard against accidentally wiping this column (using only non-original-missing)
col_after = out_mask[present, j]
if col_after.sum() == col_after.size:
# clear a random observed index
k = rng.integers(0, col_after.size)
out_mask[np.flatnonzero(present)[k], j] = False
return out_mask
def _sample_tree(
self,
internal_only: bool = False,
tips_only: bool = False,
skip_root: bool = True,
weighted: bool = False,
rng: Optional[np.random.Generator] = None,
) -> list[str]:
"""Sample a node and return descendant tip labels.
This method samples a node from the genotype tree and retrieves the tip labels of all descendant nodes. The sampling can be restricted to internal nodes, tip nodes, or can exclude the root node. Additionally, the sampling can be weighted by branch lengths.
Args:
internal_only (bool): Sample only internal nodes.
tips_only (bool): Sample only tip nodes.
skip_root (bool): Exclude the root from sampling.
weighted (bool): Weight node sampling by branch length.
rng (Optional[np.random.Generator]): Optional NumPy Generator for reproducibility.
Returns:
list[str]: Tip labels under the sampled node.
Raises:
ValueError: If no eligible nodes exist or both tips_only and internal_only are True.
"""
rng = rng if rng is not None else self.rng
if tips_only and internal_only:
msg = "tips_only and internal_only cannot both be True"
self.logger.error(msg)
raise ValueError(msg)
node_dict: dict[int | object, float] = {}
if self.tree_parser is None or not hasattr(self.tree_parser, "tree"):
msg = "SimMissingTransformer.tree cannot be NoneType when strategy='nonrandom' or strategy='nonrandom_weighted'"
self.logger.error(msg)
raise TypeError(msg)
# Traverse using the tree backend you have;
# be tolerant of API differences.
for node in self.tree_parser.tree.treenode.traverse("preorder"):
# Robust root detection: prefer is_root(), then fall back to parent None, finally fall back to idx==nnodes-1 only if needed.
is_root = False
if hasattr(node, "is_root"):
is_root = bool(node.is_root())
elif getattr(node, "up", None) is None:
is_root = True
elif hasattr(self.tree_parser.tree, "nnodes") and hasattr(node, "idx"):
is_root = node.idx == self.tree_parser.tree.nnodes - 1
if skip_root and is_root:
continue
if tips_only and not node.is_leaf():
continue
if internal_only and node.is_leaf():
continue
# Branch length; coerce invalid to 0
dist = float(getattr(node, "dist", 0.0) or 0.0)
if not np.isfinite(dist):
dist = 0.0
# Use node.idx if stable, else the node object as key
key = getattr(node, "idx", node)
node_dict[key] = dist
if not node_dict:
msg = "No eligible nodes found to sample from the tree."
self.logger.error(msg)
raise ValueError(msg)
keys = np.array(list(node_dict.keys()), dtype=object)
weights = np.asarray(list(node_dict.values()), dtype=float)
weights[~np.isfinite(weights)] = 0.0
sample_set = set(self.genotype_data.samples)
def _choose_key() -> object:
if weighted and weights.sum() > 0.0:
p = weights / weights.sum()
return rng.choice(keys, p=p)
return rng.choice(keys)
tree = self.tree_parser.tree
last_error: Optional[Exception] = None
max_attempts = max(1, len(keys) * 3)
for _ in range(max_attempts):
chosen_key = _choose_key()
# 1. Resolve chosen_key to a Node object
try:
if isinstance(chosen_key, (int, np.integer)):
node = tree[int(chosen_key)]
else:
node = chosen_key
except Exception as e:
last_error = e
continue
# 2. Retrieve leaves for this specific node
if not hasattr(node, "get_leaves"):
last_error = TypeError(
f"Object {type(node)} does not have a get_leaves method."
)
continue
try:
tips = [leaf.name for leaf in node.get_leaves()] # type: ignore
except Exception as e:
last_error = e
continue
# Filter to sample IDs present in the matrix
tips = [t for t in tips if t in sample_set]
if tips:
return tips
msg = (
"No sampled clades contain tips present in genotype_data.samples. "
"Check that tree tip names match the genotype_data samples."
)
self.logger.error(msg)
if last_error:
raise ValueError(msg) from last_error
raise ValueError(msg)
def _validate_mask(self, use_non_original_only: bool = False) -> None:
"""Ensure no column is entirely masked on observed entries.
Args:
use_non_original_only (bool): If True, only consider non-original-missing entries when validating. Defaults to False.
"""
m = self.mask_
for j in range(m.shape[1]):
if use_non_original_only:
obs = ~self.original_missing_mask_[:, j]
else:
obs = np.ones(m.shape[0], dtype=bool)
if not np.any(obs):
continue
col = m[obs, j]
if col.size and col.all():
# clear one random observed index
idxs = np.flatnonzero(obs)
k = self.rng.integers(0, idxs.size)
self.mask_[idxs[k], j] = False
def _mask_snps(self, X):
"""Mask positions in SimGenotypeData.snps and SimGenotypeData.onehot"""
if X.ndim == 3:
# One-hot encoded: zero-out all channels at masked positions
mask_val = np.zeros((X.shape[-1],), dtype=X.dtype)
elif X.ndim == 2:
# 012-encoded.
mask_val = (
float(self.missing_val)
if np.isnan(self.missing_val)
else self.missing_val
)
else:
raise ValueError(f"Invalid shape of input X: {X.shape}")
Xt = X.copy()
mask_boolean = self.mask_ != 0
Xt[mask_boolean] = mask_val
return Xt
[docs]
def write_mask(self, filename_prefix: str):
"""Write mask to file.
Args:
filename_prefix (str): Prefix for the filenames to write to.
"""
np.save(filename_prefix + "_mask.npy", self.mask_)
np.save(
filename_prefix + "_original_missing_mask.npy",
self.original_missing_mask_,
)
[docs]
def read_mask(
self, filename_prefix: str
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Read mask from file.
Args:
filename_prefix (str): Prefix for the filenames to read from.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The read masks. (mask, original_missing_mask, all_missing_mask).
"""
# Check if files exist
if not Path(filename_prefix + "_mask.npy").is_file():
msg = filename_prefix + "_mask.npy" + " does not exist."
self.logger.error(msg)
raise FileNotFoundError(msg)
if not Path(filename_prefix + "_original_missing_mask.npy").is_file():
msg = filename_prefix + "_original_missing_mask.npy" + " does not exist."
self.logger.error(msg)
raise FileNotFoundError(msg)
# Load mask from file
self.mask_ = np.load(filename_prefix + "_mask.npy")
self.original_missing_mask_ = np.load(
filename_prefix + "_original_missing_mask.npy"
)
# Recalculate all_missing_mask_ from mask_ and original_missing_mask_
self.all_missing_mask_ = np.logical_or(self.mask_, self.original_missing_mask_)
return self.mask_, self.original_missing_mask_, self.all_missing_mask_
@property
def missing_count(self) -> int:
"""Count of masked genotypes in SimGenotypeData.mask
Returns:
int: Integer count of masked alleles.
"""
return np.sum(self.mask_)
@property
def prop_missing_real(self) -> float:
"""Proportion of genotypes masked in SimGenotypeData.mask
Returns:
float: Total number of masked alleles divided by SNP matrix size.
"""
return np.sum(self.mask_) / self.mask_.size