Source code for distributed.worker_state_machine

from __future__ import annotations

import abc
import asyncio
import heapq
import logging
import math
import operator
import random
import sys
import warnings
import weakref
from collections import Counter, defaultdict, deque
from collections.abc import (
    Awaitable,
    Callable,
    Collection,
    Container,
    Hashable,
    Iterator,
    Mapping,
    MutableMapping,
    Set,
)
from copy import copy
from dataclasses import dataclass, field
from functools import lru_cache, partial, singledispatchmethod, wraps
from itertools import chain
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict, Union, cast

from tlz import peekn

import dask
from dask.typing import Key
from dask.utils import key_split, parse_bytes, typename

from distributed._stories import worker_story
from distributed.collections import HeapSet
from distributed.comm import get_address_host
from distributed.core import ErrorMessage, error_message
from distributed.metrics import DelayedMetricsLedger, monotonic, time
from distributed.protocol import pickle
from distributed.protocol.serialize import Serialize, ToPickle
from distributed.sizeof import safe_sizeof as sizeof
from distributed.utils import recursive_to_dict

logger = logging.getLogger("distributed.worker.state_machine")

if TYPE_CHECKING:
    # TODO import from typing (ParamSpec and TypeAlias requires Python >=3.10)
    # TODO import from typing (NotRequired requires Python >=3.11)
    from typing_extensions import NotRequired, ParamSpec, TypeAlias

    P = ParamSpec("P")

    # Circular imports
    from distributed.diagnostics.plugin import WorkerPlugin
    from distributed.scheduler import T_runspec
    from distributed.worker import Worker

# Not to be confused with distributed.scheduler.TaskStateState
TaskStateState: TypeAlias = Literal[
    "cancelled",
    "constrained",
    "error",
    "executing",
    "fetch",
    "flight",
    "forgotten",
    "long-running",
    "memory",
    "missing",
    "ready",
    "released",
    "rescheduled",
    "resumed",
    "waiting",
]

# TaskState.state subsets
PROCESSING: Set[TaskStateState] = {
    "waiting",
    "ready",
    "constrained",
    "executing",
    "long-running",
}
READY: Set[TaskStateState] = {"ready", "constrained"}
# Valid states for a task that is found in TaskState.waiting_for_data
WAITING_FOR_DATA: Set[TaskStateState] = {
    "constrained",
    "executing",
    "fetch",
    "flight",
    "long-running",
    "missing",
    "ready",
    "resumed",
    "waiting",
}

NO_VALUE = "--no-value-sentinel--"

RUN_ID_SENTINEL = -1


class StartStop(TypedDict):
    action: Literal["compute", "transfer", "disk-read", "disk-write", "deserialize"]
    start: float
    stop: float
    source: NotRequired[str]


class InvalidTransition(Exception):
    def __init__(
        self,
        key: Key,
        start: TaskStateState,
        finish: TaskStateState,
        story: list[tuple],
    ):
        self.key = key
        self.start = start
        self.finish = finish
        self.story = story

    def __reduce__(self) -> tuple[Callable, tuple]:
        return type(self), (self.key, self.start, self.finish, self.story)

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}: {self.key!r} :: {self.start}->{self.finish}"
            + "\n"
            + "  Story:\n    "
            + "\n    ".join(map(str, self.story))
        )

    __str__ = __repr__

    def to_event(self) -> tuple[str, dict[str, Any]]:
        return (
            "invalid-worker-transition",
            {
                "key": self.key,
                "start": self.start,
                "finish": self.finish,
                "story": self.story,
            },
        )


class TransitionCounterMaxExceeded(InvalidTransition):
    def to_event(self) -> tuple[str, dict[str, Any]]:
        topic, msg = super().to_event()
        return "transition-counter-max-exceeded", msg


class InvalidTaskState(Exception):
    def __init__(
        self,
        key: Key,
        state: TaskStateState,
        story: list[tuple],
    ):
        self.key = key
        self.state = state
        self.story = story

    def __reduce__(self) -> tuple[Callable, tuple]:
        return type(self), (self.key, self.state, self.story)

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}: {self.key!r} :: {self.state}"
            + "\n"
            + "  Story:\n    "
            + "\n    ".join(map(str, self.story))
        )

    __str__ = __repr__

    def to_event(self) -> tuple[str, dict[str, Any]]:
        return (
            "invalid-worker-task-state",
            {
                "key": self.key,
                "state": self.state,
                "story": self.story,
            },
        )


class RecommendationsConflict(Exception):
    """Two or more recommendations for the same task suggested different finish states"""


@lru_cache
def _default_data_size() -> int:
    return parse_bytes(dask.config.get("distributed.scheduler.default-data-size"))


# Note: can't specify __slots__ manually to enable slots in Python <3.10 in a @dataclass
# that defines any default values
DC_SLOTS = {"slots": True} if sys.version_info >= (3, 10) else {}


[docs]@dataclass(repr=False, eq=False, **DC_SLOTS) class TaskState: """Holds volatile state relating to an individual Dask task. Not to be confused with :class:`distributed.scheduler.TaskState`, which holds similar information on the scheduler side. """ #: Task key. Mandatory. key: Key #: Task prefix (leftmost part of the key) prefix: str = field(init=False) #: Task run ID. run_id: int = RUN_ID_SENTINEL #: A tuple containing the ``function``, ``args``, ``kwargs`` and ``task`` #: associated with this `TaskState` instance. This defaults to ``None`` and can #: remain empty if it is a dependency that this worker will receive from another #: worker. run_spec: T_runspec | None = None #: The data needed by this key to run dependencies: set[TaskState] = field(default_factory=set) #: The keys that use this dependency dependents: set[TaskState] = field(default_factory=set) #: Subset of dependencies that are not in memory waiting_for_data: set[TaskState] = field(default_factory=set) #: Subset of dependents that are not in memory waiters: set[TaskState] = field(default_factory=set) #: The current state of the task state: TaskStateState = "released" #: The previous state of the task. It is not None iff :attr:`state` in #: (cancelled, resumed). previous: Literal["executing", "long-running", "flight", None] = None #: The next state of the task. It is not None iff :attr:`state` == resumed. next: Literal["fetch", "waiting", None] = None #: Expected duration of the task duration: float | None = None #: The priority this task given by the scheduler. Determines run order. priority: tuple[int, ...] | None = None #: Addresses of workers that we believe have this data who_has: set[str] = field(default_factory=set) #: The worker that current task data is coming from if task is in flight coming_from: str | None = None #: Abstract resources required to run a task resource_restrictions: dict[str, float] = field(default_factory=dict) #: The exception caused by running a task if it erred (serialized) exception: Serialize | None = None #: The traceback caused by running a task if it erred (serialized) traceback: Serialize | None = None #: string representation of exception exception_text: str = "" #: string representation of traceback traceback_text: str = "" #: The type of a particular piece of data type: type | None = None #: The number of times a dependency has not been where we expected it suspicious_count: int = 0 #: Log of transfer, load, and compute times for a task startstops: list[StartStop] = field(default_factory=list) #: Time at which task begins running start_time: float | None = None #: Time at which task finishes running stop_time: float | None = None #: Metadata related to the task. #: Stored metadata should be msgpack serializable (e.g. int, string, list, dict). metadata: dict = field(default_factory=dict) #: The size of the value of the task, if in memory nbytes: int | None = None #: Arbitrary task annotations annotations: dict | None = None #: unique span id (see ``distributed.spans``). #: Matches ``distributed.scheduler.TaskState.group.span_id``. span_id: str | None = None #: True if the :meth:`~WorkerBase.execute` or :meth:`~WorkerBase.gather_dep` #: coroutine servicing this task completed; False otherwise. This flag changes #: the behaviour of transitions out of the ``executing``, ``flight`` etc. states. done: bool = False _instances: ClassVar[weakref.WeakSet[TaskState]] = weakref.WeakSet() # Support for weakrefs to a class with __slots__ __weakref__: Any = field(init=False) def __post_init__(self) -> None: TaskState._instances.add(self) self.prefix = key_split(self.key) def __repr__(self) -> str: if self.state == "cancelled": state = f"cancelled({self.previous})" elif self.state == "resumed": state = f"resumed({self.previous}->{self.next})" else: state = self.state return f"<TaskState {self.key!r} {state}>" def __hash__(self) -> int: """Override dataclass __hash__, reverting to the default behaviour hash(o) == id(o). Note that we also defined @dataclass(eq=False), which reverts to the default behaviour (a == b) == (a is b). On first thought, it would make sense to use TaskState.key for equality and hashing. However, a task may be forgotten and a new TaskState object with the same key may be created in its place later on. In the Worker state, you should never have multiple TaskState objects with the same key; see WorkerState.validate_state for relevant checks. We can't assert the same thing in __eq__ though, as multiple objects with the same key may appear in TaskState._instances for a brief period of time. """ return id(self) def get_nbytes(self) -> int: nbytes = self.nbytes return nbytes if nbytes is not None else _default_data_size() def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. See also -------- Client.dump_cluster_state distributed.utils.recursive_to_dict Notes ----- This class uses ``_to_dict_no_nest`` instead of ``_to_dict``. When a task references another task, just print the task repr. All tasks should neatly appear under Worker.tasks. This also prevents a RecursionError during particularly heavy loads, which have been observed to happen whenever there's an acyclic dependency chain of ~200+ tasks. """ out = recursive_to_dict(self, exclude=exclude, members=True) # Remove all Nones, empty containers, and derived attributes return {k: v for k, v in out.items() if v and k != "prefix"}
[docs]@dataclass class Instruction: """Command from the worker state machine to the Worker, in response to an event""" __slots__ = ("stimulus_id",) stimulus_id: str
[docs] @classmethod def match(cls, **kwargs: Any) -> _InstructionMatch: """Generate a partial match to compare against an Instruction instance. The typical usage is to compare a list of instructions returned by :meth:`WorkerState.handle_stimulus` or in :attr:`WorkerState.stimulus_log` vs. an expected list of matches. Examples -------- .. code-block:: python instructions = ws.handle_stimulus(...) assert instructions == [ TaskFinishedMsg.match(key="x"), ... ] """ return _InstructionMatch(cls, **kwargs)
def __eq__(self, other: object) -> bool: if isinstance(other, _InstructionMatch): return other == self else: # Revert to default dataclass behaviour return super().__eq__(other)
class _InstructionMatch: """Utility class, to be used to test an instructions list. See :meth:`Instruction.match`. """ cls: type[Instruction] kwargs: dict[str, Any] def __init__(self, cls: type[Instruction], **kwargs: Any): self.cls = cls self.kwargs = kwargs def __repr__(self) -> str: cls_str = self.cls.__name__ kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items()) return f"{cls_str}({kwargs_str}) (partial match)" def __eq__(self, other: object) -> bool: if type(other) is not self.cls: return False return all(getattr(other, k) == v for k, v in self.kwargs.items()) @dataclass class GatherDep(Instruction): __slots__ = ("worker", "to_gather", "total_nbytes") worker: str to_gather: set[Key] total_nbytes: int @dataclass class Execute(Instruction): __slots__ = ("key",) key: Key @dataclass class RetryBusyWorkerLater(Instruction): __slots__ = ("worker",) worker: str class SendMessageToScheduler(Instruction): #: Matches a key in Scheduler.stream_handlers op: ClassVar[str] __slots__ = () def to_dict(self) -> dict[str, Any]: """Convert object to dict so that it can be serialized with msgpack""" d = {k: getattr(self, k) for k in self.__annotations__} d["op"] = self.op d["stimulus_id"] = self.stimulus_id return d @dataclass class TaskFinishedMsg(SendMessageToScheduler): op = "task-finished" key: Key run_id: int nbytes: int | None type: bytes # serialized class typename: str metadata: dict thread: int | None startstops: list[StartStop] __slots__ = tuple(__annotations__) def to_dict(self) -> dict[str, Any]: d = super().to_dict() d["status"] = "OK" return d @dataclass class TaskErredMsg(SendMessageToScheduler): op = "task-erred" key: Key run_id: int exception: Serialize traceback: Serialize | None exception_text: str traceback_text: str thread: int | None startstops: list[StartStop] __slots__ = tuple(__annotations__) def to_dict(self) -> dict[str, Any]: d = super().to_dict() d["status"] = "error" return d @staticmethod def from_task( ts: TaskState, run_id: int, stimulus_id: str, thread: int | None = None ) -> TaskErredMsg: assert ts.exception return TaskErredMsg( key=ts.key, run_id=run_id, exception=ts.exception, traceback=ts.traceback, exception_text=ts.exception_text, traceback_text=ts.traceback_text, thread=thread, startstops=ts.startstops, stimulus_id=stimulus_id, ) @dataclass class ReleaseWorkerDataMsg(SendMessageToScheduler): op = "release-worker-data" __slots__ = ("key",) key: Key # Not to be confused with RescheduleEvent below or the distributed.Reschedule Exception @dataclass class RescheduleMsg(SendMessageToScheduler): op = "reschedule" __slots__ = ("key",) key: Key @dataclass class LongRunningMsg(SendMessageToScheduler): op = "long-running" __slots__ = ("key", "compute_duration") key: Key compute_duration: float | None @dataclass class AddKeysMsg(SendMessageToScheduler): op = "add-keys" __slots__ = ("keys",) keys: Collection[Key] @dataclass class RequestRefreshWhoHasMsg(SendMessageToScheduler): """Worker -> Scheduler asynchronous request for updated who_has information. Not to be confused with the scheduler.who_has synchronous RPC call, which is used by the Client. See also -------- RefreshWhoHasEvent distributed.scheduler.Scheduler.request_refresh_who_has distributed.client.Client.who_has distributed.scheduler.Scheduler.get_who_has """ op = "request-refresh-who-has" __slots__ = ("keys",) keys: Collection[Key] @dataclass class StealResponseMsg(SendMessageToScheduler): """Worker->Scheduler response to ``{op: steal-request}`` See also -------- StealRequestEvent """ op = "steal-response" __slots__ = ("key", "state") key: Key state: TaskStateState | None
[docs]@dataclass class StateMachineEvent: """Base abstract class for all stimuli that can modify the worker state""" __slots__ = ("stimulus_id", "handled") #: Unique ID of the event stimulus_id: str #: timestamp of when the event was handled by the worker # TODO Switch to @dataclass(slots=True), uncomment the line below, and remove the # __new__ method (requires Python >=3.10) # handled: float | None = field(init=False, default=None) _classes: ClassVar[dict[str, type[StateMachineEvent]]] = {} def __new__(cls, *args: Any, **kwargs: Any) -> StateMachineEvent: """Hack to initialize the ``handled`` attribute in Python <3.10""" self = object.__new__(cls) self.handled = None return self def __init_subclass__(cls) -> None: StateMachineEvent._classes[cls.__name__] = cls
[docs] def to_loggable(self, *, handled: float) -> StateMachineEvent: """Produce a variant version of self that is small enough to be stored in memory in the medium term and contains meaningful information for debugging """ self.handled: float | None = handled return self
def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. See also -------- distributed.utils.recursive_to_dict """ info = {"cls": type(self).__name__} for k in dir(self): if k in exclude or k.startswith("_"): continue if isinstance(getattr(type(self), k, None), property): continue v = getattr(self, k) if not callable(v): info[k] = v return recursive_to_dict(info, exclude=exclude)
[docs] @staticmethod def from_dict(d: dict) -> StateMachineEvent: """Convert the output of ``recursive_to_dict`` back into the original object. The output object is meaningful for the purpose of rebuilding the state machine, but not necessarily identical to the original. """ kwargs = d.copy() cls = StateMachineEvent._classes[kwargs.pop("cls")] handled = kwargs.pop("handled") inst = cls(**kwargs) inst.handled = handled inst._after_from_dict() return inst
def _after_from_dict(self) -> None: """Optional post-processing after an instance is created by ``from_dict``"""
@dataclass class PauseEvent(StateMachineEvent): __slots__ = () @dataclass class UnpauseEvent(StateMachineEvent): __slots__ = () @dataclass class RetryBusyWorkerEvent(StateMachineEvent): __slots__ = ("worker",) worker: str @dataclass class GatherDepDoneEvent(StateMachineEvent): """:class:`GatherDep` instruction terminated (abstract base class)""" __slots__ = ("worker", "total_nbytes") worker: str total_nbytes: int # Must be the same as in GatherDep instruction @dataclass class GatherDepSuccessEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: remote worker fetched successfully """ __slots__ = ("data",) data: dict[Key, object] # There may be fewer keys than in GatherDep def to_loggable(self, *, handled: float) -> StateMachineEvent: out = copy(self) out.handled = handled out.data = {k: None for k in self.data} return out def _after_from_dict(self) -> None: self.data = {k: None for k in self.data} @dataclass class GatherDepBusyEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: remote worker is busy """ __slots__ = () @dataclass class GatherDepNetworkFailureEvent(GatherDepDoneEvent): """:class:`GatherDep` instruction terminated: network failure while trying to communicate with remote worker """ __slots__ = () @dataclass class GatherDepFailureEvent(GatherDepDoneEvent): """class:`GatherDep` instruction terminated: generic error raised (not a network failure); e.g. data failed to deserialize. """ exception: Serialize traceback: Serialize | None exception_text: str traceback_text: str __slots__ = tuple(__annotations__) def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) self.traceback = None @classmethod def from_exception( cls, err: BaseException, *, worker: str, total_nbytes: int, stimulus_id: str, ) -> GatherDepFailureEvent: msg = error_message(err) return cls( worker=worker, total_nbytes=total_nbytes, exception=msg["exception"], traceback=msg["traceback"], exception_text=msg["exception_text"], traceback_text=msg["traceback_text"], stimulus_id=stimulus_id, ) @dataclass class RemoveWorkerEvent(StateMachineEvent): worker: str __slots__ = ("worker",) @dataclass class ComputeTaskEvent(StateMachineEvent): key: Key run_id: int who_has: dict[Key, Collection[str]] nbytes: dict[Key, int] priority: tuple[int, ...] duration: float run_spec: T_runspec | None resource_restrictions: dict[str, float] actor: bool annotations: dict span_id: str | None __slots__ = tuple(__annotations__) def __post_init__(self) -> None: # Fixes after msgpack decode if isinstance(self.priority, list): # type: ignore[unreachable] self.priority = tuple(self.priority) # type: ignore[unreachable] if isinstance(self.run_spec, ToPickle): # FIXME Sometimes the protocol is not unpacking this # E.g. distributed/tests/test_client.py::test_async_with self.run_spec = self.run_spec.data # type: ignore[unreachable] def _to_dict(self, *, exclude: Container[str] = ()) -> dict: return StateMachineEvent._to_dict(self._clean(), exclude=exclude) def _clean(self) -> StateMachineEvent: out = copy(self) out.run_spec = None return out def to_loggable(self, *, handled: float) -> StateMachineEvent: out = self._clean() out.handled = handled return out def _after_from_dict(self) -> None: self.run_spec = None @classmethod def _f(cls) -> None: return # pragma: nocover @classmethod def dummy_runspec(cls) -> tuple[Callable, tuple, dict]: return (cls._f, (), {}) @staticmethod def dummy( key: Key, *, run_id: int = 0, who_has: dict[Key, Collection[str]] | None = None, nbytes: dict[Key, int] | None = None, priority: tuple[int, ...] = (0,), duration: float = 1.0, resource_restrictions: dict[str, float] | None = None, actor: bool = False, annotations: dict | None = None, stimulus_id: str, ) -> ComputeTaskEvent: """Build a dummy event, with most attributes set to a reasonable default. This is a convenience method to be used in unit testing only. """ return ComputeTaskEvent( key=key, run_id=run_id, who_has=who_has or {}, nbytes=nbytes or {k: 1 for k in who_has or ()}, priority=priority, duration=duration, run_spec=ComputeTaskEvent.dummy_runspec(), resource_restrictions=resource_restrictions or {}, actor=actor, annotations=annotations or {}, span_id=None, stimulus_id=stimulus_id, ) @dataclass class ExecuteDoneEvent(StateMachineEvent): """Abstract base event for all the possible outcomes of a :class:`Compute` instruction """ key: Key __slots__ = ("key",) @dataclass class ExecuteSuccessEvent(ExecuteDoneEvent): run_id: int # FIXME: Utilize the run ID in all ExecuteDoneEvents value: object start: float stop: float nbytes: int type: type | None __slots__ = tuple(__annotations__) def to_loggable(self, *, handled: float) -> StateMachineEvent: out = copy(self) out.handled = handled out.value = None return out def _to_dict(self, *, exclude: Container[str] = ()) -> dict: d = super()._to_dict(exclude=exclude) # This is excluded by the parent class as it is a callable if "type" not in exclude: d["type"] = str(self.type) return d def _after_from_dict(self) -> None: self.value = None self.type = None @staticmethod def dummy( key: Key, value: object = None, *, run_id: int = 1, nbytes: int = 1, stimulus_id: str, ) -> ExecuteSuccessEvent: """Build a dummy event, with most attributes set to a reasonable default. This is a convenience method to be used in unit testing only. """ return ExecuteSuccessEvent( key=key, run_id=run_id, value=value, start=0.0, stop=1.0, nbytes=nbytes, type=None, stimulus_id=stimulus_id, ) @dataclass class ExecuteFailureEvent(ExecuteDoneEvent): run_id: int # FIXME: Utilize the run ID in all ExecuteDoneEvents start: float | None stop: float | None exception: Serialize traceback: Serialize | None exception_text: str traceback_text: str __slots__ = tuple(__annotations__) def _after_from_dict(self) -> None: self.exception = Serialize(Exception()) self.traceback = None @classmethod def from_exception( cls, err_or_msg: BaseException | ErrorMessage, *, key: Key, run_id: int, start: float | None = None, stop: float | None = None, stimulus_id: str, ) -> ExecuteFailureEvent: if isinstance(err_or_msg, dict): msg = err_or_msg else: msg = error_message(err_or_msg) return cls( key=key, run_id=run_id, start=start, stop=stop, exception=msg["exception"], traceback=msg["traceback"], exception_text=msg["exception_text"], traceback_text=msg["traceback_text"], stimulus_id=stimulus_id, ) @staticmethod def dummy( key: Key, *, run_id: int = 1, stimulus_id: str, ) -> ExecuteFailureEvent: """Build a dummy event, with most attributes set to a reasonable default. This is a convenience method to be used in unit testing only. """ return ExecuteFailureEvent( key=key, run_id=run_id, start=None, stop=None, exception=Serialize(None), traceback=None, exception_text="", traceback_text="", stimulus_id=stimulus_id, ) # Not to be confused with RescheduleMsg above or the distributed.Reschedule Exception @dataclass class RescheduleEvent(ExecuteDoneEvent): __slots__ = () @staticmethod def dummy(key: Key, *, stimulus_id: str) -> RescheduleEvent: """Build an event. This method exists for compatibility with the other ExecuteDoneEvent subclasses. """ return RescheduleEvent(key=key, stimulus_id=stimulus_id) @dataclass class CancelComputeEvent(StateMachineEvent): __slots__ = ("key",) key: Key @dataclass class FindMissingEvent(StateMachineEvent): __slots__ = () @dataclass class RefreshWhoHasEvent(StateMachineEvent): """Scheduler -> Worker message containing updated who_has information. See also -------- RequestRefreshWhoHasMsg """ __slots__ = ("who_has",) # {key: [worker address, ...]} who_has: dict[Key, Collection[str]] @dataclass class AcquireReplicasEvent(StateMachineEvent): __slots__ = ("who_has", "nbytes") who_has: dict[Key, Collection[str]] nbytes: dict[Key, int] @dataclass class RemoveReplicasEvent(StateMachineEvent): __slots__ = ("keys",) keys: Collection[Key] @dataclass class FreeKeysEvent(StateMachineEvent): __slots__ = ("keys",) keys: Collection[Key] @dataclass class StealRequestEvent(StateMachineEvent): """Event that requests a worker to release a key because it's now being computed somewhere else. See also -------- StealResponseMsg """ __slots__ = ("key",) key: Key @dataclass class UpdateDataEvent(StateMachineEvent): __slots__ = ("data",) data: dict[Key, object] def to_loggable(self, *, handled: float) -> StateMachineEvent: out = copy(self) out.handled = handled out.data = dict.fromkeys(self.data) return out @dataclass class SecedeEvent(StateMachineEvent): __slots__ = ("key", "compute_duration") key: Key compute_duration: float # {TaskState -> finish: TaskStateState | (finish: TaskStateState, transition *args)} # Not to be confused with distributed.scheduler.Recs Recs: TypeAlias = dict[TaskState, Union[TaskStateState, tuple]] Instructions: TypeAlias = list[Instruction] RecsInstrs: TypeAlias = tuple[Recs, Instructions] def merge_recs_instructions(*args: RecsInstrs) -> RecsInstrs: """Merge multiple (recommendations, instructions) tuples. Collisions in recommendations are only allowed if identical. """ recs: Recs = {} instr: Instructions = [] for recs_i, instr_i in args: for ts, finish in recs_i.items(): if ts in recs and recs[ts] != finish: raise RecommendationsConflict( f"Mismatched recommendations for {ts.key!r}: {recs[ts]} vs. {finish}" ) recs[ts] = finish instr += instr_i return recs, instr
[docs]class WorkerState: """State machine encapsulating the lifetime of all tasks on a worker. Not to be confused with :class:`distributed.scheduler.WorkerState`. .. note:: The data attributes of this class are implementation details and may be changed without a deprecation cycle. .. warning:: The attributes of this class are all heavily correlated with each other. *Do not* modify them directly, *ever*, as it is extremely easy to obtain a broken state this way, which in turn will likely result in cluster-wide deadlocks. The state should be exclusively mutated through :meth:`handle_stimulus`. """ #: Worker <IP address>:<port>. This is used in decision-making by the state machine, #: e.g. to determine if a peer worker is running on the same host or not. #: This attribute may not be known when the WorkerState is initialised. It *must* be #: set before the first call to :meth:`handle_stimulus`. address: str #: ``{key: TaskState}``. The tasks currently executing on this worker (and any #: dependencies of those tasks) tasks: dict[Key, TaskState] #: ``{ts.key: thread ID}``. This collection is shared by reference between #: :class:`~distributed.worker.Worker` and this class. While the WorkerState is #: thread-agnostic, it still needs access to this information in some cases. #: This collection is populated by :meth:`distributed.worker.Worker.execute`. #: It does not *need* to be populated for the WorkerState to work. threads: dict[Key, int] #: In-memory tasks data. This collection is shared by reference between #: :class:`~distributed.worker.Worker`, #: :class:`~distributed.worker_memory.WorkerMemoryManager`, and this class. data: MutableMapping[Key, object] #: ``{name: worker plugin}``. This collection is shared by reference between #: :class:`~distributed.worker.Worker` and this class. The Worker managed adding and #: removing plugins, while the WorkerState invokes the ``WorkerPlugin.transition`` #: method, is available. plugins: dict[str, WorkerPlugin] #: Priority heap of tasks that are ready to run and have no resource constrains. #: Mutually exclusive with :attr:`constrained`. ready: HeapSet[TaskState] #: Priority heap of tasks that are ready to run, but are waiting on abstract #: resources like GPUs. Mutually exclusive with :attr:`ready`. #: See :attr:`available_resources` and :doc:`resources`. constrained: HeapSet[TaskState] #: Number of tasks that can be executing in parallel. #: At any given time, :meth:`executing_count` <= nthreads. nthreads: int #: True if the state machine should start executing more tasks and fetch #: dependencies whenever a slot is available. This property must be kept aligned #: with the Worker: ``WorkerState.running == (Worker.status is Status.running)``. running: bool #: Tasks that are currently waiting for data waiting: set[TaskState] #: ``{worker address: {ts.key, ...}``. #: The data that we care about that we think a worker has has_what: defaultdict[str, set[Key]] #: The tasks which still require data in order to execute and are in memory on at #: least another worker, prioritized as per-worker heaps. All and only tasks with #: ``TaskState.state == 'fetch'`` are in this collection. A :class:`TaskState` with #: multiple entries in :attr:`~TaskState.who_has` will appear multiple times here. data_needed: defaultdict[str, HeapSet[TaskState]] #: Total number of tasks in fetch state. If a task is in more than one data_needed #: heap, it's only counted once. fetch_count: int #: Number of bytes to gather from the same worker in a single call to #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be gathered from the #: same worker will be batched in a single instruction as long as their combined #: size doesn't exceed this value. If the first task to be gathered exceeds this #: limit, it will still be gathered to ensure progress. Hence, this limit is not #: absolute. transfer_message_bytes_limit: float #: All and only tasks with ``TaskState.state == 'missing'``. missing_dep_flight: set[TaskState] #: Tasks that are coming to us in current peer-to-peer connections. #: #: This set includes exclusively tasks with :attr:`~TaskState.state` == 'flight' as #: well as tasks with :attr:`~TaskState.state` in ('cancelled', 'resumed') and #: :attr:`~TaskState.previous` == 'flight`. #: #: See also :meth:`in_flight_tasks_count`. in_flight_tasks: set[TaskState] #: ``{worker address: {ts.key, ...}}`` #: The workers from which we are currently gathering data and the dependencies we #: expect from those connections. Workers in this dict won't be asked for additional #: dependencies until the current query returns. in_flight_workers: dict[str, set[Key]] #: Current total size of open data transfers from other workers transfer_incoming_bytes: int #: Maximum number of concurrent incoming data transfers from other workers. #: See also :attr:`distributed.worker.Worker.transfer_outgoing_count_limit`. transfer_incoming_count_limit: int #: Total number of data transfers from other workers since the worker was started. transfer_incoming_count_total: int #: Ignore :attr:`transfer_incoming_count_limit` as long as :attr:`transfer_incoming_bytes` is #: less than this value. transfer_incoming_bytes_throttle_threshold: int #: Peer workers that recently returned a busy status. Workers in this set won't be #: asked for additional dependencies for some time. busy_workers: set[str] #: Counter that decreases every time the compute-task handler is invoked by the #: Scheduler. It is appended to :attr:`TaskState.priority` and acts as a #: tie-breaker between tasks that have the same priority on the Scheduler, #: determining a last-in-first-out order between them. generation: int #: ``{resource name: amount}``. Total resources available for task execution. #: See :doc: `resources`. total_resources: dict[str, float] #: ``{resource name: amount}``. Current resources that aren't being currently #: consumed by task execution. Always less or equal to :attr:`total_resources`. #: See :doc:`resources`. available_resources: dict[str, float] #: Set of tasks that are currently running. #: #: This set includes exclusively tasks with :attr:`~TaskState.state` == 'executing' #: as well as tasks with :attr:`~TaskState.state` in ('cancelled', 'resumed') and #: :attr:`~TaskState.previous` == 'executing`. #: #: See also :meth:`executing_count` and :attr:`long_running`. executing: set[TaskState] #: Set of tasks that are currently running and have called #: :func:`~distributed.secede`, so they no longer count towards the maximum number #: of concurrent tasks (nthreads). #: These tasks do not appear in the :attr:`executing` set. #: #: This set includes exclusively tasks with #: :attr:`~TaskState.state` == 'long-running' as well as tasks with #: :attr:`~TaskState.state` in ('cancelled', 'resumed') and #: :attr:`~TaskState.previous` == 'long-running`. long_running: set[TaskState] #: A number of tasks that this worker has run in its lifetime; this includes failed #: and cancelled tasks. See also :meth:`executing_count`. executed_count: int #: Total size of all tasks in memory nbytes: int #: Actor tasks. See :doc:`actors`. actors: dict[Key, object] #: Transition log: ``[(..., stimulus_id: str | None, timestamp: float), ...]`` #: The number of stimuli logged is capped. #: See also :meth:`story` and :attr:`stimulus_log`. log: deque[tuple] #: Log of all stimuli received by :meth:`handle_stimulus`. #: The number of events logged is capped. #: See also :attr:`log` and :meth:`stimulus_story`. stimulus_log: deque[StateMachineEvent] #: If True, enable expensive internal consistency check. #: Typically disabled in production. validate: bool #: Current number of tasks and cumulative elapsed time in each state, #: both broken down by :attr:`prefix` task_counter: TaskCounter #: Total number of state transitions so far. #: See also :attr:`log` and :attr:`transition_counter_max`. transition_counter: int #: Raise an error if the :attr:`transition_counter` ever reaches this value. #: This is meant for debugging only, to catch infinite recursion loops. #: In production, it should always be set to False. transition_counter_max: int | Literal[False] #: Limit of bytes for incoming data transfers; this is used for throttling. transfer_incoming_bytes_limit: float #: Statically-seeded random state, used to guarantee determinism whenever a #: pseudo-random choice is required rng: random.Random __slots__ = tuple(__annotations__) def __init__( self, *, nthreads: int = 1, address: str | None = None, data: MutableMapping[Key, object] | None = None, threads: dict[Key, int] | None = None, plugins: dict[str, WorkerPlugin] | None = None, resources: Mapping[str, float] | None = None, transfer_incoming_count_limit: int = 9999, validate: bool = True, transition_counter_max: int | Literal[False] = False, transfer_incoming_bytes_limit: float = math.inf, transfer_message_bytes_limit: float = math.inf, ): self.nthreads = nthreads # address may not be known yet when the State Machine is initialised. # Raise AttributeError if a method tries reading it before it's been set. if address: self.address = address # These collections are normally passed by reference by the Worker. # For the sake of convenience, create independent ones during unit tests. self.data = data if data is not None else {} self.threads = threads if threads is not None else {} self.plugins = plugins if plugins is not None else {} self.total_resources = dict(resources) if resources is not None else {} self.available_resources = self.total_resources.copy() self.validate = validate self.tasks = {} self.running = True self.waiting = set() self.has_what = defaultdict(set) self.data_needed = defaultdict( partial(HeapSet[TaskState], key=operator.attrgetter("priority")) ) self.fetch_count = 0 self.in_flight_workers = {} self.busy_workers = set() self.transfer_incoming_count_limit = transfer_incoming_count_limit self.transfer_incoming_count_total = 0 self.transfer_incoming_bytes_throttle_threshold = int(10e6) self.transfer_incoming_bytes = 0 self.missing_dep_flight = set() self.generation = 0 self.ready = HeapSet(key=operator.attrgetter("priority")) self.constrained = HeapSet(key=operator.attrgetter("priority")) self.executing = set() self.in_flight_tasks = set() self.nbytes = 0 self.executed_count = 0 self.long_running = set() self.transfer_message_bytes_limit = transfer_message_bytes_limit maxlen = dask.config.get("distributed.admin.low-level-log-length") self.log = deque(maxlen=maxlen) self.stimulus_log = deque(maxlen=maxlen) self.task_counter = TaskCounter() self.transition_counter = 0 self.transition_counter_max = transition_counter_max self.transfer_incoming_bytes_limit = transfer_incoming_bytes_limit self.actors = {} self.rng = random.Random(0)
[docs] def handle_stimulus(self, *stims: StateMachineEvent) -> Instructions: """Process one or more external events, transition relevant tasks to new states, and return a list of instructions to be executed as a consequence. See also -------- BaseWorker.handle_stimulus """ instructions = [] handled = time() for stim in stims: if not isinstance(stim, FindMissingEvent): self.stimulus_log.append(stim.to_loggable(handled=handled)) recs, instr = self._handle_event(stim) instructions += instr instructions += self._transitions(recs, stimulus_id=stim.stimulus_id) return instructions
############# # Accessors # ############# @property def executing_count(self) -> int: """Count of tasks currently executing on this worker and counting towards the maximum number of threads. It includes cancelled tasks, but does not include long running (a.k.a. seceded) tasks. See also -------- WorkerState.executing WorkerState.executed_count WorkerState.nthreads WorkerState.all_running_tasks """ return len(self.executing) @property def all_running_tasks(self) -> set[TaskState]: """All tasks that are currently occupying a thread. They may or may not count towards the maximum number of threads. These are: - ts.status in (executing, long-running) - ts.status in (cancelled, resumed) and ts.previous in (executing, long-running) See also -------- WorkerState.executing_count """ # Note: cancelled and resumed tasks are still in either of these sets return self.executing | self.long_running @property def in_flight_tasks_count(self) -> int: """Number of tasks currently being replicated from other workers to this one. See also -------- WorkerState.in_flight_tasks """ return len(self.in_flight_tasks) @property def transfer_incoming_count(self) -> int: """Current number of open data transfers from other workers. See also -------- WorkerState.in_flight_workers """ return len(self.in_flight_workers) ######################### # Shared helper methods # ######################### def _ensure_task_exists( self, key: Key, *, priority: tuple[int, ...], stimulus_id: str ) -> TaskState: try: ts = self.tasks[key] logger.debug("Data task %s already known (stimulus_id=%s)", ts, stimulus_id) except KeyError: self.tasks[key] = ts = TaskState(key) self.task_counter.new_task(ts) if not ts.priority: assert priority ts.priority = priority self.log.append((key, "ensure-task-exists", ts.state, stimulus_id, time())) return ts def _update_who_has(self, who_has: Mapping[Key, Collection[str]]) -> None: for key, workers in who_has.items(): ts = self.tasks.get(key) if not ts: # The worker sent a refresh-who-has request to the scheduler but, by the # time the answer comes back, some of the keys have been forgotten. continue workers = set(workers) if self.address in workers: workers.remove(self.address) # This can only happen if rebalance() recently asked to release a key, # but the RPC call hasn't returned yet. rebalance() is flagged as not # being safe to run while the cluster is not at rest and has already # been penned in to be redesigned on top of the AMM. # It is not necessary to send a message back to the # scheduler here, because it is guaranteed that there's already a # release-worker-data message in transit to it. if ts.state != "memory": logger.debug( # pragma: nocover "Scheduler claims worker %s holds data for task %s, " "which is not true.", self.address, ts, ) if ts.who_has == workers: continue for worker in ts.who_has - workers: self.has_what[worker].remove(key) if ts.state == "fetch": self.data_needed[worker].remove(ts) for worker in workers - ts.who_has: self.has_what[worker].add(key) if ts.state == "fetch": self.data_needed[worker].add(ts) ts.who_has = workers def _purge_state(self, ts: TaskState) -> None: """Ensure that TaskState attributes are reset to a neutral default and Worker-level state associated to the provided key is cleared (e.g. who_has) This is idempotent """ logger.debug("Purge task: %s", ts) # Do not use self.data.pop(key, None), as it could unspill the data from disk! if ts.key in self.data: del self.data[ts.key] self.actors.pop(ts.key, None) self.threads.pop(ts.key, None) for worker in ts.who_has: self.has_what[worker].discard(ts.key) self.data_needed[worker].discard(ts) ts.who_has.clear() for d in ts.dependencies: ts.waiting_for_data.discard(d) d.waiters.discard(ts) ts.waiting_for_data.clear() ts.nbytes = None ts.previous = None ts.next = None ts.done = False ts.coming_from = None ts.exception = None ts.traceback = None ts.traceback_text = "" ts.traceback_text = "" self.missing_dep_flight.discard(ts) self.ready.discard(ts) self.constrained.discard(ts) self.executing.discard(ts) self.long_running.discard(ts) self.in_flight_tasks.discard(ts) self.waiting.discard(ts) def _should_throttle_incoming_transfers(self) -> bool: """Decides whether the WorkerState should throttle data transfers from other workers. Returns ------- * True if the number of incoming data transfers reached its limit and the size of incoming data transfers reached the minimum threshold for throttling * True if the size of incoming data transfers reached its limit * False otherwise """ reached_count_limit = ( self.transfer_incoming_count >= self.transfer_incoming_count_limit ) reached_throttle_threshold = ( self.transfer_incoming_bytes >= self.transfer_incoming_bytes_throttle_threshold ) reached_bytes_limit = ( self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit ) return reached_count_limit and reached_throttle_threshold or reached_bytes_limit def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: """Transition tasks from fetch to flight, until there are no more tasks in fetch state or a threshold has been reached. """ if not self.running or not self.data_needed: return {}, [] if self._should_throttle_incoming_transfers(): return {}, [] recommendations: Recs = {} instructions: Instructions = [] for worker, available_tasks in self._select_workers_for_gather(): assert worker != self.address to_gather_tasks, message_nbytes = self._select_keys_for_gather( available_tasks ) # We always load at least one task assert to_gather_tasks or self.transfer_incoming_bytes # ...but that task might be selected in the previous iteration of the loop if not to_gather_tasks: break to_gather_keys = {ts.key for ts in to_gather_tasks} logger.debug( "Gathering %d tasks from %s; %d more remain. " "Pending workers: %d; connections: %d/%d; busy: %d", len(to_gather_tasks), worker, len(available_tasks), len(self.data_needed), self.transfer_incoming_count, self.transfer_incoming_count_limit, len(self.busy_workers), ) self.log.append( ("gather-dependencies", worker, to_gather_keys, stimulus_id, time()) ) for ts in to_gather_tasks: if self.validate: assert ts.state == "fetch" assert worker in ts.who_has assert ts not in recommendations recommendations[ts] = ("flight", worker) # A single invocation of _ensure_communicating may generate up to one # GatherDep instruction per worker. Multiple tasks from the same worker may # be batched in the same instruction by _select_keys_for_gather. But once # a worker has been selected for a GatherDep and added to in_flight_workers, # it won't be selected again until the gather completes. instructions.append( GatherDep( worker=worker, to_gather=to_gather_keys, total_nbytes=message_nbytes, stimulus_id=stimulus_id, ) ) self.in_flight_workers[worker] = to_gather_keys self.transfer_incoming_count_total += 1 self.transfer_incoming_bytes += message_nbytes if self._should_throttle_incoming_transfers(): break return recommendations, instructions def _select_workers_for_gather(self) -> Iterator[tuple[str, HeapSet[TaskState]]]: """Helper of _ensure_communicating. Yield the peer workers and tasks in data_needed, sorted by: 1. By highest-priority task available across all workers 2. If tied, first by local peer workers, then remote. Note that, if a task is replicated across multiple host, it may go in a tie with itself. 3. If still tied, by number of tasks available to be fetched from the host (see note below) 4. If still tied, by a random element. This is statically seeded to guarantee reproducibility. FIXME https://github.com/dask/distributed/issues/6620 You won't get determinism when a single task is replicated on multiple workers, because TaskState.who_has changes order at every interpreter restart. Omit workers that are either busy or in flight. Remove peer workers with no tasks from data_needed. Note ---- Instead of number of tasks, we could've measured total nbytes and/or number of tasks that only exist on the worker. Raw number of tasks is cruder but simpler. """ host = get_address_host(self.address) heap = [] for worker, tasks in list(self.data_needed.items()): if not tasks: del self.data_needed[worker] continue if worker in self.in_flight_workers or worker in self.busy_workers: continue heap.append( ( tasks.peek().priority, get_address_host(worker) != host, # False < True -len(tasks), self.rng.random(), worker, tasks, ) ) heapq.heapify(heap) while heap: _, is_remote, ntasks_neg, rnd, worker, tasks = heapq.heappop(heap) # The number of tasks and possibly the top priority task may have changed # since the last sort, since _select_keys_for_gather may have removed tasks # that are also replicated on a higher-priority worker. if not tasks: del self.data_needed[worker] elif -ntasks_neg != len(tasks): heapq.heappush( heap, (tasks.peek().priority, is_remote, -len(tasks), rnd, worker, tasks), ) else: yield worker, tasks if not tasks: # _select_keys_for_gather just emptied it del self.data_needed[worker] def _select_keys_for_gather( self, available: HeapSet[TaskState] ) -> tuple[list[TaskState], int]: """Helper of _ensure_communicating. Fetch all tasks that are replicated on the target worker within a single message, up to transfer_message_bytes_limit or until we reach the limit for the size of incoming data transfers. """ to_gather: list[TaskState] = [] message_nbytes = 0 while available: ts = available.peek() if self._task_exceeds_transfer_limits(ts, message_nbytes): break for worker in ts.who_has: # This also effectively pops from available self.data_needed[worker].remove(ts) to_gather.append(ts) message_nbytes += ts.get_nbytes() return to_gather, message_nbytes def _task_exceeds_transfer_limits(self, ts: TaskState, message_nbytes: int) -> bool: """Would asking to gather this task exceed transfer limits? Parameters ---------- ts Candidate task for gathering message_nbytes Total number of bytes already scheduled for gathering in this message Returns ------- exceeds_limit True if gathering the task would exceed limits, False otherwise (in which case the task can be gathered). """ if self.transfer_incoming_bytes == 0 and message_nbytes == 0: # When there is no other traffic, the top-priority task is fetched # regardless of its size to ensure progress return False incoming_bytes_allowance = ( self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes ) # If message_nbytes == 0, i.e., this is the first task to gather in this # message, ignore `self.transfer_message_bytes_limit` for the top-priority # task to ensure progress. Otherwise: if message_nbytes != 0: incoming_bytes_allowance = ( min( incoming_bytes_allowance, self.transfer_message_bytes_limit, ) - message_nbytes ) return ts.get_nbytes() > incoming_bytes_allowance def _ensure_computing(self) -> RecsInstrs: if not self.running: return {}, [] recs: Recs = {} while len(self.executing) < self.nthreads: ts = self._next_ready_task() if not ts: break if self.validate: assert ts.state in READY assert ts not in recs recs[ts] = "executing" self._acquire_resources(ts) self.executing.add(ts) return recs, [] def _next_ready_task(self) -> TaskState | None: """Pop the top-priority task from self.ready or self.constrained""" if self.ready and self.constrained: tsr = self.ready.peek() tsc = self.constrained.peek() assert tsr.priority assert tsc.priority if tsc.priority < tsr.priority and self._resource_restrictions_satisfied( tsc ): return self.constrained.pop() else: return self.ready.pop() elif self.ready: return self.ready.pop() elif self.constrained: tsc = self.constrained.peek() if self._resource_restrictions_satisfied(tsc): return self.constrained.pop() return None def _get_task_finished_msg( self, ts: TaskState, run_id: int, stimulus_id: str, ) -> TaskFinishedMsg: if self.validate: assert ts.state == "memory" assert ts.key in self.data or ts.key in self.actors assert ts.type is not None assert ts.nbytes is not None try: type_serialized = pickle.dumps(ts.type) except Exception: # Some types fail pickling (example: _thread.lock objects); # send their name as a best effort. type_serialized = pickle.dumps(typename(ts.type)) return TaskFinishedMsg( key=ts.key, run_id=run_id, nbytes=ts.nbytes, type=type_serialized, typename=typename(ts.type), metadata=ts.metadata, thread=self.threads.get(ts.key), startstops=ts.startstops, stimulus_id=stimulus_id, ) ############### # Transitions # ############### def _transition_generic_fetch(self, ts: TaskState, stimulus_id: str) -> RecsInstrs: if not ts.who_has: return {ts: "missing"}, [] ts.state = "fetch" ts.done = False self.fetch_count += 1 assert ts.priority for w in ts.who_has: self.data_needed[w].add(ts) return {}, [] def _transition_missing_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: self.missing_dep_flight.discard(ts) self._purge_state(ts) return self._transition_released_waiting(ts, stimulus_id=stimulus_id) def _transition_missing_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "missing" if not ts.who_has: return {}, [] self.missing_dep_flight.discard(ts) return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) def _transition_missing_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: self.missing_dep_flight.discard(ts) recs, instructions = self._transition_generic_released( ts, stimulus_id=stimulus_id ) assert ts.key in self.tasks return recs, instructions def _transition_flight_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: assert ts.done return self._transition_generic_missing(ts, stimulus_id=stimulus_id) def _transition_generic_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert not ts.who_has ts.state = "missing" self.missing_dep_flight.add(ts) ts.done = False return {}, [] def _transition_released_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "released" return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) def _transition_generic_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: self._purge_state(ts) recs: Recs = {} for dependency in ts.dependencies: if ( not dependency.waiters and dependency.state not in READY | PROCESSING | {"memory"} ): recs[dependency] = "released" ts.state = "released" if not ts.dependents: recs[ts] = "forgotten" return recs, [] def _transition_released_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert all(d.key in self.tasks for d in ts.dependencies) recommendations: Recs = {} ts.waiting_for_data.clear() for dep_ts in ts.dependencies: if dep_ts.state != "memory": ts.waiting_for_data.add(dep_ts) dep_ts.waiters.add(ts) recommendations[dep_ts] = "fetch" if not ts.waiting_for_data: recommendations[ts] = "ready" ts.state = "waiting" self.waiting.add(ts) return recommendations, [] def _transition_fetch_flight( self, ts: TaskState, worker: str, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "fetch" assert ts.who_has # The task has already been removed by _ensure_communicating for w in ts.who_has: assert ts not in self.data_needed[w] ts.done = False ts.state = "flight" ts.coming_from = worker self.in_flight_tasks.add(ts) self.fetch_count -= 1 return {}, [] def _transition_fetch_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: self.fetch_count -= 1 return self._transition_generic_missing(ts, stimulus_id=stimulus_id) def _transition_fetch_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: self.fetch_count -= 1 return self._transition_generic_released(ts, stimulus_id=stimulus_id) def _transition_memory_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: assert ts.nbytes is not None self.nbytes -= ts.nbytes recs, instructions = self._transition_generic_released( ts, stimulus_id=stimulus_id ) instructions.append(ReleaseWorkerDataMsg(key=ts.key, stimulus_id=stimulus_id)) return recs, instructions def _transition_waiting_constrained( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" assert not ts.waiting_for_data assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) assert all(dep.state == "memory" for dep in ts.dependencies) assert ts not in self.ready assert ts not in self.constrained ts.state = "constrained" self.waiting.remove(ts) self.constrained.add(ts) return self._ensure_computing() def _transition_executing_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """Note: this transition is triggered exclusively by a task raising the Reschedule() Exception; it is not involved in work stealing. """ assert ts.done return merge_recs_instructions( ({}, [RescheduleMsg(key=ts.key, stimulus_id=stimulus_id)]), # Note: this is not the same as recommending {ts: "released"} on the # previous line, as it would instead run the ("executing", "released") # transition, which would need special code for ts.done=True. self._transition_generic_released(ts, stimulus_id=stimulus_id), ) def _transition_waiting_ready( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "waiting" assert ts not in self.ready assert ts not in self.constrained assert not ts.waiting_for_data for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors assert dep.state == "memory" if ts.resource_restrictions: return {ts: "constrained"}, [] ts.state = "ready" assert ts.priority is not None self.waiting.remove(ts) self.ready.add(ts) return self._ensure_computing() def _transition_generic_error( self, ts: TaskState, exception: Serialize, traceback: Serialize | None, exception_text: str, traceback_text: str, run_id: int, *, stimulus_id: str, ) -> RecsInstrs: ts.exception = exception ts.traceback = traceback ts.exception_text = exception_text ts.traceback_text = traceback_text ts.state = "error" smsg = TaskErredMsg.from_task( ts, run_id=run_id, stimulus_id=stimulus_id, thread=self.threads.get(ts.key), ) return {}, [smsg] def _transition_resumed_error( self, ts: TaskState, exception: Serialize, traceback: Serialize | None, exception_text: str, traceback_text: str, run_id: int, *, stimulus_id: str, ) -> RecsInstrs: """In case of failure of the previous state, discard the error and kick off the next state without informing the scheduler """ assert ts.done if ts.previous in ("executing", "long-running"): assert ts.next == "fetch" recs: Recs = {ts: "fetch"} else: assert ts.previous == "flight" assert ts.next == "waiting" recs = {ts: "waiting"} ts.state = "released" ts.done = False ts.previous = None ts.next = None return recs, [] def _transition_resumed_rescheduled( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """If the task raises the Reschedule() exception, but the scheduler already told the worker to fetch it somewhere else, silently transition to fetch. Note that this transition effectively duplicates the logic of _transition_resumed_error. """ assert ts.done assert ts.previous in ("executing", "long-running") assert ts.next == "fetch" ts.state = "released" ts.done = False ts.previous = None ts.next = None return {ts: "fetch"}, [] def _transition_resumed_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """ See also -------- _transition_cancelled_fetch _transition_cancelled_waiting _transition_flight_fetch """ if ts.previous == "flight": if self.validate: assert ts.next == "waiting" if ts.done: # We arrived here either from GatherDepNetworkFailureEvent or from # GatherDepSuccessEvent but without the key in the data attribute. # We would now normally try to fetch the task from another peer worker # or transition it to missing if none are left; here instead we're going # to compute the task as we had been asked by the scheduler. ts.state = "released" ts.done = False ts.previous = None ts.next = None return {ts: "waiting"}, [] else: # We're back where we started. We should forget about the entire # cancellation attempt ts.state = "flight" ts.previous = None ts.next = None elif self.validate: assert ts.previous in ("executing", "long-running") assert ts.next == "fetch" # None of the exit events of execute recommend a transition to fetch assert not ts.done return {}, [] def _transition_resumed_missing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: return {ts: "fetch"}, [] def _transition_resumed_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: # None of the exit events of execute or gather_dep recommend a transition to # released assert not ts.done ts.state = "cancelled" ts.next = None return {}, [] def _transition_cancelled_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """ See also -------- _transition_cancelled_waiting _transition_resumed_fetch """ if ts.previous == "flight": if ts.done: # gather_dep just completed for a cancelled task. # Discard output and possibly forget return {ts: "released"}, [] else: # Forget the task was cancelled to begin with ts.state = "flight" ts.previous = None return {}, [] else: assert ts.previous in ("executing", "long-running") # None of the exit events of execute recommend a transition to fetch assert not ts.done ts.state = "resumed" ts.next = "fetch" return {}, [] def _transition_cancelled_waiting( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """ See also -------- _transition_cancelled_fetch _transition_cancelled_or_resumed_long_running _transition_resumed_fetch """ # None of the exit events of gather_dep or execute recommend a transition to # waiting assert not ts.done if ts.previous == "executing": # Forget the task was cancelled to begin with ts.state = "executing" ts.previous = None return {}, [] elif ts.previous == "long-running": # Forget the task was cancelled to begin with, and inform the scheduler # in arrears that it has seceded. # Note that, if the task seceded before it was cancelled, this will cause # the message to be sent twice. ts.state = "long-running" ts.previous = None smsg = LongRunningMsg( key=ts.key, compute_duration=None, stimulus_id=stimulus_id ) return {}, [smsg] else: assert ts.previous == "flight" ts.state = "resumed" ts.next = "waiting" return {}, [] def _transition_cancelled_released( self, ts: TaskState, *args: Any, # extra arguments of transitions to memory or error - ignored stimulus_id: str, ) -> RecsInstrs: if not ts.done: return {}, [] ts.previous = None ts.done = False return self._transition_generic_released(ts, stimulus_id=stimulus_id) def _transition_executing_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: """We can't stop executing a task just because the scheduler asked us to, so we're entering cancelled state and waiting until it completes. """ if self.validate: assert ts.state in ("executing", "long-running") assert not ts.next assert not ts.done ts.previous = cast(Literal["executing", "long-running"], ts.state) ts.state = "cancelled" return {}, [] def _transition_constrained_executing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "constrained" assert not ts.waiting_for_data assert ts.key not in self.data assert ts not in self.ready assert ts not in self.constrained for dep in ts.dependencies: assert dep.key in self.data or dep.key in self.actors ts.state = "executing" instr = Execute(key=ts.key, stimulus_id=stimulus_id) return {}, [instr] def _transition_ready_executing( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: if self.validate: assert ts.state == "ready" assert not ts.waiting_for_data assert ts.key not in self.data assert ts not in self.ready assert ts not in self.constrained assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) ts.state = "executing" instr = Execute(key=ts.key, stimulus_id=stimulus_id) return {}, [instr] def _transition_flight_fetch( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: # If this transition is called after the flight coroutine has finished, # we can reset the task and transition to fetch again. If it is not yet # finished, this should be a no-op if not ts.done: return {}, [] return self._transition_generic_fetch(ts, stimulus_id=stimulus_id) def _transition_flight_released( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: # None of the exit events of gather_dep recommend a transition to released assert not ts.done ts.previous = "flight" ts.next = None # See https://github.com/dask/distributed/pull/5046#discussion_r685093940 ts.state = "cancelled" return {}, [] def _transition_executing_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str ) -> RecsInstrs: """ See also -------- _transition_cancelled_or_resumed_long_running """ ts.state = "long-running" self.executing.discard(ts) self.long_running.add(ts) smsg = LongRunningMsg( key=ts.key, compute_duration=compute_duration, stimulus_id=stimulus_id ) return merge_recs_instructions( ({}, [smsg]), self._ensure_computing(), ) def _transition_cancelled_or_resumed_long_running( self, ts: TaskState, compute_duration: float, *, stimulus_id: str ) -> RecsInstrs: """Handles transitions: - cancelled(executing) -> long-running - cancelled(long-running) -> long-running (user called secede() twice) - resumed(executing->fetch) -> long-running - resumed(long-running->fetch) -> long-running (user called secede() twice) Unlike in the executing->long_running transition, do not send LongRunningMsg. From the scheduler's perspective, this task no longer exists (cancelled) or is in memory on another worker (resumed). So it shouldn't hear about it. Instead, we're going to send the LongRunningMsg when and if the task transitions back to waiting. See also -------- _transition_executing_long_running _transition_cancelled_waiting """ assert ts.previous in ("executing", "long-running") ts.previous = "long-running" self.executing.discard(ts) self.long_running.add(ts) return self._ensure_computing() def _transition_executing_memory( self, ts: TaskState, value: object, run_id: int, *, stimulus_id: str ) -> RecsInstrs: """This transition is *normally* triggered by ExecuteSuccessEvent. However, beware that it can also be triggered by scatter(). """ return self._transition_to_memory( ts, value, "task-finished", run_id=run_id, stimulus_id=stimulus_id ) def _transition_released_memory( self, ts: TaskState, value: object, run_id: int, *, stimulus_id: str ) -> RecsInstrs: """This transition is triggered by scatter(). This transition does not send any message back to the scheduler, because the scheduler doesn't know this key exists yet. """ return self._transition_to_memory( ts, value, False, run_id=RUN_ID_SENTINEL, stimulus_id=stimulus_id ) def _transition_flight_memory( self, ts: TaskState, value: object, run_id: int, *, stimulus_id: str ) -> RecsInstrs: """This transition is *normally* triggered by GatherDepSuccessEvent. However, beware that it can also be triggered by scatter(). """ return self._transition_to_memory( ts, value, "add-keys", run_id=RUN_ID_SENTINEL, stimulus_id=stimulus_id ) def _transition_resumed_memory( self, ts: TaskState, value: object, run_id: int, *, stimulus_id: str ) -> RecsInstrs: """Normally, we send to the scheduler a 'task-finished' message for a completed execution and 'add-data' for a completed replication from another worker. The scheduler's reaction to the two messages is fundamentally different; namely, add-data is only admissible for tasks that are already in memory on another worker, and won't trigger transitions. In the case of resumed tasks, the scheduler's expectation is set by ts.next - which means, the opposite of what the worker actually just completed. """ msg_type: Literal["add-keys", "task-finished"] if ts.previous in ("executing", "long-running"): assert ts.next == "fetch" msg_type = "add-keys" else: assert ts.previous == "flight" assert ts.next == "waiting" msg_type = "task-finished" ts.previous = None ts.next = None return self._transition_to_memory( ts, value, msg_type, run_id=run_id, stimulus_id=stimulus_id ) def _transition_to_memory( self, ts: TaskState, value: object, msg_type: Literal[False, "add-keys", "task-finished"], run_id: int, *, stimulus_id: str, ) -> RecsInstrs: """Insert a task's output in self.data and set the state to memory. This method is the one and only place where keys are inserted in self.data. There are three ways to get here: 1. task execution just terminated successfully. Initial state is one of - executing - long-running - resumed(prev=executing next=fetch) - resumed(prev=long-running next=fetch) 2. transfer from another worker terminated successfully. Initial state is - flight - resumed(prev=flight next=waiting) 3. scatter. In this case *normally* the task is in released state, but nothing stops a client to scatter a key while is in any other state; these race conditions are not well tested and are expected to misbehave. """ recommendations: Recs = {} instructions: Instructions = [] if self.validate: assert ts.key not in self.data assert ts.state != "memory" if ts.key in self.actors: self.actors[ts.key] = value else: start = time() try: self.data[ts.key] = value except Exception as e: # distributed.worker.memory.target is enabled and the value is # individually larger than target * memory_limit. # Inserting this task in the SpillBuffer caused it to immediately # spilled to disk, and it failed to serialize. # Third-party MutableMappings (dask-cuda etc.) may have other use cases # for this. msg = error_message(e) return {ts: tuple(msg.values()) + (run_id,)}, [] stop = time() if stop - start > 0.005: # The SpillBuffer has spilled this task (if larger than target) or other # tasks (if smaller than target) to disk. # Alternatively, a third-party MutableMapping may have spent time in # other activities, e.g. transferring data between GPGPU and system # memory. ts.startstops.append( {"action": "disk-write", "start": start, "stop": stop} ) ts.state = "memory" if ts.nbytes is None: ts.nbytes = sizeof(value) self.nbytes += ts.nbytes ts.type = type(value) for dep in ts.dependents: dep.waiting_for_data.discard(ts) if not dep.waiting_for_data and dep.state == "waiting": recommendations[dep] = "ready" self.log.append((ts.key, "put-in-memory", stimulus_id, time())) # NOTE: The scheduler's reaction to these two messages is fundamentally # different. Namely, add-keys is only admissible for tasks that are already in # memory on another worker, and won't trigger transitions. if msg_type == "add-keys": instructions.append(AddKeysMsg(keys=[ts.key], stimulus_id=stimulus_id)) elif msg_type == "task-finished": assert run_id != RUN_ID_SENTINEL instructions.append( self._get_task_finished_msg( ts, run_id=run_id, stimulus_id=stimulus_id, ) ) else: # This happens on scatter(), where the scheduler doesn't know yet that the # key exists. assert msg_type is False return recommendations, instructions def _transition_released_forgotten( self, ts: TaskState, *, stimulus_id: str ) -> RecsInstrs: recommendations: Recs = {} # Dependents _should_ be released by the scheduler before this if self.validate: assert not any(d.state != "forgotten" for d in ts.dependents) for dep in ts.dependencies: dep.dependents.discard(ts) if dep.state == "released" and not dep.dependents: recommendations[dep] = "forgotten" self._purge_state(ts) # Mark state as forgotten in case it is still referenced ts.state = "forgotten" self.tasks.pop(ts.key, None) return recommendations, [] # { # (start, finish): # transition_<start>_<finish>( # self, ts: TaskState, *args, stimulus_id: str # ) -> (recommendations, instructions) # } _TRANSITIONS_TABLE: ClassVar[ Mapping[tuple[TaskStateState, TaskStateState], Callable[..., RecsInstrs]] ] = { ("cancelled", "error"): _transition_cancelled_released, ("cancelled", "fetch"): _transition_cancelled_fetch, ("cancelled", "long-running"): _transition_cancelled_or_resumed_long_running, ("cancelled", "memory"): _transition_cancelled_released, ("cancelled", "missing"): _transition_cancelled_released, ("cancelled", "released"): _transition_cancelled_released, ("cancelled", "rescheduled"): _transition_cancelled_released, ("cancelled", "waiting"): _transition_cancelled_waiting, ("resumed", "error"): _transition_resumed_error, ("resumed", "fetch"): _transition_resumed_fetch, ("resumed", "long-running"): _transition_cancelled_or_resumed_long_running, ("resumed", "memory"): _transition_resumed_memory, ("resumed", "released"): _transition_resumed_released, ("resumed", "rescheduled"): _transition_resumed_rescheduled, ("constrained", "executing"): _transition_constrained_executing, ("constrained", "released"): _transition_generic_released, ("error", "released"): _transition_generic_released, ("executing", "error"): _transition_generic_error, ("executing", "long-running"): _transition_executing_long_running, ("executing", "memory"): _transition_executing_memory, ("executing", "released"): _transition_executing_released, ("executing", "rescheduled"): _transition_executing_rescheduled, ("fetch", "flight"): _transition_fetch_flight, ("fetch", "missing"): _transition_fetch_missing, ("fetch", "released"): _transition_fetch_released, ("flight", "error"): _transition_generic_error, ("flight", "fetch"): _transition_flight_fetch, ("flight", "memory"): _transition_flight_memory, ("flight", "missing"): _transition_flight_missing, ("flight", "released"): _transition_flight_released, ("long-running", "error"): _transition_generic_error, ("long-running", "memory"): _transition_executing_memory, ("long-running", "rescheduled"): _transition_executing_rescheduled, ("long-running", "released"): _transition_executing_released, ("memory", "released"): _transition_memory_released, ("missing", "error"): _transition_generic_error, ("missing", "fetch"): _transition_missing_fetch, ("missing", "released"): _transition_missing_released, ("missing", "waiting"): _transition_missing_waiting, ("ready", "executing"): _transition_ready_executing, ("ready", "released"): _transition_generic_released, ("released", "error"): _transition_generic_error, ("released", "fetch"): _transition_released_fetch, ("released", "forgotten"): _transition_released_forgotten, ("released", "memory"): _transition_released_memory, ("released", "missing"): _transition_generic_missing, ("released", "waiting"): _transition_released_waiting, ("waiting", "constrained"): _transition_waiting_constrained, ("waiting", "ready"): _transition_waiting_ready, ("waiting", "released"): _transition_generic_released, } def _notify_plugins(self, method_name: str, *args: Any, **kwargs: Any) -> None: for name, plugin in self.plugins.items(): if hasattr(plugin, method_name): try: getattr(plugin, method_name)(*args, **kwargs) except Exception: logger.info( "Plugin '%s' failed with exception", name, exc_info=True ) def _transition( self, ts: TaskState, finish: TaskStateState | tuple, *args: Any, stimulus_id: str, ) -> RecsInstrs: """Transition a key from its current state to the finish state See Also -------- Worker.transitions: wrapper around this method """ if isinstance(finish, tuple): # the concatenated transition path might need to access the tuple assert not args args = finish[1:] finish = cast(TaskStateState, finish[0]) if ts.state == finish: return {}, [] start = ts.state func = self._TRANSITIONS_TABLE.get((start, finish)) # Notes: # - in case of transition through released, this counter is incremented by 2 # - this increase happens before the actual transitions, so that it can # catch potential infinite recursions self.transition_counter += 1 if ( self.transition_counter_max and self.transition_counter >= self.transition_counter_max ): raise TransitionCounterMaxExceeded(ts.key, start, finish, self.story(ts)) if func is not None: recs, instructions = func(self, ts, *args, stimulus_id=stimulus_id) self._notify_plugins("transition", ts.key, start, finish) elif "released" not in (start, finish): # start -> "released" -> finish try: recs, instructions = self._transition( ts, "released", stimulus_id=stimulus_id ) v_state: TaskStateState v_args: list | tuple while v := recs.pop(ts, None): if isinstance(v, tuple): v_state, *v_args = v else: v_state, v_args = v, () if v_state == "forgotten": # We do not want to forget. The purpose of this # transition path is to get to `finish` continue recs, instructions = merge_recs_instructions( (recs, instructions), self._transition(ts, v_state, *v_args, stimulus_id=stimulus_id), ) recs, instructions = merge_recs_instructions( (recs, instructions), self._transition(ts, finish, *args, stimulus_id=stimulus_id), ) except (InvalidTransition, RecommendationsConflict) as e: raise InvalidTransition(ts.key, start, finish, self.story(ts)) from e else: raise InvalidTransition(ts.key, start, finish, self.story(ts)) self.log.append( ( # key ts.key, # initial start, # recommended finish, # final ts.state, # new recommendations { ts.key: new[0] if isinstance(new, tuple) else new for ts, new in recs.items() }, stimulus_id, time(), ) ) return recs, instructions def _resource_restrictions_satisfied(self, ts: TaskState) -> bool: if not ts.resource_restrictions: return True return all( self.available_resources[resource] >= needed for resource, needed in ts.resource_restrictions.items() ) def _acquire_resources(self, ts: TaskState) -> None: if ts.resource_restrictions: for resource, needed in ts.resource_restrictions.items(): self.available_resources[resource] -= needed def _release_resources(self, ts: TaskState) -> None: if ts.resource_restrictions: for resource, needed in ts.resource_restrictions.items(): self.available_resources[resource] += needed def _transitions(self, recommendations: Recs, *, stimulus_id: str) -> Instructions: """Process transitions until none are left This includes feedback from previous transitions and continues until we reach a steady state """ instructions = [] tasks = set() initial_states: dict[TaskState, TaskStateState] = {} def process_recs(recs: Recs) -> None: while recs: ts, finish = recs.popitem() tasks.add(ts) initial_states.setdefault(ts, ts.state) a_recs, a_instructions = self._transition( ts, finish, stimulus_id=stimulus_id ) recs.update(a_recs) instructions.extend(a_instructions) process_recs(recommendations.copy()) # We could call _ensure_communicating after we change something that could # trigger a new call to gather_dep (e.g. on transitions to fetch, # GatherDepDoneEvent, or RetryBusyWorkerEvent). However, doing so we'd # potentially call it too early, before all tasks have transitioned to fetch. # This in turn would hurt aggregation of multiple tasks into a single GatherDep # instruction. # Read: https://github.com/dask/distributed/issues/6497 a_recs, a_instructions = self._ensure_communicating(stimulus_id=stimulus_id) instructions += a_instructions process_recs(a_recs) self.task_counter.transitions(initial_states) if self.validate: # Full state validation is very expensive for ts in tasks: self.validate_task(ts) return instructions ########## # Events # ########## @singledispatchmethod def _handle_event(self, ev: StateMachineEvent) -> RecsInstrs: raise TypeError(ev) # pragma: nocover @_handle_event.register def _handle_update_data(self, ev: UpdateDataEvent) -> RecsInstrs: recommendations: Recs = {} for key, value in ev.data.items(): try: ts = self.tasks[key] except KeyError: self.tasks[key] = ts = TaskState(key) self.task_counter.new_task(ts) recommendations[ts] = ("memory", value, RUN_ID_SENTINEL) self.log.append( (key, "receive-from-scatter", ts.state, ev.stimulus_id, time()) ) return recommendations, [] @_handle_event.register def _handle_free_keys(self, ev: FreeKeysEvent) -> RecsInstrs: """Handler to be called by the scheduler. The given keys are no longer referred to and required by the scheduler. The worker is now allowed to release the key, if applicable. This does not guarantee that the memory is released since the worker may still decide to hold on to the data and task since it is required by an upstream dependency. """ self.log.append(("free-keys", ev.keys, ev.stimulus_id, time())) recommendations: Recs = {} for key in ev.keys: ts = self.tasks.get(key) if ts: recommendations[ts] = "released" return recommendations, [] @_handle_event.register def _handle_remove_replicas(self, ev: RemoveReplicasEvent) -> RecsInstrs: """Stream handler notifying the worker that it might be holding unreferenced, superfluous data. This should not actually happen during ordinary operations and is only intended to correct any erroneous state. An example where this is necessary is if a worker fetches data for a downstream task but that task is released before the data arrives. In this case, the scheduler will notify the worker that it may be holding this unnecessary data, if the worker hasn't released the data itself, already. This handler only releases tasks that are indeed in state memory. For stronger guarantees, see handler free_keys """ recommendations: Recs = {} instructions: Instructions = [] for key in ev.keys: ts = self.tasks.get(key) if ts is None or ts.state != "memory": continue # If the task is still in executing or long-running, the scheduler # should never have asked the worker to drop this key. # We cannot simply forget it because there is a time window between # setting the state to executing/long-running and # preparing/collecting the data for the task. # If a dependency was released during this time, this would pop up # as a KeyError during execute which is hard to understand for dep in ts.dependents: if dep.state in ("executing", "long-running"): raise RuntimeError( f"Cannot remove replica of {ts.key!r} while {dep.key!r} in state {dep.state!r}." ) # pragma: no cover self.log.append((ts.key, "remove-replica", ev.stimulus_id, time())) recommendations[ts] = "released" return recommendations, instructions @_handle_event.register def _handle_acquire_replicas(self, ev: AcquireReplicasEvent) -> RecsInstrs: if self.validate: assert ev.who_has.keys() == ev.nbytes.keys() assert all(ev.who_has.values()) recommendations: Recs = {} for key, nbytes in ev.nbytes.items(): ts = self._ensure_task_exists( key=key, # Transfer this data after all dependency tasks of computations with # default or explicitly high (>0) user priority and before all # computations with low priority (<0). Note that the priority= parameter # of compute() is multiplied by -1 before it reaches TaskState.priority. priority=(1,), stimulus_id=ev.stimulus_id, ) if ts.state != "memory": ts.nbytes = nbytes recommendations[ts] = "fetch" self._update_who_has(ev.who_has) return recommendations, [] @_handle_event.register def _handle_compute_task(self, ev: ComputeTaskEvent) -> RecsInstrs: try: ts = self.tasks[ev.key] logger.debug( "Asked to compute an already known task %s", {"task": ts, "stimulus_id": ev.stimulus_id}, ) except KeyError: self.tasks[ev.key] = ts = TaskState(ev.key) self.task_counter.new_task(ts) self.log.append((ev.key, "compute-task", ts.state, ev.stimulus_id, time())) ts.run_id = ev.run_id recommendations: Recs = {} instructions: Instructions = [] if ts.state in READY | { "executing", "long-running", "waiting", }: pass elif ts.state == "memory": instructions.append( self._get_task_finished_msg( ts, run_id=ev.run_id, stimulus_id=ev.stimulus_id ) ) elif ts.state in { "released", "fetch", "flight", "missing", "cancelled", "resumed", "error", }: recommendations[ts] = "waiting" ts.run_spec = ev.run_spec priority = ev.priority + (self.generation,) self.generation -= 1 if ev.actor: self.actors[ts.key] = None ts.exception = None ts.traceback = None ts.exception_text = "" ts.traceback_text = "" ts.priority = priority ts.duration = ev.duration ts.annotations = ev.annotations ts.span_id = ev.span_id # If we receive ComputeTaskEvent twice for the same task, resources may have # changed, but the task is still running. Preserve the previous resource # restrictions so that they can be properly released when it eventually # completes. if not ( ts.state in ("cancelled", "resumed") and ts.previous in ("executing", "long-running") ): ts.resource_restrictions = ev.resource_restrictions if self.validate: assert ev.who_has.keys() == ev.nbytes.keys() for dep_workers in ev.who_has.values(): assert dep_workers assert len(dep_workers) == len(set(dep_workers)) for dep_key, nbytes in ev.nbytes.items(): dep_ts = self._ensure_task_exists( key=dep_key, priority=priority, stimulus_id=ev.stimulus_id, ) if dep_ts.state != "memory": dep_ts.nbytes = nbytes # link up to child / parents ts.dependencies.add(dep_ts) dep_ts.dependents.add(ts) self._update_who_has(ev.who_has) else: raise RuntimeError( # pragma: nocover f"Unexpected task state encountered for {ts}; " f"stimulus_id={ev.stimulus_id}; story={self.story(ts)}" ) return recommendations, instructions def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState]: """Common code for the handlers of all subclasses of GatherDepDoneEvent. Yields the tasks that need to transition out of flight. The task states can be flight, cancelled, or resumed, but in case of scatter() they can also be in memory or error states. See also -------- _execute_done_common """ self.transfer_incoming_bytes -= ev.total_nbytes keys = self.in_flight_workers.pop(ev.worker) for key in keys: ts = self.tasks[key] ts.done = True ts.coming_from = None self.in_flight_tasks.remove(ts) yield ts @_handle_event.register def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: """gather_dep terminated successfully. The response may contain fewer keys than the request. """ recommendations: Recs = {} for ts in self._gather_dep_done_common(ev): if ts.key in ev.data: recommendations[ts] = ("memory", ev.data[ts.key], ts.run_id) else: self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) if self.validate: assert ts.state != "fetch" assert ts not in self.data_needed[ev.worker] ts.who_has.discard(ev.worker) if ev.worker in self.has_what: self.has_what[ev.worker].discard(ts.key) recommendations[ts] = "fetch" return recommendations, [] @_handle_event.register def _handle_gather_dep_busy(self, ev: GatherDepBusyEvent) -> RecsInstrs: """gather_dep terminated: remote worker is busy""" # Avoid hammering the worker. If there are multiple replicas # available, immediately try fetching from a different worker. self.busy_workers.add(ev.worker) recommendations: Recs = {} refresh_who_has = [] for ts in self._gather_dep_done_common(ev): recommendations[ts] = "fetch" if not ts.who_has - self.busy_workers: refresh_who_has.append(ts.key) instructions: Instructions = [ RetryBusyWorkerLater(worker=ev.worker, stimulus_id=ev.stimulus_id), ] if refresh_who_has: # All workers that hold known replicas of our tasks are busy. # Try querying the scheduler for unknown ones. instructions.append( RequestRefreshWhoHasMsg( keys=refresh_who_has, stimulus_id=ev.stimulus_id ) ) return recommendations, instructions @_handle_event.register def _handle_gather_dep_network_failure( self, ev: GatherDepNetworkFailureEvent ) -> RecsInstrs: """gather_dep terminated: network failure while trying to communicate with remote worker Though the network failure could be transient, we assume it is not, and preemptively act as though the other worker has died (including removing all keys from it, even ones we did not fetch). This optimization leads to faster completion of the fetch, since we immediately either retry a different worker, or ask the scheduler to inform us of a new worker if no other worker is available. """ recommendations: Recs = {} for ts in self._gather_dep_done_common(ev): self.log.append((ts.key, "missing-dep", ev.stimulus_id, time())) recommendations[ts] = "fetch" for ts in self.data_needed.pop(ev.worker, ()): if self.validate: assert ts.state == "fetch" assert ev.worker in ts.who_has if ts.who_has == {ev.worker}: # This can override a recommendation from the previous for loop recommendations[ts] = "missing" for key in self.has_what.pop(ev.worker, ()): ts = self.tasks[key] ts.who_has.remove(ev.worker) return recommendations, [] @_handle_event.register def _handle_gather_dep_failure(self, ev: GatherDepFailureEvent) -> RecsInstrs: """gather_dep terminated: generic error raised (not a network failure); e.g. data failed to deserialize. """ recommendations: Recs = { ts: ( "error", ev.exception, ev.traceback, ev.exception_text, ev.traceback_text, ts.run_id, ) for ts in self._gather_dep_done_common(ev) } return recommendations, [] @_handle_event.register def _handle_remove_worker(self, ev: RemoveWorkerEvent) -> RecsInstrs: recommendations: Recs = {} for ts in self.data_needed.pop(ev.worker, ()): if self.validate: assert ts.state == "fetch" assert ev.worker in ts.who_has if ts.who_has == {ev.worker}: # This can override a recommendation from the previous for loop recommendations[ts] = "missing" for key in self.has_what.pop(ev.worker, ()): ts = self.tasks[key] ts.who_has.remove(ev.worker) return recommendations, [] @_handle_event.register def _handle_secede(self, ev: SecedeEvent) -> RecsInstrs: ts = self.tasks.get(ev.key) if not ts: return {}, [] return {ts: ("long-running", ev.compute_duration)}, [] @_handle_event.register def _handle_steal_request(self, ev: StealRequestEvent) -> RecsInstrs: # There may be a race condition between stealing and releasing a task. # In this case the self.tasks is already cleared. The `None` will be # registered as `already-computing` on the other end ts = self.tasks.get(ev.key) state = ts.state if ts is not None else None smsg = StealResponseMsg(key=ev.key, state=state, stimulus_id=ev.stimulus_id) if state in READY | {"waiting"}: # If task is marked as "constrained" we haven't yet assigned it an # `available_resources` to run on, that happens in # `_transition_constrained_executing` assert ts return {ts: "released"}, [smsg] else: return {}, [smsg] @_handle_event.register def _handle_pause(self, ev: PauseEvent) -> RecsInstrs: """Prevent any further tasks to be executed or gathered. Tasks that are currently executing or in flight will continue to progress. """ self.running = False return {}, [] @_handle_event.register def _handle_unpause(self, ev: UnpauseEvent) -> RecsInstrs: """Emerge from paused status""" self.running = True return self._ensure_computing() @_handle_event.register def _handle_retry_busy_worker(self, ev: RetryBusyWorkerEvent) -> RecsInstrs: self.busy_workers.discard(ev.worker) return {}, [] @_handle_event.register def _handle_cancel_compute(self, ev: CancelComputeEvent) -> RecsInstrs: """Cancel a task on a best-effort basis. This is only possible while a task is in state `waiting` or `ready`; nothing will happen otherwise. """ ts = self.tasks.get(ev.key) if not ts or ts.state not in READY | {"waiting"}: return {}, [] self.log.append((ev.key, "cancel-compute", ev.stimulus_id, time())) # All possible dependents of ts should not be in state Processing on # scheduler side and therefore should not be assigned to a worker, yet. assert not ts.dependents return {ts: "released"}, [] def _execute_done_common( self, ev: ExecuteDoneEvent ) -> tuple[TaskState, Recs, Instructions]: """Common code for the handlers of all subclasses of ExecuteDoneEvent. The task state can be executing, cancelled, or resumed, but in case of scatter() it can also be in memory or error state. See also -------- _gather_dep_done_common """ # key *must* be still in tasks - see _transition_released_forgotten ts = self.tasks.get(ev.key) assert ts, self.story(ev.key) if self.validate: assert (ts in self.executing) != (ts in self.long_running) # XOR ts.done = True self.executed_count += 1 self._release_resources(ts) self.executing.discard(ts) self.long_running.discard(ts) recs, instr = self._ensure_computing() assert ts not in recs return ts, recs, instr @_handle_event.register def _handle_execute_success(self, ev: ExecuteSuccessEvent) -> RecsInstrs: """Task completed successfully""" ts, recs, instr = self._execute_done_common(ev) # This is used for scheduler-side occupancy heuristics; it's important that it # does not contain overhead from the thread pool or the worker's event loop # (which are not the task's fault and are unpredictable). ts.startstops.append({"action": "compute", "start": ev.start, "stop": ev.stop}) ts.nbytes = ev.nbytes ts.type = ev.type recs[ts] = ("memory", ev.value, ev.run_id) return recs, instr @_handle_event.register def _handle_execute_failure(self, ev: ExecuteFailureEvent) -> RecsInstrs: """Task execution failed""" ts, recs, instr = self._execute_done_common(ev) if ev.start is not None and ev.stop is not None: ts.startstops.append( {"action": "compute", "start": ev.start, "stop": ev.stop} ) recs[ts] = ( "error", ev.exception, ev.traceback, ev.exception_text, ev.traceback_text, ev.run_id, ) return recs, instr @_handle_event.register def _handle_reschedule(self, ev: RescheduleEvent) -> RecsInstrs: """Task raised Reschedule() exception while it was running. Note: this has nothing to do with work stealing, which instead causes a FreeKeysEvent. """ ts, recs, instr = self._execute_done_common(ev) recs[ts] = "rescheduled" return recs, instr @_handle_event.register def _handle_find_missing(self, ev: FindMissingEvent) -> RecsInstrs: if not self.missing_dep_flight: return {}, [] if self.validate: for ts in self.missing_dep_flight: assert not ts.who_has, self.story(ts) smsg = RequestRefreshWhoHasMsg( keys=[ts.key for ts in self.missing_dep_flight], stimulus_id=ev.stimulus_id, ) return {}, [smsg] @_handle_event.register def _handle_refresh_who_has(self, ev: RefreshWhoHasEvent) -> RecsInstrs: self._update_who_has(ev.who_has) recommendations: Recs = {} instructions: Instructions = [] for key in ev.who_has: ts = self.tasks.get(key) if not ts: continue if ts.who_has and ts.state == "missing": recommendations[ts] = "fetch" elif not ts.who_has and ts.state == "fetch": recommendations[ts] = "missing" # Note: if ts.who_has and ts.state == "fetch", we may have just acquired new # replicas whereas all previously known workers are in flight or busy. We # rely on _transitions to call _ensure_communicating every time, even in # absence of recommendations, to potentially kick off a new call to # gather_dep. return recommendations, instructions ############### # Diagnostics # ###############
[docs] def story(self, *keys_or_tasks_or_stimuli: str | Key | TaskState) -> list[tuple]: """Return all records from the transitions log involving one or more tasks or stimulus_id's """ keys_or_stimuli = { e.key if isinstance(e, TaskState) else e for e in keys_or_tasks_or_stimuli } return worker_story(keys_or_stimuli, self.log)
[docs] def stimulus_story( self, *keys_or_tasks: Key | TaskState ) -> list[StateMachineEvent]: """Return all state machine events involving one or more tasks""" keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks} return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys]
def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. See also -------- Client.dump_cluster_state distributed.utils.recursive_to_dict """ info = { "address": self.address, "nthreads": self.nthreads, "running": self.running, "ready": [ts.key for ts in self.ready.sorted()], "constrained": [ts.key for ts in self.constrained.sorted()], "data": dict.fromkeys(self.data), "data_needed": { w: [ts.key for ts in tss.sorted()] for w, tss in self.data_needed.items() }, "executing": {ts.key for ts in self.executing}, "has_what": dict(self.has_what), "long_running": {ts.key for ts in self.long_running}, "in_flight_tasks": {ts.key for ts in self.in_flight_tasks}, "in_flight_workers": self.in_flight_workers, "missing_dep_flight": [ts.key for ts in self.missing_dep_flight], "busy_workers": self.busy_workers, "log": self.log, "stimulus_log": self.stimulus_log, "transition_counter": self.transition_counter, "tasks": self.tasks, "task_counts": dict(self.task_counter.current_count()), "task_cumulative_elapsed": dict(self.task_counter.cumulative_elapsed()), } info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) ############## # Validation # ############## def _validate_task_memory(self, ts: TaskState) -> None: assert ts.key in self.data or ts.key in self.actors assert isinstance(ts.nbytes, int) assert not ts.waiting_for_data def _validate_task_executing(self, ts: TaskState) -> None: """Validate tasks: - ts.state == executing - ts.state == long-running - ts.state == cancelled, ts.previous == executing - ts.state == cancelled, ts.previous == long-running - ts.state == resumed, ts.previous == executing, ts.next == fetch - ts.state == resumed, ts.previous == long-running, ts.next == fetch """ if ts.state == "executing" or ts.previous == "executing": assert ts in self.executing assert ts not in self.long_running else: assert ts.state == "long-running" or ts.previous == "long-running" assert ts not in self.executing assert ts in self.long_running assert ts.run_spec is not None assert ts.key not in self.data assert not ts.waiting_for_data for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained # FIXME https://github.com/dask/distributed/issues/6893 # This assertion can be false for # - cancelled or resumed tasks # - executing tasks which used to be cancelled in the past # for dep in ts.dependencies: # assert dep.state == "memory", self.story(dep) # assert dep.key in self.data or dep.key in self.actors def _validate_task_ready(self, ts: TaskState) -> None: """Validate tasks: - ts.state == ready - ts.state == constrained """ if ts.state == "ready": assert not ts.resource_restrictions assert ts in self.ready assert ts not in self.constrained else: assert ts.resource_restrictions assert ts.state == "constrained" assert ts not in self.ready assert ts in self.constrained assert ts.key not in self.data assert not ts.done assert not ts.waiting_for_data assert all( dep.key in self.data or dep.key in self.actors for dep in ts.dependencies ) def _validate_task_waiting(self, ts: TaskState) -> None: assert ts.key not in self.data assert not ts.done assert ts in self.waiting assert ts.waiting_for_data assert ts.waiting_for_data == { dep for dep in ts.dependencies if dep.key not in self.data and dep.key not in self.actors } for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained def _validate_task_flight(self, ts: TaskState) -> None: """Validate tasks: - ts.state == flight - ts.state == cancelled, ts.previous == flight - ts.state == resumed, ts.previous == flight, ts.next == waiting """ assert ts.key not in self.data assert ts.key not in self.actors assert ts in self.in_flight_tasks for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained assert ts.coming_from assert ts.coming_from in self.in_flight_workers assert ts.key in self.in_flight_workers[ts.coming_from] def _validate_task_fetch(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts.key not in self.actors assert self.address not in ts.who_has assert not ts.done assert ts.who_has for w in ts.who_has: assert ts.key in self.has_what[w] assert ts in self.data_needed[w] for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained def _validate_task_missing(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts.key not in self.actors assert not ts.who_has assert not ts.done assert not any(ts.key in has_what for has_what in self.has_what.values()) assert ts in self.missing_dep_flight for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained def _validate_task_cancelled(self, ts: TaskState) -> None: assert ts.next is None if ts.previous in ("executing", "long-running"): self._validate_task_executing(ts) else: assert ts.previous == "flight" self._validate_task_flight(ts) def _validate_task_resumed(self, ts: TaskState) -> None: if ts.previous in ("executing", "long-running"): assert ts.next == "fetch" self._validate_task_executing(ts) else: assert ts.previous == "flight" assert ts.next == "waiting" self._validate_task_flight(ts) for dep in ts.dependents: assert dep not in self.ready assert dep not in self.constrained def _validate_task_released(self, ts: TaskState) -> None: assert ts.key not in self.data assert ts.key not in self.actors assert not ts.next assert not ts.previous for tss in self.data_needed.values(): assert ts not in tss assert ts not in self.executing assert ts not in self.in_flight_tasks assert ts not in self.missing_dep_flight # The below assert statement is true most of the time. If a task performs the # transition flight->cancel->waiting, its dependencies are normally in released # state. However, the compute-task call for their previous dependent provided # them with who_has, such that this assert is no longer true. # # assert not any(ts.key in has_what for has_what in self.has_what.values()) assert not ts.waiting_for_data assert not ts.done assert not ts.exception assert not ts.traceback def validate_task(self, ts: TaskState) -> None: try: if ts.key in self.tasks: assert self.tasks[ts.key] is ts if ts.state == "memory": self._validate_task_memory(ts) elif ts.state == "waiting": self._validate_task_waiting(ts) elif ts.state == "missing": self._validate_task_missing(ts) elif ts.state == "cancelled": self._validate_task_cancelled(ts) elif ts.state == "resumed": self._validate_task_resumed(ts) elif ts.state in ("ready", "constrained"): self._validate_task_ready(ts) elif ts.state in ("executing", "long-running"): self._validate_task_executing(ts) elif ts.state == "flight": self._validate_task_flight(ts) elif ts.state == "fetch": self._validate_task_fetch(ts) elif ts.state == "released": self._validate_task_released(ts) except Exception as e: logger.exception(e) raise InvalidTaskState( key=ts.key, state=ts.state, story=self.story(ts) ) from e def validate_state(self) -> None: for ts in self.tasks.values(): # check that worker has task for worker in ts.who_has: assert worker != self.address assert ts.key in self.has_what[worker] # check that deps have a set state and that dependency<->dependent links # are there for dep in ts.dependencies: # self.tasks was just a dict of tasks # and this check was originally that the key was in `task_state` # so we may have popped the key out of `self.tasks` but the # dependency can still be in `memory` before GC grabs it...? # Might need better bookkeeping assert self.tasks[dep.key] is dep assert ts in dep.dependents, self.story(ts) for ts_wait in ts.waiting_for_data: assert self.tasks[ts_wait.key] is ts_wait assert ts_wait.state in WAITING_FOR_DATA, self.story(ts_wait) for worker, keys in self.has_what.items(): assert worker != self.address for k in keys: assert k in self.tasks, self.story(k) assert worker in self.tasks[k].who_has # Test contents of the various sets of TaskState objects fetch_tss = set() for worker, tss in self.data_needed.items(): for ts in tss: fetch_tss.add(ts) assert ts.state == "fetch", self.story(ts) assert worker in ts.who_has, f"{ts}; {ts.who_has=}" assert len(fetch_tss) == self.fetch_count for ts in self.missing_dep_flight: assert ts.state == "missing", self.story(ts) for ts in self.ready: assert ts.state == "ready", self.story(ts) for ts in self.constrained: assert ts.state == "constrained", self.story(ts) for ts in self.executing: assert ts.state == "executing" or ( ts.state in ("cancelled", "resumed") and ts.previous == "executing" ), self.story(ts) for ts in self.long_running: assert ts.state == "long-running" or ( ts.state in ("cancelled", "resumed") and ts.previous == "long-running" ), self.story(ts) for ts in self.in_flight_tasks: assert ts.state == "flight" or ( ts.state in ("cancelled", "resumed") and ts.previous == "flight" ), self.story(ts) for ts in self.waiting: assert ts.state == "waiting", self.story(ts) # Test that there aren't multiple TaskState objects with the same key in any # Set[TaskState]. See note in TaskState.__hash__. for ts in chain( *self.data_needed.values(), self.missing_dep_flight, self.ready, self.constrained, self.in_flight_tasks, self.executing, self.long_running, self.waiting, ): assert self.tasks[ts.key] is ts, f"{self.tasks[ts.key]} is not {ts}" expect_nbytes = sum( self.tasks[key].nbytes or 0 for key in chain(self.data, self.actors) ) assert self.nbytes == expect_nbytes, f"{self.nbytes=}; expected {expect_nbytes}" for key in self.data: assert key in self.tasks, self.story(key) for key in self.actors: assert key in self.tasks, self.story(key) for ts in self.tasks.values(): self.validate_task(ts) expect_state_count = Counter( (ts.prefix, ts.state) for ts in self.tasks.values() ) assert self.task_counter.current_count() == expect_state_count, ( self.task_counter.current_count(), expect_state_count, ) if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max self._validate_resources() def _validate_resources(self) -> None: """Assert that available_resources + resources held by tasks = total_resources""" assert self.total_resources.keys() == self.available_resources.keys() total = self.total_resources.copy() for k, v in self.available_resources.items(): assert v > -1e-9, self.available_resources total[k] -= v for ts in self.all_running_tasks: if ts.resource_restrictions: for k, v in ts.resource_restrictions.items(): assert v >= 0, (ts, ts.resource_restrictions) total[k] -= v assert all((abs(v) < 1e-9) for v in total.values()), total
[docs]class BaseWorker(abc.ABC): """Wrapper around the :class:`WorkerState` that implements instructions handling. This is an abstract class with several ``@abc.abstractmethod`` methods, to be subclassed by :class:`~distributed.worker.Worker` and by unit test mock-ups. """ state: WorkerState _async_instructions: set[asyncio.Task] def __init__(self, state: WorkerState): self.state = state self._async_instructions = set() self.debug = dask.config.get("distributed.worker.validate") def _start_async_instruction( # type: ignore[valid-type] self, task_name: str, func: Callable[P, Awaitable[StateMachineEvent]], /, *args: P.args, span_id: str | None = None, **kwargs: P.kwargs, ) -> None: """Execute an asynchronous instruction inside an asyncio task""" ledger = DelayedMetricsLedger() @wraps(func) async def wrapper() -> StateMachineEvent: with ledger.record(key="async-instruction"): return await func(*args, **kwargs) task = asyncio.create_task(wrapper(), name=task_name) self._async_instructions.add(task) task.add_done_callback( partial(self._finish_async_instruction, ledger=ledger, span_id=span_id) ) def _finish_async_instruction( self, task: asyncio.Task[StateMachineEvent], ledger: DelayedMetricsLedger, span_id: str | None, ) -> None: """The asynchronous instruction just completed; process the returned stimulus. """ self._async_instructions.remove(task) try: # This *should* never raise any other exceptions stim = task.result() except asyncio.CancelledError: # This should exclusively happen in Worker.close() logger.warning(f"Async instruction for {task} ended with CancelledError") return except BaseException: # pragma: nocover # This should never happen logger.exception(f"Unhandled exception in async instruction for {task}") raise # Capture metric events in _transition_to_memory() # As this may trigger calls to _start_async_instruction for more tasks, # make sure we don't endlessly pile up context_meter callbacks by specifying # the same key as in _start_async_instruction. with ledger.record(key="async-instruction"): self.handle_stimulus(stim) self._finalize_metrics(stim, ledger, span_id) def _finalize_metrics( self, stim: StateMachineEvent, ledger: DelayedMetricsLedger, span_id: str | None = None, ) -> None: activity: tuple[str | None, ...] coarse_time: str | Literal[False] if not isinstance(stim, ExecuteDoneEvent): assert span_id is None if isinstance(stim, GatherDepSuccessEvent): activity = ("gather-dep",) for key in stim.data: ts = self.state.tasks.get(key) if ts and ts.state == "memory": # At least one key was fetched and stored in memory coarse_time = False break else: coarse_time = "cancelled" if stim.data else "missing" elif isinstance(stim, (GatherDepFailureEvent, GatherDepNetworkFailureEvent)): activity = ("gather-dep",) coarse_time = "failed" elif isinstance(stim, GatherDepBusyEvent): activity = ("gather-dep",) coarse_time = "busy" elif isinstance(stim, ExecuteSuccessEvent): activity = ("execute", span_id, key_split(stim.key)) ts = self.state.tasks.get(stim.key) coarse_time = False if ts and ts.state == "memory" else "cancelled" elif isinstance(stim, ExecuteFailureEvent): activity = ("execute", span_id, key_split(stim.key)) coarse_time = "failed" elif isinstance(stim, RescheduleEvent): activity = ("execute", span_id, key_split(stim.key)) coarse_time = "cancelled" else: assert isinstance(stim, RetryBusyWorkerEvent), stim return for label, value, unit in ledger.finalize(coarse_time=coarse_time): self.digest_metric((*activity, label, unit), value)
[docs] def handle_stimulus(self, *stims: StateMachineEvent) -> None: """Forward one or more external stimuli to :meth:`WorkerState.handle_stimulus` and process the returned instructions, invoking the relevant Worker callbacks (``@abc.abstractmethod`` methods below). Spawn asyncio tasks for all asynchronous instructions and start tracking them. See also -------- WorkerState.handle_stimulus """ instructions = self.state.handle_stimulus(*stims) for inst in instructions: if isinstance(inst, SendMessageToScheduler): self.batched_send(inst.to_dict()) elif isinstance(inst, GatherDep): assert inst.to_gather if self.debug: keys_str = ", ".join(map(str, peekn(27, inst.to_gather)[0])) if len(keys_str) > 80: keys_str = keys_str[:77] + "..." else: keys_str = "..." self._start_async_instruction( f"gather_dep({inst.worker}, {{{keys_str}}})", self.gather_dep, inst.worker, inst.to_gather, total_nbytes=inst.total_nbytes, stimulus_id=inst.stimulus_id, ) elif isinstance(inst, Execute): ts = self.state.tasks[inst.key] self._start_async_instruction( f"execute({inst.key!r})", self.execute, inst.key, span_id=ts.span_id, stimulus_id=inst.stimulus_id, ) elif isinstance(inst, RetryBusyWorkerLater): self._start_async_instruction( f"retry_busy_worker_later({inst.worker})", self.retry_busy_worker_later, inst.worker, ) else: raise TypeError(inst) # pragma: nocover
[docs] async def close(self, timeout: float = 30) -> None: """Cancel all asynchronous instructions""" if not self._async_instructions: return for task in self._async_instructions: task.cancel() # async tasks can handle cancellation and could take an arbitrary amount # of time to terminate _, pending = await asyncio.wait(self._async_instructions, timeout=timeout) for task in pending: logger.error( f"Failed to cancel asyncio task after {timeout} seconds: {task}" )
[docs] @abc.abstractmethod def batched_send(self, msg: dict[str, Any]) -> None: """Send a fire-and-forget message to the scheduler through bulk comms. Parameters ---------- msg: dict msgpack-serializable message to send to the scheduler. Must have a 'op' key which is registered in Scheduler.stream_handlers. """
[docs] @abc.abstractmethod async def gather_dep( self, worker: str, to_gather: Collection[Key], total_nbytes: int, *, stimulus_id: str, ) -> StateMachineEvent: """Gather dependencies for a task from a worker who has them Parameters ---------- worker : str Address of worker to gather dependencies from to_gather : list Keys of dependencies to gather from worker -- this is not necessarily equivalent to the full list of dependencies of ``dep`` as some dependencies may already be present on this worker. total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """
[docs] @abc.abstractmethod async def execute(self, key: Key, *, stimulus_id: str) -> StateMachineEvent: """Execute a task"""
[docs] @abc.abstractmethod async def retry_busy_worker_later(self, worker: str) -> StateMachineEvent: """Wait some time, then take a peer worker out of busy state"""
[docs] @abc.abstractmethod def digest_metric(self, name: Hashable, value: float) -> None: """Log an arbitrary numerical metric"""
class TaskCounter: # When the monotonic timer was last sampled _previous_ts: float | None # Current task counts, per prefix, per state, on the worker _current_count: Counter[tuple[str, TaskStateState]] # Tasks that have just been inserted in the WorkerState and will be immediately # transitioned to a new state _new_tasks: set[TaskState] # Ever-increasing cumulative task runtimes, per prefix, including tasks that have # left a state or even that don't exist anymore, updated as of the latest call # to cumulative_elapsed() or transitions() _cumulative_elapsed: defaultdict[tuple[str, TaskStateState], float] __slots__ = tuple(__annotations__) def __init__(self) -> None: self._previous_ts = None self._current_count = Counter() self._new_tasks = set() self._cumulative_elapsed = defaultdict(float) def current_count(self, by_prefix: bool = True) -> Counter: """Return current count of tasks. Parameters ---------- by_prefix: bool, optional True (default) Return counter of (task prefix, task state) -> count False Return counter of task state -> count """ if by_prefix: return self._current_count out: Counter[TaskStateState] = Counter() for (_, state), n in self._current_count.items(): out[state] += n return out def cumulative_elapsed(self, by_prefix: bool = True) -> Mapping[Any, float]: """Ever-increasing cumulative task runtimes, including tasks that have left a state or even that don't exist anymore, updated as of the moment when this method is called. Parameters ---------- by_prefix: bool, optional True (default) Return mapping of (task prefix, task state) -> seconds False Return mapping of task state -> seconds """ if self._current_count: assert self._previous_ts is not None now = monotonic() elapsed = now - self._previous_ts self._previous_ts = now for k, n_tasks in self._current_count.items(): self._cumulative_elapsed[k] += elapsed * n_tasks if by_prefix: return self._cumulative_elapsed out: defaultdict[TaskStateState, float] = defaultdict(float) for (_, state), n in self._cumulative_elapsed.items(): out[state] += n return out def new_task(self, ts: TaskState) -> None: """A new task has just been created and will be immediately fed into the recommendations for transitions """ self._new_tasks.add(ts) def transitions(self, prev_states: dict[TaskState, TaskStateState]) -> None: """Tasks have just transitioned to a new state""" if not prev_states and not self._new_tasks: return now = monotonic() if self._current_count: assert self._previous_ts is not None elapsed = now - self._previous_ts for k, n_tasks in self._current_count.items(): self._cumulative_elapsed[k] += elapsed * n_tasks self._previous_ts = now for ts, prev_state in prev_states.items(): if ts.state != "forgotten": self._current_count[ts.prefix, ts.state] += 1 if ts in self._new_tasks: self._new_tasks.discard(ts) else: dec_count = self._current_count[ts.prefix, prev_state] - 1 if dec_count > 0: self._current_count[ts.prefix, prev_state] = dec_count else: assert dec_count == 0 del self._current_count[ts.prefix, prev_state] for ts in self._new_tasks: # This happens exclusively on a transition from cancelled(flight) to # resumed(flight->waiting) of a task with dependencies; the dependencies # will remain in released state and never transition to anything else. self._current_count[ts.prefix, ts.state] += 1 self._new_tasks.clear() class DeprecatedWorkerStateAttribute: name: str target: str | None def __init__(self, target: str | None = None): self.target = target def __set_name__(self, owner: type, name: str) -> None: self.name = name def _warn_deprecated(self) -> None: warnings.warn( f"The `Worker.{self.name}` attribute has been moved to " f"`Worker.state.{self.target or self.name}`", FutureWarning, ) def __get__(self, instance: Worker | None, owner: type[Worker]) -> Any: if instance is None: # This is triggered by Sphinx return None # pragma: nocover self._warn_deprecated() return getattr(instance.state, self.target or self.name) def __set__(self, instance: Worker, value: Any) -> None: self._warn_deprecated() setattr(instance.state, self.target or self.name, value)