"""Abstract base classes for the dataset hierarchy.
Hierarchy
---------
Dataset (ABC)
├── LabeledDataset (ABC)
└── UnlabeledDataset (ABC)
Concrete subclasses live in image_dataset.py and audio_dataset.py.
"""
from __future__ import annotations
import random
from abc import ABC, abstractmethod
from typing import Any
from src.utils import check_range, check_type
[docs]
class Dataset(ABC):
"""Abstract base class for all datasets.
Subclasses must implement :meth:`_scan_files`, :meth:`_load_file`,
and :meth:`__getitem__`.
Attributes (private):
_root: Root folder path where data files are stored.
_lazy: Whether to load data lazily (on access) or eagerly (at init).
_file_paths: Paths discovered by :meth:`_scan_files`.
_data: Pre-loaded data objects when *eager*; ``None`` when lazy.
"""
def __init__(self, root: str, lazy: bool = True) -> None:
"""Initialise the dataset.
Args:
root: Path to the root folder containing the data files.
lazy: If ``True`` (default), data is loaded on demand.
If ``False``, all data is loaded into memory immediately.
Raises:
TypeError: If *root* is not a ``str`` or *lazy* is not a ``bool``.
"""
check_type(root, str, "root")
check_type(lazy, bool, "lazy")
self._root: str = root
self._lazy: bool = lazy
self._file_paths: list[str] = []
self._data: list[Any] | None = None
self._scan_files()
if not self._lazy:
self._data = [self._load_file(p) for p in self._file_paths]
# ------------------------------------------------------------------
# Abstract interface
# ------------------------------------------------------------------
@abstractmethod
def _scan_files(self) -> None:
"""Populate ``self._file_paths`` with the paths to every data file."""
@abstractmethod
def _load_file(self, path: str) -> Any:
"""Load a single data point from *path* and return it."""
@abstractmethod
def __getitem__(self, index: int) -> Any:
"""Return the data point (and label, if applicable) at *index*."""
# ------------------------------------------------------------------
# Concrete methods
# ------------------------------------------------------------------
def __len__(self) -> int:
"""Return the number of data points in the dataset."""
return len(self._file_paths)
[docs]
def split(self, train_ratio: float) -> tuple[Dataset, Dataset]:
"""Split the dataset into training and test subsets.
The dataset is shuffled randomly before splitting so that the
distribution of examples is approximately balanced in both subsets.
Args:
train_ratio: Fraction of data points to include in the training
set. Must be in the open interval (0, 1).
Returns:
A ``(train_dataset, test_dataset)`` tuple, each of the same
concrete class as ``self``.
Raises:
TypeError: If *train_ratio* is not a ``float``.
ValueError: If *train_ratio* is not strictly between 0 and 1.
"""
check_type(train_ratio, float, "train_ratio")
check_range(train_ratio, 0.0, 1.0, "train_ratio")
indices = list(range(len(self)))
random.shuffle(indices)
split_idx = int(len(indices) * train_ratio)
train_indices = indices[:split_idx]
test_indices = indices[split_idx:]
return self._create_subset(
train_indices), self._create_subset(test_indices)
def _create_subset(self, indices: list[int]) -> Dataset:
"""Create a new dataset instance containing only the specified indices.
Uses :meth:`object.__new__` to bypass ``__init__`` and then calls
:meth:`_init_subset` so that subclasses can extend the copy logic.
Args:
indices: Positions of the data points to include.
Returns:
A new dataset of the same concrete class.
"""
new_ds: Dataset = object.__new__(type(self))
self._init_subset(new_ds, indices)
return new_ds
def _init_subset(self, new_ds: Dataset, indices: list[int]) -> None:
"""Populate *new_ds* with the data corresponding to *indices*.
Subclasses that add extra per-item attributes (e.g. ``_labels``)
should call ``super()._init_subset(new_ds, indices)`` and then
slice their own attributes accordingly.
Args:
new_ds: The freshly-created (uninitialised) dataset object.
indices: Positions of the data points to copy.
"""
new_ds._root = self._root
new_ds._lazy = self._lazy
new_ds._file_paths = [self._file_paths[i] for i in indices]
new_ds._data = (
[self._data[i] for i in indices]
if self._data is not None else None
)
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def root(self) -> str:
"""Root folder path."""
return self._root
@property
def lazy(self) -> bool:
"""Whether the dataset uses lazy loading."""
return self._lazy
[docs]
class LabeledDataset(Dataset, ABC):
"""Abstract base class for datasets that carry per-sample labels.
Adds a ``_labels`` list that is populated by :meth:`_load_labels` and
kept parallel to ``_file_paths``.
Concrete subclasses must still implement :meth:`_scan_files` and
:meth:`_load_file`. They should call ``_load_labels()`` **after**
``super().__init__()`` (which calls ``_scan_files``).
"""
def __init__(self, root: str, lazy: bool = True) -> None:
self._labels: list[Any] = []
super().__init__(root, lazy)
# _scan_files() has been called; now populate labels.
self._load_labels()
# ------------------------------------------------------------------
# Abstract interface
# ------------------------------------------------------------------
@abstractmethod
def _load_labels(self) -> None:
"""Populate ``self._labels`` with one label per file path."""
# ------------------------------------------------------------------
# Concrete implementation of __getitem__
# ------------------------------------------------------------------
def __getitem__(self, index: int) -> tuple[Any, Any]:
"""Return ``(data, label)`` for the data point at *index*.
Args:
index: Zero-based index of the data point.
Returns:
A ``(data, label)`` tuple.
Raises:
IndexError: If *index* is out of range.
"""
if index < 0 or index >= len(self):
raise IndexError(
f"Index {index} is out of range "
f"for dataset of size {len(self)}."
)
data = (
self._load_file(self._file_paths[index])
if self._lazy
else self._data[index] # type: ignore[index]
)
return data, self._labels[index]
def _init_subset(self, new_ds: Dataset, indices: list[int]) -> None:
"""Extend the base copy logic to also slice ``_labels``."""
super()._init_subset(new_ds, indices)
# new_ds is guaranteed to be a LabeledDataset subclass.
new_ds._labels = [self._labels[i]
for i in indices] # type: ignore[attr-defined]
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def labels(self) -> list[Any]:
"""Per-sample labels, parallel to the file-path list."""
return self._labels
[docs]
class UnlabeledDataset(Dataset, ABC):
"""Abstract base class for datasets without labels.
Provides a concrete :meth:`__getitem__` that returns only the data.
"""
def __getitem__(self, index: int) -> Any:
"""Return the data point at *index* (no label).
Args:
index: Zero-based index of the data point.
Returns:
The loaded data object.
Raises:
IndexError: If *index* is out of range.
"""
if index < 0 or index >= len(self):
raise IndexError(
f"Index {index} is out of range "
f"for dataset of size {len(self)}."
)
if self._lazy:
return self._load_file(self._file_paths[index])
return self._data[index] # type: ignore[index]