"""Pydantic schemas for the nine Lance catalog tables.
Design notes:
* ``extra="allow"`` is deliberate. A row read with the model from version N
can still round-trip unknown columns added in version N+1 — additive Lance
schema evolution stays a config change, not a code change. Opposite of the
secrets models in ``constellation-utils`` (``extra="forbid"``).
* ``frozen=True``: declared fields are immutable once constructed. Catalog
mutations go through writes of new rows. Note: ``__pydantic_extra__`` is a
regular dict and is **not** frozen — we treat that as a permitted escape
hatch for ad-hoc extra hydration during ingestion, and lock the behavior
in a test so a Pydantic upgrade can't change it silently.
* All datetime fields use :data:`UTCDatetime`: aware-only on input,
normalized to UTC after validation. Tradeoff — original tz offset is
discarded; rationale is fewer foot-guns when comparing timestamps across
rows in the lifecycle layer (e.g. ``last_accessed_at`` GC).
* Map-style fields are constrained to scalar values (and lists of scalars).
Lance/Arrow ``MapType`` requires a single value type per column, and the
architecture promises "frequently-used keys can be promoted to typed
columns later." Both promises break with nested dicts or arbitrary objects.
* Time-origin convention: per-recording time fields (``domain_intervals``
tuple bounds, ``event_time``, ``TimeWindow.start_seconds``) are in the
recording's native time domain. **Negative values are legal** for
re-aligned timelines (stimulus-onset-relative events, pre-roll baseline).
Only the relative invariant ``end > start`` is enforced per interval;
intervals must additionally be sorted ascending and non-overlapping.
* ID fields use :data:`CatalogID`: non-empty, ``[A-Za-z0-9_-]+``. Stricter
content-hash / UUID4 enforcement is deferred to a shared util / writer.
* ``EmbeddingRow.vector`` is the highest-cardinality hot path; per-element
Pydantic validation of a 1024+-dim ``list[float]`` is ~4 ms/row. We type
the field as ``Any`` and run a single ``BeforeValidator`` that checks
length and ``v[0]``. Static type degrades to ``Any``; runtime contract
is preserved.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import Annotated, Any, ClassVar
from pydantic import (
AfterValidator,
AwareDatetime,
BaseModel,
BeforeValidator,
ConfigDict,
Field,
StringConstraints,
model_validator,
)
__all__ = [
"BenchmarkResultRow",
"BenchmarkSuiteRow",
"CatalogID",
"CatalogRow",
"CheckpointRow",
"EmbeddingRow",
"EmbeddingSource",
"EventRow",
"ID_PATTERN",
"IngestionStatus",
"MetadataDict",
"MetadataValue",
"ModalityName",
"ModalityRow",
"NonEmptyString",
"ParticipantRow",
"RecordingRow",
"ScalarMetadataValue",
"StorageFormat",
"StorageURI",
"TimeWindow",
"URI_PATTERN",
"UTCDatetime",
"VirgoAssetRow",
]
# ---------------------------------------------------------------------------
# Datetimes — aware-only, UTC-normalized
# ---------------------------------------------------------------------------
[docs]
def _to_utc(dt: datetime) -> datetime:
"""Normalize an aware datetime to UTC.
``AwareDatetime`` rejects naive inputs upstream, so ``dt.tzinfo`` is
guaranteed non-None here. The ``astimezone`` call is a no-op when input
is already UTC.
"""
return dt.astimezone(UTC)
#: Aware datetime, normalized to UTC after validation. Original tz offset is
#: discarded — this is intentional, see module docstring.
UTCDatetime = Annotated[AwareDatetime, AfterValidator(_to_utc)]
# ---------------------------------------------------------------------------
# Map-column value constraint
# ---------------------------------------------------------------------------
#: A scalar permitted in a Map-typed catalog column. ``bool`` is included
#: here as a legitimate flag value (contrast with vector elements, where it's
#: excluded as a numeric foot-gun).
ScalarMetadataValue = str | int | float | bool | None
#: A scalar OR a list of scalars. Nested dicts and arbitrary Python objects
#: are rejected at the type system level — they break Lance MapType.
MetadataValue = ScalarMetadataValue | list[ScalarMetadataValue]
[docs]
def _check_homogeneous_lists(value: dict[str, Any]) -> dict[str, Any]:
"""Reject mixed-type lists in a metadata map.
``MetadataValue`` already type-restricts list elements to scalars, but a
list with mixed scalar types (``[1, "two", True]``) still passes. Lance
``MapType`` requires a single Arrow type per column, so we walk the
values and reject heterogeneous lists at construction time. Bool is a
subclass of int — collapse them to the same bucket so ``[True, False, 0]``
isn't flagged as mixed.
"""
for key, v in value.items():
if not isinstance(v, list) or len(v) <= 1:
continue
# Treat bool/int as the same bucket; otherwise type(...) is enough.
types = {int if isinstance(item, bool) else type(item) for item in v}
if len(types) > 1:
raise ValueError(
f"metadata[{key!r}] has mixed-type list {v!r}; Lance MapType "
"requires homogeneous lists"
)
return value
#: A flat, queryable map for a Lance Map-typed column. Values must be a
#: single scalar type or a homogeneous list of scalars.
MetadataDict = Annotated[dict[str, MetadataValue], AfterValidator(_check_homogeneous_lists)]
# ---------------------------------------------------------------------------
# IDs and URIs
# ---------------------------------------------------------------------------
#: Catalog ID character set. Public so consumers can validate IDs without
#: instantiating a row (CLI input, ad-hoc scripts). Stricter format checks
#: (UUID4 hex, content-hash hex) deferred to a shared util / writer.
ID_PATTERN = r"^[A-Za-z0-9_-]+$"
CatalogID = Annotated[str, StringConstraints(min_length=1, pattern=ID_PATTERN)]
#: Object-store URI scheme constraint. ``{r2, s3, gcs}`` cover ENG-873's
#: planned backends; ``{file}`` covers Polaris/local SSD. ``lance://`` is
#: deliberately omitted — the architecture doc has no Lance URI form yet, and
#: adding a scheme later is a non-breaking change. Public so consumers can
#: validate URIs without instantiating a row.
URI_PATTERN = r"^(r2|s3|gcs|file)://\S+$"
StorageURI = Annotated[str, StringConstraints(min_length=1, pattern=URI_PATTERN)]
#: Non-empty short string — pipeline names, version strings, model ids, etc.
#: Reuse this everywhere ``Field(min_length=1)`` was inlined; centralizing
#: lets a future tightening (max_length, charset) change one place.
NonEmptyString = Annotated[str, StringConstraints(min_length=1)]
#: A modality name (``eeg``, ``video_webcam``, ``video_webcam_left``).
#: Identical type used in :class:`ModalityRow` and :class:`EmbeddingSource`;
#: aliasing keeps the join column consistent under future tightening.
ModalityName = NonEmptyString
# ---------------------------------------------------------------------------
# Closed enums
# ---------------------------------------------------------------------------
[docs]
class IngestionStatus(StrEnum):
"""Per-modality ingestion lifecycle (architecture v0.4 two-store model).
``raw`` — modality registered against a cold-bucket raw file (whole
object addressable via ``ModalityRow.raw_storage_uri``); Virgo's
ingestion node has not yet run, so ``domain_intervals``,
``channel_spec``, and ``format`` may be null. ``storage_uri`` mirrors
``raw_storage_uri`` until ingestion completes.
``processed`` — Virgo's ingestion node has converted the raw file to
a canonical format (Zarr / MP4 + Lance frame index). ``format`` is a
canonical (non-``RAW_*``) value, ``domain_intervals`` and
``channel_spec`` are populated, and ``storage_uri`` points at the
processed object on the hot bucket. ``raw_storage_uri`` is preserved
so re-ingestion is always possible.
"""
RAW = "raw"
PROCESSED = "processed"
# ---------------------------------------------------------------------------
# Sub-models
# ---------------------------------------------------------------------------
[docs]
class TimeWindow(BaseModel):
"""Half-open ``[start, end)`` window in the recording's native time domain.
No absolute bound on ``start_seconds`` — re-aligned timelines (stimulus
onset, pre-roll baseline) legitimately produce negative times. Only the
relative invariant ``end > start`` is enforced.
Note (deferred benchmark): nested struct columns in Lance have weaker
filter pushdown than top-level columns. If query latency on
``VirgoAssetRow.time_window`` or ``EmbeddingRow.source.time_window``
becomes a bottleneck, the writer can flatten internally and treat this
Pydantic type as a logical view.
"""
model_config = ConfigDict(frozen=True, extra="allow")
start_seconds: float
end_seconds: float
[docs]
@model_validator(mode="after")
def _end_after_start(self) -> TimeWindow:
if self.end_seconds <= self.start_seconds:
raise ValueError(
f"end_seconds ({self.end_seconds}) must exceed start_seconds ({self.start_seconds})"
)
return self
[docs]
class EmbeddingSource(BaseModel):
"""Source of an embedding row: which (recording, modality, window) it
covers. Same nested-struct caveat as :class:`TimeWindow`.
"""
model_config = ConfigDict(frozen=True, extra="allow")
recording_hash: CatalogID
modality: ModalityName
time_window: TimeWindow
# ---------------------------------------------------------------------------
# Vector hot path
# ---------------------------------------------------------------------------
[docs]
def _validate_vector_fast(v: Any) -> list[float]:
"""Fast-path vector validator: skips per-element Pydantic checks (~4 ms/1024-dim).
Accepts any sized indexable sequence — list, tuple, numpy.ndarray, torch.Tensor —
without importing numpy/torch. Only checks v[0]; heterogeneous tails slip through
and are caught by Arrow at write time. Raises ValueError (not TypeError) so Pydantic
wraps as ValidationError.
"""
if isinstance(v, (str, bytes, bytearray)):
# str/bytes have __len__ + __getitem__ but obviously aren't vectors.
raise ValueError(f"vector must be a numeric sequence, got {type(v).__name__}")
try:
length = len(v)
except TypeError as exc:
raise ValueError(f"vector must be a sized sequence, got {type(v).__name__}") from exc
if length == 0:
raise ValueError("vector must be non-empty")
try:
sample = v[0]
except (TypeError, IndexError, KeyError) as exc:
raise ValueError(f"vector must be indexable, got {type(v).__name__}") from exc
# bool is a subclass of int — exclude explicitly so [True, False] doesn't
# sneak through. Bools are legitimate metadata values; in a numeric vector
# they're a foot-gun.
if isinstance(sample, bool) or not isinstance(sample, (int, float)):
# numpy/torch scalar elements: try .item() to reach a Python scalar.
item = getattr(sample, "item", None)
sample = item() if callable(item) else sample
if isinstance(sample, bool) or not isinstance(sample, (int, float)):
raise ValueError(
f"vector elements must be numeric (int|float, not bool), "
f"got {type(sample).__name__}"
)
if isinstance(v, list):
return v
tolist = getattr(v, "tolist", None)
return tolist() if callable(tolist) else list(v)
#: Embedding vector. Static type ``Any`` for perf; runtime contract is
#: ``list[float]`` non-empty, no bools.
Vector = Annotated[Any, BeforeValidator(_validate_vector_fast)]
# ---------------------------------------------------------------------------
# Base
# ---------------------------------------------------------------------------
[docs]
class CatalogRow(BaseModel):
"""Base for all Lance catalog table rows.
Subclasses MUST declare ``__primary_key__`` as a non-empty tuple of field
names that identify the row uniquely within its table. Failure raises
``TypeError`` at class-definition time. The future Lance writer reads
this attribute to enforce uniqueness.
"""
model_config = ConfigDict(
frozen=True,
extra="allow",
populate_by_name=True,
str_strip_whitespace=True,
)
__primary_key__: ClassVar[tuple[str, ...]] = ()
[docs]
@classmethod
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
super().__pydantic_init_subclass__(**kwargs)
pk = cls.__dict__.get("__primary_key__")
if not pk:
raise TypeError(f"{cls.__name__} must declare a non-empty __primary_key__ ClassVar")
for field_name in pk:
if field_name not in cls.__pydantic_fields__:
raise TypeError(
f"{cls.__name__}.__primary_key__ references unknown field {field_name!r}"
)
# ---------------------------------------------------------------------------
# Row models
# ---------------------------------------------------------------------------
[docs]
class ParticipantRow(CatalogRow):
"""One row per enrolled participant.
Primary key: ``participant_id`` — any unique catalog ID. By convention a
short slug like ``p042``; not enforced.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("participant_id",)
participant_id: CatalogID
enrolled_at: UTCDatetime
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class RecordingRow(CatalogRow):
"""One row per recording.
Primary key: ``recording_hash`` — any unique catalog ID; content-hash
convention enforced later via a shared util.
``participant_ids`` is a list (architecture v0.4): a recording can
cover multiple participants (multi-subject experiments, dyad sessions,
crowd recordings). Single-participant recordings carry a one-element
list. The list must be non-empty.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("recording_hash",)
recording_hash: CatalogID
participant_ids: list[CatalogID] = Field(min_length=1)
start_time: UTCDatetime
duration: timedelta = Field(gt=timedelta(0))
#: Per-rig hardware manifest as a Lance MapType column: rig name, build /
#: firmware versions, sensor serial numbers. Same scalar / homogeneous-list
#: constraint as ``metadata`` — encode per-channel structure as parallel
#: lists, not list-of-dicts.
device_info: MetadataDict = Field(default_factory=dict)
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class ModalityRow(CatalogRow):
"""One row per stream within a recording.
Primary key: composite ``(recording_hash, modality)``. The Lance writer
will enforce uniqueness; this row schema only declares the contract.
Architecture v0.4 splits a modality's lifecycle into two states tracked
by :attr:`ingestion_status`:
* ``raw`` — registered against a cold-bucket raw file. ``format`` may
be a ``RAW_*`` value or null (unknown at registration time);
``domain_intervals`` and ``channel_spec`` are typically null.
``storage_uri`` mirrors :attr:`raw_storage_uri` (the immutable
cold-bucket pointer).
* ``processed`` — Virgo's ingestion node has converted the raw file to
a canonical format. ``format`` must be a non-``RAW_*`` value;
``domain_intervals`` and ``channel_spec`` must be populated;
``storage_uri`` points at the processed object on the hot bucket.
:attr:`raw_storage_uri` is preserved across the transition so
re-ingestion is always possible.
:attr:`storage_uri` always points to the *current authoritative* object
(raw URI initially, swapped to Zarr/MP4 after Virgo ingestion).
:attr:`raw_storage_uri` is the permanent cold-bucket pointer.
:attr:`domain_intervals` is a list of ``(start, end)`` tuples in the
recording's native time domain, handling non-continuous recordings,
irregular series, and gaps. Each interval enforces ``end > start`` and
intervals must be sorted ascending and non-overlapping. Null while
``ingestion_status="raw"`` (Virgo populates it during ingestion using
``temporaldata`` domain-computation utilities).
``channel_spec`` uses :data:`MetadataDict` — per-channel structured
metadata (e.g. polarity, reference) should be encoded as parallel lists
(``{"channel_names": [...], "polarity": [...], "reference": [...]}``),
not list-of-dicts; that's required for Lance Map queryability.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("recording_hash", "modality")
recording_hash: CatalogID
modality: ModalityName
ingestion_status: IngestionStatus = IngestionStatus.RAW
storage_uri: StorageURI
raw_storage_uri: StorageURI
format: StorageFormat | None = None
sampling_rate: float | None = Field(default=None, gt=0)
domain_intervals: list[tuple[float, float]] | None = None
channel_spec: MetadataDict | None = None
metadata: MetadataDict = Field(default_factory=dict)
[docs]
@model_validator(mode="after")
def _validate_domain_intervals(self) -> ModalityRow:
if self.domain_intervals is None:
return self
prev_end: float | None = None
for idx, interval in enumerate(self.domain_intervals):
start, end = interval
if end <= start:
raise ValueError(
f"domain_intervals[{idx}] has end ({end}) <= start ({start}); "
"each interval must satisfy end > start"
)
if prev_end is not None and start < prev_end:
raise ValueError(
f"domain_intervals[{idx}] starts at {start} before previous "
f"interval ended at {prev_end}; intervals must be sorted "
"ascending and non-overlapping"
)
prev_end = end
return self
[docs]
@model_validator(mode="after")
def _validate_ingestion_status_coherence(self) -> ModalityRow:
# raw: format may be RAW_* or null; domain_intervals/channel_spec may be null.
# processed: format must be canonical (non-RAW_*); domain_intervals
# must be non-empty; channel_spec must be populated (may be empty dict).
if self.ingestion_status is IngestionStatus.RAW:
if self.format is not None and not self.format.value.startswith("raw_"):
raise ValueError(
f"ingestion_status=raw requires format to be a RAW_* value or null, "
f"got {self.format.name}"
)
return self
# processed
if self.format is None or self.format.value.startswith("raw_"):
raise ValueError(
"ingestion_status=processed requires a canonical format "
"(ZARR, LANCE, MP4_INDEX, PARQUET); "
f"got {self.format.name if self.format is not None else 'None'}"
)
if not self.domain_intervals:
raise ValueError(
"ingestion_status=processed requires domain_intervals to be a "
"non-empty list (Virgo's ingestion node populates this)"
)
if self.channel_spec is None:
raise ValueError(
"ingestion_status=processed requires channel_spec to be populated "
"(may be an empty dict for modalities without per-channel structure)"
)
return self
[docs]
class EventRow(CatalogRow):
"""System prompts, user responses, and any time-stamped event.
Primary key: ``event_id`` — any unique catalog ID.
``event_time`` is in the recording's native time domain; negative values
are legal for re-aligned timelines or pre-roll events.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("event_id",)
event_id: CatalogID
recording_hash: CatalogID
event_time: float
event_type: NonEmptyString
prompt: str | None = None
response: str | None = None
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class VirgoAssetRow(CatalogRow):
"""One row per Virgo output, with full provenance.
Primary key: ``asset_id`` — any unique catalog ID.
``last_accessed_at`` is updated by Ursa's query layer on every read and
drives lifecycle GC (architecture doc §3.6).
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("asset_id",)
asset_id: CatalogID
recording_hash: CatalogID
pipeline_name: NonEmptyString
pipeline_version: NonEmptyString
#: Three-axis content hash (data + code + config). Same character set as
#: any other catalog ID — pipeline authors use the standard hash util.
cache_key: CatalogID
code_version: NonEmptyString
#: Canonical Pydantic-dump hash of the producing transform's config.
config_hash: CatalogID
time_window: TimeWindow
created_at: UTCDatetime
last_accessed_at: UTCDatetime
storage_uri: StorageURI
[docs]
class CheckpointRow(CatalogRow):
"""One row per Orion model checkpoint.
Primary key: ``checkpoint_id`` — any unique catalog ID.
``run_id`` is an opaque ClearML task ID; runs are tracked in ClearML's
own database, not in the Ursa catalog. ``parent_checkpoint_id`` records
the checkpoint this one was resumed from, enabling resume chains and full
lineage traversal. The full list of recordings consumed lives at
``storage_uri/data_hashes/manifest.json``.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("checkpoint_id",)
checkpoint_id: CatalogID
run_id: CatalogID
step: int = Field(ge=0)
model_id: NonEmptyString
code_version: NonEmptyString
storage_uri: StorageURI
created_at: UTCDatetime
parent_checkpoint_id: CatalogID | None = None
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class BenchmarkSuiteRow(CatalogRow):
"""One row per versioned benchmark suite configuration.
Primary key: composite ``(suite_name, suite_version)``, matching the
identity used by :class:`BenchmarkResultRow` as its FK. Suite configs are
standalone — no FK to recordings or participants. ``storage_uri`` points
to the held-out query spec and metric definitions on R2.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("suite_name", "suite_version")
suite_name: NonEmptyString
suite_version: int = Field(gt=0)
storage_uri: StorageURI
created_at: UTCDatetime
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class BenchmarkResultRow(CatalogRow):
"""One row per benchmark evaluation result.
Primary key: ``result_id`` — content-addressed hash of the six identity
fields ``(suite_name, suite_version, checkpoint_id, dataset_hash,
partial_subset, partial_seed)``. The six fields are stored as queryable
columns so callers can look up results without pre-computing the hash.
``dataset_hash`` is an opaque string; computation convention is a future
feature. ``partial_subset`` ∈ (0, 1] and ``partial_seed`` distinguish
full evals from in-training partial benchmarks (see architecture §5.7).
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("result_id",)
result_id: CatalogID
suite_name: NonEmptyString
suite_version: int = Field(gt=0)
checkpoint_id: CatalogID
dataset_hash: CatalogID
partial_subset: float = Field(default=1.0, gt=0.0, le=1.0)
partial_seed: int | None = None
storage_uri: StorageURI
computed_at: UTCDatetime
metadata: MetadataDict = Field(default_factory=dict)
[docs]
class EmbeddingRow(CatalogRow):
"""Vector embedding over a (recording, modality, window) tuple.
Primary key: ``embedding_id`` — any unique catalog ID.
``vector`` is typed as ``Any`` for runtime performance; runtime contract
is non-empty list[float] (no bools), enforced by
:func:`_validate_vector_fast`. Per-model fixed-dim enforcement lives in
the Lance writer, materialized as ``FixedSizeList[float, dim]`` keyed by
``model_id``.
"""
__primary_key__: ClassVar[tuple[str, ...]] = ("embedding_id",)
embedding_id: CatalogID
source: EmbeddingSource
vector: Vector
model_id: NonEmptyString