Skip to content

Status: Needs Review

This page has not been reviewed for accuracy and completeness. Content may be outdated or contain errors.


Selector Nodes

Selector nodes cover current fixed, supervised, and trainable band-selection plus false-RGB generation workflows.

Channel Selectors

channel_selector

Channel selector nodes for HSI to RGB conversion.

This module provides port-based nodes for selecting spectral channels from hyperspectral cubes and composing RGB images for downstream processing (e.g., with AdaCLIP).

Selectors gate/reweight individual channels independently: output[c] = weight[c] * input[c] (diagonal operation, preserves channel count).

For cross-channel linear projection (full matrix, reduces channel count), see :mod:cuvis_ai.node.channel_mixer.

Normalization design

All channel selectors share a common RGB normalization strategy in ChannelSelectorBase, controlled by NormMode:

  • Percentile bounds (not absolute min/max): SpectralRadiance data contains outlier pixels whose absolute max can be 10x the median, compressing 99% of the image into the bottom of the brightness range. Using the 0.5th / 99.5th percentile clips these outliers and preserves visual dynamic range.

  • Per-channel [3] bounds: Separate min/max per R/G/B channel preserves colour balance. A single scalar bound would distort hue if one channel has a wider range than the others.

  • Three modes (NormMode): running (default) — warmup + percentile accumulation with optional freeze. The first warmup frames use per-frame normalization (visually good immediately) while accumulating global bounds. After warmup, accumulated bounds are used. By default, accumulation is frozen after 20 frames to prevent late outliers from changing brightness; set freeze_running_bounds_after_frames=None to keep legacy unbounded accumulation. statistical — pre-computed global percentiles via StatisticalTrainer. Use when exact global stats matter and a full first pass is acceptable. per_frame — each frame normalized independently; no inter-frame state. Use for unrelated images or single-frame pipelines.

  • Why warmup + accumulation (not EMA): Exponential moving averages have recency bias — for long videos the early-frame statistics are forgotten. Min/max accumulation bounds only ever expand (min-of-lows, max-of-highs) during the accumulation window, giving stable normalization without recency drift. The warmup period ensures the first few frames look natural before enough data has been accumulated.

NormMode

Bases: StrEnum

RGB normalization mode for channel selectors.

ChannelSelectorBase

ChannelSelectorBase(
    norm_mode=RUNNING,
    apply_gamma=True,
    freeze_running_bounds_after_frames=20,
    running_warmup_frames=_WARMUP_FRAMES,
    **kwargs,
)

Bases: Node

Base class for hyperspectral band selection strategies.

This base class defines the common input/output ports for band selection nodes and provides shared percentile-based RGB normalization (see module docstring for design rationale).

Subclasses should implement forward() and _compute_raw_rgb() (the latter is used by statistical_initialization and _running_normalize).

Parameters:

Name Type Description Default
norm_mode str | NormMode

RGB normalization mode. Default NormMode.RUNNING.

RUNNING
apply_gamma bool

Apply sRGB gamma curve after normalization. Default True. Lifts midtones so linear [0, 1] values appear natural on standard displays.

True
freeze_running_bounds_after_frames int | None

When norm_mode='running', stop updating running_min/running_max after this many forward calls. None keeps legacy behavior (never freeze). Default 20.

20
running_warmup_frames int

Number of initial running frames to normalize per-frame while collecting bounds. Set to 0 for fully stable live rendering from the first frame. Default 10.

_WARMUP_FRAMES
Ports

INPUT_SPECS cube : float32, shape (-1, -1, -1, -1) Hyperspectral cube in BHWC format. wavelengths : float32, shape (-1,) Wavelength array in nanometers. OUTPUT_SPECS rgb_image : float32, shape (-1, -1, -1, 3) Composed RGB image in BHWC format (0-1 range). band_info : dict Metadata about selected bands.

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    norm_mode: str | NormMode = NormMode.RUNNING,
    apply_gamma: bool = True,
    freeze_running_bounds_after_frames: int | None = 20,
    running_warmup_frames: int = _WARMUP_FRAMES,
    **kwargs: Any,
) -> None:
    if freeze_running_bounds_after_frames is not None:
        if (
            isinstance(freeze_running_bounds_after_frames, bool)
            or not isinstance(freeze_running_bounds_after_frames, int)
            or freeze_running_bounds_after_frames < 1
        ):
            raise ValueError(
                "freeze_running_bounds_after_frames must be an integer >= 1 or None"
            )
    if (
        isinstance(running_warmup_frames, bool)
        or not isinstance(running_warmup_frames, int)
        or running_warmup_frames < 0
    ):
        raise ValueError("running_warmup_frames must be an integer >= 0")
    super().__init__(
        norm_mode=str(norm_mode) if isinstance(norm_mode, NormMode) else norm_mode,
        apply_gamma=apply_gamma,
        freeze_running_bounds_after_frames=freeze_running_bounds_after_frames,
        running_warmup_frames=running_warmup_frames,
        **kwargs,
    )
    self.norm_mode = NormMode(norm_mode)
    self.apply_gamma = apply_gamma
    self.freeze_running_bounds_after_frames = freeze_running_bounds_after_frames
    self.running_warmup_frames = running_warmup_frames

    # Per-channel [3] running bounds for normalization.
    self.register_buffer("running_min", torch.full((3,), float("nan")))
    self.register_buffer("running_max", torch.full((3,), float("nan")))
    self._norm_frame_count = 0
    self._statistically_initialized = False

    if self.norm_mode == NormMode.STATISTICAL:
        self._requires_initial_fit_override = True
statistical_initialization
statistical_initialization(input_stream)

Compute global percentile bounds across the entire dataset.

Uses _compute_raw_rgb() to convert each batch, then accumulates per-channel percentile bounds (min-of-lows, max-of-highs).

Source code in cuvis_ai/node/channel_selector.py
def statistical_initialization(self, input_stream: InputStream) -> None:
    """Compute global percentile bounds across the entire dataset.

    Uses ``_compute_raw_rgb()`` to convert each batch, then accumulates
    per-channel percentile bounds (min-of-lows, max-of-highs).
    """
    for batch_data in input_stream:
        raw_rgb = self._compute_raw_rgb(batch_data["cube"], batch_data["wavelengths"])
        flat = raw_rgb.reshape(-1, 3).float()  # quantile() requires float/double
        frame_lo = torch.quantile(flat, self._NORM_QUANTILE_LOW, dim=0)
        frame_hi = torch.quantile(flat, self._NORM_QUANTILE_HIGH, dim=0)

        if torch.isnan(self.running_min).any():
            self.running_min.copy_(frame_lo)
            self.running_max.copy_(frame_hi)
        else:
            torch.minimum(self.running_min, frame_lo, out=self.running_min)
            torch.maximum(self.running_max, frame_hi, out=self.running_max)

    if torch.isnan(self.running_min).any():
        raise RuntimeError(f"{type(self).__name__}.statistical_initialization received no data")
    self._statistically_initialized = True

NDVISelector

NDVISelector(
    nir_nm=827.0,
    red_nm=668.0,
    colormap_min=-0.7,
    colormap_max=0.5,
    eps=1e-06,
    **kwargs,
)

Bases: _NormalizedDifferenceIndexBase

Normalized Difference Vegetation Index renderer.

Computes:

(CUBE(nir_nm) - CUBE(red_nm)) / (CUBE(nir_nm) + CUBE(red_nm))

Bands are resolved by nearest available sensor wavelength. The raw NDVI map is returned via index_image and rgb_image contains a colour-mapped render. The scalar NDVI image is mapped with the HSV-style colormap used by the Blood_OXY plugin XML.

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    nir_nm: float = 827.0,
    red_nm: float = 668.0,
    colormap_min: float = -0.7,
    colormap_max: float = 0.5,
    eps: float = 1.0e-6,
    **kwargs: Any,
) -> None:
    if colormap_max <= colormap_min:
        raise ValueError("colormap_max must be greater than colormap_min")
    kwargs.setdefault("norm_mode", NormMode.PER_FRAME)
    kwargs.setdefault("apply_gamma", False)
    super().__init__(
        primary_nm=nir_nm,
        secondary_nm=red_nm,
        eps=eps,
        nir_nm=float(nir_nm),
        red_nm=float(red_nm),
        colormap_min=float(colormap_min),
        colormap_max=float(colormap_max),
        **kwargs,
    )
    self.nir_nm = float(nir_nm)
    self.red_nm = float(red_nm)
    self.colormap = "hsv"
    self.colormap_min = float(colormap_min)
    self.colormap_max = float(colormap_max)
    self._colormap_range = self.colormap_max - self.colormap_min
index_name property
index_name

Canonical NDVI strategy name.

primary_label property
primary_label

NDVI primary operand label.

secondary_label property
secondary_label

NDVI secondary operand label.

forward
forward(cube, wavelengths, context=None, **_)

Compute NDVI plus colour-mapped RGB output.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Compute NDVI plus colour-mapped RGB output."""
    result = super().forward(cube=cube, wavelengths=wavelengths, context=context, **_)
    result["band_info"].update(
        {
            "rendering": f"{self.colormap}_colormap",
            "colormap": self.colormap,
            "colormap_min": self.colormap_min,
            "colormap_max": self.colormap_max,
        }
    )
    return result

FixedWavelengthSelector

FixedWavelengthSelector(
    target_wavelengths=(650.0, 550.0, 450.0), **kwargs
)

Bases: ChannelSelectorBase

Fixed wavelength band selection (e.g., 650, 550, 450 nm).

Selects bands nearest to the specified target wavelengths for R, G, B channels. This is the simplest band selection strategy that produces "true color-ish" images.

Parameters:

Name Type Description Default
target_wavelengths tuple[float, float, float]

Target wavelengths for R, G, B channels in nanometers. Default: (650.0, 550.0, 450.0)

(650.0, 550.0, 450.0)
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    target_wavelengths: tuple[float, float, float] = (650.0, 550.0, 450.0),
    **kwargs,
) -> None:
    super().__init__(target_wavelengths=target_wavelengths, **kwargs)
    self.target_wavelengths = target_wavelengths
forward
forward(cube, wavelengths, context=None, **_)

Select bands and compose RGB image.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths Tensor

Wavelength array [C].

required

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" and "band_info" keys.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Select bands and compose RGB image.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : torch.Tensor
        Wavelength array [C].

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" and "band_info" keys.
    """
    wavelengths_np = np.asarray(wavelengths, dtype=np.float32)

    # Find nearest bands
    indices = [self._nearest_band_index(wavelengths_np, nm) for nm in self.target_wavelengths]

    # Compose RGB (includes normalization via _normalize_rgb)
    rgb = self._compose_rgb(cube, indices)

    band_info = {
        "strategy": "baseline_false_rgb",
        "band_indices": indices,
        "band_wavelengths_nm": [float(wavelengths_np[i]) for i in indices],
        "target_wavelengths_nm": list(self.target_wavelengths),
    }

    return {"rgb_image": rgb, "band_info": band_info}

RangeAverageFalseRGBSelector

RangeAverageFalseRGBSelector(
    red_range=(580.0, 650.0),
    green_range=(500.0, 580.0),
    blue_range=(420.0, 500.0),
    **kwargs,
)

Bases: ChannelSelectorBase

Range-based false RGB selection by averaging bands per channel.

For each output channel (R/G/B), all spectral bands within the configured wavelength range are averaged per pixel. Channels with no matching bands are filled with zeros.

Parameters:

Name Type Description Default
red_range tuple[float, float]

Inclusive wavelength range for red channel in nanometers.

(580.0, 650.0)
green_range tuple[float, float]

Inclusive wavelength range for green channel in nanometers.

(500.0, 580.0)
blue_range tuple[float, float]

Inclusive wavelength range for blue channel in nanometers.

(420.0, 500.0)
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    red_range: tuple[float, float] = (580.0, 650.0),
    green_range: tuple[float, float] = (500.0, 580.0),
    blue_range: tuple[float, float] = (420.0, 500.0),
    **kwargs: Any,
) -> None:
    for name, rng in {
        "red_range": red_range,
        "green_range": green_range,
        "blue_range": blue_range,
    }.items():
        if len(rng) != 2 or rng[0] > rng[1]:
            raise ValueError(f"{name} must be (min_nm, max_nm) with min_nm <= max_nm")

    super().__init__(
        red_range=red_range, green_range=green_range, blue_range=blue_range, **kwargs
    )
    self.red_range = red_range
    self.green_range = green_range
    self.blue_range = blue_range

    # Static channel range boundaries [3, 2]; buffer so .to(device) moves it.
    self.register_buffer(
        "_ranges",
        torch.tensor(
            [
                [red_range[0], red_range[1]],
                [green_range[0], green_range[1]],
                [blue_range[0], blue_range[1]],
            ],
            dtype=torch.float32,
        ),
    )
    # Wavelength-dependent channel weights; lazily computed on first forward.
    self.register_buffer("_avg_weights", None, persistent=False)
    self.register_buffer("_avg_mask", None, persistent=False)
    self._cached_wl_key: tuple[float, ...] | None = None
forward
forward(cube, wavelengths, context=None, **_)

Average spectral bands inside RGB ranges and compose normalized RGB.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Average spectral bands inside RGB ranges and compose normalized RGB."""
    self._ensure_weights(wavelengths, cube.device)
    wavelengths_t = self._prepare_wavelengths_tensor(wavelengths, cube.device)

    # Vectorized channel averaging:
    # cube [B,H,W,C] and weights [3,C] -> rgb [B,H,W,3]
    rgb = self._compute_raw_rgb(cube, wavelengths)
    rgb = self._normalize_rgb(rgb)

    channel_indices = [
        torch.where(self._avg_mask[i])[0].tolist() for i in range(self._avg_mask.shape[0])
    ]
    channel_names = ["red", "green", "blue"]
    missing_channels = [
        channel_names[i] for i, indices in enumerate(channel_indices) if len(indices) == 0
    ]

    band_info = {
        "strategy": "range_average_false_rgb",
        "band_indices": channel_indices,  # [R, G, B]
        "band_wavelengths_nm": [wavelengths_t[idxs].tolist() for idxs in channel_indices],
        "ranges_nm": {
            "red": [float(self.red_range[0]), float(self.red_range[1])],
            "green": [float(self.green_range[0]), float(self.green_range[1])],
            "blue": [float(self.blue_range[0]), float(self.blue_range[1])],
        },
        "aggregation": "mean",
        "missing_channels": missing_channels,
    }
    return {"rgb_image": rgb, "band_info": band_info}

FastRGBSelector

FastRGBSelector(
    red_range=(580.0, 650.0),
    green_range=(500.0, 580.0),
    blue_range=(420.0, 500.0),
    normalization_strength=0.75,
    **kwargs,
)

Bases: ChannelSelectorBase

cuvis-next parity FastRGB renderer.

This selector mirrors the cuvis fast_rgb user-plugin behavior:

  • Per-channel contiguous spectral range averaging.
  • Dynamic per-frame normalization by global RGB mean when enabled.
  • Static reflectance-style scaling when normalization is disabled.
  • 8-bit quantization before returning float RGB in [0, 1].
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    red_range: tuple[float, float] = (580.0, 650.0),
    green_range: tuple[float, float] = (500.0, 580.0),
    blue_range: tuple[float, float] = (420.0, 500.0),
    normalization_strength: float = 0.75,
    **kwargs: Any,
) -> None:
    for name, rng in {
        "red_range": red_range,
        "green_range": green_range,
        "blue_range": blue_range,
    }.items():
        if len(rng) != 2 or rng[0] > rng[1]:
            raise ValueError(f"{name} must be (min_nm, max_nm) with min_nm <= max_nm")

    # FastRGB has its own scaling path; disable base normalization/gamma.
    kwargs.pop("norm_mode", None)
    kwargs.pop("apply_gamma", None)
    super().__init__(
        norm_mode=NormMode.PER_FRAME,
        apply_gamma=False,
        red_range=red_range,
        green_range=green_range,
        blue_range=blue_range,
        normalization_strength=float(normalization_strength),
        **kwargs,
    )
    self.red_range = red_range
    self.green_range = green_range
    self.blue_range = blue_range
    self.normalization_strength = float(normalization_strength)

    self.register_buffer(
        "_ranges",
        torch.tensor(
            [
                [red_range[0], red_range[1]],
                [green_range[0], green_range[1]],
                [blue_range[0], blue_range[1]],
            ],
            dtype=torch.float32,
        ),
    )
    self.register_buffer("_channel_bounds", None, persistent=False)
    self.register_buffer("_channel_valid", None, persistent=False)
    self._cached_wl_key: tuple[float, ...] | None = None
forward
forward(cube, wavelengths, context=None, **_)

Render fast_rgb output with cuvis-next parity scaling.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Render fast_rgb output with cuvis-next parity scaling."""
    wavelengths_t = self._prepare_wavelengths_tensor(wavelengths, device=cube.device)
    raw_rgb = self._compute_raw_rgb(cube, wavelengths_t)
    rgb, factor = self._fast_rgb_scale(raw_rgb)

    channel_indices: list[list[int]] = []
    missing_channels: list[str] = []
    channel_names = ["red", "green", "blue"]
    for c in range(3):
        if bool(self._channel_valid[c].item()):
            low = int(self._channel_bounds[c, 0].item())
            high = int(self._channel_bounds[c, 1].item())
            channel_indices.append(list(range(low, high + 1)))
        else:
            channel_indices.append([])
            missing_channels.append(channel_names[c])

    band_info = {
        "strategy": "fast_rgb",
        "band_indices": channel_indices,
        "band_wavelengths_nm": [wavelengths_t[idxs].tolist() for idxs in channel_indices],
        "ranges_nm": {
            "red": [float(self.red_range[0]), float(self.red_range[1])],
            "green": [float(self.green_range[0]), float(self.green_range[1])],
            "blue": [float(self.blue_range[0]), float(self.blue_range[1])],
        },
        "aggregation": "mean",
        "normalization_strength": float(self.normalization_strength),
        "applied_scale_factor": float(factor),
        "missing_channels": missing_channels,
    }
    return {"rgb_image": rgb, "band_info": band_info}

HighContrastSelector

HighContrastSelector(
    windows=((440, 500), (500, 580), (610, 700)),
    alpha=0.1,
    **kwargs,
)

Bases: ChannelSelectorBase

Data-driven band selection using spatial variance + Laplacian energy.

For each wavelength window, selects the band with the highest score based on: score = variance + alpha * Laplacian_energy

This produces "high contrast" images that may work better for visual anomaly detection.

Parameters:

Name Type Description Default
windows Sequence[tuple[float, float]]

Wavelength windows for Blue, Green, Red channels. Default: ((440, 500), (500, 580), (610, 700)) for visible spectrum.

((440, 500), (500, 580), (610, 700))
alpha float

Weight for Laplacian energy term. Default: 0.1

0.1
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    windows: Sequence[tuple[float, float]] = ((440, 500), (500, 580), (610, 700)),
    alpha: float = 0.1,
    **kwargs,
) -> None:
    super().__init__(windows=windows, alpha=alpha, **kwargs)
    self.windows = list(windows)
    self.alpha = alpha
forward
forward(cube, wavelengths, context=None, **_)

Select high-contrast bands and compose RGB image.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths Tensor

Wavelength array [C].

required

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" and "band_info" keys.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Select high-contrast bands and compose RGB image.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : torch.Tensor
        Wavelength array [C].

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" and "band_info" keys.
    """
    wavelengths_np = np.asarray(wavelengths, dtype=np.float32)
    # Use first batch item for band selection
    cube_np = cube[0].cpu().numpy()

    selected_indices = []
    for start, end in self.windows:
        mask = (wavelengths_np >= start) & (wavelengths_np <= end)
        window_indices = np.where(mask)[0]

        if len(window_indices) == 0:
            # Fallback to nearest single wavelength
            nearest = self._nearest_band_index(wavelengths_np, (start + end) / 2.0)
            selected_indices.append(int(nearest))
            continue

        scores = []
        for idx in window_indices:
            band = cube_np[..., idx]
            variance = float(np.var(band))
            lap_energy = float(np.mean(np.abs(laplace(band))))
            scores.append(variance + self.alpha * lap_energy)

        best_idx = int(window_indices[int(np.argmax(scores))])
        selected_indices.append(best_idx)

    rgb = self._compose_rgb(cube, selected_indices)

    band_info = {
        "strategy": "high_contrast",
        "band_indices": selected_indices,
        "band_wavelengths_nm": [float(wavelengths_np[i]) for i in selected_indices],
        "windows_nm": [[float(s), float(e)] for s, e in self.windows],
        "alpha": self.alpha,
    }

    return {"rgb_image": rgb, "band_info": band_info}

CIRSelector

CIRSelector(
    nir_nm=860.0, red_nm=670.0, green_nm=560.0, **kwargs
)

Bases: ChannelSelectorBase

Color Infrared (CIR) false color composition.

Maps NIR to Red, Red to Green, Green to Blue for false-color composites. This is useful for highlighting vegetation and certain anomalies.

Parameters:

Name Type Description Default
nir_nm float

Near-infrared wavelength in nm. Default: 860.0

860.0
red_nm float

Red wavelength in nm. Default: 670.0

670.0
green_nm float

Green wavelength in nm. Default: 560.0

560.0
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    nir_nm: float = 860.0,
    red_nm: float = 670.0,
    green_nm: float = 560.0,
    **kwargs,
) -> None:
    super().__init__(nir_nm=nir_nm, red_nm=red_nm, green_nm=green_nm, **kwargs)
    self.nir_nm = nir_nm
    self.red_nm = red_nm
    self.green_nm = green_nm
forward
forward(cube, wavelengths, context=None, **_)

Select CIR bands and compose false-color image.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths Tensor

Wavelength array [C].

required

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" and "band_info" keys.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Select CIR bands and compose false-color image.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : torch.Tensor
        Wavelength array [C].

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" and "band_info" keys.
    """
    wavelengths_np = np.asarray(wavelengths, dtype=np.float32).ravel()
    nir_idx, red_idx, green_idx = self._resolve_band_indices(wavelengths_np)
    indices = [nir_idx, red_idx, green_idx]
    rgb = self._normalize_rgb(self._compute_raw_rgb(cube, wavelengths_np))

    band_info = {
        "strategy": "cir_false_color",
        "band_indices": indices,
        "band_wavelengths_nm": [float(wavelengths_np[i]) for i in indices],
        "target_wavelengths_nm": [self.nir_nm, self.red_nm, self.green_nm],
        "channel_mapping": {"R": "NIR", "G": "Red", "B": "Green"},
    }

    return {"rgb_image": rgb, "band_info": band_info}

CIETristimulusFalseRGBSelector

CIETristimulusFalseRGBSelector(**kwargs)

Bases: ChannelSelectorBase

CIE 1931 tristimulus-based false RGB rendering.

Converts a hyperspectral cube to sRGB by integrating each pixel's spectrum with the CIE 1931 2-degree standard observer color matching functions (x_bar, y_bar, z_bar), applying a D65 white point normalization, and converting from CIE XYZ to linear sRGB.

Normalization and sRGB gamma are handled by ChannelSelectorBase (see apply_gamma parameter inherited from the base class).

This produces the most physically grounded false RGB and lands closest to the distribution SAM3's Perception Encoder expects.

For wavelengths outside the visible range (approx. >780 nm), the CMFs are zero, so NIR bands do not contribute to the output.

Source code in cuvis_ai/node/channel_selector.py
def __init__(self, **kwargs: Any) -> None:
    super().__init__(**kwargs)

    # Static XYZ -> linear sRGB matrix; buffer so .to(device) moves it.
    self.register_buffer(
        "_xyz_to_srgb_matrix",
        torch.from_numpy(self._XYZ_TO_SRGB.astype(np.float32)),
    )
    # Wavelength-dependent CMF integration weights; lazily computed on first forward.
    self.register_buffer("_cmf_weights", None, persistent=False)
    self._cached_wl_key: tuple[float, ...] | None = None
    self._cached_n_visible: int = 0
forward
forward(cube, wavelengths, context=None, **_)

Convert HSI cube to sRGB via CIE 1931 tristimulus integration.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths Tensor | ndarray

Wavelength array [C] in nanometers.

required

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" [B, H, W, 3] and "band_info".

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Convert HSI cube to sRGB via CIE 1931 tristimulus integration.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : torch.Tensor | np.ndarray
        Wavelength array [C] in nanometers.

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" [B, H, W, 3] and "band_info".
    """
    wavelengths_np = np.asarray(wavelengths, dtype=np.float64).ravel()
    if wavelengths_np.ndim == 0:
        raise ValueError("wavelengths must be a 1-D array")

    # Compute unnormalized linear sRGB, then normalize + gamma via base class.
    rgb = self._normalize_rgb(self._compute_raw_rgb(cube, wavelengths))

    band_info = {
        "strategy": "cie_tristimulus",
        "illuminant": "D65",
        "apply_gamma": self.apply_gamma,
        "sensor_bands_total": len(wavelengths_np),
        "sensor_bands_visible": self._cached_n_visible,
        "wavelength_range_nm": [float(wavelengths_np[0]), float(wavelengths_np[-1])],
    }

    return {"rgb_image": rgb, "band_info": band_info}

CameraEmulationFalseRGBSelector

CameraEmulationFalseRGBSelector(
    r_peak=610.0,
    g_peak=540.0,
    b_peak=460.0,
    r_sigma=40.0,
    g_sigma=35.0,
    b_sigma=30.0,
    **kwargs,
)

Bases: ChannelSelectorBase

Camera-emulation false RGB using smooth Gaussian sensitivity curves.

Defines three broad, smooth Gaussian weighting curves over the spectral bands that mimic R/G/B camera sensitivity (peaks at configurable wavelengths). The weight matrix W is [3, num_bands], applied as rgb = W @ spectrum. Non-negativity is enforced by construction.

This is simple, stable, and requires no training. Good middle ground between single-band selection and learned mapping.

Parameters:

Name Type Description Default
r_peak float

Red channel peak wavelength in nm. Default: 610.0

610.0
g_peak float

Green channel peak wavelength in nm. Default: 540.0

540.0
b_peak float

Blue channel peak wavelength in nm. Default: 460.0

460.0
r_sigma float

Red channel Gaussian sigma in nm. Default: 40.0

40.0
g_sigma float

Green channel Gaussian sigma in nm. Default: 35.0

35.0
b_sigma float

Blue channel Gaussian sigma in nm. Default: 30.0

30.0
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    r_peak: float = 610.0,
    g_peak: float = 540.0,
    b_peak: float = 460.0,
    r_sigma: float = 40.0,
    g_sigma: float = 35.0,
    b_sigma: float = 30.0,
    **kwargs: Any,
) -> None:
    super().__init__(
        r_peak=r_peak,
        g_peak=g_peak,
        b_peak=b_peak,
        r_sigma=r_sigma,
        g_sigma=g_sigma,
        b_sigma=b_sigma,
        **kwargs,
    )
    self.peaks = (r_peak, g_peak, b_peak)
    self.sigmas = (r_sigma, g_sigma, b_sigma)

    # Wavelength-dependent Gaussian weights; lazily computed on first forward.
    self.register_buffer("_channel_weights", None, persistent=False)
    self._cached_wl_key: tuple[float, ...] | None = None
forward
forward(cube, wavelengths, context=None, **_)

Convert HSI cube to false RGB using Gaussian camera sensitivity.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths Tensor | ndarray

Wavelength array [C] in nanometers.

required

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" [B, H, W, 3] and "band_info".

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: Any,
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Convert HSI cube to false RGB using Gaussian camera sensitivity.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : torch.Tensor | np.ndarray
        Wavelength array [C] in nanometers.

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" [B, H, W, 3] and "band_info".
    """
    wavelengths_np = np.asarray(wavelengths, dtype=np.float64).ravel()

    rgb = self._compute_raw_rgb(cube, wavelengths)
    rgb = self._normalize_rgb(rgb)

    band_info = {
        "strategy": "camera_emulation",
        "peaks_nm": {"R": self.peaks[0], "G": self.peaks[1], "B": self.peaks[2]},
        "sigmas_nm": {"R": self.sigmas[0], "G": self.sigmas[1], "B": self.sigmas[2]},
        "sensor_bands_total": len(wavelengths_np),
    }

    return {"rgb_image": rgb, "band_info": band_info}

SupervisedSelectorBase

SupervisedSelectorBase(
    num_spectral_bands,
    score_weights=(1.0, 1.0, 1.0),
    lambda_penalty=0.5,
    **kwargs,
)

Bases: ChannelSelectorBase

Base class for supervised band selection strategies.

This class adds an optional mask input port and implements common logic for statistical initialization via :meth:fit.

The mask is assumed to be binary (0/1), where 1 denotes the positive class (e.g. stone) and 0 denotes the negative class (e.g. lentil/background).

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    num_spectral_bands: int,
    score_weights: tuple[float, float, float] = (1.0, 1.0, 1.0),
    lambda_penalty: float = 0.5,
    **kwargs: Any,
) -> None:
    # Call super().__init__ FIRST so Serializable captures hparams correctly
    super().__init__(
        num_spectral_bands=num_spectral_bands,
        score_weights=score_weights,
        lambda_penalty=lambda_penalty,
        **kwargs,
    )
    # Then set instance attributes
    self.num_spectral_bands = num_spectral_bands
    self.score_weights = score_weights
    self.lambda_penalty = lambda_penalty
    # Initialize buffers with correct shapes (not empty)
    # selected_indices: always 3 for RGB
    # score buffers: num_spectral_bands
    self.register_buffer("selected_indices", torch.zeros(3, dtype=torch.long), persistent=True)
    self.register_buffer(
        "band_scores", torch.zeros(num_spectral_bands, dtype=torch.float32), persistent=True
    )
    self.register_buffer(
        "fisher_scores", torch.zeros(num_spectral_bands, dtype=torch.float32), persistent=True
    )
    self.register_buffer(
        "auc_scores", torch.zeros(num_spectral_bands, dtype=torch.float32), persistent=True
    )
    self.register_buffer(
        "mi_scores", torch.zeros(num_spectral_bands, dtype=torch.float32), persistent=True
    )
    # Use standard instance attribute for initialization tracking
    self._statistically_initialized = False
requires_initial_fit property
requires_initial_fit

Whether this node requires statistical initialization from training data.

Returns:

Type Description
bool

Always True for supervised band selectors.

statistical_initialization
statistical_initialization(input_stream)

Initialize band selection using supervised scoring.

Computes Fisher, AUC, and MI scores for each band, delegates to :meth:_select_bands for strategy-specific selection, and stores the 3 selected bands.

Parameters:

Name Type Description Default
input_stream InputStream

Training data stream with cube, mask, and wavelengths.

required

Raises:

Type Description
ValueError

If band selection doesn't return exactly 3 bands.

Source code in cuvis_ai/node/channel_selector.py
def statistical_initialization(self, input_stream: InputStream) -> None:
    """Initialize band selection using supervised scoring.

    Computes Fisher, AUC, and MI scores for each band, delegates to
    :meth:`_select_bands` for strategy-specific selection, and stores
    the 3 selected bands.

    Parameters
    ----------
    input_stream : InputStream
        Training data stream with cube, mask, and wavelengths.

    Raises
    ------
    ValueError
        If band selection doesn't return exactly 3 bands.
    """
    cubes, masks, wavelengths = self._collect_training_data(input_stream)
    band_scores, fisher_scores, auc_scores, mi_scores = _compute_band_scores_supervised(
        cubes,
        masks,
        wavelengths,
        self.score_weights,
    )
    corr_matrix = _compute_band_correlation_matrix(cubes, len(wavelengths))
    selected_indices = self._select_bands(band_scores, wavelengths, corr_matrix)
    if len(selected_indices) != 3:
        raise ValueError(f"{type(self).__name__} expected 3 bands, got {len(selected_indices)}")
    self._store_scores_and_indices(
        band_scores, fisher_scores, auc_scores, mi_scores, selected_indices
    )
forward
forward(cube, wavelengths, mask=None, context=None, **_)

Generate false-color RGB from selected bands.

Parameters:

Name Type Description Default
cube Tensor

Hyperspectral cube [B, H, W, C].

required
wavelengths ndarray

Wavelengths for each channel [C].

required
mask Tensor

Ground truth mask (unused in forward, required for initialization).

None
context Context

Pipeline execution context (unused).

None
**_ Any

Additional unused keyword arguments.

{}

Returns:

Type Description
dict[str, Any]

Dictionary with "rgb_image" [B, H, W, 3] and "band_info" metadata.

Raises:

Type Description
RuntimeError

If the node has not been statistically initialized.

Source code in cuvis_ai/node/channel_selector.py
def forward(
    self,
    cube: torch.Tensor,
    wavelengths: np.ndarray,
    mask: torch.Tensor | None = None,  # noqa: ARG002
    context: Context | None = None,  # noqa: ARG002
    **_: Any,
) -> dict[str, Any]:
    """Generate false-color RGB from selected bands.

    Parameters
    ----------
    cube : torch.Tensor
        Hyperspectral cube [B, H, W, C].
    wavelengths : np.ndarray
        Wavelengths for each channel [C].
    mask : torch.Tensor, optional
        Ground truth mask (unused in forward, required for initialization).
    context : Context, optional
        Pipeline execution context (unused).
    **_ : Any
        Additional unused keyword arguments.

    Returns
    -------
    dict[str, Any]
        Dictionary with "rgb_image" [B, H, W, 3] and "band_info" metadata.

    Raises
    ------
    RuntimeError
        If the node has not been statistically initialized.
    """
    if not self._statistically_initialized or self.selected_indices.numel() != 3:
        raise RuntimeError(f"{type(self).__name__} not fitted")

    wavelengths_np = np.asarray(wavelengths, dtype=np.float32)
    indices = self.selected_indices.tolist()
    rgb = self._compose_rgb(cube, indices)

    band_info = {
        "strategy": self._strategy_name,
        "band_indices": indices,
        "band_wavelengths_nm": [float(wavelengths_np[i]) for i in indices],
        "score_weights": list(self.score_weights),
        "lambda_penalty": float(self.lambda_penalty),
        **self._extra_band_info(wavelengths_np),
    }
    return {"rgb_image": rgb, "band_info": band_info}

SupervisedCIRSelector

SupervisedCIRSelector(
    windows=(
        (840.0, 910.0),
        (650.0, 720.0),
        (500.0, 570.0),
    ),
    score_weights=(1.0, 1.0, 1.0),
    lambda_penalty=0.5,
    **kwargs,
)

Bases: SupervisedSelectorBase

Supervised CIR/NIR band selection with window constraints.

Windows are typically set to:

- NIR: 840-910 nm
- Red: 650-720 nm
- Green: 500-570 nm

The selector chooses one band per window using a supervised score (Fisher + AUC + MI) with an mRMR-style redundancy penalty.

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    windows: Sequence[tuple[float, float]] = ((840.0, 910.0), (650.0, 720.0), (500.0, 570.0)),
    score_weights: tuple[float, float, float] = (1.0, 1.0, 1.0),
    lambda_penalty: float = 0.5,
    **kwargs: Any,
) -> None:
    super().__init__(
        score_weights=score_weights,
        lambda_penalty=lambda_penalty,
        windows=list(windows),
        **kwargs,
    )
    self.windows = list(windows)

SupervisedWindowedSelector

SupervisedWindowedSelector(
    windows=(
        (440.0, 500.0),
        (500.0, 580.0),
        (610.0, 700.0),
    ),
    score_weights=(1.0, 1.0, 1.0),
    lambda_penalty=0.5,
    **kwargs,
)

Bases: SupervisedSelectorBase

Supervised band selection constrained to visible RGB windows.

Similar to :class:HighContrastSelector, but uses label-driven scores. Default windows:

- Blue: 440-500 nm
- Green: 500-580 nm
- Red: 610-700 nm
Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    windows: Sequence[tuple[float, float]] = ((440.0, 500.0), (500.0, 580.0), (610.0, 700.0)),
    score_weights: tuple[float, float, float] = (1.0, 1.0, 1.0),
    lambda_penalty: float = 0.5,
    **kwargs: Any,
) -> None:
    super().__init__(
        score_weights=score_weights,
        lambda_penalty=lambda_penalty,
        windows=list(windows),
        **kwargs,
    )
    self.windows = list(windows)

SupervisedFullSpectrumSelector

SupervisedFullSpectrumSelector(
    score_weights=(1.0, 1.0, 1.0),
    lambda_penalty=0.5,
    **kwargs,
)

Bases: SupervisedSelectorBase

Supervised selection without window constraints.

Picks the top-3 discriminative bands globally with an mRMR-style redundancy penalty applied over the full spectrum.

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    score_weights: tuple[float, float, float] = (1.0, 1.0, 1.0),
    lambda_penalty: float = 0.5,
    **kwargs: Any,
) -> None:
    super().__init__(score_weights=score_weights, lambda_penalty=lambda_penalty, **kwargs)

SoftChannelSelector

SoftChannelSelector(
    n_select,
    input_channels,
    init_method="uniform",
    temperature_init=5.0,
    temperature_min=0.1,
    temperature_decay=0.9,
    hard=False,
    eps=1e-06,
    **kwargs,
)

Bases: Node

Soft channel selector with temperature-based Gumbel-Softmax selection.

This is a selector node — it gates/reweights individual channels independently: output[c] = weight[c] * input[c] (diagonal operation, preserves channel count).

For cross-channel linear projection that reduces channel count, see :class:cuvis_ai.node.channel_mixer.ConcreteChannelMixer or :class:cuvis_ai.node.channel_mixer.LearnableChannelMixer.

This node learns to select a subset of input channels using differentiable channel selection with temperature annealing. Supports:

  • Statistical initialization (uniform or importance-based)
  • Gradient-based optimization with temperature scheduling
  • Entropy and diversity regularization
  • Hard selection at inference time

Parameters:

Name Type Description Default
n_select int

Number of channels to select

required
input_channels int

Number of input channels

required
init_method ('uniform', 'variance')

Initialization method for channel weights (default: "uniform")

"uniform"
temperature_init float

Initial temperature for Gumbel-Softmax (default: 5.0)

5.0
temperature_min float

Minimum temperature (default: 0.1)

0.1
temperature_decay float

Temperature decay factor per epoch (default: 0.9)

0.9
hard bool

If True, use hard selection at inference (default: False)

False
eps float

Small constant for numerical stability (default: 1e-6)

1e-06

Attributes:

Name Type Description
channel_logits Parameter or Tensor

Unnormalized channel importance scores [n_channels]

temperature float

Current temperature for Gumbel-Softmax

Source code in cuvis_ai/node/channel_selector.py
def __init__(
    self,
    n_select: int,
    input_channels: int,
    init_method: Literal["uniform", "variance"] = "uniform",
    temperature_init: float = 5.0,
    temperature_min: float = 0.1,
    temperature_decay: float = 0.9,
    hard: bool = False,
    eps: float = 1e-6,
    **kwargs,
) -> None:
    self.n_select = n_select
    self.input_channels = input_channels
    self.init_method = init_method
    self.temperature_init = temperature_init
    self.temperature_min = temperature_min
    self.temperature_decay = temperature_decay
    self.hard = hard
    self.eps = eps

    super().__init__(
        n_select=n_select,
        input_channels=input_channels,
        init_method=init_method,
        temperature_init=temperature_init,
        temperature_min=temperature_min,
        temperature_decay=temperature_decay,
        hard=hard,
        eps=eps,
        **kwargs,
    )

    # Temperature tracking (not a parameter, managed externally)
    self.temperature = temperature_init
    self._n_channels = input_channels

    # Validate selection size
    if self.n_select > self._n_channels:
        raise ValueError(
            f"Cannot select {self.n_select} channels from {self._n_channels} available channels"  # nosec B608
        )

    # Initialize channel logits based on method - always as buffer
    if self.init_method == "uniform":
        # Uniform initialization
        logits = torch.zeros(self._n_channels)
    elif self.init_method == "variance":
        # Random initialization - will be refined with fit if called
        logits = torch.randn(self._n_channels) * 0.01
    else:
        raise ValueError(f"Unknown init_method: {self.init_method}")

    # Store as buffer initially
    self.register_buffer("channel_logits", logits)

    self._statistically_initialized = False
statistical_initialization
statistical_initialization(input_stream)

Initialize channel selection weights from data.

Parameters:

Name Type Description Default
input_stream InputStream

Iterator yielding dicts matching INPUT_SPECS (port-based format) Expected format: {"data": tensor} where tensor is BHWC

required
Source code in cuvis_ai/node/channel_selector.py
def statistical_initialization(self, input_stream: InputStream) -> None:
    """Initialize channel selection weights from data.

    Parameters
    ----------
    input_stream : InputStream
        Iterator yielding dicts matching INPUT_SPECS (port-based format)
        Expected format: {"data": tensor} where tensor is BHWC
    """
    # Collect statistics from first batch to determine n_channels
    first_batch = next(iter(input_stream))
    x = first_batch["data"]

    if x is None:
        raise ValueError("No data provided for selector initialization")

    self._n_channels = x.shape[-1]

    if self.n_select > self._n_channels:
        raise ValueError(
            f"Cannot select {self.n_select} channels from {self._n_channels} available channels"  # nosec B608
        )

    # Initialize channel logits based on method
    if self.init_method == "uniform":
        # Uniform initialization
        logits = torch.zeros(self._n_channels)
    elif self.init_method == "variance":
        # Importance-based initialization using channel variance
        acc = WelfordAccumulator(self._n_channels)
        acc.update(x.reshape(-1, x.shape[-1]))
        for batch_data in input_stream:
            x_batch = batch_data["data"]
            if x_batch is not None:
                acc.update(x_batch.reshape(-1, x_batch.shape[-1]))

        variance = acc.var  # [C]

        # Use log variance as initial logits (high variance = high importance)
        logits = torch.log(variance + self.eps)
    else:
        raise ValueError(f"Unknown init_method: {self.init_method}")

    # Store as buffer
    self.channel_logits.data[:] = logits.clone()
    self._statistically_initialized = True
update_temperature
update_temperature(epoch=None, step=None)

Update temperature with decay schedule.

Parameters:

Name Type Description Default
epoch int

Current epoch number (used for per-epoch decay)

None
step int

Current training step (for more granular control)

None
Source code in cuvis_ai/node/channel_selector.py
def update_temperature(self, epoch: int | None = None, step: int | None = None) -> None:
    """Update temperature with decay schedule.

    Parameters
    ----------
    epoch : int, optional
        Current epoch number (used for per-epoch decay)
    step : int, optional
        Current training step (for more granular control)
    """
    if epoch is not None:
        # Exponential decay per epoch
        self.temperature = max(
            self.temperature_min, self.temperature_init * (self.temperature_decay**epoch)
        )
get_selection_weights
get_selection_weights(hard=None)

Get current channel selection weights.

Parameters:

Name Type Description Default
hard bool

If True, use hard selection (top-k). If None, uses self.hard.

None

Returns:

Type Description
Tensor

Selection weights [n_channels] summing to n_select

Source code in cuvis_ai/node/channel_selector.py
def get_selection_weights(self, hard: bool | None = None) -> Tensor:
    """Get current channel selection weights.

    Parameters
    ----------
    hard : bool, optional
        If True, use hard selection (top-k). If None, uses self.hard.

    Returns
    -------
    Tensor
        Selection weights [n_channels] summing to n_select
    """
    if hard is None:
        hard = self.hard and not self.training

    if hard:
        # Hard selection: top-k channels
        _, top_indices = torch.topk(self.channel_logits, self.n_select)
        weights = torch.zeros_like(self.channel_logits)
        weights[top_indices] = 1.0
    else:
        # Soft selection with Gumbel-Softmax
        # First, compute selection probabilities
        probs = F.softmax(self.channel_logits / self.temperature, dim=-1)

        # Scale to sum to n_select instead of 1
        weights = probs * self.n_select

    return weights
forward
forward(data, **_)

Apply soft channel selection to input.

Parameters:

Name Type Description Default
data Tensor

Input tensor [B, H, W, C]

required

Returns:

Type Description
dict[str, Tensor]

Dictionary with "selected" key containing reweighted channels and optional "weights" key containing selection weights

Source code in cuvis_ai/node/channel_selector.py
def forward(self, data: Tensor, **_: Any) -> dict[str, Tensor]:
    """Apply soft channel selection to input.

    Parameters
    ----------
    data : Tensor
        Input tensor [B, H, W, C]

    Returns
    -------
    dict[str, Tensor]
        Dictionary with "selected" key containing reweighted channels
        and optional "weights" key containing selection weights
    """
    # Get selection weights
    weights = self.get_selection_weights()

    # Apply channel-wise weighting: [B, H, W, C] * [C]
    selected = data * weights.view(1, 1, 1, -1)

    # Prepare output dictionary - weights always exposed for loss/metric nodes
    outputs = {"selected": selected, "weights": weights}

    return outputs

TopKIndices

TopKIndices(k, **kwargs)

Bases: Node

Utility node that surfaces the top-k channel indices from selector weights.

This node extracts the indices of the top-k weighted channels from a selector's weight vector. Useful for introspection and reporting which channels were selected.

Parameters:

Name Type Description Default
k int

Number of top indices to return

required

Attributes:

Name Type Description
k int

Number of top indices to return

Source code in cuvis_ai/node/channel_selector.py
def __init__(self, k: int, **kwargs: Any) -> None:
    self.k = int(k)

    # Extract Node base parameters from kwargs to avoid duplication
    name = kwargs.pop("name", None)
    execution_stages = kwargs.pop("execution_stages", None)

    super().__init__(
        name=name,
        execution_stages=execution_stages,
        k=self.k,
        **kwargs,
    )
forward
forward(weights, **_)

Return the indices of the top-k weighted channels.

Parameters:

Name Type Description Default
weights Tensor

Channel selection weights [n_channels]

required

Returns:

Type Description
dict[str, Tensor]

Dictionary with "indices" key containing top-k indices

Source code in cuvis_ai/node/channel_selector.py
def forward(self, weights: torch.Tensor, **_: Any) -> dict[str, torch.Tensor]:
    """Return the indices of the top-k weighted channels.

    Parameters
    ----------
    weights : torch.Tensor
        Channel selection weights [n_channels]

    Returns
    -------
    dict[str, torch.Tensor]
        Dictionary with "indices" key containing top-k indices
    """
    top_k = min(self.k, weights.shape[-1]) if weights.numel() else 0
    if top_k == 0:
        return {"indices": torch.zeros(0, dtype=torch.int64, device=weights.device)}

    _, indices = torch.topk(weights, top_k)
    return {"indices": indices}