Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def cmip6_data_catalog(sample_data_dir) -> pd.DataFrame:
return adapter.find_local_datasets(sample_data_dir / "CMIP6")


@pytest.fixture(scope="session")
def cmip6_data_catalog_drs(sample_data_dir) -> pd.DataFrame:
config = Config.default()
config.cmip6_parser = "drs"

adapter = CMIP6DatasetAdapter(config=config)
return adapter.find_local_datasets(sample_data_dir / "CMIP6")


@pytest.fixture(scope="session")
def obs4mips_data_catalog(sample_data_dir) -> pd.DataFrame:
adapter = Obs4MIPsDatasetAdapter()
Expand Down
4 changes: 2 additions & 2 deletions packages/climate-ref/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ def db_seeded_template(tmp_path_session, cmip6_data_catalog, obs4mips_data_catal
adapter = CMIP6DatasetAdapter()
with database.session.begin():
for instance_id, data_catalog_dataset in cmip6_data_catalog.groupby(adapter.slug_column):
adapter.register_dataset(config, database, data_catalog_dataset)
adapter.register_dataset(database, data_catalog_dataset)

# Seed the obs4MIPs sample datasets
adapter_obs = Obs4MIPsDatasetAdapter()
with database.session.begin():
for instance_id, data_catalog_dataset in obs4mips_data_catalog.groupby(adapter_obs.slug_column):
adapter_obs.register_dataset(config, database, data_catalog_dataset)
adapter_obs.register_dataset(database, data_catalog_dataset)

with database.session.begin():
_register_provider(database, example_provider)
Expand Down
2 changes: 1 addition & 1 deletion packages/climate-ref/src/climate_ref/cli/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def ingest( # noqa: PLR0913
logger.info(f"Would save dataset {instance_id} to the database")
continue
else:
adapter.register_dataset(config, db, data_catalog_dataset)
adapter.register_dataset(db, data_catalog_dataset)

if solve:
solve_required_executions(
Expand Down
64 changes: 64 additions & 0 deletions packages/climate-ref/src/climate_ref/data_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pandas as pd
from attrs import define
from loguru import logger

from climate_ref.database import Database
from climate_ref.datasets.base import DatasetAdapter
from climate_ref.datasets.mixins import FinaliseableDatasetAdapterMixin


@define
class DataCatalog:
"""
Data catalog for managing datasets in the database.

This class provides an abstraction layer for interacting with a database-backed data catalog.
"""

database: Database
adapter: DatasetAdapter
_df: pd.DataFrame | None = None

def finalise(self, subset: pd.DataFrame) -> pd.DataFrame:
"""
Finalise the datasets in the provided subset.

This is a no-op if the adapter does not support finalisation.
"""
if not isinstance(self.adapter, FinaliseableDatasetAdapterMixin):
return subset

if "finalised" in subset.columns and not subset["finalised"].all():
subset_to_finalise = subset[~subset["finalised"]].copy()
logger.info(f"Finalising {len(subset_to_finalise)} datasets")
finalised_datasets = self.adapter.finalise_datasets(subset_to_finalise)

if len(finalised_datasets) < len(subset_to_finalise):
logger.warning(
f"Finalised {len(finalised_datasets)} datasets, but expected {len(subset_to_finalise)}. "
"Some datasets may not have been finalised."
)

# Merge the finalised datasets back into the original subset/data catalog
subset.update(finalised_datasets, overwrite=True)
subset = subset.infer_objects()

# Update the database with the finalised datasets
for instance_id, data_catalog_dataset in finalised_datasets.groupby(self.adapter.slug_column):
logger.debug(f"Processing dataset {instance_id}")
with self.database.session.begin():
self.adapter.register_dataset(self.database, data_catalog_dataset)
if self._df is not None:
self._df.update(subset_to_finalise, overwrite=True)
self._df = self._df.infer_objects()

return subset

def to_frame(self) -> pd.DataFrame:
"""
Load the data catalog into a DataFrame.
"""
if self._df is None:
logger.info("Loading data catalog from database")
self._df = self.adapter.load_catalog(self.database)
return self._df
37 changes: 37 additions & 0 deletions packages/climate-ref/src/climate_ref/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,40 @@ def get_or_create(
instance = model(**params)
self.session.add(instance)
return instance, True

def create_or_update(
self, model: type[Table], values: dict[str, Any] | None = None, **kwargs: Any
) -> tuple[Table, bool]:
"""
Create or update an instance of a model

This doesn't commit the transaction,
so you will need to call `session.commit()` after this method
or use a transaction context manager.

Parameters
----------
model
The model to create or update
values
Value to set on the instance when creating or updating it
kwargs
The filter parameters to use when querying for an instance

Returns
-------
:
A tuple containing the instance and a boolean indicating if the instance was created
"""
instance = self.session.query(model).filter_by(**kwargs).first()
if instance and values:
for key, value in values.items():
setattr(instance, key, value)
created = False
else:
params = {**kwargs, **(values or {})}
instance = model(**params)
self.session.add(instance)
created = True

return instance, created
41 changes: 19 additions & 22 deletions packages/climate-ref/src/climate_ref/datasets/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from pathlib import Path
from typing import Any, Protocol, cast
from typing import Any, Protocol, cast, runtime_checkable

import pandas as pd
from loguru import logger
from sqlalchemy.orm import joinedload

from climate_ref.config import Config
from climate_ref.database import Database
from climate_ref.datasets.utils import validate_path
from climate_ref.models.dataset import Dataset, DatasetFile
Expand All @@ -25,9 +24,6 @@ def _log_duplicate_metadata(
invalid_dataset_nunique = invalid_datasets.loc[instance_id]
invalid_dataset_columns = invalid_dataset_nunique[invalid_dataset_nunique.gt(1)].index.tolist()

# Include time_range in the list of invalid columns to make debugging easier
invalid_dataset_columns.append("time_range")

data_catalog_subset = data_catalog[data_catalog[slug_column] == instance_id]

logger.error(
Expand Down Expand Up @@ -60,6 +56,7 @@ def __call__(self, file: str, **kwargs: Any) -> dict[str, Any]:
...


@runtime_checkable
class DatasetAdapter(Protocol):
"""
An adapter to provide a common interface for different dataset types
Expand Down Expand Up @@ -169,16 +166,15 @@ def validate_data_catalog(self, data_catalog: pd.DataFrame, skip_invalid: bool =

return data_catalog

def register_dataset(
self, config: Config, db: Database, data_catalog_dataset: pd.DataFrame
) -> Dataset | None:
def register_dataset(self, db: Database, data_catalog_dataset: pd.DataFrame) -> Dataset | None:
"""
Register a dataset in the database using the data catalog

If an existing dataset with the same slug already exists in the database,
this will update the dataset with the new metadata and files.

Parameters
----------
config
Configuration object
db
Database instance
data_catalog_dataset
Expand All @@ -198,21 +194,23 @@ def register_dataset(
slug = unique_slugs[0]

dataset_metadata = data_catalog_dataset[list(self.dataset_specific_metadata)].iloc[0].to_dict()
dataset, created = db.get_or_create(DatasetModel, defaults=dataset_metadata, slug=slug)
if not created:
logger.warning(f"{dataset} already exists in the database. Skipping")
return None
dataset, created = db.create_or_update(DatasetModel, values=dataset_metadata, slug=slug)
db.session.flush()

if not created:
# If the dataset already exists, then we are updating it
# We need to check if the dataset type matches
logger.info(f"Updating existing dataset {slug}")
for dataset_file in data_catalog_dataset.to_dict(orient="records"):
path = validate_path(dataset_file.pop("path"))

db.session.add(
DatasetFile(
path=str(path),
dataset_id=dataset.id,
start_time=dataset_file.pop("start_time"),
end_time=dataset_file.pop("end_time"),
)
db.create_or_update(
DatasetFile,
values={
"dataset_id": dataset.id,
**{k: dataset_file[k] for k in self.file_specific_metadata if k in dataset_file},
},
path=str(path),
)
return dataset

Expand All @@ -237,7 +235,6 @@ def _get_dataset_files(self, db: Database, limit: int | None = None) -> pd.DataF
{
**{k: getattr(file, k) for k in self.file_specific_metadata},
**{k: getattr(file.dataset, k) for k in self.dataset_specific_metadata},
"finalised": file.dataset.finalised,
}
for file in result
],
Expand Down
86 changes: 62 additions & 24 deletions packages/climate-ref/src/climate_ref/datasets/cmip6.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from climate_ref.config import Config
from climate_ref.datasets.base import DatasetAdapter, DatasetParsingFunction
from climate_ref.datasets.cmip6_parsers import parse_cmip6_complete, parse_cmip6_drs
from climate_ref.datasets.mixins import FinaliseableDatasetAdapterMixin
from climate_ref.models.dataset import CMIP6Dataset


Expand Down Expand Up @@ -62,6 +63,10 @@ def _fix_parent_variant_label(group: pd.DataFrame) -> pd.DataFrame:
if "branch_time_in_parent" in data_catalog:
data_catalog["branch_time_in_parent"] = _clean_branch_time(data_catalog["branch_time_in_parent"])

if "init_year" in data_catalog:
# Convert init_year to numeric, coercing errors to NaN
data_catalog["init_year"] = pd.to_numeric(data_catalog["init_year"])

return data_catalog


Expand All @@ -71,7 +76,7 @@ def _clean_branch_time(branch_time: pd.Series[str]) -> pd.Series[float]:
return pd.to_numeric(branch_time.astype(str).str.replace("D", ""), errors="coerce")


class CMIP6DatasetAdapter(DatasetAdapter):
class CMIP6DatasetAdapter(FinaliseableDatasetAdapterMixin, DatasetAdapter):
"""
Adapter for CMIP6 datasets
"""
Expand All @@ -89,6 +94,7 @@ class CMIP6DatasetAdapter(DatasetAdapter):
"frequency",
"grid",
"grid_label",
"init_year",
"institution_id",
"nominal_resolution",
"parent_activity_id",
Expand Down Expand Up @@ -155,6 +161,32 @@ def get_parsing_function(self) -> DatasetParsingFunction:
logger.info(f"Using DRS CMIP6 parser (config value: {parser_type})")
return parse_cmip6_drs

def _clean_dataframe(self, datasets: pd.DataFrame) -> pd.DataFrame:
# Convert the start_time and end_time columns to datetime objects
# We don't know the calendar used in the dataset (TODO: Check what ecgtools does)
datasets["start_time"] = _parse_datetime(datasets["start_time"])
datasets["end_time"] = _parse_datetime(datasets["end_time"])

drs_items = [
*self.dataset_id_metadata,
self.version_metadata,
]
datasets["instance_id"] = datasets.apply(
lambda row: "CMIP6." + ".".join([row[item] for item in drs_items]), axis=1
)

# Add in any missing metadata columns
missing_columns = set(self.dataset_specific_metadata + self.file_specific_metadata) - set(
datasets.columns
)
if missing_columns:
for column in missing_columns:
datasets[column] = pd.NA

# Temporary fix for some datasets
# TODO: Replace with a standalone package that contains metadata fixes for CMIP6 datasets
return _apply_fixes(datasets)

def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
"""
Generate a data catalog from the specified file or directory
Expand Down Expand Up @@ -185,31 +217,37 @@ def find_local_datasets(self, file_or_directory: Path) -> pd.DataFrame:
joblib_parallel_kwargs={"n_jobs": self.n_jobs},
).build(parsing_func=parsing_function)

datasets: pd.DataFrame = builder.df.drop(["init_year"], axis=1)
return self._clean_dataframe(builder.df)

# Convert the start_time and end_time columns to datetime objects
# We don't know the calendar used in the dataset (TODO: Check what ecgtools does)
datasets["start_time"] = _parse_datetime(datasets["start_time"])
datasets["end_time"] = _parse_datetime(datasets["end_time"])
def finalise_datasets(self, datasets: pd.DataFrame) -> pd.DataFrame:
"""
Finalise a subset of datasets by applying the complete parser

drs_items = [
*self.dataset_id_metadata,
self.version_metadata,
]
datasets["instance_id"] = datasets.apply(
lambda row: "CMIP6." + ".".join([row[item] for item in drs_items]), axis=1
)
This is used to lazily parse the datasets after they have been filtered.

# Add in any missing metadata columns
missing_columns = set(self.dataset_specific_metadata + self.file_specific_metadata) - set(
datasets.columns
)
if missing_columns:
for column in missing_columns:
datasets[column] = pd.NA
Parameters
----------
datasets
DataFrame of datasets to finalise

# Temporary fix for some datasets
# TODO: Replace with a standalone package that contains metadata fixes for CMIP6 datasets
datasets = _apply_fixes(datasets)
Returns
-------
:
DataFrame of finalised datasets
"""
if "path" not in datasets.columns:
raise ValueError("The 'path' column is required to finalise the datasets")

finalised_rows = []
for index, row in datasets.iterrows():
parsed_row = parse_cmip6_complete(row["path"])
if "INVALID_ASSET" in parsed_row:
logger.warning(f"Failed to finalise dataset at {row['path']}: {parsed_row['INVALID_ASSET']}")
continue
finalised_rows.append({"index": index, **parsed_row})

# We need to preserve the original index to be able to update the original dataframe
finalised_df = pd.DataFrame(finalised_rows).set_index("index")
finalised_df.index.name = None

return datasets
return self._clean_dataframe(finalised_df)
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ def parse_cmip6_complete(file: str, **kwargs: Any) -> dict[str, Any]:
info["init_year"] = init_year
info["start_time"] = start_time
info["end_time"] = end_time
if not (start_time and end_time):
info["time_range"] = None
else:
info["time_range"] = f"{start_time}-{end_time}"
info["path"] = str(file)
info["version"] = extract_attr_with_regex(str(file), regex=r"v\d{4}\d{2}\d{2}|v\d{1}") or "v0"

Expand Down Expand Up @@ -180,7 +176,7 @@ def parse_cmip6_drs(file: str, **kwargs: Any) -> dict[str, Any]:

if info.get("time_range"):
# Parse the time range if it exists
start_time, end_time = _parse_daterange(info["time_range"])
start_time, end_time = _parse_daterange(info.pop("time_range"))
info["start_time"] = start_time
info["end_time"] = end_time

Expand Down
Loading
Loading