"""PyArrow schemas for the nine catalog tables.
Each schema is hand-written rather than auto-generated from the Pydantic
row class. Reasons:
* Lance/Arrow type choices (``timestamp("us", tz="UTC")`` vs the default
``timestamp("ns")``; ``duration("us")`` for ``timedelta``;
``list_(float64)`` for embedding vectors) are explicit and reviewable.
* ``MetadataDict``-typed columns serialize to JSON strings in M2 — see
ENG-1066 for the M3 promotion to Lance ``MapType`` + hot-key columns.
* Nested submodels (``TimeWindow``, ``EmbeddingSource``) materialize as
``pa.struct`` with a fixed field set; extras are rejected at write
time because Arrow struct columns are fixed-schema.
A drift test (``tests/catalog/test_catalog_arrow_coverage.py``) asserts
every Pydantic field is covered, so adding a field to a row class
without touching this module fails CI.
"""
from __future__ import annotations
import pyarrow as pa
from ursa.catalog.schemas import (
BenchmarkResultRow,
BenchmarkSuiteRow,
CheckpointRow,
EmbeddingRow,
EventRow,
ModalityRow,
ParticipantRow,
RecordingRow,
VirgoAssetRow,
)
from ursa.layout import (
TABLE_BENCHMARK_RESULTS,
TABLE_BENCHMARK_SUITES,
TABLE_CHECKPOINTS,
TABLE_EMBEDDINGS,
TABLE_EVENTS,
TABLE_MODALITIES,
TABLE_PARTICIPANTS,
TABLE_RECORDINGS,
TABLE_VIRGO_ASSETS,
)
__all__ = [
"ARROW_SCHEMAS",
"EXTRAS_COLUMN",
"JSON_METADATA_COLUMNS",
"NULLABLE_JSON_METADATA_COLUMNS",
"PRIMARY_KEYS",
"ROW_CLASSES",
]
# Column name used to round-trip top-level Pydantic extras through Lance.
# Stored as a JSON string per row; ``"{}"`` when the row has no extras.
EXTRAS_COLUMN = "__pydantic_extra__"
# Reusable Arrow types — declared once so the table schemas stay scannable.
_TS_UTC = pa.timestamp("us", tz="UTC")
_DUR = pa.duration("us")
_STR = pa.string()
_F64 = pa.float64()
_I64 = pa.int64()
_VECTOR = pa.list_(_F64)
# Per-row JSON-encoded MetadataDict columns. M2 stores these as strings
# so the catalog ships without a MapType migration; ENG-1066 promotes hot
# keys to typed top-level columns and the rest to Lance MapType.
JSON_METADATA_COLUMNS: dict[str, frozenset[str]] = {
TABLE_PARTICIPANTS: frozenset({"metadata"}),
TABLE_RECORDINGS: frozenset({"device_info", "metadata"}),
TABLE_MODALITIES: frozenset({"channel_spec", "metadata"}),
TABLE_EVENTS: frozenset({"metadata"}),
TABLE_EMBEDDINGS: frozenset(),
TABLE_VIRGO_ASSETS: frozenset(),
TABLE_CHECKPOINTS: frozenset({"metadata"}),
TABLE_BENCHMARK_SUITES: frozenset({"metadata"}),
TABLE_BENCHMARK_RESULTS: frozenset({"metadata"}),
}
# JSON-encoded MetadataDict columns that are *also* Pydantic-nullable.
# Architecture v0.4: ``ModalityRow.channel_spec`` is null while
# ``ingestion_status="raw"`` (Virgo populates it during ingestion). Other
# JSON columns default to ``{}`` and stay non-null.
NULLABLE_JSON_METADATA_COLUMNS: dict[str, frozenset[str]] = {
TABLE_MODALITIES: frozenset({"channel_spec"}),
}
_TIME_WINDOW = pa.struct(
[
pa.field("start_seconds", _F64, nullable=False),
pa.field("end_seconds", _F64, nullable=False),
]
)
_EMBEDDING_SOURCE = pa.struct(
[
pa.field("recording_hash", _STR, nullable=False),
pa.field("modality", _STR, nullable=False),
pa.field("time_window", _TIME_WINDOW, nullable=False),
]
)
# Per-modality temporal coverage list (architecture v0.4): list of
# half-open ``[start, end)`` intervals in the recording's native time
# domain. Materialized as ``list<struct<start, end>>`` rather than
# ``list<list<float>>`` so each interval has named fields and Arrow's
# nested-struct filter pushdown applies once we add range queries (M3).
_DOMAIN_INTERVAL = pa.struct(
[
pa.field("start", _F64, nullable=False),
pa.field("end", _F64, nullable=False),
]
)
_DOMAIN_INTERVALS = pa.list_(_DOMAIN_INTERVAL)
# RecordingRow.participant_ids is a list of catalog IDs (architecture
# v0.4): a recording can cover multiple participants. Membership filters
# require array-contains predicates; see ``ursa.catalog._filters``.
_PARTICIPANT_IDS = pa.list_(_STR)
ARROW_SCHEMAS: dict[str, pa.Schema] = {
TABLE_PARTICIPANTS: pa.schema(
[
pa.field("participant_id", _STR, nullable=False),
pa.field("enrolled_at", _TS_UTC, nullable=False),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_RECORDINGS: pa.schema(
[
pa.field("recording_hash", _STR, nullable=False),
pa.field("participant_ids", _PARTICIPANT_IDS, nullable=False),
pa.field("start_time", _TS_UTC, nullable=False),
pa.field("duration", _DUR, nullable=False),
pa.field("device_info", _STR, nullable=False),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_MODALITIES: pa.schema(
[
pa.field("recording_hash", _STR, nullable=False),
pa.field("modality", _STR, nullable=False),
pa.field("ingestion_status", _STR, nullable=False),
pa.field("storage_uri", _STR, nullable=False),
pa.field("raw_storage_uri", _STR, nullable=False),
# Nullable in v0.4: format / channel_spec / domain_intervals
# are populated by Virgo's ingestion node, not at registration.
pa.field("format", _STR, nullable=True),
pa.field("sampling_rate", _F64, nullable=True),
pa.field("domain_intervals", _DOMAIN_INTERVALS, nullable=True),
pa.field("channel_spec", _STR, nullable=True),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_EVENTS: pa.schema(
[
pa.field("event_id", _STR, nullable=False),
pa.field("recording_hash", _STR, nullable=False),
pa.field("event_time", _F64, nullable=False),
pa.field("event_type", _STR, nullable=False),
pa.field("prompt", _STR, nullable=True),
pa.field("response", _STR, nullable=True),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_EMBEDDINGS: pa.schema(
[
pa.field("embedding_id", _STR, nullable=False),
pa.field("source", _EMBEDDING_SOURCE, nullable=False),
pa.field("vector", _VECTOR, nullable=False),
pa.field("model_id", _STR, nullable=False),
_extras_field(),
]
),
TABLE_VIRGO_ASSETS: pa.schema(
[
pa.field("asset_id", _STR, nullable=False),
pa.field("recording_hash", _STR, nullable=False),
pa.field("pipeline_name", _STR, nullable=False),
pa.field("pipeline_version", _STR, nullable=False),
pa.field("cache_key", _STR, nullable=False),
pa.field("code_version", _STR, nullable=False),
pa.field("config_hash", _STR, nullable=False),
pa.field("time_window", _TIME_WINDOW, nullable=False),
pa.field("created_at", _TS_UTC, nullable=False),
pa.field("last_accessed_at", _TS_UTC, nullable=False),
pa.field("storage_uri", _STR, nullable=False),
_extras_field(),
]
),
TABLE_CHECKPOINTS: pa.schema(
[
pa.field("checkpoint_id", _STR, nullable=False),
pa.field("run_id", _STR, nullable=False),
pa.field("step", _I64, nullable=False),
pa.field("model_id", _STR, nullable=False),
pa.field("code_version", _STR, nullable=False),
pa.field("storage_uri", _STR, nullable=False),
pa.field("created_at", _TS_UTC, nullable=False),
pa.field("parent_checkpoint_id", _STR, nullable=True),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_BENCHMARK_SUITES: pa.schema(
[
pa.field("suite_name", _STR, nullable=False),
pa.field("suite_version", _I64, nullable=False),
pa.field("storage_uri", _STR, nullable=False),
pa.field("created_at", _TS_UTC, nullable=False),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
TABLE_BENCHMARK_RESULTS: pa.schema(
[
pa.field("result_id", _STR, nullable=False),
pa.field("suite_name", _STR, nullable=False),
pa.field("suite_version", _I64, nullable=False),
pa.field("checkpoint_id", _STR, nullable=False),
pa.field("dataset_hash", _STR, nullable=False),
pa.field("partial_subset", _F64, nullable=False),
pa.field("partial_seed", _I64, nullable=True),
pa.field("storage_uri", _STR, nullable=False),
pa.field("computed_at", _TS_UTC, nullable=False),
pa.field("metadata", _STR, nullable=False),
_extras_field(),
]
),
}
ROW_CLASSES = {
TABLE_PARTICIPANTS: ParticipantRow,
TABLE_RECORDINGS: RecordingRow,
TABLE_MODALITIES: ModalityRow,
TABLE_EVENTS: EventRow,
TABLE_EMBEDDINGS: EmbeddingRow,
TABLE_VIRGO_ASSETS: VirgoAssetRow,
TABLE_CHECKPOINTS: CheckpointRow,
TABLE_BENCHMARK_SUITES: BenchmarkSuiteRow,
TABLE_BENCHMARK_RESULTS: BenchmarkResultRow,
}
PRIMARY_KEYS: dict[str, list[str]] = {
name: list(cls.__primary_key__) # type: ignore[attr-defined]
for name, cls in ROW_CLASSES.items()
}