11"""Plot results using the matplotlib library."""
22
33import copy
4+ import types
45from functools import partial , wraps
56from math import ceil , floor , sqrt
67from statistics import stdev
7- from textwrap import fill
88from typing import Callable , Literal , Optional , Union
99
1010import matplotlib
1111import matplotlib .pyplot as plt
12+ import matplotlib .transforms as mtransforms
1213import numpy as np
1314from matplotlib .axes ._axes import Axes
1415from scipy .ndimage import gaussian_filter1d
@@ -668,11 +669,15 @@ def plot_corner(
668669
669670 num_params = len (params )
670671
671- fig , axes = plt .subplots (num_params , num_params , figsize = (2 * num_params , 2 * num_params ))
672+ fig , axes = plt .subplots (num_params , num_params , figsize = (11 , 10 ))
672673 # i is row, j is column
673674 for i , row_param in enumerate (params ):
674675 for j , col_param in enumerate (params ):
675676 current_axes : Axes = axes [i ][j ]
677+ current_axes .tick_params (which = "both" , labelsize = "medium" )
678+ current_axes .xaxis .offsetText .set_fontsize ("small" )
679+ current_axes .yaxis .offsetText .set_fontsize ("small" )
680+
676681 if i == j : # diagonal: histograms
677682 plot_one_hist (results , param = row_param , smooth = smooth , axes = current_axes , ** hist_kwargs )
678683 elif i > j : # lower triangle: 2d histograms
@@ -687,10 +692,12 @@ def plot_corner(
687692 if i != len (params ) - 1 :
688693 current_axes .get_xaxis ().set_visible (False )
689694 # make labels invisible as titles cover that
695+ current_axes .yaxis ._update_offset_text_position = types .MethodType (
696+ _y_update_offset_text_position , current_axes .yaxis
697+ )
698+ current_axes .yaxis .offset_text_position = "center"
690699 current_axes .set_ylabel ("" )
691700 current_axes .set_xlabel ("" )
692-
693- fig .tight_layout ()
694701 if return_fig :
695702 return fig
696703 plt .show (block = block )
@@ -776,7 +783,7 @@ def plot_one_hist(
776783 color = "white" ,
777784 )
778785
779- axes .set_title (fill ( results .fitNames [param ], 20 )) # use `fill` to wrap long titles
786+ axes .set_title (results .fitNames [param ], loc = "left" , fontsize = "medium" )
780787
781788 if estimated_density :
782789 dx = bins [1 ] - bins [0 ]
@@ -806,6 +813,47 @@ def plot_one_hist(
806813 plt .show (block = block )
807814
808815
816+ def _y_update_offset_text_position (axis , _bboxes , bboxes2 ):
817+ """Update the position of the Y axis offset text using the provided bounding boxes.
818+
819+ Adapted from https://github.com/matplotlib/matplotlib/issues/4476#issuecomment-105627334.
820+
821+ Parameters
822+ ----------
823+ axis : matplotlib.axis.YAxis
824+ Y axis to update.
825+ _bboxes : List
826+ list of bounding boxes
827+ bboxes2 : List
828+ list of bounding boxes
829+ """
830+ x , y = axis .offsetText .get_position ()
831+
832+ if axis .offset_text_position == "left" :
833+ # y in axes coords, x in display coords
834+ axis .offsetText .set_transform (
835+ mtransforms .blended_transform_factory (axis .axes .transAxes , mtransforms .IdentityTransform ())
836+ )
837+
838+ top = axis .axes .bbox .ymax
839+ y = top + axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
840+
841+ else :
842+ # x & y in display coords
843+ axis .offsetText .set_transform (mtransforms .IdentityTransform ())
844+
845+ # Northwest of upper-right corner of right-hand extent of tick labels
846+ if bboxes2 :
847+ bbox = mtransforms .Bbox .union (bboxes2 )
848+ else :
849+ bbox = axis .axes .bbox
850+ center = bbox .ymin + (bbox .ymax - bbox .ymin ) / 2
851+ x = bbox .xmin - axis .OFFSETTEXTPAD * axis .figure .dpi / 72.0
852+ y = center
853+ x_offset = 110
854+ axis .offsetText .set_position ((x - x_offset , y ))
855+
856+
809857@assert_bayesian ("Contour" )
810858def plot_contour (
811859 results : RATapi .outputs .BayesResults ,
@@ -899,7 +947,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
899947 """
900948 nplots = len (indices )
901949 nrows , ncols = ceil (sqrt (nplots )), round (sqrt (nplots ))
902- fig = plt .subplots (nrows , ncols , figsize = (2.5 * ncols , 2 * nrows ))[0 ]
950+ fig = plt .subplots (nrows , ncols , figsize = (11 , 10 ))[0 ]
903951 axs = fig .get_axes ()
904952
905953 for plot_num , index in enumerate (indices ):
0 commit comments