Skip to content

Commit 927baaf

Browse files
authored
Adds live plot context manager and update events (#49)
* Adds live plot context manager and update events * Addresses review comments
1 parent b38bd3a commit 927baaf

File tree

8 files changed

+430
-398
lines changed

8 files changed

+430
-398
lines changed

RATapi/events.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,26 @@ def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventDat
5757
__event_callbacks[event_type].add(callback)
5858

5959

60-
def clear() -> None:
61-
"""Clears all event callbacks."""
62-
__event_impl.clear()
63-
for key in __event_callbacks:
64-
__event_callbacks[key] = set()
60+
def clear(key=None, callback=None) -> None:
61+
"""Clears all event callbacks or specific callback.
62+
63+
Parameters
64+
----------
65+
callback : Callable[[Union[str, PlotEventData, ProgressEventData]], None]
66+
The callback for when the event is triggered.
67+
68+
"""
69+
if key is None and callback is None:
70+
for key in __event_callbacks:
71+
__event_callbacks[key] = set()
72+
elif key is not None and callback is not None:
73+
__event_callbacks[key].remove(callback)
74+
75+
for value in __event_callbacks.values():
76+
if value:
77+
break
78+
else:
79+
__event_impl.clear()
6580

6681

6782
dir_path = os.path.dirname(os.path.realpath(__file__))

RATapi/utils/plotting.py

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,6 @@
1212
from RATapi.rat_core import PlotEventData, makeSLDProfileXY
1313

1414

15-
class Figure:
16-
"""Creates a plotting figure."""
17-
18-
def __init__(self, row: int = 1, col: int = 1):
19-
"""Initializes the figure and the subplots.
20-
21-
Parameters
22-
----------
23-
row : int, default: 1
24-
The number of rows in subplot
25-
col : int, default: 1
26-
The number of columns in subplot
27-
28-
"""
29-
self._fig, self._ax = plt.subplots(row, col, num="Reflectivity Algorithms Toolbox (RAT)")
30-
plt.show(block=False)
31-
self._esc_pressed = False
32-
self._close_clicked = False
33-
self._fig.canvas.mpl_connect("key_press_event", self._process_button_press)
34-
self._fig.canvas.mpl_connect("close_event", self._close)
35-
36-
def wait_for_close(self):
37-
"""Waits for the user to close the figure
38-
using the esc key.
39-
"""
40-
while not (self._esc_pressed or self._close_clicked):
41-
plt.waitforbuttonpress(timeout=0.005)
42-
plt.close(self._fig)
43-
44-
def _process_button_press(self, event):
45-
"""Process the key_press_event."""
46-
if event.key == "escape":
47-
self._esc_pressed = True
48-
49-
def _close(self, _):
50-
"""Process the close_event."""
51-
self._close_clicked = True
52-
53-
5415
def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_sided: bool, color: str):
5516
"""Plots the error bars.
5617
@@ -75,33 +36,39 @@ def plot_errorbars(ax: Axes, x: np.ndarray, y: np.ndarray, err: np.ndarray, one_
7536
ax.scatter(x=x, y=y, s=3, marker="o", color=color)
7637

7738

78-
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay: bool = True):
39+
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[plt.figure] = None, delay: bool = True):
7940
"""Clears the previous plots and updates the ref and SLD plots.
8041
8142
Parameters
8243
----------
8344
data : PlotEventData
8445
The plot event data that contains all the information
8546
to generate the ref and sld plots
86-
fig : Figure, optional
47+
fig : matplotlib.pyplot.figure, optional
8748
The figure class that has two subplots
8849
delay : bool, default: True
8950
Controls whether to delay 0.005s after plot is created
9051
9152
Returns
9253
-------
93-
fig : Figure
54+
fig : matplotlib.pyplot.figure
9455
The figure class that has two subplots
9556
9657
"""
97-
if fig is None:
98-
fig = Figure(1, 2)
99-
elif fig._ax.shape != (2,):
100-
fig._fig.clf()
101-
fig._ax = fig._fig.subplots(1, 2)
58+
preserve_zoom = False
10259

103-
ref_plot = fig._ax[0]
104-
sld_plot = fig._ax[1]
60+
if fig is None:
61+
fig = plt.subplots(1, 2)[0]
62+
elif len(fig.axes) != 2:
63+
fig.clf()
64+
fig.subplots(1, 2)
65+
fig.subplots_adjust(wspace=0.3)
66+
67+
ref_plot = fig.axes[0]
68+
sld_plot = fig.axes[1]
69+
if ref_plot.lines and fig.canvas.toolbar is not None:
70+
preserve_zoom = True
71+
fig.canvas.toolbar.push_current()
10572

10673
# Clears the previous plots
10774
ref_plot.cla()
@@ -160,16 +127,18 @@ def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay
160127
# Format the axis
161128
ref_plot.set_yscale("log")
162129
ref_plot.set_xscale("log")
163-
ref_plot.set_xlabel("Qz")
164-
ref_plot.set_ylabel("Ref")
130+
ref_plot.set_xlabel("$Q_{z} (\u00c5^{-1})$")
131+
ref_plot.set_ylabel("Reflectivity")
165132
ref_plot.legend()
166133
ref_plot.grid()
167134

168-
sld_plot.set_xlabel("Z")
169-
sld_plot.set_ylabel("SLD")
135+
sld_plot.set_xlabel("$Z (\u00c5)$")
136+
sld_plot.set_ylabel("$SLD (\u00c5^{-2})$")
170137
sld_plot.legend()
171138
sld_plot.grid()
172139

140+
if preserve_zoom:
141+
fig.canvas.toolbar.back()
173142
if delay:
174143
plt.pause(0.005)
175144

@@ -204,8 +173,52 @@ def plot_ref_sld(
204173
data.subRoughs = results.contrastParams.subRoughs
205174
data.resample = RATapi.inputs.make_resample(project)
206175

207-
figure = Figure(1, 2)
176+
figure = plt.subplots(1, 2)[0]
208177

209178
plot_ref_sld_helper(data, figure)
210-
if block:
211-
figure.wait_for_close()
179+
180+
plt.show(block=block)
181+
182+
183+
class LivePlot:
184+
"""Creates a plot that gets updates from the plot event during a
185+
calculation
186+
187+
Parameters
188+
----------
189+
block : bool, default: False
190+
Indicates the plot should block until it is closed
191+
192+
"""
193+
194+
def __init__(self, block=False):
195+
self.block = block
196+
self.closed = False
197+
198+
def __enter__(self):
199+
self.figure = plt.subplots(1, 2)[0]
200+
self.figure.canvas.mpl_connect("close_event", self._setCloseState)
201+
self.figure.show()
202+
RATapi.events.register(RATapi.events.EventTypes.Plot, self.plotEvent)
203+
204+
return self.figure
205+
206+
def _setCloseState(self, _):
207+
"""Close event handler"""
208+
self.closed = True
209+
210+
def plotEvent(self, event):
211+
"""Callback for the plot event.
212+
213+
Parameters
214+
----------
215+
event: PlotEventData
216+
The plot event data.
217+
"""
218+
if not self.closed and self.figure.number in plt.get_fignums():
219+
plot_ref_sld_helper(event, self.figure)
220+
221+
def __exit__(self, _exc_type, _exc_val, _traceback):
222+
RATapi.events.clear(RATapi.events.EventTypes.Plot, self.plotEvent)
223+
if not self.closed and self.figure.number in plt.get_fignums():
224+
plt.show(block=self.block)

cpp/RAT

cpp/rat.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,27 +192,34 @@ class EventBridge
192192
};
193193

194194
py::list unpackDataToCell(int rows, int cols, double* data, double* nData,
195-
double* data2, double* nData2, int dataCol)
195+
double* data2, double* nData2, bool isOutput2D=false)
196196
{
197197
py::list allResults;
198-
int dims[2] = {0, dataCol};
198+
int dims[2] = {0, 0};
199199
int offset = 0;
200200
for (int i = 0; i < rows; i++){
201-
py::list rowList;
202-
dims[0] = (int)nData[i] / dataCol;
201+
dims[0] = (int)nData[2*i];
202+
dims[1] = (int)nData[2*i+1];
203203
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
204204
std::memcpy(result.request().ptr, data + offset, result.nbytes());
205205
offset += result.size();
206-
rowList.append(result);
207-
allResults.append(rowList);
206+
if (isOutput2D){
207+
py::list rowList;
208+
rowList.append(result);
209+
allResults.append(rowList);
210+
}
211+
else{
212+
allResults.append(result);
213+
}
208214
}
209215

210216
if (data2 != NULL && nData2 != NULL)
211217
{
212218
// This is used to unpack the domains data into the second column
213219
offset = 0;
214220
for ( int i = 0; i < rows; i++){
215-
dims[0] = (int)nData2[i] / dataCol;
221+
dims[0] = (int)nData2[2*i];
222+
dims[1] = (int)nData2[2*i+1];
216223
auto result = py::array_t<double, py::array::f_style>({dims[0], dims[1]});
217224
std::memcpy(result.request().ptr, data2 + offset, result.nbytes());
218225
offset += result.size();
@@ -252,18 +259,18 @@ class EventBridge
252259
std::memcpy(eventData.dataPresent.request().ptr, pEvent->data->dataPresent, eventData.dataPresent.nbytes());
253260

254261
eventData.reflectivity = unpackDataToCell(pEvent->data->nContrast, 1,
255-
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL, 2);
262+
pEvent->data->reflect, pEvent->data->nReflect, NULL, NULL);
256263

257264
eventData.shiftedData = unpackDataToCell(pEvent->data->nContrast, 1,
258-
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL, 3);
265+
pEvent->data->shiftedData, pEvent->data->nShiftedData, NULL, NULL);
259266

260267
eventData.sldProfiles = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nSldProfiles2 == NULL) ? 1 : 2,
261268
pEvent->data->sldProfiles, pEvent->data->nSldProfiles,
262-
pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, 2);
263-
269+
pEvent->data->sldProfiles2, pEvent->data->nSldProfiles2, true);
270+
264271
eventData.resampledLayers = unpackDataToCell(pEvent->data->nContrast, (pEvent->data->nLayers2 == NULL) ? 1 : 2,
265-
pEvent->data->layers, pEvent->data->nLayers,
266-
pEvent->data->layers2, pEvent->data->nLayers, 2);
272+
pEvent->data->layers, pEvent->data->nLayers,
273+
pEvent->data->layers2, pEvent->data->nLayers2, true);
267274
this->callback(event.type, eventData);
268275
}
269276
};

0 commit comments

Comments
 (0)