11"""Plot results using the matplotlib library."""
22
33import copy
4+ import types
45from functools import partial , wraps
56from math import ceil , floor , sqrt
67from statistics import stdev
78from typing import Callable , Literal , Optional , Union
89
910import matplotlib
1011import matplotlib .pyplot as plt
12+ import matplotlib .transforms as mtransforms
1113import numpy as np
1214from matplotlib .axes ._axes import Axes
1315from scipy .ndimage import gaussian_filter1d
@@ -667,15 +669,15 @@ def plot_corner(
667669
668670 num_params = len (params )
669671
670- fig , axes = plt .subplots (num_params , num_params , figsize = (14 , 10 ))
672+ fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ))
671673 # i is row, j is column
672674 for i , row_param in enumerate (params ):
673675 for j , col_param in enumerate (params ):
674676 current_axes : Axes = axes [i ][j ]
675677 current_axes .tick_params (which = "both" , labelsize = "medium" )
676678 current_axes .xaxis .offsetText .set_fontsize ("small" )
677679 current_axes .yaxis .offsetText .set_fontsize ("small" )
678- current_axes . yaxis . offsetText . set_x ( - 1.5 )
680+
679681 if i == j : # diagonal: histograms
680682 plot_one_hist (results , param = row_param , smooth = smooth , axes = current_axes , ** hist_kwargs )
681683 elif i > j : # lower triangle: 2d histograms
@@ -690,6 +692,10 @@ def plot_corner(
690692 if i != len (params ) - 1 :
691693 current_axes .get_xaxis ().set_visible (False )
692694 # make labels invisible as titles cover that
695+ current_axes .yaxis ._update_offset_text_position = types .MethodType (
696+ _y_update_offset_text_position , current_axes .yaxis
697+ )
698+ current_axes .yaxis .offset_text_position = "center"
693699 current_axes .set_ylabel ("" )
694700 current_axes .set_xlabel ("" )
695701 if return_fig :
@@ -807,6 +813,47 @@ def plot_one_hist(
807813 plt .show (block = block )
808814
809815
816+ def _y_update_offset_text_position (axis , _bboxes , bboxes2 ):
817+ """Update the position of the Y axis offset text using the provided bounding boxes.
818+
819+ Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
820+
821+ Parameters
822+ ----------
823+ axis : matplotlib.axis.YAxis
824+ Y axis to update.
825+ _bboxes : List
826+ list of bounding boxes
827+ bboxes2 : List
828+ list of bounding boxes
829+ """
830+ x , y = axis .offsetText .get_position ()
831+
832+ if axis .offset_text_position == "left" :
833+ # y in axes coords, x in display coords
834+ axis .offsetText .set_transform (
835+ mtransforms .blended_transform_factory (axis .axes .transAxes , mtransforms .IdentityTransform ())
836+ )
837+
838+ top = axis .axes .bbox .ymax
839+ y = top + axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
840+
841+ else :
842+ # x & y in display coords
843+ axis .offsetText .set_transform (mtransforms .IdentityTransform ())
844+
845+ # Northwest of upper-right corner of right-hand extent of tick labels
846+ if bboxes2 :
847+ bbox = mtransforms .Bbox .union (bboxes2 )
848+ else :
849+ bbox = axis .axes .bbox
850+ center = bbox .ymin + (bbox .ymax - bbox .ymin ) / 2
851+ x = bbox .xmin - axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
852+ y = center
853+ x_offset = 110
854+ axis .offsetText .set_position ((x - x_offset , y ))
855+
856+
810857@assert_bayesian ("Contour" )
811858def plot_contour (
812859 results : RATapi .outputs .BayesResults ,
@@ -900,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
900947 """
901948 nplots = len (indices )
902949 nrows , ncols = ceil (sqrt (nplots )), round (sqrt (nplots ))
903- fig = plt .subplots (nrows , ncols , figsize = (14 , 10 ))[0 ]
950+ fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ))[0 ]
904951 axs = fig .get_axes ()
905952
906953 for plot_num , index in enumerate (indices ):
0 commit comments