Source code for distributed.worker

from __future__ import annotations

import asyncio
import bisect
import builtins
import contextlib
import contextvars
import errno
import logging
import math
import os
import pathlib
import random
import sys
import threading
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
    Callable,
    Collection,
    Container,
    Hashable,
    Iterable,
    Mapping,
    MutableMapping,
)
from concurrent.futures import Executor
from contextlib import suppress
from datetime import timedelta
from functools import wraps
from inspect import isawaitable
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    TextIO,
    TypedDict,
    TypeVar,
    cast,
)

from tlz import keymap, pluck
from tornado.ioloop import IOLoop

import dask
from dask._task_spec import GraphNode
from dask.system import CPU_COUNT
from dask.typing import Key
from dask.utils import (
    format_bytes,
    funcname,
    key_split,
    parse_bytes,
    parse_timedelta,
    tmpdir,
    typename,
)

from distributed import preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
from distributed.comm import resolve_address as comm_resolve_address
from distributed.comm.addressing import address_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
    ConnectionPool,
    ErrorMessage,
    OKMessage,
    PooledRPCCall,
    Status,
    coerce_to_address,
    context_meter_to_server_digest,
    error_message,
    pingpong,
)
from distributed.core import rpc as RPCType
from distributed.core import send_recv
from distributed.diagnostics import nvml, rmm
from distributed.diagnostics.plugin import WorkerPlugin, _get_plugin_name
from distributed.diskutils import WorkSpace
from distributed.exceptions import Reschedule
from distributed.gc import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.http import get_handlers
from distributed.metrics import context_meter, thread_time, time
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.pubsub import PubSubWorkerExtension
from distributed.security import Security
from distributed.sizeof import safe_sizeof as sizeof
from distributed.spans import CONTEXTS_WITH_SPAN_ID, SpansWorkerExtension
from distributed.threadpoolexecutor import ThreadPoolExecutor
from distributed.threadpoolexecutor import secede as tpe_secede
from distributed.utils import (
    TimeoutError,
    get_ip,
    has_arg,
    in_async_call,
    iscoroutinefunction,
    json_load_robust,
    log_errors,
    offload,
    parse_ports,
    recursive_to_dict,
    run_in_executor_with_context,
    set_thread_state,
    silence_logging_cmgr,
    thread_state,
    wait_for,
)
from distributed.utils_comm import gather_from_workers, retry_operation
from distributed.versions import get_versions
from distributed.worker_memory import (
    DeprecatedMemoryManagerAttribute,
    DeprecatedMemoryMonitor,
    WorkerDataParameter,
    WorkerMemoryManager,
)
from distributed.worker_state_machine import (
    AcquireReplicasEvent,
    BaseWorker,
    CancelComputeEvent,
    ComputeTaskEvent,
    DeprecatedWorkerStateAttribute,
    ExecuteFailureEvent,
    ExecuteSuccessEvent,
    FindMissingEvent,
    FreeKeysEvent,
    GatherDepBusyEvent,
    GatherDepFailureEvent,
    GatherDepNetworkFailureEvent,
    GatherDepSuccessEvent,
    PauseEvent,
    RefreshWhoHasEvent,
    RemoveReplicasEvent,
    RemoveWorkerEvent,
    RescheduleEvent,
    RetryBusyWorkerEvent,
    SecedeEvent,
    StateMachineEvent,
    StealRequestEvent,
    TaskState,
    UnpauseEvent,
    UpdateDataEvent,
    WorkerState,
)

if TYPE_CHECKING:
    # FIXME import from typing (needs Python >=3.10)
    from typing_extensions import ParamSpec

    # Circular imports
    from distributed.client import Client
    from distributed.nanny import Nanny

    P = ParamSpec("P")
    T = TypeVar("T")

logger = logging.getLogger(__name__)

LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")

DEFAULT_EXTENSIONS: dict[str, type] = {
    "pubsub": PubSubWorkerExtension,
    "spans": SpansWorkerExtension,
}

DEFAULT_METRICS: dict[str, Callable[[Worker], Any]] = {}

DEFAULT_STARTUP_INFORMATION: dict[str, Callable[[Worker], Any]] = {}

WORKER_ANY_RUNNING = {
    Status.running,
    Status.paused,
    Status.closing_gracefully,
}


class RunTaskSuccess(OKMessage):
    op: Literal["task-finished"]
    result: object
    nbytes: int
    type: type
    start: float
    stop: float
    thread: int


class RunTaskFailure(ErrorMessage):
    op: Literal["task-erred"]
    result: object
    nbytes: int
    type: type
    start: float
    stop: float
    thread: int
    actual_exception: BaseException | Exception


class GetDataBusy(TypedDict):
    status: Literal["busy"]


class GetDataSuccess(TypedDict):
    status: Literal["OK"]
    data: dict[Key, object]


def fail_hard(method: Callable[P, T]) -> Callable[P, T]:
    """
    Decorator to close the worker if this method encounters an exception.
    """
    reason = f"worker-{method.__name__}-fail-hard"
    if iscoroutinefunction(method):

        @wraps(method)
        async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> Any:
            try:
                return await method(self, *args, **kwargs)  # type: ignore
            except Exception as e:
                if self.status not in (Status.closed, Status.closing):
                    self.log_event("worker-fail-hard", error_message(e))
                    logger.exception(e)
                await _force_close(self, reason)
                raise

    else:

        @wraps(method)
        def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
            try:
                return method(self, *args, **kwargs)
            except Exception as e:
                if self.status not in (Status.closed, Status.closing):
                    self.log_event("worker-fail-hard", error_message(e))
                    logger.exception(e)
                self.loop.add_callback(_force_close, self, reason)
                raise

    return wrapper  # type: ignore


async def _force_close(self, reason: str):
    """
    Used with the fail_hard decorator defined above

    1.  Wait for a worker to close
    2.  If it doesn't, log and kill the process
    """
    try:
        await wait_for(
            self.close(nanny=False, executor_wait=False, reason=reason),
            30,
        )
    except (KeyboardInterrupt, SystemExit):  # pragma: nocover
        raise
    except BaseException:  # pragma: nocover
        # Worker is in a very broken state if closing fails. We need to shut down
        # immediately, to ensure things don't get even worse and this worker potentially
        # deadlocks the cluster.
        from distributed import Scheduler

        if Scheduler._instances:
            # We're likely in a unit test. Don't kill the whole test suite!
            raise

        logger.critical(
            "Error trying close worker in response to broken internal state. "
            "Forcibly exiting worker NOW",
            exc_info=True,
        )
        # use `os._exit` instead of `sys.exit` because of uncertainty
        # around propagating `SystemExit` from asyncio callbacks
        os._exit(1)


[docs]class Worker(BaseWorker, ServerNode): """Worker node in a Dask distributed cluster Workers perform two functions: 1. **Serve data** from a local dictionary 2. **Perform computation** on that data and on data from peers Workers keep the scheduler informed of their data and use that scheduler to gather data from other workers when necessary to perform a computation. You can start a worker with the ``dask worker`` command line application:: $ dask worker scheduler-ip:port Use the ``--help`` flag to see more options:: $ dask worker --help The rest of this docstring is about the internal state that the worker uses to manage and track internal computations. **State** **Informational State** These attributes don't change significantly during execution. * **nthreads:** ``int``: Number of nthreads used by this worker process * **executors:** ``dict[str, concurrent.futures.Executor]``: Executors used to perform computation. Always contains the default executor. * **local_directory:** ``path``: Path on local machine to store temporary files * **scheduler:** ``PooledRPCCall``: Location of scheduler. See ``.ip/.port`` attributes. * **name:** ``string``: Alias * **services:** ``{str: Server}``: Auxiliary web servers running on this worker * **service_ports:** ``{str: port}``: * **transfer_outgoing_count_limit**: ``int`` The maximum number of concurrent outgoing data transfers. See also :attr:`distributed.worker_state_machine.WorkerState.transfer_incoming_count_limit`. * **batched_stream**: ``BatchedSend`` A batched stream along which we communicate to the scheduler * **log**: ``[(message)]`` A structured and queryable log. See ``Worker.story`` **Volatile State** These attributes track the progress of tasks that this worker is trying to complete. In the descriptions below a ``key`` is the name of a task that we want to compute and ``dep`` is the name of a piece of dependent data that we want to collect from others. * **threads**: ``{key: int}`` The ID of the thread on which the task ran * **active_threads**: ``{int: key}`` The keys currently running on active threads * **state**: ``WorkerState`` Encapsulated state machine. See :class:`~distributed.worker_state_machine.BaseWorker` and :class:`~distributed.worker_state_machine.WorkerState` Parameters ---------- scheduler_ip: str, optional scheduler_port: int, optional scheduler_file: str, optional host: str, optional data: MutableMapping, type, None The object to use for storage, builds a disk-backed LRU dict by default. If a callable to construct the storage object is provided, it will receive the worker's attr:``local_directory`` as an argument if the calling signature has an argument named ``worker_local_directory``. nthreads: int, optional local_directory: str, optional Directory where we place local resources name: str, optional memory_limit: int, float, string Number of bytes of memory that this worker should use. Set to zero for no limit. Set to 'auto' to calculate as system.MEMORY_LIMIT * min(1, nthreads / total_cores) Use strings or numbers like 5GB or 5e9 memory_target_fraction: float or False Fraction of memory to try to stay beneath (default: read from config key distributed.worker.memory.target) memory_spill_fraction: float or False Fraction of memory at which we start spilling to disk (default: read from config key distributed.worker.memory.spill) memory_pause_fraction: float or False Fraction of memory at which we stop running new tasks (default: read from config key distributed.worker.memory.pause) max_spill: int, string or False Limit of number of bytes to be spilled on disk. (default: read from config key distributed.worker.memory.max-spill) executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload" The executor(s) to use. Depending on the type, it has the following meanings: - Executor instance: The default executor. - Dict[str, Executor]: mapping names to Executor instances. If the "default" key isn't in the dict, a "default" executor will be created using ``ThreadPoolExecutor(nthreads)``. - Str: The string "offload", which refer to the same thread pool used for offloading communications. This results in the same thread being used for deserialization and computation. resources: dict Resources that this worker has like ``{'GPU': 2}`` nanny: str Address on which to contact nanny, if it exists lifetime: str Amount of time like "1 hour" after which we gracefully shut down the worker. This defaults to None, meaning no explicit shutdown time. lifetime_stagger: str Amount of time like "5 minutes" to stagger the lifetime value The actual lifetime will be selected uniformly at random between lifetime +/- lifetime_stagger lifetime_restart: bool Whether or not to restart a worker after it has reached its lifetime Default False kwargs: optional Additional parameters to ServerNode constructor Examples -------- Use the command line to start a worker:: $ dask scheduler Start scheduler at 127.0.0.1:8786 $ dask worker 127.0.0.1:8786 Start worker at: 127.0.0.1:1234 Registered with scheduler at: 127.0.0.1:8786 See Also -------- distributed.scheduler.Scheduler distributed.nanny.Nanny """ _instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet() _initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet() nanny: Nanny | None _lock: threading.Lock transfer_outgoing_count_limit: int threads: dict[Key, int] # {ts.key: thread ID} active_threads_lock: threading.Lock active_threads: dict[int, Key] # {thread ID: ts.key} active_keys: set[Key] profile_keys: defaultdict[str, dict[str, Any]] profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]] profile_recent: dict[str, Any] profile_history: deque[tuple[float, dict[str, Any]]] transfer_incoming_log: deque[dict[str, Any]] transfer_outgoing_log: deque[dict[str, Any]] #: Total number of data transfers to other workers since the worker was started transfer_outgoing_count_total: int #: Total size of data transfers to other workers (including in-progress and failed transfers) transfer_outgoing_bytes_total: int #: Current total size of open data transfers to other workers transfer_outgoing_bytes: int #: Current number of open data transfers to other workers transfer_outgoing_count: int bandwidth: float latency: float profile_cycle_interval: float workspace: WorkSpace _client: Client | None bandwidth_workers: defaultdict[str, tuple[float, int]] bandwidth_types: defaultdict[type, tuple[float, int]] preloads: preloading.PreloadManager contact_address: str | None _start_port: int | str | Collection[int] | None = None _start_host: str | None _interface: str | None _protocol: str _dashboard_address: str | None _dashboard: bool _http_prefix: str death_timeout: float | None lifetime: float | None lifetime_stagger: float | None lifetime_restart: bool extensions: dict security: Security connection_args: dict[str, Any] loop: IOLoop executors: dict[str, Executor] batched_stream: BatchedSend name: Any scheduler_delay: float stream_comms: dict[str, BatchedSend] heartbeat_interval: float services: dict[str, Any] = {} service_specs: dict[str, Any] metrics: dict[str, Callable[[Worker], Any]] startup_information: dict[str, Callable[[Worker], Any]] low_level_profiler: bool scheduler: PooledRPCCall execution_state: dict[str, Any] plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] def __init__( self, scheduler_ip: str | None = None, scheduler_port: int | None = None, *, scheduler_file: str | None = None, nthreads: int | None = None, loop: IOLoop | None = None, # Deprecated local_directory: str | None = None, services: dict | None = None, name: Any | None = None, reconnect: bool | None = None, executor: Executor | dict[str, Executor] | Literal["offload"] | None = None, resources: dict[str, float] | None = None, silence_logs: int | None = None, death_timeout: Any | None = None, preload: list[str] | None = None, preload_argv: list[str] | list[list[str]] | None = None, security: Security | dict[str, Any] | None = None, contact_address: str | None = None, heartbeat_interval: Any = "1s", extensions: dict[str, type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ str, Callable[[Worker], Any] ] = DEFAULT_STARTUP_INFORMATION, interface: str | None = None, host: str | None = None, port: int | str | Collection[int] | None = None, protocol: str | None = None, dashboard_address: str | None = None, dashboard: bool = False, http_prefix: str = "/", nanny: Nanny | None = None, plugins: tuple[WorkerPlugin, ...] = (), low_level_profiler: bool | None = None, validate: bool | None = None, profile_cycle_interval=None, lifetime: Any | None = None, lifetime_stagger: Any | None = None, lifetime_restart: bool | None = None, transition_counter_max: int | Literal[False] = False, ################################### # Parameters to WorkerMemoryManager memory_limit: str | float = "auto", # Allow overriding the dict-like that stores the task outputs. # This is meant for power users only. See WorkerMemoryManager for details. data: WorkerDataParameter = None, # Deprecated parameters; please use dask config instead. memory_target_fraction: float | Literal[False] | None = None, memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, ################################### # Parameters to Server scheduler_sni: str | None = None, WorkerStateClass: type = WorkerState, **kwargs, ): if reconnect is not None: if reconnect: raise ValueError( "The `reconnect=True` option for `Worker` has been removed. " "To improve cluster stability, workers now always shut down in the face of network disconnects. " "For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350." ) else: warnings.warn( "The `reconnect` argument to `Worker` is deprecated, and will be removed in a future release. " "Worker reconnection is now always disabled, so passing `reconnect=False` is unnecessary. " "See https://github.com/dask/distributed/issues/6350 for details.", DeprecationWarning, stacklevel=2, ) if loop is not None: warnings.warn( "The `loop` argument to `Worker` is ignored, and will be removed in a future release. " "The Worker always binds to the current loop", DeprecationWarning, stacklevel=2, ) self.__exit_stack = stack = contextlib.ExitStack() self.nanny = nanny self._lock = threading.Lock() transfer_incoming_count_limit = dask.config.get( "distributed.worker.connections.outgoing" ) self.transfer_outgoing_count_limit = dask.config.get( "distributed.worker.connections.incoming" ) transfer_message_bytes_limit = parse_bytes( dask.config.get("distributed.worker.transfer.message-bytes-limit") ) self.threads = {} self.active_threads_lock = threading.Lock() self.active_threads = {} self.active_keys = set() self.profile_keys = defaultdict(profile.create) maxlen = dask.config.get("distributed.admin.low-level-log-length") self.profile_keys_history = deque(maxlen=maxlen) self.profile_history = deque(maxlen=maxlen) self.profile_recent = profile.create() if validate is None: validate = dask.config.get("distributed.worker.validate") self.transfer_incoming_log = deque(maxlen=maxlen) self.transfer_outgoing_log = deque(maxlen=maxlen) self.transfer_outgoing_count_total = 0 self.transfer_outgoing_bytes_total = 0 self.transfer_outgoing_bytes = 0 self.transfer_outgoing_count = 0 self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.bandwidth_workers = defaultdict( lambda: (0, 0) ) # bw/count recent transfers self.bandwidth_types = defaultdict(lambda: (0, 0)) # bw/count recent transfers self.latency = 0.001 self._client = None if profile_cycle_interval is None: profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle") profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") assert profile_cycle_interval self._setup_logging(logger) self.death_timeout = parse_timedelta(death_timeout) self.contact_address = contact_address self._start_port = port self._start_host = host if host: # Helpful error message if IPv6 specified incorrectly _, host_address = parse_address(host) if host_address.count(":") > 1 and not host_address.startswith("["): raise ValueError( "Host address with IPv6 must be bracketed like '[::1]'; " f"got {host_address}" ) self._interface = interface nthreads = nthreads or CPU_COUNT if resources is None: resources = dask.config.get("distributed.worker.resources") assert isinstance(resources, dict) self.extensions = {} if silence_logs: stack.enter_context(silence_logging_cmgr(level=silence_logs)) if isinstance(security, dict): security = Security(**security) self.security = security or Security() assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") self.loop = self.io_loop = IOLoop.current() if scheduler_sni: self.connection_args["server_hostname"] = scheduler_sni # Common executors always available self.executors = { "offload": utils._offload_executor, "actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"), } # Find the default executor if executor == "offload": self.executors["default"] = self.executors["offload"] elif isinstance(executor, dict): self.executors.update(executor) elif executor is not None: self.executors["default"] = executor if "default" not in self.executors: self.executors["default"] = ThreadPoolExecutor( nthreads, thread_name_prefix="Dask-Default-Threads" ) self.batched_stream = BatchedSend(interval="2ms", loop=self.loop) self.name = name self.scheduler_delay = 0 self.stream_comms = {} self.plugins = {} self._pending_plugins = plugins self.services = {} self.service_specs = services or {} self._dashboard_address = dashboard_address self._dashboard = dashboard self._http_prefix = http_prefix self.metrics = dict(metrics) if metrics else {} self.startup_information = ( dict(startup_information) if startup_information else {} ) if low_level_profiler is None: low_level_profiler = dask.config.get("distributed.worker.profile.low-level") self.low_level_profiler = low_level_profiler handlers = { "gather": self.gather, "run": self.run, "run_coroutine": self.run_coroutine, "get_data": self.get_data, "update_data": self.update_data, "free_keys": self._handle_remote_stimulus(FreeKeysEvent), "terminate": self.close, "ping": pingpong, "upload_file": self.upload_file, "call_stack": self.get_call_stack, "profile": self.get_profile, "profile_metadata": self.get_profile_metadata, "get_logs": self.get_logs, "keys": self.keys, "versions": self.versions, "actor_execute": self.actor_execute, "actor_attribute": self.actor_attribute, "plugin-add": self.plugin_add, "plugin-remove": self.plugin_remove, "get_monitor_info": self.get_monitor_info, "benchmark_disk": self.benchmark_disk, "benchmark_memory": self.benchmark_memory, "benchmark_network": self.benchmark_network, "get_story": self.get_story, } stream_handlers = { "close": self.close, "cancel-compute": self._handle_remote_stimulus(CancelComputeEvent), "acquire-replicas": self._handle_remote_stimulus(AcquireReplicasEvent), "compute-task": self._handle_remote_stimulus(ComputeTaskEvent), "free-keys": self._handle_remote_stimulus(FreeKeysEvent), "remove-replicas": self._handle_remote_stimulus(RemoveReplicasEvent), "steal-request": self._handle_remote_stimulus(StealRequestEvent), "refresh-who-has": self._handle_remote_stimulus(RefreshWhoHasEvent), "worker-status-change": self.handle_worker_status_change, "remove-worker": self._handle_remove_worker, } ServerNode.__init__( self, handlers=handlers, stream_handlers=stream_handlers, connection_args=self.connection_args, local_directory=local_directory, **kwargs, ) if not preload: preload = dask.config.get("distributed.worker.preload") if not preload_argv: preload_argv = dask.config.get("distributed.worker.preload-argv") assert preload is not None assert preload_argv is not None self.preloads = preloading.process_preloads( self, preload, preload_argv, file_dir=self.local_directory ) if scheduler_file: cfg = json_load_robust(scheduler_file, timeout=self.death_timeout) scheduler_addr = cfg["address"] elif scheduler_ip is None and dask.config.get("scheduler-address", None): scheduler_addr = dask.config.get("scheduler-address") elif scheduler_port is None: scheduler_addr = coerce_to_address(scheduler_ip) else: scheduler_addr = coerce_to_address((scheduler_ip, scheduler_port)) if protocol is None: protocol_address = scheduler_addr.split("://") if len(protocol_address) == 2: protocol = protocol_address[0] assert protocol self._protocol = protocol self.memory_manager = WorkerMemoryManager( self, data=data, nthreads=nthreads, memory_limit=memory_limit, memory_target_fraction=memory_target_fraction, memory_spill_fraction=memory_spill_fraction, memory_pause_fraction=memory_pause_fraction, ) transfer_incoming_bytes_limit = math.inf transfer_incoming_bytes_fraction = dask.config.get( "distributed.worker.memory.transfer" ) if ( self.memory_manager.memory_limit is not None and transfer_incoming_bytes_fraction is not False ): transfer_incoming_bytes_limit = int( self.memory_manager.memory_limit * transfer_incoming_bytes_fraction ) state = WorkerStateClass( nthreads=nthreads, data=self.memory_manager.data, threads=self.threads, plugins=self.plugins, resources=resources, transfer_incoming_count_limit=transfer_incoming_count_limit, validate=validate, transition_counter_max=transition_counter_max, transfer_incoming_bytes_limit=transfer_incoming_bytes_limit, transfer_message_bytes_limit=transfer_message_bytes_limit, ) BaseWorker.__init__(self, state) self.scheduler = self.rpc(scheduler_addr) self.execution_state = { "scheduler": self.scheduler.address, "ioloop": self.loop, "worker": self, } self.heartbeat_interval = parse_timedelta(heartbeat_interval, default="ms") pc = PeriodicCallback(self.heartbeat, self.heartbeat_interval * 1000) self.periodic_callbacks["heartbeat"] = pc pc = PeriodicCallback(lambda: self.batched_send({"op": "keep-alive"}), 60000) self.periodic_callbacks["keep-alive"] = pc pc = PeriodicCallback(self.find_missing, 1000) self.periodic_callbacks["find-missing"] = pc self._address = contact_address if extensions is None: extensions = DEFAULT_EXTENSIONS self.extensions = { name: extension(self) for name, extension in extensions.items() } setproctitle("dask worker [not started]") if dask.config.get("distributed.worker.profile.enabled"): profile_trigger_interval = parse_timedelta( dask.config.get("distributed.worker.profile.interval"), default="ms" ) pc = PeriodicCallback(self.trigger_profile, profile_trigger_interval * 1000) self.periodic_callbacks["profile"] = pc pc = PeriodicCallback(self.cycle_profile, profile_cycle_interval * 1000) self.periodic_callbacks["profile-cycle"] = pc if lifetime is None: lifetime = dask.config.get("distributed.worker.lifetime.duration") lifetime = parse_timedelta(lifetime) if lifetime_stagger is None: lifetime_stagger = dask.config.get("distributed.worker.lifetime.stagger") lifetime_stagger = parse_timedelta(lifetime_stagger) if lifetime_restart is None: lifetime_restart = dask.config.get("distributed.worker.lifetime.restart") self.lifetime_restart = lifetime_restart if lifetime: lifetime += (random.random() * 2 - 1) * lifetime_stagger self.io_loop.call_later( lifetime, self.close_gracefully, reason="worker-lifetime-reached" ) self.lifetime = lifetime Worker._instances.add(self) ################ # Memory manager ################ memory_manager: WorkerMemoryManager @property def data(self) -> MutableMapping[Key, object]: """{task key: task payload} of all completed tasks, whether they were computed on this Worker or computed somewhere else and then transferred here over the network. When using the default configuration, this is a zict buffer that automatically spills to disk whenever the target threshold is exceeded. If spilling is disabled, it is a plain dict instead. It could also be a user-defined arbitrary dict-like passed when initialising the Worker or the Nanny. Worker logic should treat this opaquely and stick to the MutableMapping API. .. note:: This same collection is also available at ``self.state.data`` and ``self.memory_manager.data``. """ return self.memory_manager.data # Deprecated attributes moved to self.memory_manager.<name> memory_limit = DeprecatedMemoryManagerAttribute() memory_target_fraction = DeprecatedMemoryManagerAttribute() memory_spill_fraction = DeprecatedMemoryManagerAttribute() memory_pause_fraction = DeprecatedMemoryManagerAttribute() memory_monitor = DeprecatedMemoryMonitor() ########################### # State machine accessors # ########################### # Deprecated attributes moved to self.state.<name> actors = DeprecatedWorkerStateAttribute() available_resources = DeprecatedWorkerStateAttribute() busy_workers = DeprecatedWorkerStateAttribute() comm_nbytes = DeprecatedWorkerStateAttribute(target="transfer_incoming_bytes") comm_threshold_bytes = DeprecatedWorkerStateAttribute( target="transfer_incoming_bytes_throttle_threshold" ) constrained = DeprecatedWorkerStateAttribute() data_needed_per_worker = DeprecatedWorkerStateAttribute(target="data_needed") executed_count = DeprecatedWorkerStateAttribute() executing_count = DeprecatedWorkerStateAttribute() generation = DeprecatedWorkerStateAttribute() has_what = DeprecatedWorkerStateAttribute() incoming_count = DeprecatedWorkerStateAttribute( target="transfer_incoming_count_total" ) in_flight_tasks = DeprecatedWorkerStateAttribute(target="in_flight_tasks_count") in_flight_workers = DeprecatedWorkerStateAttribute() log = DeprecatedWorkerStateAttribute() long_running = DeprecatedWorkerStateAttribute() nthreads = DeprecatedWorkerStateAttribute() stimulus_log = DeprecatedWorkerStateAttribute() stimulus_story = DeprecatedWorkerStateAttribute() story = DeprecatedWorkerStateAttribute() ready = DeprecatedWorkerStateAttribute() tasks = DeprecatedWorkerStateAttribute() target_message_size = DeprecatedWorkerStateAttribute( target="transfer_message_bytes_limit" ) total_out_connections = DeprecatedWorkerStateAttribute( target="transfer_incoming_count_limit" ) total_resources = DeprecatedWorkerStateAttribute() transition_counter = DeprecatedWorkerStateAttribute() transition_counter_max = DeprecatedWorkerStateAttribute() validate = DeprecatedWorkerStateAttribute() validate_task = DeprecatedWorkerStateAttribute() @property def data_needed(self) -> set[TaskState]: warnings.warn( "The `Worker.data_needed` attribute has been removed; " "use `Worker.state.data_needed[address]`", FutureWarning, ) return {ts for tss in self.state.data_needed.values() for ts in tss} @property def waiting_for_data_count(self) -> int: warnings.warn( "The `Worker.waiting_for_data_count` attribute has been removed; " "use `len(Worker.state.waiting)`", FutureWarning, ) return len(self.state.waiting) ################## # Administrative # ################## def __repr__(self) -> str: name = f", name: {self.name}" if self.name != self.address_safe else "" return ( f"<{self.__class__.__name__} {self.address_safe!r}{name}, " f"status: {self.status.name}, " f"stored: {len(self.data)}, " f"running: {self.state.executing_count}/{self.state.nthreads}, " f"ready: {len(self.state.ready)}, " f"comm: {self.state.in_flight_tasks_count}, " f"waiting: {len(self.state.waiting)}>" ) @property def logs(self): return self._deque_handler.deque
[docs] def log_event(self, topic: str | Collection[str], msg: Any) -> None: """Log an event under a given topic Parameters ---------- topic : str, list[str] Name of the topic under which to log an event. To log the same event under multiple topics, pass a list of topic names. msg Event message to log. Note this must be msgpack serializable. See also -------- Client.log_event """ if not _is_dumpable(msg): raise TypeError( f"Message must be msgpack serializable. Got {type(msg)=} instead." ) full_msg = { "op": "log-event", "topic": topic, "msg": msg, } if self.thread_id == threading.get_ident(): self.batched_send(full_msg) else: self.loop.add_callback(self.batched_send, full_msg)
@property def worker_address(self): """For API compatibility with Nanny""" return self.address @property def executor(self): return self.executors["default"] @ServerNode.status.setter # type: ignore def status(self, value: Status) -> None: """Override Server.status to notify the Scheduler of status changes. Also handles pausing/unpausing. """ prev_status = self.status ServerNode.status.__set__(self, value) # type: ignore stimulus_id = f"worker-status-change-{time()}" self._send_worker_status_change(stimulus_id) if prev_status == Status.running and value != Status.running: self.handle_stimulus(PauseEvent(stimulus_id=stimulus_id)) elif value == Status.running and prev_status in ( Status.paused, Status.closing_gracefully, ): self.handle_stimulus(UnpauseEvent(stimulus_id=stimulus_id)) def _send_worker_status_change(self, stimulus_id: str) -> None: self.batched_send( { "op": "worker-status-change", "status": self._status.name, "stimulus_id": stimulus_id, }, ) async def get_metrics(self) -> dict: try: spilled_memory, spilled_disk = self.data.spilled_total # type: ignore except AttributeError: # spilling is disabled spilled_memory, spilled_disk = 0, 0 # Send Fine Performance Metrics # Swap the dictionary to avoid updates while we iterate over it digests_total_since_heartbeat = self.digests_total_since_heartbeat self.digests_total_since_heartbeat = defaultdict(int) spans_ext: SpansWorkerExtension | None = self.extensions.get("spans") if spans_ext: # Send metrics with disaggregated span_id spans_ext.collect_digests(digests_total_since_heartbeat) # Send metrics with squashed span_id # Don't cast int metrics to float digests: defaultdict[Hashable, float] = defaultdict(int) for k, v in digests_total_since_heartbeat.items(): if isinstance(k, tuple) and k[0] in CONTEXTS_WITH_SPAN_ID: k = k[:1] + k[2:] digests[k] += v out: dict = dict( task_counts=self.state.task_counter.current_count(by_prefix=False), bandwidth={ "total": self.bandwidth, "workers": dict(self.bandwidth_workers), "types": keymap(typename, self.bandwidth_types), }, digests_total_since_heartbeat=dict(digests), managed_bytes=self.state.nbytes, spilled_bytes={ "memory": spilled_memory, "disk": spilled_disk, }, transfer={ "incoming_bytes": self.state.transfer_incoming_bytes, "incoming_count": self.state.transfer_incoming_count, "incoming_count_total": self.state.transfer_incoming_count_total, "outgoing_bytes": self.transfer_outgoing_bytes, "outgoing_count": self.transfer_outgoing_count, "outgoing_count_total": self.transfer_outgoing_count_total, }, event_loop_interval=self._tick_interval_observed, ) monitor_recent = self.monitor.recent() # Convert {foo.bar: 123} to {foo: {bar: 123}} for k, v in monitor_recent.items(): if "." in k: k0, _, k1 = k.partition(".") out.setdefault(k0, {})[k1] = v else: out[k] = v for k, metric in self.metrics.items(): try: result = metric(self) if isawaitable(result): result = await result # In case of collision, prefer core metrics out.setdefault(k, result) except Exception: # TODO: log error once pass return out async def get_startup_information(self): result = {} for k, f in self.startup_information.items(): try: v = f(self) if isawaitable(v): v = await v result[k] = v except Exception: # TODO: log error once pass return result def identity(self): return { "type": type(self).__name__, "id": self.id, "scheduler": self.scheduler.address, "nthreads": self.state.nthreads, "memory_limit": self.memory_manager.memory_limit, } def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. See also -------- Worker.identity Client.dump_cluster_state distributed.utils.recursive_to_dict """ info = super()._to_dict(exclude=exclude) extra = { "status": self.status, "logs": self.get_logs(), "config": dask.config.config, "transfer_incoming_log": self.transfer_incoming_log, "transfer_outgoing_log": self.transfer_outgoing_log, } extra = {k: v for k, v in extra.items() if k not in exclude} info.update(extra) info.update(self.state._to_dict(exclude=exclude)) info.update(self.memory_manager._to_dict(exclude=exclude)) return recursive_to_dict(info, exclude=exclude) ##################### # External Services # #####################
[docs] def batched_send(self, msg: dict[str, Any]) -> None: """Implements BaseWorker abstract method. Send a fire-and-forget message to the scheduler through bulk comms. If we're not currently connected to the scheduler, the message will be silently dropped! See also -------- distributed.worker_state_machine.BaseWorker.batched_send """ if ( self.batched_stream and self.batched_stream.comm and not self.batched_stream.comm.closed() ): self.batched_stream.send(msg)
async def _register_with_scheduler(self) -> None: self.periodic_callbacks["keep-alive"].stop() self.periodic_callbacks["heartbeat"].stop() start = time() if self.contact_address is None: self.contact_address = self.address logger.info("-" * 49) # Worker reconnection is not supported assert not self.data assert not self.state.tasks while True: try: _start = time() comm = await connect(self.scheduler.address, **self.connection_args) comm.name = "Worker->Scheduler" comm._server = weakref.ref(self) await comm.write( dict( op="register-worker", reply=False, address=self.contact_address, status=self.status.name, nthreads=self.state.nthreads, name=self.name, now=time(), resources=self.state.total_resources, memory_limit=self.memory_manager.memory_limit, local_directory=self.local_directory, services=self.service_ports, nanny=self.nanny, pid=os.getpid(), versions=get_versions(), metrics=await self.get_metrics(), extra=await self.get_startup_information(), stimulus_id=f"worker-connect-{time()}", server_id=self.id, ), serializers=["msgpack"], ) future = comm.read(deserializers=["msgpack"]) response = await future if response.get("warning"): logger.warning(response["warning"]) _end = time() middle = (_start + _end) / 2 self._update_latency(_end - start) self.scheduler_delay = response["time"] - middle break except OSError: logger.info("Waiting to connect to: %26s", self.scheduler.address) await asyncio.sleep(0.1) except TimeoutError: # pragma: no cover logger.info("Timed out when connecting to scheduler") if response["status"] != "OK": await comm.close() msg = response["message"] if "message" in response else repr(response) logger.error(f"Unable to connect to scheduler: {msg}") raise ValueError(f"Unexpected response from register: {response!r}") self.batched_stream.start(comm) self.status = Status.running await asyncio.gather( *( self.plugin_add(name=name, plugin=plugin) for name, plugin in response["worker-plugins"].items() ), ) logger.info(" Registered to: %26s", self.scheduler.address) logger.info("-" * 49) self.periodic_callbacks["keep-alive"].start() self.periodic_callbacks["heartbeat"].start() self.loop.add_callback(self.handle_scheduler, comm) def _update_latency(self, latency: float) -> None: self.latency = latency * 0.05 + self.latency * 0.95 self.digest_metric("latency", latency) async def heartbeat(self) -> None: logger.debug("Heartbeat: %s", self.address) try: start = time() response = await retry_operation( self.scheduler.heartbeat_worker, address=self.contact_address, now=start, metrics=await self.get_metrics(), executing={ key: start - cast(float, self.state.tasks[key].start_time) for key in self.active_keys if key in self.state.tasks }, extensions={ name: extension.heartbeat() for name, extension in self.extensions.items() if hasattr(extension, "heartbeat") }, ) end = time() middle = (start + end) / 2 self._update_latency(end - start) if response["status"] == "missing": # Scheduler thought we left. # This is a common race condition when the scheduler calls # remove_worker(); there can be a heartbeat between when the scheduler # removes the worker on its side and when the {"op": "close"} command # arrives through batched comms to the worker. logger.warning("Scheduler was unaware of this worker; shutting down.") # We close here just for safety's sake - the {op: close} should # arrive soon anyway. await self.close(reason="worker-heartbeat-missing") return self.scheduler_delay = response["time"] - middle self.periodic_callbacks["heartbeat"].callback_time = ( response["heartbeat-interval"] * 1000 ) self.bandwidth_workers.clear() self.bandwidth_types.clear() except OSError: logger.exception("Failed to communicate with scheduler during heartbeat.") except Exception: logger.exception("Unexpected exception during heartbeat. Closing worker.") await self.close(reason="worker-heartbeat-error") raise @fail_hard async def handle_scheduler(self, comm: Comm) -> None: try: await self.handle_stream(comm) finally: await self.close(reason="worker-handle-scheduler-connection-broken") def keys(self) -> list[Key]: return list(self.data)
[docs] async def gather(self, who_has: dict[Key, list[str]]) -> dict[Key, object]: """Endpoint used by Scheduler.rebalance() and Scheduler.replicate()""" missing_keys = [k for k in who_has if k not in self.data] failed_keys = [] missing_workers: set[str] = set() stimulus_id = f"gather-{time()}" while missing_keys: to_gather = {} for k in missing_keys: workers = set(who_has[k]) - missing_workers if workers: to_gather[k] = workers else: failed_keys.append(k) if not to_gather: break ( data, missing_keys, new_failed_keys, new_missing_workers, ) = await gather_from_workers( who_has=to_gather, rpc=self.rpc, who=self.address ) self.update_data(data, stimulus_id=stimulus_id) del data failed_keys += new_failed_keys missing_workers.update(new_missing_workers) if missing_keys: who_has = await retry_operation( self.scheduler.who_has, keys=missing_keys ) if failed_keys: logger.error("Could not find data: %s", failed_keys) return {"status": "partial-fail", "keys": list(failed_keys)} else: return {"status": "OK"}
def get_monitor_info(self, recent: bool = False, start: int = 0) -> dict[str, Any]: result = dict( range_query=( self.monitor.recent() if recent else self.monitor.range_query(start=start) ), count=self.monitor.count, last_time=self.monitor.last_time, ) if nvml.device_get_count() > 0: result["gpu_name"] = self.monitor.gpu_name result["gpu_memory_total"] = self.monitor.gpu_memory_total return result ############# # Lifecycle # #############
[docs] async def start_unsafe(self): await super().start_unsafe() enable_gc_diagnosis() ports = parse_ports(self._start_port) for port in ports: start_address = address_from_user_args( host=self._start_host, port=port, interface=self._interface, protocol=self._protocol, security=self.security, ) kwargs = self.security.get_listen_args("worker") if self._protocol in ("tcp", "tls"): kwargs = kwargs.copy() kwargs["default_host"] = get_ip( get_address_host(self.scheduler.address) ) try: await self.listen(start_address, **kwargs) except OSError as e: if len(ports) > 1 and e.errno == errno.EADDRINUSE: continue else: raise else: self._start_address = start_address break else: raise ValueError( f"Could not start Worker on host {self._start_host} " f"with port {self._start_port}" ) # Start HTTP server associated with this Worker node routes = get_handlers( server=self, modules=dask.config.get("distributed.worker.http.routes"), prefix=self._http_prefix, ) self.start_http_server(routes, self._dashboard_address) if self._dashboard: try: import distributed.dashboard.worker except ImportError: logger.debug("To start diagnostics web server please install Bokeh") else: distributed.dashboard.worker.connect( self.http_application, self.http_server, self, prefix=self._http_prefix, ) self.ip = get_address_host(self.address) if self.name is None: self.name = self.address await self.preloads.start() # Services listen on all addresses # Note Nanny is not a "real" service, just some metadata # passed in service_ports... self.start_services(self.ip) try: listening_address = "%s%s:%d" % (self.listener.prefix, self.ip, self.port) except Exception: listening_address = f"{self.listener.prefix}{self.ip}" logger.info(" Start worker at: %26s", self.address) logger.info(" Listening to: %26s", listening_address) if self.name != self.address_safe: # only if name was not None logger.info(" Worker name: %26s", self.name) for k, v in self.service_ports.items(): logger.info(" {:>16} at: {:>26}".format(k, self.ip + ":" + str(v))) logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) logger.info(" Threads: %26d", self.state.nthreads) if self.memory_manager.memory_limit: logger.info( " Memory: %26s", format_bytes(self.memory_manager.memory_limit), ) logger.info(" Local Directory: %26s", self.local_directory) setproctitle("dask worker [%s]" % self.address) plugins_msgs = await asyncio.gather( *( self.plugin_add(plugin=plugin, catch_errors=False) for plugin in self._pending_plugins ), return_exceptions=True, ) plugins_exceptions = [msg for msg in plugins_msgs if isinstance(msg, Exception)] if len(plugins_exceptions) >= 1: if len(plugins_exceptions) > 1: logger.error( "Multiple plugin exceptions raised. All exceptions will be logged, the first is raised." ) for exc in plugins_exceptions: logger.error(repr(exc)) raise plugins_exceptions[0] self._pending_plugins = () self.state.address = self.address await self._register_with_scheduler() self.start_periodic_callbacks() return self
[docs] @log_errors async def close( # type: ignore self, timeout: float = 30, executor_wait: bool = True, nanny: bool = True, reason: str = "worker-close", ) -> str | None: """Close the worker Close asynchronous operations running on the worker, stop all executors and comms. If requested, this also closes the nanny. Parameters ---------- timeout Timeout in seconds for shutting down individual instructions executor_wait If True, shut down executors synchronously, otherwise asynchronously nanny If True, close the nanny reason Reason for closing the worker Returns ------- str | None None if worker already in closing state or failed, "OK" otherwise """ # FIXME: The worker should not be allowed to close the nanny. Ownership # is the other way round. If an external caller wants to close # nanny+worker, the nanny must be notified first. ==> Remove kwarg # nanny, see also Scheduler.retire_workers if self.status in (Status.closed, Status.closing, Status.failed): logger.debug( "Attempted to close worker that is already %s. Reason: %s", self.status, reason, ) await self.finished() return None if self.status == Status.init: # If the worker is still in startup/init and is started by a nanny, # this means the nanny itself is not up, yet. If the Nanny isn't up, # yet, it's server will not accept any incoming RPC requests and # will block until the startup is finished. # Therefore, this worker trying to communicate with the Nanny during # startup is not possible and we cannot close it. # In this case, the Nanny will automatically close after inspecting # the worker status nanny = False disable_gc_diagnosis() try: self.log_event(self.address, {"action": "closing-worker", "reason": reason}) except Exception: # This can happen when the Server is not up yet logger.exception("Failed to log closing event") try: logger.info("Stopping worker at %s. Reason: %s", self.address, reason) except ValueError: # address not available if already closed logger.info("Stopping worker. Reason: %s", reason) if self.status not in WORKER_ANY_RUNNING: logger.info("Closed worker has not yet started: %s", self.status) if not executor_wait: logger.info("Not waiting on executor to close") # This also informs the scheduler about the status update self.status = Status.closing setproctitle("dask worker [closing]") if nanny and self.nanny: with self.rpc(self.nanny) as r: await r.close_gracefully(reason=reason) # Cancel async instructions await BaseWorker.close(self, timeout=timeout) await asyncio.gather(*(self.plugin_remove(name) for name in self.plugins)) for extension in self.extensions.values(): if hasattr(extension, "close"): result = extension.close() if isawaitable(result): await result self.stop_services() await self.preloads.teardown() for pc in self.periodic_callbacks.values(): pc.stop() if self._client: # If this worker is the last one alive, clean up the worker # initialized clients if not any( w for w in Worker._instances if w != self and w.status in WORKER_ANY_RUNNING ): for c in Worker._initialized_clients: # Regardless of what the client was initialized with # we'll require the result as a future. This is # necessary since the heuristics of asynchronous are not # reliable and we might deadlock here c._asynchronous = True if c.asynchronous: await c.close() else: # There is still the chance that even with us # telling the client to be async, itself will decide # otherwise c.close() await self._stop_listeners() await self.rpc.close() # Give some time for a UCX scheduler to complete closing endpoints # before closing self.batched_stream, otherwise the local endpoint # may be closed too early and errors be raised on the scheduler when # trying to send closing message. if self._protocol == "ucx": # pragma: no cover await asyncio.sleep(0.2) self.batched_send({"op": "close-stream"}) if self.batched_stream: with suppress(TimeoutError): await self.batched_stream.close(timedelta(seconds=timeout)) for executor in self.executors.values(): if executor is utils._offload_executor: continue # Never shutdown the offload executor def _close(executor, wait): if isinstance(executor, ThreadPoolExecutor): executor._work_queue.queue.clear() executor.shutdown(wait=wait, timeout=timeout) else: executor.shutdown(wait=wait) # Waiting for the shutdown can block the event loop causing # weird deadlocks particularly if the task that is executing in # the thread is waiting for a server reply, e.g. when using # worker clients, semaphores, etc. # Are we shutting down the process? if self._is_finalizing() or not threading.main_thread().is_alive(): # If we're shutting down there is no need to wait for daemon # threads to finish _close(executor=executor, wait=False) else: try: await asyncio.to_thread( _close, executor=executor, wait=executor_wait ) except RuntimeError: logger.error( "Could not close executor %r by dispatching to thread. Trying synchronously.", executor, exc_info=True, ) _close( executor=executor, wait=executor_wait ) # Just run it directly self.stop() self.status = Status.closed setproctitle("dask worker [closed]") await ServerNode.close(self) self.__exit_stack.__exit__(None, None, None) return "OK"
[docs] async def close_gracefully( self, restart=None, reason: str = "worker-close-gracefully" ): """Gracefully shut down a worker This first informs the scheduler that we're shutting down, and asks it to move our data elsewhere. Afterwards, we close as normal """ if self.status in (Status.closing, Status.closing_gracefully): await self.finished() if self.status == Status.closed: return logger.info("Closing worker gracefully: %s. Reason: %s", self.address, reason) # Wait for all tasks to leave the worker and don't accept any new ones. # Scheduler.retire_workers will set the status to closing_gracefully and push it # back to this worker. await self.scheduler.retire_workers( workers=[self.address], close_workers=False, remove=True, stimulus_id=f"worker-close-gracefully-{time()}", ) if restart is None: restart = self.lifetime_restart await self.close(nanny=not restart, reason=reason)
async def wait_until_closed(self): warnings.warn("wait_until_closed has moved to finished()") await self.finished() assert self.status == Status.closed ################ # Worker Peers # ################ def send_to_worker(self, address, msg): if address not in self.stream_comms: bcomm = BatchedSend(interval="1ms", loop=self.loop) self.stream_comms[address] = bcomm async def batched_send_connect(): comm = await connect( address, **self.connection_args # TODO, serialization ) comm.name = "Worker->Worker" await comm.write({"op": "connection_stream"}) bcomm.start(comm) self._ongoing_background_tasks.call_soon(batched_send_connect) self.stream_comms[address].send(msg) @context_meter_to_server_digest("get-data") async def get_data( self, comm: Comm, keys: Collection[str], who: str | None = None, serializers: list[str] | None = None, ) -> GetDataBusy | Literal[Status.dont_reply]: max_connections = self.transfer_outgoing_count_limit # Allow same-host connections more liberally if get_address_host(comm.peer_address) == get_address_host(self.address): max_connections = max_connections * 2 if self.status == Status.paused: max_connections = 1 throttle_msg = ( " Throttling outgoing data transfers because worker is paused." ) else: throttle_msg = "" if ( max_connections is not False and self.transfer_outgoing_count >= max_connections ): logger.debug( "Worker %s has too many open connections to respond to data request " "from %s (%d/%d).%s", self.address, who, self.transfer_outgoing_count, max_connections, throttle_msg, ) return {"status": "busy"} self.transfer_outgoing_count += 1 self.transfer_outgoing_count_total += 1 # This may potentially take many seconds if it involves unspilling data = {k: self.data[k] for k in keys if k in self.data} if len(data) < len(keys): for k in set(keys) - data.keys(): if k in self.state.actors: from distributed.actor import Actor data[k] = Actor( type(self.state.actors[k]), self.address, k, worker=self ) msg = {"status": "OK", "data": {k: to_serialize(v) for k, v in data.items()}} # Note: `if k in self.data` above guarantees that # k is in self.state.tasks too and that nbytes is non-None bytes_per_task = {k: self.state.tasks[k].nbytes or 0 for k in data} total_bytes = sum(bytes_per_task.values()) self.transfer_outgoing_bytes += total_bytes self.transfer_outgoing_bytes_total += total_bytes try: with context_meter.meter("network", func=time) as m: compressed = await comm.write(msg, serializers=serializers) response = await comm.read(deserializers=serializers) assert response == "OK", response except OSError: logger.exception( "failed during get data with %s -> %s", self.address, who, ) comm.abort() raise finally: self.transfer_outgoing_bytes -= total_bytes self.transfer_outgoing_count -= 1 # Not the same as m.delta, which doesn't include time spent # serializing/deserializing duration = max(0.001, m.stop - m.start) self.transfer_outgoing_log.append( { "start": m.start + self.scheduler_delay, "stop": m.stop + self.scheduler_delay, "middle": (m.start + m.stop) / 2, "duration": duration, "who": who, "keys": bytes_per_task, "total": total_bytes, "compressed": compressed, "bandwidth": total_bytes / duration, } ) return Status.dont_reply ################### # Local Execution # ################### def update_data( self, data: dict[Key, object], stimulus_id: str | None = None, ) -> dict[str, Any]: if stimulus_id is None: stimulus_id = f"update-data-{time()}" self.handle_stimulus(UpdateDataEvent(data=data, stimulus_id=stimulus_id)) return {"nbytes": {k: sizeof(v) for k, v in data.items()}, "status": "OK"} async def set_resources(self, **resources: float) -> None: for r, quantity in resources.items(): if r in self.state.total_resources: self.state.available_resources[r] += ( quantity - self.state.total_resources[r] ) else: self.state.available_resources[r] = quantity self.state.total_resources[r] = quantity await retry_operation( self.scheduler.set_resources, resources=self.state.total_resources, worker=self.contact_address, ) @log_errors async def plugin_add( self, plugin: WorkerPlugin | bytes, name: str | None = None, catch_errors: bool = True, ) -> ErrorMessage | OKMessage: if isinstance(plugin, bytes): plugin = pickle.loads(plugin) if not isinstance(plugin, WorkerPlugin): warnings.warn( "Registering duck-typed plugins has been deprecated. " "Please make sure your plugin subclasses `WorkerPlugin`.", DeprecationWarning, stacklevel=2, ) plugin = cast(WorkerPlugin, plugin) if name is None: name = _get_plugin_name(plugin) assert name if name in self.plugins: await self.plugin_remove(name=name) self.plugins[name] = plugin logger.info("Starting Worker plugin %s", name) if hasattr(plugin, "setup"): try: result = plugin.setup(worker=self) if isawaitable(result): await result except Exception as e: logger.exception("Worker plugin %s failed to setup", name) if not catch_errors: raise return error_message(e) return {"status": "OK"} @log_errors async def plugin_remove(self, name: str) -> ErrorMessage | OKMessage: logger.info(f"Removing Worker plugin {name}") try: plugin = self.plugins.pop(name) if hasattr(plugin, "teardown"): result = plugin.teardown(worker=self) if isawaitable(result): await result except Exception as e: logger.exception("Worker plugin %s failed to teardown", name) return error_message(e) return {"status": "OK"} def handle_worker_status_change(self, status: str, stimulus_id: str) -> None: new_status = Status.lookup[status] # type: ignore if ( new_status == Status.closing_gracefully and self._status not in WORKER_ANY_RUNNING ): logger.error( "Invalid Worker.status transition: %s -> %s", self._status, new_status ) # Reiterate the current status to the scheduler to restore sync self._send_worker_status_change(stimulus_id) else: # Update status and send confirmation to the Scheduler (see status.setter) self.status = new_status ################### # Task Management # ################### def _handle_remote_stimulus( self, cls: type[StateMachineEvent] ) -> Callable[..., None]: def _(**kwargs): event = cls(**kwargs) self.handle_stimulus(event) _.__name__ = f"_handle_remote_stimulus({cls.__name__})" return _
[docs] @fail_hard def handle_stimulus(self, *stims: StateMachineEvent) -> None: """Override BaseWorker method for added validation See also -------- distributed.worker_state_machine.BaseWorker.handle_stimulus distributed.worker_state_machine.WorkerState.handle_stimulus """ try: super().handle_stimulus(*stims) except Exception as e: if hasattr(e, "to_event"): topic, msg = e.to_event() self.log_event(topic, msg) raise
def stateof(self, key: str) -> dict[str, Any]: ts = self.state.tasks[key] return { "executing": ts.state == "executing", "waiting_for_data": bool(ts.waiting_for_data), "heap": ts in self.state.ready or ts in self.state.constrained, "data": key in self.data, } async def get_story(self, keys_or_stimuli: Iterable[str]) -> list[tuple]: return self.state.story(*keys_or_stimuli) ########################## # Dependencies gathering # ########################## def _get_cause(self, keys: Iterable[Key]) -> TaskState: """For diagnostics, we want to attach a transfer to a single task. This task is typically the next to be executed but since we're fetching tasks for potentially many dependents, an exact match is not possible. Additionally, if a key was fetched through acquire-replicas, dependents may not be known at all. Returns ------- The task to attach startstops of this transfer to """ cause = None for key in keys: ts = self.state.tasks[key] if ts.dependents: return next(iter(ts.dependents)) cause = ts assert cause # Always at least one key return cause def _update_metrics_received_data( self, start: float, stop: float, data: dict[Key, object], cause: TaskState, worker: str, ) -> None: total_bytes = sum(self.state.tasks[key].get_nbytes() for key in data) cause.startstops.append( { "action": "transfer", "start": start + self.scheduler_delay, "stop": stop + self.scheduler_delay, "source": worker, } ) duration = max(0.001, stop - start) bandwidth = total_bytes / duration self.transfer_incoming_log.append( { "start": start + self.scheduler_delay, "stop": stop + self.scheduler_delay, "middle": (start + stop) / 2.0 + self.scheduler_delay, "duration": duration, "keys": {key: self.state.tasks[key].nbytes for key in data}, "total": total_bytes, "bandwidth": bandwidth, "who": worker, } ) if total_bytes > 1_000_000: self.bandwidth = self.bandwidth * 0.95 + bandwidth * 0.05 bw, cnt = self.bandwidth_workers[worker] self.bandwidth_workers[worker] = (bw + bandwidth, cnt + 1) types = set(map(type, data.values())) if len(types) == 1: [typ] = types bw, cnt = self.bandwidth_types[typ] self.bandwidth_types[typ] = (bw + bandwidth, cnt + 1) self.digest_metric("transfer-bandwidth", total_bytes / duration) self.digest_metric("transfer-duration", duration) self.counters["transfer-count"].add(len(data))
[docs] @fail_hard async def gather_dep( self, worker: str, to_gather: Collection[Key], total_nbytes: int, *, stimulus_id: str, ) -> StateMachineEvent: """Implements BaseWorker abstract method See also -------- distributed.worker_state_machine.BaseWorker.gather_dep """ if self.status not in WORKER_ANY_RUNNING: # This is only for the sake of coherence of the WorkerState; # it should never actually reach the scheduler. return GatherDepFailureEvent.from_exception( RuntimeError("Worker is shutting down"), worker=worker, total_nbytes=total_nbytes, stimulus_id=f"worker-closing-{time()}", ) self.state.log.append(("request-dep", worker, to_gather, stimulus_id, time())) logger.debug("Request %d keys from %s", len(to_gather), worker) try: with context_meter.meter("network", func=time) as m: response = await get_data_from_worker( rpc=self.rpc, keys=to_gather, worker=worker, who=self.address ) if response["status"] == "busy": self.state.log.append( ("gather-dep-busy", worker, to_gather, stimulus_id, time()) ) return GatherDepBusyEvent( worker=worker, total_nbytes=total_nbytes, stimulus_id=f"gather-dep-busy-{time()}", ) assert response["status"] == "OK" cause = self._get_cause(to_gather) self._update_metrics_received_data( start=m.start, stop=m.stop, data=response["data"], cause=cause, worker=worker, ) self.state.log.append( ("receive-dep", worker, set(response["data"]), stimulus_id, time()) ) return GatherDepSuccessEvent( worker=worker, total_nbytes=total_nbytes, data=response["data"], stimulus_id=f"gather-dep-success-{time()}", ) except OSError: logger.exception("Worker stream died during communication: %s", worker) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time()) ) return GatherDepNetworkFailureEvent( worker=worker, total_nbytes=total_nbytes, stimulus_id=f"gather-dep-network-failure-{time()}", ) except Exception as e: # e.g. data failed to deserialize # FIXME this will deadlock the cluster # https://github.com/dask/distributed/issues/6705 logger.exception(e) self.state.log.append( ("gather-dep-failed", worker, to_gather, stimulus_id, time()) ) if self.batched_stream and LOG_PDB: import pdb pdb.set_trace() return GatherDepFailureEvent.from_exception( e, worker=worker, total_nbytes=total_nbytes, stimulus_id=f"gather-dep-failed-{time()}", )
[docs] async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent: """Wait some time, then take a peer worker out of busy state. Implements BaseWorker abstract method. See Also -------- distributed.worker_state_machine.BaseWorker.retry_busy_worker_later """ await asyncio.sleep(0.15) return RetryBusyWorkerEvent( worker=worker, stimulus_id=f"retry-busy-worker-{time()}" )
[docs] def digest_metric(self, name: Hashable, value: float) -> None: """Implement BaseWorker.digest_metric by calling Server.digest_metric""" ServerNode.digest_metric(self, name, value)
@log_errors def find_missing(self) -> None: self.handle_stimulus(FindMissingEvent(stimulus_id=f"find-missing-{time()}")) # This is quite arbitrary but the heartbeat has scaling implemented self.periodic_callbacks["find-missing"].callback_time = self.periodic_callbacks[ "heartbeat" ].callback_time ################ # Execute Task # ################ def run(self, comm, function, args=(), wait=True, kwargs=None): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) def run_coroutine(self, comm, function, args=(), kwargs=None, wait=True): return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait) async def actor_execute( self, actor=None, function=None, args=(), kwargs: dict | None = None, ) -> dict[str, Any]: kwargs = kwargs or {} separate_thread = kwargs.pop("separate_thread", True) key = actor actor = self.state.actors[key] func = getattr(actor, function) name = key_split(key) + "." + function try: if iscoroutinefunction(func): token = _worker_cvar.set(self) try: result = await func(*args, **kwargs) finally: _worker_cvar.reset(token) elif separate_thread: result = await self.loop.run_in_executor( self.executors["actor"], _run_actor, func, args, kwargs, self.execution_state, name, self.active_threads, self.active_threads_lock, ) else: token = _worker_cvar.set(self) try: result = func(*args, **kwargs) finally: _worker_cvar.reset(token) return {"status": "OK", "result": to_serialize(result)} except Exception as ex: return {"status": "error", "exception": to_serialize(ex)} def actor_attribute(self, actor=None, attribute=None) -> dict[str, Any]: try: value = getattr(self.state.actors[actor], attribute) return {"status": "OK", "result": to_serialize(value)} except Exception as ex: return {"status": "error", "exception": to_serialize(ex)}
[docs] @fail_hard async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: """Execute a task. Implements BaseWorker abstract method. See also -------- distributed.worker_state_machine.BaseWorker.execute """ if self.status not in WORKER_ANY_RUNNING: # This is just for internal coherence of the WorkerState; the reschedule # message should not ever reach the Scheduler. # It is still OK if it does though. return RescheduleEvent(key=key, stimulus_id=f"worker-closing-{time()}") # The key *must* be in the worker state thanks to the cancelled state ts = self.state.tasks[key] run_id = ts.run_id try: if self.state.validate: assert not ts.waiting_for_data assert ts.state in ("executing", "cancelled", "resumed"), ts assert ts.run_spec is not None start = time() data: dict[Key, Any] = {} for dep in ts.dependencies: dkey = dep.key actors = self.state.actors if actors and dkey in actors: from distributed.actor import Actor # TODO: create local actor data[dkey] = Actor(type(actors[dkey]), self.address, dkey, self) else: data[dkey] = self.data[dkey] stop = time() if stop - start > 0.005: ts.startstops.append( {"action": "disk-read", "start": start, "stop": stop} ) assert ts.annotations is not None executor = ts.annotations.get("executor", "default") try: e = self.executors[executor] except KeyError: raise ValueError( f"Invalid executor {executor!r}; " f"expected one of: {sorted(self.executors)}" ) self.active_keys.add(key) # Propagate span (see distributed.spans). This is useful when spawning # more tasks using worker_client() and for logging. span_ctx = ( dask.annotate(span=ts.annotations["span"]) if "span" in ts.annotations else contextlib.nullcontext() ) span_ctx.__enter__() run_spec = ts.run_spec try: ts.start_time = time() if ts.run_spec.is_coro: token = _worker_cvar.set(self) try: result = await _run_task_async( ts.run_spec, data, self.scheduler_delay, ) finally: _worker_cvar.reset(token) elif "ThreadPoolExecutor" in str(type(e)): # The 'executor' time metric should be almost zero most of the time, # e.g. thread synchronization overhead only, since thread-noncpu and # thread-cpu inside the thread detract from it. However, it may # become substantial in case of misalignment between the size of the # thread pool and the number of running tasks in the worker stater # machine (e.g. https://github.com/dask/distributed/issues/5882) with context_meter.meter("executor"): result = await run_in_executor_with_context( e, _run_task, ts.run_spec, data, self.execution_state, key, self.active_threads, self.active_threads_lock, self.scheduler_delay, ) else: # Can't capture contextvars across processes. If this is a # ProcessPoolExecutor, the 'executor' time metric will show the # whole runtime inside the executor. with context_meter.meter("executor"): result = await self.loop.run_in_executor( e, _run_task_simple, ts.run_spec, data, self.scheduler_delay, ) finally: self.active_keys.discard(key) span_ctx.__exit__(None, None, None) self.threads[key] = result["thread"] if result["op"] == "task-finished": if self.digests is not None: duration = max(0, result["stop"] - result["start"]) self.digests["task-duration"].add(duration) return ExecuteSuccessEvent( key=key, run_id=run_id, value=result["result"], start=result["start"], stop=result["stop"], nbytes=result["nbytes"], type=result["type"], stimulus_id=f"task-finished-{time()}", ) task_exc = result["actual_exception"] if isinstance(task_exc, Reschedule): return RescheduleEvent(key=ts.key, stimulus_id=f"reschedule-{time()}") if ( self.status == Status.closing and isinstance(task_exc, asyncio.CancelledError) and run_spec.is_coro ): # `Worker.cancel` will cause async user tasks to raise `CancelledError`. # Since we cancelled those tasks, we shouldn't treat them as failures. # This is just a heuristic; it's _possible_ the task happened to # fail independently with `CancelledError`. logger.info( f"Async task {key!r} cancelled during worker close; rescheduling." ) return RescheduleEvent( key=ts.key, stimulus_id=f"cancelled-by-worker-close-{time()}" ) if ts.state in ("executing", "long-running", "resumed"): logger.error( "Compute Failed\n" "Key: %s\n" "State: %s\n" "Task: %s\n" "Exception: %r\n" "Traceback: %r\n", key, ts.state, repr(run_spec)[:1000], result["exception_text"], result["traceback_text"], ) return ExecuteFailureEvent.from_exception( result, key=key, run_id=run_id, start=result["start"], stop=result["stop"], stimulus_id=f"task-erred-{time()}", ) except Exception as exc: # Some legitimate use cases that will make us reach this point: # - User specified an invalid executor; # - Task transitioned to cancelled or resumed(fetch) before the start of # execute() and its dependencies were released. This caused # _prepare_args_for_execution() to raise KeyError; # - A dependency was unspilled but failed to deserialize due to a bug in # user-defined or third party classes. if ts.state in ("executing", "long-running"): logger.error( f"Exception during execution of task {key!r}", exc_info=True, ) return ExecuteFailureEvent.from_exception( exc, key=key, run_id=run_id, stimulus_id=f"execute-unknown-error-{time()}", )
################## # Administrative # ################## def cycle_profile(self) -> None: now = time() + self.scheduler_delay prof, self.profile_recent = self.profile_recent, profile.create() self.profile_history.append((now, prof)) self.profile_keys_history.append((now, dict(self.profile_keys))) self.profile_keys.clear()
[docs] def trigger_profile(self) -> None: """ Get a frame from all actively computing threads Merge these frames into existing profile counts """ if not self.active_threads: # hope that this is thread-atomic? return start = time() with self.active_threads_lock: active_threads = self.active_threads.copy() frames = sys._current_frames() frames = {ident: frames[ident] for ident in active_threads} llframes = {} if self.low_level_profiler: llframes = {ident: profile.ll_get_stack(ident) for ident in active_threads} for ident, frame in frames.items(): if frame is not None: key = key_split(active_threads[ident]) llframe = llframes.get(ident) state = profile.process( frame, True, self.profile_recent, stop="distributed/worker.py" ) profile.llprocess(llframe, None, state) profile.process( frame, True, self.profile_keys[key], stop="distributed/worker.py" ) stop = time() self.digest_metric("profile-duration", stop - start)
async def get_profile( self, start=None, stop=None, key=None, server: bool = False, ): now = time() + self.scheduler_delay if server: history = self.io_loop.profile # type: ignore[attr-defined] elif key is None: history = self.profile_history else: history = [(t, d[key]) for t, d in self.profile_keys_history if key in d] if start is None: istart = 0 else: istart = bisect.bisect_left(history, (start,)) if stop is None: istop = None else: istop = bisect.bisect_right(history, (stop,)) + 1 if istop >= len(history): istop = None # include end if istart == 0 and istop is None: history = list(history) else: iistop = len(history) if istop is None else istop history = [history[i] for i in range(istart, iistop)] prof = profile.merge(*pluck(1, history)) if not history: return profile.create() if istop is None and (start is None or start < now): if key is None: recent = self.profile_recent else: recent = self.profile_keys[key] prof = profile.merge(prof, recent) return prof async def get_profile_metadata( self, start: float = 0, stop: float | None = None ) -> dict[str, Any]: add_recent = stop is None now = time() + self.scheduler_delay stop = stop or now result = { "counts": [ (t, d["count"]) for t, d in self.profile_history if start < t < stop ], "keys": [ (t, {k: d["count"] for k, d in v.items()}) for t, v in self.profile_keys_history if start < t < stop ], } if add_recent: result["counts"].append((now, self.profile_recent["count"])) result["keys"].append( (now, {k: v["count"] for k, v in self.profile_keys.items()}) ) return result def get_call_stack(self, keys: Collection[Key] | None = None) -> dict[Key, Any]: with self.active_threads_lock: sys_frames = sys._current_frames() frames = {key: sys_frames[tid] for tid, key in self.active_threads.items()} if keys is not None: frames = {key: frames[key] for key in keys if key in frames} return {key: profile.call_stack(frame) for key, frame in frames.items()} async def benchmark_disk(self) -> dict[str, float]: return await self.loop.run_in_executor( self.executor, benchmark_disk, self.local_directory ) async def benchmark_memory(self) -> dict[str, float]: return await self.loop.run_in_executor(self.executor, benchmark_memory) async def benchmark_network(self, address: str) -> dict[str, float]: return await benchmark_network(rpc=self.rpc, address=address) ####################################### # Worker Clients (advanced workloads) # ####################################### @property def client(self) -> Client: with self._lock: if self._client: return self._client else: return self._get_client() def _get_client(self, timeout: float | None = None) -> Client: """Get local client attached to this worker If no such client exists, create one See Also -------- get_client """ if timeout is None: timeout = dask.config.get("distributed.comm.timeouts.connect") timeout = parse_timedelta(timeout, "s") try: from distributed.client import default_client client = default_client() except ValueError: # no clients found, need to make a new one pass else: # must be lazy import otherwise cyclic import from distributed.deploy.cluster import Cluster if ( client.scheduler and client.scheduler.address == self.scheduler.address # The below conditions should only happen in case a second # cluster is alive, e.g. if a submitted task spawned its onwn # LocalCluster, see gh4565 or ( isinstance(client._start_arg, str) and client._start_arg == self.scheduler.address or isinstance(client._start_arg, Cluster) and client._start_arg.scheduler_address == self.scheduler.address ) ): self._client = client if not self._client: from distributed.client import Client asynchronous = in_async_call(self.loop) self._client = Client( self.scheduler, loop=self.loop, security=self.security, set_as_default=True, asynchronous=asynchronous, direct_to_workers=True, name="worker", timeout=timeout, ) Worker._initialized_clients.add(self._client) if not asynchronous: assert self._client.status == "running" self.log_event( "worker-get-client", { "client": self._client.id, "timeout": timeout, }, ) return self._client
[docs] def get_current_task(self) -> Key: """Get the key of the task we are currently running This only makes sense to run within a task Examples -------- >>> from dask.distributed import get_worker >>> def f(): ... return get_worker().get_current_task() >>> future = client.submit(f) # doctest: +SKIP >>> future.result() # doctest: +SKIP 'f-1234' See Also -------- get_worker """ return self.active_threads[threading.get_ident()]
def _handle_remove_worker(self, worker: str, stimulus_id: str) -> None: self.rpc.remove(worker) self.handle_stimulus(RemoveWorkerEvent(worker=worker, stimulus_id=stimulus_id)) def validate_state(self) -> None: try: self.state.validate_state() except Exception as e: logger.error("Validate state failed", exc_info=e) logger.exception(e) if LOG_PDB: import pdb pdb.set_trace() if hasattr(e, "to_event"): topic, msg = e.to_event() self.log_event(topic, msg) raise @property def incoming_transfer_log(self): warnings.warn( "The `Worker.incoming_transfer_log` attribute has been renamed to " "`Worker.transfer_incoming_log`", DeprecationWarning, stacklevel=2, ) return self.transfer_incoming_log @property def outgoing_count(self): warnings.warn( "The `Worker.outgoing_count` attribute has been renamed to " "`Worker.transfer_outgoing_count_total`", DeprecationWarning, stacklevel=2, ) return self.transfer_outgoing_count_total @property def outgoing_current_count(self): warnings.warn( "The `Worker.outgoing_current_count` attribute has been renamed to " "`Worker.transfer_outgoing_count`", DeprecationWarning, stacklevel=2, ) return self.transfer_outgoing_count @property def outgoing_transfer_log(self): warnings.warn( "The `Worker.outgoing_transfer_log` attribute has been renamed to " "`Worker.transfer_outgoing_log`", DeprecationWarning, stacklevel=2, ) return self.transfer_outgoing_log @property def total_in_connections(self): warnings.warn( "The `Worker.total_in_connections` attribute has been renamed to " "`Worker.transfer_outgoing_count_limit`", DeprecationWarning, stacklevel=2, ) return self.transfer_outgoing_count_limit
_worker_cvar: contextvars.ContextVar[Worker] = contextvars.ContextVar("_worker_cvar")
[docs]def get_worker() -> Worker: """Get the worker currently running this task Examples -------- >>> def f(): ... worker = get_worker() # The worker on which this task is running ... return worker.address >>> future = client.submit(f) # doctest: +SKIP >>> future.result() # doctest: +SKIP 'tcp://127.0.0.1:47373' See Also -------- get_client worker_client """ try: return _worker_cvar.get() except LookupError: raise ValueError("No worker found") from None
[docs]def get_client(address=None, timeout=None, resolve_address=True) -> Client: """Get a client while within a task. This client connects to the same scheduler to which the worker is connected Parameters ---------- address : str, optional The address of the scheduler to connect to. Defaults to the scheduler the worker is connected to. timeout : int or str Timeout (in seconds) for getting the Client. Defaults to the ``distributed.comm.timeouts.connect`` configuration value. resolve_address : bool, default True Whether to resolve `address` to its canonical form. Returns ------- Client Examples -------- >>> def f(): ... client = get_client(timeout="10s") ... futures = client.map(lambda x: x + 1, range(10)) # spawn many tasks ... results = client.gather(futures) ... return sum(results) >>> future = client.submit(f) # doctest: +SKIP >>> future.result() # doctest: +SKIP 55 See Also -------- get_worker worker_client secede """ if timeout is None: timeout = dask.config.get("distributed.comm.timeouts.connect") timeout = parse_timedelta(timeout, "s") if address and resolve_address: address = comm_resolve_address(address) try: worker = get_worker() except ValueError: # could not find worker pass else: if not address or worker.scheduler.address == address: return worker._get_client(timeout=timeout) from distributed.client import Client try: client = Client.current() # TODO: assumes the same scheduler except ValueError: client = None if client and (not address or client.scheduler.address == address): return client elif address: return Client(address, timeout=timeout) else: raise ValueError("No global client found and no address provided")
[docs]def secede(): """ Have this task secede from the worker's thread pool This opens up a new scheduling slot and a new thread for a new task. This enables the client to schedule tasks on this node, which is especially useful while waiting for other jobs to finish (e.g., with ``client.gather``). Examples -------- >>> def mytask(x): ... # do some work ... client = get_client() ... futures = client.map(...) # do some remote work ... secede() # while that work happens, remove ourself from the pool ... return client.gather(futures) # return gathered results See Also -------- get_client get_worker """ worker = get_worker() tpe_secede() # have this thread secede from the thread pool duration = time() - thread_state.start_time worker.loop.add_callback( worker.handle_stimulus, SecedeEvent( key=thread_state.key, compute_duration=duration, stimulus_id=f"secede-{time()}", ), )
async def get_data_from_worker( rpc: ConnectionPool, keys: Collection[Key], worker: str, *, who: str | None = None, serializers: list[str] | None = None, deserializers: list[str] | None = None, ) -> GetDataBusy | GetDataSuccess: """Get keys from worker The worker has a two step handshake to acknowledge when data has been fully delivered. This function implements that handshake. See Also -------- Worker.get_data Worker.gather_dep utils_comm.gather_data_from_workers """ if serializers is None: serializers = rpc.serializers if deserializers is None: deserializers = rpc.deserializers comm = await rpc.connect(worker) comm.name = "Ephemeral Worker->Worker for gather" try: response = await send_recv( comm, serializers=serializers, deserializers=deserializers, op="get_data", keys=keys, who=who, ) try: status = response["status"] except KeyError: # pragma: no cover raise ValueError("Unexpected response", response) else: if status == "OK": await comm.write("OK") return response finally: rpc.reuse(worker, comm) cache_dumps: LRU[Callable[..., Any], bytes] = LRU(maxsize=100) _cache_lock = threading.Lock() def dumps_function(func) -> bytes: """Dump a function to bytes, cache functions""" try: with _cache_lock: result = cache_dumps[func] except KeyError: result = pickle.dumps(func) if len(result) < 100000: with _cache_lock: cache_dumps[func] = result except TypeError: # Unhashable function result = pickle.dumps(func) return result def _run_task( task: GraphNode, data: dict, execution_state: dict, key: Key, active_threads: dict, active_threads_lock: threading.Lock, time_delay: float, ) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns ------- msg: dictionary with status, result/error, timings, etc.. """ ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key with set_thread_state( start_time=time(), execution_state=execution_state, key=key, ): token = _worker_cvar.set(execution_state["worker"]) try: msg = _run_task_simple(task, data, time_delay) finally: _worker_cvar.reset(token) with active_threads_lock: del active_threads[ident] return msg def _run_task_simple( task: GraphNode, data: dict, time_delay: float, ) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns ------- msg: dictionary with status, result/error, timings, etc.. """ # meter("thread-cpu").delta # Difference in thread_time() before and after function call, minus user calls # to context_meter inside the function. Published to Server.digests as # {("execute", <prefix>, "thread-cpu", "seconds"): <value>} # m.delta # Difference in wall time before and after function call, minus thread-cpu, # minus user calls to context_meter. Published to Server.digests as # {("execute", <prefix>, "thread-noncpu", "seconds"): <value>} # m.stop - m.start # Difference in wall time before and after function call, without subtracting # anything. This is used in scheduler heuristics, e.g. task stealing. with ( context_meter.meter("thread-noncpu", func=time) as m, context_meter.meter("thread-cpu", func=thread_time), ): try: result = task(data) except (SystemExit, KeyboardInterrupt): # Special-case these, just like asyncio does all over the place. They will # pass through `fail_hard` and `_handle_stimulus_from_task`, and eventually # be caught by special-case logic in asyncio: # https://github.com/python/cpython/blob/v3.9.4/Lib/asyncio/events.py#L81-L82 # Any other `BaseException` types would ultimately be ignored by asyncio if # raised here, after messing up the worker state machine along their way. raise except BaseException as e: # noqa: B036 # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they # aren't a reason to shut down the whole system (since we allow the # system-shutting-down `SystemExit` and `KeyboardInterrupt` to pass through) msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" msg["actual_exception"] = e else: msg: RunTaskSuccess = { # type: ignore "op": "task-finished", "status": "OK", "result": result, "nbytes": sizeof(result), "type": type(result) if result is not None else None, } msg["start"] = m.start + time_delay msg["stop"] = m.stop + time_delay msg["thread"] = threading.get_ident() return msg async def _run_task_async( task: GraphNode, data: dict, time_delay: float, ) -> RunTaskSuccess | RunTaskFailure: """Run a function, collect information Returns ------- msg: dictionary with status, result/error, timings, etc.. """ with context_meter.meter("thread-noncpu", func=time) as m: try: result = await task(data) except (SystemExit, KeyboardInterrupt): # Special-case these, just like asyncio does all over the place. They will # pass through `fail_hard` and `_handle_stimulus_from_task`, and eventually # be caught by special-case logic in asyncio: # https://github.com/python/cpython/blob/v3.9.4/Lib/asyncio/events.py#L81-L82 # Any other `BaseException` types would ultimately be ignored by asyncio if # raised here, after messing up the worker state machine along their way. raise except BaseException as e: # noqa: B036 # NOTE: this includes `CancelledError`! Since it's a user task, that's _not_ # a reason to shut down the worker. # Users _shouldn't_ use `BaseException`s, but if they do, we can assume they # aren't a reason to shut down the whole system (since we allow the # system-shutting-down `SystemExit` and `KeyboardInterrupt` to pass through) msg: RunTaskFailure = error_message(e) # type: ignore msg["op"] = "task-erred" msg["actual_exception"] = e else: msg: RunTaskSuccess = { # type: ignore "op": "task-finished", "status": "OK", "result": result, "nbytes": sizeof(result), "type": type(result) if result is not None else None, } msg["start"] = m.start + time_delay msg["stop"] = m.stop + time_delay msg["thread"] = threading.get_ident() return msg def _run_actor( func: Callable, args: tuple, kwargs: dict, execution_state: dict, key: Key, active_threads: dict, active_threads_lock: threading.Lock, ) -> Any: """Run a function, collect information Returns ------- msg: dictionary with status, result/error, timings, etc.. """ ident = threading.get_ident() with active_threads_lock: active_threads[ident] = key with set_thread_state( start_time=time(), execution_state=execution_state, key=key, actor=True, ): token = _worker_cvar.set(execution_state["worker"]) try: result = func(*args, **kwargs) finally: _worker_cvar.reset(token) with active_threads_lock: del active_threads[ident] return result def get_msg_safe_str(msg): """Make a worker msg, which contains args and kwargs, safe to cast to str: allowing for some arguments to raise exceptions during conversion and ignoring them. """ class Repr: def __init__(self, f, val): self._f = f self._val = val def __repr__(self): return self._f(self._val) msg = msg.copy() if "args" in msg: msg["args"] = Repr(convert_args_to_str, msg["args"]) if "kwargs" in msg: msg["kwargs"] = Repr(convert_kwargs_to_str, msg["kwargs"]) return msg def convert_args_to_str(args, max_len: int | None = None) -> str: """Convert args to a string, allowing for some arguments to raise exceptions during conversion and ignoring them. """ length = 0 strs = ["" for i in range(len(args))] for i, arg in enumerate(args): try: sarg = repr(arg) except Exception: sarg = "< could not convert arg to str >" strs[i] = sarg length += len(sarg) + 2 if max_len is not None and length > max_len: return "({}".format(", ".join(strs[: i + 1]))[:max_len] else: return "({})".format(", ".join(strs)) def convert_kwargs_to_str(kwargs: dict, max_len: int | None = None) -> str: """Convert kwargs to a string, allowing for some arguments to raise exceptions during conversion and ignoring them. """ length = 0 strs = ["" for i in range(len(kwargs))] for i, (argname, arg) in enumerate(kwargs.items()): try: sarg = repr(arg) except Exception: sarg = "< could not convert arg to str >" skwarg = repr(argname) + ": " + sarg strs[i] = skwarg length += len(skwarg) + 2 if max_len is not None and length > max_len: return "{{{}".format(", ".join(strs[: i + 1]))[:max_len] else: return "{{{}}}".format(", ".join(strs)) async def run(server, comm, function, args=(), kwargs=None, wait=True): kwargs = kwargs or {} function = pickle.loads(function) is_coro = iscoroutinefunction(function) assert wait or is_coro, "Combination not supported" if args: args = pickle.loads(args) if kwargs: kwargs = pickle.loads(kwargs) if has_arg(function, "dask_worker"): kwargs["dask_worker"] = server if has_arg(function, "dask_scheduler"): kwargs["dask_scheduler"] = server logger.info("Run out-of-band function %r", funcname(function)) try: if not is_coro: result = function(*args, **kwargs) else: if wait: result = await function(*args, **kwargs) else: server._ongoing_background_tasks.call_soon(function, *args, **kwargs) result = None except Exception as e: logger.warning( "Run Failed\nFunction: %s\nargs: %s\nkwargs: %s\n", str(funcname(function))[:1000], convert_args_to_str(args, max_len=1000), convert_kwargs_to_str(kwargs, max_len=1000), exc_info=True, ) response = error_message(e) else: response = {"status": "OK", "result": to_serialize(result)} return response _global_workers = Worker._instances def add_gpu_metrics(): async def gpu_metric(worker): result = await offload(nvml.real_time) return result DEFAULT_METRICS["gpu"] = gpu_metric def gpu_startup(worker): return nvml.one_time() DEFAULT_STARTUP_INFORMATION["gpu"] = gpu_startup try: import rmm as _rmm except Exception: pass else: async def rmm_metric(worker): result = await offload(rmm.real_time) return result DEFAULT_METRICS["rmm"] = rmm_metric del _rmm # avoid importing cuDF unless explicitly enabled if dask.config.get("distributed.diagnostics.cudf"): try: import cudf as _cudf # noqa: F401 except Exception: pass else: from distributed.diagnostics import cudf async def cudf_metric(worker): result = await offload(cudf.real_time) return result DEFAULT_METRICS["cudf"] = cudf_metric del _cudf
[docs]def print( *args, sep: str | None = " ", end: str | None = "\n", file: TextIO | None = None, flush: bool = False, ) -> None: """ A drop-in replacement of the built-in ``print`` function for remote printing from workers to clients. If called from outside a dask worker, its arguments are passed directly to ``builtins.print()``. If called by code running on a worker, then in addition to printing locally, any clients connected (possibly remotely) to the scheduler managing this worker will receive an event instructing them to print the same output to their own standard output or standard error streams. For example, the user can perform simple debugging of remote computations by including calls to this ``print`` function in the submitted code and inspecting the output in a local Jupyter notebook or interpreter session. All arguments behave the same as those of ``builtins.print()``, with the exception that the ``file`` keyword argument, if specified, must either be ``sys.stdout`` or ``sys.stderr``; arbitrary file-like objects are not allowed. All non-keyword arguments are converted to strings using ``str()`` and written to the stream, separated by ``sep`` and followed by ``end``. Both ``sep`` and ``end`` must be strings; they can also be ``None``, which means to use the default values. If no objects are given, ``print()`` will just write ``end``. Parameters ---------- sep : str, optional String inserted between values, default a space. end : str, optional String appended after the last value, default a newline. file : ``sys.stdout`` or ``sys.stderr``, optional Defaults to the current sys.stdout. flush : bool, default False Whether to forcibly flush the stream. Examples -------- >>> from dask.distributed import Client, print >>> client = distributed.Client(...) >>> def worker_function(): ... print("Hello from worker!") >>> client.submit(worker_function) <Future: finished, type: NoneType, key: worker_function-...> Hello from worker! """ try: worker = get_worker() except ValueError: pass else: # We are in a worker: prepare all of the print args and kwargs to be # serialized over the wire to the client. msg = { # According to the Python stdlib docs, builtin print() simply calls # str() on each positional argument, so we do the same here. "args": tuple(map(str, args)), "sep": sep, "end": end, "flush": flush, } if file == sys.stdout: msg["file"] = 1 # type: ignore elif file == sys.stderr: msg["file"] = 2 # type: ignore elif file is not None: raise TypeError( f"Remote printing to arbitrary file objects is not supported. file " f"kwarg must be one of None, sys.stdout, or sys.stderr; got: {file!r}" ) worker.log_event("print", msg) builtins.print(*args, sep=sep, end=end, file=file, flush=flush)
[docs]def warn( message: str | Warning, category: type[Warning] | None = UserWarning, stacklevel: int = 1, source: Any = None, ) -> None: """ A drop-in replacement of the built-in ``warnings.warn()`` function for issuing warnings remotely from workers to clients. If called from outside a dask worker, its arguments are passed directly to ``warnings.warn()``. If called by code running on a worker, then in addition to emitting a warning locally, any clients connected (possibly remotely) to the scheduler managing this worker will receive an event instructing them to emit the same warning (subject to their own local filters, etc.). When implementing computations that may run on a worker, the user can call this ``warn`` function to ensure that any remote client sessions will see their warnings, for example in a Jupyter output cell. While all of the arguments are respected by the locally emitted warning (with same meanings as in ``warnings.warn()``), ``stacklevel`` and ``source`` are ignored by clients because they would not be meaningful in the client's thread. Examples -------- >>> from dask.distributed import Client, warn >>> client = Client() >>> def do_warn(): ... warn("A warning from a worker.") >>> client.submit(do_warn).result() /path/to/distributed/client.py:678: UserWarning: A warning from a worker. """ try: worker = get_worker() except ValueError: # pragma: no cover pass else: # We are in a worker: log a warn event with args serialized to the # client. We have to pickle message and category into bytes ourselves # because msgpack cannot handle them. The expectations is that these are # always small objects. worker.log_event( "warn", { "message": pickle.dumps(message), "category": pickle.dumps(category), # We ignore stacklevel because it will be meaningless in the # client's thread/process. # We ignore source because we don't want to serialize arbitrary # objects. }, ) # add 1 to stacklevel so that, at least in the worker's local stderr, we'll # see the source line that called us warnings.warn(message, category, stacklevel + 1, source)
def benchmark_disk( rootdir: str | None = None, sizes: Iterable[str] = ("1 kiB", "100 kiB", "1 MiB", "10 MiB", "100 MiB"), duration="1 s", ) -> dict[str, float]: """ Benchmark disk bandwidth Returns ------- out: dict Maps sizes of outputs to measured bandwidths """ duration = parse_timedelta(duration) out = {} for size_str in sizes: with tmpdir(dir=rootdir) as dir: dir = pathlib.Path(dir) names = list(map(str, range(100))) size = parse_bytes(size_str) data = random.randbytes(size) start = time() total = 0 while time() < start + duration: with open(dir / random.choice(names), mode="ab") as f: f.write(data) f.flush() os.fsync(f.fileno()) total += size out[size_str] = total / (time() - start) return out def benchmark_memory( sizes: Iterable[str] = ("2 kiB", "10 kiB", "100 kiB", "1 MiB", "10 MiB"), duration="200 ms", ) -> dict[str, float]: """ Benchmark memory bandwidth Returns ------- out: dict Maps sizes of outputs to measured bandwidths """ duration = parse_timedelta(duration) out = {} for size_str in sizes: size = parse_bytes(size_str) data = random.randbytes(size) start = time() total = 0 while time() < start + duration: _ = data[:-1] del _ total += size out[size_str] = total / (time() - start) return out async def benchmark_network( address: str, rpc: ConnectionPool | Callable[[str], RPCType], sizes: Iterable[str] = ("1 kiB", "10 kiB", "100 kiB", "1 MiB", "10 MiB", "50 MiB"), duration="1 s", ) -> dict[str, float]: """ Benchmark network communications to another worker Returns ------- out: dict Maps sizes of outputs to measured bandwidths """ duration = parse_timedelta(duration) out = {} async with rpc(address) as r: for size_str in sizes: size = parse_bytes(size_str) data = to_serialize(random.randbytes(size)) start = time() total = 0 while time() < start + duration: await r.echo(data=data) total += size * 2 out[size_str] = total / (time() - start) return out