# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2026 DISDRODB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------.
"""DISDRODB Plotting Tools."""
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import xarray as xr
from matplotlib.colors import BoundaryNorm, ListedColormap, LogNorm, Normalize
from matplotlib.gridspec import GridSpec
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.ticker import MaxNLocator
from disdrodb.constants import DIAMETER_DIMENSION, VELOCITY_DIMENSION
from disdrodb.l2.empirical_dsd import get_drop_average_velocity
from disdrodb.l2.processing import get_mask_contour, get_spectrum_mask_boundary
# TODO FIX: XARRAY PCOLORMESH IS CURRENTLY INACCURATE
# IMPROVEMENTS
# plot_filtering_boundary()
# plot_spectrum(add theoreticall_fall_velocity,
# add_tolerance_fraction=0.5,
# add_tolerance_=0.5)
# Add plot_raw_and_filtered_spectra(animation=True, legend_variables on first axes)
# TODO: plot_l0_quicklook
## - plot_l0_quicklook # remap weather codes to hydrometeor_type, find R if available
####-------------------------------------------------------------------------------------------------------
[docs]
def get_precipitation_legend_style(name):
"""Return (colors_dict, labels_dict) for supported categorical variables.
Supported:
- rain_type
- precipitation_type
- hydrometeor_type
"""
if name == "rain_type":
colors = {
0: "white",
1: "dodgerblue", # stratiform
2: "orangered", # convective
}
labels = {
0: "No precipitation",
1: "Stratiform",
2: "Convective",
}
elif name == "precipitation_type":
colors = {
-2: "lightgray", # undefined
-1: "white", # no_precip
0: "#0050b5", # rainfall (deep blue)
1: "#00bcd4", # snowfall (cyan)
2: "#8e44ad", # mixed_phase (purple)
}
labels = {
-2: "Undefined",
-1: "No precipitation",
0: "Rainfall",
1: "Snowfall",
2: "Mixed phase",
}
elif name == "hydrometeor_type":
colors = {
-2: "white", # no_hydrometeor
-1: "lightgray", # undefined
0: "white", # no_precipitation
# 🌧 Liquid family (blue gradient)
1: "#4a90e2", # drizzle (light blue)
2: "#0057d9", # drizzle + rain (strong blue)
3: "#002f8a", # rain (deep royal blue)
# 1: "#6baed6", # drizzle (light blue)
# 2: "#2171b5", # drizzle + rain (medium blue)
# 3: "#08306b", # rain (deep navy blue)
# 🟣 Mixed
4: "#9b59b6", # mixed (purple)
# ❄️ Frozen family (teal / aqua gradient — shifted hue)
5: "#1abc6c", # snow (green-teal)
6: "#9be7c4", # snow grains (pale mint)
# 5: "#1abc9c", # snow (teal)
# 6: "#76eec6", # snow grains (light aqua)
# đźź Ice family
7: "#f39c12", # ice pellets
8: "#e67e22", # graupel
# đź”´ Severe
9: "#c0392b", # hail
# 1: "lightskyblue", # drizzle
# 2: "dodgerblue", # drizzle_and_rain
# 3: "royalblue", # rain
# 4: "mediumorchid", # mixed
# 5: "deepskyblue", # snow
# 6: "paleturquoise", # snow_grains
# 7: "darkorange", # ice_pellets
# 8: "orange", # graupel
# 9: "red", # hail
}
labels = {
-2: "No hydrometeor",
-1: "Undefined",
0: "No precipitation",
1: "Drizzle",
2: "Drizzle + Rain",
3: "Rain",
4: "Mixed",
5: "Snow",
6: "Snow grains",
7: "Ice pellets",
8: "Graupel",
9: "Hail",
}
else:
raise ValueError(f"Unsupported categorical variable: {name}")
return colors, labels
####-------------------------------------------------------------------------------------------------------
#### N(D) and n(D) visualizations
L0_DSD_VARIABLES = ["raw_drop_concentration"]
L1_DSD_VARIABLES = ["raw_particle_number_concentration", "raw_particle_counts"]
L2_DSD_VARIABLES = ["drop_number_concentration", "drop_counts", "raw_drop_number_concentration", "raw_drop_counts"]
DSD_VARIABLES = [*L2_DSD_VARIABLES, *L1_DSD_VARIABLES, *L0_DSD_VARIABLES]
DSD_LABEL_DICT = {
# L0
"raw_spectrum": "$n_{raw}(D,V)$ [#]", # FUTURE
"raw_drop_concentration": "$N_{raw}(D)$ [# $m^{-3} mm^{-1}$]", # FUTURE: raw_particle_number_concentration
# L1
# raw_spectrum
"raw_particle_counts": "$n_{raw}(D)$ [#]",
"raw_particle_number_concentration": "$N_{raw}(D)$ [# $m^{-3} mm^{-1}$]",
# L2
"raw_drop_number": "$n_{raw}(D,V)$ [#]", # FUTURE: raw spectrum
"drop_number": "$n(D,V)$ [#]", # FUTURE: spectrum
"raw_drop_counts": "$n_{raw}(D)$ [#]",
"drop_counts": "$n(D)$ [#]",
"raw_drop_number_concentration": "$N_{raw}(D)$ [# $m^{-3} mm^{-1}$]",
"drop_number_concentration": "$N(D)$ [# $m^{-3} mm^{-1}$]",
}
DSD_TITLE_DICT = {
# L0
"raw_spectrum": "Raw spectrum n(D,V)",
"raw_drop_concentration": "Raw particle number concentration N(D)", # FUTURE: raw_particle_number_concentration
# L1
"raw_particle_counts": "Raw particle counts n(D)",
"raw_particle_number_concentration": "Raw particle number concentration N(D)",
# L2
"raw_drop_number": "Raw spectrum n(D,V)",
"drop_number": "Spectrum n(D,V)",
"raw_drop_counts": "Raw drop counts n(D)",
"drop_counts": "Drop counts n(D)",
"raw_number_concentration": "Raw drop number concentration N(D)",
"drop_number_concentration": "Drop number concentration N(D)",
}
def _get_dsd_labels(variable_name, da=None):
"""Get appropriate labels based on the variable type.
Parameters
----------
variable_name : str
Name of the variable.
da : xr.DataArray, optional
DataArray to extract units from attributes if variable is generic.
Returns
-------
dict
Dictionary with 'ylabel', 'cbar_label', and 'title' keys.
"""
# Retrieve for DISDRODB N(D) or n(D) variables
if variable_name in DSD_LABEL_DICT:
return {
"ylabel": DSD_LABEL_DICT[variable_name],
"cbar_label": DSD_LABEL_DICT[variable_name],
"title": DSD_TITLE_DICT[variable_name],
}
if variable_name is None and da.name is None:
raise ValueError("Variable name is not provided and DataArray has no name. Cannot determine labels.")
# Generic fallback - try to extract units from DataArray attributes
units = None
if da is not None and hasattr(da, "attrs") and "units" in da.attrs:
units = da.attrs["units"]
# Handle dimensionless units
if units in {"1", ""}:
units = "#"
variable_name = variable_name.replace("_", " ").capitalize()
if units:
ylabel = f"{variable_name} [{units}]"
cbar_label = f"{variable_name} [{units}]"
else:
ylabel = variable_name
cbar_label = variable_name
return {
"ylabel": ylabel,
"cbar_label": cbar_label,
"title": f"{variable_name} distribution",
}
def _single_plot_dsd_distribution(
data,
diameter,
diameter_bin_width,
variable_name="drop_number_concentration",
ax=None,
yscale="linear",
):
if ax is None:
fig, ax = plt.subplots(1, 1)
labels = _get_dsd_labels(variable_name, da=data)
ax.bar(
diameter,
data,
width=diameter_bin_width,
edgecolor="darkgray",
color="lightgray",
label="Data",
)
ax.set_xlim(diameter[0] - diameter_bin_width[0] / 2, None)
ax.set_title(labels["title"])
ax.set_xlabel("Drop diameter (mm)")
ax.set_ylabel(labels["ylabel"])
ax.set_yscale(yscale)
return ax
def _check_has_diameter_dims(da, diameter_dim):
if diameter_dim not in da.dims:
raise ValueError(f"The DataArray must have dimension '{DIAMETER_DIMENSION}'.")
if "diameter_bin_width" not in da.coords:
raise ValueError("The DataArray must have coordinate 'diameter_bin_width'.")
return da
[docs]
def get_dataset_dsd_variable_name(ds, variables=None):
"""Return N(D) or n(D) variable name present in the xarray.Dataset."""
variables = DSD_VARIABLES if variables is None else variables
# Search for candidate variables
variable = None
for var in variables:
if var in ds:
variable = var
break
if variable is None:
raise ValueError(
"Any n(D) or N(D) variable found in the dataset. "
f"Searched for any of {variables} variables. "
"Please specify the variable explicitly.",
)
return variable
[docs]
def get_dsd_variable(xr_obj, variable=None, diameter_dim=DIAMETER_DIMENSION):
"""Return N(D), n(d) DataArray.
If N(D) or n(d) not available, derive n(d) from n(D,V).
Parameters
----------
xr_obj : xr.Dataset or xr.DataArray
Input xarray object.
variable : str, optional
Variable name to extract from the xarray object.
If xr_obj is a DataArray, will return the DataArray and its name directly'.
If xr_obj is a Dataset, if None, will search for candidate variables in order:
['drop_number_concentration', 'drop_counts', 'raw_drop_counts',
'raw_particle_counts', 'raw_drop_concentration'].
Returns
-------
tuple
(DataArray, variable_name)
"""
if not isinstance(xr_obj, (xr.Dataset, xr.DataArray)):
raise TypeError("Expecting xarray object as input.")
if isinstance(xr_obj, xr.DataArray):
# If DataArray provided, use it directly
da = xr_obj
else:
# xr.Dataset provided
if variable is None:
variable = get_dataset_dsd_variable_name(xr_obj, variables=[*DSD_VARIABLES, "raw_drop_number"])
elif variable not in xr_obj:
raise ValueError(f"The dataset does not include {variable=}.")
# Extract DataArray
da = xr_obj[variable]
# Deal with raw_drop_number (in future raw_spectrum)
# --> Compute raw_particle_counts on the fly
if variable == "raw_drop_number" and VELOCITY_DIMENSION in da.dims:
da = da.sum(dim=VELOCITY_DIMENSION)
da.name = "raw_particle_counts"
if VELOCITY_DIMENSION in da.dims:
raise ValueError("N(D) must not have the velocity dimension.")
da = _check_has_diameter_dims(da, diameter_dim=diameter_dim)
return da
[docs]
def plot_dsd(
xr_obj,
variable=None,
cmap=None,
norm=None,
yscale="linear",
ax=None,
velocity_method="theoretical_velocity",
):
"""Plot drop number concentration N(D) or drop counts n(D) timeseries.
Parameters
----------
xr_obj : xr.Dataset or xr.DataArray
Input xarray object containing drop data.
variable : str, optional
Variable name to plot. If None and xr_obj is a Dataset,
will search for candidate variables in order:
['drop_number_concentration', 'raw_particle_counts', 'drop_counts'].
If xr_obj is a DataArray, it will be plotted directly.
cmap : matplotlib colormap, optional
Colormap to use for the plot.
norm : matplotlib normalization, optional
Normalization for the colormap.
yscale : str, optional
Scale for y-axis ('linear' or 'log'). Default is 'linear'.
ax : matplotlib axes, optional
Axes to plot on.
velocity_method : str, optional
If the dataset has a velocity_method dimension, select the method to use for plotting.
The default is "theoretical_velocity".
Returns
-------
matplotlib axes or plot object
"""
# Select velocity_method=0 if velocity_method dimension exists (e.g. for L2E products)
if "velocity_method" in xr_obj.dims:
xr_obj = xr_obj.sel(velocity_method=velocity_method)
# Retrieve N(D) or n(D)
da_dsd = get_dsd_variable(xr_obj, variable=variable)
da_dsd = da_dsd.compute()
# Check not empty object
if da_dsd.size == 0:
raise ValueError("No data to plot.")
# Retrieve label
variable_name = da_dsd.name
labels = _get_dsd_labels(variable_name, da=da_dsd)
# Check only time and diameter dimensions are specified
if "time" not in da_dsd.dims:
ax = _single_plot_dsd_distribution(
data=da_dsd.isel(velocity_method=0, missing_dims="ignore"),
diameter=(
xr_obj["diameter_bin_center"] if isinstance(xr_obj, xr.Dataset) else da_dsd["diameter_bin_center"]
),
diameter_bin_width=(
xr_obj["diameter_bin_width"] if isinstance(xr_obj, xr.Dataset) else da_dsd["diameter_bin_width"]
),
variable_name=variable_name,
yscale=yscale,
ax=ax,
)
return ax
# Regularize input if sample_interval is available to ensure consistent time steps
if "sample_interval" in da_dsd.coords:
da_dsd = da_dsd.disdrodb.regularize()
# Set 0 values to np.nan
da_dsd = da_dsd.where(da_dsd > 0)
# Define cmap an norm
if cmap is None:
cmap = plt.get_cmap("Spectral_r").copy()
if norm is None:
vmin = np.maximum(da_dsd.min().item(), 1e-1) # 0 is set to np.nan before
norm = Normalize() if np.isnan(vmin) else LogNorm(vmin, None)
# Plot N(D) or drop counts
cbar_kwargs = {"label": labels["cbar_label"]}
p = da_dsd.plot.pcolormesh(x="time", norm=norm, cmap=cmap, extend="max", cbar_kwargs=cbar_kwargs, ax=ax)
p.axes.set_title(labels["title"])
p.axes.set_ylabel("Drop diameter (mm)")
# Improve time axis ticks/labels ---
locator = mdates.AutoDateLocator(minticks=4, maxticks=8)
formatter = mdates.ConciseDateFormatter(locator) # compact, avoids repetition
p.axes.xaxis.set_major_locator(locator)
p.axes.xaxis.set_major_formatter(formatter)
# Nice rotation/alignment if still dense
p.axes.figure.autofmt_xdate(rotation=30, ha="right")
# Optional: avoid clipping of labels
p.axes.figure.tight_layout()
return p
[docs]
def plot_l1_dsd_quicklook(xr_obj, precipitation_type="precipitation_type", **kwargs):
"""Define L1 DSD default quicklook."""
fig = plot_dsd_quicklook(xr_obj, precipitation_type=precipitation_type, **kwargs)
return fig
[docs]
def plot_l2_dsd_quicklook(
xr_obj,
precipitation_type="rain_type",
secondary_var="R",
secondary_label=None,
secondary_ylim=None,
secondary_hlines=None,
secondary_yscale=None,
secondary_linestyle=":",
**kwargs,
):
"""Define L2 DSD default quicklook."""
from disdrodb.l2.empirical_dsd import bringi_nw_dm_classification
bottom_right_str = None
if isinstance(xr_obj, xr.Dataset):
# Compute event Rmax, Ptot, and define legend string
if "R" in xr_obj and "P" in xr_obj:
r_max = xr_obj["R"].max().item()
p_tot = xr_obj["P"].sum().item()
bottom_right_str = f"$R_{{MAX}}$={r_max:.1f} mm/h $P_{{TOT}}$={p_tot:.1f} mm"
# Define secondary variable
if secondary_var == "R" and secondary_var in xr_obj:
secondary_ylim = secondary_ylim if secondary_ylim is not None else (0.1, 200)
secondary_yscale = secondary_yscale if secondary_yscale is not None else "log"
secondary_label = secondary_label if secondary_label is not None else r"R [$mm hr^{-1}$]"
secondary_hlines = secondary_hlines if secondary_hlines is not None else (1, 10, 100)
# Compute precipitation type
if precipitation_type is not None:
if precipitation_type == "rain_type" and precipitation_type not in xr_obj:
if "Dm" in xr_obj and "Nw" in xr_obj:
xr_obj["rain_type"] = bringi_nw_dm_classification(xr_obj["Dm"], xr_obj["Nw"])
else:
precipitation_type = None
if precipitation_type not in xr_obj:
raise ValueError(f"Specified precipitation_type {precipitation_type} is not at Dataset variable.")
# Plot quicklook
fig = plot_dsd_quicklook(
xr_obj,
precipitation_type=precipitation_type,
secondary_var=secondary_var,
secondary_ylim=secondary_ylim,
secondary_yscale=secondary_yscale,
secondary_linestyle=secondary_linestyle,
secondary_hlines=secondary_hlines,
bottom_right_str=bottom_right_str,
**kwargs,
)
return fig
[docs]
def plot_dsd_quicklook(
xr_obj,
# Plot layout
hours_per_slice=3,
max_rows=6,
aligned=True,
verbose=False,
# Spectrum options
variable=None,
cbar_label=None,
cmap=None,
norm=None,
# Diameter axis options
d_dim=DIAMETER_DIMENSION,
d_lim=(0.3, 5.5),
d_label="Diameter [mm]",
# Colorbar options
cbar_as_legend=True,
cbar_xpos=0.73,
cbar_width=0.25,
# Bottom string
bottom_right_str=None,
bottom_right_str_fontsize=8,
# Secondary time-series overlay options
secondary_var=None,
secondary_ylim=None,
secondary_yscale="linear",
secondary_color="black",
secondary_alpha=1,
secondary_linewidth=1,
secondary_linestyle="-",
secondary_label=None,
secondary_hlines=None,
# Precipitation type options
precipitation_type=None,
precipitation_legend_fontsize=7,
precipitation_legend_colors=None,
precipitation_legend_labels=None,
precipitation_legend_ncol=None,
precipitation_legend_height=0.3, # inches
# Diameter axis variables
add_dm=True,
add_sigma_m=True,
sigma_label=r"$2\sigma_m$",
sigma_linewidth=0.5,
dm_linewidth=0.5,
# Figure options
dpi=300,
):
"""Display multi-rows quicklook of N(D)."""
from pycolorbar.utils.mpl_legend import add_fancybox, get_tightbbox_position
# Figure settings
plt.rcParams.update(
{
"font.size": 8,
"axes.labelsize": 8,
"axes.titlesize": 8,
"xtick.labelsize": 7,
"ytick.labelsize": 7,
},
)
legend_fontsize = 8
title_fontsize = 8
side_title_fontsize = 8
axis_label_fontsize = 8
cbar_fontsize = 9
cbar_labelpad = 6
cbar_ypos = 0.7
time_ticklabel_pad = 2
# ------------------------------------------------------------------------.
# Ensure to create Dataset if input is DataArray
if isinstance(xr_obj, xr.Dataset):
ds = xr_obj
elif isinstance(xr_obj, xr.DataArray):
variable = xr_obj.name
variable = "unknown DSD" if variable is None else variable
ds = xr_obj.to_dataset(name=variable)
# Set options to False
precipitation_type = None
secondary_var = None
add_dm = False
add_sigma_m = False
else:
raise TypeError("Expecting xarray.Dataset or xarray.DataArray as input.")
# ------------------------------------------------------------------------.
# Validate generic secondary axis and precipitation classification options
if secondary_var is not None:
if not isinstance(secondary_var, str):
raise TypeError("secondary_var must be a string or None.")
if secondary_var not in ds:
raise ValueError(f"{secondary_var} not found in dataset.")
if secondary_ylim is not None and not (isinstance(secondary_ylim, (tuple, list)) and len(secondary_ylim) == 2):
raise ValueError("secondary_ylim must be a tuple/list of length 2.")
if secondary_hlines is not None and not isinstance(secondary_hlines, (list, tuple)):
raise TypeError("secondary_hlines must be a list or tuple.")
if precipitation_type is not None:
if not isinstance(precipitation_type, str):
raise TypeError("precipitation_type must be a string or None.")
if precipitation_type not in ds:
raise ValueError(f"{precipitation_type} not found in dataset.")
# ------------------------------------------------------------------------.
# Disable overlays when required variables are missing
if "Dm" not in ds:
add_dm = False
if "sigma_m" not in ds:
add_sigma_m = False
# ------------------------------------------------------------------------.
# Select velocity_method
if "velocity_method" in ds.dims:
velocity_method = ds["velocity_method"].to_numpy()[0]
print(f"Selecting velocity_method '{velocity_method}")
ds = ds.sel(velocity_method=velocity_method)
# ------------------------------------------------------------------------.
# Derive N(D) variable
da_dsd = get_dsd_variable(ds, variable=variable, diameter_dim=d_dim)
variable = da_dsd.name
ds[da_dsd.name] = da_dsd # might have computed n(d) on-the-fly from N(D,V)
# ------------------------------------------------------------------------.
# Define precipitation type classification (colors and legend)
add_precipitation_legend = precipitation_type is not None
if add_precipitation_legend:
if precipitation_legend_colors is None:
precipitation_legend_colors, precipitation_legend_labels = get_precipitation_legend_style(
precipitation_type,
)
if precipitation_legend_ncol is None:
precipitation_legend_ncol = len(precipitation_legend_labels)
# ------------------------------------------------------------------------.
# Colormap & normalization
if cmap is None:
cmap = plt.get_cmap("Spectral_r").copy()
cmap.set_under("none")
if norm is None:
norm = LogNorm(vmin=1, vmax=10_000)
# ------------------------------------------------------------------------.
# Define cbar label
if cbar_label is None:
if variable in DSD_LABEL_DICT:
cbar_label = DSD_LABEL_DICT[variable]
else:
units = ds[variable].attrs.get("units", "-")
cbar_label = f"{variable} [{units}]"
# ------------------------------------------------------------------------.
# Calculate event duration
duration = ds.disdrodb.end_time - ds.disdrodb.start_time
# ------------------------------------------------------------------------.
# Define temporal slices
# - Align to closest <hours_per_slice> time
# - For hours_per_slice=3 --> 00, 03, 06, ...
time = ds["time"].to_index()
t_start = time[0]
t_end = time[-1]
if aligned:
aligned_start = t_start.floor(f"{hours_per_slice}h")
aligned_end = t_end.ceil(f"{hours_per_slice}h")
# Create time bins
time_bins = pd.date_range(
start=aligned_start,
end=aligned_end,
freq=f"{hours_per_slice}h",
)
else:
# Create time bins
time_bins = pd.date_range(
start=t_start,
end=t_end + pd.Timedelta(f"{hours_per_slice}h"),
freq=f"{hours_per_slice}h",
)
n_total_slices = len(time_bins) - 1
n_slices = min(n_total_slices, max_rows)
# ------------------------------------------------------------------------.
# Print info on event quicklook
if verbose:
print("=== N(D) Event Quicklook ===")
print(f"Dataset time span : {t_start} → {t_end}")
print(f"Slice length : {hours_per_slice} h")
print(f"Plotted slices : {n_slices}/{n_total_slices}")
if n_total_slices > max_rows:
last_plotted_end = time_bins[max_rows]
print(f"Unplotted period : {last_plotted_end} → {t_end}")
# ------------------------------------------------------------------------.
# Regularize dataset to match bin start_time and end_time
ds = ds.disdrodb.regularize(
start_time=time_bins[0],
end_time=time_bins[-1],
fill_value=np.nan,
)
# Check at least 2 timesteps are available
if ds.sizes["time"] < 2:
raise ValueError("Dataset must have at least 2 time steps for quicklook.")
# Enforce legend colorbar for n_slices 1 and 2
if n_slices <= 2:
cbar_as_legend = True
####-----------------------------------------------------------------------.
#### Define figure with GridSpec
# - If cbar_as_legend=False: reserve extra row for colorbar
# - If cbar_as_legend=True: no extra row needed
# - If add_precipitation_legend True, add extra row for precipitation classification legend
# Define number of extra rows
extra_rows = 1 if (not cbar_as_legend) else 0
extra_rows += 1 if add_precipitation_legend else 0
# Define figure size
fig_width = 6.9
subplot_height = 1.9
fig_height = subplot_height * n_slices + (precipitation_legend_height if add_precipitation_legend else 0)
figsize = (fig_width, fig_height)
# Define height ratios
hspace = 0.15
height_ratios = [1] * n_slices
if not cbar_as_legend:
cbar_height_ratio = 0.2 if n_slices == 3 else 0.15 # more subplots = relatively smaller colorbar row
height_ratios.append(cbar_height_ratio)
if add_precipitation_legend:
height_ratio_precipitation_legend = precipitation_legend_height / subplot_height
height_ratios.append(height_ratio_precipitation_legend)
hspace = 0.2
# Create figure
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = GridSpec(
nrows=n_slices + extra_rows,
ncols=1,
figure=fig,
height_ratios=height_ratios,
hspace=hspace,
)
axes = [fig.add_subplot(gs[i, 0]) for i in range(n_slices)]
ax_sec_for_legend = None
# ------------------------------------------------------------------------.
#### - Plot each slice
for i in range(n_slices):
# Extract dataset slice
t0 = time_bins[i]
t1 = time_bins[i + 1]
ds_slice = ds.sel(time=slice(t0, t1))
da_dsd = ds_slice[variable]
# Define plot ax
ax = axes[i]
# Plot N(D)
p = da_dsd.plot.pcolormesh(
ax=ax,
x="time",
y=d_dim,
norm=norm,
cmap=cmap,
shading="auto",
add_colorbar=False,
)
# Always remove xarray default title
ax.set_title("")
#### - Overlay Dm
if add_dm:
ds_slice["Dm"].plot(
ax=ax,
x="time",
color="black",
linestyle="-",
linewidth=dm_linewidth,
label="$D_m$",
)
#### - Overlay sigma_m
if add_sigma_m:
(ds_slice["sigma_m"] * 2).plot(
ax=ax,
x="time",
color="black",
linestyle="--",
linewidth=sigma_linewidth,
label=sigma_label,
)
# Remove xarray default title
ax.set_title("")
# Add axis labels - remove ylabel from all individual axes
ax.set_xlabel("")
ax.set_ylabel("")
#### - Add generic secondary time series on a twin axis
if secondary_var is not None:
ax_sec = ax.twinx()
ds_slice[secondary_var].plot(
ax=ax_sec,
x="time",
color=secondary_color,
alpha=secondary_alpha,
linewidth=secondary_linewidth,
linestyle=secondary_linestyle,
label=secondary_var,
)
# Always remove xarray default title
ax_sec.set_title("")
ax_sec.set_yscale(secondary_yscale)
# Display playnumbers instead of scientific notation
if secondary_yscale == "log":
yticks = ax_sec.get_yticks()
ytick_labels = [f"{t:g}" for t in yticks]
ax_sec.set_yticks(yticks)
ax_sec.set_yticklabels(ytick_labels)
if secondary_ylim is not None:
ax_sec.set_ylim(secondary_ylim)
# Remove ylabel from all individual axes
ax_sec.set_ylabel("")
ax_sec.tick_params(axis="y", labelcolor=secondary_color)
if secondary_hlines is not None:
for y in secondary_hlines:
ax_sec.axhline(
y=y,
color="gray",
alpha=0.2,
linewidth=1,
linestyle="-",
zorder=0,
)
if ax_sec_for_legend is None:
ax_sec_for_legend = ax_sec
ax.set_ylim(*d_lim)
ax.tick_params(axis="x", pad=time_ticklabel_pad)
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
#### - Format time axis
locator = mdates.AutoDateLocator(minticks=4, maxticks=8)
formatter = mdates.ConciseDateFormatter(locator)
ax.xaxis.set_major_locator(locator)
ax.xaxis.set_major_formatter(formatter)
ax.xaxis.get_offset_text().set_visible(False)
#### - Add precipitation classification strip at the top of the subplot as inset axes
if add_precipitation_legend:
# Get classification values for this time slice
precip_type = ds_slice[precipitation_type]
# Create inset axes at the top (sharing x-axis with main plot)
# [x0, y0, width, height] in axes coordinates
ax_precip = ax.inset_axes([0, 0.95, 1, 0.05], sharex=ax)
# Define colormap and norm
# - Sort category values
values = np.array(sorted(precipitation_legend_colors.keys()))
# - Extract colors in same order
colors = [precipitation_legend_colors[v] for v in values]
# - Create colormap
cmap_precip = ListedColormap(colors)
# - Build automatic boundaries
# Works even for negative or non-consecutive values
boundaries = np.concatenate(
[
[values[0] - 0.5],
(values[:-1] + values[1:]) / 2,
[values[-1] + 0.5],
],
)
norm_precip = BoundaryNorm(boundaries, cmap_precip.N)
# Plot 1-pixel-high strip
t = precip_type["time"].to_numpy()
dt = np.diff(t) # timedelta64
t_edges = np.concatenate(
[
[t[0] - dt[0] / 2],
t[:-1] + dt / 2,
[t[-1] + dt[-1] / 2],
],
)
ax_precip.pcolormesh(
t_edges,
[0, 1],
precip_type.to_numpy()[np.newaxis, :],
cmap=cmap_precip,
norm=norm_precip,
shading="flat",
)
# Add 'axis' line
ax_precip.axhline(
y=0,
color="black",
linewidth=1,
alpha=1.0,
)
# Remove ticks and ticklabels
ax_precip.set_yticks([])
ax_precip.xaxis.set_visible(False)
for spine in ax_precip.spines.values():
spine.set_visible(False)
# Add time xlabel
if precipitation_type is None:
axes[n_slices - 1].set_xlabel("Time (UTC)", fontsize=axis_label_fontsize)
# ------------------------------------------------------------------------.
#### - Add title
# Format duration as "XhYmin" or "Xmin"
total_minutes = int(duration.total_seconds() / 60)
hours = total_minutes // 60
minutes = total_minutes % 60
duration_str = f"{hours}H{minutes:02d}MIN" if hours > 0 else f"{minutes}MIN"
# Format title based on whether dates are the same
t_start_dt = time_bins[0]
t_end_dt = time_bins[n_slices]
if t_start_dt.date() == t_end_dt.date():
# Same date: show "YYYY-MM-DD HH:MM - HH:MM UTC"
title_str = f"{t_start_dt.strftime('%Y-%m-%d %H:%M')} - {t_end_dt.strftime('%H:%M')} UTC"
else:
# Different dates: show full datetime for both
title_str = f"{t_start_dt.strftime('%Y-%m-%d %H:%M')} - {t_end_dt.strftime('%Y-%m-%d %H:%M')} UTC"
# Set center title
axes[0].set_title(
title_str,
fontsize=title_fontsize,
fontweight="bold",
loc="center",
)
# Add right title with event duration
axes[0].set_title(
duration_str,
fontsize=side_title_fontsize,
loc="right",
)
# ------------------------------------------------------------------------.
#### - Add centered y-labels in the middle of the figure (closer to axes)
fig.text(
0.08,
0.5,
d_label,
va="center",
rotation="vertical",
fontsize=axis_label_fontsize,
)
if secondary_var is not None:
if secondary_label is None:
units = ds[secondary_var].attrs.get("units", "")
secondary_label = f"{secondary_var} [{units}]" if units else secondary_var
fig.text(
0.945,
0.5,
secondary_label,
va="center",
rotation="vertical",
fontsize=axis_label_fontsize,
color=secondary_color,
)
# ------------------------------------------------------------------------.
#### - Add legend
# Collect legend handles from both axes
handles, labels = axes[0].get_legend_handles_labels()
if secondary_var is not None and ax_sec_for_legend is not None:
handles_sec, labels_sec = ax_sec_for_legend.get_legend_handles_labels()
handles += handles_sec
labels += labels_sec
if np.any([secondary_var is not None, add_sigma_m, add_dm]):
axes[0].legend(
handles,
labels,
loc="upper left",
bbox_to_anchor=(0, 0.98),
fontsize=legend_fontsize,
frameon=True,
fancybox=False,
edgecolor="black",
)
# ------------------------------------------------------------------------.
#### - Add colorbar
if cbar_as_legend:
# Add colorbar as a legend in the last subplot with background box
cax = axes[-1].inset_axes([cbar_xpos, cbar_ypos, cbar_width, 0.10]) # [x, y, width, height] in axes coords
# # Raise z-order so the colorbar is on top and fancybox behind
fancybox_zorder = cax.get_zorder() + 1
cax.set_zorder(cax.get_zorder() + 2)
cbar = fig.colorbar(
p,
cax=cax,
orientation="horizontal",
extend="max",
)
# Move label above colorbar
cbar.ax.set_xlabel(cbar_label, fontsize=cbar_fontsize, labelpad=cbar_labelpad)
cbar.ax.xaxis.set_label_position("top")
cbar.ax.tick_params(labelsize=9)
# Add white box with edge behind colorbar
fancy_bbox = get_tightbbox_position(cax)
add_fancybox(
ax=axes[-1],
bbox=fancy_bbox,
pad=0,
fc="white",
ec="none",
lw=0.5,
alpha=0.9,
shape="square",
zorder=fancybox_zorder,
)
else:
# Add colorbar as separate subplot at bottom
cbar_pad = 0.05 * (3 / n_slices)
cbar_fraction = 0.03 * (3 / n_slices)
cbar = fig.colorbar(
p,
ax=axes,
orientation="horizontal",
pad=cbar_pad,
fraction=cbar_fraction,
extend="max",
)
cbar.set_label(cbar_label, fontsize=cbar_fontsize)
#### Add classification legend
if add_precipitation_legend:
legend_patches = [
Patch(facecolor=precipitation_legend_colors[k], edgecolor="black", label=precipitation_legend_labels[k])
for k in sorted(precipitation_legend_colors.keys())
]
# Select last row in GridSpec
precip_ax = fig.add_subplot(gs[-1, 0])
precip_ax.axis("off")
precip_ax.legend(
handles=legend_patches,
loc="center left",
bbox_to_anchor=(-0.01, 0.5),
ncol=precipitation_legend_ncol,
frameon=False,
fontsize=precipitation_legend_fontsize,
columnspacing=1.2,
handlelength=0.8,
)
#### - Add bottom right legend
if bottom_right_str is not None:
# Compute vertical center of last GridSpec row
if add_precipitation_legend:
# Reuse precipitation legend axis for perfect alignment
precip_ax.text(
0.99,
0.5,
bottom_right_str,
ha="right",
va="center",
fontsize=bottom_right_str_fontsize,
transform=precip_ax.transAxes,
)
else:
fig.axes[-1].text(
0.99,
-0.15,
bottom_right_str,
ha="right",
va="top",
fontsize=bottom_right_str_fontsize,
transform=axes[-1].transAxes,
)
# Return figure
return fig
####-------------------------------------------------------------------------------------------------------
#### Spectra visualizations
def _check_has_diameter_and_velocity_dims(da):
if DIAMETER_DIMENSION not in da.dims or VELOCITY_DIMENSION not in da.dims:
raise ValueError(f"The DataArray must have both '{DIAMETER_DIMENSION}' and '{VELOCITY_DIMENSION}' dimensions.")
return da
def _get_spectrum_variable(xr_obj, variable):
if not isinstance(xr_obj, (xr.Dataset, xr.DataArray)):
raise TypeError("Expecting xarray object as input.")
if VELOCITY_DIMENSION not in xr_obj.dims:
raise ValueError("2D spectrum not available.")
if isinstance(xr_obj, xr.Dataset):
if variable not in xr_obj:
raise ValueError(f"The dataset do not include {variable=}.")
xr_obj = xr_obj[variable]
xr_obj = _check_has_diameter_and_velocity_dims(xr_obj)
return xr_obj
[docs]
def plot_spectrum_evolution(
ds,
legend_variables=None,
legend_ncol=1,
xlim=None,
ylim=None,
plot_hc_rain_mask_boundary=False,
**plot_kwargs,
):
"""Plot the evolution of disdrodb spectra over time.
Parameters
----------
ds : xarray.Dataset
Dataset containing 'time' dimension and 'disdrodb'.
legend_variables : list of str, optional
Dataset variables to display in the legend.
legend_ncol : int, optional
Number of legend entries per row (horizontal layout).
xlim, ylim : tuple, optional
Axis limits passed to matplotlib.
plot_kwargs : dict
Additional keyword arguments passed to plot_spectrum().
"""
# Check timestep available
if "time" not in ds.dims or ds.sizes["time"] == 0:
raise ValueError("No timesteps available.")
# Define legend formatting
# --> Define decimals per variable
decimals = {}
if legend_variables is not None:
for var in legend_variables:
if var not in ds:
raise KeyError(f"Variable '{var}' not found in dataset")
ds[var] = ds[var].compute() # Ensure variable is loaded in memory
values = ds[var].to_numpy()
# Remove NaNs
values = values[np.isfinite(values)]
# Integer-like → 0 decimals
if np.allclose(values, np.round(values)):
decimals[var] = 0
else:
decimals[var] = 2
# Precompute hc rain mask contour
if plot_hc_rain_mask_boundary:
contour = get_spectrum_mask_boundary(
ds,
above_velocity_tolerance=2,
below_velocity_fraction=None,
below_velocity_tolerance=3,
maintain_drops_smaller_than=1, # 1, # 2
maintain_drops_slower_than=2.5, # 2.5, # 3
maintain_smallest_drops=False,
fall_velocity_model="Beard1976",
)
# Loop over time
for i in range(ds.sizes["time"]):
ds_i = ds.isel(time=i)
# Define figure
fig, ax = plt.subplots()
# Plot spectrum
plot_spectrum(ds_i, ax=ax, plot_hc_rain_mask_boundary=False, **plot_kwargs)
if plot_hc_rain_mask_boundary:
plot_contour(contour, ax=ax, color="black", linestyle="--")
# Set title
title_str = pd.to_datetime(ds_i["time"].to_numpy()).strftime("%Y-%m-%d %H:%M:%S")
ax.set_title(title_str)
# Add legend
if legend_variables is not None:
handles = []
labels = []
for var in legend_variables:
value = ds_i[var].item()
dec = decimals[var]
label = f"{var}: NaN" if np.isnan(value) else f"{var}: {value:.{dec}f}"
# Invisible handle
handles.append(Line2D([], [], linestyle="none"))
labels.append(label)
ax.legend(
handles,
labels,
loc="upper left",
ncol=legend_ncol,
frameon=True,
handlelength=0,
handletextpad=0.0,
columnspacing=1.2,
)
# Set limits if specified
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
plt.show()
[docs]
def plot_spectrum(
xr_obj,
variable="raw_drop_number",
ax=None,
cmap=None,
norm=None,
extend="max",
add_colorbar=True,
cbar_kwargs=None,
title=None,
plot_hc_rain_mask_boundary=False,
**plot_kwargs,
):
"""Plot the spectrum.
Parameters
----------
xr_obj : xarray.Dataset or xarray.DataArray
Input xarray object. If Dataset, the variable to plot must be specified.
If DataArray, it must have both diameter and velocity dimensions.
variable : str
Name of the variable to plot if xr_obj is a Dataset.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, uses current axes or creates a new one.
cmap : matplotlib.colors.Colormap, optional
Colormap to use. If None, uses 'Spectral_r' with 'under' set to 'none'.
norm : matplotlib.colors.Normalize, optional
Normalization for colormap. If None, uses LogNorm with vmin=1.
extend : str, optional
Whether to draw arrows on the colorbar to indicate out-of-range values.
Valid options are 'neither', 'min', 'max', 'both'.
Default is 'max'.
add_colorbar : bool, optional
Whether to add a colorbar. Default is True.
cbar_kwargs : dict, optional
Additional keyword arguments for colorbar. If None, uses {'label': 'Number of particles '}.
title : str, optional
Title of the plot. If not provided, defaults to the timestep or time range of the spectrum.
**plot_kwargs : dict
Additional keyword arguments passed to xarray's plot.pcolormesh method.
Notes
-----
If the input DataArray has a time dimension, it is summed over time before plotting
unless FacetGrid options (e.g., col, row) are specified in plot_kwargs.
If FacetGrid options are used, the plot will create a grid of subplots for each time slice.
To create a FacetGrid plot, use:
ds.isel(time=slice(0, 9)).disdrodb.plot_spectrum(col="time", col_wrap=3)
"""
# Retrieve spectrum
drop_number = _get_spectrum_variable(xr_obj, variable)
# Check if FacetGrid
is_facetgrid = "col" in plot_kwargs or "row" in plot_kwargs
# Check not empty object
if drop_number.size == 0:
raise ValueError("No data to plot.")
# Define start_time and end_time if time coordinate is present
drop_number = drop_number.squeeze()
if "time" in drop_number.dims:
start_time = pd.to_datetime(drop_number.disdrodb.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_time = pd.to_datetime(drop_number.disdrodb.end_time).strftime("%Y-%m-%d %H:%M:%S")
else:
start_time = None
end_time = None
# Sum over time dimension if still present
# - Unless FacetGrid options in plot_kwargs
if "time" in drop_number.dims and not is_facetgrid:
drop_number = drop_number.sum(dim="time")
if title is None:
title = f"{start_time} - {end_time}" if start_time is not None else ""
elif title is None:
title = f"{start_time}" if start_time is not None else ""
# Define default cbar_kwargs if not specified
if cbar_kwargs is None:
cbar_kwargs = {"label": "Number of particles"}
# Define cmap and norm
if cmap is None:
cmap = plt.get_cmap("Spectral_r").copy()
cmap.set_under("none")
if norm is None:
norm = LogNorm(vmin=1, vmax=None) if drop_number.sum() > 0 else None
# Remove cbar_kwargs if add_colorbar=False
if not add_colorbar:
cbar_kwargs = None
# Plot
p = drop_number.plot.pcolormesh(
ax=ax,
x=DIAMETER_DIMENSION,
y=VELOCITY_DIMENSION,
cmap=cmap,
extend=extend,
norm=norm,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
if plot_hc_rain_mask_boundary:
contour = get_spectrum_mask_boundary(
xr_obj,
above_velocity_tolerance=2,
below_velocity_fraction=None,
below_velocity_tolerance=3,
maintain_drops_smaller_than=1, # 1, # 2
maintain_drops_slower_than=2.5, # 2.5, # 3
maintain_smallest_drops=False,
fall_velocity_model="Beard1976",
)
plot_contour(contour, ax=p.axes, color="black", linestyle="--")
if not is_facetgrid:
p.axes.set_xlabel("Diamenter [mm]")
p.axes.set_ylabel("Fall velocity [m/s]")
p.axes.set_title(title)
else:
p.set_axis_labels("Diameter [mm]", "Fall velocity [m/s]")
return p
[docs]
def plot_raw_and_filtered_spectra(
ds,
cmap=None,
norm=None,
extend="max",
add_theoretical_average_velocity=True,
add_measured_average_velocity=True,
figsize=(6.9, 3.2),
dpi=300,
):
"""Plot raw and filtered drop spectrum."""
# Retrieve spectrum arrays
drop_number = _get_spectrum_variable(ds, variable="drop_number")
if "time" in drop_number.dims:
drop_number = drop_number.sum(dim="time")
drop_number = drop_number.compute()
raw_drop_number = _get_spectrum_variable(ds, variable="raw_drop_number")
if "time" in raw_drop_number.dims:
raw_drop_number = raw_drop_number.sum(dim="time")
raw_drop_number = raw_drop_number.compute()
# Compute theoretical and measured average velocity if asked
if add_theoretical_average_velocity:
theoretical_average_velocity = ds["fall_velocity"]
if "time" in theoretical_average_velocity.dims:
theoretical_average_velocity = theoretical_average_velocity.mean(dim="time")
if add_measured_average_velocity and VELOCITY_DIMENSION in drop_number.dims:
measured_average_velocity = get_drop_average_velocity(drop_number)
# Define norm if not specified
if norm is None:
norm = LogNorm(1, raw_drop_number.max())
# Initialize figure
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = GridSpec(1, 2, width_ratios=[1, 1.15], wspace=0.05) # More space for ax2
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
# Plot raw_drop_number
plot_spectrum(raw_drop_number, ax=ax1, cmap=cmap, norm=norm, extend=extend, add_colorbar=False, title="")
# Add velocities if asked
if add_theoretical_average_velocity:
theoretical_average_velocity.plot(ax=ax1, c="k", linestyle="dashed")
if add_measured_average_velocity and VELOCITY_DIMENSION in drop_number.dims:
measured_average_velocity.plot(ax=ax1, c="k", linestyle="dotted")
# Improve plot appearance
ax1.set_xlabel("Diamenter [mm]")
ax1.set_ylabel("Fall velocity [m/s]")
ax1.set_title("Raw Spectrum")
# Plot drop_number
plot_spectrum(drop_number, ax=ax2, cmap=cmap, norm=norm, extend=extend, add_colorbar=True, title="")
# Add velocities if asked
if add_theoretical_average_velocity:
theoretical_average_velocity.plot(ax=ax2, c="k", linestyle="dashed", label="Theoretical velocity")
if add_measured_average_velocity and VELOCITY_DIMENSION in drop_number.dims:
measured_average_velocity.plot(ax=ax2, c="k", linestyle="dotted", label="Measured average velocity")
# Improve plot appearance
ax2.set_yticks([])
ax2.set_yticklabels([])
ax2.set_xlabel("Diamenter [mm]")
ax2.set_ylabel("")
ax2.set_title("Filtered Spectrum")
# Add legend
if add_theoretical_average_velocity or add_measured_average_velocity:
ax2.legend(loc="lower right", frameon=False)
return fig
####---------------------------------------------------------------------------.
#### Mask utilities
[docs]
def plot_contour(contour, ax=None, **kwargs):
"""Plot contour [X,Y]."""
if ax is None:
fig, ax = plt.subplots(1, 1)
label = kwargs.pop("label", None)
for i, seg in enumerate(contour):
if label is not None and i == len(contour) - 1:
kwargs["label"] = label
ax.plot(seg[:, 0], seg[:, 1], **kwargs)
[docs]
def plot_mask_contour(mask, ax=None, **kwargs):
"""Plot mask contour."""
contour = get_mask_contour(mask)
if ax is None:
fig, ax = plt.subplots(1, 1)
label = kwargs.pop("label", None)
for i, seg in enumerate(contour):
if label is not None and i == len(contour) - 1:
kwargs["label"] = label
ax.plot(seg[:, 0], seg[:, 1], **kwargs)
####-------------------------------------------------------------------------------------------------------
#### Confusion matrix
[docs]
def plot_confusion_matrix(cm, labels, cmap="Blues", norm=None, xlabel="", ylabel="", title="", add_colorbar=False):
"""Plot confusion matrix."""
fig, ax = plt.subplots(figsize=(6, 6))
cm_img = cm.copy().astype(float)
cm_img[cm_img == 0] = np.nan # better than small constant
im = ax.imshow(
cm_img,
cmap=cmap,
norm=norm,
)
# Add grid lines between cells
ax.set_xticks(np.arange(cm.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(cm.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="gray", linestyle="-", linewidth=0.5)
ax.tick_params(which="minor", bottom=False, left=False)
# Ticks
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.set_yticklabels(labels)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
# Add counts in each cell
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
value = cm[i, j]
if value > 0:
ax.text(
j,
i,
f"{int(value)}",
ha="center",
va="center",
color="black" if value < np.nanmax(cm) / 5 else "white",
fontsize=9,
)
# Colorbar
if add_colorbar:
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Counts")
plt.tight_layout()
plt.show()
####-------------------------------------------------------------------------------------------------------
#### DenseLines
[docs]
def normalize_array(arr, method="max"):
"""Normalize a NumPy array according to the chosen method.
Parameters
----------
arr : numpy.ndarray
Input array.
method : str
Normalization method. Options:
- 'max' : Divide by the maximum value.
- 'minmax': Scale to [0, 1] range.
- 'zscore': Standardize to mean 0, std 1.
- 'log' : Apply log10 transform (shifted if min <= 0).
- 'none' : No normalization (return original array).
Returns
-------
numpy.ndarray
Normalized array.
"""
arr = np.asarray(arr, dtype=float)
if method == "max":
max_val = np.nanmax(arr)
return arr / max_val if max_val != 0 else arr
if method == "minmax":
min_val = np.nanmin(arr)
max_val = np.nanmax(arr)
return (arr - min_val) / (max_val - min_val) if max_val != min_val else np.zeros_like(arr)
if method == "zscore":
mean_val = np.nanmean(arr)
std_val = np.nanstd(arr)
return (arr - mean_val) / std_val if std_val != 0 else np.zeros_like(arr)
if method == "log":
min_val = np.nanmin(arr)
shifted = arr - min_val + 1e-12 # Shift to avoid log(0) or log of negative
return np.log10(shifted)
if method == "none":
return arr
raise ValueError(f"Unknown normalization method: {method}")
def _np_to_rgba_alpha(arr, cmap="viridis", cmap_norm=None, scaling="linear"):
"""Convert a numpy array to an RGBA array with alpha based on array value.
Parameters
----------
arr : numpy.ndarray
arr of counts or frequencies.
cmap : str or Colormap, optional
Matplotlib colormap to use for RGB channels.
cmap_norm: matplotlib.colors.Norm
Norm to be used to scale data before assigning cmap colors.
The default is Normalize(vmin, vmax).
scaling : str, optional
Scaling type for alpha mapping:
- "linear" : min-max normalization
- "log" : logarithmic normalization (positive values only)
- "sqrt" : square-root (power-law with exponent=0.5)
- "exp" : exponential scaling
- "quantile" : percentile-based scaling
- "none" : full opacity (alpha=1)
Returns
-------
rgba : 3D numpy array (ny, nx, 4)
RGBA array.
"""
# Ensure numpy array
arr = np.asarray(arr, dtype=float)
# Define mask with NaN pixel
mask_na = np.isnan(arr)
# Retrieve array shape
ny, nx = arr.shape
# Define colormap norm
if cmap_norm is None:
cmap_norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
# Define alpha
if scaling == "linear":
norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
alpha = norm(arr)
elif scaling == "log":
vals = np.where(arr > 0, arr, np.nan) # mask non-positive
norm = LogNorm(vmin=np.nanmin(vals), vmax=np.nanmax(vals))
alpha = norm(arr)
alpha = np.nan_to_num(alpha, nan=0.0)
elif scaling == "sqrt":
alpha = np.sqrt(np.clip(arr, 0, None) / np.nanmax(arr))
elif scaling == "exp":
normed = np.clip(arr / np.nanmax(arr), 0, 1)
alpha = np.expm1(normed) / np.expm1(1)
elif scaling == "quantile":
flat = arr.ravel()
ranks = np.argsort(np.argsort(flat)) # rankdata without scipy
alpha = ranks / (len(flat) - 1)
alpha = alpha.reshape(arr.shape)
elif scaling == "none":
alpha = np.ones_like(arr, dtype=float)
else:
raise ValueError(f"Unknown scaling type: {scaling}")
# Map values to colors
cmap = plt.get_cmap(cmap).copy()
rgba = cmap(cmap_norm(arr))
# Set alpha channel
alpha[mask_na] = 0 # where input was NaN
rgba[..., -1] = np.clip(alpha, 0, 1)
return rgba
[docs]
def to_rgba(obj, cmap="viridis", norm=None, scaling="none"):
"""Map a xarray DataArray (or numpy array) to RGBA with optional alpha-scaling."""
input_is_xarray = False
if isinstance(obj, xr.DataArray):
# Define template for RGBA DataArray
da_rgba = obj.copy()
da_rgba = da_rgba.expand_dims({"rgba": 4}).transpose(..., "rgba")
input_is_xarray = True
# Extract numpy array
obj = obj.to_numpy()
# Apply transparency
arr = _np_to_rgba_alpha(obj, cmap=cmap, cmap_norm=norm, scaling=scaling)
# Return xarray.DataArray
if input_is_xarray:
da_rgba.data = arr
return da_rgba
# Or numpy array otherwise
return arr
[docs]
def max_blend_images(ds_rgb, dim):
"""Max blend a RGBA DataArray across a samples dimensions."""
# Ensure dimension to blend in first position
ds_rgb = ds_rgb.transpose(dim, ...)
# Extract numpy array
stack = ds_rgb.data
# Extract alpha array
alphas = stack[..., 3]
# Select the winning RGBA per pixel # (N, H, W)
idx = np.argmax(alphas, axis=0) # (H, W), index of image with max alpha
idx4 = np.repeat(idx[np.newaxis, ..., np.newaxis], 4, axis=-1) # (1, H, W, 4)
out = np.take_along_axis(stack, idx4, axis=0)[0] # (H, W, 4)
# Create output RGBA array
da = ds_rgb.isel({dim: 0}).copy()
da.data = out
return da
def _create_denseline_grid(indices, ny, nx, nsamples):
# Assign 1 when line pass in a bin
valid = (indices >= 0) & (indices < ny)
s_idx, x_idx = np.nonzero(valid)
y_idx = indices[valid]
# ----------------------------------------------
### Vectorized code with high memory footprint because of 3D array
# # Create 3D array with hits
# grid_3d = np.zeros((nsamples, ny, nx), dtype=np.int64)
# grid_3d[s_idx, y_idx, x_idx] = 1
# # Normalize by columns
# col_sums = grid_3d.sum(axis=1, keepdims=True)
# col_sums[col_sums == 0] = 1 # Avoid division by zero
# grid_3d = grid_3d / col_sums
# # Sum over samples
# grid = grid_3d.sum(axis=0)
# # Free memory
# del grid_3d
# ----------------------------------------------
## Vectorized alternative with much lower memory footprint
# Count hits per (sample, y, x)
grid = np.zeros((ny, nx), dtype=np.float64)
# Compute per-sample-per-column counts
col_counts = np.zeros((nsamples, nx), dtype=np.int64)
np.add.at(col_counts, (s_idx, x_idx), 1)
# Define weights to normalize contributions, avoiding division by zero
# - Weight = 1 / (# hits per column, per sample)
col_counts[col_counts == 0] = 1
weights = 1.0 / col_counts[s_idx, x_idx]
# Accumulate weighted contributions
np.add.at(grid, (y_idx, x_idx), weights)
# Return 2D grid
return grid
def _compute_block_size(ny, nx, dtype=np.float64, safety_margin=2e9):
"""Compute maximum block size given available memory."""
avail_mem = psutil.virtual_memory().available - safety_margin
# Constant cost for final grid
base = ny * nx * np.dtype(dtype).itemsize
# Per-sample cost (worst case, includes col_counts + indices + weights)
per_sample = nx * 40
max_block = (avail_mem - base) // per_sample
return max(1, int(max_block))
[docs]
def compute_dense_lines(
da: xr.DataArray,
coord: str,
x_bins: list,
y_bins: list,
normalization="max",
):
"""
Compute a 2D density-of-lines histogram from an xarray.DataArray.
Parameters
----------
da : xarray.DataArray
Input data array. One of its dimensions (named by ``coord``) is taken
as the horizontal coordinate. All other dimensions are collapsed into
“series,” so that each combination of the remaining dimension values
produces one 1D line along ``coord``.
coord : str
The name of the coordinate/dimension of the DataArray to bin over.
``da.coords[coord]`` must be a 1D numeric array (monotonic is recommended).
x_bins : array-like
Bin edges to bin the coordinate/dimension with shape (nx+1,).
Must be monotonically increasing.
The number of x-bins will be ``nx = len(x_bins) - 1``.
y_bins : array-like
Bin edges for the DataArray values with shape (ny+1,).
Must be monotonically increasing.
The number of y-bins will be ``ny = len(y_bins) - 1``.
normalization : bool, optional
If 'none', returns the raw histogram.
By default, the function normalize the histogram by its global maximum ('max').
Log-normalization ('log') is also available.
Returns
-------
xarray.DataArray
2D histogram of shape ``(ny, nx)``. Dimensions are ``('y', 'x')``, where:
- ``x``: the bin-center coordinate of ``x_bins`` (length ``nx``)
- ``y``: the bin-center coordinate of ``y_bins`` (length ``ny``)
Each element ``out.values[y_i, x_j]`` is the count (or normalized count) of how
many “series-values” from ``da`` fell into the rectangular bin
``x_bins[j] ≤ x_value < x_bins[j+1]`` and
``y_bins[i] ≤ data_value < y_bins[i+1]``.
References
----------
Moritz, D., Fisher, D. (2018).
Visualizing a Million Time Series with the Density Line Chart
https://doi.org/10.48550/arXiv.1808.06019
"""
# Check DataArray name
if da.name is None or da.name == "":
raise ValueError("The DataArray must have a name.")
# Validate x_bins and y_bins
x_bins = np.asarray(x_bins)
y_bins = np.asarray(y_bins)
if x_bins.ndim != 1 or x_bins.size < 2:
raise ValueError("`x_bins` must be a 1D array with at least two edges.")
if y_bins.ndim != 1 or y_bins.size < 2:
raise ValueError("`y_bins` must be a 1D array with at least two edges.")
if not np.all(np.diff(x_bins) > 0):
raise ValueError("`x_bins` must be strictly increasing.")
if not np.all(np.diff(y_bins) > 0):
raise ValueError("`y_bins` must be strictly increasing.")
# Verify that `coord` exists as either a dimension or a coordinate
if coord not in (list(da.coords) + list(da.dims)):
raise ValueError(f"'{coord}' is not a dimension or coordinate of the DataArray.")
if coord not in da.dims:
if da[coord].ndim != 1:
raise ValueError(f"Coordinate '{coord}' must be 1D. Instead has dimensions {da[coord].dims}")
x_dim = da[coord].dims[0]
else:
x_dim = coord
# Extract the coordinate array
x_values = (x_bins[0:-1] + x_bins[1:]) / 2
# Extract the array (samples, x)
other_dims = [d for d in da.dims if d != x_dim]
if len(other_dims) == 1:
arr = da.transpose(*other_dims, x_dim).to_numpy()
else:
arr = da.stack({"sample": other_dims}).transpose("sample", x_dim).to_numpy() # noqa PD013
# Define y bins center
y_center = (y_bins[0:-1] + y_bins[1:]) / 2
# Prepare the 2D count grid of shape (ny, nx)
# - ny correspond tot he value of the timeseries at nx points
nx = len(x_bins) - 1
ny = len(y_bins) - 1
nsamples = arr.shape[0]
# For each (series, x-index), find which y-bin it falls into:
# - np.searchsorted(y_bins, value) gives the insertion index in y_bins;
# --> subtracting 1 yields the bin index.
# If a value is not in y_bins, searchsorted returns 0, so idx = -1
# If a valueis NaN, the indices value will be ny
indices = np.searchsorted(y_bins, arr) - 1 # (samples, nx)
# Compute unormalized DenseLines grid
# grid = _create_denseline_grid(
# indices=indices,
# ny=ny,
# nx=nx,
# nsamples=nsamples
# )
# Compute unormalized DenseLines grid by blocks to avoid running out of memory
# - Define block size based on available RAM memory
block = _compute_block_size(ny=ny, nx=nx, dtype=np.float64, safety_margin=4e9)
list_grid = []
for i in range(0, nsamples, block):
block_start_idx = i
block_end_idx = min(i + block, nsamples)
block_indices = indices[block_start_idx:block_end_idx, :]
block_nsamples = block_end_idx - block_start_idx
block_grid = _create_denseline_grid(indices=block_indices, ny=ny, nx=nx, nsamples=block_nsamples)
list_grid.append(block_grid)
grid_3d = np.stack(list_grid, axis=0)
# Finalize sum over samples
grid = grid_3d.sum(axis=0)
# Normalize grid
grid = normalize_array(grid, method=normalization)
# Create DataArray
name = da.name
out = xr.DataArray(grid, dims=[name, coord], coords={coord: (coord, x_values), name: (name, y_center)})
# Mask values which are 0 with NaN
out = out.where(out > 0)
# Return 2D histogram
return out