Source code for src.batch_loader

"""BatchLoader: iterator-based batching over any Dataset."""

from __future__ import annotations

import math
import random
from collections.abc import Iterator
from typing import Any

from src.dataset import Dataset
from src.utils import check_type


[docs] class BatchLoader: """Wraps a dataset and yields batches of data via iteration. Batches are built from indices only; the actual data is loaded from the dataset (which may itself be lazy) only when each batch is consumed. Args: dataset: Any ``Dataset`` instance (labeled or unlabeled). batch_size: Number of samples per batch. shuffle: If ``True``, indices are randomly shuffled before batching at the start of every iteration. If ``False``, data is returned in its original order. drop_last: If ``True``, the last batch is discarded when it contains fewer than ``batch_size`` samples. If ``False`` (default), the last batch is kept even if it is smaller. Raises: TypeError: If any argument has an unexpected type. ValueError: If ``batch_size`` is less than 1. Example:: loader = BatchLoader(dataset, batch_size=32, shuffle=True) for batch in loader: # batch is a list of (data, label) tuples or data items ... """ def __init__( self, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last: bool = False, ) -> None: check_type(dataset, Dataset, "dataset") check_type(batch_size, int, "batch_size") check_type(shuffle, bool, "shuffle") check_type(drop_last, bool, "drop_last") if batch_size < 1: raise ValueError( f"'batch_size' must be at least 1, got {batch_size}.") self._dataset: Dataset = dataset self._batch_size: int = batch_size self._shuffle: bool = shuffle self._drop_last: bool = drop_last # ------------------------------------------------------------------ # Iterator protocol # ------------------------------------------------------------------ def __iter__(self) -> Iterator[list[Any]]: """Yield batches of data, loading each item only when reached. Yields: A list of items returned by ``dataset[i]`` — either raw data (unlabeled) or ``(data, label)`` tuples (labeled). """ indices = list(range(len(self._dataset))) if self._shuffle: random.shuffle(indices) for start in range(0, len(indices), self._batch_size): batch_indices = indices[start: start + self._batch_size] if self._drop_last and len(batch_indices) < self._batch_size: return yield [self._dataset[i] for i in batch_indices] # ------------------------------------------------------------------ # Sized protocol # ------------------------------------------------------------------ def __len__(self) -> int: """Return the number of batches produced per iteration. Accounts for ``drop_last``: if the last batch is incomplete and ``drop_last`` is ``True``, it is not counted. """ n = len(self._dataset) if self._drop_last: return n // self._batch_size return math.ceil(n / self._batch_size) # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property def dataset(self) -> Dataset: """The underlying dataset.""" return self._dataset @property def batch_size(self) -> int: """Number of samples per batch.""" return self._batch_size @property def shuffle(self) -> bool: """Whether indices are shuffled at the start of each iteration.""" return self._shuffle @property def drop_last(self) -> bool: """Whether the last incomplete batch is dropped.""" return self._drop_last