Source code for ursa.catalog._arrow

"""PyArrow schemas for the nine catalog tables.

Each schema is hand-written rather than auto-generated from the Pydantic
row class. Reasons:

* Lance/Arrow type choices (``timestamp("us", tz="UTC")`` vs the default
  ``timestamp("ns")``; ``duration("us")`` for ``timedelta``;
  ``list_(float64)`` for embedding vectors) are explicit and reviewable.
* ``MetadataDict``-typed columns serialize to JSON strings in M2 — see
  ENG-1066 for the M3 promotion to Lance ``MapType`` + hot-key columns.
* Nested submodels (``TimeWindow``, ``EmbeddingSource``) materialize as
  ``pa.struct`` with a fixed field set; extras are rejected at write
  time because Arrow struct columns are fixed-schema.

A drift test (``tests/catalog/test_catalog_arrow_coverage.py``) asserts
every Pydantic field is covered, so adding a field to a row class
without touching this module fails CI.
"""

from __future__ import annotations

import pyarrow as pa

from ursa.catalog.schemas import (
    BenchmarkResultRow,
    BenchmarkSuiteRow,
    CheckpointRow,
    EmbeddingRow,
    EventRow,
    ModalityRow,
    ParticipantRow,
    RecordingRow,
    VirgoAssetRow,
)
from ursa.layout import (
    TABLE_BENCHMARK_RESULTS,
    TABLE_BENCHMARK_SUITES,
    TABLE_CHECKPOINTS,
    TABLE_EMBEDDINGS,
    TABLE_EVENTS,
    TABLE_MODALITIES,
    TABLE_PARTICIPANTS,
    TABLE_RECORDINGS,
    TABLE_VIRGO_ASSETS,
)

__all__ = [
    "ARROW_SCHEMAS",
    "EXTRAS_COLUMN",
    "JSON_METADATA_COLUMNS",
    "NULLABLE_JSON_METADATA_COLUMNS",
    "PRIMARY_KEYS",
    "ROW_CLASSES",
]


# Column name used to round-trip top-level Pydantic extras through Lance.
# Stored as a JSON string per row; ``"{}"`` when the row has no extras.
EXTRAS_COLUMN = "__pydantic_extra__"


# Reusable Arrow types — declared once so the table schemas stay scannable.
_TS_UTC = pa.timestamp("us", tz="UTC")
_DUR = pa.duration("us")
_STR = pa.string()
_F64 = pa.float64()
_I64 = pa.int64()
_VECTOR = pa.list_(_F64)

# Per-row JSON-encoded MetadataDict columns. M2 stores these as strings
# so the catalog ships without a MapType migration; ENG-1066 promotes hot
# keys to typed top-level columns and the rest to Lance MapType.
JSON_METADATA_COLUMNS: dict[str, frozenset[str]] = {
    TABLE_PARTICIPANTS: frozenset({"metadata"}),
    TABLE_RECORDINGS: frozenset({"device_info", "metadata"}),
    TABLE_MODALITIES: frozenset({"channel_spec", "metadata"}),
    TABLE_EVENTS: frozenset({"metadata"}),
    TABLE_EMBEDDINGS: frozenset(),
    TABLE_VIRGO_ASSETS: frozenset(),
    TABLE_CHECKPOINTS: frozenset({"metadata"}),
    TABLE_BENCHMARK_SUITES: frozenset({"metadata"}),
    TABLE_BENCHMARK_RESULTS: frozenset({"metadata"}),
}


# JSON-encoded MetadataDict columns that are *also* Pydantic-nullable.
# Architecture v0.4: ``ModalityRow.channel_spec`` is null while
# ``ingestion_status="raw"`` (Virgo populates it during ingestion). Other
# JSON columns default to ``{}`` and stay non-null.
NULLABLE_JSON_METADATA_COLUMNS: dict[str, frozenset[str]] = {
    TABLE_MODALITIES: frozenset({"channel_spec"}),
}


_TIME_WINDOW = pa.struct(
    [
        pa.field("start_seconds", _F64, nullable=False),
        pa.field("end_seconds", _F64, nullable=False),
    ]
)


_EMBEDDING_SOURCE = pa.struct(
    [
        pa.field("recording_hash", _STR, nullable=False),
        pa.field("modality", _STR, nullable=False),
        pa.field("time_window", _TIME_WINDOW, nullable=False),
    ]
)


# Per-modality temporal coverage list (architecture v0.4): list of
# half-open ``[start, end)`` intervals in the recording's native time
# domain. Materialized as ``list<struct<start, end>>`` rather than
# ``list<list<float>>`` so each interval has named fields and Arrow's
# nested-struct filter pushdown applies once we add range queries (M3).
_DOMAIN_INTERVAL = pa.struct(
    [
        pa.field("start", _F64, nullable=False),
        pa.field("end", _F64, nullable=False),
    ]
)
_DOMAIN_INTERVALS = pa.list_(_DOMAIN_INTERVAL)

# RecordingRow.participant_ids is a list of catalog IDs (architecture
# v0.4): a recording can cover multiple participants. Membership filters
# require array-contains predicates; see ``ursa.catalog._filters``.
_PARTICIPANT_IDS = pa.list_(_STR)


[docs] def _extras_field() -> pa.Field: return pa.field(EXTRAS_COLUMN, _STR, nullable=False)
ARROW_SCHEMAS: dict[str, pa.Schema] = { TABLE_PARTICIPANTS: pa.schema( [ pa.field("participant_id", _STR, nullable=False), pa.field("enrolled_at", _TS_UTC, nullable=False), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_RECORDINGS: pa.schema( [ pa.field("recording_hash", _STR, nullable=False), pa.field("participant_ids", _PARTICIPANT_IDS, nullable=False), pa.field("start_time", _TS_UTC, nullable=False), pa.field("duration", _DUR, nullable=False), pa.field("device_info", _STR, nullable=False), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_MODALITIES: pa.schema( [ pa.field("recording_hash", _STR, nullable=False), pa.field("modality", _STR, nullable=False), pa.field("ingestion_status", _STR, nullable=False), pa.field("storage_uri", _STR, nullable=False), pa.field("raw_storage_uri", _STR, nullable=False), # Nullable in v0.4: format / channel_spec / domain_intervals # are populated by Virgo's ingestion node, not at registration. pa.field("format", _STR, nullable=True), pa.field("sampling_rate", _F64, nullable=True), pa.field("domain_intervals", _DOMAIN_INTERVALS, nullable=True), pa.field("channel_spec", _STR, nullable=True), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_EVENTS: pa.schema( [ pa.field("event_id", _STR, nullable=False), pa.field("recording_hash", _STR, nullable=False), pa.field("event_time", _F64, nullable=False), pa.field("event_type", _STR, nullable=False), pa.field("prompt", _STR, nullable=True), pa.field("response", _STR, nullable=True), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_EMBEDDINGS: pa.schema( [ pa.field("embedding_id", _STR, nullable=False), pa.field("source", _EMBEDDING_SOURCE, nullable=False), pa.field("vector", _VECTOR, nullable=False), pa.field("model_id", _STR, nullable=False), _extras_field(), ] ), TABLE_VIRGO_ASSETS: pa.schema( [ pa.field("asset_id", _STR, nullable=False), pa.field("recording_hash", _STR, nullable=False), pa.field("pipeline_name", _STR, nullable=False), pa.field("pipeline_version", _STR, nullable=False), pa.field("cache_key", _STR, nullable=False), pa.field("code_version", _STR, nullable=False), pa.field("config_hash", _STR, nullable=False), pa.field("time_window", _TIME_WINDOW, nullable=False), pa.field("created_at", _TS_UTC, nullable=False), pa.field("last_accessed_at", _TS_UTC, nullable=False), pa.field("storage_uri", _STR, nullable=False), _extras_field(), ] ), TABLE_CHECKPOINTS: pa.schema( [ pa.field("checkpoint_id", _STR, nullable=False), pa.field("run_id", _STR, nullable=False), pa.field("step", _I64, nullable=False), pa.field("model_id", _STR, nullable=False), pa.field("code_version", _STR, nullable=False), pa.field("storage_uri", _STR, nullable=False), pa.field("created_at", _TS_UTC, nullable=False), pa.field("parent_checkpoint_id", _STR, nullable=True), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_BENCHMARK_SUITES: pa.schema( [ pa.field("suite_name", _STR, nullable=False), pa.field("suite_version", _I64, nullable=False), pa.field("storage_uri", _STR, nullable=False), pa.field("created_at", _TS_UTC, nullable=False), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), TABLE_BENCHMARK_RESULTS: pa.schema( [ pa.field("result_id", _STR, nullable=False), pa.field("suite_name", _STR, nullable=False), pa.field("suite_version", _I64, nullable=False), pa.field("checkpoint_id", _STR, nullable=False), pa.field("dataset_hash", _STR, nullable=False), pa.field("partial_subset", _F64, nullable=False), pa.field("partial_seed", _I64, nullable=True), pa.field("storage_uri", _STR, nullable=False), pa.field("computed_at", _TS_UTC, nullable=False), pa.field("metadata", _STR, nullable=False), _extras_field(), ] ), } ROW_CLASSES = { TABLE_PARTICIPANTS: ParticipantRow, TABLE_RECORDINGS: RecordingRow, TABLE_MODALITIES: ModalityRow, TABLE_EVENTS: EventRow, TABLE_EMBEDDINGS: EmbeddingRow, TABLE_VIRGO_ASSETS: VirgoAssetRow, TABLE_CHECKPOINTS: CheckpointRow, TABLE_BENCHMARK_SUITES: BenchmarkSuiteRow, TABLE_BENCHMARK_RESULTS: BenchmarkResultRow, } PRIMARY_KEYS: dict[str, list[str]] = { name: list(cls.__primary_key__) # type: ignore[attr-defined] for name, cls in ROW_CLASSES.items() }