20. Skip to content

20. Data augmentation API

This page documents the data augmentation API. For usage patterns, see Data augmentation how-to.

20.1 What it is for

The data augmentation brick defines training-time, stochastic transforms for multiple modalities with deterministic contexts. [1][2]

20.2 Examples

List available ops:

from modssc.data_augmentation import available_ops

print(available_ops(modality="vision"))

Build and apply a pipeline:

import numpy as np
from modssc.data_augmentation import AugmentationContext, AugmentationPlan, StepConfig, build_pipeline

plan = AugmentationPlan(steps=(StepConfig(op_id="tabular.gaussian_noise", params={"std": 0.1}),))
pipeline = build_pipeline(plan)

x = np.zeros((4,), dtype=np.float32)
ctx = AugmentationContext(seed=0, sample_id=0, epoch=0)
print(pipeline(x, ctx=ctx))

The registry and plan schema are defined in src/modssc/data_augmentation/registry.py and src/modssc/data_augmentation/plan.py. [3][4]

20.3 API reference

ModSSC data augmentation brick.

This brick provides training-time (stochastic) transformations for multiple modalities (vision, text, tabular, audio, graph). It is designed to be:

  • Deterministic when requested (seed + epoch + sample_id => same output)
  • Backend-aware (NumPy by default; supports torch tensors without requiring torch at import)
  • Composable through a small plan/pipeline system
  • Extensible via a registry (contributors can add new operations without touching core code)

20.3.1 Notes

This is intentionally separate from :mod:modssc.preprocess, which is meant for offline and/or cacheable feature engineering (including embeddings with pretrained models). Augmentations are applied on-the-fly during training loops (future brick/orchestrator).

20.4 AugmentationContext dataclass

Deterministic context for augmentation.

20.4.0.1 Parameters

seed: Global seed for the experiment. sample_id: A stable identifier for the sample (e.g. dataset index). epoch: Current training epoch (or 0 for stateless usage). backend: Backend preference. "auto" uses torch if the input is a torch tensor. modality: Optional modality hint (used for validation only).

Source code in src/modssc/data_augmentation/types.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@dataclass(frozen=True)
class AugmentationContext:
    """Deterministic context for augmentation.

    Parameters
    ----------
    seed:
        Global seed for the experiment.
    sample_id:
        A stable identifier for the sample (e.g. dataset index).
    epoch:
        Current training epoch (or 0 for stateless usage).
    backend:
        Backend preference. "auto" uses torch if the input is a torch tensor.
    modality:
        Optional modality hint (used for validation only).
    """

    seed: int
    sample_id: int = 0
    epoch: int = 0
    backend: Backend = "auto"
    modality: Modality | None = None

20.5 AugmentationPipeline dataclass

A compiled augmentation pipeline.

Source code in src/modssc/data_augmentation/api.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@dataclass(frozen=True)
class AugmentationPipeline:
    """A compiled augmentation pipeline."""

    plan: AugmentationPlan
    ops: tuple[Any, ...]

    def apply(self, x: Any, *, ctx: AugmentationContext) -> Any:
        rng = make_context_rng(ctx)
        out = x
        for op in self.ops:
            out = op(out, rng=rng, ctx=ctx)
        return out

    def __call__(self, x: Any, *, ctx: AugmentationContext) -> Any:
        return self.apply(x, ctx=ctx)

20.6 AugmentationPlan dataclass

A sequence of augmentation steps.

20.6.0.1 Notes

Unlike preprocessing, augmentation is usually applied online (during training). Plans are still useful to describe pipelines declaratively and reproducibly.

Source code in src/modssc/data_augmentation/plan.py
25
26
27
28
29
30
31
32
33
34
35
36
37
@dataclass(frozen=True)
class AugmentationPlan:
    """A sequence of augmentation steps.

    Notes
    -----
    Unlike preprocessing, augmentation is usually applied *online* (during training).
    Plans are still useful to describe pipelines declaratively and reproducibly.
    """

    steps: tuple[StepConfig, ...]
    modality: Modality | None = None
    description: str | None = None

20.7 AugmentationStrategy dataclass

Weak/strong strategy container (useful for FixMatch-style algorithms).

Source code in src/modssc/data_augmentation/api.py
49
50
51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class AugmentationStrategy:
    """Weak/strong strategy container (useful for FixMatch-style algorithms)."""

    weak: AugmentationPipeline
    strong: AugmentationPipeline

    def apply(self, x: Any, *, ctx: AugmentationContext) -> tuple[Any, Any]:
        xw = self.weak.apply(x, ctx=ctx)
        xs = self.strong.apply(x, ctx=ctx)
        return xw, xs

20.8 GraphSample dataclass

Minimal graph container for augmentation.

This is intentionally small and compatible with common graph toolkits: it mirrors the key fields of PyG's Data (x, edge_index, edge_weight).

20.8.0.1 Notes

edge_index is expected to be shaped (2, E) or (E, 2). Augmentations will normalize it to (2, E) internally.

Source code in src/modssc/data_augmentation/types.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class GraphSample:
    """Minimal graph container for augmentation.

    This is intentionally small and compatible with common graph toolkits:
    it mirrors the key fields of PyG's ``Data`` (``x``, ``edge_index``, ``edge_weight``).

    Notes
    -----
    ``edge_index`` is expected to be shaped ``(2, E)`` or ``(E, 2)``. Augmentations will
    normalize it to ``(2, E)`` internally.
    """

    x: ArrayLike
    edge_index: ArrayLike
    edge_weight: ArrayLike | None = None
    meta: dict[str, Any] = field(default_factory=dict)

    def num_nodes(self) -> int:
        x = np.asarray(self.x)
        return int(x.shape[0])

20.9 StepConfig dataclass

A single augmentation step.

20.9.0.1 Parameters

op_id: Registry id of the augmentation operation (e.g. "vision.random_horizontal_flip"). params: Keyword parameters forwarded to the op constructor.

Source code in src/modssc/data_augmentation/plan.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
@dataclass(frozen=True)
class StepConfig:
    """A single augmentation step.

    Parameters
    ----------
    op_id:
        Registry id of the augmentation operation (e.g. ``"vision.random_horizontal_flip"``).
    params:
        Keyword parameters forwarded to the op constructor.
    """

    op_id: str
    params: dict[str, Any] = field(default_factory=dict)

20.10 available_ops(*, modality=None)

List registered operation ids.

Source code in src/modssc/data_augmentation/registry.py
30
31
32
33
34
35
36
37
38
39
def available_ops(*, modality: Modality | None = None) -> list[str]:
    """List registered operation ids."""
    if modality is None:
        return sorted(_OPS.keys())
    out: list[str] = []
    for k, cls in _OPS.items():
        m = getattr(cls, "modality", "any")
        if m == modality or m == "any":
            out.append(k)
    return sorted(out)

20.11 build_pipeline(plan)

Compile an :class:AugmentationPlan into an executable pipeline.

Source code in src/modssc/data_augmentation/api.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def build_pipeline(plan: AugmentationPlan) -> AugmentationPipeline:
    """Compile an :class:`AugmentationPlan` into an executable pipeline."""
    if not isinstance(plan.steps, tuple):
        raise DataAugmentationValidationError("plan.steps must be a tuple of StepConfig")
    ops = []
    for step in plan.steps:
        if not isinstance(step, StepConfig):
            raise DataAugmentationValidationError("Each plan step must be a StepConfig")
        op = _get_op(step.op_id, **(step.params or {}))
        if plan.modality is not None:
            op_modality = getattr(op, "modality", "any")
            if op_modality not in ("any", plan.modality):
                raise DataAugmentationValidationError(
                    f"Op {step.op_id!r} has modality {op_modality!r} but plan expects {plan.modality!r}"
                )
        ops.append(op)
    return AugmentationPipeline(plan=plan, ops=tuple(ops))

20.12 make_context_rng(ctx)

Build a deterministic RNG for a given augmentation context.

Source code in src/modssc/data_augmentation/api.py
26
27
28
def make_context_rng(ctx: AugmentationContext) -> np.random.Generator:
    """Build a deterministic RNG for a given augmentation context."""
    return make_numpy_rng(seed=int(ctx.seed), epoch=int(ctx.epoch), sample_id=int(ctx.sample_id))

20.13 register_op(op_id)

Decorator to register an augmentation operation class.

Source code in src/modssc/data_augmentation/registry.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def register_op(op_id: str) -> Callable[[type[AugmentationOp]], type[AugmentationOp]]:
    """Decorator to register an augmentation operation class."""

    def _decorator(cls: type[AugmentationOp]) -> type[AugmentationOp]:
        if op_id in _OPS:
            raise DataAugmentationValidationError(f"Duplicate op_id: {op_id}")
        # basic sanity: the class should expose op_id/modality defaults
        if not hasattr(cls, "op_id") and not hasattr(cls, "modality"):
            raise DataAugmentationValidationError(
                f"Op class {cls.__name__} must define 'op_id' and 'modality'."
            )
        _OPS[op_id] = cls
        return cls

    return _decorator
Sources
  1. src/modssc/data_augmentation/api.py
  2. src/modssc/data_augmentation/types.py
  3. src/modssc/data_augmentation/registry.py
  4. src/modssc/data_augmentation/plan.py