6. Training

CNVRock training is one entry point: models/train.py. It loads a YAML config, builds the dataset, trains the VAE, runs inference + HMM segmentation, calls per-gene CNVs (chromosomal + plasmid), and writes evaluation outputs.

Entry point

python models/train.py models/experiments/32/config.yaml

SLURM wrapper

hpc/train_gpu.sh requests a GPU node, activates the conda env, and cds into models/ before invoking train.py. Submit with:

sbatch hpc/train_gpu.sh experiments/32/config.yaml

Run-time on an A40: ~4 min for 5K samples, ~6 min for 10K samples (150 epochs × ~40 batches × 128 samples/batch).

Config schema

Every experiment lives at models/experiments/{N}/config.yaml. The configs for exp 32–36 share identical architecture, HMM, CNV-caller and threshold parameters — only store_path, plasmid_store_path, and out_dir vary across the scaling tiers.

# Modules (resolved via importlib at runtime)
architecture: "06_conv_vae"
hmm:          "02_gaussian_hmm"
cnv:          "06_gene_cnv_caller"
evaluation:   "04_kpsc_evaluation"

# Data (per-tier varies)
store_path:           "../../../data/inputs/KpSC-expansion-5k-mq20-1000bp-npy"
plasmid_store_path:   "../../../data/inputs/KpSC-expansion-5k-mq20-plasmid-1000bp-npy"
out_dir:              "../../../data/results/32_kpsc_expansion_5k"

# Ground truth
kpsc_gt_path:           "../../../assets/amrfinder_gt_expansion.tsv"
kpsc_kleborate_gt_path: "../../../assets/kpsc_expansion_kleborate_gt_runlevel.tsv"
kpsc_meta_path:         "../../../assets/kpsc_expansion_metadata_runlevel.tsv"

# Plasmid genes
plasmid_gene_coords_path: "../../../assets/plasmid_refs/plasmid_gene_coords.tsv"
pcn_absent_threshold:     0.20
pcn_amp_threshold:        1.50

# VAE
latent_dim:    10
batch_size:    128
epochs:        150
lr:            1.0e-4
weight_decay:  1.0e-5
max_beta:      1.0
warmup_epochs: 20
patience:      20

# HMM
hmm_n_states:        6
hmm_self_transition: 0.80
hmm_low_cov_threshold: 10

# Chromosomal CNV caller
cnv_min_cn1_proportion:         0.55
cnv_min_confidence:             0.50
cnv_flank_padding:              100000
cnv_crr_amp_threshold:          1.75
cnv_crr_gate_threshold:         1.75
cnv_crr_min_bins_fallback:      3
cnv_min_gene_coverage_fraction: 0.50

eval_min_group_n: 10

Architecture: 1D Conv-VAE

models/architecture/06_conv_vae.py. Encoder takes the per-sample 5,334-bin vector through three 1D convolutions + a dense head, producing a 10-dim latent. Decoder mirrors the encoder with transposed convs back to bin space.

Training minimises a weighted ELBO with β-warmup (β=0 → β=1 over the first 20 epochs) plus a CNV-pattern alignment auxiliary loss at weight 1.0 (warmup 30 epochs). The auxiliary loss pulls the latent toward biologically-meaningful structure (preventing the VAE from collapsing to a global-depth-only representation).

Segmenter: Gaussian HMM

models/hmm/02_gaussian_hmm.py. Per-sample inference reconstructions are re-normalised and segmented with a 6-state Gaussian HMM (state means initialised at CN ∈ {0, 0.5, 1, 1.5, 2, 3}, self-transition 0.8). Low-coverage bins (<10 reads) are masked and re-imputed from neighbours.

Output artefacts

After train.py completes, out_dir/ contains:

checkpoint.pth                model + optimiser state at best epoch
training_log.tsv              per-epoch loss curves
reconstructions.npy           (n_samples, n_bins) imputed depth
latents.npy                   (n_samples, 10)
segments.parquet              per-sample CN segments
gene_calls.tsv                per-sample chromosomal-gene CN calls
plasmid_gene_calls.tsv        per-sample plasmid-gene CN calls
evaluation.txt                MCC/FNR/PPV by gene and by ST

evaluation.txt is the headline artefact for the manuscript — see 7. Evaluation.