Skip to content

Commit 6bdce54

Browse files
authored
Adds plotting feature using matplotlib (#24)
1 parent f635416 commit 6bdce54

File tree

7 files changed

+400
-3
lines changed

7 files changed

+400
-3
lines changed

.github/workflows/run_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515

1616
strategy:
1717
matrix:
18-
os: [ubuntu-latest, windows-latest, macos-latest]
18+
os: [ubuntu-latest, windows-latest, macos-13]
1919
version: ["3.9", "3.x"]
2020
defaults:
2121
run:

RAT/utils/plotting.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""
2+
Plots using the matplotlib library
3+
"""
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
from RAT.rat_core import PlotEventData, makeSLDProfileXY
7+
8+
9+
class Figure:
10+
"""
11+
Creates a plotting figure.
12+
"""
13+
14+
def __init__(self, row: int = 1, col: int = 2):
15+
"""
16+
Initializes the figure and the subplots.
17+
18+
Parameters
19+
----------
20+
row : int
21+
The number of rows in subplot
22+
col : int
23+
The number of columns in subplot
24+
"""
25+
self._fig, self._ax = \
26+
plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
27+
plt.show(block=False)
28+
self._esc_pressed = False
29+
self._close_clicked = False
30+
self._fig.canvas.mpl_connect("key_press_event",
31+
self._process_button_press)
32+
self._fig.canvas.mpl_connect('close_event',
33+
self._close)
34+
35+
def wait_for_close(self):
36+
"""
37+
Waits for the user to close the figure
38+
using the esc key.
39+
"""
40+
while not (self._esc_pressed or self._close_clicked):
41+
plt.waitforbuttonpress(timeout=0.005)
42+
plt.close(self._fig)
43+
44+
def _process_button_press(self, event):
45+
"""
46+
Process the key_press_event.
47+
"""
48+
if event.key == 'escape':
49+
self._esc_pressed = True
50+
51+
def _close(self, _):
52+
"""
53+
Process the close_event.
54+
"""
55+
self._close_clicked = True
56+
57+
58+
def plot_errorbars(ax, x, y, err, onesided, color):
59+
"""
60+
Plots the error bars.
61+
62+
Parameters
63+
----------
64+
ax : matplotlib.axes._axes.Axes
65+
The axis on which to draw errorbars
66+
x : np.ndarray
67+
The shifted data x axis data
68+
y : np.ndarray
69+
The shifted data y axis data
70+
err : np.ndarray
71+
The shifted data e data
72+
onesided : bool
73+
A boolean to indicate whether to draw one sided errorbars
74+
color : str
75+
The hex representing the color of the errorbars
76+
"""
77+
y_error = [[0]*len(err), err] if onesided else err
78+
ax.errorbar(x=x,
79+
y=y,
80+
yerr=y_error,
81+
fmt='none',
82+
ecolor=color,
83+
elinewidth=1,
84+
capsize=0)
85+
ax.scatter(x=x, y=y, s=3, marker="o", color=color)
86+
87+
88+
def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
89+
"""
90+
Clears the previous plots and updates the ref and SLD plots.
91+
92+
Parameters
93+
----------
94+
data : PlotEventData
95+
The plot event data that contains all the information
96+
to generate the ref and sld plots
97+
fig : Figure
98+
The figure class that has two subplots
99+
delay : bool
100+
Controls whether to delay 0.005s after plot is created
101+
102+
Returns
103+
-------
104+
fig : Figure
105+
The figure class that has two subplots
106+
"""
107+
if fig is None:
108+
fig = Figure()
109+
elif fig._ax.shape != (2,):
110+
fig._fig.clf()
111+
fig._ax = fig._fig.subplots(1, 2)
112+
113+
ref_plot = fig._ax[0]
114+
sld_plot = fig._ax[1]
115+
116+
# Clears the previous plots
117+
ref_plot.cla()
118+
sld_plot.cla()
119+
120+
for i, (r, sd, sld, layer) in enumerate(zip(data.reflectivity,
121+
data.shiftedData,
122+
data.sldProfiles,
123+
data.resampledLayers)):
124+
125+
r, sd, sld, layer = map(lambda x: x[0], (r, sd, sld, layer))
126+
127+
# Calculate the divisor
128+
div = 1 if i == 0 else 2**(4*(i+1))
129+
130+
# Plot the reflectivity on plot (1,1)
131+
ref_plot.plot(r[:, 0],
132+
r[:, 1]/div,
133+
label=f'ref {i+1}',
134+
linewidth=2)
135+
color = ref_plot.get_lines()[-1].get_color()
136+
137+
if data.dataPresent[i]:
138+
sd_x = sd[:, 0]
139+
sd_y, sd_e = map(lambda x: x/div, (sd[:, 1], sd[:, 2]))
140+
141+
# Plot the errorbars
142+
indices_removed = np.flip(np.nonzero(sd_y - sd_e < 0)[0])
143+
sd_x_r, sd_y_r, sd_e_r = map(lambda x:
144+
np.delete(x, indices_removed),
145+
(sd_x, sd_y, sd_e))
146+
plot_errorbars(ref_plot, sd_x_r, sd_y_r, sd_e_r, False, color)
147+
148+
# Plot one sided errorbars
149+
indices_selected = [x for x in indices_removed
150+
if x not in np.nonzero(sd_y < 0)[0]]
151+
sd_x_s, sd_y_s, sd_e_s = map(lambda x:
152+
[x[i] for i in indices_selected],
153+
(sd_x, sd_y, sd_e))
154+
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)
155+
156+
# Plot the slds on plot (1,2)
157+
for j in range(1, sld.shape[1]):
158+
sld_plot.plot(sld[:, 0],
159+
sld[:, j],
160+
label=f'sld {i+1}',
161+
color=color,
162+
linewidth=2)
163+
164+
if data.resample[i] == 1 or data.modelType == 'custom xy':
165+
new = makeSLDProfileXY(layer[0, 1],
166+
layer[-1, 1],
167+
data.subRoughs[i],
168+
layer,
169+
len(layer),
170+
1.0)
171+
172+
sld_plot.plot([row[0]-49 for row in new],
173+
[row[1] for row in new],
174+
color=color,
175+
linewidth=1)
176+
177+
# Format the axis
178+
ref_plot.set_yscale('log')
179+
ref_plot.set_xscale('log')
180+
ref_plot.set_xlabel('Qz')
181+
ref_plot.set_ylabel('Ref')
182+
ref_plot.legend()
183+
ref_plot.grid()
184+
185+
sld_plot.set_xlabel('Z')
186+
sld_plot.set_ylabel('SLD')
187+
sld_plot.legend()
188+
sld_plot.grid()
189+
190+
if delay:
191+
plt.pause(0.005)
192+
193+
return fig

cpp/rat.cpp

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ setup_pybind11(cfg)
1616
#include "RAT/RATMain_initialize.h"
1717
#include "RAT/RATMain_terminate.h"
1818
#include "RAT/RATMain_types.h"
19+
#include "RAT/makeSLDProfileXY.h"
1920
#include "RAT/classHandle.hpp"
2021
#include "RAT/dylib.hpp"
2122
#include "RAT/events/eventManager.h"
@@ -1165,6 +1166,27 @@ py::tuple RATMain(const ProblemDefinition& problem_def, const Cells& cells, cons
11651166
bayesResultsFromStruct8T(bayesResults));
11661167
}
11671168

1169+
py::array_t<real_T> makeSLDProfileXY(real_T bulk_in,
1170+
real_T bulk_out,
1171+
real_T ssub,
1172+
const py::array_t<real_T> &layers,
1173+
real_T number_of_layers,
1174+
real_T repeats)
1175+
{
1176+
coder::array<real_T, 2U> out;
1177+
coder::array<real_T, 2U> layers_array = pyArrayToRatArray2d(layers);
1178+
RAT::makeSLDProfileXY(bulk_in,
1179+
bulk_out,
1180+
ssub,
1181+
layers_array,
1182+
number_of_layers,
1183+
repeats,
1184+
out);
1185+
1186+
return pyArrayFromRatArray2d(out);
1187+
1188+
}
1189+
11681190
class Module
11691191
{
11701192
public:
@@ -1434,5 +1456,7 @@ PYBIND11_MODULE(rat_core, m) {
14341456
.def_readwrite("fitLimits", &ProblemDefinition::fitLimits)
14351457
.def_readwrite("otherLimits", &ProblemDefinition::otherLimits);
14361458

1437-
m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");
1459+
m.def("RATMain", &RATMain, "Entry point for the main reflectivity computation.");
1460+
1461+
m.def("makeSLDProfileXY", &makeSLDProfileXY, "Creates the profiles for the SLD plots");
14381462
}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ pybind11 >= 2.4
44
pydantic >= 2.4.2, <= 2.6.4
55
pytest >= 7.4.0
66
pytest-cov >= 4.1.0
7+
matplotlib >= 3.8.3
78
StrEnum >= 0.4.15; python_version < '3.11'

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def build_libraries(self, libraries):
159159
libraries = [libevent],
160160
ext_modules = ext_modules,
161161
python_requires = '>=3.9',
162-
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4'],
162+
install_requires = ['numpy >= 1.20', 'prettytable >= 3.9.0', 'pydantic >= 2.4.2, <= 2.6.4', 'matplotlib >= 3.8.3'],
163163
extras_require = {':python_version < "3.11"': ['StrEnum >= 0.4.15'],
164164
'Dev': ['pytest>=7.4.0', 'pytest-cov>=4.1.0'],
165165
'Matlab_latest': ['matlabengine'],
27 KB
Binary file not shown.

0 commit comments

Comments
 (0)