Skip to content

Commit c9b2dea

Browse files
committed
Add progress callback for panel plot helper
1 parent f848758 commit c9b2dea

File tree

1 file changed

+24
-9
lines changed

1 file changed

+24
-9
lines changed

ratapi/utils/plotting.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,10 @@ def plot_contour(
982982

983983

984984
def panel_plot_helper(
985-
plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None
985+
plot_func: Callable,
986+
indices: list[int],
987+
fig: matplotlib.figure.Figure | None = None,
988+
progress_callback: Callable[[int, int], None] | None = None,
986989
) -> matplotlib.figure.Figure:
987990
"""Generate a panel-based plot from a single plot function.
988991
@@ -994,6 +997,9 @@ def panel_plot_helper(
994997
The list of indices to pass into ``plot_func``.
995998
fig : matplotlib.figure.Figure, optional
996999
The figure object to use for plot.
1000+
progress_callback: Union[Callable[[int, int], None], None]
1001+
Callback function for providing progress during plot creation
1002+
First argument is current completed sub plot and second is total number of sub plots
9971003
9981004
Returns
9991005
-------
@@ -1005,21 +1011,21 @@ def panel_plot_helper(
10051011
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
10061012

10071013
if fig is None:
1008-
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
1014+
fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0]
10091015
else:
10101016
fig.clf()
1011-
fig.subplots(nrows, ncols)
1017+
fig.subplots(nrows, ncols, subplot_kw={"visible": False})
10121018
axs = fig.get_axes()
1013-
1019+
current_plot = 0
10141020
for plot_num, index in enumerate(indices):
10151021
axs[plot_num].tick_params(which="both", labelsize="medium")
10161022
axs[plot_num].xaxis.offsetText.set_fontsize("small")
10171023
axs[plot_num].yaxis.offsetText.set_fontsize("small")
1024+
axs[plot_num].set_visible(True)
10181025
plot_func(axs[plot_num], index)
1019-
1020-
# blank unused plots
1021-
for i in range(nplots, len(axs)):
1022-
axs[i].set_visible(False)
1026+
if progress_callback is not None:
1027+
current_plot += 1
1028+
progress_callback(current_plot, nplots)
10231029

10241030
fig.tight_layout()
10251031
return fig
@@ -1036,6 +1042,7 @@ def plot_hists(
10361042
block: bool = False,
10371043
fig: matplotlib.figure.Figure | None = None,
10381044
return_fig: bool = False,
1045+
progress_callback: Callable[[int, int], None] | None = None,
10391046
**hist_settings,
10401047
):
10411048
"""Plot marginalised posteriors for several parameters from a Bayesian analysis.
@@ -1072,6 +1079,9 @@ def plot_hists(
10721079
The figure object to use for plot.
10731080
return_fig: bool, default False
10741081
If True, return the figure as an object instead of showing it.
1082+
progress_callback: Union[Callable[[int, int], None], None]
1083+
Callback function for providing progress during plot creation
1084+
First argument is current completed sub plot and second is total number of sub plots
10751085
hist_settings :
10761086
Settings passed to `np.histogram`. By default, the settings
10771087
passed are `bins = 25` and `density = True`.
@@ -1130,6 +1140,7 @@ def validate_dens_type(dens_type: str | None, param: str):
11301140
),
11311141
params,
11321142
fig,
1143+
progress_callback,
11331144
)
11341145
if return_fig:
11351146
return fig
@@ -1144,6 +1155,7 @@ def plot_chain(
11441155
block: bool = False,
11451156
fig: matplotlib.figure.Figure | None = None,
11461157
return_fig: bool = False,
1158+
progress_callback: Callable[[int, int], None] | None = None,
11471159
):
11481160
"""Plot the MCMC chain for each parameter of a Bayesian analysis.
11491161
@@ -1162,6 +1174,9 @@ def plot_chain(
11621174
The figure object to use for plot.
11631175
return_fig: bool, default False
11641176
If True, return the figure as an object instead of showing it.
1177+
progress_callback: Union[Callable[[int, int], None], None]
1178+
Callback function for providing progress during plot creation
1179+
First argument is current completed sub plot and second is total number of sub plots
11651180
11661181
Returns
11671182
-------
@@ -1187,7 +1202,7 @@ def plot_one_chain(axes: Axes, i: int):
11871202
axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip])
11881203
axes.set_title(results.fitNames[i], fontsize="small")
11891204

1190-
fig = panel_plot_helper(plot_one_chain, params, fig=fig)
1205+
fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback)
11911206
if return_fig:
11921207
return fig
11931208
plt.show(block=block)

0 commit comments

Comments
 (0)