Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import pytest

from tests.utils import generate_sample_embeddings_for_run
from timdex_dataset_api import TIMDEXDataset
from timdex_dataset_api.embeddings import (
METADATA_SELECT_FILTER_COLUMNS,
TIMDEX_DATASET_EMBEDDINGS_SCHEMA,
Expand Down Expand Up @@ -302,9 +301,7 @@ def test_current_embeddings_view_single_run(timdex_dataset_for_embeddings_views)

# write embeddings for run "apple-1"
td.embeddings.write(generate_sample_embeddings_for_run(td, run_id="apple-1"))

# NOTE: at time of test creation, this manual reload is required
td = TIMDEXDataset(td.location)
Comment on lines -305 to -307
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and other similar lines: based on previous refactor, missed at the time, we can lean on the improved .refresh() method!

td.refresh()

# query current_embeddings for apple source using read_dataframe
result = td.embeddings.read_dataframe(table="current_embeddings", source="apple")
Expand All @@ -320,9 +317,7 @@ def test_current_embeddings_view_multiple_runs(timdex_dataset_for_embeddings_vie
# write embeddings for runs "orange-1" and "orange-2"
td.embeddings.write(generate_sample_embeddings_for_run(td, run_id="orange-1"))
td.embeddings.write(generate_sample_embeddings_for_run(td, run_id="orange-2"))

# NOTE: at time of test creation, this manual reload is required
td = TIMDEXDataset(td.location)
td.refresh()

# query current_embeddings for orange source using read_dataframe
result = td.embeddings.read_dataframe(table="current_embeddings", source="orange")
Expand Down Expand Up @@ -363,9 +358,7 @@ def test_current_embeddings_view_handles_duplicate_run_embeddings(
td, run_id="lemon-2", embedding_timestamp="2025-08-03T00:00:00+00:00"
)
)

# NOTE: at time of test creation, this manual reload is required
td = TIMDEXDataset(td.location)
td.refresh()

# check all embeddings for lemon-2 to verify both writes exist
all_lemon_2 = td.embeddings.read_dataframe(table="embeddings", run_id="lemon-2")
Expand Down Expand Up @@ -416,9 +409,7 @@ def test_embeddings_view_includes_all_embeddings(timdex_dataset_for_embeddings_v
td, run_id="lemon-2", embedding_timestamp="2025-08-03T00:00:00+00:00"
)
)

# NOTE: at time of test creation, this manual reload is required
td = TIMDEXDataset(td.location)
td.refresh()

# query all embeddings for lemon source
result = td.embeddings.read_dataframe(table="embeddings", source="lemon")
Expand All @@ -435,3 +426,25 @@ def test_embeddings_view_includes_all_embeddings(timdex_dataset_for_embeddings_v
lemon_2_embeddings = result[result["run_id"] == "lemon-2"]
assert len(lemon_2_embeddings) == 10 # 5 from each write
assert (lemon_2_embeddings["run_date"] == date(2025, 8, 2)).all()


def test_embeddings_read_batches_iter_returns_empty_when_embeddings_missing(
timdex_dataset_empty, caplog
):
result = list(timdex_dataset_empty.embeddings.read_batches_iter())
assert result == []
assert (
"Table 'embeddings' not found in DuckDB context. Embeddings may not yet exist "
"or TIMDEXDataset.refresh() may be required." in caplog.text
)


def test_embeddings_read_batches_iter_returns_empty_for_invalid_table(
timdex_embeddings_with_runs, caplog
):
"""read_batches_iter returns empty iterator for nonexistent table name."""
with pytest.raises(
ValueError,
match="Invalid table: 'nonexistent'",
):
list(timdex_embeddings_with_runs.read_batches_iter(table="nonexistent"))
27 changes: 26 additions & 1 deletion tests/test_read.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ruff: noqa: D205, D209, PLR2004

import re

import pandas as pd
import pyarrow as pa
Expand Down Expand Up @@ -280,3 +280,28 @@ def test_read_batches_iter_limit_returns_n_rows(timdex_dataset_multi_source):
batches = timdex_dataset_multi_source.read_batches_iter(limit=10)
table = pa.Table.from_batches(batches)
assert len(table) == 10


def test_read_batches_iter_returns_empty_when_metadata_missing(
timdex_dataset_empty, caplog
):
with pytest.raises(
ValueError,
match=re.escape(
"Table 'records' not found in DuckDB context. If this is a new dataset, "
"either records do not yet exist or a "
"TIMDEXDataset.metadata.rebuild_dataset_metadata() may be required."
),
):
list(timdex_dataset_empty.read_batches_iter())


def test_read_batches_iter_returns_empty_for_invalid_table(
timdex_dataset_multi_source, caplog
):
"""read_batches_iter returns empty iterator for nonexistent table name."""
with pytest.raises(
ValueError,
match="Invalid table: 'nonexistent'",
):
list(timdex_dataset_multi_source.read_batches_iter(table="nonexistent"))
2 changes: 1 addition & 1 deletion timdex_dataset_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
from timdex_dataset_api.record import DatasetRecord

__version__ = "3.9.0"
__version__ = "3.10.0"

__all__ = [
"DatasetEmbedding",
Expand Down
29 changes: 21 additions & 8 deletions timdex_dataset_api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,17 +443,30 @@ def read_batches_iter(
"""
start_time = time.perf_counter()

# ensure valid table
if table not in ["records", "current_records"]:
raise ValueError(f"Invalid table: '{table}'")

# ensure table exists
try:
self.get_sa_table("metadata", table)
except ValueError as exc:
raise ValueError(
f"Table '{table}' not found in DuckDB context. If this is a new "
"dataset, either records do not yet exist or a "
"TIMDEXDataset.metadata.rebuild_dataset_metadata() may be required."
) from exc

temp_table_name = "read_meta_chunk"
total_yield_count = 0

for i, meta_chunk_df in enumerate(
self._iter_meta_chunks(
table,
limit=limit,
where=where,
**filters,
)
):
meta_chunks = self._iter_meta_chunks(
table,
limit=limit,
where=where,
**filters,
)
for i, meta_chunk_df in enumerate(meta_chunks):
batch_time = time.perf_counter()
batch_yield_count = len(meta_chunk_df)
total_yield_count += batch_yield_count
Expand Down
13 changes: 13 additions & 0 deletions timdex_dataset_api/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,19 @@ def read_batches_iter(
"""
start_time = time.perf_counter()

if table not in ["embeddings", "current_embeddings", "current_run_embeddings"]:
raise ValueError(f"Invalid table: '{table}'")

# ensure table exists
try:
self.timdex_dataset.get_sa_table("data", table)
except ValueError:
logger.warning(
f"Table '{table}' not found in DuckDB context. Embeddings may not yet "
"exist or TIMDEXDataset.refresh() may be required."
)
return

data_query = self._build_query(
table,
columns,
Expand Down