Skip to content

Commit bf63817

Browse files
committed
Addresses review comments
1 parent 1353b78 commit bf63817

File tree

1 file changed

+50
-3
lines changed

1 file changed

+50
-3
lines changed

RATapi/utils/plotting.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Plot results using the matplotlib library."""
22

33
import copy
4+
import types
45
from functools import partial, wraps
56
from math import ceil, floor, sqrt
67
from statistics import stdev
78
from typing import Callable, Literal, Optional, Union
89

910
import matplotlib
1011
import matplotlib.pyplot as plt
12+
import matplotlib.transforms as mtransforms
1113
import numpy as np
1214
from matplotlib.axes._axes import Axes
1315
from scipy.ndimage import gaussian_filter1d
@@ -667,15 +669,15 @@ def plot_corner(
667669

668670
num_params = len(params)
669671

670-
fig, axes = plt.subplots(num_params, num_params, figsize=(14, 10))
672+
fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10))
671673
# i is row, j is column
672674
for i, row_param in enumerate(params):
673675
for j, col_param in enumerate(params):
674676
current_axes: Axes = axes[i][j]
675677
current_axes.tick_params(which="both", labelsize="medium")
676678
current_axes.xaxis.offsetText.set_fontsize("small")
677679
current_axes.yaxis.offsetText.set_fontsize("small")
678-
current_axes.yaxis.offsetText.set_x(-1.5)
680+
679681
if i == j: # diagonal: histograms
680682
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
681683
elif i > j: # lower triangle: 2d histograms
@@ -690,6 +692,10 @@ def plot_corner(
690692
if i != len(params) - 1:
691693
current_axes.get_xaxis().set_visible(False)
692694
# make labels invisible as titles cover that
695+
current_axes.yaxis._update_offset_text_position = types.MethodType(
696+
_y_update_offset_text_position, current_axes.yaxis
697+
)
698+
current_axes.yaxis.offset_text_position = "center"
693699
current_axes.set_ylabel("")
694700
current_axes.set_xlabel("")
695701
if return_fig:
@@ -807,6 +813,47 @@ def plot_one_hist(
807813
plt.show(block=block)
808814

809815

816+
def _y_update_offset_text_position(axis, _bboxes, bboxes2):
817+
"""Update the position of the Y axis offset text using the provided bounding boxes.
818+
819+
Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
820+
821+
Parameters
822+
----------
823+
axis : matplotlib.axis.YAxis
824+
Y axis to update.
825+
_bboxes : List
826+
list of bounding boxes
827+
bboxes2 : List
828+
list of bounding boxes
829+
"""
830+
x, y = axis.offsetText.get_position()
831+
832+
if axis.offset_text_position == "left":
833+
# y in axes coords, x in display coords
834+
axis.offsetText.set_transform(
835+
mtransforms.blended_transform_factory(axis.axes.transAxes, mtransforms.IdentityTransform())
836+
)
837+
838+
top = axis.axes.bbox.ymax
839+
y = top + axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
840+
841+
else:
842+
# x & y in display coords
843+
axis.offsetText.set_transform(mtransforms.IdentityTransform())
844+
845+
# Northwest of upper-right corner of right-hand extent of tick labels
846+
if bboxes2:
847+
bbox = mtransforms.Bbox.union(bboxes2)
848+
else:
849+
bbox = axis.axes.bbox
850+
center = bbox.ymin + (bbox.ymax - bbox.ymin) / 2
851+
x = bbox.xmin - axis.OFFSETTEXTPAD * axis.figure.dpi / 72.0
852+
y = center
853+
x_offset = 110
854+
axis.offsetText.set_position((x - x_offset, y))
855+
856+
810857
@assert_bayesian("Contour")
811858
def plot_contour(
812859
results: RATapi.outputs.BayesResults,
@@ -900,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
900947
"""
901948
nplots = len(indices)
902949
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
903-
fig = plt.subplots(nrows, ncols, figsize=(14, 10))[0]
950+
fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0]
904951
axs = fig.get_axes()
905952

906953
for plot_num, index in enumerate(indices):

0 commit comments

Comments
 (0)