from __future__ import annotations
import copy
import uuid
import weakref
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator, Mapping
from contextlib import contextmanager
from itertools import islice
from typing import TYPE_CHECKING, Any, TypedDict
import dask.config
from dask.typing import Key
from distributed.collections import sum_mappings
from distributed.itertools import ffill
from distributed.metrics import time
if TYPE_CHECKING:
# Needed to avoid Sphinx WARNING: more than one target found for cross-reference
# 'TaskState' and 'WorkerState'"
# https://github.com/agronholm/sphinx-autodoc-typehints#dealing-with-circular-imports
from distributed import Scheduler, Worker
from distributed import scheduler as scheduler_module
from distributed.client import SourceCode
from distributed.scheduler import TaskGroup, TaskStateState
CONTEXTS_WITH_SPAN_ID = ("execute", "p2p")
class SpanMetadata(TypedDict):
collections: list[dict]
[docs]@contextmanager
def span(*tags: str) -> Iterator[str]:
"""Tag group of tasks to be part of a certain group, called a span.
This context manager can be nested, thus creating sub-spans. If you close and
re-open a span context manager with the same tag, you'll end up with two separate
spans.
Every cluster defines a global "default" span when no span has been defined by the
client; the default span is automatically closed and reopened when all tasks
associated to it have been completed; in other words the cluster is idle save for
tasks that are explicitly annotated by a span. Note that, in some edge cases, you
may end up with overlapping default spans, e.g. if a worker crashes and all unique
tasks that were in memory on it need to be recomputed.
You may capture the ID of a span on the client to match it with the
:class:`~distributed.spans.Span` objects the scheduler:
>>> client = Client()
>>> with span("my workflow") as span_id:
... client.submit(lambda: "Hello world!").result()
>>> client.cluster.scheduler.extensions["spans"].spans[span_id]
Span<name=('my workflow',), id=5dc9b908-116b-49a5-b0d7-5a681f49a111>
Notes
-----
You may retrieve the current span with ``dask.get_annotations().get("span")``.
You can do so in the client code as well as from inside a task.
"""
if not tags:
raise ValueError("Must specify at least one span tag")
annotation = dask.get_annotations().get("span")
prev_tags = annotation["name"] if annotation else ()
# You must specify the full history of IDs, not just the parent, because
# otherwise you would not be able to uniquely identify grandparents when
# they have no tasks of their own.
prev_ids = annotation["ids"] if annotation else ()
ids = tuple(str(uuid.uuid4()) for _ in tags)
with dask.annotate(span={"name": prev_tags + tags, "ids": prev_ids + ids}):
yield ids[-1]
[docs]class Span:
#: (<tag>, <tag>, ...)
#: Matches ``TaskState.annotations["span"]["name"]``, both on the scheduler and the
#: worker.
name: tuple[str, ...]
#: Unique ID, generated by :func:`~distributed.span` and
#: taken from ``TaskState.annotations["span"]["id"][-1]``.
#: Matches ``distributed.scheduler.TaskState.group.span_id``
#: and ``distributed.worker_state_machine.TaskState.span_id``.
id: str
_parent: weakref.ref[Span] | None
#: Direct children of this span, sorted by creation time
children: list[Span]
#: Task groups *directly* belonging to this span.
#:
#: See Also
#: --------
# traverse_groups
#:
#: Notes
#: -----
#: TaskGroups are forgotten by the Scheduler when the last task is forgotten, but
#: remain referenced here indefinitely. If a user calls compute() twice on the same
#: collection, you'll have more than one group with the same tg.name in this set!
#: For the same reason, while the same TaskGroup object is guaranteed to be attached
#: to exactly one Span, you may have different TaskGroups with the same key attached
#: to different Spans.
groups: set[TaskGroup]
#: Time when the span first appeared on the scheduler.
#: The same property on parent spans is always less than or equal to this.
#:
#: See Also
#: --------
#: start
#: stop
enqueued: float
#: Source code snippets, if it was sent by the client.
#: We're using a dict without values as an insertion-sorted set.
_code: dict[tuple[SourceCode, ...], None]
_metadata: SpanMetadata | None
_cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]
#: reference to SchedulerState.total_nthreads_history
_total_nthreads_history: list[tuple[float, int]]
#: Length of total_nthreads_history when this span was enqueued
_total_nthreads_offset: int
# Support for weakrefs to a class with __slots__
__weakref__: Any
__slots__ = tuple(__annotations__)
_metadata_seen: set[int] = set()
def __init__(
self,
name: tuple[str, ...],
id_: str,
parent: Span | None,
total_nthreads_history: list[tuple[float, int]],
):
self.name = name
self.id = id_
self._parent = weakref.ref(parent) if parent is not None else None
self.enqueued = time()
self.children = []
self.groups = set()
self._code = {}
self._metadata = None
# Don't cast int metrics to float
self._cumulative_worker_metrics = defaultdict(int)
assert len(total_nthreads_history) > 0
self._total_nthreads_history = total_nthreads_history
self._total_nthreads_offset = len(total_nthreads_history) - 1
def __repr__(self) -> str:
return f"Span<name={self.name}, id={self.id}>"
@property
def parent(self) -> Span | None:
if self._parent:
out = self._parent()
assert out
return out
return None
@property
def annotation(self) -> dict[str, tuple[str, ...]] | None:
"""Rebuild the dask graph annotation which contains the full id history
Note that this may not match the original annotation in case of TaskGroup
collision.
"""
if self.name == ("default",):
return None
ids = []
node: Span | None = self
while node:
ids.append(node.id)
node = node.parent
return {"name": self.name, "ids": tuple(reversed(ids))}
[docs] def traverse_spans(self) -> Iterator[Span]:
"""Top-down recursion of all spans belonging to this branch off span tree,
including self
"""
yield self
for child in self.children:
yield from child.traverse_spans()
[docs] def traverse_groups(self) -> Iterator[TaskGroup]:
"""All TaskGroups belonging to this branch of span tree"""
for span in self.traverse_spans():
yield from span.groups
@property
def start(self) -> float:
"""Earliest time when a task belonging to this span tree started computing;
0 if no task has *finished* computing yet.
Notes
-----
This is not updated until at least one task has *finished* computing.
It could move backwards as tasks complete.
See Also
--------
enqueued
stop
distributed.scheduler.TaskGroup.start
"""
out = min(
(tg.start for tg in self.traverse_groups() if tg.start != 0.0),
default=0.0,
)
if out:
# absorb small errors in worker delay calculation
out = max(out, self.enqueued)
return out
@property
def stop(self) -> float:
"""When this span tree finished computing, or current timestamp if it didn't
finish yet.
Notes
-----
This differs from ``TaskGroup.stop`` when there aren't unfinished tasks; is also
will never be zero.
See Also
--------
enqueued
start
done
distributed.scheduler.TaskGroup.stop
"""
if self.done:
out = max(tg.stop for tg in self.traverse_groups())
else:
out = time()
# absorb small errors in worker delay calculation, as well as in time() not
# being perfectly monotonic
return max(out, self.enqueued)
@property
def metadata(self) -> SpanMetadata | None:
return self._metadata
@property
def states(self) -> dict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.
See Also
--------
distributed.scheduler.TaskGroup.states
"""
return sum_mappings(tg.states for tg in self.traverse_groups())
@property
def done(self) -> bool:
"""Return True if all tasks in this span tree are completed; False otherwise.
Notes
-----
This property may transition from True to False, e.g. when a new sub-span is
added or when a worker that contained the only replica of a task in memory
crashes and the task need to be recomputed.
See Also
--------
distributed.scheduler.TaskGroup.done
"""
return all(tg.done for tg in self.traverse_groups())
@property
def all_durations(self) -> dict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action
See Also
--------
duration
distributed.scheduler.TaskGroup.all_durations
"""
return sum_mappings(tg.all_durations for tg in self.traverse_groups())
@property
def duration(self) -> float:
"""The total amount of time spent on all tasks in this span tree
See Also
--------
all_durations
distributed.scheduler.TaskGroup.duration
"""
return sum(tg.duration for tg in self.traverse_groups())
@property
def nbytes_total(self) -> int:
"""The total number of bytes that this span tree has produced
See Also
--------
distributed.scheduler.TaskGroup.nbytes_total
"""
return sum(tg.nbytes_total for tg in self.traverse_groups())
@property
def code(self) -> list[tuple[SourceCode, ...]]:
"""Code snippets, sent by the client on compute(), persist(), and submit().
Only populated if ``distributed.diagnostics.computations.nframes`` is non-zero.
"""
# Deduplicate, but preserve order
return list(
dict.fromkeys(sc for child in self.traverse_spans() for sc in child._code)
)
@property
def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]:
"""Replica of ``Worker.digests_total`` and
``Scheduler.cumulative_worker_metrics``, but only for the metrics that can be
attributed to the current span tree. The span id has been removed from the key.
At the moment of writing, all keys are
``("execute", <task prefix>, <activity>, <unit>)``
or
``("p2p", <where>, <activity>, <unit>)``
but more may be added in the future with a different format; please test e.g.
for ``k[0] == "execute"``.
"""
out = sum_mappings(
child._cumulative_worker_metrics for child in self.traverse_spans()
)
known_seconds = sum(
v for k, v in out.items() if k[0] == "execute" and k[-1] == "seconds"
)
# Besides rounding errors, you may get negative unknown seconds if a user
# manually invokes `context_meter.digest_metric`.
unknown_seconds = max(0.0, self.active_cpu_seconds - known_seconds)
out["execute", "N/A", "idle or other spans", "seconds"] = unknown_seconds
return out
[docs] @staticmethod
def merge(*items: Span) -> Span:
"""Merge multiple spans into a synthetic one.
The input spans must not be related with each other.
"""
if not items:
raise ValueError("Nothing to merge")
out = Span(
name=("(merged)",),
id_="(merged)",
parent=None,
total_nthreads_history=items[0]._total_nthreads_history,
)
out._total_nthreads_offset = min(
child._total_nthreads_offset for child in items
)
out.children.extend(items)
out.enqueued = min(child.enqueued for child in items)
return out
def _nthreads_timeseries(self) -> Iterator[tuple[float, int]]:
"""Yield (timestamp, number of threads across the cluster), forward-fill"""
stop = self.stop if self.done else 0
for t, n in islice(
self._total_nthreads_history, self._total_nthreads_offset, None
):
if stop and t >= stop:
break
yield max(self.enqueued, t), n
def _active_timeseries(self) -> Iterator[tuple[float, bool]]:
"""If this span is the output of :meth:`merge`, yield
(timestamp, True if at least one input span is active), forward-fill.
"""
if self.id != "(merged)":
yield self.enqueued, True
yield self.stop, False
return
events = []
for child in self.children:
events += [(child.enqueued, 1), (child.stop, -1)]
# enqueued <= stop by construction.
# Occasionally, enqueued == stop, e.g. when the clock is adjusted backwards.
# Prevent negative n_active when this happens.
events.sort(key=lambda el: el[0])
n_active = 0
for t, delta in events:
if not n_active:
assert delta == 1
yield t, True
n_active += delta
if not n_active:
yield t, False
@property
def nthreads_intervals(self) -> list[tuple[float, float, int]]:
"""
Returns
-------
List of tuples:
- begin timestamp
- end timestamp
- Scheduler.total_nthreads during this interval
When the Span is the output of :meth:`merge`, the intervals may not be
contiguous.
See Also
--------
enqueued
stop
active_cpu_seconds
distributed.scheduler.SchedulerState.total_nthreads
"""
nthreads_t, nthreads_count = zip(*self._nthreads_timeseries())
is_active_t, is_active_flag = zip(*self._active_timeseries())
t_interp = sorted({*nthreads_t, *is_active_t})
nthreads_count_interp = ffill(t_interp, nthreads_t, nthreads_count, left=0)
is_active_flag_interp = ffill(t_interp, is_active_t, is_active_flag, left=False)
return [
(t0, t1, n)
for t0, t1, n, active in zip(
t_interp, t_interp[1:], nthreads_count_interp, is_active_flag_interp
)
if active
]
@property
def active_cpu_seconds(self) -> float:
"""Return number of CPU seconds that were made available on the cluster while
this Span was running; in other words
``(Span.stop - Span.enqueued) * Scheduler.total_nthreads``.
This accounts for workers joining and leaving the cluster while this Span was
active. If this Span is the output of :meth:`merge`, do not count gaps between
input spans.
See Also
--------
enqueued
stop
nthreads_intervals
distributed.scheduler.SchedulerState.total_nthreads
"""
return sum((t1 - t0) * nthreads for t0, t1, nthreads in self.nthreads_intervals)
[docs]class SpansSchedulerExtension:
"""Scheduler extension for spans support"""
scheduler: Scheduler
#: All Span objects by id
spans: dict[str, Span]
#: Only the spans that don't have any parents, sorted by creation time.
#: This is a convenience helper structure to speed up searches.
root_spans: list[Span]
#: All spans, keyed by their full name and sorted by creation time.
#: This is a convenience helper structure to speed up searches.
spans_search_by_name: defaultdict[tuple[str, ...], list[Span]]
#: All spans, keyed by the individual tags that make up their name and sorted by
#: creation time.
#: This is a convenience helper structure to speed up searches.
#:
#: See Also
#: --------
#: find_by_tags
#: merge_by_tags
spans_search_by_tag: defaultdict[str, list[Span]]
def __init__(self, scheduler: Scheduler):
self.scheduler = scheduler
self.spans = {}
self.root_spans = []
self.spans_search_by_name = defaultdict(list)
self.spans_search_by_tag = defaultdict(list)
[docs] def observe_tasks(
self,
tss: Iterable[scheduler_module.TaskState],
code: tuple[SourceCode, ...],
span_metadata: SpanMetadata,
) -> dict[Key, dict]:
"""Acknowledge the existence of runnable tasks on the scheduler. These may
either be new tasks, tasks that were previously unrunnable, or tasks that were
already fed into this method already.
Attach newly observed tasks to either the desired span or to ("default", ).
Update TaskGroup.span_id and wipe TaskState.annotations["span"].
Returns
-------
Updated 'span' annotations: {key: {"name": (..., ...), "ids": (..., ...)}}
"""
out = {}
default_span = None
for ts in tss:
if ts.annotations is None:
ts.annotations = dict()
# You may have different tasks belonging to the same TaskGroup but to
# different spans. If that happens, arbitrarily force everything onto the
# span of the earliest encountered TaskGroup.
tg = ts.group
if tg.span_id:
span = self.spans[tg.span_id]
else:
ann = ts.annotations.get("span")
if ann:
span = self._ensure_span(ann["name"], ann["ids"])
else:
if not default_span:
default_span = self._ensure_default_span()
span = default_span
tg.span_id = span.id
span.groups.add(tg)
if code:
span._code[code] = None
if span_metadata:
span.add_metadata(span_metadata)
# The span may be completely different from the one referenced by the
# annotation, due to the TaskGroup collision issue explained above.
if ann := span.annotation:
ts.annotations["span"] = out[ts.key] = ann
else:
ts.annotations.pop("span", None)
return out
def _ensure_default_span(self) -> Span:
"""Return the currently active default span, or create one if the previous one
terminated. In other words, do not reuse the previous default span if all tasks
that were not explicitly annotated with :func:`spans` on the client side are
finished.
"""
defaults = self.spans_search_by_name["default",]
if defaults and not defaults[-1].done:
return defaults[-1]
return self._ensure_span(("default",), (str(uuid.uuid4()),))
def _ensure_span(self, name: tuple[str, ...], ids: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[ids[-1]]
except KeyError:
pass
assert len(name) == len(ids)
assert len(name) > 0
parent = None
for i in range(1, len(name)):
parent = self._ensure_span(name[:i], ids[:i])
span = Span(
name=name,
id_=ids[-1],
parent=parent,
total_nthreads_history=self.scheduler.total_nthreads_history,
)
self.spans[span.id] = span
self.spans_search_by_name[name].append(span)
for tag in name:
self.spans_search_by_tag[tag].append(span)
if parent:
parent.children.append(span)
else:
self.root_spans.append(span)
return span
[docs] def merge_all(self) -> Span:
"""Return a synthetic Span which is the sum of all spans"""
return Span.merge(*self.root_spans)
[docs] def heartbeat(
self, ws: scheduler_module.WorkerState, data: dict[tuple[Hashable, ...], float]
) -> None:
"""Triggered by :meth:`SpansWorkerExtension.heartbeat`.
Populate :meth:`Span.cumulative_worker_metrics` with data from the worker.
See Also
--------
SpansWorkerExtension.heartbeat
Span.cumulative_worker_metrics
"""
for (context, span_id, *other), v in data.items():
assert isinstance(span_id, str)
span = self.spans[span_id]
span._cumulative_worker_metrics[(context, *other)] += v
[docs]class SpansWorkerExtension:
"""Worker extension for spans support"""
worker: Worker
digests_total_since_heartbeat: dict[tuple[Hashable, ...], float]
def __init__(self, worker: Worker):
self.worker = worker
self.digests_total_since_heartbeat = {}
def collect_digests(
self, digests_total_since_heartbeat: Mapping[Hashable, float]
) -> None:
# Note: this method may be called spuriously by Worker._register_with_scheduler,
# but when it does it's guaranteed not to find any metrics
assert not self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {
k: v
for k, v in digests_total_since_heartbeat.items()
if isinstance(k, tuple) and k[0] in CONTEXTS_WITH_SPAN_ID
}
[docs] def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
"""Apportion the metrics that do have a span to the Spans on the scheduler
Returns
-------
``{(context, span_id, prefix, activity, unit): value}}``
See Also
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
distributed.worker.Worker.get_metrics
"""
out = self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {}
return out