Source code for distributed.comm.core

from abc import ABCMeta, abstractmethod, abstractproperty
from datetime import timedelta
import logging
import weakref

import dask
from six import with_metaclass
from tornado import gen

from ..metrics import time
from ..utils import parse_timedelta
from . import registry
from .addressing import parse_address

logger = logging.getLogger(__name__)

class CommClosedError(IOError):

class FatalCommClosedError(CommClosedError):

[docs]class Comm(with_metaclass(ABCMeta)): """ A message-oriented communication object, representing an established communication channel. There should be only one reader and one writer at a time: to manage current communications, even with a single peer, you must create distinct ``Comm`` objects. Messages are arbitrary Python objects. Concrete implementations of this class can implement different serialization mechanisms depending on the underlying transport's characteristics. """ _instances = weakref.WeakSet() def __init__(self): self._instances.add(self) = None # XXX add set_close_callback()?
[docs] @abstractmethod def read(self, deserializers=None): """ Read and return a message (a Python object). This method is a coroutine. Parameters ---------- deserializers : Optional[Dict[str, Tuple[Callable, Callable, bool]]] An optional dict appropriate for distributed.protocol.deserialize. See :ref:`serialization` for more. """
[docs] @abstractmethod def write(self, msg, on_error=None): """ Write a message (a Python object). This method is a coroutine. Parameters ---------- msg : on_error : Optional[str] The behavior when serialization fails. See ``distributed.protocol.core.dumps`` for valid values. """
[docs] @abstractmethod def close(self): """ Close the communication cleanly. This will attempt to flush outgoing buffers before actually closing the underlying transport. This method is a coroutine. """
[docs] @abstractmethod def abort(self): """ Close the communication immediately and abruptly. Useful in destructors or generators' ``finally`` blocks. """
[docs] @abstractmethod def closed(self): """ Return whether the stream is closed. """
@abstractproperty def local_address(self): """ The local address. For logging and debugging purposes only. """ @abstractproperty def peer_address(self): """ The peer's address. For logging and debugging purposes only. """ @property def extra_info(self): """ Return backend-specific information about the communication, as a dict. Typically, this is information which is initialized when the communication is established and doesn't vary afterwards. """ return {} def __repr__(self): clsname = self.__class__.__name__ if self.closed(): return "<closed %s>" % (clsname,) else: return "<%s %s local=%s remote=%s>" % ( clsname, or "", self.local_address, self.peer_address, )
[docs]class Listener(with_metaclass(ABCMeta)):
[docs] @abstractmethod def start(self): """ Start listening for incoming connections. """
[docs] @abstractmethod def stop(self): """ Stop listening. This does not shutdown already established communications, but prevents accepting new ones. """
@abstractproperty def listen_address(self): """ The listening address as a URI string. """ @abstractproperty def contact_address(self): """ An address this listener can be contacted on. This can be different from `listen_address` if the latter is some wildcard address such as 'tcp://'. """ def __enter__(self): self.start() return self def __exit__(self, *exc): self.stop()
class Connector(with_metaclass(ABCMeta)): @abstractmethod def connect(self, address, deserialize=True): """ Connect to the given address and return a Comm object. This function is a coroutine. It may raise EnvironmentError if the other endpoint is unreachable or unavailable. It may raise ValueError if the address is malformed. """
[docs]async def connect(addr, timeout=None, deserialize=True, connection_args=None): """ Connect to the given address (a URI such as ``tcp://``) and yield a ``Comm`` object. If the connection attempt fails, it is retried until the *timeout* is expired. """ if timeout is None: timeout = dask.config.get("distributed.comm.timeouts.connect") timeout = parse_timedelta(timeout, default="seconds") scheme, loc = parse_address(addr) backend = registry.get_backend(scheme) connector = backend.get_connector() start = time() deadline = start + timeout error = None def _raise(error): error = error or "connect() didn't finish in time" msg = "Timed out trying to connect to %r after %s s: %s" % ( addr, timeout, error, ) raise IOError(msg) # This starts a thread while True: try: future = connector.connect( loc, deserialize=deserialize, **(connection_args or {}) ) comm = await gen.with_timeout( timedelta(seconds=deadline - time()), future, quiet_exceptions=EnvironmentError, ) except FatalCommClosedError: raise except EnvironmentError as e: error = str(e) if time() < deadline: await gen.sleep(0.01) logger.debug("sleeping on connect") else: _raise(error) except gen.TimeoutError: _raise(error) else: break return comm
[docs]def listen(addr, handle_comm, deserialize=True, connection_args=None): """ Create a listener object with the given parameters. When its ``start()`` method is called, the listener will listen on the given address (a URI such as ``tcp://``) and call *handle_comm* with a ``Comm`` object for each incoming connection. *handle_comm* can be a regular function or a coroutine. """ try: scheme, loc = parse_address(addr, strict=True) except ValueError: if connection_args and connection_args.get("ssl_context"): addr = "tls://" + addr else: addr = "tcp://" + addr scheme, loc = parse_address(addr, strict=True) backend = registry.get_backend(scheme) return backend.get_listener( loc, handle_comm, deserialize, **(connection_args or {}) )