Skip to content

Commit 37c0696

Browse files
authored
Adds fixed figure size and font size to make corner plot less awkward (#159)
1 parent 8bd5363 commit 37c0696

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

RATapi/utils/plotting.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Plot results using the matplotlib library."""
22

33
import copy
4+
import types
45
from functools import partial, wraps
56
from math import ceil, floor, sqrt
67
from statistics import stdev
7-
from textwrap import fill
88
from typing import Callable, Literal, Optional, Union
99

1010
import matplotlib
1111
import matplotlib.pyplot as plt
12+
import matplotlib.transforms as mtransforms
1213
import numpy as np
1314
from matplotlib.axes._axes import Axes
1415
from scipy.ndimage import gaussian_filter1d
@@ -668,11 +669,15 @@ def plot_corner(
668669

669670
num_params = len(params)
670671

671-
fig, axes = plt.subplots(num_params, num_params, figsize=(2 * num_params, 2 * num_params))
672+
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
672673
# i is row, j is column
673674
for i, row_param in enumerate(params):
674675
for j, col_param in enumerate(params):
675676
current_axes: Axes = axes[i][j]
677+
current_axes.tick_params(which="both", labelsize="medium")
678+
current_axes.xaxis.offsetText.set_fontsize("small")
679+
current_axes.yaxis.offsetText.set_fontsize("small")
680+
676681
if i == j: # diagonal: histograms
677682
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
678683
elif i > j: # lower triangle: 2d histograms
@@ -687,10 +692,12 @@ def plot_corner(
687692
if i != len(params) - 1:
688693
current_axes.get_xaxis().set_visible(False)
689694
# make labels invisible as titles cover that
695+
current_axes.yaxis._update_offset_text_position = types.MethodType(
696+
_y_update_offset_text_position, current_axes.yaxis
697+
)
698+
current_axes.yaxis.offset_text_position = "center"
690699
current_axes.set_ylabel("")
691700
current_axes.set_xlabel("")
692-
693-
fig.tight_layout()
694701
if return_fig:
695702
return fig
696703
plt.show(block=block)
@@ -776,7 +783,7 @@ def plot_one_hist(
776783
color="white",
777784
)
778785

779-
axes.set_title(fill(results.fitNames[param], 20)) # use `fill` to wrap long titles
786+
axes.set_title(results.fitNames[param], loc="left", fontsize="medium")
780787

781788
if estimated_density:
782789
dx = bins[1] - bins[0]
@@ -806,6 +813,47 @@ def plot_one_hist(
806813
plt.show(block=block)
807814

808815

816+
def _y_update_offset_text_position(axis, _bboxes, bboxes2):
817+
"""Update the position of the Y axis offset text using the provided bounding boxes.
818+
819+
Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
820+
821+
Parameters
822+
----------
823+
axis : matplotlib.axis.YAxis
824+
Y axis to update.
825+
_bboxes : List
826+
list of bounding boxes
827+
bboxes2 : List
828+
list of bounding boxes
829+
"""
830+
x, y = axis.offsetText.get_position()
831+
832+
if axis.offset_text_position == "left":
833+
# y in axes coords, x in display coords
834+
axis.offsetText.set_transform(
835+
mtransforms.blended_transform_factory(axis.axes.transAxes, mtransforms.IdentityTransform())
836+
)
837+
838+
top = axis.axes.bbox.ymax
839+
y = top + axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
840+
841+
else:
842+
# x & y in display coords
843+
axis.offsetText.set_transform(mtransforms.IdentityTransform())
844+
845+
# Northwest of upper-right corner of right-hand extent of tick labels
846+
if bboxes2:
847+
bbox = mtransforms.Bbox.union(bboxes2)
848+
else:
849+
bbox = axis.axes.bbox
850+
center = bbox.ymin + (bbox.ymax - bbox.ymin) / 2
851+
x = bbox.xmin - axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
852+
y = center
853+
x_offset = 110
854+
axis.offsetText.set_position((x - x_offset, y))
855+
856+
809857
@assert_bayesian("Contour")
810858
def plot_contour(
811859
results: RATapi.outputs.BayesResults,
@@ -899,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
899947
"""
900948
nplots = len(indices)
901949
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
902-
fig = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 2 * nrows))[0]
950+
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
903951
axs = fig.get_axes()
904952

905953
for plot_num, index in enumerate(indices):

tests/test_plotting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import pickle
33
from math import ceil, sqrt
4-
from textwrap import fill
54
from unittest.mock import MagicMock, patch
65

76
import matplotlib.pyplot as plt
@@ -293,7 +292,7 @@ def test_hist(dream_results, param, hist_settings, est_dens):
293292

294293
# assert title is as expected
295294
# also tests string to index conversion
296-
assert ax.get_title() == fill(dream_results.fitNames[param] if isinstance(param, int) else param, 20)
295+
assert ax.get_title(loc="left") == dream_results.fitNames[param] if isinstance(param, int) else param
297296

298297
# assert range is default, unless given
299298
# this tests non-default hist_settings propagates correctly
@@ -377,8 +376,10 @@ def test_corner(dream_results, params):
377376
assert current_axes.get_xbound() == axes[-1][j].get_xbound()
378377
elif i == j:
379378
# check title is correct
380-
assert current_axes.get_title() == fill(
381-
dream_results.fitNames[params[i]] if isinstance(params[i], int) else params[i], 20
379+
assert (
380+
current_axes.get_title(loc="left") == dream_results.fitNames[params[i]]
381+
if isinstance(params[i], int)
382+
else params[i]
382383
)
383384

384385
plt.close(fig)

0 commit comments

Comments
 (0)