Source code for disdrodb.utils.dask
# -----------------------------------------------------------------------------.
# Copyright (c) 2021-2026 DISDRODB developers
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------.
"""Utilities for Dask Distributed Computations."""
import logging
import os
import numpy as np
[docs]
def check_parallel_validity(parallel):
"""Check validity of parallel option given Dask settings."""
import dask
scheduler = dask.config.get("scheduler", None)
if scheduler is None:
return parallel
if scheduler in ["synchronous", "threads", "single-threaded"]:
return False
if scheduler in ["distributed", "dask.distributed"]:
from dask.distributed import default_client
client = default_client()
info = client.scheduler_info()
# If ThreadWorker, only 1 pid
pids = list(client.run(os.getpid).values())
if len(np.unique(pids)) == 1:
return False
# If ProcessWorker
# - Check single thread per worker to avoid locks
nthreads_per_process = np.array([v["nthreads"] for v in info["workers"].values()])
if not np.all(nthreads_per_process == 1):
print(
"To open netCDFs in parallel with dask distributed (processes=True), please set threads_per_worker=1 !",
)
return False
# Otherwise let the user choose
return parallel
[docs]
def initialize_dask_cluster(minimum_memory=None):
"""Initialize Dask Cluster."""
import dask
import psutil
# Silence dask warnings
# dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
# dask.config.set({"distributed.worker.multiprocessing-method": "spawn"})
# dask.config.set({"logging.distributed": "error"})
# Import dask.distributed after setting the config
from dask.distributed import Client, LocalCluster
from dask.utils import parse_bytes
# Set HDF5_USE_FILE_LOCKING to avoid going stuck with HDF
os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
# Retrieve the number of processes to run
# --> If DASK_NUM_WORKERS is not set, use all CPUs minus 2
available_workers = os.cpu_count() - 2 # if not set, all CPUs minus 2
num_workers = dask.config.get("num_workers", available_workers)
# If memory limit specified, ensure correct amount of workers
if minimum_memory is not None:
# Compute available memory (in bytes)
total_memory = psutil.virtual_memory().total
# Get minimum memory per worker (in bytes)
minimum_memory = parse_bytes(minimum_memory)
# Determine number of workers constrained by memory
maximum_workers_allowed = max(1, total_memory // minimum_memory)
# Respect both CPU and memory requirements
num_workers = min(maximum_workers_allowed, num_workers)
# Create dask.distributed local cluster
cluster = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
processes=True,
memory_limit=0, # this avoid flexible dask memory management
silence_logs=logging.ERROR,
)
client = Client(cluster)
return cluster, client
[docs]
def close_dask_cluster(cluster, client):
"""Close Dask Cluster."""
logger = logging.getLogger()
# Backup current log level
original_level = logger.level
logger.setLevel(logging.CRITICAL + 1) # Set level to suppress all logs
# Close cluster
# - Avoid log 'distributed.worker - ERROR - Failed to communicate with scheduler during heartbeat.'
try:
cluster.close()
client.close()
finally:
# Restore the original log level
logger.setLevel(original_level)
def _batch_iterable(iterable, n):
"""Yield successive n-sized chunks from iterable."""
for i in range(0, len(iterable), n):
yield iterable[i : i + n]
[docs]
def execute_tasks_safely(list_tasks, parallel: bool, logs_dir: str, max_tasks_per_batch=5_000):
"""
Execute Dask tasks and skip failed ones.
Parameters
----------
list_tasks : list
List of dask delayed objects or results.
parallel : bool
Whether to execute in parallel with Dask or not.
logs_dir : str
Directory to store FAILED_TASKS.log.
max_tasks_per_batch : int or None, optional
Maximum number of tasks to submit to `client.compute()` at once.
The default is 5000. Dask struggle if more than 10_000 tasks are submitted.
Returns
-------
list_logs : list
List of task results. For failed tasks, adds the path
to FAILED_TASKS.log in place of the result.
"""
from dask.distributed import get_client
if not parallel:
# Non-parallel mode: just return results directly
return list_tasks
# Ensure logs_dir exists
os.makedirs(logs_dir, exist_ok=True)
# Define file name where to log failed dask tasks
failed_log_path = os.path.join(logs_dir, "FAILED_DASK_TASKS.log")
# Ensure we have a Dask client
try:
client = get_client()
except ValueError:
raise ValueError("No Dask Distributed Client found.")
all_results = []
failed_futures = []
# Batch execution
task_batches = list(_batch_iterable(list_tasks, max_tasks_per_batch)) if max_tasks_per_batch else [list_tasks]
for batch in task_batches:
# Compute tasks (all concurrently)
# - Runs tasks == num_workers * threads_per_worker (which is 1 for DISDRODB)
# - If errors occurs in some, skip it
futures = client.compute(batch)
results = client.gather(futures, errors="skip")
# Identify and collect failed futures
batch_failed = [f for f in futures if f.status != "finished"]
failed_futures.extend(batch_failed)
# Collect results from successful tasks
all_results.extend(results)
# If no tasks failed, return results
if not failed_futures:
return all_results
# Otherwise define log file listing failed tasks
with open(failed_log_path, "w") as f:
for fut in failed_futures:
err = fut.exception()
f.write(f"ERROR - DASK TASK FAILURE - Task {fut.key} failed: {err}\n")
# Append to list of log filepaths (results) the dask failing log
all_results.append(failed_log_path)
return all_results