Skip to content

Commit cb5df08

Browse files
authored
Add Queries for Runs Processing (#6)
1 parent 35b63b2 commit cb5df08

File tree

9 files changed

+943
-1
lines changed

9 files changed

+943
-1
lines changed

.github/workflows/CI.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
runs-on: ubuntu-latest
1212
strategy:
1313
matrix:
14-
python-version: ["3.10", "3.11", "3.12", "3.13"]
14+
python-version: ["3.10"] # add later: "3.11", "3.12", "3.13"
1515
steps:
1616
- name: Checkout repository
1717
uses: actions/checkout@v4

ablate/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
1+
from . import queries, sources
2+
3+
4+
__all__ = ["queries", "sources"]
5+
16
__version__ = "0.1.0"

ablate/queries/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from .grouped_query import GroupedQuery
2+
from .query import Query
3+
from .selectors import (
4+
AbstractMetric,
5+
AbstractParam,
6+
AbstractSelector,
7+
Id,
8+
Metric,
9+
Param,
10+
TemporalMetric,
11+
)
12+
13+
14+
__all__ = [
15+
"AbstractMetric",
16+
"AbstractParam",
17+
"AbstractSelector",
18+
"GroupedQuery",
19+
"Id",
20+
"Metric",
21+
"Param",
22+
"Query",
23+
"TemporalMetric",
24+
]

ablate/queries/grouped_query.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from copy import deepcopy
5+
from typing import TYPE_CHECKING, Callable, Dict, List, Literal
6+
7+
from ablate.core.types import GroupedRun, Run
8+
9+
10+
if TYPE_CHECKING: # pragma: no cover
11+
from .query import Query # noqa: TC004
12+
from .selectors import AbstractMetric
13+
14+
15+
class GroupedQuery:
16+
def __init__(self, groups: List[GroupedRun]) -> None:
17+
"""Query interface for manipulating grouped runs in a functional way.
18+
19+
All methods operate on a shallow copy of the runs in the query, so the original
20+
runs are not modified and assumed to be immutable.
21+
22+
Args:
23+
groups: A list of grouped runs to be queried.
24+
"""
25+
self._grouped = groups
26+
27+
def filter(self, fn: Callable[[GroupedRun], bool]) -> GroupedQuery:
28+
"""Filter the grouped runs in the grouped query based on a predicate function.
29+
30+
Args:
31+
fn: Predicate function that takes in a grouped run and returns a boolean
32+
value.
33+
34+
Returns:
35+
A new grouped query with the grouped runs that satisfy the predicate
36+
function.
37+
"""
38+
return GroupedQuery([g for g in self._grouped[:] if fn(g)])
39+
40+
def map(self, fn: Callable[[GroupedRun], GroupedRun]) -> GroupedQuery:
41+
"""Apply a function to each grouped run in the grouped query.
42+
43+
This function is intended to be used for modifying the grouped runs in the
44+
grouped query. The function should return a new grouped run object as the
45+
original grouped run is not modified.
46+
47+
Args:
48+
fn: Function that takes in a grouped run and returns a new grouped run
49+
object.
50+
51+
Returns:
52+
A new grouped query with the modified grouped runs.
53+
"""
54+
return GroupedQuery([fn(deepcopy(g)) for g in self._grouped])
55+
56+
def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery:
57+
"""Sort the runs inside each grouped run in the grouped query based on a metric.
58+
59+
Args:
60+
key: Metric to sort the grouped runs by.
61+
ascending: Whether to sort in ascending order.
62+
Defaults to False (descending order).
63+
64+
Returns:
65+
A new grouped query with the grouped runs sorted by the specified metric.
66+
"""
67+
return GroupedQuery(
68+
[
69+
GroupedRun(
70+
key=g.key,
71+
value=g.value,
72+
runs=sorted(g.runs, key=key, reverse=not ascending),
73+
)
74+
for g in self._grouped
75+
]
76+
)
77+
78+
def head(self, n: int) -> Query:
79+
"""Get the first n runs inside each grouped run.
80+
81+
Args:
82+
n: Number of runs to return per group.
83+
84+
Returns:
85+
A new query with the first n runs from each grouped run.
86+
"""
87+
return GroupedQuery(
88+
[
89+
GroupedRun(key=g.key, value=g.value, runs=g.runs[:n])
90+
for g in self._grouped
91+
]
92+
)._to_query()
93+
94+
def tail(self, n: int) -> Query:
95+
"""Get the last n runs inside each grouped run.
96+
97+
Args:
98+
n: Number of runs to return per group.
99+
100+
Returns:
101+
A new query with the last n runs from each grouped run.
102+
"""
103+
return GroupedQuery(
104+
[
105+
GroupedRun(key=g.key, value=g.value, runs=g.runs[-n:])
106+
for g in self._grouped
107+
]
108+
)._to_query()
109+
110+
def topk(self, metric: AbstractMetric, k: int) -> Query:
111+
"""Get the top k runs inside each grouped run based on a metric.
112+
113+
Args:
114+
metric: Metric to sort the runs by.
115+
k: Number of top runs to return per group.
116+
117+
Returns:
118+
A new query with the top k runs from each grouped run based on the
119+
specified metric.
120+
"""
121+
return GroupedQuery(
122+
[
123+
GroupedRun(
124+
key=g.key,
125+
value=g.value,
126+
runs=sorted(g.runs, key=metric, reverse=metric.direction == "min")[
127+
:k
128+
],
129+
)
130+
for g in self._grouped
131+
]
132+
)._to_query()
133+
134+
def bottomk(self, metric: AbstractMetric, k: int) -> Query:
135+
"""Get the bottom k runs inside each grouped run based on a metric.
136+
137+
Args:
138+
metric: Metric to sort the runs by.
139+
k: Number of bottom runs to return per group.
140+
141+
Returns:
142+
A new query with the bottom k runs from each grouped run based on the
143+
specified metric.
144+
"""
145+
return GroupedQuery(
146+
[
147+
GroupedRun(
148+
key=g.key,
149+
value=g.value,
150+
runs=sorted(g.runs, key=metric, reverse=metric.direction == "max")[
151+
:k
152+
],
153+
)
154+
for g in self._grouped
155+
]
156+
)._to_query()
157+
158+
def aggregate(
159+
self,
160+
method: Literal["first", "last", "best", "worst", "mean"],
161+
over: AbstractMetric,
162+
) -> Query:
163+
"""Aggregate each group of runs using a specified method.
164+
165+
Supported methods include:
166+
- "first": Selects the first run from each group.
167+
- "last": Selects the last run from each group.
168+
- "best": Selects the run with the best value based on the given metric.
169+
- "worst": Selects the run with the worst value based on the given metric.
170+
- "mean": Computes the mean run across all runs in each group, including
171+
averaged metrics and temporal data, and collapsed metadata.
172+
Args:
173+
method: Aggregation strategy to apply per group.
174+
over: The metric used for comparison when using "best" or "worst" methods.
175+
176+
Raises:
177+
ValueError: If an unsupported aggregation method is provided.
178+
179+
Returns:
180+
A new query with the aggregated runs from each group.
181+
"""
182+
from .query import Query
183+
184+
match method:
185+
case "first":
186+
return self.head(1)
187+
case "last":
188+
return self.tail(1)
189+
case "best":
190+
return self.topk(over, 1)
191+
case "worst":
192+
return self.bottomk(over, 1)
193+
case "mean":
194+
return Query([self._mean_run(g) for g in self._grouped])
195+
case _:
196+
raise ValueError(
197+
f"Unsupported aggregation method: '{method}'. Must be "
198+
"'first', 'last', 'best', 'worst', or 'mean'."
199+
)
200+
201+
@staticmethod
202+
def _mean_run(group: GroupedRun) -> Run:
203+
def _mean(values: List[float]) -> float:
204+
return sum(values) / len(values) if values else float("nan")
205+
206+
def _mean_temporal(runs: List[Run]) -> Dict[str, List[tuple[int, float]]]:
207+
all_keys = set().union(*(r.temporal.keys() for r in runs))
208+
step_accumulator: Dict[str, Dict[int, List[float]]] = {}
209+
210+
for key in all_keys:
211+
step_values = defaultdict(list)
212+
for run in runs:
213+
for step, val in run.temporal.get(key, []):
214+
step_values[step].append(val)
215+
step_accumulator[key] = step_values
216+
217+
return {
218+
key: sorted(
219+
(step, sum(vals) / len(vals)) for step, vals in step_values.items()
220+
)
221+
for key, step_values in step_accumulator.items()
222+
}
223+
224+
def _common_metadata(attr: str) -> Dict[str, str]:
225+
key_sets = [getattr(r, attr).keys() for r in group.runs]
226+
common_keys = set.intersection(*map(set, key_sets))
227+
result = {}
228+
for k in common_keys:
229+
values = {str(getattr(r, attr)[k]) for r in group.runs}
230+
result[k] = next(iter(values)) if len(values) == 1 else "#"
231+
return result
232+
233+
all_metrics = [r.metrics for r in group.runs]
234+
all_keys = set().union(*all_metrics)
235+
mean_metrics = {
236+
k: _mean([m[k] for m in all_metrics if k in m]) for k in all_keys
237+
}
238+
239+
return Run(
240+
id=f"grouped:{group.key}:{group.value}",
241+
params=_common_metadata("params"),
242+
metrics=mean_metrics,
243+
temporal=_mean_temporal(group.runs),
244+
)
245+
246+
def _to_query(self) -> Query:
247+
from .query import Query
248+
249+
return Query([run for group in self._grouped for run in group.runs])
250+
251+
def all(self) -> List[Run]:
252+
"""Collect all runs in the grouped query by flattening the grouped runs.
253+
254+
Returns:
255+
A list of all runs in the grouped query.
256+
"""
257+
return deepcopy(self._to_query()._runs)
258+
259+
def copy(self) -> GroupedQuery:
260+
"""Obtain a shallow copy of the grouped query.
261+
262+
Returns:
263+
A new grouped query with the same grouped runs as the original grouped
264+
query.
265+
"""
266+
return GroupedQuery(self._grouped[:])
267+
268+
def deepcopy(self) -> GroupedQuery:
269+
"""Obtain a deep copy of the grouped query.
270+
271+
Returns:
272+
A new grouped query with deep copies of the grouped runs in the original
273+
grouped query.
274+
"""
275+
return GroupedQuery(deepcopy(self._grouped))
276+
277+
def __len__(self) -> int:
278+
"""Get the number of grouped runs in the grouped query.
279+
280+
Returns:
281+
The number of grouped runs in the grouped query.
282+
"""
283+
return len(self._grouped)

0 commit comments

Comments
 (0)