44from functools import partial , wraps
55from math import ceil , floor , sqrt
66from statistics import stdev
7- from textwrap import fill
87from typing import Callable , Literal , Optional , Union
98
109import matplotlib
@@ -668,11 +667,15 @@ def plot_corner(
668667
669668 num_params = len (params )
670669
671- fig , axes = plt .subplots (num_params , num_params , figsize = (2 * num_params , 2 * num_params ))
670+ fig , axes = plt .subplots (num_params , num_params , figsize = (14 , 10 ))
672671 # i is row, j is column
673672 for i , row_param in enumerate (params ):
674673 for j , col_param in enumerate (params ):
675674 current_axes : Axes = axes [i ][j ]
675+ current_axes .tick_params (which = "both" , labelsize = "medium" )
676+ current_axes .xaxis .offsetText .set_fontsize ("small" )
677+ current_axes .yaxis .offsetText .set_fontsize ("small" )
678+ current_axes .yaxis .offsetText .set_x (- 1.5 )
676679 if i == j : # diagonal: histograms
677680 plot_one_hist (results , param = row_param , smooth = smooth , axes = current_axes , ** hist_kwargs )
678681 elif i > j : # lower triangle: 2d histograms
@@ -689,8 +692,6 @@ def plot_corner(
689692 # make labels invisible as titles cover that
690693 current_axes .set_ylabel ("" )
691694 current_axes .set_xlabel ("" )
692-
693- fig .tight_layout ()
694695 if return_fig :
695696 return fig
696697 plt .show (block = block )
@@ -776,7 +777,7 @@ def plot_one_hist(
776777 color = "white" ,
777778 )
778779
779- axes .set_title (fill ( results .fitNames [param ], 20 )) # use `fill` to wrap long titles
780+ axes .set_title (results .fitNames [param ], loc = "left" , fontsize = "medium" )
780781
781782 if estimated_density :
782783 dx = bins [1 ] - bins [0 ]
@@ -899,7 +900,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
899900 """
900901 nplots = len (indices )
901902 nrows , ncols = ceil (sqrt (nplots )), round (sqrt (nplots ))
902- fig = plt .subplots (nrows , ncols , figsize = (2.5 * ncols , 2 * nrows ))[0 ]
903+ fig = plt .subplots (nrows , ncols , figsize = (14 , 10 ))[0 ]
903904 axs = fig .get_axes ()
904905
905906 for plot_num , index in enumerate (indices ):
0 commit comments