Gradient Training¶
Fit a cuvis-ai pipeline using GradientTrainer — backpropagation
through trainable parameters, driven by PyTorch Lightning.
Goal¶
Produce a saved, fully-trained pipeline (and a matching
trainrun.yaml) that can be replayed with
restore-trainrun for reproducible re-runs.
Prerequisites¶
- A pipeline with at least one node carrying trainable parameters (Deep SVDD, AdaCLIP, learned Channel Selector, …).
- A pipeline that has already been statistically initialised — see Statistical Training. Gradient training is Phase 2 of the two-phase model.
- A datamodule producing the data shape your pipeline expects (typically labelled or self-supervised, depending on the loss node).
- Loss and metric nodes wired into the pipeline.
Recipe¶
from cuvis_ai_core.trainer import GradientTrainer
from cuvis_ai_core.config import OptimizerConfig, SchedulerConfig
trainer = GradientTrainer(
max_epochs=50,
optimizer=OptimizerConfig(name="adam", lr=1e-3),
scheduler=SchedulerConfig(name="cosine", t_max=50),
callbacks=["early_stopping", "model_checkpoint"],
)
trainer.fit(pipeline=pipeline, datamodule=datamodule)
pipeline.save("artifacts/trained_pipeline.yaml")
trainer.save_trainrun("artifacts/trainrun.yaml")
What happens under the hood¶
- Trainer wraps the pipeline in a
LightningModule. - For each batch:
- nodes whose stages include
FORWARDrun a forward pass, - nodes whose stages include
LOSScompute the loss, - the optimizer steps,
- nodes whose stages include
METRIClog validation metrics. - Callbacks (early stopping, model checkpoint) fire at epoch boundaries.
- At the end,
save_trainrun()writes a YAML capturing the entire training config so the run can be reproduced.
Common variations¶
-
Resume from a checkpoint: load both the pipeline YAML and the Lightning checkpoint, then call
trainer.fit(pipeline, datamodule, ckpt_path=...). -
Multi-stage freezing: drive unfreezing via callbacks (e.g. unfreeze the channel selector after epoch 10).
-
Sweep configurations: pair with Hydra sweeps to run a grid of trainings.
Related¶
- Concepts → Training — two-phase model behind the trainer.
- Concepts → Execution stages — which nodes run when.
- Monitoring & Visualization — TensorBoard, callbacks, runtime visualisation.
- Profiling — find bottlenecks in long training runs.