2323
2424__all__ = ("WholeSkyPlot" ,)
2525
26+ import importlib .resources as importResources
2627from typing import Mapping , Optional
2728
29+ import lsst .analysis .tools
2830import matplotlib .patheffects as pathEffects
29- import matplotlib .pyplot as plt
3031import numpy as np
32+ import yaml
3133from lsst .pex .config import ChoiceField , Field , ListField
34+ from lsst .utils .plotting import (
35+ accent_color ,
36+ divergent_cmap ,
37+ make_figure ,
38+ set_rubin_plotstyle ,
39+ stars_cmap ,
40+ stars_color ,
41+ )
42+ from matplotlib import gridspec
3243from matplotlib .collections import PatchCollection
3344from matplotlib .colors import CenteredNorm
3445from matplotlib .figure import Figure
3748from ...interfaces import KeyedData , KeyedDataSchema , PlotAction , Scalar , Vector
3849from ...math import nanSigmaMad
3950from ...utils import getTractCorners
40- from .plotUtils import addPlotInfo , mkColormap
51+ from .plotUtils import addPlotInfo
4152
4253
4354class WholeSkyPlot (PlotAction ):
@@ -58,8 +69,6 @@ class WholeSkyPlot(PlotAction):
5869 xLimits = ListField [float ](doc = "Plotting limits for the x axis." , default = [- 5.0 , 365.0 ])
5970 yLimits = ListField [float ](doc = "Plotting limits for the y axis." , default = [- 10.0 , 60.0 ])
6071 autoAxesLimits = Field [bool ](doc = "Find axes limits automatically." , default = True )
61- dpi = Field [int ](doc = "DPI size of the figure." , default = 500 )
62- figureSize = ListField [float ](doc = "Size of the figure." , default = [9.0 , 3.5 ])
6372 colorBarMin = Field [float ](doc = "The minimum value of the color bar." , optional = True )
6473 colorBarMax = Field [float ](doc = "The minimum value of the color bar." , optional = True )
6574 colorBarRange = Field [float ](
@@ -84,6 +93,11 @@ class WholeSkyPlot(PlotAction):
8493 showNaNs = Field [bool ](doc = "Show the NaNs on the plot." , default = True )
8594 labelTracts = Field [bool ](doc = "Label the tracts." , default = False )
8695
96+ addThreshold = Field [bool ](
97+ doc = "Read in the predefined threshold and add it to the histogram." ,
98+ default = True ,
99+ )
100+
87101 def getInputSchema (self , ** kwargs ) -> KeyedDataSchema :
88102 base = []
89103 base .append (("z" , Vector ))
@@ -161,15 +175,18 @@ def _getMaxOutlierVals(self, multiplier: float, tracts: list, values: list, outl
161175 text : `str`
162176 A string containing the 5 tracts with the largest outlier values.
163177 """
164- text = f"Tracts with |value| > { multiplier } " + r"$\sigma_{MAD}$" + ": "
178+ if self .addThreshold :
179+ text = "Tracts with value outside thresholds: "
180+ else :
181+ text = f"Tracts with |value| > { multiplier } " + r"$\sigma_{MAD}$" + ": "
165182 if len (outlierInds ) > 0 :
166183 outlierValues = np .array (values )[outlierInds ]
167184 outlierTracts = np .array (tracts )[outlierInds ]
168185 # Sort values in descending (-) absolute value order discounting
169186 # NaNs.
170187 maxInds = np .argsort (- np .abs (outlierValues ))
171188 # Show up to five values on the plot.
172- for ind in maxInds [:5 ]:
189+ for ind in maxInds [:10 ]:
173190 val = outlierValues [ind ]
174191 tract = outlierTracts [ind ]
175192 text += f"{ tract } , { val :.3} ; "
@@ -225,6 +242,10 @@ def makePlot(
225242 if plotInfo is None :
226243 plotInfo = {}
227244
245+ if self .addThreshold :
246+ metricThresholdFile = importResources .read_text (lsst .analysis .tools , "metricInformation.yaml" )
247+ metricDefs = yaml .safe_load (metricThresholdFile )
248+
228249 # Prevent Bands in the plot info showing a list of bands.
229250 # If bands is a list, it implies that parameterizedBand=False,
230251 # and that the metric is not band-specific.
@@ -236,11 +257,13 @@ def makePlot(
236257 match self .colorMapType :
237258 case "sequential" :
238259 if colorMap is None :
239- colorMap = mkColormap (["#F5F5F5" , "#5AB4AC" , "#284D48" ])
260+ colorMap = stars_cmap ()
261+ outlierColor = "red"
240262 norm = None
241263 case "divergent" :
242264 if colorMap is None :
243- colorMap = mkColormap (["#9A6E3A" , "#C6A267" , "#A9A9A9" , "#4F938B" , "#2C665A" ])
265+ colorMap = divergent_cmap ()
266+ outlierColor = "fuchsia"
244267 norm = CenteredNorm ()
245268
246269 # Create patches using the corners of each tract.
@@ -262,23 +285,19 @@ def makePlot(
262285 mid_decs .append ((corners [0 ][1 ] + corners [2 ][1 ]) / 2 )
263286
264287 # Setup figure.
265- fig , ax = plt .subplots (1 , 1 , figsize = self .figureSize , dpi = self .dpi )
266- if self .autoAxesLimits :
267- xlim , ylim = self ._getAxesLimits (ras , decs )
268- else :
269- xlim , ylim = self .xLimits , self .yLimits
270- ax .set_xlim (xlim )
271- ax .set_ylim (ylim )
272- ax .set_xlabel (self .xAxisLabel )
273- ax .set_ylabel (self .yAxisLabel )
274- ax .invert_xaxis ()
275-
288+ fig = make_figure (dpi = 300 , figsize = (12 , 3.5 ))
289+ set_rubin_plotstyle ()
290+ gs = gridspec .GridSpec (1 , 4 )
291+ ax = fig .add_subplot (gs [:3 ])
276292 # Add colored patches showing tract metric values.
277293 patchCollection = PatchCollection (patches , cmap = colorMap , norm = norm )
278294 ax .add_collection (patchCollection )
279295
280296 # Define color bar range.
281- med = np .nanmedian (colBarVals )
297+ if np .sum (np .isfinite (colBarVals )) > 0 :
298+ med = np .nanmedian (colBarVals )
299+ else :
300+ med = np .nan
282301 sigmaMad = nanSigmaMad (colBarVals )
283302 if self .colorBarMin is not None :
284303 vmin = np .float64 (self .colorBarMin )
@@ -289,8 +308,15 @@ def makePlot(
289308 else :
290309 vmax = med + self .colorBarRange * sigmaMad
291310
292- # Note tracts with metrics outside (vmin, vmax) as outliers.
293- outlierInds = np .where ((colBarVals < vmin ) | (colBarVals > vmax ))[0 ]
311+ dataName = self .zAxisLabel .format_map (kwargs )
312+ colBarVals = np .array (colBarVals )
313+ if self .addThreshold and dataName in metricDefs :
314+ lowThreshold = metricDefs [dataName ]["lowThreshold" ]
315+ highThreshold = metricDefs [dataName ]["highThreshold" ]
316+ outlierInds = np .where ((colBarVals < lowThreshold ) | (colBarVals > highThreshold ))[0 ]
317+ else :
318+ # Note tracts with metrics outside (vmin, vmax) as outliers.
319+ outlierInds = np .where ((colBarVals < vmin ) | (colBarVals > vmax ))[0 ]
294320
295321 # Initialize legend handles.
296322 handles = []
@@ -306,15 +332,15 @@ def makePlot(
306332 cmap = colorMap ,
307333 norm = norm ,
308334 facecolors = "none" ,
309- edgecolors = "k" ,
335+ edgecolors = outlierColor ,
310336 linewidths = 0.5 ,
311337 zorder = 100 ,
312338 )
313339 ax .add_collection (outlierPatchCollection )
314340 # Add legend information.
315341 outlierPatch = Patch (
316342 facecolor = "none" ,
317- edgecolor = "k" ,
343+ edgecolor = outlierColor ,
318344 linewidth = 0.5 ,
319345 label = "Outlier" ,
320346 )
@@ -365,13 +391,50 @@ def makePlot(
365391 zorder = 100 ,
366392 )
367393
394+ ax .set_aspect ("equals" )
395+ axPos = ax .get_position ()
396+ ax1 = fig .add_axes ([0.73 , 0.25 , 0.20 , 0.47 ])
397+
398+ if np .sum (np .isfinite (data ["z" ])) > 0 :
399+ ax1 .hist (data ["z" ], bins = len (data ["z" ] / 10 ), color = stars_color (), histtype = "step" )
400+ else :
401+ ax1 .text (0.5 , 0.5 , "Data all NaN/Inf" )
402+ ax1 .set_xlabel ("Metric Values" )
403+ ax1 .set_ylabel ("Number" )
404+ ax1 .yaxis .set_label_position ("right" )
405+ ax1 .yaxis .tick_right ()
406+
407+ if self .addThreshold and dataName in metricDefs :
408+ ax1 .axvline (lowThreshold , color = accent_color ())
409+ ax1 .axvline (highThreshold , color = accent_color ())
410+
411+ widthThreshold = highThreshold - lowThreshold
412+ upperLim = highThreshold + 0.5 * widthThreshold
413+ lowerLim = lowThreshold - 0.5 * widthThreshold
414+ ax1 .set_xlim (lowerLim , upperLim )
415+ numOutside = np .sum (((data ["z" ] > upperLim ) | (data ["z" ] < lowerLim )))
416+ ax1 .set_title ("Outside plot limits: " + str (numOutside ))
417+
418+ else :
419+ if vmin != vmax and np .isfinite (vmin ) and np .isfinite (vmax ):
420+ ax1 .set_xlim (vmin , vmax )
421+
422+ if self .autoAxesLimits :
423+ xlim , ylim = self ._getAxesLimits (ras , decs )
424+ else :
425+ xlim , ylim = self .xLimits , self .yLimits
426+ ax .set_xlim (xlim )
427+ ax .set_ylim (ylim )
428+ ax .set_xlabel (self .xAxisLabel )
429+ ax .set_ylabel (self .yAxisLabel )
430+ ax .invert_xaxis ()
431+
368432 if self .showOutliers :
369433 # Add text boxes to show the number of tracts, number of NaNs,
370434 # median, sigma MAD, and the five largest outlier values.
371435 outlierText = self ._getMaxOutlierVals (self .colorBarRange , tracts , colBarVals , outlierInds )
372-
373436 # Make vertical text spacing readable for different figure sizes.
374- multiplier = 3.5 / self . figureSize [1 ]
437+ multiplier = 3.5 / fig . get_size_inches () [1 ]
375438 verticalSpacing = 0.028 * multiplier
376439 fig .text (
377440 0.01 ,
@@ -402,25 +465,27 @@ def makePlot(
402465 fig .text (0.01 , 0.01 , outlierText , transform = fig .transFigure , fontsize = 8 , alpha = 0.7 )
403466
404467 # Truncate the color range to (vmin, vmax).
405- colorBarVals = np .clip (np .array (colBarVals ), vmin , vmax )
406- patchCollection .set_array (colorBarVals )
468+ if vmin != vmax and np .isfinite (vmin ) and np .isfinite (vmax ):
469+ colBarVals = np .clip (np .array (colBarVals ), vmin , vmax )
470+ patchCollection .set_array (colBarVals )
407471 # Make the color bar with a metric label.
408- cbar = plt .colorbar (
472+ axPos = ax .get_position ()
473+ cax = fig .add_axes ([0.084 , axPos .y1 + 0.02 , 0.62 , 0.07 ])
474+ fig .colorbar (
409475 patchCollection ,
410- ax = ax ,
476+ cax = cax ,
411477 shrink = 0.7 ,
412478 extend = "both" ,
413479 location = "top" ,
414480 orientation = "horizontal" ,
415481 )
416- cbarText = self .zAxisLabel .format_map (kwargs )
417- if "zUnit" in data and data ["zUnit" ] != "" :
418- cbarText += f" ({ data ['zUnit' ]} )"
419- text = cbar .ax .text (
482+ cbarText = "Metric Values"
483+
484+ text = cax .text (
420485 0.5 ,
421486 0.5 ,
422487 cbarText ,
423- transform = cbar . ax .transAxes ,
488+ transform = cax .transAxes ,
424489 ha = "center" ,
425490 va = "center" ,
426491 fontsize = 10 ,
@@ -431,9 +496,11 @@ def makePlot(
431496 # Finalize plot appearance.
432497 ax .grid ()
433498 ax .set_axisbelow (True )
434- ax .set_aspect ("equal" )
435- fig = plt .gcf ()
436499 fig = addPlotInfo (fig , plotInfo )
437- plt .subplots_adjust (left = 0.08 , right = 0.97 , top = 0.8 , bottom = 0.17 , wspace = 0.35 )
500+ fig .subplots_adjust (left = 0.08 , right = 0.92 , top = 0.8 , bottom = 0.17 , wspace = 0.05 )
501+ titleText = self .zAxisLabel .format_map (kwargs )
502+ if "zUnit" in data and data ["zUnit" ] != "" :
503+ titleText += f" ({ data ['zUnit' ]} )"
504+ fig .suptitle ("Metric: " + titleText , fontsize = 20 )
438505
439506 return fig
0 commit comments