import logging
import warnings
from pathlib import Path
from typing import Dict, List, Literal, Mapping, Optional, Sequence, cast
import matplotlib as mpl
# Use Agg backend for headless plotting
mpl.use("Agg")
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import optuna
import pandas as pd
import seaborn as sns
import torch
from optuna.exceptions import ExperimentalWarning
from scipy.spatial.distance import jensenshannon
from sklearn.metrics import (
ConfusionMatrixDisplay,
auc,
average_precision_score,
confusion_matrix,
precision_recall_curve,
roc_curve,
)
from sklearn.preprocessing import label_binarize
from snpio import SNPioMultiQC
from snpio.utils.logging import LoggerManager
from pgsui.utils import misc
from pgsui.utils.logging_utils import configure_logger
# Quiet Matplotlib/fontTools INFO logging when saving PDF/SVG
for name in (
"fontTools",
"fontTools.subset",
"fontTools.ttLib",
"matplotlib.font_manager",
):
lg = logging.getLogger(name)
lg.setLevel(logging.WARNING)
lg.propagate = False
[docs]
class Plotting:
"""Class for plotting imputer scoring and results.
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
Example:
>>> from pgsui import Plotting
>>> plotter = Plotting(model_name="ImputeVAE", prefix="pgsui_test", plot_format="png")
>>> plotter.plot_metrics(metrics, num_classes)
>>> plotter.plot_history(history)
>>> plotter.plot_confusion_matrix(y_true_1d, y_pred_1d)
>>> plotter.plot_tuning(study, model_name, optimize_dir, target_name="Objective Value")
>>> plotter.plot_gt_distribution(df)
Attributes:
model_name (str): Name of the model.
prefix (str): Prefix for the output directory.
plot_format (Literal["pdf", "png", "jpeg", "jpg", "svg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg', 'svg').
plot_fontsize (int): Font size for the plots.
plot_dpi (int): Dots per inch for the plots.
title_fontsize (int): Font size for the plot titles.
show_plots (bool): Whether to display the plots inline or during execution.
output_dir (Path): Directory where plots will be saved.
logger (logging.Logger): Logger instance for logging messages.
"""
[docs]
def __init__(
self,
model_name: str,
*,
prefix: str = "pgsui",
plot_format: Literal["pdf", "png", "jpeg", "jpg", "svg"] = "pdf",
plot_fontsize: int = 18,
plot_dpi: int = 300,
title_fontsize: int = 20,
despine: bool = True,
show_plots: bool = False,
verbose: int = 0,
debug: bool = False,
multiqc: bool = False,
multiqc_section: Optional[str] = None,
) -> None:
"""Initialize the Plotting object.
This class is used to plot the performance metrics of imputation models. It can plot ROC and Precision-Recall curves, model history, and the distribution of genotypes in the dataset.
Args:
model_name (str): Name of the model.
prefix (str, optional): Prefix for the output directory. Defaults to 'pgsui'.
plot_format (Literal["pdf", "png", "jpeg", "jpg"]): Format for the plots ('pdf', 'png', 'jpeg', 'jpg'). Defaults to 'pdf'.
plot_fontsize (int): Font size for the plots. Defaults to 18.
plot_dpi (int): Dots per inch for the plots. Defaults to 300.
title_fontsize (int): Font size for the plot titles. Defaults to 20.
despine (bool): Whether to remove the top and right spines from the plots. Defaults to True.
show_plots (bool): Whether to display the plots. Defaults to False.
verbose (int): Verbosity level for logging. Defaults to 0.
debug (bool): Whether to enable debug mode. Defaults to False.
multiqc (bool): Whether to queue plots for a MultiQC HTML report. Defaults to False.
multiqc_section (Optional[str]): Section name to use in MultiQC. Defaults to 'PG-SUI (<model_name>)'.
"""
logman = LoggerManager(
name=__name__, prefix=prefix, verbose=bool(verbose), debug=bool(debug)
)
self.logger = configure_logger(
logman.get_logger(), verbose=bool(verbose), debug=bool(debug)
)
self.model_name = model_name
self.prefix = prefix
self.plot_format = plot_format
self.plot_fontsize = plot_fontsize
self.plot_dpi = plot_dpi
self.title_fontsize = title_fontsize
self.show_plots = show_plots
# MultiQC configuration
self.use_multiqc: bool = bool(multiqc)
self.multiqc_section: str = (
multiqc_section if multiqc_section is not None else f"PG-SUI ({model_name})"
)
if self.plot_format.startswith("."):
self.plot_format = self.plot_format.lstrip(".")
self.param_dict = {
"axes.labelsize": self.plot_fontsize,
"axes.titlesize": self.title_fontsize,
"axes.spines.top": despine,
"axes.spines.right": despine,
"xtick.labelsize": self.plot_fontsize,
"ytick.labelsize": self.plot_fontsize,
"legend.fontsize": self.plot_fontsize,
"legend.facecolor": "white",
"figure.titlesize": self.title_fontsize,
"figure.dpi": self.plot_dpi,
"figure.facecolor": "white",
"axes.linewidth": 2.0,
"lines.linewidth": 2.0,
"font.size": self.plot_fontsize,
"savefig.bbox": "tight",
"savefig.facecolor": "white",
"savefig.dpi": self.plot_dpi,
"pdf.fonttype": 42,
"ps.fonttype": 42,
}
mpl.rcParams.update(self.param_dict)
unsuper = {"ImputeVAE", "ImputeNLPCA", "ImputeAutoencoder", "ImputeUBP"}
det = {
"ImputeRefAllele",
"ImputeMostFrequent",
"ImputeMostFrequentPerPop",
"ImputePhylo",
}
sup = {"ImputeRandomForest", "ImputeHistGradientBoosting"}
if model_name in unsuper:
plot_dir = "Unsupervised"
elif model_name in det:
plot_dir = "Deterministic"
elif model_name in sup:
plot_dir = "Supervised"
else:
msg = f"model_name '{model_name}' not recognized."
self.logger.error(msg)
raise ValueError(msg)
self.output_dir = Path(f"{self.prefix}_output", plot_dir)
self.output_dir = self.output_dir / "plots" / model_name
self.output_dir.mkdir(parents=True, exist_ok=True)
# --------------------------------------------------------------------- #
# Core plotting methods #
# --------------------------------------------------------------------- #
[docs]
def plot_tuning(
self,
study: optuna.study.Study,
model_name: str,
optimize_dir: Path,
target_name: str = "Objective Value",
) -> None:
"""Plot the optimization history of a study.
This method plots the optimization history of a study. The plot is saved to disk as a ``<plot_format>`` file.
Args:
study (optuna.study.Study): Optuna study object.
model_name (str): Name of the model.
optimize_dir (Path): Directory to save the optimization plots.
target_name (str): Name of the target value. Defaults to 'Objective Value'.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=ExperimentalWarning)
target_name = target_name.title()
try:
ax = optuna.visualization.matplotlib.plot_optimization_history(
study, target_name=target_name
)
is_multi_obj = False
except ValueError:
ax = optuna.visualization.matplotlib.plot_optimization_history(
study, target_name=target_name, target=lambda t: t.values[0]
)
is_multi_obj = True
ax.set_title(f"{model_name} Optimization History")
ax.set_xlabel("Trial")
ax.set_ylabel(target_name)
ax.legend(
loc="best",
shadow=True,
fancybox=True,
fontsize=mpl.rcParamsDefault["legend.fontsize"],
)
od = optimize_dir
fn = od / f"optuna_optimization_history.{self.plot_format}"
if not fn.parent.exists():
fn.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(fn)
plt.close()
if is_multi_obj:
ax = optuna.visualization.matplotlib.plot_edf(
study, target_name=target_name, target=lambda t: t.values[0]
)
else:
ax = optuna.visualization.matplotlib.plot_edf(
study, target_name=target_name
)
ax.set_title(f"{model_name} Empirical Distribution Function (EDF)")
ax.set_xlabel(target_name)
ax.set_ylabel(f"{model_name} Cumulative Probability")
ax.legend(
loc="best",
shadow=True,
fancybox=True,
fontsize=mpl.rcParamsDefault["legend.fontsize"],
)
plt.savefig(fn.with_stem("optuna_edf_plot"))
plt.close()
if is_multi_obj:
ax = optuna.visualization.matplotlib.plot_param_importances(
study, target_name=target_name, target=lambda t: t.values[0]
)
else:
ax = optuna.visualization.matplotlib.plot_param_importances(
study, target_name=target_name
)
ax.set_xlabel("Importance")
ax.set_ylabel("Parameter")
ax.legend(loc="best", shadow=True, fancybox=True)
plt.savefig(fn.with_stem("optuna_param_importances_plot"))
plt.close()
ax = optuna.visualization.matplotlib.plot_timeline(study)
ax.set_title(f"{model_name} Timeline Plot")
ax.set_xlabel("Datetime")
ax.set_ylabel("Trial")
plt.savefig(fn.with_stem("optuna_timeline_plot"))
plt.close()
# Reset the style from Optuna's plotting.
sns.set_style("ticks", rc=self.param_dict)
mpl.rcParams.update(self.param_dict)
plt.rcParams.update(self.param_dict)
# ---- MultiQC: Optuna tuning line graph + best-params table --------
if self._multiqc_enabled():
try:
self._queue_multiqc_tuning(
study=study, model_name=model_name, target_name=target_name
)
except Exception as exc: # pragma: no cover - defensive
self.logger.warning(f"Failed to queue MultiQC tuning plots: {exc}")
[docs]
def plot_metrics(
self,
y_true: np.ndarray,
y_pred_proba: np.ndarray,
metrics: Dict[str, float],
label_names: Optional[Sequence[str]] = None,
prefix: str = "",
) -> None:
"""Plot multi-class ROC-AUC and Precision-Recall curves.
This method plots the multi-class ROC-AUC and Precision-Recall curves. The plot is saved to disk as a ``<plot_format>`` file.
Args:
y_true (np.ndarray): 1D array of true integer labels in [0, n_classes-1].
y_pred_proba (np.ndarray): (n_samples, n_classes) array of predicted probabilities.
metrics (Dict[str, float]): Dict of summary metrics to annotate the figure.
label_names (Optional[Sequence[str]]): Optional sequence of class names (length must equal n_classes).
If provided, legends will use these names instead of 'Class i'.
prefix (str): Optional prefix for the output filename.
Raises:
ValueError: If model_name is not recognized (legacy guard).
"""
num_classes = y_pred_proba.shape[1]
if num_classes < 2:
msg = "plot_metrics: num_classes must be >= 2 for ROC/PR curves."
self.logger.error(msg)
raise ValueError(msg)
# Validate/normalize label names
if label_names is not None and len(label_names) != num_classes:
self.logger.warning(
f"plot_metrics: len(label_names)={len(label_names)} "
f"!= n_classes={num_classes}. Ignoring label_names."
)
label_names = None
if label_names is None:
label_names = [f"Class {i}" for i in range(num_classes)]
# --- One-Hot Encoding ---
# 1. Ensure y_true is int for indexing
# 2. Use np.eye to create a dense (N, num_classes) array immediately.
# This avoids label_binarize's inconsistent shape for binary cases (N,1 vs N,2)
# and avoids the sparse matrix wrapper bug.
y_true = np.array(y_true, dtype=int).ravel()
y_true_bin = np.eye(num_classes, dtype=int)[y_true]
# Containers
fpr, tpr, roc_auc_vals = {}, {}, {}
precision, recall, average_precision_vals = {}, {}, {}
# Per-class ROC & PR
for i in range(num_classes):
# Safe to slice [:, i] now because y_true_bin is guaranteed (N, num_classes)
fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], y_pred_proba[:, i])
roc_auc_vals[i] = auc(fpr[i], tpr[i])
precision[i], recall[i], _ = precision_recall_curve(
y_true_bin[:, i], y_pred_proba[:, i]
)
average_precision_vals[i] = average_precision_score(
y_true_bin[:, i], y_pred_proba[:, i]
)
# Macro-average ROC
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(num_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(num_classes):
mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= num_classes
fpr["macro"], tpr["macro"] = all_fpr, mean_tpr
roc_auc_vals["macro"] = auc(fpr["macro"], tpr["macro"])
# Macro-average PR
all_recall = np.unique(np.concatenate([recall[i] for i in range(num_classes)]))
mean_precision = np.zeros_like(all_recall)
for i in range(num_classes):
# recall[i] increases,
# but precision[i] is given over decreasing thresholds
mean_precision += np.interp(all_recall, recall[i][::-1], precision[i][::-1])
mean_precision /= num_classes
average_precision_vals["macro"] = average_precision_score(
y_true_bin, y_pred_proba, average="macro"
)
# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# ROC
axes[0].plot(
fpr["macro"],
tpr["macro"],
label=f"Macro ROC-AUC={roc_auc_vals['macro']:.2f})",
linestyle="--",
linewidth=4,
)
for i in range(num_classes):
axes[0].plot(
fpr[i],
tpr[i],
label=f"{label_names[i]} ROC-AUC={roc_auc_vals[i]:.2f})",
)
axes[0].plot(
[0, 1], [0, 1], linestyle="--", color="black", label="Random Baseline"
)
axes[0].set_xlabel("False Positive Rate", fontsize=self.plot_fontsize + 10)
axes[0].set_ylabel("True Positive Rate", fontsize=self.plot_fontsize + 10)
axes[0].tick_params(
axis="both", which="major", labelsize=self.plot_fontsize + 10
)
axes[0].set_title(
"Multi-class ROC-AUC Curves", fontsize=self.plot_fontsize + 10
)
axes[0].legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
fancybox=True,
shadow=True,
ncol=2,
fontsize=self.plot_fontsize + 5,
)
# Precision-recall
axes[1].plot(
all_recall,
mean_precision,
label=f"Macro AP={average_precision_vals['macro']:.2f})",
linestyle="--",
linewidth=4,
)
for i in range(num_classes):
axes[1].plot(
recall[i],
precision[i],
label=f"{label_names[i]} AP={average_precision_vals[i]:.2f})",
)
axes[1].plot(
[0, 1], [1, 0], linestyle="--", color="black", label="Random Baseline"
)
axes[1].set_xlabel("Recall", fontsize=self.plot_fontsize + 10)
axes[1].set_ylabel("Precision", fontsize=self.plot_fontsize + 10)
axes[1].tick_params(
axis="both", which="major", labelsize=self.plot_fontsize + 10
)
axes[1].set_title(
"Multi-class Precision-Recall Curves", fontsize=self.plot_fontsize + 10
)
axes[1].legend(
loc="upper center",
bbox_to_anchor=(0.5, -0.15),
fancybox=True,
shadow=True,
ncol=2,
fontsize=self.plot_fontsize + 5,
)
# Title & save
fig.suptitle(
"\n".join([f"{k}: {v:.2f}" for k, v in metrics.items()]),
fontsize=self.title_fontsize + 8,
y=1.45,
)
prefix_for_name = f"{prefix}_" if prefix != "" else ""
out_name = (
f"{self.model_name}_{prefix_for_name}roc_pr_curves.{self.plot_format}"
)
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
if self.show_plots:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
plt.show()
plt.close(fig)
# ---- MultiQC: metrics table + per-class AUC/AP heatmap ------------
if self._multiqc_enabled():
try:
self._queue_multiqc_metrics(
metrics=metrics,
roc_auc=roc_auc_vals,
average_precision=average_precision_vals,
label_names=label_names,
panel_prefix=prefix,
)
except Exception as exc: # pragma: no cover - defensive
self.logger.warning(f"Failed to queue MultiQC metrics plots: {exc}")
try:
self._queue_multiqc_roc_curves(
fpr=fpr,
tpr=tpr,
label_names=label_names,
panel_prefix=prefix,
)
self._queue_multiqc_pr_curves(
precision=precision,
recall=recall,
label_names=label_names,
panel_prefix=prefix,
)
except Exception as exc: # pragma: no cover - defensive
self.logger.warning(f"Failed to queue MultiQC ROC/PR curves: {exc}")
def _series_from_history(self, vals: list[float]) -> pd.Series:
"""Convert to float series and coerce non-finite to NaN."""
s = pd.Series(vals, dtype="float64")
s[~np.isfinite(s.to_numpy())] = np.nan
return s
[docs]
def plot_history(
self, history: dict[str, list[float]] | dict[str, dict[str, list[float]]]
) -> None:
"""Plot model history traces. Will be saved to file.
This method plots the deep learning model history traces. The plot is saved to disk as a ``<plot_format>`` file.
Args:
history (dict[str, list[float]] | dict[str, dict[str, list[float]]]): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
Raises:
ValueError: self.model_name must be either 'ImputeAutoencoder' or 'ImputeVAE'.
ValueError: history object passed to 'plot_history' is empty.
TypeError: history must be a dict containing {'Train', 'Val'} or {'Phase2', 'Phase3'}.
ValueError: For ImputeUBP, history must contain 'Phase2' and 'Phase3' keys.
"""
if self.model_name not in {
"ImputeUBP",
"ImputeNLPCA",
"ImputeVAE",
"ImputeAutoencoder",
}:
msg = f"model_name must be 'ImputeUBP', 'ImputeNLPCA', 'ImputeVAE' or 'ImputeAutoencoder', but got: {self.model_name}."
self.logger.error(msg)
raise ValueError(msg)
if not history:
msg = "history object passed to 'plot_history' is empty."
self.logger.error(msg)
raise ValueError(msg)
if not isinstance(history, dict) or not (
{"Train", "Val"} <= history.keys() or {"Phase2", "Phase3"} <= history.keys()
):
msg = "history must be a dict containing {'Train', 'Val'} or {'Phase2', 'Phase3'}."
self.logger.error(msg)
raise TypeError(msg)
if self.model_name == "ImputeUBP":
if "Phase2" in history and "Phase3" in history:
fig, axes = plt.subplots(2, 1, figsize=(6, 6), sharex=True)
# Phase 2
phase2_train = self._series_from_history(
history["Phase2"]["Train"] # type: ignore
)
phase2_val = self._series_from_history(
history["Phase2"]["Val"] # type: ignore
)
df_phase2 = pd.DataFrame(
{"Train": phase2_train, "Validation": phase2_val}
)
df_phase2_melt = df_phase2.melt(
var_name="Dataset", value_name="Loss", ignore_index=False
)
axes[0] = sns.lineplot(
data=df_phase2_melt,
x=df_phase2_melt.index,
y="Loss",
hue="Dataset",
hue_order=["Train", "Validation"],
palette="Set2",
dashes=False,
lw=3,
linestyle="-",
ax=axes[0],
legend=True,
)
axes[0].set_title(
f"{self.model_name} Phase 2 - Loss per Epoch",
fontsize=self.title_fontsize,
)
axes[0].set_xlabel("Epoch", fontsize=self.plot_fontsize)
axes[0].set_ylabel("Loss", fontsize=self.plot_fontsize)
axes[0].tick_params(
axis="both", which="major", labelsize=self.plot_fontsize
)
axes[0].legend(
fontsize=self.plot_fontsize,
title_fontsize=self.plot_fontsize,
fancybox=True,
shadow=True,
)
sns.move_legend(axes[0], title="", loc="best")
sns.despine(ax=axes[0])
# Phase 3
phase3_train = self._series_from_history(
history["Phase3"]["Train"] # type: ignore
)
phase3_val = self._series_from_history(
history["Phase3"]["Val"] # type: ignore
)
df_phase3 = pd.DataFrame(
{"Train": phase3_train, "Validation": phase3_val}
)
df_phase3_melt = df_phase3.melt(
var_name="Dataset", value_name="Loss", ignore_index=False
)
axes[1] = sns.lineplot(
data=df_phase3_melt,
x=df_phase3_melt.index,
y="Loss",
hue="Dataset",
hue_order=["Train", "Validation"],
palette="Set2",
dashes=False,
lw=3,
linestyle="-",
ax=axes[1],
legend=True,
)
axes[1].set_title(
f"{self.model_name} Phase 3 - Loss per Epoch",
fontsize=self.title_fontsize,
)
axes[1].set_xlabel("Epoch", fontsize=self.plot_fontsize)
axes[1].set_ylabel("Loss", fontsize=self.plot_fontsize)
axes[1].tick_params(
axis="both", which="major", labelsize=self.plot_fontsize
)
axes[1].legend(
fontsize=self.plot_fontsize,
title_fontsize=self.plot_fontsize,
fancybox=True,
shadow=True,
)
sns.move_legend(axes[1], title="", loc="best")
sns.despine(ax=axes[1])
fig.suptitle(
f"{self.model_name} Loss per Epoch", fontsize=self.title_fontsize
)
else:
msg = "For ImputeUBP, history must contain 'Phase2' and 'Phase3' keys."
self.logger.error(msg)
raise ValueError(msg)
else:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
train = self._series_from_history(history["Train"]) # type: ignore
val = self._series_from_history(history["Val"]) # type: ignore
df = pd.DataFrame({"Train": train, "Validation": val})
df_melt = df.melt(var_name="Dataset", value_name="Loss", ignore_index=False)
ax = sns.lineplot(
data=df_melt,
x=df_melt.index,
y="Loss",
hue="Dataset",
palette="Set2",
dashes=False,
lw=3,
linestyle="-",
ax=ax,
legend=True,
)
ax.set_title(
f"{self.model_name} - Loss per Epoch", fontsize=self.title_fontsize
)
ax.set_ylabel("Loss", fontsize=self.plot_fontsize)
ax.set_xlabel("Epoch", fontsize=self.plot_fontsize)
ax.tick_params(axis="both", which="major", labelsize=self.plot_fontsize)
ax.legend(
loc="best",
fontsize=self.plot_fontsize,
title_fontsize=self.plot_fontsize,
fancybox=True,
shadow=True,
)
sns.move_legend(ax, title="", loc="best")
sns.despine(ax=ax)
fn = f"{self.model_name.lower()}_history_plot.{self.plot_format}"
fn = self.output_dir / fn
fig.savefig(fn)
if self.show_plots:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
plt.show()
plt.close(fig)
# ---- MultiQC: training-loss vs epoch linegraphs -------------------
if self._multiqc_enabled():
try:
self._queue_multiqc_history(history=history)
except Exception as exc: # pragma: no cover
self.logger.warning(f"Failed to queue MultiQC history plot: {exc}")
[docs]
def plot_confusion_matrix(
self,
y_true_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
y_pred_1d: np.ndarray | pd.DataFrame | List[str | int] | torch.Tensor,
label_names: Sequence[str] | Dict[str, int] | None = None,
prefix: str = "",
) -> None:
"""Plot a confusion matrix with optional class labels.
This method plots a confusion matrix using true and predicted labels. The plot is saved to disk as a ``<plot_format>`` file.
Args:
y_true_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of true integer labels in [0, n_classes-1].
y_pred_1d (np.ndarray | pd.DataFrame | list | torch.Tensor): 1D array of predicted integer labels in [0, n_classes-1].
label_names (Sequence[str] | None): Optional sequence of class names (length must equal n_classes). If provided, both the internal label order and displayed tick labels will respect this order (assumed to be 0..n-1).
prefix (str): Optional prefix for the output filename.
Notes:
- If `label_names` is None, the display labels default to the numeric class indices inferred from `y_true_1d ∪ y_pred_1d`.
"""
y_true_1d = misc.validate_input_type(y_true_1d, return_type="array")
y_pred_1d = misc.validate_input_type(y_pred_1d, return_type="array")
if not isinstance(y_true_1d, np.ndarray) or y_true_1d.ndim != 1:
msg = "y_true_1d must be a 1D array-like of true labels."
self.logger.error(msg)
raise TypeError(msg)
if not isinstance(y_pred_1d, np.ndarray) or y_pred_1d.ndim != 1:
msg = "y_pred_1d must be a 1D array-like of predicted labels."
self.logger.error(msg)
raise TypeError(msg)
if y_true_1d.ndim > 1:
y_true_1d = y_true_1d.flatten()
if y_pred_1d.ndim > 1:
y_pred_1d = y_pred_1d.flatten()
# Determine class count/order
if label_names is not None:
n_classes = len(label_names)
labels = np.arange(n_classes) # our y_* are ints 0..n-1
display_labels = list(map(str, label_names))
else:
# Infer labels from data to keep matrix tight
labels = np.unique(np.concatenate([y_true_1d, y_pred_1d]))
display_labels = labels # sklearn will convert to strings
fig, ax = plt.subplots(1, 1, figsize=(15, 15))
true_values, true_counts = np.unique(y_true_1d, return_counts=True)
pred_values, pred_counts = np.unique(y_pred_1d, return_counts=True)
vmin = int(min(true_counts.min(), pred_counts.min()))
vmax = int(max(true_counts.max(), pred_counts.max()))
if n_classes <= 3:
disp = ConfusionMatrixDisplay.from_predictions(
y_true=y_true_1d,
y_pred=y_pred_1d,
labels=labels,
display_labels=display_labels,
ax=ax,
cmap="viridis",
colorbar=True,
text_kw={"fontsize": 28},
im_kw={"norm": colors.LogNorm(vmin=1, vmax=vmax)},
)
ax.set_xlabel("Predicted Label", fontsize=28)
ax.set_ylabel("True Label", fontsize=28)
ax.tick_params(axis="both", which="major", labelsize=28)
ax.set_title(f"{self.model_name} Confusion Matrix", fontsize=28)
if disp.im_.colorbar is not None:
disp.im_.colorbar.ax.tick_params(labelsize=24)
else:
ConfusionMatrixDisplay.from_predictions(
y_true=y_true_1d,
y_pred=y_pred_1d,
labels=labels,
display_labels=display_labels,
ax=ax,
cmap="viridis",
colorbar=True,
im_kw={"norm": colors.LogNorm(vmin=1, vmax=vmax)},
)
ax.set_xlabel("Predicted Label")
ax.set_ylabel("True Label")
ax.tick_params(axis="both", which="major")
ax.set_title(f"{self.model_name} Confusion Matrix")
# Build a stable panel id before mutating prefix
panel_suffix = f"{prefix}_" if prefix else ""
panel_id = f"{self.model_name.lower()}_{panel_suffix}confusion_matrix"
if prefix != "" and not prefix.endswith("_"):
prefix = f"{prefix}_"
out_name = (
f"{self.model_name.lower()}_{prefix}confusion_matrix.{self.plot_format}"
)
fig.savefig(self.output_dir / out_name, bbox_inches="tight")
if self.show_plots:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
plt.show()
plt.close(fig)
# ---- MultiQC: confusion-matrix heatmap ----------------------------
if self._multiqc_enabled():
try:
self._queue_multiqc_confusion(
y_true=y_true_1d,
y_pred=y_pred_1d,
labels=labels,
display_labels=display_labels,
panel_id=panel_id,
)
except Exception as exc: # pragma: no cover
self.logger.warning(f"Failed to queue MultiQC confusion matrix: {exc}")
[docs]
def plot_gt_distribution(
self,
X: np.ndarray | pd.DataFrame | list | torch.Tensor,
X_compare: np.ndarray | pd.DataFrame | list | torch.Tensor | None = None,
is_imputed: bool = False,
) -> None:
"""Plot genotype distribution, optionally comparing two datasets.
Plots counts for all genotypes. If `X_compare` is provided, it plots side-by-side bars and calculates the Jensen-Shannon distance between
the distributions.
Args:
X (np.ndarray | pd.DataFrame | list | torch.Tensor): Primary genotype matrix (usually the imputed/final one).
X_compare (np.ndarray | pd.DataFrame | list | torch.Tensor | None): Optional baseline genotype matrix to compare against
(e.g., the original dataset with missing values).
is_imputed (bool): Labeling flag. If True, X is labeled "Imputed".
"""
# --- Helper to process raw input into a normalized Series ---
def _process_input(data, name) -> tuple[pd.Series, str, list]:
if isinstance(data, pd.DataFrame):
arr = data.to_numpy()
elif torch.is_tensor(data):
arr = data.detach().cpu().numpy()
else:
arr = np.asarray(data)
s = pd.Series(arr.ravel())
# Detect string vs numeric encodings
if s.dtype.kind in ("O", "U", "S"):
s = s.astype(str).str.upper().replace({"-": "N", ".": "N", "?": "N"})
lbl = "Genotype (IUPAC)"
canonical = ["A", "C", "T", "G"]
ambig = sorted(["M", "R", "W", "S", "Y", "K", "V", "H", "D", "B"])
order = ["N"] + canonical + ambig
else:
s = s.astype(float)
s = s.where(~np.isin(s, [-1, np.nan]), other=np.nan)
s = s.fillna("N").astype(int, errors="ignore").astype(str)
lbl = "Genotype (Integer-encoded)"
order = ["N", "0", "1", "2", "3"]
return s, lbl, order
# --- Process Datasets ---
# Define labels based on is_imputed flag or comparison logic
label_main = "Imputed" if is_imputed else "Dataset A"
label_compare = "Original" if is_imputed else "Dataset B"
s_main, x_label, base_order = _process_input(X, label_main)
datasets = {label_main: s_main}
if X_compare is not None:
s_comp, _, _ = _process_input(X_compare, label_compare)
datasets[label_compare] = s_comp
# --- Unified Ordering ---
# Collect all unique keys from all datasets to ensure alignment
all_uniques = set()
for s in datasets.values():
all_uniques.update(s.unique())
extras = sorted(all_uniques - set(base_order))
full_order = base_order + [e for e in extras if e not in base_order]
# --- Build Frequency Table ---
stats_data = []
plot_df_list = []
for name, series in datasets.items():
# Get counts reindexed to the full unified order
counts = series.value_counts().reindex(full_order, fill_value=0)
# Normalize to probabilities for JS distance calc
probs = counts[1:] / counts[1:].sum()
stats_data.append(probs.to_numpy())
# Prepare plotting DF
_df = counts.rename_axis("Genotype").reset_index(name="Count")
_df["Percent"] = _df["Count"] / _df["Count"].sum() * 100
_df["Dataset"] = name
plot_df_list.append(_df)
df_final = pd.concat(plot_df_list, ignore_index=True)
# --- Calculate Distance (if comparing) ---
dist_text = ""
if len(stats_data) == 2:
# Jensen-Shannon Distance is sqrt(JSD). Base 2 gives range [0, 1]
js_dist = jensenshannon(stats_data[0], stats_data[1], base=2)
dist_text = f"JS Dist: {js_dist:.2f}"
# --- Plot ---
fig, ax = plt.subplots(figsize=(10, 6) if X_compare is not None else (8, 5))
# If comparing, use 'Dataset' as hue.
# If not, map color to 'Genotype' as before.
hue_col = "Dataset" if X_compare is not None else "Genotype"
palette = "Set2" if X_compare is not None else "Set1"
sns.barplot(
data=df_final,
x="Genotype",
y="Percent",
hue=hue_col,
order=full_order,
errorbar=None,
ax=ax,
palette=palette,
edgecolor="black", # Add border to distinguish side-by-side bars
linewidth=0.5,
)
sns.set_style("ticks", rc=self.param_dict)
sns.despine(ax=ax)
ax.set_xlabel(x_label, fontsize=self.plot_fontsize + 10)
ax.set_ylabel("Percent", fontsize=self.plot_fontsize + 10)
ax.tick_params(axis="both", which="major", labelsize=self.plot_fontsize + 10)
title = (
"Genotype Distribution Comparison"
if X_compare is not None
else ("Imputed Genotype Counts" if is_imputed else "Genotype Counts")
)
ax.set_title(title, fontsize=self.plot_fontsize + 12)
# Add distance statistic to plot
if dist_text:
ax.annotate(
dist_text,
xy=(0.95, 0.95),
xycoords="axes fraction",
ha="right",
va="top",
fontsize=self.plot_fontsize + 10,
)
if X_compare is not None:
ax.legend(
title="Dataset",
fontsize=self.plot_fontsize + 10,
title_fontsize=self.plot_fontsize + 10,
loc="best",
)
else:
leg = ax.legend()
leg.set_visible(False)
fig.tight_layout()
suffix = (
"comparison"
if X_compare is not None
else ("imputed" if is_imputed else "original")
)
fn = self.output_dir / f"gt_distributions_{suffix}.{self.plot_format}"
fig.savefig(fn, dpi=300)
if self.show_plots:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
plt.show()
plt.close(fig)
# ---- MultiQC -----------------------
if self._multiqc_enabled():
try:
# Assuming queue supports the combined DF logic, or pass main
# You may need to adapt your MultiQC handler to accept the comparison DF
self._queue_multiqc_gt_distribution(
df=df_final,
is_imputed=is_imputed,
is_comparison=X_compare is not None,
)
except Exception as exc:
self.logger.warning(
f"Failed to queue MultiQC genotype distribution: {exc}"
)
# --------------------------------------------------------------------- #
# MultiQC helper methods #
# --------------------------------------------------------------------- #
def _multiqc_enabled(self) -> bool:
"""Return True if MultiQC integration is active."""
return bool(self.use_multiqc)
def _queue_multiqc_tuning(
self,
*,
study: optuna.study.Study,
model_name: str,
target_name: str,
) -> None:
"""Queue Optuna tuning results for MultiQC.
Args:
study (optuna.study.Study): Optuna study object.
model_name (str): Name of the model.
target_name (str): Name of the target value.
"""
if not self._multiqc_enabled():
return
# trial number vs objective value line graph
try:
df_trials = study.trials_dataframe(attrs=("number", "value"))
except Exception as exc: # pragma: no cover
self.logger.warning(
f"Could not extract trials_dataframe for MultiQC: {exc}"
)
return
if df_trials.empty or "value" not in df_trials:
return
history_data: Dict[str, Dict[int, float]] = {
model_name: {
int(row["number"]): float(row["value"])
for _, row in df_trials.iterrows()
if row["value"] is not None
}
}
if not history_data[model_name]:
return
SNPioMultiQC.queue_linegraph(
data=cast(Dict[str, Dict[int, int]], history_data),
panel_id=f"{self.model_name}_optuna_history",
section=self.multiqc_section,
title=f"{self.model_name} Optuna Optimization History",
index_label="Trial",
description=f"Optuna optimization history for {self.model_name} "
f"(target={target_name}).",
)
# best-params table
try:
best_value = study.best_value
best_params = study.best_params
except Exception:
return
if best_params:
# Build a single dict so static type checkers don't infer a
# mismatched dtype for the Series and complain about assigning
# a float value after creation.
best_param_data: Dict[str, float | int | str] = {
**{str(k): cast(float | int | str, v) for k, v in best_params.items()},
"objective": float(best_value),
}
series = pd.Series(best_param_data, name="Best Value")
SNPioMultiQC.queue_table(
df=series,
panel_id=f"{self.model_name}_optuna_best_params",
section=self.multiqc_section,
title=f"{self.model_name} Best Optuna Parameters",
index_label="Parameter",
description="Best Optuna hyperparameters and objective value.",
)
def _queue_multiqc_roc_curves(
self,
*,
fpr: dict,
tpr: dict,
label_names: Sequence[str],
panel_prefix: str,
) -> None:
"""Queue ROC and Precision-Recall curves for MultiQC.
Args:
fpr (dict): False positive rates for each class.
tpr (dict): True positive rates for each class.
label_names (Sequence[str]): Class names.
panel_prefix (str): Optional prefix for panel IDs.
"""
if not self._multiqc_enabled():
return
def _curve_to_mapping(
x_vals: Sequence[float], y_vals: Sequence[float]
) -> Dict[float, float]:
"""Return {x: y} mapping expected by MultiQC linegraphs."""
return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
data: Dict[str, Dict[float, float]] = {}
# Only report the first three classes
# (MultiQC plot readability) plus micro/macro averages
class_keys = sorted(k for k in fpr.keys() if isinstance(k, int))
for idx in class_keys[:3]:
label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
data[label] = _curve_to_mapping(fpr[idx], tpr[idx])
agg = "macro"
if agg in fpr and agg in tpr:
pretty_name = f"{agg.title()} Average"
data[pretty_name] = _curve_to_mapping(fpr[agg], tpr[agg])
if not data:
return
# ROC curves
curve_data = cast(Dict[str, Dict[int, int]], data)
SNPioMultiQC.queue_linegraph(
data=curve_data,
panel_id=(
f"{self.model_name}_{panel_prefix}_roc_curves"
if panel_prefix
else f"{self.model_name}_roc_curves"
),
section=self.multiqc_section,
title=f"{self.model_name} ROC Curves",
index_label="False Positive Rate",
description="Multi-class ROC curves for PG-SUI predictions.",
)
def _queue_multiqc_pr_curves(
self,
*,
precision: dict,
recall: dict,
label_names: Sequence[str],
panel_prefix: str,
) -> None:
"""Queue Precision-Recall curves for MultiQC."""
if not self._multiqc_enabled():
return
def _curve_to_mapping(
x_vals: Sequence[float], y_vals: Sequence[float]
) -> Dict[float, float]:
"""Return {recall: precision} mapping expected by MultiQC linegraphs."""
return {float(x): float(y) for x, y in zip(x_vals, y_vals)}
data: Dict[str, Dict[float, float]] = {}
# Only report the first three classes (MultiQC plot readability) plus micro/macro averages
class_keys = sorted(k for k in recall.keys() if isinstance(k, int))
for idx in class_keys[:3]:
if idx not in precision or idx not in recall:
continue
label = label_names[idx] if idx < len(label_names) else f"Class {idx}"
data[label] = _curve_to_mapping(recall[idx], precision[idx])
agg = "macro"
if agg in precision and agg in recall:
pretty_name = f"{agg.title()} Average"
data[pretty_name] = _curve_to_mapping(recall[agg], precision[agg])
if not data:
return
curve_data = cast(Dict[str, Dict[int, int]], data)
SNPioMultiQC.queue_linegraph(
data=curve_data,
panel_id=(
f"{self.model_name}_{panel_prefix}_pr_curves"
if panel_prefix
else f"{self.model_name}_pr_curves"
),
section=self.multiqc_section,
title=f"{self.model_name} Precision-Recall Curves",
index_label="Recall",
description="Multi-class Precision-Recall curves for PG-SUI predictions.",
)
def _queue_multiqc_metrics(
self,
*,
metrics: Dict[str, float],
roc_auc: Dict[object, float],
average_precision: Dict[object, float],
label_names: Sequence[str],
panel_prefix: str,
) -> None:
"""Queue summary metrics and per-class AUC/AP for MultiQC.
Args:
metrics (Dict[str, float]): Summary metrics (accuracy, F1, etc.).
roc_auc (Dict[object, float]): Per-class and aggregate ROC-AUC values.
average_precision (Dict[object, float]): Per-class and aggregate average precision values.
label_names (Sequence[str]): Class names.
panel_prefix (str): Optional prefix for panel IDs.
"""
if not self._multiqc_enabled():
return
# Summary metrics table (accuracy, F1, etc.)
if metrics:
series = pd.Series(metrics, name="Value")
SNPioMultiQC.queue_table(
df=series,
panel_id=f"{self.model_name}_summary_metrics",
section=self.multiqc_section,
title=f"{self.model_name} Summary Metrics",
index_label="Metric",
description="Global evaluation metrics produced by PG-SUI.",
)
# Per-class ROC-AUC and AP heatmap
rows: List[Dict[str, float | str]] = []
# integer keys are classes; others are 'micro', 'macro'
class_keys = [k for k in roc_auc.keys() if isinstance(k, int)]
class_keys_sorted = sorted(class_keys)
for i in class_keys_sorted:
class_name = label_names[i] if i < len(label_names) else f"Class {i}"
rows.append(
{
"Class": str(class_name),
"ROC_AUC": float(roc_auc.get(i, np.nan)),
"AveragePrecision": float(average_precision.get(i, np.nan)),
}
)
agg = "macro"
if agg in roc_auc:
rows.append(
{
"Class": agg,
"ROC_AUC": float(roc_auc.get(agg, np.nan)),
"AveragePrecision": float(average_precision.get(agg, np.nan)),
}
)
if not rows:
return
df = pd.DataFrame(rows).set_index("Class")
suffix = f"{panel_prefix}_" if panel_prefix else ""
panel_id = f"{self.model_name}_{suffix}roc_pr_summary"
SNPioMultiQC.queue_heatmap(
df=df,
panel_id=panel_id,
section=self.multiqc_section,
title=f"{self.model_name} ROC-AUC and Average Precision",
index_label="Class",
description=(
"Per-class ROC-AUC and average precision for PG-SUI predictions (including micro/macro averages where available)."
),
)
def _queue_multiqc_history(
self,
*,
history: Mapping[str, List[float] | Dict[str, List[float]] | None] | None,
) -> None:
"""Queue training history (loss vs epoch) for MultiQC.
Args:
history (Dict[str, List[float]] | None): Dictionary with lists of history objects. Keys should be "Train" and "Validation".
"""
if not self._multiqc_enabled() or history is None:
return
data: Dict[str, Dict[int, int]] = {}
if self.model_name != "ImputeUBP":
if not isinstance(history, dict) or "Train" not in history:
return
train_history = self._series_from_history(history["Train"]) # type: ignore
val_history = self._series_from_history(history.get("Val", [])) # type: ignore
data["Train"] = {
epoch: val
for epoch, val in enumerate(train_history.to_numpy(), start=1)
}
data["Val"] = {
epoch: val for epoch, val in enumerate(val_history.to_numpy(), start=1)
}
else:
for phase in range(2, 4):
phase_key = f"Phase{phase}"
if (
not isinstance(history, dict)
or phase_key not in history
or not isinstance(history[phase_key], dict)
):
continue
train_history = self._series_from_history(
history[phase_key].get("Train", []) # type: ignore
)
val_history = self._series_from_history(
history[phase_key].get("Val", []) # type: ignore
)
data[f"Phase {phase} Train"] = {
epoch: val
for epoch, val in enumerate(train_history.to_numpy(), start=1)
}
data[f"Phase {phase} Val"] = {
epoch: val
for epoch, val in enumerate(val_history.to_numpy(), start=1)
}
if not data:
return
SNPioMultiQC.queue_linegraph(
data=data,
panel_id=f"{self.model_name}_training_history",
section=self.multiqc_section,
title=f"{self.model_name} Training Loss per Epoch",
index_label="Epoch",
description="Training loss trajectory by epoch as recorded by PG-SUI.",
)
def _queue_multiqc_confusion(
self,
*,
y_true: np.ndarray,
y_pred: np.ndarray,
labels: np.ndarray,
display_labels: List[str] | np.ndarray,
panel_id: str,
) -> None:
"""Queue confusion-matrix heatmap for MultiQC.
Args:
y_true (np.ndarray): 1D array of true integer labels.
y_pred (np.ndarray): 1D array of predicted integer labels.
labels (np.ndarray): Array of label indices to index the confusion matrix.
display_labels (List[str] | np.ndarray): Labels to display on axes.
panel_id (str): Panel ID for MultiQC.
"""
if not self._multiqc_enabled():
return
cm = confusion_matrix(y_true, y_pred, labels=labels)
df_cm = pd.DataFrame(cm, index=display_labels, columns=display_labels)
SNPioMultiQC.queue_heatmap(
df=df_cm,
panel_id=panel_id,
section=self.multiqc_section,
title=f"{self.model_name} Confusion Matrix",
index_label="True Label",
description=(
"Confusion matrix for PG-SUI predictions. Rows correspond to true "
"labels; columns correspond to predicted labels."
),
)
def _queue_multiqc_gt_distribution(
self,
*,
df: pd.DataFrame,
is_imputed: bool,
is_comparison: bool = False,
) -> None:
"""Queue genotype-distribution barplot for MultiQC.
Args:
df (pd.DataFrame): DataFrame with 'Genotype' and 'Percent' columns
is_imputed (bool): Whether these genotypes are imputed.
"""
if not self._multiqc_enabled():
return
if "Genotype" not in df.columns or "Percent" not in df.columns:
return
if is_comparison:
df_imputed = df[df["Dataset"] == "Imputed"]
df_original = df[df["Dataset"] == "Original"]
series1 = df_original.set_index("Genotype")["Percent"]
series2 = df_imputed.set_index("Genotype")["Percent"]
df_final = pd.DataFrame(
{"Original": series1, "Imputed": series2},
index=series1.index,
columns=["Original", "Imputed"],
).fillna(0)
SNPioMultiQC.queue_barplot(
df=df_final,
panel_id=f"{self.model_name}_gt_distribution_comparison",
section=self.multiqc_section,
title=f"{self.model_name} Genotype Distribution Comparison",
index_label="Genotype",
value_label="Percent",
description=(
"Genotype frequency distribution (percent per genotype) computed by PG-SUI for both original and imputed datasets."
),
)
else:
series = df.set_index("Genotype")["Percent"]
suffix = "imputed" if is_imputed else "original"
dataset = "Imputed" if is_imputed else "Original"
title = f"{self.model_name} {dataset} Genotype Distribution"
SNPioMultiQC.queue_barplot(
df=series,
panel_id=f"{self.model_name}_gt_distribution_{suffix}",
section=self.multiqc_section,
title=title,
index_label="Genotype",
value_label="Percent",
description=(
"Genotype frequency distribution (percent of calls per genotype) computed by PG-SUI."
),
)