Skip to content

Commit 8694c7b

Browse files
committed
Add histograms to wholeSkyPlot, bonus speed ups, metric thresholds and error catching
1 parent fec249c commit 8694c7b

File tree

1 file changed

+105
-38
lines changed

1 file changed

+105
-38
lines changed

python/lsst/analysis/tools/actions/plot/wholeSkyPlot.py

Lines changed: 105 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,23 @@
2323

2424
__all__ = ("WholeSkyPlot",)
2525

26+
import importlib.resources as importResources
2627
from typing import Mapping, Optional
2728

29+
import lsst.analysis.tools
2830
import matplotlib.patheffects as pathEffects
29-
import matplotlib.pyplot as plt
3031
import numpy as np
32+
import yaml
3133
from lsst.pex.config import ChoiceField, Field, ListField
34+
from lsst.utils.plotting import (
35+
accent_color,
36+
divergent_cmap,
37+
make_figure,
38+
set_rubin_plotstyle,
39+
stars_cmap,
40+
stars_color,
41+
)
42+
from matplotlib import gridspec
3243
from matplotlib.collections import PatchCollection
3344
from matplotlib.colors import CenteredNorm
3445
from matplotlib.figure import Figure
@@ -37,7 +48,7 @@
3748
from ...interfaces import KeyedData, KeyedDataSchema, PlotAction, Scalar, Vector
3849
from ...math import nanSigmaMad
3950
from ...utils import getTractCorners
40-
from .plotUtils import addPlotInfo, mkColormap
51+
from .plotUtils import addPlotInfo
4152

4253

4354
class WholeSkyPlot(PlotAction):
@@ -58,8 +69,6 @@ class WholeSkyPlot(PlotAction):
5869
xLimits = ListField[float](doc="Plotting limits for the x axis.", default=[-5.0, 365.0])
5970
yLimits = ListField[float](doc="Plotting limits for the y axis.", default=[-10.0, 60.0])
6071
autoAxesLimits = Field[bool](doc="Find axes limits automatically.", default=True)
61-
dpi = Field[int](doc="DPI size of the figure.", default=500)
62-
figureSize = ListField[float](doc="Size of the figure.", default=[9.0, 3.5])
6372
colorBarMin = Field[float](doc="The minimum value of the color bar.", optional=True)
6473
colorBarMax = Field[float](doc="The minimum value of the color bar.", optional=True)
6574
colorBarRange = Field[float](
@@ -84,6 +93,11 @@ class WholeSkyPlot(PlotAction):
8493
showNaNs = Field[bool](doc="Show the NaNs on the plot.", default=True)
8594
labelTracts = Field[bool](doc="Label the tracts.", default=False)
8695

96+
addThreshold = Field[bool](
97+
doc="Read in the predefined threshold and add it to the histogram.",
98+
default=True,
99+
)
100+
87101
def getInputSchema(self, **kwargs) -> KeyedDataSchema:
88102
base = []
89103
base.append(("z", Vector))
@@ -161,15 +175,18 @@ def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outl
161175
text : `str`
162176
A string containing the 5 tracts with the largest outlier values.
163177
"""
164-
text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": "
178+
if self.addThreshold:
179+
text = "Tracts with value outside thresholds: "
180+
else:
181+
text = f"Tracts with |value| > {multiplier}" + r"$\sigma_{MAD}$" + ": "
165182
if len(outlierInds) > 0:
166183
outlierValues = np.array(values)[outlierInds]
167184
outlierTracts = np.array(tracts)[outlierInds]
168185
# Sort values in descending (-) absolute value order discounting
169186
# NaNs.
170187
maxInds = np.argsort(-np.abs(outlierValues))
171188
# Show up to five values on the plot.
172-
for ind in maxInds[:5]:
189+
for ind in maxInds[:10]:
173190
val = outlierValues[ind]
174191
tract = outlierTracts[ind]
175192
text += f"{tract}, {val:.3}; "
@@ -225,6 +242,10 @@ def makePlot(
225242
if plotInfo is None:
226243
plotInfo = {}
227244

245+
if self.addThreshold:
246+
metricThresholdFile = importResources.read_text(lsst.analysis.tools, "metricInformation.yaml")
247+
metricDefs = yaml.safe_load(metricThresholdFile)
248+
228249
# Prevent Bands in the plot info showing a list of bands.
229250
# If bands is a list, it implies that parameterizedBand=False,
230251
# and that the metric is not band-specific.
@@ -236,11 +257,13 @@ def makePlot(
236257
match self.colorMapType:
237258
case "sequential":
238259
if colorMap is None:
239-
colorMap = mkColormap(["#F5F5F5", "#5AB4AC", "#284D48"])
260+
colorMap = stars_cmap()
261+
outlierColor = "red"
240262
norm = None
241263
case "divergent":
242264
if colorMap is None:
243-
colorMap = mkColormap(["#9A6E3A", "#C6A267", "#A9A9A9", "#4F938B", "#2C665A"])
265+
colorMap = divergent_cmap()
266+
outlierColor = "fuchsia"
244267
norm = CenteredNorm()
245268

246269
# Create patches using the corners of each tract.
@@ -262,23 +285,19 @@ def makePlot(
262285
mid_decs.append((corners[0][1] + corners[2][1]) / 2)
263286

264287
# Setup figure.
265-
fig, ax = plt.subplots(1, 1, figsize=self.figureSize, dpi=self.dpi)
266-
if self.autoAxesLimits:
267-
xlim, ylim = self._getAxesLimits(ras, decs)
268-
else:
269-
xlim, ylim = self.xLimits, self.yLimits
270-
ax.set_xlim(xlim)
271-
ax.set_ylim(ylim)
272-
ax.set_xlabel(self.xAxisLabel)
273-
ax.set_ylabel(self.yAxisLabel)
274-
ax.invert_xaxis()
275-
288+
fig = make_figure(dpi=300, figsize=(12, 3.5))
289+
set_rubin_plotstyle()
290+
gs = gridspec.GridSpec(1, 4)
291+
ax = fig.add_subplot(gs[:3])
276292
# Add colored patches showing tract metric values.
277293
patchCollection = PatchCollection(patches, cmap=colorMap, norm=norm)
278294
ax.add_collection(patchCollection)
279295

280296
# Define color bar range.
281-
med = np.nanmedian(colBarVals)
297+
if np.sum(np.isfinite(colBarVals)) > 0:
298+
med = np.nanmedian(colBarVals)
299+
else:
300+
med = np.nan
282301
sigmaMad = nanSigmaMad(colBarVals)
283302
if self.colorBarMin is not None:
284303
vmin = np.float64(self.colorBarMin)
@@ -289,8 +308,15 @@ def makePlot(
289308
else:
290309
vmax = med + self.colorBarRange * sigmaMad
291310

292-
# Note tracts with metrics outside (vmin, vmax) as outliers.
293-
outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0]
311+
dataName = self.zAxisLabel.format_map(kwargs)
312+
colBarVals = np.array(colBarVals)
313+
if self.addThreshold and dataName in metricDefs:
314+
lowThreshold = metricDefs[dataName]["lowThreshold"]
315+
highThreshold = metricDefs[dataName]["highThreshold"]
316+
outlierInds = np.where((colBarVals < lowThreshold) | (colBarVals > highThreshold))[0]
317+
else:
318+
# Note tracts with metrics outside (vmin, vmax) as outliers.
319+
outlierInds = np.where((colBarVals < vmin) | (colBarVals > vmax))[0]
294320

295321
# Initialize legend handles.
296322
handles = []
@@ -306,15 +332,15 @@ def makePlot(
306332
cmap=colorMap,
307333
norm=norm,
308334
facecolors="none",
309-
edgecolors="k",
335+
edgecolors=outlierColor,
310336
linewidths=0.5,
311337
zorder=100,
312338
)
313339
ax.add_collection(outlierPatchCollection)
314340
# Add legend information.
315341
outlierPatch = Patch(
316342
facecolor="none",
317-
edgecolor="k",
343+
edgecolor=outlierColor,
318344
linewidth=0.5,
319345
label="Outlier",
320346
)
@@ -365,13 +391,50 @@ def makePlot(
365391
zorder=100,
366392
)
367393

394+
ax.set_aspect("equals")
395+
axPos = ax.get_position()
396+
ax1 = fig.add_axes([0.73, 0.25, 0.20, 0.47])
397+
398+
if np.sum(np.isfinite(data["z"])) > 0:
399+
ax1.hist(data["z"], bins=len(data["z"] / 10), color=stars_color(), histtype="step")
400+
else:
401+
ax1.text(0.5, 0.5, "Data all NaN/Inf")
402+
ax1.set_xlabel("Metric Values")
403+
ax1.set_ylabel("Number")
404+
ax1.yaxis.set_label_position("right")
405+
ax1.yaxis.tick_right()
406+
407+
if self.addThreshold and dataName in metricDefs:
408+
ax1.axvline(lowThreshold, color=accent_color())
409+
ax1.axvline(highThreshold, color=accent_color())
410+
411+
widthThreshold = highThreshold - lowThreshold
412+
upperLim = highThreshold + 0.5 * widthThreshold
413+
lowerLim = lowThreshold - 0.5 * widthThreshold
414+
ax1.set_xlim(lowerLim, upperLim)
415+
numOutside = np.sum(((data["z"] > upperLim) | (data["z"] < lowerLim)))
416+
ax1.set_title("Outside plot limits: " + str(numOutside))
417+
418+
else:
419+
if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
420+
ax1.set_xlim(vmin, vmax)
421+
422+
if self.autoAxesLimits:
423+
xlim, ylim = self._getAxesLimits(ras, decs)
424+
else:
425+
xlim, ylim = self.xLimits, self.yLimits
426+
ax.set_xlim(xlim)
427+
ax.set_ylim(ylim)
428+
ax.set_xlabel(self.xAxisLabel)
429+
ax.set_ylabel(self.yAxisLabel)
430+
ax.invert_xaxis()
431+
368432
if self.showOutliers:
369433
# Add text boxes to show the number of tracts, number of NaNs,
370434
# median, sigma MAD, and the five largest outlier values.
371435
outlierText = self._getMaxOutlierVals(self.colorBarRange, tracts, colBarVals, outlierInds)
372-
373436
# Make vertical text spacing readable for different figure sizes.
374-
multiplier = 3.5 / self.figureSize[1]
437+
multiplier = 3.5 / fig.get_size_inches()[1]
375438
verticalSpacing = 0.028 * multiplier
376439
fig.text(
377440
0.01,
@@ -402,25 +465,27 @@ def makePlot(
402465
fig.text(0.01, 0.01, outlierText, transform=fig.transFigure, fontsize=8, alpha=0.7)
403466

404467
# Truncate the color range to (vmin, vmax).
405-
colorBarVals = np.clip(np.array(colBarVals), vmin, vmax)
406-
patchCollection.set_array(colorBarVals)
468+
if vmin != vmax and np.isfinite(vmin) and np.isfinite(vmax):
469+
colBarVals = np.clip(np.array(colBarVals), vmin, vmax)
470+
patchCollection.set_array(colBarVals)
407471
# Make the color bar with a metric label.
408-
cbar = plt.colorbar(
472+
axPos = ax.get_position()
473+
cax = fig.add_axes([0.084, axPos.y1 + 0.02, 0.62, 0.07])
474+
fig.colorbar(
409475
patchCollection,
410-
ax=ax,
476+
cax=cax,
411477
shrink=0.7,
412478
extend="both",
413479
location="top",
414480
orientation="horizontal",
415481
)
416-
cbarText = self.zAxisLabel.format_map(kwargs)
417-
if "zUnit" in data and data["zUnit"] != "":
418-
cbarText += f" ({data['zUnit']})"
419-
text = cbar.ax.text(
482+
cbarText = "Metric Values"
483+
484+
text = cax.text(
420485
0.5,
421486
0.5,
422487
cbarText,
423-
transform=cbar.ax.transAxes,
488+
transform=cax.transAxes,
424489
ha="center",
425490
va="center",
426491
fontsize=10,
@@ -431,9 +496,11 @@ def makePlot(
431496
# Finalize plot appearance.
432497
ax.grid()
433498
ax.set_axisbelow(True)
434-
ax.set_aspect("equal")
435-
fig = plt.gcf()
436499
fig = addPlotInfo(fig, plotInfo)
437-
plt.subplots_adjust(left=0.08, right=0.97, top=0.8, bottom=0.17, wspace=0.35)
500+
fig.subplots_adjust(left=0.08, right=0.92, top=0.8, bottom=0.17, wspace=0.05)
501+
titleText = self.zAxisLabel.format_map(kwargs)
502+
if "zUnit" in data and data["zUnit"] != "":
503+
titleText += f" ({data['zUnit']})"
504+
fig.suptitle("Metric: " + titleText, fontsize=20)
438505

439506
return fig

0 commit comments

Comments
 (0)