21. Skip to content

21. Sampling API

This page documents the sampling API. For workflows, see Sampling how-to.

21.1 What it is for

The sampling brick builds deterministic labeled/unlabeled splits and stores them on disk. [1][2]

21.2 Examples

Create a sampling plan and sample a dataset:

from modssc.data_loader import load_dataset
from modssc.sampling import HoldoutSplitSpec, LabelingSpec, SamplingPlan, sample

ds = load_dataset("toy", download=True)
plan = SamplingPlan(split=HoldoutSplitSpec(test_fraction=0.0, val_fraction=0.2), labeling=LabelingSpec())
res, _ = sample(ds, plan=plan, seed=0, dataset_fingerprint=str(ds.meta["dataset_fingerprint"]))
print(res.stats)

Save and load a split:

from modssc.sampling import load_split, save_split

out_dir = save_split(res, out_dir="splits/toy", overwrite=True)
loaded = load_split(out_dir)
print(loaded.split_fingerprint)

Plan and storage helpers are defined in src/modssc/sampling/plan.py and src/modssc/sampling/storage.py. [3][2]

21.3 API reference

Sampling and splitting for semi-supervised experiments.

This module takes a canonical dataset from modssc.data_loader and produces reproducible experimental splits (holdout, k-fold) plus labeled/unlabeled partitions.

It does NOT download datasets. Use modssc.data_loader for that.

21.4 ImbalanceSpec dataclass

Optional class imbalance scenario.

Kinds: - none - subsample_max_per_class: cap each class to max_per_class (applies to train or labeled) - long_tail: exponential decay per class rank (applies to train or labeled)

apply_to: - train: modify train_idx before labeling - labeled: modify labeled subset after labeling (removed labeled become unlabeled)

Source code in src/modssc/sampling/plan.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
@dataclass(frozen=True)
class ImbalanceSpec:
    """Optional class imbalance scenario.

    Kinds:
    - none
    - subsample_max_per_class: cap each class to max_per_class (applies to train or labeled)
    - long_tail: exponential decay per class rank (applies to train or labeled)

    apply_to:
    - train: modify train_idx before labeling
    - labeled: modify labeled subset after labeling (removed labeled become unlabeled)
    """

    kind: Literal["none", "subsample_max_per_class", "long_tail"] = "none"
    apply_to: Literal["train", "labeled"] = "train"
    max_per_class: int | None = None
    alpha: float | None = None
    min_per_class: int = 1

    def as_dict(self) -> dict[str, Any]:
        return {
            "kind": self.kind,
            "apply_to": self.apply_to,
            "max_per_class": None if self.max_per_class is None else int(self.max_per_class),
            "alpha": None if self.alpha is None else float(self.alpha),
            "min_per_class": int(self.min_per_class),
        }

    @classmethod
    def from_dict(cls, d: Mapping[str, Any]) -> ImbalanceSpec:
        _assert_known_keys(
            d,
            {"kind", "apply_to", "max_per_class", "alpha", "min_per_class"},
            "imbalance",
        )
        kind = str(d.get("kind", "none"))
        if kind not in ("none", "subsample_max_per_class", "long_tail"):
            raise ValueError(f"Unknown imbalance kind: {kind!r}")
        apply_to = str(d.get("apply_to", "train"))
        if apply_to not in ("train", "labeled"):
            raise ValueError(f"Unknown imbalance apply_to: {apply_to!r}")
        return cls(
            kind=kind,  # type: ignore[arg-type]
            apply_to=apply_to,  # type: ignore[arg-type]
            max_per_class=d.get("max_per_class", None),
            alpha=d.get("alpha", None),
            min_per_class=int(d.get("min_per_class", 1)),
        )

21.5 LabelingSpec dataclass

How to select labeled samples within the train partition.

Modes: - fraction: value in (0, 1], selects that fraction of train samples - count: value is an integer count of labeled samples - per_class: value is an integer count per class

If fixed_indices is provided, it is used directly (validated) and the mode is ignored.

Source code in src/modssc/sampling/plan.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
@dataclass(frozen=True)
class LabelingSpec:
    """How to select labeled samples within the train partition.

    Modes:
    - fraction: value in (0, 1], selects that fraction of train samples
    - count: value is an integer count of labeled samples
    - per_class: value is an integer count per class

    If fixed_indices is provided, it is used directly (validated) and the mode is ignored.
    """

    mode: Literal["fraction", "count", "per_class"] = "fraction"
    value: float | int = 0.1
    per_class: bool = False
    min_per_class: int = 1
    strategy: Literal["proportional", "balanced"] = "proportional"
    fixed_indices: Sequence[int] | None = None

    def as_dict(self) -> dict[str, Any]:
        return {
            "mode": self.mode,
            "value": float(self.value) if self.mode == "fraction" else int(self.value),
            "per_class": bool(self.per_class),
            "min_per_class": int(self.min_per_class),
            "strategy": self.strategy,
            "fixed_indices": None
            if self.fixed_indices is None
            else [int(i) for i in self.fixed_indices],
        }

    @classmethod
    def from_dict(cls, d: Mapping[str, Any]) -> LabelingSpec:
        _assert_known_keys(
            d,
            {"mode", "value", "per_class", "min_per_class", "strategy", "fixed_indices"},
            "labeling",
        )
        mode = str(d.get("mode", "fraction"))
        if mode not in ("fraction", "count", "per_class"):
            raise ValueError(f"Unknown labeling mode: {mode!r}")
        value = d.get("value", 0.1)
        value = float(value) if mode == "fraction" else int(value)
        fixed_indices = d.get("fixed_indices", None)
        if fixed_indices is not None:
            if isinstance(fixed_indices, (str, bytes)) or not isinstance(fixed_indices, Sequence):
                raise ValueError("labeling.fixed_indices must be a sequence of integers")
            fixed_indices = [int(i) for i in fixed_indices]
        strategy = str(d.get("strategy", "proportional"))
        if strategy not in ("proportional", "balanced"):
            raise ValueError(f"Unknown labeling strategy: {strategy!r}")
        return cls(
            mode=mode,  # type: ignore[arg-type]
            value=value,
            per_class=bool(d.get("per_class", False)),
            min_per_class=int(d.get("min_per_class", 1)),
            strategy=strategy,  # type: ignore[arg-type]
            fixed_indices=fixed_indices,
        )

21.6 SamplingError

Bases: RuntimeError

Base error for sampling.

Source code in src/modssc/sampling/errors.py
4
5
class SamplingError(RuntimeError):
    """Base error for sampling."""

21.7 SamplingPlan dataclass

Full sampling plan.

Source code in src/modssc/sampling/plan.py
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@dataclass(frozen=True)
class SamplingPlan:
    """Full sampling plan."""

    split: SplitSpec = field(default_factory=HoldoutSplitSpec)
    labeling: LabelingSpec = field(default_factory=LabelingSpec)
    imbalance: ImbalanceSpec = field(default_factory=ImbalanceSpec)
    policy: SamplingPolicy = field(default_factory=SamplingPolicy)

    def as_dict(self) -> dict[str, Any]:
        return {
            "split": self.split.as_dict(),
            "labeling": self.labeling.as_dict(),
            "imbalance": self.imbalance.as_dict(),
            "policy": {
                "respect_official_test": bool(self.policy.respect_official_test),
                "use_official_graph_masks": bool(self.policy.use_official_graph_masks),
                "allow_override_official": bool(self.policy.allow_override_official),
            },
        }

    @classmethod
    def from_dict(cls, d: Mapping[str, Any]) -> SamplingPlan:
        _assert_known_keys(d, {"split", "labeling", "imbalance", "policy"}, "plan")
        split_obj = _ensure_mapping(d.get("split", {}), "split")
        split_kind = str(split_obj.get("kind", "holdout"))
        if split_kind == "kfold":
            split = KFoldSplitSpec.from_dict(split_obj)
        elif split_kind == "holdout":
            split = HoldoutSplitSpec.from_dict(split_obj)
        else:
            raise ValueError(f"Unknown split kind: {split_kind!r}")

        labeling_obj = _ensure_mapping(d.get("labeling", {}), "labeling")
        labeling = LabelingSpec.from_dict(labeling_obj)

        imbalance_obj = _ensure_mapping(d.get("imbalance", {}), "imbalance")
        imbalance = ImbalanceSpec.from_dict(imbalance_obj)

        policy_obj = _ensure_mapping(d.get("policy", {}), "policy")
        policy = SamplingPolicy.from_dict(policy_obj)

        return cls(split=split, labeling=labeling, imbalance=imbalance, policy=policy)

21.8 SamplingPolicy dataclass

Policy for handling official provider splits.

  • respect_official_test: if dataset.test exists, keep it as the test set
  • use_official_graph_masks: if graph dataset provides masks, use them as train/val/test masks
Source code in src/modssc/sampling/plan.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@dataclass(frozen=True)
class SamplingPolicy:
    """Policy for handling official provider splits.

    - respect_official_test: if dataset.test exists, keep it as the test set
    - use_official_graph_masks: if graph dataset provides masks, use them as train/val/test masks
    """

    respect_official_test: bool = True
    use_official_graph_masks: bool = True
    allow_override_official: bool = False

    @classmethod
    def from_dict(cls, d: Mapping[str, Any]) -> SamplingPolicy:
        _assert_known_keys(
            d,
            {"respect_official_test", "use_official_graph_masks", "allow_override_official"},
            "policy",
        )
        return cls(
            respect_official_test=bool(d.get("respect_official_test", True)),
            use_official_graph_masks=bool(d.get("use_official_graph_masks", True)),
            allow_override_official=bool(d.get("allow_override_official", False)),
        )

21.9 SamplingResult dataclass

Sampling result with indices (inductive) or masks (graph transductive).

Indices keys (typical): - train, val, test - train_labeled, train_unlabeled

Refs indicate the base split each index array refers to: - "train" means indices are relative to dataset.train - "test" means indices are relative to dataset.test - "nodes" means graph nodes

Masks keys (graph): - train, val, test, labeled, unlabeled

Source code in src/modssc/sampling/result.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@dataclass(frozen=True)
class SamplingResult:
    """Sampling result with indices (inductive) or masks (graph transductive).

    Indices keys (typical):
    - train, val, test
    - train_labeled, train_unlabeled

    Refs indicate the base split each index array refers to:
    - "train" means indices are relative to dataset.train
    - "test" means indices are relative to dataset.test
    - "nodes" means graph nodes

    Masks keys (graph):
    - train, val, test, labeled, unlabeled
    """

    schema_version: int
    created_at: str
    dataset_fingerprint: str
    split_fingerprint: str
    plan: Mapping[str, Any]

    indices: Mapping[str, np.ndarray] = field(default_factory=dict)
    refs: Mapping[str, str] = field(default_factory=dict)
    masks: Mapping[str, np.ndarray] = field(default_factory=dict)

    stats: Mapping[str, Any] = field(default_factory=dict)

    def is_graph(self) -> bool:
        return bool(self.masks)

    def validate(self, *, n_train: int, n_test: int | None, n_nodes: int | None) -> None:
        if self.is_graph():
            self._validate_graph(n_nodes=n_nodes)
        else:
            self._validate_inductive(n_train=n_train, n_test=n_test)

    def _validate_inductive(self, *, n_train: int, n_test: int | None) -> None:
        def _check_idx(name: str, base: str, idx: np.ndarray) -> None:
            if idx.dtype.kind not in ("i", "u"):
                raise SamplingValidationError(f"{name} indices must be integers, got {idx.dtype}")
            if idx.size == 0:
                return
            max_ok = n_train if base == "train" else n_test
            if max_ok is None:
                raise SamplingValidationError("n_test is required when base='test'")
            if idx.min() < 0 or idx.max() >= max_ok:
                raise SamplingValidationError(f"{name} has out-of-range indices for base={base!r}")
            if np.unique(idx).size != idx.size:
                raise SamplingValidationError(f"{name} contains duplicate indices")

        for name, idx in self.indices.items():
            base = self.refs.get(name, "train")
            _check_idx(name, base, idx)

        # core invariants
        train = self.indices.get("train")
        val = self.indices.get("val")
        test = self.indices.get("test")
        if train is None:
            raise SamplingValidationError("Missing train indices")
        if val is None:
            raise SamplingValidationError("Missing val indices")
        if test is None:
            raise SamplingValidationError("Missing test indices (may be empty array)")

        if self.refs.get("train", "train") != "train" or self.refs.get("val", "train") != "train":
            raise SamplingValidationError("train and val indices must be relative to dataset.train")

        # disjointness of train/val in same base
        if np.intersect1d(train, val).size:
            raise SamplingValidationError("train and val overlap")

        # test disjointness only if same base
        if self.refs.get("test", "train") == "train" and (
            np.intersect1d(train, test).size or np.intersect1d(val, test).size
        ):
            raise SamplingValidationError("test overlaps with train or val")

        labeled = self.indices.get("train_labeled")
        unlabeled = self.indices.get("train_unlabeled")
        if labeled is None or unlabeled is None:
            raise SamplingValidationError("Missing labeled/unlabeled indices")
        if np.intersect1d(labeled, unlabeled).size:
            raise SamplingValidationError("labeled and unlabeled overlap")
        if not np.array_equal(np.sort(np.concatenate([labeled, unlabeled])), np.sort(train)):
            raise SamplingValidationError("labeled + unlabeled must cover train exactly")

    def _validate_graph(self, *, n_nodes: int | None) -> None:
        if n_nodes is None:
            raise SamplingValidationError("n_nodes is required to validate graph masks")
        required = {"train", "val", "test", "labeled", "unlabeled"}
        if set(self.masks.keys()) != required:
            raise SamplingValidationError(f"Graph masks must have keys {sorted(required)}")

        for k, m in self.masks.items():
            if m.dtype != bool:
                raise SamplingValidationError(f"Mask {k!r} must be bool, got {m.dtype}")
            if m.shape != (n_nodes,):
                raise SamplingValidationError(
                    f"Mask {k!r} must have shape ({n_nodes},), got {m.shape}"
                )

        train = self.masks["train"]
        val = self.masks["val"]
        test = self.masks["test"]
        labeled = self.masks["labeled"]
        unlabeled = self.masks["unlabeled"]

        if (labeled & ~train).any():
            raise SamplingValidationError("labeled_mask must be subset of train_mask")
        if not np.array_equal(unlabeled, train & ~labeled):
            raise SamplingValidationError("unlabeled_mask must equal train_mask & ~labeled_mask")

        # train/val/test should be disjoint
        if (train & val).any() or (train & test).any() or (val & test).any():
            raise SamplingValidationError("train/val/test masks must be disjoint")

    @property
    def train_idx(self) -> np.ndarray:
        if self.is_graph():
            return np.where(self.masks["train"])[0]
        return self.indices.get("train", np.array([]))

    @property
    def val_idx(self) -> np.ndarray:
        if self.is_graph():
            return np.where(self.masks["val"])[0]
        return self.indices.get("val", np.array([]))

    @property
    def test_idx(self) -> np.ndarray:
        if self.is_graph():
            return np.where(self.masks["test"])[0]
        return self.indices.get("test", np.array([]))

    @property
    def labeled_idx(self) -> np.ndarray:
        if self.is_graph():
            return np.where(self.masks["labeled"])[0]
        return self.indices.get("train_labeled", np.array([]))

    @property
    def unlabeled_idx(self) -> np.ndarray:
        if self.is_graph():
            return np.where(self.masks["unlabeled"])[0]
        return self.indices.get("train_unlabeled", np.array([]))

21.10 SamplingValidationError

Bases: 21.6 SamplingError

Raised when a sampled split violates invariants.

Source code in src/modssc/sampling/errors.py
16
17
class SamplingValidationError(SamplingError):
    """Raised when a sampled split violates invariants."""

21.11 sample(dataset, *, plan, seed, dataset_fingerprint=None, dataset_id=None, cache_root=None, save=False, overwrite=False)

Sample a canonical dataset into a reproducible experimental split.

Returns (result, path). Path is not None if save=True.

Source code in src/modssc/sampling/api.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def sample(
    dataset: Any,
    *,
    plan: SamplingPlan,
    seed: int,
    dataset_fingerprint: str | None = None,
    dataset_id: str | None = None,
    cache_root: Path | None = None,
    save: bool = False,
    overwrite: bool = False,
) -> tuple[SamplingResult, Path | None]:
    """Sample a canonical dataset into a reproducible experimental split.

    Returns (result, path). Path is not None if save=True.
    """
    start = perf_counter()
    ds_fp = _resolve_dataset_fingerprint(dataset, dataset_fingerprint)

    seed_split = derive_seed(seed, "split")
    seed_label = derive_seed(seed, "labeling")
    seed_imb = derive_seed(seed, "imbalance")

    plan_dict = plan.as_dict()
    split_fingerprint = stable_hash(
        {
            "schema_version": SCHEMA_VERSION,
            "dataset_fingerprint": ds_fp,
            "plan": plan_dict,
            "seed": int(seed),
        }
    )

    created_at = datetime.now(timezone.utc).isoformat()

    # detect graph
    is_graph = (
        getattr(getattr(dataset, "train", None), "edges", None) is not None
        or getattr(getattr(dataset, "train", None), "masks", None) is not None
    )
    logger.info(
        "Sampling start: dataset_id=%s dataset_fingerprint=%s seed=%s graph=%s split=%s",
        dataset_id,
        ds_fp,
        seed,
        bool(is_graph),
        plan.split.kind,
    )
    logger.debug("Sampling plan: %s", plan_dict)

    if is_graph:
        result = _sample_graph(
            dataset,
            plan=plan,
            seed_split=seed_split,
            seed_label=seed_label,
            seed_imb=seed_imb,
            dataset_fingerprint=ds_fp,
            split_fingerprint=split_fingerprint,
            created_at=created_at,
            plan_dict=plan_dict,
        )
    else:
        result = _sample_inductive(
            dataset,
            plan=plan,
            seed_split=seed_split,
            seed_label=seed_label,
            seed_imb=seed_imb,
            dataset_fingerprint=ds_fp,
            split_fingerprint=split_fingerprint,
            created_at=created_at,
            plan_dict=plan_dict,
        )

    out_path: Path | None = None
    if save:
        out_path = split_dir_for(
            dataset_fingerprint=ds_fp, split_fingerprint=split_fingerprint, root=cache_root
        )
        save_split(result, out_path, overwrite=overwrite)
    duration = perf_counter() - start
    logger.info(
        "Sampling done: train=%s val=%s test=%s labeled=%s unlabeled=%s duration_s=%.3f",
        int(result.train_idx.shape[0]),
        int(result.val_idx.shape[0]),
        int(result.test_idx.shape[0]),
        int(result.labeled_idx.shape[0]),
        int(result.unlabeled_idx.shape[0]),
        duration,
    )
    logger.debug("Sampling stats: %s", dict(result.stats))
    _warn_on_sampling_stats(result)
    return result, out_path
Sources
  1. src/modssc/sampling/api.py
  2. src/modssc/sampling/storage.py
  3. src/modssc/sampling/plan.py