Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions colormaps/gsl_wind_speed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Colors used for wind speed as defined by NOAA-GSL's pygraf utility
plot_under: False
colors:
- '#fef8fe'
- '#f8d3f9'
- '#f1a5f3'
- '#e074f0'
- '#0045ff'
- '#0099ff'
- '#00ceff'
- '#00e8ff'
- '#00ffe6'
- '#67d300'
- '#7ffa06'
- '#b4ff36'
- '#eaff12'
- '#ffe500'
- '#ffc808'
- '#ff8608'
- '#ff3300'
- '#ff0039'
- '#f704fc'

29 changes: 29 additions & 0 deletions custom_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import logging
import uxarray as ux
import numpy as np

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,10 +63,38 @@ def vert_min(field: ux.UxDataArray, dim: str = "nVertLevels") -> ux.UxDataArray:
return vertmin


def sum_of_magnitudes(field1: ux.UxDataArray, field2: ux.UxDataArray) -> ux.UxDataArray:
"""
Take two vectors (usually wind vectors) and return the sum of the magnitudes
"""

return np.sqrt(np.square(field1) + np.square(field2))

def max_all_times(field: ux.UxDataArray, dim: str = "Time") -> ux.UxDataArray:
"""
Return the maximum value across all input times for a given point.
"""
# Compute differences along Time
result = field.max(dim=dim, keep_attrs=True)

return result

def min_all_times(field: ux.UxDataArray, dim: str = "Time") -> ux.UxDataArray:
"""
Return the minimum value across all input times for a given point.
"""
# Compute differences along Time
result = field.min(dim=dim, keep_attrs=True)

return result

DERIVED_FUNCTIONS = {
"diff_prev_timestep": diff_prev_timestep,
"sum_fields": sum_fields,
"vert_max": vert_max,
"vert_min": vert_min,
"sum_of_magnitudes": sum_of_magnitudes,
"max_all_times": max_all_times,
"min_all_times": min_all_times,
}

22 changes: 18 additions & 4 deletions default_options.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,18 @@ plot:
# {fnme} = Name of file (minus extension) being read for plotted data
# {date} = The date of plotted data, in %Y-%m-%d format
# {time} = The time of plotted data, in %H:%M:%S format
# {maxval} = The maximum value in the plotted data
# {minval} = The minimum value in the plotted data

filename: '{var}_{lev}.png'
format: null
title:
text: 'Plot of {varln}, level {lev} for MPAS forecast, {date} {time}'
fontsize: 8
exists: rename
dpi: 300
figheight: 4
figwidth: 8
dpi: 200
figheight: 3
figwidth: 6

# colormap:
# Color scheme to use for output plots. Options can either be standard Matplotlib colormaps (reference
Expand All @@ -130,10 +132,22 @@ plot:
#
colormap: "viridis"

# pixel_ratio:
# This controls the quantity of pixels to sample in the rasterization process; higher numbers result
# in higher quality plots, though at a cost of plotting speed.
pixel_ratio: 1

# polycollection:
# NOT RECOMMENDED
# This is the legacy plotting method that converts the unstructured grid to a set of polygons.
# This can be orders of magnitude slower than the default raster method and so is not
# recommended for large domains; if you need more detail in your plot, it's recommended to
# increase the "pixel_ratio" setting.
# periodic_bdy:
# For periodic domains (including global), the plot routines will omit the boundary cells by default. To plot
# For periodic domains (including global), the polycollection routines will omit the boundary cells by default. To plot
# all data, including boundaries, set this option to True, but note that it will slow down plotting substantially.
#
polycollection: False
periodic_bdy: False

# vmin, vmax:
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- numpy=1.26*
- matplotlib
- netcdf4
- xarray=2025.9.0
- cartopy
- uwtools=2.9*
- uxarray=2025.05*
- uxarray=2025.11*
73 changes: 72 additions & 1 deletion plot_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import traceback

import numpy as np
import cartopy.crs as ccrs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -47,6 +48,10 @@ def set_patterns_and_outfile(valid, var, lev, filepath, field, ftime, plotdict):
#filename minus extension
fnme=os.path.splitext(filename)[0]

# max and min values for plotted field
maxval=float(field.max().compute())
minval=float(field.min().compute())

pattern_dict = {
"var": var,
"lev": lev,
Expand All @@ -56,7 +61,9 @@ def set_patterns_and_outfile(valid, var, lev, filepath, field, ftime, plotdict):
"fnme": fnme,
"proj": plotdict["projection"]["projection"],
"date": "no_Time_dimension",
"time": "no_Time_dimension"
"time": "no_Time_dimension",
"maxval": f"{maxval:.2f}",
"minval": f"{minval:.2f}",
}
if field.attrs.get("units"):
pattern_dict.update({
Expand Down Expand Up @@ -220,3 +227,67 @@ def set_map_projection(confproj) -> ccrs.Projection:

raise ValueError(f"Invalid projection {proj} specified; valid options are:\n{valid}")


def get_data_extent_raster(raster, lon_bounds=(-180, 180), lat_bounds=(-90, 90)):
"""
Computes data extent from image raster for automatic zooming to data domain

Parameters
----------
raster : np.ndarray
2D raster array with NaNs outside valid region
lon_bounds : tuple(float, float)
Longitude range corresponding to full raster width
lat_bounds : tuple(float, float)
Latitude range corresponding to full raster height

Returns
-------
extent : list [lon_min, lon_max, lat_min, lat_max]
"""
valid = ~np.isnan(raster)
if not np.any(valid):
# no data at all
return lon_bounds + lat_bounds

# pixel indices of valid data
ys, xs = np.where(valid)

# convert indices to lon/lat using proportional scaling
nrows, ncols = raster.shape
lon_min, lon_max = lon_bounds
lat_min, lat_max = lat_bounds

x_min = lon_min + (xs.min() / ncols) * (lon_max - lon_min)
x_max = lon_min + (xs.max() / ncols) * (lon_max - lon_min)
y_min = lat_max - (ys.max() / nrows) * (lat_max - lat_min)
y_max = lat_max - (ys.min() / nrows) * (lat_max - lat_min)

pad_fraction=0.05
dx = (x_max - x_min) * pad_fraction
dy = (y_max - y_min) * pad_fraction
# y dimension is flipped for some reason
return [x_min - dx, x_max + dx, -y_max - dy, -y_min + dy]


def get_data_extent(uxda, pad_fraction=0.05):
"""Return (lon_min, lon_max, lat_min, lat_max) in degrees, with buffer."""
try:
if "n_face" in uxda.dims:
lons = getattr(uxda.uxgrid, "node_lon", None)
lats = getattr(uxda.uxgrid, "node_lat", None)
else:
lons = uxda.lon
lats = uxda.lat

lon_min = np.nanmin(lons)
lon_max = np.nanmax(lons)
lat_min = np.nanmin(lats)
lat_max = np.nanmax(lats)

dx = (lon_max - lon_min) * pad_fraction
dy = (lat_max - lat_min) * pad_fraction

return [lon_min - dx, lon_max + dx, lat_min - dy, lat_max + dy]
except Exception as e:
raise RuntimeError(f"Could not determine lat/lon bounds: {e}")
Loading