Skip to content

Commit 6ebfcf1

Browse files
authored
Add Selector Expressions (#21)
1 parent 32d10a8 commit 6ebfcf1

File tree

5 files changed

+223
-77
lines changed

5 files changed

+223
-77
lines changed

README.md

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ To create your first [Report](https://ramppdev.github.io/ablate/modules/report.h
5959
For example, the built in [Mock](https://ramppdev.github.io/ablate/modules/sources.html#mock-source) can be used to simulate runs:
6060

6161
```python
62-
import ablate
62+
from ablate import sources
6363

64-
source = ablate.sources.Mock(
64+
source = Mock(
6565
grid={"model": ["vgg", "resnet"], "lr": [0.01, 0.001]},
6666
num_seeds=2,
6767
)
@@ -73,10 +73,12 @@ Next, the runs can be loaded and processed using functional-style queries to e.g
7373
group by seed, aggregate the results by mean, and finally collect all results into a single list:
7474

7575
```python
76+
from ablate.queries import Query, Metric, Param
77+
7678
runs = (
77-
ablate.queries.Query(source.load())
78-
.sort(ablate.queries.Metric("accuracy", direction="max"))
79-
.groupdiff(ablate.queries.Param("seed"))
79+
Query(source.load())
80+
.sort(Metric("accuracy", direction="max"))
81+
.groupdiff(Param("seed"))
8082
.aggregate("mean")
8183
.all()
8284
)
@@ -87,16 +89,19 @@ Now that the runs are loaded and processed, a [Report](https://ramppdev.github.i
8789
comprising multiple blocks can be created to structure the content:
8890

8991
```python
90-
report = ablate.Report(runs)
91-
report.add(ablate.blocks.H1("Model Performance"))
92+
from ablate import Report
93+
from ablate.blocks import H1, Table
94+
95+
report = Report(runs)
96+
report.add(H1("Model Performance"))
9297
report.add(
93-
ablate.blocks.Table(
98+
Table(
9499
columns=[
95-
ablate.queries.Param("model", label="Model"),
96-
ablate.queries.Param("lr", label="Learning Rate"),
97-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
98-
ablate.queries.Metric("f1", direction="max", label="F1 Score"),
99-
ablate.queries.Metric("loss", direction="min", label="Loss"),
100+
Param("model", label="Model"),
101+
Param("lr", label="Learning Rate"),
102+
Metric("accuracy", direction="max", label="Accuracy"),
103+
Metric("f1", direction="max", label="F1 Score"),
104+
Metric("loss", direction="min", label="Loss"),
100105
]
101106
)
102107
)
@@ -105,7 +110,9 @@ report.add(
105110
Finally, the report can be exported to a desired format such as [Markdown](https://ramppdev.github.io/ablate/modules/exporters.html#ablate.exporters.Markdown):
106111

107112
```python
108-
ablate.exporters.Markdown().export(report)
113+
from ablate.exporters import Markdown
114+
115+
Markdown().export(report)
109116
```
110117

111118
This will produce a `report.md` file with the following content:
@@ -127,24 +134,53 @@ To compose multiple sources, they can be added together using the `+` operator
127134
as they represent lists of [Run](https://ramppdev.github.io/ablate/modules/core.html#ablate.core.types.Run) objects:
128135

129136
```python
130-
runs1 = ablate.sources.Mock(...).load()
131-
runs2 = ablate.sources.Mock(...).load()
137+
runs1 = Mock(...).load()
138+
runs2 = Mock(...).load()
132139

133140
all_runs = runs1 + runs2 # combines both sources into a single list of runs
134141
```
135142

143+
### Selector Expressions
144+
145+
_ablate_ selectors are lightweight expressions that access attributes of experiment runs, such as parameters, metrics, or IDs.
146+
They support standard Python comparison operators and can be composed using logical operators to define complex query logic:
147+
148+
```python
149+
accuracy = Metric("accuracy", direction="max")
150+
loss = Metric("loss", direction="min")
151+
152+
runs = (
153+
Query(source.load())
154+
.filter((accuracy > 0.9) & (loss < 0.1))
155+
.all()
156+
)
157+
```
158+
159+
Selectors return callable predicates, so they can be used in any query operation that requires a condition.
160+
All standard comparisons are supported: `==`, `!=`, `<`, `<=`, `>`, `>=`.
161+
Logical operators `&` (and), `|` (or), and `~` (not) can be used to combine expressions:
162+
163+
```python
164+
from ablate.queries import Id
165+
166+
select = (Param("model") == "resnet") | (Param("lr") < 0.001) # select resnet or LR below 0.001
167+
168+
exclude = ~(Id() == "run-42") # exclude a specific run by ID
169+
170+
runs = Query(source.load()).filter(select & exclude).all()
171+
172+
```
173+
136174
### Functional Queries
137175

138176
_ablate_ queries are functionally pure such that intermediate results are not modified and can be reused:
139177

140178
```python
141-
runs = ablate.sources.Mock(...).load()
179+
runs = Mock(...).load()
142180

143-
sorted_runs = Query(runs).sort(ablate.queries.Metric("accuracy", direction="max"))
181+
sorted_runs = Query(runs).sort(Metric("accuracy", direction="max"))
144182

145-
filtered_runs = sorted_runs.filter(
146-
ablate.queries.Metric("accuracy", direction="max") > 0.9
147-
)
183+
filtered_runs = sorted_runs.filter(Metric("accuracy", direction="max") > 0.9)
148184

149185
sorted_runs.all() # still contains all runs sorted by accuracy
150186
filtered_runs.all() # only contains runs with accuracy > 0.9
@@ -157,25 +193,25 @@ To create more complex reports, blocks can be populated with a custom list of ru
157193

158194
```python
159195
report = ablate.Report(sorted_runs.all())
160-
report.add(ablate.blocks.H1("Report with Sorted Runs and Filtered Runs"))
161-
report.add(ablate.blocks.H2("Sorted Runs"))
196+
report.add(H1("Report with Sorted Runs and Filtered Runs"))
197+
report.add(H2("Sorted Runs"))
162198
report.add(
163-
ablate.blocks.Table(
199+
Table(
164200
columns=[
165-
ablate.queries.Param("model", label="Model"),
166-
ablate.queries.Param("lr", label="Learning Rate"),
167-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
201+
Param("model", label="Model"),
202+
Param("lr", label="Learning Rate"),
203+
Metric("accuracy", direction="max", label="Accuracy"),
168204
]
169205
)
170206
)
171-
report.add(ablate.blocks.H2("Filtered Runs"))
207+
report.add(H2("Filtered Runs"))
172208
report.add(
173-
ablate.blocks.Table(
209+
Table(
174210
runs = filtered_runs.all(), # use filtered runs only for this block
175211
columns=[
176-
ablate.queries.Param("model", label="Model"),
177-
ablate.queries.Param("lr", label="Learning Rate"),
178-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
212+
Param("model", label="Model"),
213+
Param("lr", label="Learning Rate"),
214+
Metric("accuracy", direction="max", label="Accuracy"),
179215
]
180216
)
181217
)

README.rst

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ For example, the built in :class:`~ablate.sources.Mock` can be used to simulate
8080
.. code-block:: python
8181
:linenos:
8282
83-
import ablate
83+
from ablate.sources import Mock
8484
85-
source = ablate.sources.Mock(
85+
source = Mock(
8686
grid={"model": ["vgg", "resnet"], "lr": [0.01, 0.001]},
8787
num_seeds=2,
8888
)
@@ -95,10 +95,12 @@ group by seed, aggregate the results by mean, and finally collect all results in
9595
.. code-block:: python
9696
:linenos:
9797
98+
from ablate.queries import Metric, Param, Query
99+
98100
runs = (
99-
ablate.queries.Query(source.load())
100-
.sort(ablate.queries.Metric("accuracy", direction="max"))
101-
.groupdiff(ablate.queries.Param("seed"))
101+
Query(source.load())
102+
.sort(Metric("accuracy", direction="max"))
103+
.groupdiff(Param("seed"))
102104
.aggregate("mean")
103105
.all()
104106
)
@@ -109,16 +111,19 @@ can be created to structure the content:
109111
.. code-block:: python
110112
:linenos:
111113
112-
report = ablate.Report(runs)
113-
report.add(ablate.blocks.H1("Model Performance"))
114+
from ablate import Report
115+
from ablate.blocks import H1, Table
116+
117+
report = Report(runs)
118+
report.add(H1("Model Performance"))
114119
report.add(
115-
ablate.blocks.Table(
120+
Table(
116121
columns=[
117-
ablate.queries.Param("model", label="Model"),
118-
ablate.queries.Param("lr", label="Learning Rate"),
119-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
120-
ablate.queries.Metric("f1", direction="max", label="F1 Score"),
121-
ablate.queries.Metric("loss", direction="min", label="Loss"),
122+
Param("model", label="Model"),
123+
Param("lr", label="Learning Rate"),
124+
Metric("accuracy", direction="max", label="Accuracy"),
125+
Metric("f1", direction="max", label="F1 Score"),
126+
Metric("loss", direction="min", label="Loss"),
122127
]
123128
)
124129
)
@@ -128,7 +133,9 @@ Finally, the report can be exported to a desired format such as :class:`~ablate.
128133
.. code-block:: python
129134
:linenos:
130135
131-
ablate.exporters.Markdown().export(report)
136+
from ablate.exporters import Markdown
137+
138+
Markdown().export(report)
132139
133140
This will produce a :file:`report.md` file with the following content:
134141

@@ -153,12 +160,47 @@ as they represent lists of :class:`~ablate.core.types.Run` objects:
153160
.. code-block:: python
154161
:linenos:
155162
156-
runs1 = ablate.sources.Mock(...).load()
157-
runs2 = ablate.sources.Mock(...).load()
163+
runs1 = Mock(...).load()
164+
runs2 = Mock(...).load()
158165
159166
all_runs = runs1 + runs2 # combines both sources into a single list of runs
160167
161168
169+
Selector Expressions
170+
~~~~~~~~~~~~~~~~~~~~
171+
172+
`ablate` selectors are lightweight expressions that access attributes of experiment runs, such as parameters, metrics, or IDs.
173+
They support standard Python comparison operators and can be composed using logical operators to define complex query logic:
174+
175+
.. code-block:: python
176+
:linenos:
177+
178+
accuracy = Metric("accuracy", direction="max")
179+
loss = Metric("loss", direction="min")
180+
181+
runs = (
182+
Query(source.load())
183+
.filter((accuracy > 0.9) & (loss < 0.1))
184+
.all()
185+
)
186+
187+
188+
Selectors return callable predicates, so they can be used in any query operation that requires a condition.
189+
All standard comparisons are supported: :attr:`==`, :attr:`!=`, :attr:`<`, :attr:`<=`, :attr:`>`, :attr:`>=`.
190+
Logical operators :attr:`&` (and), :attr:`|` (or), and :attr:`~~` (not) can be used to combine expressions:
191+
192+
.. code-block:: python
193+
:linenos:
194+
195+
from ablate.queries import Id
196+
197+
select = (Param("model") == "resnet") | (Param("lr") < 0.001) # select resnet or LR below 0.001
198+
199+
exclude = ~(Id() == "run-42") # exclude a specific run by ID
200+
201+
runs = Query(source.load()).filter(select & exclude).all()
202+
203+
162204
Functional Queries
163205
~~~~~~~~~~~~~~~~~~
164206

@@ -167,13 +209,11 @@ Functional Queries
167209
.. code-block:: python
168210
:linenos:
169211
170-
runs = ablate.sources.Mock(...).load()
212+
runs = Mock(...).load()
171213
172-
sorted_runs = Query(runs).sort(ablate.queries.Metric("accuracy", direction="max"))
214+
sorted_runs = Query(runs).sort(Metric("accuracy", direction="max"))
173215
174-
filtered_runs = sorted_runs.filter(
175-
ablate.queries.Metric("accuracy", direction="max") > 0.9
176-
)
216+
filtered_runs = sorted_runs.filter(Metric("accuracy", direction="max") > 0.9)
177217
178218
sorted_runs.all() # still contains all runs sorted by accuracy
179219
filtered_runs.all() # only contains runs with accuracy > 0.9
@@ -189,29 +229,30 @@ To create more complex reports, blocks can be populated with a custom list of ru
189229
:linenos:
190230
191231
report = ablate.Report(sorted_runs.all())
192-
report.add(ablate.blocks.H1("Report with Sorted Runs and Filtered Runs"))
193-
report.add(ablate.blocks.H2("Sorted Runs"))
232+
report.add(H1("Report with Sorted Runs and Filtered Runs"))
233+
report.add(H2("Sorted Runs"))
194234
report.add(
195-
ablate.blocks.Table(
235+
Table(
196236
columns=[
197-
ablate.queries.Param("model", label="Model"),
198-
ablate.queries.Param("lr", label="Learning Rate"),
199-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
237+
Param("model", label="Model"),
238+
Param("lr", label="Learning Rate"),
239+
Metric("accuracy", direction="max", label="Accuracy"),
200240
]
201241
)
202242
)
203-
report.add(ablate.blocks.H2("Filtered Runs"))
243+
report.add(H2("Filtered Runs"))
204244
report.add(
205-
ablate.blocks.Table(
245+
Table(
206246
runs = filtered_runs.all(), # use filtered runs only for this block
207247
columns=[
208-
ablate.queries.Param("model", label="Model"),
209-
ablate.queries.Param("lr", label="Learning Rate"),
210-
ablate.queries.Metric("accuracy", direction="max", label="Accuracy"),
248+
Param("model", label="Model"),
249+
Param("lr", label="Learning Rate"),
250+
Metric("accuracy", direction="max", label="Accuracy"),
211251
]
212252
)
213253
)
214254
255+
215256
Extending `ablate`
216257
------------------
217258

0 commit comments

Comments
 (0)