Status: Needs Review
This page has not been reviewed for accuracy and completeness. Content may be outdated or contain errors.
Tutorial: Learnable Channel Selection with Two-Phase Training¶
Learn how to build an anomaly detection pipeline with learnable channel selection using two-phase training (statistical initialization + gradient optimization).
Overview¶
What You'll Learn:
- Building on the RX Statistical tutorial with learnable channel selection
- Understanding Gumbel-Softmax for differentiable channel selection
- Two-phase training: statistical initialization → gradient optimization
- Loss composition with regularizers (BCE + entropy + diversity)
- Using callbacks for early stopping and model checkpointing
- Visualizing selected channels and training progress
Prerequisites:
- RX Statistical Tutorial - Complete this first!
- Understanding of Two-Phase Training
- Familiarity with Execution Stages
- Basic knowledge of gradient-based optimization
Time: ~30 minutes
Perfect for: Users who want to learn gradient-based training, channel selection, and advanced training strategies in CUVIS.AI.
Background¶
Why Channel Selection?¶
Hyperspectral data often contains 61+ spectral channels, but: - Not all channels are informative for a given task - Redundant channels add noise and computation cost - Channel selection can improve performance and interpretability
Manual selection (e.g., RGB channels at 650, 550, 450 nm) works but is suboptimal.
Learnable selection lets the model automatically discover the best channels for anomaly detection.
Gumbel-Softmax for Channel Selection¶
The challenge: Channel selection is discrete (select channel i or not), which isn't differentiable.
Solution: Gumbel-Softmax relaxation allows gradient-based optimization:
weights: Learnable parameters per channeltemperature: Controls discreteness (high → soft, low → hard)softmax: Makes weights differentiable
Temperature Annealing: Start with high temperature (smooth distribution), gradually decrease to make selection sharper.
Two-Phase Training Strategy¶
Phase 1: Statistical Initialization - Initialize normalizer statistics (min/max) - Initialize selector weights (variance-based or uniform) - Initialize RX detector (background covariance) - All nodes frozen (no gradients)
Phase 2: Gradient Training - Unfreeze: Selector, RX, LogitHead - Keep frozen: Normalizer (fixed statistics) - Optimize with BCE loss + regularizers - Use callbacks (early stopping, checkpointing)
This strategy combines the stability of statistical initialization with the power of gradient optimization.
Step 1: Setup (Building on RX Tutorial)¶
Step 1: Data Loading¶
This tutorial extends the RX Statistical Tutorial. We'll use the same data setup:
from pathlib import Path
from cuvis_ai_core.data.datasets import SingleCu3sDataModule
from cuvis_ai_core.pipeline.pipeline import CuvisPipeline
from cuvis_ai.node.data import LentilsAnomalyDataNode
# Setup datamodule
datamodule = SingleCu3sDataModule(
data_dir="data/lentils",
batch_size=4,
num_workers=0,
)
datamodule.setup(stage="fit")
# Create pipeline
pipeline = CuvisPipeline("Channel_Selector")
# Data node (same as RX tutorial)
data_node = LentilsAnomalyDataNode(normal_class_ids=[0, 1])
Step 2: Build Pipeline with Selector¶
Preprocessing¶
Normalization¶
from cuvis_ai.node.normalization import MinMaxNormalizer
normalizer = MinMaxNormalizer(
eps=1.0e-6,
use_running_stats=True, # Will be initialized in Phase 1
)
Add Soft Channel Selector (NEW!)¶
from cuvis_ai.node.selector import SoftChannelSelector
selector = SoftChannelSelector(
n_select=3, # Select 3 channels (for RGB-like output)
input_channels=61, # Input hyperspectral channels
init_method="variance", # Initialize by channel variance
temperature_init=5.0, # Start with soft selection
temperature_min=0.1, # End with sharp selection
temperature_decay=0.9, # Decay rate per epoch
hard=False, # Soft selection during training
eps=1.0e-6,
)
Key Parameters:
- n_select=3: Number of output channels (like RGB)
- init_method="variance": Initialize weights by channel variance
- Temperature schedule: 5.0 → 0.1 with decay 0.9
- Epoch 0: temperature = 5.0 (soft selection)
- Epoch 10: temperature ≈ 1.94 (medium)
- Epoch 20: temperature ≈ 0.76 (sharper)
- Epoch 30: temperature = 0.1 (nearly discrete)
Add RX Detector¶
Important: RX now operates on 3 selected channels instead of 61!
from cuvis_ai.anomaly.rx_detector import RXGlobal
from cuvis_ai.node.conversion import ScoreToLogit
rx = RXGlobal(
num_channels=15, # Actually selector outputs 3*n_select channels
eps=1.0e-6,
)
logit_head = ScoreToLogit(init_scale=1.0, init_bias=0.0)
Add Decision Node¶
Step 3: Add Loss Functions and Regularizers¶
Loss Composition¶
Main Loss: BCE with Logits¶
from cuvis_ai.node.losses import AnomalyBCEWithLogits
bce_loss = AnomalyBCEWithLogits(
name="bce",
weight=10.0, # High weight for main task
pos_weight=None, # No class imbalance weighting
)
Entropy Regularization¶
Regularizer 1: Entropy Loss¶
Encourages uniform channel selection distribution (prevents mode collapse):
from cuvis_ai.node.losses import SelectorEntropyRegularizer
entropy_loss = SelectorEntropyRegularizer(
name="entropy",
weight=0.1, # Regularization strength
target_entropy=None, # Use maximum entropy as target
)
Why? Without entropy loss, selector might focus on only 1-2 channels.
Diversity Regularization¶
Regularizer 2: Diversity Loss¶
Encourages selected channels to be diverse (not redundant):
from cuvis_ai.node.losses import SelectorDiversityRegularizer
diversity_loss = SelectorDiversityRegularizer(
name="diversity",
weight=0.01, # Weak regularization
)
Total Loss:
Step 4: Add Metrics and Visualization¶
Monitoring¶
Metrics¶
from cuvis_ai.node.metrics import AnomalyDetectionMetrics
metrics_anomaly = AnomalyDetectionMetrics(name="metrics_anomaly")
Visualization Nodes¶
from cuvis_ai.node.visualizations import AnomalyMask, CubeRGBVisualizer
# Anomaly mask visualization (same as RX tutorial)
viz_mask = AnomalyMask(name="mask", channel=30, up_to=5)
# NEW: RGB visualization of selected channels
viz_rgb = CubeRGBVisualizer(
name="rgb",
up_to=5, # Visualize first 5 samples
)
TensorBoard Monitoring¶
from cuvis_ai.node.monitor import TensorBoardMonitorNode
tensorboard_node = TensorBoardMonitorNode(
output_dir="runs/tensorboard",
run_name=pipeline.name,
)
Step 5: Connect the Pipeline¶
pipeline.connect(
# Processing flow: data → normalize → selector → RX → logits → decisions
(data_node.outputs.cube, normalizer.data),
(normalizer.normalized, selector.data),
(selector.selected, rx.data),
(rx.scores, logit_head.scores),
(logit_head.logits, decider.logits),
# Loss flow
(logit_head.logits, bce_loss.predictions),
(data_node.outputs.mask, bce_loss.targets),
(selector.weights, entropy_loss.weights),
(selector.weights, diversity_loss.weights),
# Metric flow
(decider.decisions, metrics_anomaly.decisions),
(data_node.outputs.mask, metrics_anomaly.targets),
(metrics_anomaly.metrics, tensorboard_node.metrics),
# Visualization flow
(decider.decisions, viz_mask.decisions),
(data_node.outputs.mask, viz_mask.mask),
(data_node.outputs.cube, viz_mask.cube),
(data_node.outputs.cube, viz_rgb.cube),
(selector.weights, viz_rgb.weights), # Visualize channel selection
(data_node.outputs.wavelengths, viz_rgb.wavelengths),
(viz_mask.artifacts, tensorboard_node.artifacts),
(viz_rgb.artifacts, tensorboard_node.artifacts),
)
Pipeline Flow:
graph LR
A[Data] -->|cube| B[Normalizer]
B -->|normalized| C[SoftChannelSelector]
C -->|selected 3ch| D[RXGlobal]
D -->|scores| E[ScoreToLogit]
E -->|logits| F[BinaryDecider]
E -->|logits| G[BCELoss]
A -->|mask| G
C -->|weights| H[EntropyLoss]
C -->|weights| I[DiversityLoss]
F -->|decisions| J[Metrics]
A -->|mask| J
J -->|metrics| K[TensorBoard]
F -->|decisions| L[VisualizeM ask]
A -->|cube| M[VisualizeRGB]
C -->|weights| M
Step 6: Phase 1 - Statistical Initialization¶
Initialize nodes without gradients:
from cuvis_ai_core.training import StatisticalTrainer
from loguru import logger
logger.info("Phase 1: Statistical initialization...")
stat_trainer = StatisticalTrainer(
pipeline=pipeline,
datamodule=datamodule,
)
stat_trainer.fit()
What Gets Initialized: 1. MinMaxNormalizer: Collects global min/max 2. SoftChannelSelector: Initializes weights by channel variance 3. RXGlobal: Computes covariance on selected 3 channels
Console Output:
[INFO] Phase 1: Statistical initialization...
[INFO] MinMaxNormalizer: min=-0.123, max=1.987
[INFO] SoftChannelSelector: Initialized weights by variance
[INFO] RXGlobal: Computed background statistics (3 selected channels)
[INFO] Statistical initialization complete
Step 7: Phase 2 - Unfreeze Nodes for Gradient Training¶
Selectively unfreeze nodes for optimization:
logger.info("Phase 2: Unfreezing selector and RX for gradient training...")
unfreeze_node_names = [selector.name, rx.name, logit_head.name]
pipeline.unfreeze_nodes_by_name(unfreeze_node_names)
logger.info(f"Unfrozen nodes: {unfreeze_node_names}")
Freeze Status After Unfreezing: - ✅ Unfrozen (trainable): SoftChannelSelector, RXGlobal, ScoreToLogit - 🔒 Frozen (fixed): MinMaxNormalizer, Data Node - Loss and metric nodes don't have weights
Step 8: Configure Training with Callbacks¶
Setup Training Configuration¶
from cuvis_ai_core.training.config import (
TrainingConfig,
CallbacksConfig,
EarlyStoppingConfig,
ModelCheckpointConfig,
SchedulerConfig,
)
from omegaconf import OmegaConf
# Load training config from Hydra
training_cfg = TrainingConfig.from_dict(
OmegaConf.to_container(cfg.training, resolve=True)
)
Add Early Stopping¶
Stop training if BCE loss doesn't improve:
if training_cfg.trainer.callbacks is None:
training_cfg.trainer.callbacks = CallbacksConfig()
training_cfg.trainer.callbacks.early_stopping.append(
EarlyStoppingConfig(
monitor="train/bce", # Monitor training BCE loss
mode="min", # Lower is better
patience=20, # Stop after 20 epochs without improvement
)
)
Add Learning Rate Scheduler¶
Reduce LR when validation IoU plateaus:
training_cfg.scheduler = SchedulerConfig(
name="reduce_on_plateau",
monitor="metrics_anomaly/iou", # Monitor validation IoU
mode="max", # Higher is better
factor=0.5, # Reduce LR by 50%
patience=5, # Wait 5 epochs before reducing
)
Add Model Checkpointing¶
Save top-3 models by validation IoU:
training_cfg.trainer.callbacks.checkpoint = ModelCheckpointConfig(
dirpath="outputs/checkpoints",
monitor="metrics_anomaly/iou", # Track validation IoU
mode="max", # Save highest IoU models
save_top_k=3, # Keep top 3 checkpoints
save_last=True, # Also save last checkpoint
filename="{epoch:02d}", # Filename pattern
verbose=True,
)
Step 9: Phase 3 - Gradient Training¶
Run gradient-based optimization:
from cuvis_ai_core.training import GradientTrainer
logger.info("Phase 3: Gradient-based channel selection optimization...")
grad_trainer = GradientTrainer(
pipeline=pipeline,
datamodule=datamodule,
loss_nodes=[bce_loss], # Can add entropy_loss, diversity_loss here
metric_nodes=[metrics_anomaly],
trainer_config=training_cfg.trainer,
optimizer_config=training_cfg.optimizer,
monitors=[tensorboard_node],
)
grad_trainer.fit()
Training Loop: 1. Forward pass through unfrozen nodes 2. Compute losses: BCE + entropy + diversity 3. Backpropagation and weight updates 4. Temperature decay for selector 5. Validation every N epochs 6. Checkpointing and early stopping
Console Output:
[INFO] Phase 3: Gradient training...
Epoch 0: train/bce=0.623, metrics_anomaly/iou=0.671, lr=0.001
Epoch 1: train/bce=0.512, metrics_anomaly/iou=0.709, lr=0.001
Epoch 5: train/bce=0.387, metrics_anomaly/iou=0.784, lr=0.001
[INFO] Reducing learning rate to 0.0005
Epoch 10: train/bce=0.298, metrics_anomaly/iou=0.812, lr=0.0005
[INFO] Saved checkpoint: epoch=10, iou=0.812
...
[INFO] Early stopping triggered (patience=20)
[INFO] Best checkpoint: epoch=23, iou=0.847
Step 10: Evaluation¶
Validation on Best Checkpoint¶
logger.info("Running validation with best checkpoint...")
val_results = grad_trainer.validate()
logger.info(f"Validation results: {val_results}")
Test Evaluation¶
logger.info("Running test evaluation with best checkpoint...")
test_results = grad_trainer.test()
logger.info(f"Test results: {test_results}")
Expected Results: - IoU: 0.80-0.85 (improved from RX baseline ~0.72) - Precision: 0.85-0.90 - Recall: 0.80-0.87
Step 11: Analyze Selected Channels¶
Monitoring Channel Weights¶
Inspect Selector Weights¶
# Get final channel weights
selector_weights = selector.get_selection_weights()
print(f"Selector weights shape: {selector_weights.shape}") # [61]
print(f"Top 3 channels: {selector_weights.topk(3).indices}")
# Example output:
# Top 3 channels: tensor([18, 42, 55])
# These correspond to specific wavelengths
Visualize Channel Selection¶
TensorBoard will show: - RGB visualization: False-color image from selected channels - Weight plot: Bar chart of channel weights - Weight evolution: How weights changed during training
Step 12: Save Pipeline and TrainRun Config¶
Save Trained Pipeline¶
from cuvis_ai_core.training.config import PipelineMetadata
results_dir = Path("outputs/trained_models")
pipeline_output_path = results_dir / "Channel_Selector.yaml"
pipeline.save_to_file(
str(pipeline_output_path),
metadata=PipelineMetadata(
name=pipeline.name,
description="Channel Selector with two-phase training",
tags=["gradient", "statistical", "channel_selector", "rx"],
author="cuvis.ai",
),
)
Save TrainRun Config¶
from cuvis_ai_core.training.config import TrainRunConfig
trainrun_config = TrainRunConfig(
name="channel_selector",
pipeline=pipeline.serialize(),
data=datamodule_config,
training=training_cfg,
output_dir="outputs",
loss_nodes=["bce"],
metric_nodes=["metrics_anomaly"],
freeze_nodes=[],
unfreeze_nodes=unfreeze_node_names, # Track which nodes were trained
)
trainrun_config.save_to_file("outputs/trained_models/channel_selector_trainrun.yaml")
Complete Example Script¶
Full runnable script (examples/channel_selector.py):
from pathlib import Path
import hydra
from cuvis_ai_core.data.datasets import SingleCu3sDataModule
from cuvis_ai_core.pipeline.pipeline import CuvisPipeline
from cuvis_ai_core.training import GradientTrainer, StatisticalTrainer
from cuvis_ai_core.training.config import (
CallbacksConfig,
EarlyStoppingConfig,
ModelCheckpointConfig,
PipelineMetadata,
SchedulerConfig,
TrainingConfig,
TrainRunConfig,
)
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from cuvis_ai.anomaly.rx_detector import RXGlobal
from cuvis_ai.node.conversion import ScoreToLogit
from cuvis_ai.deciders.binary_decider import BinaryDecider
from cuvis_ai.node.data import LentilsAnomalyDataNode
from cuvis_ai.node.losses import (
AnomalyBCEWithLogits,
SelectorDiversityRegularizer,
SelectorEntropyRegularizer,
)
from cuvis_ai.node.metrics import AnomalyDetectionMetrics
from cuvis_ai.node.monitor import TensorBoardMonitorNode
from cuvis_ai.node.normalization import MinMaxNormalizer
from cuvis_ai.node.selector import SoftChannelSelector
from cuvis_ai.node.visualizations import AnomalyMask, CubeRGBVisualizer
@hydra.main(config_path="../configs/", config_name="trainrun/channel_selector", version_base=None)
def main(cfg: DictConfig) -> None:
"""Channel Selector with Two-Phase Training."""
logger.info("=== Channel Selector for Anomaly Detection ===")
output_dir = Path(cfg.output_dir)
# Stage 1: Setup datamodule
datamodule = SingleCu3sDataModule(**cfg.data)
datamodule.setup(stage="fit")
# Stage 2: Build pipeline with selector
pipeline = CuvisPipeline("Channel_Selector")
data_node = LentilsAnomalyDataNode(normal_class_ids=[0, 1])
normalizer = MinMaxNormalizer(eps=1.0e-6, use_running_stats=True)
selector = SoftChannelSelector(
n_select=3,
input_channels=61,
init_method="variance",
temperature_init=5.0,
temperature_min=0.1,
temperature_decay=0.9,
)
rx = RXGlobal(num_channels=15, eps=1.0e-6)
logit_head = ScoreToLogit(init_scale=1.0, init_bias=0.0)
decider = BinaryDecider(threshold=0.5)
bce_loss = AnomalyBCEWithLogits(name="bce", weight=10.0)
entropy_loss = SelectorEntropyRegularizer(name="entropy", weight=0.1)
diversity_loss = SelectorDiversityRegularizer(name="diversity", weight=0.01)
metrics_anomaly = AnomalyDetectionMetrics(name="metrics_anomaly")
viz_mask = AnomalyMask(name="mask", channel=30, up_to=5)
viz_rgb = CubeRGBVisualizer(name="rgb", up_to=5)
tensorboard_node = TensorBoardMonitorNode(
output_dir=str(output_dir / ".." / "tensorboard"),
run_name=pipeline.name,
)
# Stage 3: Connect pipeline
pipeline.connect(
(data_node.outputs.cube, normalizer.data),
(normalizer.normalized, selector.data),
(selector.selected, rx.data),
(rx.scores, logit_head.scores),
(logit_head.logits, bce_loss.predictions),
(data_node.outputs.mask, bce_loss.targets),
(selector.weights, entropy_loss.weights),
(selector.weights, diversity_loss.weights),
(logit_head.logits, decider.logits),
(decider.decisions, metrics_anomaly.decisions),
(data_node.outputs.mask, metrics_anomaly.targets),
(metrics_anomaly.metrics, tensorboard_node.metrics),
(decider.decisions, viz_mask.decisions),
(data_node.outputs.mask, viz_mask.mask),
(data_node.outputs.cube, viz_mask.cube),
(data_node.outputs.cube, viz_rgb.cube),
(selector.weights, viz_rgb.weights),
(data_node.outputs.wavelengths, viz_rgb.wavelengths),
(viz_mask.artifacts, tensorboard_node.artifacts),
(viz_rgb.artifacts, tensorboard_node.artifacts),
)
# Configure training
training_cfg = TrainingConfig.from_dict(OmegaConf.to_container(cfg.training, resolve=True))
if training_cfg.trainer.callbacks is None:
training_cfg.trainer.callbacks = CallbacksConfig()
training_cfg.trainer.callbacks.early_stopping.append(
EarlyStoppingConfig(monitor="train/bce", mode="min", patience=20)
)
training_cfg.scheduler = SchedulerConfig(
name="reduce_on_plateau",
monitor="metrics_anomaly/iou",
mode="max",
factor=0.5,
patience=5,
)
training_cfg.trainer.callbacks.checkpoint = ModelCheckpointConfig(
dirpath=str(output_dir / "checkpoints"),
monitor="metrics_anomaly/iou",
mode="max",
save_top_k=3,
save_last=True,
)
# Phase 1: Statistical initialization
logger.info("Phase 1: Statistical initialization...")
stat_trainer = StatisticalTrainer(pipeline=pipeline, datamodule=datamodule)
stat_trainer.fit()
# Phase 2: Unfreeze nodes
logger.info("Phase 2: Unfreezing selector and RX...")
unfreeze_node_names = [selector.name, rx.name, logit_head.name]
pipeline.unfreeze_nodes_by_name(unfreeze_node_names)
# Phase 3: Gradient training
logger.info("Phase 3: Gradient training...")
grad_trainer = GradientTrainer(
pipeline=pipeline,
datamodule=datamodule,
loss_nodes=[bce_loss],
metric_nodes=[metrics_anomaly],
trainer_config=training_cfg.trainer,
optimizer_config=training_cfg.optimizer,
monitors=[tensorboard_node],
)
grad_trainer.fit()
# Evaluation
grad_trainer.validate()
grad_trainer.test()
# Save pipeline
pipeline.save_to_file(
str(output_dir / "trained_models" / "Channel_Selector.yaml"),
metadata=PipelineMetadata(
name=pipeline.name,
description="Channel Selector with two-phase training",
tags=["gradient", "channel_selector"],
author="cuvis.ai",
),
)
if __name__ == "__main__":
main()
Run the example:
Troubleshooting¶
Issue: Channel Collapse (All Weight on 1-2 Channels)¶
Symptoms: Selector weights like [0.95, 0.03, 0.02, ...]
Solution: Increase entropy regularization:
Issue: Training Unstable (Loss Oscillates)¶
Solution 1: Lower learning rate:
Solution 2: Slower temperature decay:
selector = SoftChannelSelector(
temperature_init=10.0, # Higher init
temperature_decay=0.95, # Slower decay
)
Issue: No Improvement Over RX Baseline¶
Possible Causes: 1. Insufficient training epochs 2. Regularization too strong 3. Temperature decays too quickly
Solutions:
# Increase training budget
training_cfg.trainer.max_epochs = 50
# Reduce regularization
entropy_loss.weight = 0.05
diversity_loss.weight = 0.005
# Slower temperature schedule
selector.temperature_decay = 0.92
Issue: Checkpoints Not Saving¶
Cause: Checkpoint directory doesn't exist or validation not running
Solution:
# Ensure directory exists
output_dir.mkdir(parents=True, exist_ok=True)
# Verify validation runs
training_cfg.trainer.check_val_every_n_epoch = 1 # Validate every epoch
Next Steps¶
Build on this tutorial:
- Deep SVDD Tutorial - Replace RX with deep learning detector
- AdaCLIP Workflow - Try other channel selectors (DRCNN, Concrete)
- gRPC Workflow - Deploy trained pipeline as a service
Explore related concepts:
- Two-Phase Training Deep Dive - Detailed training strategy
- Execution Stages - Control node execution timing
- Gradient Trainer - Optimizer configuration
Explore related nodes:
- SoftChannelSelector - Selector details
- Loss Nodes - All available loss functions
- Callbacks - Advanced training control
Summary¶
In this tutorial, you learned:
✅ How to add learnable channel selection to an anomaly detection pipeline ✅ Two-phase training strategy: statistical init → gradient optimization ✅ Loss composition with regularizers ✅ Using callbacks for early stopping and checkpointing ✅ Temperature annealing for discrete selection ✅ Analyzing and visualizing selected channels
Key Takeaways: - Channel selection improves both performance and interpretability - Gumbel-Softmax enables gradient-based discrete optimization - Two-phase training combines stability (statistical) with power (gradient) - Regularizers (entropy, diversity) prevent mode collapse - Callbacks automate training management (early stopping, LR scheduling)
Now you're ready to explore deep learning-based anomaly detection!