Skip to content

Commit a57541a

Browse files
authored
Reverts removal of bliiting (#177)
1 parent 9ce6e9f commit a57541a

File tree

2 files changed

+207
-1
lines changed

2 files changed

+207
-1
lines changed

ratapi/utils/plotting.py

Lines changed: 204 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,205 @@ def plot_ref_sld(
356356
plt.show(block=block)
357357

358358

359+
class BlittingSupport:
360+
"""Create a SLD plot that uses blitting to get faster draws.
361+
362+
The blit plot stores the background from an
363+
initial draw then updates the foreground (lines and error bars) if the background is not changed.
364+
365+
Parameters
366+
----------
367+
data : PlotEventData
368+
The plot event data that contains all the information
369+
to generate the ref and sld plots
370+
fig : matplotlib.pyplot.figure, optional
371+
The figure class that has two subplots
372+
linear_x : bool, default: False
373+
Controls whether the x-axis on reflectivity plot uses the linear scale
374+
q4 : bool, default: False
375+
Controls whether Q^4 is plotted on the reflectivity plot
376+
show_error_bar : bool, default: True
377+
Controls whether the error bars are shown
378+
show_grid : bool, default: False
379+
Controls whether the grid is shown
380+
show_legend : bool, default: True
381+
Controls whether the legend is shown
382+
shift_value : float, default: 100
383+
A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts
384+
"""
385+
386+
def __init__(
387+
self,
388+
data,
389+
fig=None,
390+
linear_x: bool = False,
391+
q4: bool = False,
392+
show_error_bar: bool = True,
393+
show_grid: bool = False,
394+
show_legend: bool = True,
395+
shift_value: float = 100,
396+
):
397+
self.figure = fig
398+
self.linear_x = linear_x
399+
self.q4 = q4
400+
self.show_error_bar = show_error_bar
401+
self.show_grid = show_grid
402+
self.show_legend = show_legend
403+
self.shift_value = shift_value
404+
self.update_plot(data)
405+
self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent)
406+
407+
def __del__(self):
408+
self.figure.canvas.mpl_disconnect(self.event_id)
409+
410+
def resizeEvent(self, _event):
411+
"""Ensure the background is updated after a resize event."""
412+
self.__background_changed = True
413+
414+
def update(self, data):
415+
"""Update the foreground, if background has not changed otherwise it updates full plot.
416+
417+
Parameters
418+
----------
419+
data : PlotEventData
420+
The plot event data that contains all the information
421+
to generate the ref and sld plots
422+
"""
423+
if self.__background_changed:
424+
self.update_plot(data)
425+
else:
426+
self.update_foreground(data)
427+
428+
def __setattr__(self, name, value):
429+
old_value = getattr(self, name, None)
430+
if value == old_value:
431+
return
432+
433+
super().__setattr__(name, value)
434+
if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend", "shift_value"]:
435+
self.__background_changed = True
436+
437+
def set_animated(self, is_animated: bool):
438+
"""Set the animated property of foreground plot elements.
439+
440+
Parameters
441+
----------
442+
is_animated : bool
443+
Indicates if the animated property should be set.
444+
"""
445+
for line in self.figure.axes[0].lines:
446+
line.set_animated(is_animated)
447+
for line in self.figure.axes[1].lines:
448+
line.set_animated(is_animated)
449+
for container in self.figure.axes[0].containers:
450+
container[2][0].set_animated(is_animated)
451+
452+
def adjust_error_bar(self, error_bar_container, x, y, y_error):
453+
"""Adjust the error bar data.
454+
455+
Parameters
456+
----------
457+
error_bar_container : Tuple
458+
Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines)
459+
x : np.ndarray
460+
The shifted data x axis data
461+
y : np.ndarray
462+
The shifted data y axis data
463+
y_error : np.ndarray
464+
The shifted data y axis error data
465+
"""
466+
line, _, (bars_y,) = error_bar_container
467+
468+
line.set_data(x, y)
469+
x_base = x
470+
y_base = y
471+
472+
y_error_top = y_base + y_error
473+
y_error_bottom = y_base - y_error
474+
475+
new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)]
476+
bars_y.set_segments(new_segments_y)
477+
478+
def update_plot(self, data):
479+
"""Update the full plot.
480+
481+
Parameters
482+
----------
483+
data : PlotEventData
484+
The plot event data that contains all the information
485+
to generate the ref and sld plots
486+
"""
487+
if self.figure is not None:
488+
self.figure.clf()
489+
self.figure = ratapi.plotting.plot_ref_sld_helper(
490+
data,
491+
self.figure,
492+
linear_x=self.linear_x,
493+
q4=self.q4,
494+
show_error_bar=self.show_error_bar,
495+
show_grid=self.show_grid,
496+
show_legend=self.show_legend,
497+
animated=True,
498+
)
499+
self.figure.tight_layout(pad=1)
500+
self.figure.canvas.draw()
501+
self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox)
502+
for line in self.figure.axes[0].lines:
503+
self.figure.axes[0].draw_artist(line)
504+
for line in self.figure.axes[1].lines:
505+
self.figure.axes[1].draw_artist(line)
506+
for container in self.figure.axes[0].containers:
507+
self.figure.axes[0].draw_artist(container[2][0])
508+
self.figure.canvas.blit(self.figure.bbox)
509+
self.set_animated(False)
510+
self.__background_changed = False
511+
512+
def update_foreground(self, data):
513+
"""Update the plot foreground only.
514+
515+
Parameters
516+
----------
517+
data : PlotEventData
518+
The plot event data that contains all the information
519+
to generate the ref and sld plots
520+
"""
521+
self.set_animated(True)
522+
self.figure.canvas.restore_region(self.bg)
523+
plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value)
524+
525+
offset = 2 if self.show_error_bar else 1
526+
for i in range(
527+
0,
528+
len(self.figure.axes[0].lines),
529+
):
530+
self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1])
531+
self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i])
532+
533+
i = 0
534+
for j in range(len(plot_data["sld"])):
535+
for sld in plot_data["sld"][j]:
536+
self.figure.axes[1].lines[i].set_data(sld[0], sld[1])
537+
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
538+
i += 1
539+
540+
if plot_data["sld_resample"]:
541+
for resampled in plot_data["sld_resample"][j]:
542+
self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1])
543+
self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i])
544+
i += 1
545+
546+
for i, container in enumerate(self.figure.axes[0].containers):
547+
self.adjust_error_bar(
548+
container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2]
549+
)
550+
self.figure.axes[0].draw_artist(container[2][0])
551+
self.figure.axes[0].draw_artist(container[0])
552+
553+
self.figure.canvas.blit(self.figure.bbox)
554+
self.figure.canvas.flush_events()
555+
self.set_animated(False)
556+
557+
359558
class LivePlot:
360559
"""Create a plot that gets updates from the plot event during a calculation.
361560
@@ -369,6 +568,7 @@ class LivePlot:
369568
def __init__(self, block=False):
370569
self.block = block
371570
self.closed = False
571+
self.blit_plot = None
372572

373573
def __enter__(self):
374574
self.figure = plt.subplots(1, 2)[0]
@@ -394,7 +594,10 @@ def plotEvent(self, event):
394594
395595
"""
396596
if not self.closed and self.figure.number in plt.get_fignums():
397-
plot_ref_sld_helper(event, self.figure)
597+
if self.blit_plot is None:
598+
self.blit_plot = BlittingSupport(event, self.figure)
599+
else:
600+
self.blit_plot.update(event)
398601

399602
def __exit__(self, _exc_type, _exc_val, _traceback):
400603
ratapi.events.clear(ratapi.events.EventTypes.Plot, self.plotEvent)

tests/test_orso_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def prist():
3636
],
3737
)
3838
@pytest.mark.parametrize("absorption", [True, False])
39+
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
3940
def test_orso_model_to_rat(model, absorption):
4041
"""Test that orso_model_to_rat gives the expected parameters, layers and model."""
4142

@@ -72,6 +73,7 @@ def test_orso_model_to_rat(model, absorption):
7273
"prist5_10K_m_025.Rqz.ort",
7374
],
7475
)
76+
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
7577
def test_load_ort_data(test_data):
7678
"""Test that .ort data is loaded correctly."""
7779
# manually get the test data for comparison
@@ -104,6 +106,7 @@ def test_load_ort_data(test_data):
104106
["prist5_10K_m_025.Rqz.ort", "prist.json"],
105107
],
106108
)
109+
@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available")
107110
def test_load_ort_project(test_data, expected_data):
108111
"""Test that a project with model data is loaded correctly."""
109112
ort_data = ORSOProject(Path(TEST_DIR_PATH, test_data))

0 commit comments

Comments
 (0)