Source code for src.utils

"""Utility functions shared across dataset classes."""

import csv
import numpy as np
from PIL import Image


[docs] def check_type(value: object, expected_type: type | tuple, name: str) -> None: """Raise TypeError if value is not an instance of expected_type. Args: value: The value to check. expected_type: The expected type or tuple of types. name: The parameter name (used in the error message). Raises: TypeError: If value is not of the expected type. """ if not isinstance(value, expected_type): expected = ( expected_type.__name__ if isinstance(expected_type, type) else " | ".join(t.__name__ for t in expected_type) ) raise TypeError( f"'{name}' must be of type {expected}, got {type(value).__name__}." )
[docs] def check_range(value: float, min_val: float, max_val: float, name: str) -> None: """Raise ValueError if value is not within [min_val, max_val]. Args: value: The numeric value to check. min_val: The minimum allowed value (inclusive). max_val: The maximum allowed value (inclusive). name: The parameter name (used in the error message). Raises: ValueError: If value is outside the allowed range. """ if not (min_val <= value <= max_val): raise ValueError( f"'{name}' must be between {min_val} and {max_val}, got {value}." )
[docs] def parse_labels_csv(labels_file: str) -> dict[str, str]: """Parse a CSV file mapping filenames to labels. The CSV is expected to have no header row. Each row contains: - column 0: filename (basename only, e.g. ``image.jpg``) - column 1: label (string; callers may cast to int/float as needed) Args: labels_file: Absolute or relative path to the CSV file. Returns: A dict mapping each filename to its label string. Raises: FileNotFoundError: If labels_file does not exist. ValueError: If the CSV contains a row with fewer than two columns. """ check_type(labels_file, str, "labels_file") mapping: dict[str, str] = {} with open(labels_file, newline="", encoding="utf-8") as fh: reader = csv.reader(fh) for row_num, row in enumerate(reader, start=1): if len(row) < 2: raise ValueError( f"Row {row_num} in '{labels_file}'" " has fewer than 2 columns." ) filename, label = row[0].strip(), row[1].strip() mapping[filename] = label return mapping
[docs] def load_image(path: str) -> np.ndarray: """Load an image from disk as an RGB numpy array. Args: path: Path to the image file (.jpg, .jpeg, or .png). Returns: A numpy array of shape (H, W, 3) with dtype uint8. Raises: FileNotFoundError: If the file does not exist. OSError: If Pillow cannot open the file. """ check_type(path, str, "path") with Image.open(path) as img: img_rgb = img.convert("RGB") return np.array(img_rgb)