Source code for distributed.spill

from __future__ import annotations

import logging
from collections.abc import Iterator, Mapping, MutableMapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from time import perf_counter
from typing import Any, Literal, NamedTuple, Protocol, cast

from packaging.version import parse as parse_version

import zict

from distributed.protocol import deserialize_bytes, serialize_bytelist
from distributed.sizeof import safe_sizeof

logger = logging.getLogger(__name__)
has_zict_220 = parse_version(zict.__version__) >= parse_version("2.2.0")
has_zict_230 = parse_version(zict.__version__) >= parse_version("2.3.0")


class SpilledSize(NamedTuple):
    """Size of a key/value pair when spilled to disk, in bytes"""

    # output of sizeof()
    memory: int
    # pickled size
    disk: int

    def __add__(self, other: SpilledSize) -> SpilledSize:  # type: ignore
        return SpilledSize(self.memory + other.memory, self.disk + other.disk)

    def __sub__(self, other: SpilledSize) -> SpilledSize:
        return SpilledSize(self.memory - other.memory, self.disk - other.disk)


[docs]class ManualEvictProto(Protocol): """Duck-type API that a third-party alternative to SpillBuffer must respect (in addition to MutableMapping) if it wishes to support spilling when the ``distributed.worker.memory.spill`` threshold is surpassed. This is public API. At the moment of writing, Dask-CUDA implements this protocol in the ProxifyHostFile class. """ @property def fast(self) -> Sized | bool: """Access to fast memory. This is normally a MutableMapping, but for the purpose of the manual eviction API it is just tested for emptiness to know if there is anything to evict. """ ... # pragma: nocover
[docs] def evict(self) -> int: """Manually evict a key/value pair from fast to slow memory. Return size of the evicted value in fast memory. If the eviction failed for whatever reason, return -1. This method must guarantee that the key/value pair that caused the issue has been retained in fast memory and that the problem has been logged internally. This method never raises. """ ... # pragma: nocover
@dataclass class FastMetrics: """Cumulative metrics for SpillBuffer.fast since the latest worker restart""" read_count_total: int = 0 read_bytes_total: int = 0 def log_read(self, key_bytes: int) -> None: self.read_count_total += 1 self.read_bytes_total += key_bytes @dataclass class SlowMetrics: """Cumulative metrics for SpillBuffer.slow since the latest worker restart""" read_count_total: int = 0 read_bytes_total: int = 0 read_time_total: float = 0 write_count_total: int = 0 write_bytes_total: int = 0 write_time_total: float = 0 pickle_time_total: float = 0 unpickle_time_total: float = 0 def log_write( self, key_bytes: int, pickle_time: float, write_time: float, ) -> None: self.write_count_total += 1 self.write_bytes_total += key_bytes self.pickle_time_total += pickle_time self.write_time_total += write_time def log_read(self, key_bytes: int, read_time: float, unpickle_time: float) -> None: self.read_count_total += 1 self.read_bytes_total += key_bytes self.read_time_total += read_time self.unpickle_time_total += unpickle_time # zict.Buffer[str, Any] requires zict >= 2.2.0 class SpillBuffer(zict.Buffer): """MutableMapping that automatically spills out dask key/value pairs to disk when the total size of the stored data exceeds the target. If max_spill is provided the key/value pairs won't be spilled once this threshold has been reached. Parameters ---------- spill_directory: str Location on disk to write the spill files to target: int Managed memory, in bytes, to start spilling at max_spill: int | False, optional Limit of number of bytes to be spilled on disk. Set to False to disable. min_log_interval: float, optional Minimum interval, in seconds, between warnings on the log file about full disk """ last_logged: float min_log_interval: float logged_pickle_errors: set[str] fast_metrics: FastMetrics def __init__( self, spill_directory: str, target: int, max_spill: int | Literal[False] = False, min_log_interval: float = 2, ): slow: MutableMapping[str, Any] = Slow(spill_directory, max_spill) if has_zict_220: # If a value is still in use somewhere on the worker since the last time it # was unspilled, don't duplicate it slow = zict.Cache(slow, zict.WeakValueMapping()) super().__init__(fast={}, slow=slow, n=target, weight=_in_memory_weight) self.last_logged = 0 self.min_log_interval = min_log_interval self.logged_pickle_errors = set() # keys logged with pickle error self.fast_metrics = FastMetrics() @contextmanager def handle_errors(self, key: str | None) -> Iterator[None]: try: yield except MaxSpillExceeded as e: # key is in self.fast; no keys have been lost on eviction # Note: requires zict > 2.0 (key_e,) = e.args assert key_e in self.fast assert key_e not in self.slow now = perf_counter() if now - self.last_logged >= self.min_log_interval: logger.warning( "Spill file on disk reached capacity; keeping data in memory" ) self.last_logged = now raise HandledError() except OSError: # Typically, this is a disk full error now = perf_counter() if now - self.last_logged >= self.min_log_interval: logger.error( "Spill to disk failed; keeping data in memory", exc_info=True ) self.last_logged = now raise HandledError() except PickleError as e: key_e, orig_e = e.args if parse_version(zict.__version__) <= parse_version("2.0.0"): pass else: assert key_e in self.fast assert key_e not in self.slow if key_e == key: assert key is not None # The key we just inserted failed to serialize. # This happens only when the key is individually larger than target. # The exception will be caught by Worker and logged; the status of # the task will be set to error. del self[key] raise orig_e else: # The key we just inserted is smaller than target, but it caused # another, unrelated key to be spilled out of the LRU, and that key # failed to serialize. There's nothing wrong with the new key. The older # key is still in memory. if key_e not in self.logged_pickle_errors: logger.error(f"Failed to pickle {key_e!r}", exc_info=True) self.logged_pickle_errors.add(key_e) raise HandledError() def __setitem__(self, key: str, value: Any) -> None: """If sizeof(value) < target, write key/value pair to self.fast; this may in turn cause older keys to be spilled from fast to slow. If sizeof(value) >= target, write key/value pair directly to self.slow instead. Raises ------ Exception sizeof(value) >= target, and value failed to pickle. The key/value pair has been forgotten. In all other cases: - an older value was evicted and failed to pickle, - this value or an older one caused the disk to fill and raise OSError, - this value or an older one caused the max_spill threshold to be exceeded, this method does not raise and guarantees that the key/value that caused the issue remained in fast. """ try: with self.handle_errors(key): super().__setitem__(key, value) self.logged_pickle_errors.discard(key) except HandledError: assert key in self.fast assert key not in self.slow def evict(self) -> int: """Implementation of :meth:`ManualEvictProto.evict`. Manually evict the oldest key/value pair, even if target has not been reached. Returns sizeof(value). If the eviction failed (value failed to pickle, disk full, or max_spill exceeded), return -1; the key/value pair that caused the issue will remain in fast. The exception has been logged internally. This method never raises. """ try: with self.handle_errors(None): _, _, weight = self.fast.evict() return cast(int, weight) except HandledError: return -1 def __getitem__(self, key: str) -> Any: if key in self.fast: # Note: don't log from self.fast.__getitem__, because that's called every # time a key is evicted, and we don't want to count those events here. nbytes = cast(int, self.fast.weights[key]) self.fast_metrics.log_read(nbytes) return super().__getitem__(key) def __delitem__(self, key: str) -> None: super().__delitem__(key) self.logged_pickle_errors.discard(key) @property def memory(self) -> Mapping[str, Any]: """Key/value pairs stored in RAM. Alias of zict.Buffer.fast. For inspection only - do not modify directly! """ return self.fast @property def disk(self) -> Mapping[str, Any]: """Key/value pairs spilled out to disk. Alias of zict.Buffer.slow. For inspection only - do not modify directly! """ return self.slow @property def _slow_uncached(self) -> Slow: slow = cast(zict.Cache, self.slow).data if has_zict_220 else self.slow return cast(Slow, slow) @property def spilled_total(self) -> SpilledSize: """Number of bytes spilled to disk. Tuple of - output of sizeof() - pickled size The two may differ substantially, e.g. if sizeof() is inaccurate or in case of compression. """ return self._slow_uncached.total_weight def get_metrics(self) -> dict[str, float]: """Metrics to be exported to Prometheus or to be parsed directly. From these you may generate derived metrics: cache hit ratio: by keys = memory_read_count_total / (memory_read_count_total + disk_read_count_total) by bytes = memory_read_bytes_total / (memory_read_bytes_total + disk_read_bytes_total) mean times per key: pickle = pickle_time_total / disk_write_count_total write = disk_write_time_total / disk_write_count_total unpickle = unpickle_time_total / disk_read_count_total read = disk_read_time_total / disk_read_count_total mean bytes per key: write = disk_write_bytes_total / disk_write_count_total read = disk_read_bytes_total / disk_read_count_total mean bytes per second: write = disk_write_bytes_total / disk_write_time_total read = disk_read_bytes_total / disk_read_time_total """ fm = self.fast_metrics sm = self._slow_uncached.metrics out = { "memory_count": len(self.fast), "memory_bytes": self.fast.total_weight, "disk_count": len(self.slow), "disk_bytes": self._slow_uncached.total_weight.disk, } for k, v in fm.__dict__.items(): out[f"memory_{k}"] = v for k, v in sm.__dict__.items(): out[k if "pickle" in k else f"disk_{k}"] = v return out def _in_memory_weight(key: str, value: Any) -> int: return safe_sizeof(value) # Internal exceptions. These are never raised by SpillBuffer. class MaxSpillExceeded(Exception): pass class PickleError(Exception): pass class HandledError(Exception): pass # zict.Func[str, Any] requires zict >= 2.2.0 class Slow(zict.Func): max_weight: int | Literal[False] weight_by_key: dict[str, SpilledSize] total_weight: SpilledSize metrics: SlowMetrics def __init__(self, spill_directory: str, max_weight: int | Literal[False] = False): super().__init__( partial(serialize_bytelist, on_error="raise"), deserialize_bytes, zict.File(spill_directory), ) self.max_weight = max_weight self.weight_by_key = {} self.total_weight = SpilledSize(0, 0) self.metrics = SlowMetrics() def __getitem__(self, key: str) -> Any: t0 = perf_counter() pickled = self.d[key] assert isinstance(pickled, bytearray if has_zict_230 else bytes) t1 = perf_counter() out = self.load(pickled) # type: ignore t2 = perf_counter() # For the sake of simplicity, we're not metering failure use cases. self.metrics.log_read( key_bytes=len(pickled), read_time=t1 - t0, unpickle_time=t2 - t1, ) return out def __setitem__(self, key: str, value: Any) -> None: t0 = perf_counter() try: # FIXME https://github.com/python/mypy/issues/708 pickled = self.dump(value) # type: ignore except Exception as e: # zict.LRU ensures that the key remains in fast if we raise. # Wrap the exception so that it's recognizable by SpillBuffer, # which will then unwrap it. raise PickleError(key, e) pickled_size = sum( frame.nbytes if isinstance(frame, memoryview) else len(frame) for frame in pickled ) t1 = perf_counter() # Thanks to Buffer.__setitem__, we never update existing # keys in slow, but always delete them and reinsert them. assert key not in self.d assert key not in self.weight_by_key if ( self.max_weight is not False and self.total_weight.disk + pickled_size > self.max_weight ): # Stop callbacks and ensure that the key ends up in SpillBuffer.fast # To be caught by SpillBuffer.__setitem__ raise MaxSpillExceeded(key) # Store to disk through File. # This may raise OSError, which is caught by SpillBuffer above. self.d[key] = pickled t2 = perf_counter() weight = SpilledSize(safe_sizeof(value), pickled_size) self.weight_by_key[key] = weight self.total_weight += weight # For the sake of simplicity, we're not metering failure use cases. self.metrics.log_write( key_bytes=pickled_size, pickle_time=t1 - t0, write_time=t2 - t1, ) def __delitem__(self, key: str) -> None: super().__delitem__(key) self.total_weight -= self.weight_by_key.pop(key)