Skip to content
Merged
56 changes: 56 additions & 0 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,13 +309,69 @@ class SolarPosition(TimeWindowMixin):
"""Solar position configuration model."""


class T0Embedding(Base):
"""Configuration for the t0 time embedding."""

periods: list[str] = Field(
default=[],
description="""List of periods to embed (e.g., "1h", "Nh", "1y", "Ny")""",
)

embeddings: list[str] = Field(
default=[],
description="List of embeddings to use for each period.",
)

@field_validator("periods")
def validate_periods(cls, periods: list[str]) -> list[str]:
"""Validate 'periods'."""
for period in periods:

if not isinstance(period, str):
raise ValueError(f"Each period must be a string, found {type(period)}")

unit = period[-1]
if unit not in ["h", "y"]:
raise ValueError(f"""Unit {unit} needs to in ["h","y"]""")

if not period[:-1].isdigit():
raise ValueError(f"{period[:-1]} not recognised as an integer")

if unit=="y" and not int(period[:-1])>0:
raise ValueError(f"When using unit y the period (={period[:-1]}) must be > 0")

if unit=="h" and not (1<=int(period[:-1])<=24):
raise ValueError(
f"When using unit h the period (={period[:-1]}) must be in interval [1, 24]",
)

return periods

@field_validator("embeddings")
def validate_embeddings(cls, embeddings: list[str]) -> list[str]:
"""Validator for 'embeddings'."""
for embedding in embeddings:
if embedding not in ["cyclic", "linear"]:
raise ValueError(f"Embedding ({embedding}) must be cyclic or linear")
return embeddings

@model_validator(mode="after")
def check_periods_and_embeddings_len(self) -> "T0Embedding":
"""Validate each period has an embedding."""
if len(self.periods)!=len(self.embeddings):
raise ValueError("The number of periods and embeddings must match")
return self


class InputData(Base):
"""Input data model."""

satellite: Satellite | None = None
nwp: MultiNWP | None = None
generation: Generation | None = None
solar_position: SolarPosition | None = None
t0_embedding: T0Embedding | None = None


class Configuration(Base):
"""Configuration model for the dataset."""
Expand Down
2 changes: 1 addition & 1 deletion ocf_data_sampler/numpy_sample/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Conversion from Xarray to NumpySample"""

from .datetime_features import encode_datetimes
from .datetime_features import encode_datetimes, get_t0_embedding
from .generation import convert_generation_to_numpy_sample, GenerationSampleKey
from .nwp import convert_nwp_to_numpy_sample, NWPSampleKey
from .satellite import convert_satellite_to_numpy_sample, SatelliteSampleKey
Expand Down
40 changes: 40 additions & 0 deletions ocf_data_sampler/numpy_sample/datetime_features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Functions to create trigonometric date and time inputs."""

from typing import Literal

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -27,3 +29,41 @@ def encode_datetimes(datetimes: pd.DatetimeIndex) -> NumpySample:
"time_sin": np.sin(time_in_radians),
"time_cos": np.cos(time_in_radians),
}


def get_t0_embedding(
t0: pd.Timestamp,
periods: list[str],
embeddings: list[Literal["cyclic", "linear"]],
) -> dict[str, np.ndarray]:
"""Creates dictionary of t0 time embeddings.

Args:
t0: The time to create sin-cos embeddings for
periods: List of periods to encode (e.g., "1h", "Nh", "1y", "Ny")
embeddings: How to represent each of these periods. Either "cyclic" or "linear". When cyclic
the period is sin-cos embedded, else it is 0-1 scaled as fraction through the period.
Note that using "cyclic" adds 2 elements to the output vector to embed a period whilst
"linear" adds only 1 element.
"""
features = []

for period_str, embedding in zip(periods, embeddings, strict=True):

if period_str.endswith("h"):
period_hours = int(period_str.removesuffix("h"))
frac = (t0.hour + t0.minute / 60) / period_hours

elif period_str.endswith("y"):
period_years = int(period_str.removesuffix("y"))
days_in_year = 366 if t0.is_leap_year else 365
frac = (((t0.dayofyear-1) / days_in_year) + t0.year % period_years) / period_years

if embedding=="cyclic":
radians = 2 * np.pi * frac
features.extend([np.sin(radians), np.cos(radians)])

elif embedding=="linear":
features.append(frac)

return {"t0_embedding": np.array(features, dtype=np.float32)}
48 changes: 28 additions & 20 deletions ocf_data_sampler/torch_datasets/pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
convert_nwp_to_numpy_sample,
convert_satellite_to_numpy_sample,
encode_datetimes,
get_t0_embedding,
make_sun_position_numpy_sample,
)
from ocf_data_sampler.numpy_sample.collate import stack_np_samples_into_batch
Expand Down Expand Up @@ -243,30 +244,37 @@ def process_and_combine_datasets(
},
)

# Add datetime features
datetimes = pd.DatetimeIndex(da_generation.time_utc.values)
datetime_features = encode_datetimes(datetimes=datetimes)
# Add datetime features
generation_config = self.config.input_data.generation
datetimes = pd.date_range(
t0 + minutes(generation_config.time_resolution_minutes),
t0 + minutes(generation_config.interval_end_minutes),
freq=minutes(generation_config.time_resolution_minutes),
)
numpy_modalities.append(encode_datetimes(datetimes=datetimes))

numpy_modalities.append(datetime_features)
if self.config.input_data.t0_embedding is not None:
emb_conf = self.config.input_data.t0_embedding
numpy_modalities.append(get_t0_embedding(t0, emb_conf.periods, emb_conf.embeddings))

# Only add solar position if explicitly configured
if self.config.input_data.solar_position is not None:
solar_config = self.config.input_data.solar_position
# Only add solar position if explicitly configured
if self.config.input_data.solar_position is not None:
solar_config = self.config.input_data.solar_position

# Create datetime range for solar position calculation
datetimes = pd.date_range(
t0 + minutes(solar_config.interval_start_minutes),
t0 + minutes(solar_config.interval_end_minutes),
freq=minutes(solar_config.time_resolution_minutes),
)
# Create datetime range for solar position calculation
datetimes = pd.date_range(
t0 + minutes(solar_config.interval_start_minutes),
t0 + minutes(solar_config.interval_end_minutes),
freq=minutes(solar_config.time_resolution_minutes),
)

numpy_modalities.append(
make_sun_position_numpy_sample(
datetimes,
da_generation.longitude.values,
da_generation.latitude.values,
),
)
numpy_modalities.append(
make_sun_position_numpy_sample(
datetimes,
da_generation.longitude.values,
da_generation.latitude.values,
),
)

# Combine all the modalities and fill NaNs
combined_sample = merge_dicts(numpy_modalities)
Expand Down
62 changes: 61 additions & 1 deletion tests/numpy_sample/test_datetime_features.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import pandas as pd

from ocf_data_sampler.numpy_sample.datetime_features import encode_datetimes
from ocf_data_sampler.numpy_sample.datetime_features import (
encode_datetimes,
get_t0_embedding,
)


def test_encode_datetimes():
Expand All @@ -16,3 +19,60 @@ def test_encode_datetimes():
# Values should be between -1 and 1
for key in ("date_sin", "date_cos", "time_sin", "time_cos"):
assert np.all(np.abs(features[key]) <= 1)


def test_get_t0_embedding():

def check(t0s, period_strs, embeddings, xs, period_floats):
# Test the results are expected for each t0 time
for x, t0 in zip(xs, t0s, strict=False):
results = get_t0_embedding(t0, period_strs, embeddings)["t0_embedding"]

expected_results = []
for p, emb in zip(period_floats, embeddings, strict=False):
if emb=="cyclic":
expected_results.extend([np.sin(2*np.pi*(x / p)), np.cos(2*np.pi*(x / p))])
elif emb=="linear":
expected_results.append(x / p)
else:
raise ValueError

expected_results = np.array(expected_results)

assert len(expected_results)==len(results)

if not np.allclose(results, expected_results, atol=1e-6):
raise ValueError(f"{results}!={expected_results}")

# Define some t0 times and periods to check
t0s = pd.date_range("2024-01-01 00:00", "2024-01-01 06:00")
period_strs = ["1h", "1h", "2h", "6h"]
embeddings = ["linear", "cyclic", "cyclic", "cyclic"]

# Equivalent times and periods in float form
xs = np.linspace(0, 6, num=len(t0s))
period_floats = [1, 1, 2, 6]

check(t0s, period_strs, embeddings, xs, period_floats)

# Repeat the check focusing on year periods rather than hours
t0s = pd.to_datetime(
[
"2020-01-01 00:00", "2020-01-01 23:30", "2020-01-02 00:00",
"2020-06-10 00:00", "2021-01-01 00:00", "2021-01-02 00:00",
],
)
period_strs = ["1y", "2y"]
embeddings = ["cyclic", "cyclic"]


# Equivalent times and periods in float form
# Note:
# - When doing year encoding we don't consider time of day
# - 2020 is a leap year but 2021 is not
# - 2020-06-10 is the 162nd day of that year
xs = np.array([0, 0, 1/366, 161/366, 1, 1+1/365], dtype=np.float32)
period_floats = [1, 2]

check(t0s, period_strs, embeddings, xs, period_floats)

4 changes: 4 additions & 0 deletions tests/test_data/configs/pvnet_test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ input_data:
interval_start_minutes: -15
interval_end_minutes: 15
time_resolution_minutes: 5

t0_embedding:
periods: ["1h", "24h", "1y"]
embeddings: ["cyclic", "cyclic", "cyclic"]
4 changes: 3 additions & 1 deletion tests/torch_datasets/test_pvnet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_pvnet_dataset(pvnet_config_filename):
assert isinstance(sample, dict)

# Specific keys should always be present
required_keys = ["nwp", "satellite_actual", "generation", "t0"]
required_keys = ["nwp", "satellite_actual", "generation", "t0", "t0_embedding"]
for key in required_keys:
assert key in sample

Expand Down Expand Up @@ -63,6 +63,8 @@ def test_pvnet_dataset(pvnet_config_filename):
assert sample["nwp"]["ukv"]["nwp"].shape == (4, 1, 2, 2)
# 3 hours of 30 minute data (inclusive)
assert sample["generation"].shape == (7,)
# The config uses 3 periods each of which generates a sin and cos embedding
assert sample["t0_embedding"].shape == (6,)


def test_pvnet_dataset_sites(pvnet_site_config_filename):
Expand Down