Source code for ursa.catalog.schemas

"""Pydantic schemas for the nine Lance catalog tables.

Design notes:

* ``extra="allow"`` is deliberate. A row read with the model from version N
  can still round-trip unknown columns added in version N+1 — additive Lance
  schema evolution stays a config change, not a code change. Opposite of the
  secrets models in ``constellation-utils`` (``extra="forbid"``).
* ``frozen=True``: declared fields are immutable once constructed. Catalog
  mutations go through writes of new rows. Note: ``__pydantic_extra__`` is a
  regular dict and is **not** frozen — we treat that as a permitted escape
  hatch for ad-hoc extra hydration during ingestion, and lock the behavior
  in a test so a Pydantic upgrade can't change it silently.
* All datetime fields use :data:`UTCDatetime`: aware-only on input,
  normalized to UTC after validation. Tradeoff — original tz offset is
  discarded; rationale is fewer foot-guns when comparing timestamps across
  rows in the lifecycle layer (e.g. ``last_accessed_at`` GC).
* Map-style fields are constrained to scalar values (and lists of scalars).
  Lance/Arrow ``MapType`` requires a single value type per column, and the
  architecture promises "frequently-used keys can be promoted to typed
  columns later." Both promises break with nested dicts or arbitrary objects.
* Time-origin convention: per-recording time fields (``domain_intervals``
  tuple bounds, ``event_time``, ``TimeWindow.start_seconds``) are in the
  recording's native time domain. **Negative values are legal** for
  re-aligned timelines (stimulus-onset-relative events, pre-roll baseline).
  Only the relative invariant ``end > start`` is enforced per interval;
  intervals must additionally be sorted ascending and non-overlapping.
* ID fields use :data:`CatalogID`: non-empty, ``[A-Za-z0-9_-]+``. Stricter
  content-hash / UUID4 enforcement is deferred to a shared util / writer.
* ``EmbeddingRow.vector`` is the highest-cardinality hot path; per-element
  Pydantic validation of a 1024+-dim ``list[float]`` is ~4 ms/row. We type
  the field as ``Any`` and run a single ``BeforeValidator`` that checks
  length and ``v[0]``. Static type degrades to ``Any``; runtime contract
  is preserved.
"""

from __future__ import annotations

from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import Annotated, Any, ClassVar

from pydantic import (
    AfterValidator,
    AwareDatetime,
    BaseModel,
    BeforeValidator,
    ConfigDict,
    Field,
    StringConstraints,
    model_validator,
)

__all__ = [
    "BenchmarkResultRow",
    "BenchmarkSuiteRow",
    "CatalogID",
    "CatalogRow",
    "CheckpointRow",
    "EmbeddingRow",
    "EmbeddingSource",
    "EventRow",
    "ID_PATTERN",
    "IngestionStatus",
    "MetadataDict",
    "MetadataValue",
    "ModalityName",
    "ModalityRow",
    "NonEmptyString",
    "ParticipantRow",
    "RecordingRow",
    "ScalarMetadataValue",
    "StorageFormat",
    "StorageURI",
    "TimeWindow",
    "URI_PATTERN",
    "UTCDatetime",
    "VirgoAssetRow",
]

# ---------------------------------------------------------------------------
# Datetimes — aware-only, UTC-normalized
# ---------------------------------------------------------------------------


[docs] def _to_utc(dt: datetime) -> datetime: """Normalize an aware datetime to UTC. ``AwareDatetime`` rejects naive inputs upstream, so ``dt.tzinfo`` is guaranteed non-None here. The ``astimezone`` call is a no-op when input is already UTC. """ return dt.astimezone(UTC)
#: Aware datetime, normalized to UTC after validation. Original tz offset is #: discarded — this is intentional, see module docstring. UTCDatetime = Annotated[AwareDatetime, AfterValidator(_to_utc)] # --------------------------------------------------------------------------- # Map-column value constraint # --------------------------------------------------------------------------- #: A scalar permitted in a Map-typed catalog column. ``bool`` is included #: here as a legitimate flag value (contrast with vector elements, where it's #: excluded as a numeric foot-gun). ScalarMetadataValue = str | int | float | bool | None #: A scalar OR a list of scalars. Nested dicts and arbitrary Python objects #: are rejected at the type system level — they break Lance MapType. MetadataValue = ScalarMetadataValue | list[ScalarMetadataValue]
[docs] def _check_homogeneous_lists(value: dict[str, Any]) -> dict[str, Any]: """Reject mixed-type lists in a metadata map. ``MetadataValue`` already type-restricts list elements to scalars, but a list with mixed scalar types (``[1, "two", True]``) still passes. Lance ``MapType`` requires a single Arrow type per column, so we walk the values and reject heterogeneous lists at construction time. Bool is a subclass of int — collapse them to the same bucket so ``[True, False, 0]`` isn't flagged as mixed. """ for key, v in value.items(): if not isinstance(v, list) or len(v) <= 1: continue # Treat bool/int as the same bucket; otherwise type(...) is enough. types = {int if isinstance(item, bool) else type(item) for item in v} if len(types) > 1: raise ValueError( f"metadata[{key!r}] has mixed-type list {v!r}; Lance MapType " "requires homogeneous lists" ) return value
#: A flat, queryable map for a Lance Map-typed column. Values must be a #: single scalar type or a homogeneous list of scalars. MetadataDict = Annotated[dict[str, MetadataValue], AfterValidator(_check_homogeneous_lists)] # --------------------------------------------------------------------------- # IDs and URIs # --------------------------------------------------------------------------- #: Catalog ID character set. Public so consumers can validate IDs without #: instantiating a row (CLI input, ad-hoc scripts). Stricter format checks #: (UUID4 hex, content-hash hex) deferred to a shared util / writer. ID_PATTERN = r"^[A-Za-z0-9_-]+$" CatalogID = Annotated[str, StringConstraints(min_length=1, pattern=ID_PATTERN)] #: Object-store URI scheme constraint. ``{r2, s3, gcs}`` cover ENG-873's #: planned backends; ``{file}`` covers Polaris/local SSD. ``lance://`` is #: deliberately omitted — the architecture doc has no Lance URI form yet, and #: adding a scheme later is a non-breaking change. Public so consumers can #: validate URIs without instantiating a row. URI_PATTERN = r"^(r2|s3|gcs|file)://\S+$" StorageURI = Annotated[str, StringConstraints(min_length=1, pattern=URI_PATTERN)] #: Non-empty short string — pipeline names, version strings, model ids, etc. #: Reuse this everywhere ``Field(min_length=1)`` was inlined; centralizing #: lets a future tightening (max_length, charset) change one place. NonEmptyString = Annotated[str, StringConstraints(min_length=1)] #: A modality name (``eeg``, ``video_webcam``, ``video_webcam_left``). #: Identical type used in :class:`ModalityRow` and :class:`EmbeddingSource`; #: aliasing keeps the join column consistent under future tightening. ModalityName = NonEmptyString # --------------------------------------------------------------------------- # Closed enums # ---------------------------------------------------------------------------
[docs] class StorageFormat(StrEnum): """How a modality's bytes are laid out in object storage. Two tiers: **Canonical** — Ursa-managed, produced by ingestion or Virgo. These are the permanent storage formats that ``ursa.query`` reads directly. **Raw** (``RAW_*``) — data-engine-native segment files, registered during Phase 1a ingestion before Virgo has processed them. ``ModalityRow`` entries with raw formats have ``storage_uri`` pointing at the data-engine raw prefix. Virgo promotes raw modalities to canonical formats; old raw rows are retired by lifecycle GC. See :mod:`ursa.layout` for key conventions. """ # --- Canonical ---------------------------------------------------------- ZARR = "zarr" # Zarr array — regular continuous streams (EEG, biometrics) LANCE = "lance" # Lance table — irregular event streams + catalog MP4_INDEX = "mp4_index" # mp4 + Lance frame-index — video PARQUET = "parquet" # Parquet file — event streams before Lance promotion # --- Raw (data-engine native) ------------------------------------------- RAW_BINARY = "raw_binary" # EEG float64 row-major .bin segment files RAW_CSV = "raw_csv" # CSV streams: gaze, keyboard, mouse, watch, notes, battery RAW_JSONL = "raw_jsonl" # JSONL streams: browser rrweb, location/environment RAW_AUDIO = "raw_audio" # WAV PCM segment files (microphone) RAW_VIDEO = "raw_video" # MKV/H.264 segment files (camera, screen, Pupil Labs)
[docs] class IngestionStatus(StrEnum): """Per-modality ingestion lifecycle (architecture v0.4 two-store model). ``raw`` — modality registered against a cold-bucket raw file (whole object addressable via ``ModalityRow.raw_storage_uri``); Virgo's ingestion node has not yet run, so ``domain_intervals``, ``channel_spec``, and ``format`` may be null. ``storage_uri`` mirrors ``raw_storage_uri`` until ingestion completes. ``processed`` — Virgo's ingestion node has converted the raw file to a canonical format (Zarr / MP4 + Lance frame index). ``format`` is a canonical (non-``RAW_*``) value, ``domain_intervals`` and ``channel_spec`` are populated, and ``storage_uri`` points at the processed object on the hot bucket. ``raw_storage_uri`` is preserved so re-ingestion is always possible. """ RAW = "raw" PROCESSED = "processed"
# --------------------------------------------------------------------------- # Sub-models # ---------------------------------------------------------------------------
[docs] class TimeWindow(BaseModel): """Half-open ``[start, end)`` window in the recording's native time domain. No absolute bound on ``start_seconds`` — re-aligned timelines (stimulus onset, pre-roll baseline) legitimately produce negative times. Only the relative invariant ``end > start`` is enforced. Note (deferred benchmark): nested struct columns in Lance have weaker filter pushdown than top-level columns. If query latency on ``VirgoAssetRow.time_window`` or ``EmbeddingRow.source.time_window`` becomes a bottleneck, the writer can flatten internally and treat this Pydantic type as a logical view. """ model_config = ConfigDict(frozen=True, extra="allow") start_seconds: float end_seconds: float
[docs] @model_validator(mode="after") def _end_after_start(self) -> TimeWindow: if self.end_seconds <= self.start_seconds: raise ValueError( f"end_seconds ({self.end_seconds}) must exceed start_seconds ({self.start_seconds})" ) return self
[docs] class EmbeddingSource(BaseModel): """Source of an embedding row: which (recording, modality, window) it covers. Same nested-struct caveat as :class:`TimeWindow`. """ model_config = ConfigDict(frozen=True, extra="allow") recording_hash: CatalogID modality: ModalityName time_window: TimeWindow
# --------------------------------------------------------------------------- # Vector hot path # ---------------------------------------------------------------------------
[docs] def _validate_vector_fast(v: Any) -> list[float]: """Fast-path vector validator: skips per-element Pydantic checks (~4 ms/1024-dim). Accepts any sized indexable sequence — list, tuple, numpy.ndarray, torch.Tensor — without importing numpy/torch. Only checks v[0]; heterogeneous tails slip through and are caught by Arrow at write time. Raises ValueError (not TypeError) so Pydantic wraps as ValidationError. """ if isinstance(v, (str, bytes, bytearray)): # str/bytes have __len__ + __getitem__ but obviously aren't vectors. raise ValueError(f"vector must be a numeric sequence, got {type(v).__name__}") try: length = len(v) except TypeError as exc: raise ValueError(f"vector must be a sized sequence, got {type(v).__name__}") from exc if length == 0: raise ValueError("vector must be non-empty") try: sample = v[0] except (TypeError, IndexError, KeyError) as exc: raise ValueError(f"vector must be indexable, got {type(v).__name__}") from exc # bool is a subclass of int — exclude explicitly so [True, False] doesn't # sneak through. Bools are legitimate metadata values; in a numeric vector # they're a foot-gun. if isinstance(sample, bool) or not isinstance(sample, (int, float)): # numpy/torch scalar elements: try .item() to reach a Python scalar. item = getattr(sample, "item", None) sample = item() if callable(item) else sample if isinstance(sample, bool) or not isinstance(sample, (int, float)): raise ValueError( f"vector elements must be numeric (int|float, not bool), " f"got {type(sample).__name__}" ) if isinstance(v, list): return v tolist = getattr(v, "tolist", None) return tolist() if callable(tolist) else list(v)
#: Embedding vector. Static type ``Any`` for perf; runtime contract is #: ``list[float]`` non-empty, no bools. Vector = Annotated[Any, BeforeValidator(_validate_vector_fast)] # --------------------------------------------------------------------------- # Base # ---------------------------------------------------------------------------
[docs] class CatalogRow(BaseModel): """Base for all Lance catalog table rows. Subclasses MUST declare ``__primary_key__`` as a non-empty tuple of field names that identify the row uniquely within its table. Failure raises ``TypeError`` at class-definition time. The future Lance writer reads this attribute to enforce uniqueness. """ model_config = ConfigDict( frozen=True, extra="allow", populate_by_name=True, str_strip_whitespace=True, ) __primary_key__: ClassVar[tuple[str, ...]] = ()
[docs] @classmethod def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: super().__pydantic_init_subclass__(**kwargs) pk = cls.__dict__.get("__primary_key__") if not pk: raise TypeError(f"{cls.__name__} must declare a non-empty __primary_key__ ClassVar") for field_name in pk: if field_name not in cls.__pydantic_fields__: raise TypeError( f"{cls.__name__}.__primary_key__ references unknown field {field_name!r}" )
# --------------------------------------------------------------------------- # Row models # ---------------------------------------------------------------------------
[docs] class ParticipantRow(CatalogRow): """One row per enrolled participant. Primary key: ``participant_id`` — any unique catalog ID. By convention a short slug like ``p042``; not enforced. """ __primary_key__: ClassVar[tuple[str, ...]] = ("participant_id",) participant_id: CatalogID enrolled_at: UTCDatetime metadata: MetadataDict = Field(default_factory=dict)
[docs] class RecordingRow(CatalogRow): """One row per recording. Primary key: ``recording_hash`` — any unique catalog ID; content-hash convention enforced later via a shared util. ``participant_ids`` is a list (architecture v0.4): a recording can cover multiple participants (multi-subject experiments, dyad sessions, crowd recordings). Single-participant recordings carry a one-element list. The list must be non-empty. """ __primary_key__: ClassVar[tuple[str, ...]] = ("recording_hash",) recording_hash: CatalogID participant_ids: list[CatalogID] = Field(min_length=1) start_time: UTCDatetime duration: timedelta = Field(gt=timedelta(0)) #: Per-rig hardware manifest as a Lance MapType column: rig name, build / #: firmware versions, sensor serial numbers. Same scalar / homogeneous-list #: constraint as ``metadata`` — encode per-channel structure as parallel #: lists, not list-of-dicts. device_info: MetadataDict = Field(default_factory=dict) metadata: MetadataDict = Field(default_factory=dict)
[docs] class ModalityRow(CatalogRow): """One row per stream within a recording. Primary key: composite ``(recording_hash, modality)``. The Lance writer will enforce uniqueness; this row schema only declares the contract. Architecture v0.4 splits a modality's lifecycle into two states tracked by :attr:`ingestion_status`: * ``raw`` — registered against a cold-bucket raw file. ``format`` may be a ``RAW_*`` value or null (unknown at registration time); ``domain_intervals`` and ``channel_spec`` are typically null. ``storage_uri`` mirrors :attr:`raw_storage_uri` (the immutable cold-bucket pointer). * ``processed`` — Virgo's ingestion node has converted the raw file to a canonical format. ``format`` must be a non-``RAW_*`` value; ``domain_intervals`` and ``channel_spec`` must be populated; ``storage_uri`` points at the processed object on the hot bucket. :attr:`raw_storage_uri` is preserved across the transition so re-ingestion is always possible. :attr:`storage_uri` always points to the *current authoritative* object (raw URI initially, swapped to Zarr/MP4 after Virgo ingestion). :attr:`raw_storage_uri` is the permanent cold-bucket pointer. :attr:`domain_intervals` is a list of ``(start, end)`` tuples in the recording's native time domain, handling non-continuous recordings, irregular series, and gaps. Each interval enforces ``end > start`` and intervals must be sorted ascending and non-overlapping. Null while ``ingestion_status="raw"`` (Virgo populates it during ingestion using ``temporaldata`` domain-computation utilities). ``channel_spec`` uses :data:`MetadataDict` — per-channel structured metadata (e.g. polarity, reference) should be encoded as parallel lists (``{"channel_names": [...], "polarity": [...], "reference": [...]}``), not list-of-dicts; that's required for Lance Map queryability. """ __primary_key__: ClassVar[tuple[str, ...]] = ("recording_hash", "modality") recording_hash: CatalogID modality: ModalityName ingestion_status: IngestionStatus = IngestionStatus.RAW storage_uri: StorageURI raw_storage_uri: StorageURI format: StorageFormat | None = None sampling_rate: float | None = Field(default=None, gt=0) domain_intervals: list[tuple[float, float]] | None = None channel_spec: MetadataDict | None = None metadata: MetadataDict = Field(default_factory=dict)
[docs] @model_validator(mode="after") def _validate_domain_intervals(self) -> ModalityRow: if self.domain_intervals is None: return self prev_end: float | None = None for idx, interval in enumerate(self.domain_intervals): start, end = interval if end <= start: raise ValueError( f"domain_intervals[{idx}] has end ({end}) <= start ({start}); " "each interval must satisfy end > start" ) if prev_end is not None and start < prev_end: raise ValueError( f"domain_intervals[{idx}] starts at {start} before previous " f"interval ended at {prev_end}; intervals must be sorted " "ascending and non-overlapping" ) prev_end = end return self
[docs] @model_validator(mode="after") def _validate_ingestion_status_coherence(self) -> ModalityRow: # raw: format may be RAW_* or null; domain_intervals/channel_spec may be null. # processed: format must be canonical (non-RAW_*); domain_intervals # must be non-empty; channel_spec must be populated (may be empty dict). if self.ingestion_status is IngestionStatus.RAW: if self.format is not None and not self.format.value.startswith("raw_"): raise ValueError( f"ingestion_status=raw requires format to be a RAW_* value or null, " f"got {self.format.name}" ) return self # processed if self.format is None or self.format.value.startswith("raw_"): raise ValueError( "ingestion_status=processed requires a canonical format " "(ZARR, LANCE, MP4_INDEX, PARQUET); " f"got {self.format.name if self.format is not None else 'None'}" ) if not self.domain_intervals: raise ValueError( "ingestion_status=processed requires domain_intervals to be a " "non-empty list (Virgo's ingestion node populates this)" ) if self.channel_spec is None: raise ValueError( "ingestion_status=processed requires channel_spec to be populated " "(may be an empty dict for modalities without per-channel structure)" ) return self
[docs] class EventRow(CatalogRow): """System prompts, user responses, and any time-stamped event. Primary key: ``event_id`` — any unique catalog ID. ``event_time`` is in the recording's native time domain; negative values are legal for re-aligned timelines or pre-roll events. """ __primary_key__: ClassVar[tuple[str, ...]] = ("event_id",) event_id: CatalogID recording_hash: CatalogID event_time: float event_type: NonEmptyString prompt: str | None = None response: str | None = None metadata: MetadataDict = Field(default_factory=dict)
[docs] class VirgoAssetRow(CatalogRow): """One row per Virgo output, with full provenance. Primary key: ``asset_id`` — any unique catalog ID. ``last_accessed_at`` is updated by Ursa's query layer on every read and drives lifecycle GC (architecture doc §3.6). """ __primary_key__: ClassVar[tuple[str, ...]] = ("asset_id",) asset_id: CatalogID recording_hash: CatalogID pipeline_name: NonEmptyString pipeline_version: NonEmptyString #: Three-axis content hash (data + code + config). Same character set as #: any other catalog ID — pipeline authors use the standard hash util. cache_key: CatalogID code_version: NonEmptyString #: Canonical Pydantic-dump hash of the producing transform's config. config_hash: CatalogID time_window: TimeWindow created_at: UTCDatetime last_accessed_at: UTCDatetime storage_uri: StorageURI
[docs] class CheckpointRow(CatalogRow): """One row per Orion model checkpoint. Primary key: ``checkpoint_id`` — any unique catalog ID. ``run_id`` is an opaque ClearML task ID; runs are tracked in ClearML's own database, not in the Ursa catalog. ``parent_checkpoint_id`` records the checkpoint this one was resumed from, enabling resume chains and full lineage traversal. The full list of recordings consumed lives at ``storage_uri/data_hashes/manifest.json``. """ __primary_key__: ClassVar[tuple[str, ...]] = ("checkpoint_id",) checkpoint_id: CatalogID run_id: CatalogID step: int = Field(ge=0) model_id: NonEmptyString code_version: NonEmptyString storage_uri: StorageURI created_at: UTCDatetime parent_checkpoint_id: CatalogID | None = None metadata: MetadataDict = Field(default_factory=dict)
[docs] class BenchmarkSuiteRow(CatalogRow): """One row per versioned benchmark suite configuration. Primary key: composite ``(suite_name, suite_version)``, matching the identity used by :class:`BenchmarkResultRow` as its FK. Suite configs are standalone — no FK to recordings or participants. ``storage_uri`` points to the held-out query spec and metric definitions on R2. """ __primary_key__: ClassVar[tuple[str, ...]] = ("suite_name", "suite_version") suite_name: NonEmptyString suite_version: int = Field(gt=0) storage_uri: StorageURI created_at: UTCDatetime metadata: MetadataDict = Field(default_factory=dict)
[docs] class BenchmarkResultRow(CatalogRow): """One row per benchmark evaluation result. Primary key: ``result_id`` — content-addressed hash of the six identity fields ``(suite_name, suite_version, checkpoint_id, dataset_hash, partial_subset, partial_seed)``. The six fields are stored as queryable columns so callers can look up results without pre-computing the hash. ``dataset_hash`` is an opaque string; computation convention is a future feature. ``partial_subset`` ∈ (0, 1] and ``partial_seed`` distinguish full evals from in-training partial benchmarks (see architecture §5.7). """ __primary_key__: ClassVar[tuple[str, ...]] = ("result_id",) result_id: CatalogID suite_name: NonEmptyString suite_version: int = Field(gt=0) checkpoint_id: CatalogID dataset_hash: CatalogID partial_subset: float = Field(default=1.0, gt=0.0, le=1.0) partial_seed: int | None = None storage_uri: StorageURI computed_at: UTCDatetime metadata: MetadataDict = Field(default_factory=dict)
[docs] class EmbeddingRow(CatalogRow): """Vector embedding over a (recording, modality, window) tuple. Primary key: ``embedding_id`` — any unique catalog ID. ``vector`` is typed as ``Any`` for runtime performance; runtime contract is non-empty list[float] (no bools), enforced by :func:`_validate_vector_fast`. Per-model fixed-dim enforcement lives in the Lance writer, materialized as ``FixedSizeList[float, dim]`` keyed by ``model_id``. """ __primary_key__: ClassVar[tuple[str, ...]] = ("embedding_id",) embedding_id: CatalogID source: EmbeddingSource vector: Vector model_id: NonEmptyString