Skip to content

Commit 1353b78

Browse files
committed
Adds fixed figure size and font size to make corner plot less awkward
1 parent 700f6b7 commit 1353b78

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

RATapi/utils/plotting.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import partial, wraps
55
from math import ceil, floor, sqrt
66
from statistics import stdev
7-
from textwrap import fill
87
from typing import Callable, Literal, Optional, Union
98

109
import matplotlib
@@ -668,11 +667,15 @@ def plot_corner(
668667

669668
num_params = len(params)
670669

671-
fig, axes = plt.subplots(num_params, num_params, figsize=(2 * num_params, 2 * num_params))
670+
fig, axes = plt.subplots(num_params, num_params, figsize=(14, 10))
672671
# i is row, j is column
673672
for i, row_param in enumerate(params):
674673
for j, col_param in enumerate(params):
675674
current_axes: Axes = axes[i][j]
675+
current_axes.tick_params(which="both", labelsize="medium")
676+
current_axes.xaxis.offsetText.set_fontsize("small")
677+
current_axes.yaxis.offsetText.set_fontsize("small")
678+
current_axes.yaxis.offsetText.set_x(-1.5)
676679
if i == j: # diagonal: histograms
677680
plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs)
678681
elif i > j: # lower triangle: 2d histograms
@@ -689,8 +692,6 @@ def plot_corner(
689692
# make labels invisible as titles cover that
690693
current_axes.set_ylabel("")
691694
current_axes.set_xlabel("")
692-
693-
fig.tight_layout()
694695
if return_fig:
695696
return fig
696697
plt.show(block=block)
@@ -776,7 +777,7 @@ def plot_one_hist(
776777
color="white",
777778
)
778779

779-
axes.set_title(fill(results.fitNames[param], 20)) # use `fill` to wrap long titles
780+
axes.set_title(results.fitNames[param], loc="left", fontsize="medium")
780781

781782
if estimated_density:
782783
dx = bins[1] - bins[0]
@@ -899,7 +900,7 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig
899900
"""
900901
nplots = len(indices)
901902
nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots))
902-
fig = plt.subplots(nrows, ncols, figsize=(2.5 * ncols, 2 * nrows))[0]
903+
fig = plt.subplots(nrows, ncols, figsize=(14, 10))[0]
903904
axs = fig.get_axes()
904905

905906
for plot_num, index in enumerate(indices):

tests/test_plotting.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import pickle
33
from math import ceil, sqrt
4-
from textwrap import fill
54
from unittest.mock import MagicMock, patch
65

76
import matplotlib.pyplot as plt
@@ -293,7 +292,7 @@ def test_hist(dream_results, param, hist_settings, est_dens):
293292

294293
# assert title is as expected
295294
# also tests string to index conversion
296-
assert ax.get_title() == fill(dream_results.fitNames[param] if isinstance(param, int) else param, 20)
295+
assert ax.get_title(loc="left") == dream_results.fitNames[param] if isinstance(param, int) else param
297296

298297
# assert range is default, unless given
299298
# this tests non-default hist_settings propagates correctly
@@ -377,8 +376,10 @@ def test_corner(dream_results, params):
377376
assert current_axes.get_xbound() == axes[-1][j].get_xbound()
378377
elif i == j:
379378
# check title is correct
380-
assert current_axes.get_title() == fill(
381-
dream_results.fitNames[params[i]] if isinstance(params[i], int) else params[i], 20
379+
assert (
380+
current_axes.get_title(loc="left") == dream_results.fitNames[params[i]]
381+
if isinstance(params[i], int)
382+
else params[i]
382383
)
383384

384385
plt.close(fig)

0 commit comments

Comments
 (0)