Skip to content

Commit 5e54f55

Browse files
authored
Refactors plot_ref_sld to plot_ref_sld_helper, (#33)
1 parent a708a2f commit 5e54f55

File tree

4 files changed

+70
-36
lines changed

4 files changed

+70
-36
lines changed

RAT/events.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def get_event_callback(event_type: EventTypes) -> List[Callable[[Union[str, Plot
2626
event_type : EventTypes
2727
The event type.
2828
29-
Retuns
30-
------
29+
Returns
30+
-------
3131
callback : Callable[[Union[str, PlotEventData, ProgressEventData]], None]
3232
The callback for the event type.
3333
"""

RAT/utils/plotting.py

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Plots using the matplotlib library
33
"""
4+
from typing import Optional
45
import matplotlib.pyplot as plt
56
import numpy as np
67
from RAT.rat_core import PlotEventData, makeSLDProfileXY
@@ -11,15 +12,15 @@ class Figure:
1112
Creates a plotting figure.
1213
"""
1314

14-
def __init__(self, row: int = 1, col: int = 2):
15+
def __init__(self, row: int = 1, col: int = 1):
1516
"""
1617
Initializes the figure and the subplots.
1718
1819
Parameters
1920
----------
20-
row : int
21+
row : int, default: 1
2122
The number of rows in subplot
22-
col : int
23+
col : int, default: 1
2324
The number of columns in subplot
2425
"""
2526
self._fig, self._ax = \
@@ -55,7 +56,8 @@ def _close(self, _):
5556
self._close_clicked = True
5657

5758

58-
def plot_errorbars(ax, x, y, err, onesided, color):
59+
def plot_errorbars(ax: 'matplotlib.axes._axes.Axes', x: np.ndarray, y: np.ndarray, err: np.ndarray,
60+
one_sided: bool, color: str):
5961
"""
6062
Plots the error bars.
6163
@@ -69,12 +71,12 @@ def plot_errorbars(ax, x, y, err, onesided, color):
6971
The shifted data y axis data
7072
err : np.ndarray
7173
The shifted data e data
72-
onesided : bool
74+
one_sided : bool
7375
A boolean to indicate whether to draw one sided errorbars
7476
color : str
7577
The hex representing the color of the errorbars
7678
"""
77-
y_error = [[0]*len(err), err] if onesided else err
79+
y_error = [[0]*len(err), err] if one_sided else err
7880
ax.errorbar(x=x,
7981
y=y,
8082
yerr=y_error,
@@ -85,7 +87,7 @@ def plot_errorbars(ax, x, y, err, onesided, color):
8587
ax.scatter(x=x, y=y, s=3, marker="o", color=color)
8688

8789

88-
def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
90+
def plot_ref_sld_helper(data: PlotEventData, fig: Optional[Figure] = None, delay: bool = True):
8991
"""
9092
Clears the previous plots and updates the ref and SLD plots.
9193
@@ -94,9 +96,9 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
9496
data : PlotEventData
9597
The plot event data that contains all the information
9698
to generate the ref and sld plots
97-
fig : Figure
99+
fig : Figure, optional
98100
The figure class that has two subplots
99-
delay : bool
101+
delay : bool, default: True
100102
Controls whether to delay 0.005s after plot is created
101103
102104
Returns
@@ -105,7 +107,7 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
105107
The figure class that has two subplots
106108
"""
107109
if fig is None:
108-
fig = Figure()
110+
fig = Figure(1, 2)
109111
elif fig._ax.shape != (2,):
110112
fig._fig.clf()
111113
fig._ax = fig._fig.subplots(1, 2)
@@ -121,9 +123,6 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
121123
data.shiftedData,
122124
data.sldProfiles,
123125
data.resampledLayers)):
124-
125-
sld, layer = map(lambda x: x[0], (sld, layer))
126-
127126
# Calculate the divisor
128127
div = 1 if i == 0 else 2**(4*(i+1))
129128

@@ -154,25 +153,29 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
154153
plot_errorbars(ref_plot, sd_x_s, sd_y_s, sd_e_s, True, color)
155154

156155
# Plot the slds on plot (1,2)
157-
for j in range(1, sld.shape[1]):
158-
sld_plot.plot(sld[:, 0],
159-
sld[:, j],
156+
for j in range(len(sld)):
157+
sld_plot.plot(sld[j][:, 0],
158+
sld[j][:, 1],
160159
label=f'sld {i+1}',
161-
color=color,
162-
linewidth=2)
160+
linewidth=1)
163161

164162
if data.resample[i] == 1 or data.modelType == 'custom xy':
165-
new = makeSLDProfileXY(layer[0, 1],
166-
layer[-1, 1],
167-
data.subRoughs[i],
168-
layer,
169-
len(layer),
170-
1.0)
171-
172-
sld_plot.plot([row[0]-49 for row in new],
173-
[row[1] for row in new],
174-
color=color,
175-
linewidth=1)
163+
layers = data.resampledLayers[i][0]
164+
for j in range(len(data.resampledLayers[i])):
165+
layer = data.resampledLayers[i][j]
166+
if layers.shape[1] == 4:
167+
layer = np.delete(layer, 2, 1)
168+
new_profile = makeSLDProfileXY(layers[0, 1], # Bulk In
169+
layers[-1, 1], # Bulk Out
170+
data.subRoughs[i], # roughness
171+
layer,
172+
len(layer),
173+
1.0)
174+
175+
sld_plot.plot([row[0]-49 for row in new_profile],
176+
[row[1] for row in new_profile],
177+
color=color,
178+
linewidth=1)
176179

177180
# Format the axis
178181
ref_plot.set_yscale('log')
@@ -191,3 +194,33 @@ def plot_ref_sld(data: PlotEventData, fig: Figure = None, delay: bool = True):
191194
plt.pause(0.005)
192195

193196
return fig
197+
198+
199+
def plot_ref_sld(problem, results, block: bool = False):
200+
"""
201+
Plots the reflectivity and SLD profiles.
202+
203+
Parameters
204+
----------
205+
problem : ProblemDefinition
206+
An instance of the ProblemDefinition class
207+
results : OutputResult
208+
The result from the calculation
209+
block : bool, default: False
210+
Indicates the plot should block until it is closed
211+
"""
212+
data = PlotEventData()
213+
214+
data.reflectivity = results.reflectivity
215+
data.shiftedData = results.shiftedData
216+
data.sldProfiles = results.sldProfiles
217+
data.resampledLayers = results.resampledLayers
218+
data.dataPresent = problem.dataPresent
219+
data.subRoughs = results.contrastParams.subRoughs
220+
data.resample = problem.resample
221+
222+
figure = Figure(1, 2)
223+
224+
plot_ref_sld_helper(data, figure)
225+
if block:
226+
figure.wait_for_close()

tests/test_plotting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from unittest.mock import MagicMock
77
import matplotlib.pyplot as plt
88
from RAT.rat_core import PlotEventData
9-
from RAT.utils.plotting import Figure, plot_ref_sld
9+
from RAT.utils.plotting import Figure, plot_ref_sld_helper
1010

1111

1212
TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
@@ -40,7 +40,7 @@ def fig() -> Figure:
4040
"""
4141
plt.close('all')
4242
figure = Figure(1, 3)
43-
fig = plot_ref_sld(fig=figure, data=data())
43+
fig = plot_ref_sld_helper(fig=figure, data=data())
4444
return fig
4545

4646

@@ -155,7 +155,7 @@ def test_sld_profile_function_call(mock: MagicMock) -> None:
155155
Tests the makeSLDProfileXY function called with
156156
correct args.
157157
"""
158-
plot_ref_sld(data())
158+
plot_ref_sld_helper(data())
159159

160160
assert mock.call_count == 3
161161
assert mock.call_args_list[0].args[0] == 2.07e-06

tests/test_wrappers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66

77
def test_matlab_wrapper() -> None:
8-
with pytest.raises(ImportError):
9-
RAT.wrappers.MatlabWrapper('demo.m')
8+
with mock.patch.dict('sys.modules', {'matlab': mock.MagicMock(side_effect=ImportError)}):
9+
with pytest.raises(ImportError):
10+
RAT.wrappers.MatlabWrapper('demo.m')
1011
mocked_matlab_module = mock.MagicMock()
1112
mocked_engine = mock.MagicMock()
1213
mocked_matlab_module.engine.start_matlab.return_value = mocked_engine

0 commit comments

Comments
 (0)