Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 46 additions & 3 deletions pypfopt/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def _plot_io(**kwargs):
plt.show()


def plot_covariance(cov_matrix, plot_correlation=False, show_tickers=True, **kwargs):
def plot_covariance(
cov_matrix, plot_correlation=False, show_tickers=True, show_values=False, **kwargs
):
"""
Generate a basic plot of the covariance (or correlation) matrix, given a
covariance matrix.
Expand All @@ -80,6 +82,9 @@ def plot_covariance(cov_matrix, plot_correlation=False, show_tickers=True, **kwa
show_tickers : bool, optional
whether to use tickers as labels (not recommended for large portfolios),
defaults to True
show_values : bool, optional
if True, annotate each cell with the numeric value formatted to
two decimal places. Defaults to False.

Returns
-------
Expand All @@ -92,18 +97,56 @@ def plot_covariance(cov_matrix, plot_correlation=False, show_tickers=True, **kwa
matrix = risk_models.cov_to_corr(cov_matrix)
else:
matrix = cov_matrix

fig, ax = plt.subplots()

cax = ax.imshow(matrix)
fig.colorbar(cax)

# if show_tickers:
# ax.set_xticks(np.arange(0, matrix.shape[0], 1))
# ax.set_xticklabels(matrix.index)
# ax.set_yticks(np.arange(0, matrix.shape[0], 1))
# ax.set_yticklabels(matrix.index)
# plt.xticks(rotation=90)
if show_tickers:
ax.set_xticks(np.arange(0, matrix.shape[0], 1))
ax.set_xticklabels(matrix.index)
# Handle both DataFrame and ndarray for tick labels
if hasattr(matrix, "index"):
labels = matrix.index
else:
# For numpy array, create generic labels
labels = [f"Asset {i + 1}" for i in range(matrix.shape[0])]

ax.set_xticklabels(labels)
ax.set_yticks(np.arange(0, matrix.shape[0], 1))
ax.set_yticklabels(matrix.index)
ax.set_yticklabels(labels)
plt.xticks(rotation=90)

# Optional: overlay numeric values on each cell
if show_values:
is_dataframe = hasattr(matrix, "iloc")
n_rows, n_cols = matrix.shape

for i in range(n_rows):
for j in range(n_cols):
if is_dataframe:
val = matrix.iloc[i, j]
else:
val = matrix[i, j]

text_str = f"{val:.2f}"

ax.text(
j,
i,
text_str,
ha="center",
va="center",
color="w",
fontsize=plt.rcParams.get("font.size", 10) * 0.9,
)

_plot_io(**kwargs)

return ax
Expand Down
4 changes: 2 additions & 2 deletions tests/test_efficient_frontier.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def test_min_vol_pair_constraint():
ef.min_volatility()
old_sum = ef.weights[:2].sum()
ef = setup_efficient_frontier()
ef.add_constraint(lambda w: (w[1] + w[0] <= old_sum / 2))
ef.add_constraint(lambda w: w[1] + w[0] <= old_sum / 2)
ef.min_volatility()
new_sum = ef.weights[:2].sum()
assert new_sum <= old_sum / 2 + 1e-4
Expand All @@ -620,7 +620,7 @@ def test_max_sharpe_pair_constraint():
old_sum = ef.weights[:2].sum()

ef = setup_efficient_frontier()
ef.add_constraint(lambda w: (w[1] + w[0] <= old_sum / 2))
ef.add_constraint(lambda w: w[1] + w[0] <= old_sum / 2)
ef.max_sharpe(risk_free_rate=0.02)
new_sum = ef.weights[:2].sum()
assert new_sum <= old_sum / 2 + 1e-4
Expand Down
64 changes: 64 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,67 @@ def test_plot_efficient_frontier():
ef = setup_efficient_frontier()
ef.min_volatility()
optimal_ret, optimal_risk, _ = ef.portfolio_performance(risk_free_rate=0.02)


@pytest.mark.skipif(
not _check_soft_dependencies(["matplotlib"], severity="none"),
reason="skip test if matplotlib is not installed in environment",
)
def test_plot_covariance_show_values():
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Simple 3x3 covariance matrix
cov_data = np.array(
[[0.04, 0.01, 0.002], [0.01, 0.09, 0.003], [0.002, 0.003, 0.16]]
)
tickers = ["A", "B", "C"]
df = pd.DataFrame(cov_data, index=tickers, columns=tickers)

def count_texts(ax):
return len([obj for obj in ax.findobj() if obj.__class__.__name__ == "Text"])

# Test with ndarray input, show_values=False (baseline)
plt.figure()
ax = plotting.plot_covariance(cov_data, showfig=False)
baseline_texts = count_texts(ax)
plt.clf()
plt.close()

# Test with ndarray input, show_values=True
plt.figure()
ax = plotting.plot_covariance(cov_data, show_values=True, showfig=False)
with_values_texts = count_texts(ax)
plt.clf()
plt.close()

# Expect more text annotations when show_values=True
assert with_values_texts > baseline_texts

# Test with DataFrame input, show_values=False
plt.figure()
ax = plotting.plot_covariance(df, showfig=False)
baseline_texts_df = count_texts(ax)
plt.clf()
plt.close()

# Test with DataFrame input, show_values=True
plt.figure()
ax = plotting.plot_covariance(df, show_values=True, showfig=False)
with_values_texts_df = count_texts(ax)
plt.clf()
plt.close()

assert with_values_texts_df > baseline_texts_df

# Ensure saving still works
with tempfile.TemporaryDirectory() as tmpdir:
fname = f"{tmpdir}/cov_plot.png"
ax = plotting.plot_covariance(
df, show_values=True, filename=fname, showfig=False
)
assert os.path.exists(fname)
assert os.path.getsize(fname) > 0
plt.clf()
plt.close()
Loading