Skip to content

scikit-learn Integration

Note

Requires scikit-learn: pip install pyspectrakit[sklearn]

spectrakit.sklearn.SpectralTransformer

Bases: BaseEstimator, TransformerMixin

scikit-learn transformer wrapping any SpectraKit function.

Wraps a SpectraKit processing function (e.g., baseline_als, normalize_snv, smooth_savgol) as a scikit-learn compatible transformer that can be used in sklearn.pipeline.Pipeline.

Parameters:

Name Type Description Default
func Callable[..., ndarray]

A SpectraKit processing function with signature func(intensities, **kwargs) -> np.ndarray.

required
**kwargs Any

Keyword arguments passed to func on each transform() call.

{}

Examples:

>>> from sklearn.pipeline import Pipeline as SkPipeline
>>> from spectrakit import baseline_als, normalize_snv
>>> from spectrakit.sklearn import SpectralTransformer
>>>
>>> pipe = SkPipeline([
...     ("baseline", SpectralTransformer(baseline_als, lam=1e6)),
...     ("normalize", SpectralTransformer(normalize_snv)),
... ])
>>> X_processed = pipe.fit_transform(X_raw)
Source code in src/spectrakit/sklearn/transformers.py
class SpectralTransformer(BaseEstimator, TransformerMixin):  # type: ignore[misc]
    """scikit-learn transformer wrapping any SpectraKit function.

    Wraps a SpectraKit processing function (e.g., ``baseline_als``,
    ``normalize_snv``, ``smooth_savgol``) as a scikit-learn compatible
    transformer that can be used in ``sklearn.pipeline.Pipeline``.

    Args:
        func: A SpectraKit processing function with signature
            ``func(intensities, **kwargs) -> np.ndarray``.
        **kwargs: Keyword arguments passed to ``func`` on each
            ``transform()`` call.

    Examples:
        >>> from sklearn.pipeline import Pipeline as SkPipeline
        >>> from spectrakit import baseline_als, normalize_snv
        >>> from spectrakit.sklearn import SpectralTransformer
        >>>
        >>> pipe = SkPipeline([
        ...     ("baseline", SpectralTransformer(baseline_als, lam=1e6)),
        ...     ("normalize", SpectralTransformer(normalize_snv)),
        ... ])
        >>> X_processed = pipe.fit_transform(X_raw)
    """

    def __init__(self, func: Callable[..., np.ndarray], **kwargs: Any) -> None:
        if not _HAS_SKLEARN:
            raise DependencyError(
                "scikit-learn is required for SpectralTransformer. "
                "Install with: pip install spectrakit[sklearn]"
            )
        self.func = func
        self.kwargs = kwargs

    def fit(
        self,
        X: np.ndarray,  # noqa: N803 — sklearn convention
        y: Any = None,
    ) -> SpectralTransformer:
        """No-op fit (stateless transformer).

        Args:
            X: Training data (ignored).
            y: Training labels (ignored).

        Returns:
            Self.
        """
        return self

    def transform(self, X: np.ndarray) -> np.ndarray:  # noqa: N803
        """Apply the wrapped SpectraKit function.

        Args:
            X: Input spectral data, shape ``(N, W)``.

        Returns:
            Processed spectral data, same shape.
        """
        return self.func(X, **self.kwargs)

    def get_params(self, deep: bool = True) -> dict[str, Any]:
        """Get transformer parameters (sklearn interface).

        Args:
            deep: If True, return nested params.

        Returns:
            Dict of parameters.
        """
        params: dict[str, Any] = {"func": self.func}
        params.update(self.kwargs)
        return params

    def set_params(self, **params: Any) -> SpectralTransformer:
        """Set transformer parameters (sklearn interface).

        Args:
            **params: Parameters to set.

        Returns:
            Self.
        """
        if "func" in params:
            self.func = params.pop("func")
        self.kwargs.update(params)
        return self

    def __repr__(self) -> str:
        func_name = getattr(self.func, "__name__", str(self.func))
        param_str = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items())
        if param_str:
            return f"SpectralTransformer({func_name}, {param_str})"
        return f"SpectralTransformer({func_name})"

fit

fit(X: ndarray, y: Any = None) -> SpectralTransformer

No-op fit (stateless transformer).

Parameters:

Name Type Description Default
X ndarray

Training data (ignored).

required
y Any

Training labels (ignored).

None

Returns:

Type Description
SpectralTransformer

Self.

Source code in src/spectrakit/sklearn/transformers.py
def fit(
    self,
    X: np.ndarray,  # noqa: N803 — sklearn convention
    y: Any = None,
) -> SpectralTransformer:
    """No-op fit (stateless transformer).

    Args:
        X: Training data (ignored).
        y: Training labels (ignored).

    Returns:
        Self.
    """
    return self

get_params

get_params(deep: bool = True) -> dict[str, Any]

Get transformer parameters (sklearn interface).

Parameters:

Name Type Description Default
deep bool

If True, return nested params.

True

Returns:

Type Description
dict[str, Any]

Dict of parameters.

Source code in src/spectrakit/sklearn/transformers.py
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """Get transformer parameters (sklearn interface).

    Args:
        deep: If True, return nested params.

    Returns:
        Dict of parameters.
    """
    params: dict[str, Any] = {"func": self.func}
    params.update(self.kwargs)
    return params

set_params

set_params(**params: Any) -> SpectralTransformer

Set transformer parameters (sklearn interface).

Parameters:

Name Type Description Default
**params Any

Parameters to set.

{}

Returns:

Type Description
SpectralTransformer

Self.

Source code in src/spectrakit/sklearn/transformers.py
def set_params(self, **params: Any) -> SpectralTransformer:
    """Set transformer parameters (sklearn interface).

    Args:
        **params: Parameters to set.

    Returns:
        Self.
    """
    if "func" in params:
        self.func = params.pop("func")
    self.kwargs.update(params)
    return self

transform

transform(X: ndarray) -> np.ndarray

Apply the wrapped SpectraKit function.

Parameters:

Name Type Description Default
X ndarray

Input spectral data, shape (N, W).

required

Returns:

Type Description
ndarray

Processed spectral data, same shape.

Source code in src/spectrakit/sklearn/transformers.py
def transform(self, X: np.ndarray) -> np.ndarray:  # noqa: N803
    """Apply the wrapped SpectraKit function.

    Args:
        X: Input spectral data, shape ``(N, W)``.

    Returns:
        Processed spectral data, same shape.
    """
    return self.func(X, **self.kwargs)