Skip to content

Commit d4fb0bb

Browse files
authored
Added plot legend and updates tests (#59)
1 parent 70bcb3f commit d4fb0bb

File tree

8 files changed

+61
-20
lines changed

8 files changed

+61
-20
lines changed

.github/workflows/run_ruff.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
name: Ruff
22

3-
on:
4-
push:
5-
branches: [ main ]
6-
pull_request:
7-
branches: [ main ]
3+
on: workflow_call
84

95
jobs:
106
ruff:

.github/workflows/run_tests.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
jobs:
14-
test:
14+
run_ruff:
15+
uses: ./.github/workflows/run_ruff.yml
1516

17+
test:
18+
needs: [run_ruff]
1619
strategy:
1720
fail-fast: false
1821
matrix:

RATapi/inputs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def make_cells(project: RATapi.Project) -> Cells:
387387
]
388388

389389
cells.f20 = [param.name for param in project.domain_ratios]
390+
cells.f21 = [contrast.name for contrast in project.contrasts]
390391

391392
return cells
392393

RATapi/utils/plotting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,14 +74,14 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d
7474
ref_plot.cla()
7575
sld_plot.cla()
7676

77-
for i, (r, sd, sld, layer) in enumerate(
78-
zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.resampledLayers),
77+
for i, (r, sd, sld, name) in enumerate(
78+
zip(data.reflectivity, data.shiftedData, data.sldProfiles, data.contrastNames),
7979
):
8080
# Calculate the divisor
8181
div = 1 if i == 0 else 2 ** (4 * (i + 1))
8282

8383
# Plot the reflectivity on plot (1,1)
84-
ref_plot.plot(r[:, 0], r[:, 1] / div, label=f"ref {i+1}", linewidth=2)
84+
ref_plot.plot(r[:, 0], r[:, 1] / div, label=name, linewidth=2)
8585
color = ref_plot.get_lines()[-1].get_color()
8686

8787
if data.dataPresent[i]:
@@ -100,7 +100,8 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, d
100100

101101
# Plot the slds on plot (1,2)
102102
for j in range(len(sld)):
103-
sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=f"sld {i+1}", linewidth=1)
103+
label = name if len(sld) == 1 else f"{name} Domain {j+1}"
104+
sld_plot.plot(sld[j][:, 0], sld[j][:, 1], label=label, linewidth=1)
104105

105106
if data.resample[i] == 1 or data.modelType == "custom xy":
106107
layers = data.resampledLayers[i][0]
@@ -172,6 +173,7 @@ def plot_ref_sld(
172173
data.dataPresent = RATapi.inputs.make_data_present(project)
173174
data.subRoughs = results.contrastParams.subRoughs
174175
data.resample = RATapi.inputs.make_resample(project)
176+
data.contrastNames = [contrast.name for contrast in project.contrasts]
175177

176178
figure = plt.subplots(1, 2)[0]
177179

cpp/rat.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ struct PlotEventData
171171
py::array_t<double> resample;
172172
py::array_t<double> dataPresent;
173173
std::string modelType;
174+
py::list contrastNames;
174175
};
175176

176177
class EventBridge
@@ -271,6 +272,13 @@ class EventBridge
271272
eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2,
272273
pEvent->data->layers, pEvent->data->nLayers,
273274
pEvent->data->layers2, pEvent->data->nLayers2, true);
275+
276+
int offset = 0;
277+
for (int i = 0; i < pEvent->data->nContrast; i++){
278+
eventData.contrastNames.append(std::string(pEvent->data->contrastNames + offset,
279+
pEvent->data->nContrastNames[i]));
280+
offset += pEvent->data->nContrastNames[i];
281+
}
274282
this->callback(event.type, eventData);
275283
}
276284
};
@@ -445,6 +453,7 @@ struct Cells {
445453
py::list f18;
446454
py::list f19;
447455
py::list f20;
456+
py::list f21;
448457
};
449458

450459
struct ProblemDefinition {
@@ -835,7 +844,7 @@ RAT::cell_7 createCell7(const Cells& cells)
835844
cells_struct.f2 = customCaller("Cells.f2", pyListToRatCellWrap3, cells.f2);
836845
cells_struct.f3 = customCaller("Cells.f3", pyListToRatCellWrap2, cells.f3);
837846
cells_struct.f4 = customCaller("Cells.f4", pyListToRatCellWrap2, cells.f4);
838-
cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5); //improve this error
847+
cells_struct.f5 = customCaller("Cells.f5", pyListToRatCellWrap4, cells.f5);
839848
cells_struct.f6 = customCaller("Cells.f6", pyListToRatCellWrap5, cells.f6);
840849
cells_struct.f7 = customCaller("Cells.f7", pyListToRatCellWrap6, cells.f7);
841850
cells_struct.f8 = customCaller("Cells.f8", pyListToRatCellWrap6, cells.f8);
@@ -851,6 +860,7 @@ RAT::cell_7 createCell7(const Cells& cells)
851860
cells_struct.f18 = customCaller("Cells.f18", pyListToRatCellWrap2, cells.f18);
852861
cells_struct.f19 = customCaller("Cells.f19", pyListToRatCellWrap4, cells.f19);
853862
cells_struct.f20 = customCaller("Cells.f20", pyListToRatCellWrap6, cells.f20);
863+
cells_struct.f21 = customCaller("Cells.f21", pyListToRatCellWrap6, cells.f21);
854864

855865
return cells_struct;
856866
}
@@ -1257,7 +1267,8 @@ PYBIND11_MODULE(rat_core, m) {
12571267
.def_readwrite("subRoughs", &PlotEventData::subRoughs)
12581268
.def_readwrite("resample", &PlotEventData::resample)
12591269
.def_readwrite("dataPresent", &PlotEventData::dataPresent)
1260-
.def_readwrite("modelType", &PlotEventData::modelType);
1270+
.def_readwrite("modelType", &PlotEventData::modelType)
1271+
.def_readwrite("contrastNames", &PlotEventData::contrastNames);
12611272

12621273
py::class_<ProgressEventData>(m, "ProgressEventData")
12631274
.def(py::init<>())
@@ -1402,7 +1413,8 @@ PYBIND11_MODULE(rat_core, m) {
14021413
.def_readwrite("f17", &Cells::f17)
14031414
.def_readwrite("f18", &Cells::f18)
14041415
.def_readwrite("f19", &Cells::f19)
1405-
.def_readwrite("f20", &Cells::f20);
1416+
.def_readwrite("f20", &Cells::f20)
1417+
.def_readwrite("f21", &Cells::f21);
14061418

14071419
py::class_<Control>(m, "Control")
14081420
.def(py::init<>())

tests/test_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def standard_layers_cells():
281281
cells.f18 = []
282282
cells.f19 = []
283283
cells.f20 = []
284+
cells.f21 = ["D2O"]
284285

285286
return cells
286287

@@ -309,6 +310,7 @@ def domains_cells():
309310
cells.f18 = [[0, 1], [0, 1]]
310311
cells.f19 = [[1], [1]]
311312
cells.f20 = ["Domain Ratio 1"]
313+
cells.f21 = ["D2O"]
312314

313315
return cells
314316

@@ -337,6 +339,7 @@ def custom_xy_cells():
337339
cells.f18 = []
338340
cells.f19 = []
339341
cells.f20 = []
342+
cells.f21 = ["D2O"]
340343

341344
return cells
342345

tests/test_plotting.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
def data() -> PlotEventData:
16-
"""Creates the fixture for the tests."""
16+
"""Creates the data for the tests."""
1717
data_path = os.path.join(TEST_DIR_PATH, "plotting_data.pickle")
1818
with open(data_path, "rb") as f:
1919
loaded_data = pickle.load(f)
@@ -27,17 +27,29 @@ def data() -> PlotEventData:
2727
data.reflectivity = loaded_data["reflectivity"]
2828
data.shiftedData = loaded_data["shiftedData"]
2929
data.sldProfiles = loaded_data["sldProfiles"]
30+
data.contrastNames = ["D2O", "SMW", "H2O"]
31+
3032
return data
3133

3234

33-
@pytest.fixture
34-
def fig() -> plt.figure:
35+
def domains_data() -> PlotEventData:
36+
"""Creates the fake domains data for the tests."""
37+
domains_data = data()
38+
for sld_list in domains_data.sldProfiles:
39+
sld_list.append(sld_list[0])
40+
41+
return domains_data
42+
43+
44+
@pytest.fixture(params=[False])
45+
def fig(request) -> plt.figure:
3546
"""Creates the fixture for the tests."""
3647
plt.close("all")
3748
figure = plt.subplots(1, 3)[0]
38-
return plot_ref_sld_helper(fig=figure, data=data())
49+
return plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data())
3950

4051

52+
@pytest.mark.parametrize("fig", [False, True], indirect=True)
4153
def test_figure_axis_formating(fig: plt.figure) -> None:
4254
"""Tests the axis formating of the figure."""
4355
ref_plot = fig.axes[0]
@@ -50,13 +62,24 @@ def test_figure_axis_formating(fig: plt.figure) -> None:
5062
assert ref_plot.get_xscale() == "log"
5163
assert ref_plot.get_ylabel() == "Reflectivity"
5264
assert ref_plot.get_yscale() == "log"
53-
assert [label._text for label in ref_plot.get_legend().texts] == ["ref 1", "ref 2", "ref 3"]
65+
assert [label._text for label in ref_plot.get_legend().texts] == ["D2O", "SMW", "H2O"]
5466

5567
assert sld_plot.get_xlabel() == "$Z (\u00c5)$"
5668
assert sld_plot.get_xscale() == "linear"
5769
assert sld_plot.get_ylabel() == "$SLD (\u00c5^{-2})$"
5870
assert sld_plot.get_yscale() == "linear"
59-
assert [label._text for label in sld_plot.get_legend().texts] == ["sld 1", "sld 2", "sld 3"]
71+
labels = [label._text for label in sld_plot.get_legend().texts]
72+
if len(labels) == 3:
73+
assert labels == ["D2O", "SMW", "H2O"]
74+
else:
75+
assert labels == [
76+
"D2O Domain 1",
77+
"D2O Domain 2",
78+
"SMW Domain 1",
79+
"SMW Domain 2",
80+
"H2O Domain 1",
81+
"H2O Domain 2",
82+
]
6083

6184

6285
def test_figure_color_formating(fig: plt.figure) -> None:
@@ -157,3 +180,4 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r
157180
assert data.dataPresent.size == 0
158181
assert (data.subRoughs == reflectivity_calculation_results.contrastParams.subRoughs).all()
159182
assert data.resample.size == 0
183+
assert len(data.contrastNames) == 0

0 commit comments

Comments
 (0)