# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2023 DISDRODB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------.
"""DISDRODB Plotting Tools."""
import matplotlib.pyplot as plt
import numpy as np
import psutil
import xarray as xr
from matplotlib.colors import LogNorm, Normalize
from matplotlib.gridspec import GridSpec
from disdrodb.constants import DIAMETER_DIMENSION, VELOCITY_DIMENSION
from disdrodb.l2.empirical_dsd import get_drop_average_velocity
####-------------------------------------------------------------------------------------------------------
#### N(D) visualizations
def _single_plot_nd_distribution(drop_number_concentration, diameter, diameter_bin_width):
fig, ax = plt.subplots(1, 1)
ax.bar(
diameter,
drop_number_concentration,
width=diameter_bin_width,
edgecolor="darkgray",
color="lightgray",
label="Data",
)
ax.set_title("Drop number concentration (N(D))")
ax.set_xlabel("Drop diameter (mm)")
ax.set_ylabel("N(D) [m-3 mm-1]")
return ax
[docs]
def plot_nd(ds, var="drop_number_concentration", cmap=None, norm=None):
"""Plot drop number concentration N(D) timeseries."""
# Check inputs
if var not in ds:
raise ValueError(f"{var} is not a xarray Dataset variable!")
# Check only time and diameter dimensions are specified
if "time" not in ds.dims:
ax = _single_plot_nd_distribution(
drop_number_concentration=ds[var],
diameter=ds["diameter_bin_center"],
diameter_bin_width=ds["diameter_bin_width"],
)
return ax
# Select N(D)
ds_var = ds[[var]].compute()
# Regularize input
ds_var = ds_var.disdrodb.regularize()
# Set 0 values to np.nan
ds_var = ds_var.where(ds_var[var] > 0)
# Define cmap an norm
if cmap is None:
cmap = plt.get_cmap("Spectral_r").copy()
vmin = ds_var[var].min().item()
norm = LogNorm(vmin, None) if norm is None else norm
# Plot N(D)
p = ds_var[var].plot.pcolormesh(x="time", norm=norm, cmap=cmap)
p.axes.set_title("Drop number concentration (N(D))")
p.axes.set_ylabel("Drop diameter (mm)")
return p
####-------------------------------------------------------------------------------------------------------
#### Spectra visualizations
def _check_has_diameter_and_velocity_dims(da):
if DIAMETER_DIMENSION not in da.dims or VELOCITY_DIMENSION not in da.dims:
raise ValueError(f"The DataArray must have both '{DIAMETER_DIMENSION}' and '{VELOCITY_DIMENSION}' dimensions.")
return da
def _get_spectrum_variable(xr_obj, variable):
if not isinstance(xr_obj, (xr.Dataset, xr.DataArray)):
raise TypeError("Expecting xarray object as input.")
if isinstance(xr_obj, xr.Dataset):
if variable not in xr_obj:
raise ValueError(f"The dataset do not include {variable=}.")
xr_obj = xr_obj[variable]
xr_obj = _check_has_diameter_and_velocity_dims(xr_obj)
return xr_obj
[docs]
def plot_spectrum(
xr_obj,
variable="raw_drop_number",
ax=None,
cmap=None,
norm=None,
extend="max",
add_colorbar=True,
cbar_kwargs=None,
title="Drop Spectrum",
**plot_kwargs,
):
"""Plot the spectrum.
Parameters
----------
xr_obj : xarray.Dataset or xarray.DataArray
Input xarray object. If Dataset, the variable to plot must be specified.
If DataArray, it must have both diameter and velocity dimensions.
variable : str
Name of the variable to plot if xr_obj is a Dataset.
ax : matplotlib.axes.Axes, optional
Axes to plot on. If None, uses current axes or creates a new one.
cmap : Colormap, optional
Colormap to use. If None, uses 'Spectral_r' with 'under' set to 'none'.
norm : matplotlib.colors.Normalize, optional
Normalization for colormap. If None, uses LogNorm with vmin=1.
extend : {'neither', 'both', 'min', 'max'}, optional
Whether to draw arrows on the colorbar to indicate out-of-range values.
Default is 'max'.
add_colorbar : bool, optional
Whether to add a colorbar. Default is True.
cbar_kwargs : dict, optional
Additional keyword arguments for colorbar. If None, uses {'label': 'Number of particles '}.
title : str, optional
Title of the plot. Default is 'Drop Spectrum'.
**plot_kwargs : dict
Additional keyword arguments passed to xarray's plot.pcolormesh method.
Notes
-----
- If the input DataArray has a time dimension, it is summed over time before plotting
unless FacetGrid options (e.g., col, row) are specified in plot_kwargs.
- If FacetGrid options are used, the plot will create a grid of subplots for each time slice.
To create a FacetGrid plot, use:
ds.isel(time=slice(0, 9)).disdrodb.plot_spectrum(col="time", col_wrap=3)
"""
# Retrieve spectrum
drop_number = _get_spectrum_variable(xr_obj, variable)
# Check if FacetGrid
is_facetgrid = "col" in plot_kwargs or "row" in plot_kwargs
# Sum over time dimension if still present
# - Unless FacetGrid options in plot_kwargs
if "time" in drop_number.dims and not is_facetgrid:
drop_number = drop_number.sum(dim="time")
# Define default cbar_kwargs if not specified
if cbar_kwargs is None:
cbar_kwargs = {"label": "Number of particles"}
# Define cmap and norm
if cmap is None:
cmap = plt.get_cmap("Spectral_r").copy()
cmap.set_under("none")
if norm is None:
norm = LogNorm(vmin=1, vmax=None) if drop_number.sum() > 0 else None
# Remove cbar_kwargs if add_colorbar=False
if not add_colorbar:
cbar_kwargs = None
# Plot
p = drop_number.plot.pcolormesh(
ax=ax,
x=DIAMETER_DIMENSION,
y=VELOCITY_DIMENSION,
cmap=cmap,
extend=extend,
norm=norm,
add_colorbar=add_colorbar,
cbar_kwargs=cbar_kwargs,
**plot_kwargs,
)
if not is_facetgrid:
p.axes.set_xlabel("Diamenter [mm]")
p.axes.set_ylabel("Fall velocity [m/s]")
p.axes.set_title(title)
else:
p.set_axis_labels("Diameter [mm]", "Fall velocity [m/s]")
return p
[docs]
def plot_raw_and_filtered_spectra(
ds,
cmap=None,
norm=None,
extend="max",
add_theoretical_average_velocity=True,
add_measured_average_velocity=True,
figsize=(8, 4),
dpi=300,
):
"""Plot raw and filtered drop spectrum."""
# Retrieve spectrum arrays
drop_number = _get_spectrum_variable(ds, variable="drop_number")
if "time" in drop_number.dims:
drop_number = drop_number.sum(dim="time")
drop_number = drop_number.compute()
raw_drop_number = _get_spectrum_variable(ds, variable="raw_drop_number")
if "time" in raw_drop_number.dims:
raw_drop_number = raw_drop_number.sum(dim="time")
raw_drop_number = raw_drop_number.compute()
# Compute theoretical and measured average velocity if asked
if add_theoretical_average_velocity:
theoretical_average_velocity = ds["fall_velocity"]
if "time" in theoretical_average_velocity.dims:
theoretical_average_velocity = theoretical_average_velocity.mean(dim="time")
if add_measured_average_velocity:
measured_average_velocity = get_drop_average_velocity(drop_number)
# Define norm if not specified
if norm is None:
norm = LogNorm(1, raw_drop_number.max())
# Initialize figure
fig = plt.figure(figsize=figsize, dpi=dpi)
gs = GridSpec(1, 2, width_ratios=[1, 1.15], wspace=0.05) # More space for ax2
ax1 = fig.add_subplot(gs[0])
ax2 = fig.add_subplot(gs[1])
# Plot raw_drop_number
plot_spectrum(raw_drop_number, ax=ax1, cmap=cmap, norm=norm, extend=extend, add_colorbar=False, title="")
# Add velocities if asked
if add_theoretical_average_velocity:
theoretical_average_velocity.plot(ax=ax1, c="k", linestyle="dashed")
if add_measured_average_velocity:
measured_average_velocity.plot(ax=ax1, c="k", linestyle="dotted")
# Improve plot appearance
ax1.set_xlabel("Diamenter [mm]")
ax1.set_ylabel("Fall velocity [m/s]")
ax1.set_title("Raw Spectrum")
# Plot drop_number
plot_spectrum(drop_number, ax=ax2, cmap=cmap, norm=norm, extend=extend, add_colorbar=True, title="")
# Add velocities if asked
if add_theoretical_average_velocity:
theoretical_average_velocity.plot(ax=ax2, c="k", linestyle="dashed", label="Theoretical velocity")
if add_measured_average_velocity:
measured_average_velocity.plot(ax=ax2, c="k", linestyle="dotted", label="Measured average velocity")
# Improve plot appearance
ax2.set_yticks([])
ax2.set_yticklabels([])
ax2.set_xlabel("Diamenter [mm]")
ax2.set_ylabel("")
ax2.set_title("Filtered Spectrum")
# Add legend
if add_theoretical_average_velocity or add_measured_average_velocity:
ax2.legend(loc="lower right", frameon=False)
return fig
####-------------------------------------------------------------------------------------------------------
#### DenseLines
[docs]
def normalize_array(arr, method="max"):
"""Normalize a NumPy array according to the chosen method.
Parameters
----------
arr : np.ndarray
Input array.
method : str
Normalization method. Options:
- 'max' : Divide by the maximum value.
- 'minmax': Scale to [0, 1] range.
- 'zscore': Standardize to mean 0, std 1.
- 'log' : Apply log10 transform (shifted if min <= 0).
- 'none' : No normalization (return original array).
Returns
-------
np.ndarray
Normalized array.
"""
arr = np.asarray(arr, dtype=float)
if method == "max":
max_val = np.nanmax(arr)
return arr / max_val if max_val != 0 else arr
if method == "minmax":
min_val = np.nanmin(arr)
max_val = np.nanmax(arr)
return (arr - min_val) / (max_val - min_val) if max_val != min_val else np.zeros_like(arr)
if method == "zscore":
mean_val = np.nanmean(arr)
std_val = np.nanstd(arr)
return (arr - mean_val) / std_val if std_val != 0 else np.zeros_like(arr)
if method == "log":
min_val = np.nanmin(arr)
shifted = arr - min_val + 1e-12 # Shift to avoid log(0) or log of negative
return np.log10(shifted)
if method == "none":
return arr
raise ValueError(f"Unknown normalization method: {method}")
def _np_to_rgba_alpha(arr, cmap="viridis", cmap_norm=None, scaling="linear"):
"""Convert a numpy array to an RGBA array with alpha based on array value.
Parameters
----------
arr : numpy.ndarray
arr of counts or frequencies.
cmap : str or Colormap, optional
Matplotlib colormap to use for RGB channels.
cmap_norm: matplotlib.colors.Norm
Norm to be used to scale data before assigning cmap colors.
The default is Normalize(vmin, vmax).
scaling : str, optional
Scaling type for alpha mapping:
- "linear" : min-max normalization
- "log" : logarithmic normalization (positive values only)
- "sqrt" : square-root (power-law with exponent=0.5)
- "exp" : exponential scaling
- "quantile" : percentile-based scaling
- "none" : full opacity (alpha=1)
Returns
-------
rgba : 3D numpy array (ny, nx, 4)
RGBA array.
"""
# Ensure numpy array
arr = np.asarray(arr, dtype=float)
# Define mask with NaN pixel
mask_na = np.isnan(arr)
# Retrieve array shape
ny, nx = arr.shape
# Define colormap norm
if cmap_norm is None:
cmap_norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
# Define alpha
if scaling == "linear":
norm = Normalize(vmin=np.nanmin(arr), vmax=np.nanmax(arr))
alpha = norm(arr)
elif scaling == "log":
vals = np.where(arr > 0, arr, np.nan) # mask non-positive
norm = LogNorm(vmin=np.nanmin(vals), vmax=np.nanmax(vals))
alpha = norm(arr)
alpha = np.nan_to_num(alpha, nan=0.0)
elif scaling == "sqrt":
alpha = np.sqrt(np.clip(arr, 0, None) / np.nanmax(arr))
elif scaling == "exp":
normed = np.clip(arr / np.nanmax(arr), 0, 1)
alpha = np.expm1(normed) / np.expm1(1)
elif scaling == "quantile":
flat = arr.ravel()
ranks = np.argsort(np.argsort(flat)) # rankdata without scipy
alpha = ranks / (len(flat) - 1)
alpha = alpha.reshape(arr.shape)
elif scaling == "none":
alpha = np.ones_like(arr, dtype=float)
else:
raise ValueError(f"Unknown scaling type: {scaling}")
# Map values to colors
cmap = plt.get_cmap(cmap).copy()
rgba = cmap(cmap_norm(arr))
# Set alpha channel
alpha[mask_na] = 0 # where input was NaN
rgba[..., -1] = np.clip(alpha, 0, 1)
return rgba
[docs]
def to_rgba(obj, cmap="viridis", norm=None, scaling="none"):
"""Map a xarray DataArray (or numpy array) to RGBA with optional alpha-scaling."""
input_is_xarray = False
if isinstance(obj, xr.DataArray):
# Define template for RGBA DataArray
da_rgba = obj.copy()
da_rgba = da_rgba.expand_dims({"rgba": 4}).transpose(..., "rgba")
input_is_xarray = True
# Extract numpy array
obj = obj.to_numpy()
# Apply transparency
arr = _np_to_rgba_alpha(obj, cmap=cmap, cmap_norm=norm, scaling=scaling)
# Return xarray.DataArray
if input_is_xarray:
da_rgba.data = arr
return da_rgba
# Or numpy array otherwise
return arr
[docs]
def max_blend_images(ds_rgb, dim):
"""Max blend a RGBA DataArray across a samples dimensions."""
# Ensure dimension to blend in first position
ds_rgb = ds_rgb.transpose(dim, ...)
# Extract numpy array
stack = ds_rgb.data
# Extract alpha array
alphas = stack[..., 3]
# Select the winning RGBA per pixel # (N, H, W)
idx = np.argmax(alphas, axis=0) # (H, W), index of image with max alpha
idx4 = np.repeat(idx[np.newaxis, ..., np.newaxis], 4, axis=-1) # (1, H, W, 4)
out = np.take_along_axis(stack, idx4, axis=0)[0] # (H, W, 4)
# Create output RGBA array
da = ds_rgb.isel({dim: 0}).copy()
da.data = out
return da
def _create_denseline_grid(indices, ny, nx, nsamples):
# Assign 1 when line pass in a bin
valid = (indices >= 0) & (indices < ny)
s_idx, x_idx = np.nonzero(valid)
y_idx = indices[valid]
# ----------------------------------------------
### Vectorized code with high memory footprint because of 3D array
# # Create 3D array with hits
# grid_3d = np.zeros((nsamples, ny, nx), dtype=np.int64)
# grid_3d[s_idx, y_idx, x_idx] = 1
# # Normalize by columns
# col_sums = grid_3d.sum(axis=1, keepdims=True)
# col_sums[col_sums == 0] = 1 # Avoid division by zero
# grid_3d = grid_3d / col_sums
# # Sum over samples
# grid = grid_3d.sum(axis=0)
# # Free memory
# del grid_3d
# ----------------------------------------------
## Vectorized alternative with much lower memory footprint
# Count hits per (sample, y, x)
grid = np.zeros((ny, nx), dtype=np.float64)
# Compute per-sample-per-column counts
col_counts = np.zeros((nsamples, nx), dtype=np.int64)
np.add.at(col_counts, (s_idx, x_idx), 1)
# Define weights to normalize contributions, avoiding division by zero
# - Weight = 1 / (# hits per column, per sample)
col_counts[col_counts == 0] = 1
weights = 1.0 / col_counts[s_idx, x_idx]
# Accumulate weighted contributions
np.add.at(grid, (y_idx, x_idx), weights)
# Return 2D grid
return grid
def _compute_block_size(ny, nx, dtype=np.float64, safety_margin=2e9):
"""Compute maximum block size given available memory."""
avail_mem = psutil.virtual_memory().available - safety_margin
# Constant cost for final grid
base = ny * nx * np.dtype(dtype).itemsize
# Per-sample cost (worst case, includes col_counts + indices + weights)
per_sample = nx * 40
max_block = (avail_mem - base) // per_sample
return max(1, int(max_block))
[docs]
def compute_dense_lines(
da: xr.DataArray,
coord: str,
x_bins: list,
y_bins: list,
normalization="max",
):
"""
Compute a 2D density-of-lines histogram from an xarray.DataArray.
Parameters
----------
da : xarray.DataArray
Input data array. One of its dimensions (named by ``coord``) is taken
as the horizontal coordinate. All other dimensions are collapsed into
“series,” so that each combination of the remaining dimension values
produces one 1D line along ``coord``.
coord : str
The name of the coordinate/dimension of the DataArray to bin over.
``da.coords[coord]`` must be a 1D numeric array (monotonic is recommended).
x_bins : array_like of shape (nx+1,)
Bin edges to bin the coordinate/dimension.
Must be monotonically increasing.
The number of x-bins will be ``nx = len(x_bins) - 1``.
y_bins : array_like of shape (ny+1,)
Bin edges for the DataArray values.
Must be monotonically increasing.
The number of y-bins will be ``ny = len(y_bins) - 1``.
normalization : bool, optional
If 'none', returns the raw histogram.
By default, the function normalize the histogram by its global maximum ('max').
Log-normalization ('log') is also available.
Returns
-------
xr.DataArray
2D histogram of shape ``(ny, nx)``. Dimensions are ``('y', 'x')``, where:
- ``x``: the bin-center coordinate of ``x_bins`` (length ``nx``)
- ``y``: the bin-center coordinate of ``y_bins`` (length ``ny``)
Each element ``out.values[y_i, x_j]`` is the count (or normalized count) of how
many “series-values” from ``da`` fell into the rectangular bin
``x_bins[j] ≤ x_value < x_bins[j+1]`` and
``y_bins[i] ≤ data_value < y_bins[i+1]``.
References
----------
Moritz, D., Fisher, D. (2018).
Visualizing a Million Time Series with the Density Line Chart
https://doi.org/10.48550/arXiv.1808.06019
"""
# Check DataArray name
if da.name is None or da.name == "":
raise ValueError("The DataArray must have a name.")
# Validate x_bins and y_bins
x_bins = np.asarray(x_bins)
y_bins = np.asarray(y_bins)
if x_bins.ndim != 1 or x_bins.size < 2:
raise ValueError("`x_bins` must be a 1D array with at least two edges.")
if y_bins.ndim != 1 or y_bins.size < 2:
raise ValueError("`y_bins` must be a 1D array with at least two edges.")
if not np.all(np.diff(x_bins) > 0):
raise ValueError("`x_bins` must be strictly increasing.")
if not np.all(np.diff(y_bins) > 0):
raise ValueError("`y_bins` must be strictly increasing.")
# Verify that `coord` exists as either a dimension or a coordinate
if coord not in (list(da.coords) + list(da.dims)):
raise ValueError(f"'{coord}' is not a dimension or coordinate of the DataArray.")
if coord not in da.dims:
if da[coord].ndim != 1:
raise ValueError(f"Coordinate '{coord}' must be 1D. Instead has dimensions {da[coord].dims}")
x_dim = da[coord].dims[0]
else:
x_dim = coord
# Extract the coordinate array
x_values = (x_bins[0:-1] + x_bins[1:]) / 2
# Extract the array (samples, x)
other_dims = [d for d in da.dims if d != x_dim]
if len(other_dims) == 1:
arr = da.transpose(*other_dims, x_dim).to_numpy()
else:
arr = da.stack({"sample": other_dims}).transpose("sample", x_dim).to_numpy()
# Define y bins center
y_center = (y_bins[0:-1] + y_bins[1:]) / 2
# Prepare the 2D count grid of shape (ny, nx)
# - ny correspond tot he value of the timeseries at nx points
nx = len(x_bins) - 1
ny = len(y_bins) - 1
nsamples = arr.shape[0]
# For each (series, x-index), find which y-bin it falls into:
# - np.searchsorted(y_bins, value) gives the insertion index in y_bins;
# --> subtracting 1 yields the bin index.
# If a value is not in y_bins, searchsorted returns 0, so idx = -1
# If a valueis NaN, the indices value will be ny
indices = np.searchsorted(y_bins, arr) - 1 # (samples, nx)
# Compute unormalized DenseLines grid
# grid = _create_denseline_grid(
# indices=indices,
# ny=ny,
# nx=nx,
# nsamples=nsamples
# )
# Compute unormalized DenseLines grid by blocks to avoid running out of memory
# - Define block size based on available RAM memory
block = _compute_block_size(ny=ny, nx=nx, dtype=np.float64, safety_margin=4e9)
list_grid = []
for i in range(0, nsamples, block):
block_start_idx = i
block_end_idx = min(i + block, nsamples)
block_indices = indices[block_start_idx:block_end_idx, :]
block_nsamples = block_end_idx - block_start_idx
block_grid = _create_denseline_grid(indices=block_indices, ny=ny, nx=nx, nsamples=block_nsamples)
list_grid.append(block_grid)
grid_3d = np.stack(list_grid, axis=0)
# Finalize sum over samples
grid = grid_3d.sum(axis=0)
# Normalize grid
grid = normalize_array(grid, method=normalization)
# Create DataArray
name = da.name
out = xr.DataArray(grid, dims=[name, coord], coords={coord: (coord, x_values), name: (name, y_center)})
# Mask values which are 0 with NaN
out = out.where(out > 0)
# Return 2D histogram
return out