1"""A semi-synchronous Client for IPython parallel"""
2
3# Copyright (c) IPython Development Team.
4# Distributed under the terms of the Modified BSD License.
5import asyncio
6import json
7import os
8import re
9import socket
10import time
11import types
12import warnings
13from collections.abc import Iterable
14from concurrent.futures import Future
15from functools import partial
16from getpass import getpass
17from pprint import pprint
18from threading import Event, current_thread
19
20import jupyter_client.session
21import zmq
22from decorator import decorator
23from ipykernel.comm import Comm
24from IPython import get_ipython
25from IPython.core.application import BaseIPythonApplication
26from IPython.core.profiledir import ProfileDir, ProfileDirError
27from IPython.paths import get_ipython_dir
28from IPython.utils.capture import RichOutput
29from IPython.utils.path import compress_user
30from jupyter_client.localinterfaces import is_local_ip, localhost
31from jupyter_client.session import Session
32from tornado import ioloop
33from traitlets import (
34 Any,
35 Bool,
36 Bytes,
37 Dict,
38 HasTraits,
39 Instance,
40 List,
41 Set,
42 Unicode,
43 default,
44)
45from traitlets.config.configurable import MultipleInstanceError
46from zmq.eventloop.zmqstream import ZMQStream
47
48import ipyparallel as ipp
49from ipyparallel import error, serialize, util
50from ipyparallel.serialize import PrePickled, Reference
51from ipyparallel.util import _OutputProducingThread as Thread
52from ipyparallel.util import _TermColors
53
54from .asyncresult import AsyncHubResult, AsyncResult
55from .futures import MessageFuture, multi_future
56from .view import BroadcastView, DirectView, LoadBalancedView
57
58pjoin = os.path.join
59jupyter_client.session.extract_dates = lambda obj: obj
60# --------------------------------------------------------------------------
61# Decorators for Client methods
62# --------------------------------------------------------------------------
63
64
65@decorator
66def unpack_message(f, self, msg_parts):
67 """Unpack a message before calling the decorated method."""
68 idents, msg = self.session.feed_identities(msg_parts, copy=False)
69 try:
70 msg = self.session.deserialize(msg, content=True, copy=False)
71 except Exception:
72 self.log.error("Invalid Message", exc_info=True)
73 else:
74 if self.debug:
75 pprint(msg)
76 return f(self, msg)
77
78
79# --------------------------------------------------------------------------
80# Classes
81# --------------------------------------------------------------------------
82
83
84_no_connection_file_msg = """
85Failed to connect because no Controller could be found.
86Please double-check your profile and ensure that a cluster is running.
87"""
88
89
90class ExecuteReply(RichOutput):
91 """wrapper for finished Execute results"""
92
93 def __init__(self, msg_id, content, metadata):
94 self.msg_id = msg_id
95 self._content = content
96 self.execution_count = content['execution_count']
97 self.metadata = metadata
98
99 # RichOutput overrides
100
101 @property
102 def source(self):
103 execute_result = self.metadata['execute_result']
104 if execute_result:
105 return execute_result.get('source', '')
106
107 @property
108 def data(self):
109 execute_result = self.metadata['execute_result']
110 if execute_result:
111 return execute_result.get('data', {})
112 return {}
113
114 @property
115 def _metadata(self):
116 execute_result = self.metadata['execute_result']
117 if execute_result:
118 return execute_result.get('metadata', {})
119 return {}
120
121 def display(self):
122 from IPython.display import publish_display_data
123
124 publish_display_data(self.data, self.metadata)
125
126 def _repr_mime_(self, mime):
127 if mime not in self.data:
128 return
129 data = self.data[mime]
130 if mime in self._metadata:
131 return data, self._metadata[mime]
132 else:
133 return data
134
135 def _repr_mimebundle_(self, *args, **kwargs):
136 data, md = self.data, self.metadata
137 if 'text/plain' in data:
138 data = data.copy()
139 data['text/plain'] = self._plaintext()
140 return data, md
141
142 def __getitem__(self, key):
143 return self.metadata[key]
144
145 def __getattr__(self, key):
146 if key not in self.metadata:
147 raise AttributeError(key)
148 return self.metadata[key]
149
150 def __repr__(self):
151 execute_result = self.metadata['execute_result'] or {'data': {}}
152 text_out = execute_result['data'].get('text/plain', '')
153 if len(text_out) > 32:
154 text_out = text_out[:29] + '...'
155
156 return f"<ExecuteReply[{self.execution_count}]: {text_out}>"
157
158 def _plaintext(self) -> str:
159 execute_result = self.metadata['execute_result'] or {'data': {}}
160 text_out = execute_result['data'].get('text/plain', '')
161
162 if not text_out:
163 return ''
164
165 ip = get_ipython()
166 if ip is None:
167 colors = "NoColor"
168 else:
169 colors = ip.colors
170
171 if colors.lower() == "nocolor":
172 out = normal = ""
173 else:
174 out = _TermColors.Red
175 normal = _TermColors.Normal
176
177 if '\n' in text_out and not text_out.startswith('\n'):
178 # add newline for multiline reprs
179 text_out = '\n' + text_out
180
181 return ''.join(
182 [
183 out,
184 f"Out[{self.metadata['engine_id']}:{self.execution_count}]: ",
185 normal,
186 text_out,
187 ]
188 )
189
190 def _repr_pretty_(self, p, cycle):
191 p.text(self._plaintext())
192
193
194class Metadata(dict):
195 """Subclass of dict for initializing metadata values.
196
197 Attribute access works on keys.
198
199 These objects have a strict set of keys - errors will raise if you try
200 to add new keys.
201 """
202
203 def __init__(self, *args, **kwargs):
204 dict.__init__(self)
205 md = {
206 'msg_id': None,
207 'submitted': None,
208 'started': None,
209 'completed': None,
210 'received': None,
211 'engine_uuid': None,
212 'engine_id': None,
213 'follow': None,
214 'after': None,
215 'status': None,
216 'execute_input': None,
217 'execute_result': None,
218 'error': None,
219 'stdout': '',
220 'stderr': '',
221 'outputs': [],
222 'data': {},
223 }
224 self.update(md)
225 self.update(dict(*args, **kwargs))
226
227 def __getattr__(self, key):
228 """getattr aliased to getitem"""
229 if key in self:
230 return self[key]
231 else:
232 raise AttributeError(key)
233
234 def __setattr__(self, key, value):
235 """setattr aliased to setitem, with strict"""
236 if key in self:
237 self[key] = value
238 else:
239 raise AttributeError(key)
240
241 def __setitem__(self, key, value):
242 """strict static key enforcement"""
243 if key in self:
244 dict.__setitem__(self, key, value)
245 else:
246 raise KeyError(key)
247
248
249def _is_future(f):
250 """light duck-typing check for Futures"""
251 return hasattr(f, 'add_done_callback')
252
253
254# carriage return pattern
255_cr_pat = re.compile(r'.*\r(?=[^\n])')
256
257
258class Client(HasTraits):
259 """A semi-synchronous client to an IPython parallel cluster
260
261 Parameters
262 ----------
263
264 connection_info : str or dict
265 The path to ipcontroller-client.json, or a dict containing the same information.
266 This JSON file should contain all the information needed to connect to a cluster,
267 and is usually the only argument needed.
268 [Default: use profile]
269 profile : str
270 The name of the Cluster profile to be used to find connector information.
271 If run from an IPython application, the default profile will be the same
272 as the running application, otherwise it will be 'default'.
273 cluster_id : str
274 String id to added to runtime files, to prevent name collisions when using
275 multiple clusters with a single profile simultaneously.
276 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
277 Since this is text inserted into filenames, typical recommendations apply:
278 Simple character strings are ideal, and spaces are not recommended (but
279 should generally work)
280 context : zmq.Context
281 Pass an existing zmq.Context instance, otherwise the client will create its own.
282 debug : bool
283 flag for lots of message printing for debug purposes
284 timeout : float
285 time (in seconds) to wait for connection replies from the Hub
286 [Default: 10]
287
288 Other Parameters
289 ----------------
290
291 sshserver : str
292 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
293 If keyfile or password is specified, and this is not, it will default to
294 the ip given in addr.
295 sshkey : str; path to ssh private key file
296 This specifies a key to be used in ssh login, default None.
297 Regular default ssh keys will be used without specifying this argument.
298 password : str
299 Your ssh password to sshserver. Note that if this is left None,
300 you will be prompted for it if passwordless key based login is unavailable.
301 paramiko : bool
302 flag for whether to use paramiko instead of shell ssh for tunneling.
303 [default: True on win32, False else]
304
305
306 Attributes
307 ----------
308
309 ids : list of int engine IDs
310 requesting the ids attribute always synchronizes
311 the registration state. To request ids without synchronization,
312 use semi-private _ids attributes.
313
314 history : list of msg_ids
315 a list of msg_ids, keeping track of all the execution
316 messages you have submitted in order.
317
318 outstanding : set of msg_ids
319 a set of msg_ids that have been submitted, but whose
320 results have not yet been received.
321
322 results : dict
323 a dict of all our results, keyed by msg_id
324
325 block : bool
326 determines default behavior when block not specified
327 in execution methods
328
329 """
330
331 block = Bool(False)
332 outstanding = Set()
333 results = Instance('collections.defaultdict', (dict,))
334 metadata = Instance('collections.defaultdict', (Metadata,))
335 cluster = Instance('ipyparallel.cluster.Cluster', allow_none=True)
336 history = List()
337 debug = Bool(False)
338 _futures = Dict()
339 _output_futures = Dict()
340 _io_loop = Any()
341 _io_thread = Any()
342
343 profile = Unicode()
344
345 def _profile_default(self):
346 if BaseIPythonApplication.initialized():
347 # an IPython app *might* be running, try to get its profile
348 try:
349 return BaseIPythonApplication.instance().profile
350 except (AttributeError, MultipleInstanceError):
351 # could be a *different* subclass of config.Application,
352 # which would raise one of these two errors.
353 return 'default'
354 else:
355 return 'default'
356
357 _outstanding_dict = Instance('collections.defaultdict', (set,))
358 _ids = List()
359 _connected = Bool(False)
360 _ssh = Bool(False)
361 _context = Instance('zmq.Context', allow_none=True)
362
363 @default("_context")
364 def _default_context(self):
365 return zmq.Context.instance()
366
367 _config = Dict()
368 _engines = Instance(util.ReverseDict, (), {})
369 _query_socket = Instance('zmq.Socket', allow_none=True)
370 _control_socket = Instance('zmq.Socket', allow_none=True)
371 _iopub_socket = Instance('zmq.Socket', allow_none=True)
372 _notification_socket = Instance('zmq.Socket', allow_none=True)
373 _mux_socket = Instance('zmq.Socket', allow_none=True)
374 _task_socket = Instance('zmq.Socket', allow_none=True)
375 _broadcast_socket = Instance('zmq.Socket', allow_none=True)
376 _registration_callbacks = List()
377
378 curve_serverkey = Bytes(allow_none=True)
379 curve_secretkey = Bytes(allow_none=True)
380 curve_publickey = Bytes(allow_none=True)
381
382 _task_scheme = Unicode()
383 _closed = False
384
385 def __new__(self, *args, **kw):
386 # don't raise on positional args
387 return HasTraits.__new__(self, **kw)
388
389 def __init__(
390 self,
391 connection_info=None,
392 *,
393 url_file=None,
394 profile=None,
395 profile_dir=None,
396 ipython_dir=None,
397 context=None,
398 debug=False,
399 sshserver=None,
400 sshkey=None,
401 password=None,
402 paramiko=None,
403 timeout=10,
404 cluster_id=None,
405 cluster=None,
406 **extra_args,
407 ):
408 super_kwargs = {'debug': debug, 'cluster': cluster}
409 if profile:
410 super_kwargs['profile'] = profile
411 super().__init__(**super_kwargs)
412 if context is not None:
413 self._context = context
414
415 for argname in ('url_or_file', 'url_file'):
416 if argname in extra_args:
417 connection_info = extra_args[argname]
418 warnings.warn(
419 f"{argname} arg no longer supported, use positional connection_info argument",
420 DeprecationWarning,
421 stacklevel=2,
422 )
423
424 if isinstance(connection_info, str) and util.is_url(connection_info):
425 raise ValueError(
426 f"single urls ({connection_info!r}) cannot be specified, url-files must be used."
427 )
428
429 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
430
431 no_file_msg = '\n'.join(
432 [
433 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
434 "Please double-check your configuration and ensure that a cluster is running.",
435 ]
436 )
437
438 if connection_info is None and self._profile_dir is not None:
439 # default: find connection info from profile
440 if cluster_id:
441 client_json = f'ipcontroller-{cluster_id}-client.json'
442 else:
443 client_json = 'ipcontroller-client.json'
444 connection_file = pjoin(self._profile_dir.security_dir, client_json)
445 short = compress_user(connection_file)
446 if not os.path.exists(connection_file):
447 print(f"Waiting for connection file: {short}")
448 waiting_time = 0.0
449 while waiting_time < timeout:
450 time.sleep(min(timeout - waiting_time, 1))
451 waiting_time += 1
452 if os.path.exists(connection_file):
453 break
454 if not os.path.exists(connection_file):
455 msg = '\n'.join([f"Connection file {short!r} not found.", no_file_msg])
456 raise OSError(msg)
457
458 with open(connection_file) as f:
459 connection_info = json.load(f)
460
461 if connection_info is None:
462 raise OSError(no_file_msg)
463
464 if isinstance(connection_info, dict):
465 cfg = connection_info.copy()
466 else:
467 # connection_info given as path to connection file
468 connection_file = connection_info
469 if not os.path.exists(connection_file):
470 # Connection file explicitly specified, but not found
471 raise OSError(
472 f"Connection file {compress_user(connection_file)} not found. Is a controller running?"
473 )
474
475 with open(connection_file) as f:
476 connection_info = cfg = json.load(f)
477
478 self._task_scheme = cfg['task_scheme']
479
480 if not cfg.get("curve_serverkey") and "IPP_CURVE_SERVERKEY" in os.environ:
481 # load from env, if not set in connection file
482 cfg["curve_serverkey"] = os.environ["IPP_CURVE_SERVERKEY"]
483
484 if cfg.get("curve_serverkey"):
485 self.curve_serverkey = cfg["curve_serverkey"].encode('ascii')
486 if not self.curve_publickey or not self.curve_secretkey:
487 # if context: this could crash!
488 # inappropriately closes libsodium random_bytes source
489 # with libzmq <= 4.3.4
490 self.curve_publickey, self.curve_secretkey = zmq.curve_keypair()
491
492 # sync defaults from args, json:
493 if sshserver:
494 cfg['ssh'] = sshserver
495
496 location = cfg.setdefault('location', None)
497
498 proto, addr = cfg['interface'].split('://')
499 addr = util.disambiguate_ip_address(addr, location)
500 cfg['interface'] = f"{proto}://{addr}"
501
502 # turn interface,port into full urls:
503 for key in (
504 'control',
505 'task',
506 'mux',
507 'iopub',
508 'notification',
509 'registration',
510 'broadcast',
511 ):
512 cfg[key] = f"{cfg['interface']}:{cfg[key]}"
513
514 url = cfg['registration']
515
516 if location is not None and addr == localhost():
517 # location specified, and connection is expected to be local
518 location_ip = util.ip_for_host(location)
519
520 if not is_local_ip(location_ip) and not sshserver:
521 # load ssh from JSON *only* if the controller is not on
522 # this machine
523 sshserver = cfg['ssh']
524 if (
525 not is_local_ip(location_ip)
526 and not sshserver
527 and location != socket.gethostname()
528 ):
529 # warn if no ssh specified, but SSH is probably needed
530 # This is only a warning, because the most likely cause
531 # is a local Controller on a laptop whose IP is dynamic
532 warnings.warn(
533 f"""
534 Controller appears to be listening on localhost, but not on this machine.
535 If this is true, you should specify Client(...,sshserver='you@{location}')
536 or instruct your controller to listen on an external IP.""",
537 RuntimeWarning,
538 )
539 elif not sshserver:
540 # otherwise sync with cfg
541 sshserver = cfg['ssh']
542
543 self._config = cfg
544
545 self._ssh = bool(sshserver or sshkey or password)
546 if self._ssh and sshserver is None:
547 # default to ssh via localhost
548 sshserver = addr
549 if self._ssh and password is None:
550 from zmq.ssh import tunnel
551
552 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
553 password = False
554 else:
555 password = getpass(f"SSH Password for {sshserver}: ")
556 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
557
558 # configure and construct the session
559 try:
560 extra_args['packer'] = cfg['pack']
561 extra_args['unpacker'] = cfg['unpack']
562 extra_args['key'] = cfg['key'].encode("utf8")
563 extra_args['signature_scheme'] = cfg['signature_scheme']
564 except KeyError as exc:
565 msg = '\n'.join(
566 [
567 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
568 "If you are reusing connection files, remove them and start ipcontroller again.",
569 ]
570 )
571 raise ValueError(msg.format(exc.message))
572
573 util._disable_session_extract_dates()
574 self.session = Session(**extra_args)
575
576 self._query_socket = self._context.socket(zmq.DEALER)
577 if self.curve_serverkey:
578 self._query_socket.curve_serverkey = self.curve_serverkey
579 self._query_socket.curve_secretkey = self.curve_secretkey
580 self._query_socket.curve_publickey = self.curve_publickey
581
582 if self._ssh:
583 from zmq.ssh import tunnel
584
585 tunnel.tunnel_connection(
586 self._query_socket,
587 cfg['registration'],
588 sshserver,
589 timeout=timeout,
590 **ssh_kwargs,
591 )
592 else:
593 self._query_socket.connect(cfg['registration'])
594
595 self.session.debug = self.debug
596
597 self._notification_handlers = {
598 'registration_notification': self._register_engine,
599 'unregistration_notification': self._unregister_engine,
600 'shutdown_notification': lambda msg: self.close(),
601 }
602 self._queue_handlers = {
603 'execute_reply': self._handle_execute_reply,
604 'apply_reply': self._handle_apply_reply,
605 }
606
607 try:
608 self._connect(sshserver, ssh_kwargs, timeout)
609 except Exception:
610 self.close(linger=0)
611 raise
612
613 # last step: setup magics, if we are in IPython:
614
615 ip = get_ipython()
616 if ip is None:
617 return
618 else:
619 if 'px' not in ip.magics_manager.magics["line"]:
620 # in IPython but we are the first Client.
621 # activate a default view for parallel magics.
622 self.activate()
623
624 def __del__(self):
625 """cleanup sockets, but _not_ context."""
626 self.close()
627
628 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
629 if ipython_dir is None:
630 ipython_dir = get_ipython_dir()
631 if profile_dir is not None:
632 try:
633 self._profile_dir = ProfileDir.find_profile_dir(profile_dir)
634 return
635 except ProfileDirError:
636 pass
637 elif profile is not None:
638 try:
639 self._profile_dir = ProfileDir.find_profile_dir_by_name(
640 ipython_dir, profile
641 )
642 return
643 except ProfileDirError:
644 pass
645 self._profile_dir = None
646
647 def __enter__(self):
648 """A client can be used as a context manager
649
650 which will close the client on exit
651
652 .. versionadded: 7.0
653 """
654 return self
655
656 def __exit__(self, exc_type, exc_value, traceback):
657 """Exiting a client context closes the client"""
658 self.close()
659
660 def _update_engines(self, engines):
661 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
662 for k, v in engines.items():
663 eid = int(k)
664 if eid not in self._engines:
665 self._ids.append(eid)
666 self._engines[eid] = v
667 self._ids = sorted(self._ids)
668 if (
669 sorted(self._engines.keys()) != list(range(len(self._engines)))
670 and self._task_scheme == 'pure'
671 and self._task_socket
672 ):
673 self._stop_scheduling_tasks()
674
675 def _stop_scheduling_tasks(self):
676 """Stop scheduling tasks because an engine has been unregistered
677 from a pure ZMQ scheduler.
678 """
679 self._task_socket.close()
680 self._task_socket = None
681 msg = (
682 "An engine has been unregistered, and we are using pure "
683 + "ZMQ task scheduling. Task farming will be disabled."
684 )
685 if self.outstanding:
686 msg += (
687 " If you were running tasks when this happened, "
688 + "some `outstanding` msg_ids may never resolve."
689 )
690 warnings.warn(msg, RuntimeWarning)
691
692 def _build_targets(self, targets):
693 """Turn valid target IDs or 'all' into two lists:
694 (int_ids, uuids).
695 """
696 if not self._ids:
697 # flush notification socket if no engines yet, just in case
698 if not self.ids:
699 raise error.NoEnginesRegistered(
700 "Can't build targets without any engines"
701 )
702
703 if targets is None:
704 targets = self._ids
705 elif isinstance(targets, str):
706 if targets.lower() == 'all':
707 targets = self._ids
708 else:
709 raise TypeError(f"{targets!r} not valid str target, must be 'all'")
710 elif isinstance(targets, int):
711 if targets < 0:
712 targets = self.ids[targets]
713 if targets not in self._ids:
714 raise IndexError(f"No such engine: {targets}")
715 targets = [targets]
716
717 if isinstance(targets, slice):
718 indices = list(range(len(self._ids))[targets])
719 ids = self.ids
720 targets = [ids[i] for i in indices]
721
722 if not isinstance(targets, (tuple, list, range)):
723 raise TypeError(
724 f"targets by int/slice/collection of ints only, not {type(targets)}"
725 )
726
727 return [self._engines[t].encode("utf8") for t in targets], list(targets)
728
729 def _connect(self, sshserver, ssh_kwargs, timeout):
730 """setup all our socket connections to the cluster. This is called from
731 __init__."""
732
733 # Maybe allow reconnecting?
734 if self._connected:
735 return
736 self._connected = True
737
738 def connect_socket(s, url):
739 if self._ssh:
740 from zmq.ssh import tunnel
741
742 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
743 else:
744 return util.connect(
745 s,
746 url,
747 curve_serverkey=self.curve_serverkey,
748 curve_secretkey=self.curve_secretkey,
749 curve_publickey=self.curve_publickey,
750 )
751
752 self.session.send(self._query_socket, 'connection_request')
753 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
754 poller = zmq.Poller()
755 poller.register(self._query_socket, zmq.POLLIN)
756 # poll expects milliseconds, timeout is seconds
757 evts = poller.poll(timeout * 1000)
758 if not evts:
759 raise TimeoutError("Hub connection request timed out")
760 idents, msg = self.session.recv(self._query_socket, mode=0)
761 if self.debug:
762 pprint(msg)
763 content = msg['content']
764 # self._config['registration'] = dict(content)
765 cfg = self._config
766 if content['status'] == 'ok':
767 self._mux_socket = self._context.socket(zmq.DEALER)
768 connect_socket(self._mux_socket, cfg['mux'])
769
770 self._task_socket = self._context.socket(zmq.DEALER)
771 connect_socket(self._task_socket, cfg['task'])
772
773 self._broadcast_socket = self._context.socket(zmq.DEALER)
774 connect_socket(self._broadcast_socket, cfg['broadcast'])
775
776 self._notification_socket = self._context.socket(zmq.SUB)
777 self._notification_socket.RCVHWM = 0
778 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
779 connect_socket(self._notification_socket, cfg['notification'])
780
781 self._control_socket = self._context.socket(zmq.DEALER)
782 connect_socket(self._control_socket, cfg['control'])
783
784 self._iopub_socket = self._context.socket(zmq.SUB)
785 self._iopub_socket.RCVHWM = 0
786 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
787 connect_socket(self._iopub_socket, cfg['iopub'])
788
789 self._update_engines(dict(content['engines']))
790 else:
791 self._connected = False
792 tb = '\n'.join(content.get('traceback', []))
793 raise Exception(f"Failed to connect! {tb}")
794
795 self._start_io_thread()
796
797 # --------------------------------------------------------------------------
798 # handlers and callbacks for incoming messages
799 # --------------------------------------------------------------------------
800
801 def _unwrap_exception(self, content):
802 """unwrap exception, and remap engine_id to int."""
803 e = error.unwrap_exception(content)
804 # print e.traceback
805 if e.engine_info and 'engine_id' not in e.engine_info:
806 e_uuid = e.engine_info['engine_uuid']
807 eid = self._engines[e_uuid]
808 e.engine_info['engine_id'] = eid
809 return e
810
811 def _extract_metadata(self, msg):
812 header = msg['header']
813 parent = msg['parent_header']
814 msg_meta = msg['metadata']
815 content = msg['content']
816 md = {
817 'msg_id': parent['msg_id'],
818 'received': util.utcnow(),
819 'engine_uuid': msg_meta.get('engine', None),
820 'follow': msg_meta.get('follow', []),
821 'after': msg_meta.get('after', []),
822 'status': content['status'],
823 'is_broadcast': msg_meta.get('is_broadcast', False),
824 'is_coalescing': msg_meta.get('is_coalescing', False),
825 }
826
827 if md['engine_uuid'] is not None:
828 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
829
830 if md['is_coalescing']:
831 # get destinations from target metadata
832 targets = msg_meta.get("broadcast_targets", [])
833 md['engine_uuid'], md['engine_id'] = map(list, zip(*targets))
834
835 if 'date' in parent:
836 md['submitted'] = parent['date']
837 if 'started' in msg_meta:
838 md['started'] = util._parse_date(msg_meta['started'])
839 if 'date' in header:
840 md['completed'] = header['date']
841 return md
842
843 def _register_engine(self, msg):
844 """Register a new engine, and update our connection info."""
845 content = msg['content']
846 eid = content['id']
847 d = {eid: content['uuid']}
848 self._update_engines(d)
849 event = {'event': 'register'}
850 event.update(content)
851 for callback in self._registration_callbacks:
852 callback(event)
853
854 def _unregister_engine(self, msg):
855 """Unregister an engine that has died."""
856 content = msg['content']
857 eid = int(content['id'])
858 if eid in self._ids:
859 self._ids.remove(eid)
860 uuid = self._engines.pop(eid)
861
862 self._handle_stranded_msgs(eid, uuid)
863
864 if self._task_socket and self._task_scheme == 'pure':
865 self._stop_scheduling_tasks()
866
867 event = {"event": "unregister"}
868 event.update(content)
869 for callback in self._registration_callbacks:
870 callback(event)
871
872 def _handle_stranded_msgs(self, eid, uuid):
873 """Handle messages known to be on an engine when the engine unregisters.
874
875 It is possible that this will fire prematurely - that is, an engine will
876 go down after completing a result, and the client will be notified
877 of the unregistration and later receive the successful result.
878 """
879
880 outstanding = self._outstanding_dict[uuid]
881
882 for msg_id in list(outstanding):
883 if msg_id in self.results:
884 # we already
885 continue
886 try:
887 raise error.EngineError(
888 f"Engine {eid!r} died while running task {msg_id!r}"
889 )
890 except Exception:
891 content = error.wrap_exception()
892 # build a fake message:
893 msg = self.session.msg('apply_reply', content=content)
894 msg['parent_header']['msg_id'] = msg_id
895 msg['metadata']['engine'] = uuid
896 self._handle_apply_reply(msg)
897
898 def _handle_execute_reply(self, msg):
899 """Save the reply to an execute_request into our results.
900
901 execute messages are never actually used. apply is used instead.
902 """
903
904 parent = msg['parent_header']
905 if self._should_use_metadata_msg_id(msg):
906 msg_id = msg['metadata']['original_msg_id']
907 else:
908 msg_id = parent['msg_id']
909
910 future = self._futures.get(msg_id, None)
911 if msg_id not in self.outstanding:
912 if msg_id in self.history:
913 print(f"got stale result: {msg_id}")
914 else:
915 print(f"got unknown result: {msg_id}")
916 else:
917 self.outstanding.remove(msg_id)
918
919 content = msg['content']
920 header = msg['header']
921
922 # construct metadata:
923 md = self.metadata[msg_id]
924 md.update(self._extract_metadata(msg))
925
926 if md['is_coalescing']:
927 engine_uuids = md['engine_uuid'] or []
928 else:
929 engine_uuids = [md['engine_uuid']]
930
931 for engine_uuid in engine_uuids:
932 if engine_uuid is not None:
933 e_outstanding = self._outstanding_dict[engine_uuid]
934 if msg_id in e_outstanding:
935 e_outstanding.remove(msg_id)
936
937 # construct result:
938 if content['status'] == 'ok':
939 self.results[msg_id] = ExecuteReply(msg_id, content, md)
940 elif content['status'] == 'aborted':
941 self.results[msg_id] = error.TaskAborted(msg_id)
942 # aborted tasks will not get output
943 out_future = self._output_futures.get(msg_id)
944 if out_future and not out_future.done():
945 out_future.set_result(None)
946 elif content['status'] == 'resubmitted':
947 # TODO: handle resubmission
948 pass
949 else:
950 self.results[msg_id] = self._unwrap_exception(content)
951 if content['status'] != 'ok' and not content.get('engine_info'):
952 # not an engine failure, don't expect output
953 out_future = self._output_futures.get(msg_id)
954 if out_future and not out_future.done():
955 out_future.set_result(None)
956 if future:
957 future.set_result(self.results[msg_id])
958
959 def _should_use_metadata_msg_id(self, msg):
960 md = msg['metadata']
961 return md.get('is_broadcast', False) and md.get('is_coalescing', False)
962
963 def _handle_apply_reply(self, msg):
964 """Save the reply to an apply_request into our results."""
965 parent = msg['parent_header']
966 if self._should_use_metadata_msg_id(msg):
967 msg_id = msg['metadata']['original_msg_id']
968 else:
969 msg_id = parent['msg_id']
970
971 future = self._futures.get(msg_id, None)
972 if msg_id not in self.outstanding:
973 if msg_id in self.history:
974 print(f"got stale result: {msg_id}")
975 print(self.results[msg_id])
976 print(msg)
977 else:
978 print(f"got unknown result: {msg_id}")
979 else:
980 self.outstanding.remove(msg_id)
981 content = msg['content']
982 header = msg['header']
983
984 # construct metadata:
985 md = self.metadata[msg_id]
986 md.update(self._extract_metadata(msg))
987
988 if md['is_coalescing']:
989 engine_uuids = md['engine_uuid'] or []
990 else:
991 engine_uuids = [md['engine_uuid']]
992
993 for engine_uuid in engine_uuids:
994 if engine_uuid is not None:
995 e_outstanding = self._outstanding_dict[engine_uuid]
996 if msg_id in e_outstanding:
997 e_outstanding.remove(msg_id)
998
999 # construct result:
1000 if content['status'] == 'ok':
1001 if md.get('is_coalescing', False):
1002 deserialized_bufs = []
1003 bufs = msg['buffers']
1004 while bufs:
1005 deserialized, bufs = serialize.deserialize_object(bufs)
1006 deserialized_bufs.append(deserialized)
1007 self.results[msg_id] = deserialized_bufs
1008 else:
1009 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
1010 elif content['status'] == 'aborted':
1011 self.results[msg_id] = error.TaskAborted(msg_id)
1012 out_future = self._output_futures.get(msg_id)
1013 if out_future and not out_future.done():
1014 out_future.set_result(None)
1015 elif content['status'] == 'resubmitted':
1016 # TODO: handle resubmission
1017 pass
1018 else:
1019 self.results[msg_id] = self._unwrap_exception(content)
1020 if content['status'] != 'ok' and not content.get('engine_info'):
1021 # not an engine failure, don't expect output
1022 out_future = self._output_futures.get(msg_id)
1023 if out_future and not out_future.done():
1024 out_future.set_result(None)
1025 if future:
1026 future.set_result(self.results[msg_id])
1027
1028 def _make_io_loop(self):
1029 """Make my IOLoop. Override with IOLoop.current to return"""
1030 # runs first thing in the io thread
1031 # always create a fresh asyncio loop for the thread
1032 if os.name == "nt":
1033 asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
1034 loop = ioloop.IOLoop(make_current=False)
1035 return loop
1036
1037 def _stop_io_thread(self):
1038 """Stop my IO thread"""
1039 if self._io_loop:
1040 self._io_loop.add_callback(self._io_loop.stop)
1041 if self._io_thread and self._io_thread is not current_thread():
1042 self._io_thread.join()
1043
1044 def _setup_streams(self):
1045 self._query_stream = ZMQStream(self._query_socket, self._io_loop)
1046 self._query_stream.on_recv(self._dispatch_single_reply, copy=False)
1047 self._control_stream = ZMQStream(self._control_socket, self._io_loop)
1048 self._control_stream.on_recv(self._dispatch_single_reply, copy=False)
1049 self._mux_stream = ZMQStream(self._mux_socket, self._io_loop)
1050 self._mux_stream.on_recv(self._dispatch_reply, copy=False)
1051 self._task_stream = ZMQStream(self._task_socket, self._io_loop)
1052 self._task_stream.on_recv(self._dispatch_reply, copy=False)
1053 self._iopub_stream = ZMQStream(self._iopub_socket, self._io_loop)
1054 self._iopub_stream.on_recv(self._dispatch_iopub, copy=False)
1055 self._notification_stream = ZMQStream(self._notification_socket, self._io_loop)
1056 self._notification_stream.on_recv(self._dispatch_notification, copy=False)
1057
1058 self._broadcast_stream = ZMQStream(self._broadcast_socket, self._io_loop)
1059 self._broadcast_stream.on_recv(self._dispatch_reply, copy=False)
1060
1061 def _start_io_thread(self):
1062 """Start IOLoop in a background thread."""
1063 evt = Event()
1064 self._io_thread = Thread(target=self._io_main, args=(evt,))
1065 self._io_thread.daemon = True
1066 self._io_thread.start()
1067 # wait for the IOLoop to start
1068 for i in range(20):
1069 if evt.wait(1):
1070 return
1071 if not self._io_thread.is_alive():
1072 raise RuntimeError("IO Loop failed to start")
1073 else:
1074 raise RuntimeError(
1075 "Start event was never set. Maybe a problem in the IO thread."
1076 )
1077
1078 def _io_main(self, start_evt=None):
1079 """main loop for background IO thread"""
1080 self._io_loop = self._make_io_loop()
1081 self._setup_streams()
1082 # signal that start has finished
1083 # so that the main thread knows that all our attributes are defined
1084 if start_evt:
1085 start_evt.set()
1086 try:
1087 self._io_loop.start()
1088 finally:
1089 self._io_loop.close(all_fds=True)
1090
1091 @unpack_message
1092 def _dispatch_single_reply(self, msg):
1093 """Dispatch single (non-execution) replies"""
1094 msg_id = msg['parent_header'].get('msg_id', None)
1095 future = self._futures.get(msg_id)
1096 if future is not None:
1097 future.set_result(msg)
1098
1099 @unpack_message
1100 def _dispatch_notification(self, msg):
1101 """Dispatch notification messages"""
1102 msg_type = msg['header']['msg_type']
1103 handler = self._notification_handlers.get(msg_type, None)
1104 if handler is None:
1105 raise KeyError(f"Unhandled notification message type: {msg_type}")
1106 else:
1107 handler(msg)
1108
1109 @unpack_message
1110 def _dispatch_reply(self, msg):
1111 """handle execution replies waiting in ZMQ queue."""
1112 msg_type = msg['header']['msg_type']
1113 handler = self._queue_handlers.get(msg_type, None)
1114 if handler is None:
1115 raise KeyError(f"Unhandled reply message type: {msg_type}")
1116 else:
1117 handler(msg)
1118
1119 @unpack_message
1120 def _dispatch_iopub(self, msg):
1121 """handler for IOPub messages"""
1122 parent = msg['parent_header']
1123 if not parent or parent['session'] != self.session.session:
1124 # ignore IOPub messages not from here
1125 return
1126 msg_id = parent['msg_id']
1127 content = msg['content']
1128 header = msg['header']
1129 msg_type = msg['header']['msg_type']
1130
1131 if msg_type == 'status' and msg_id not in self.metadata:
1132 # ignore status messages if they aren't mine
1133 return
1134
1135 # init metadata:
1136 md = self.metadata[msg_id]
1137
1138 if md['engine_id'] is None and 'engine' in msg['metadata']:
1139 e_uuid = msg['metadata']['engine']
1140 try:
1141 md['engine_uuid'] = e_uuid
1142 md['engine_id'] = self._engines[e_uuid]
1143 except KeyError:
1144 pass
1145
1146 ip = get_ipython()
1147
1148 if msg_type == 'stream':
1149 name = content['name']
1150 new_text = (md[name] or '') + content['text']
1151 if '\r' in content['text']:
1152 new_text = _cr_pat.sub('', new_text)
1153 md[name] = new_text
1154 elif msg_type == 'error':
1155 md.update({'error': self._unwrap_exception(content)})
1156 elif msg_type == 'execute_input':
1157 md.update({'execute_input': content['code']})
1158 elif msg_type == 'display_data':
1159 md['outputs'].append(content)
1160 elif msg_type == 'execute_result':
1161 md['execute_result'] = content
1162 elif msg_type == 'data_message':
1163 data, remainder = serialize.deserialize_object(msg['buffers'])
1164 md['data'].update(data)
1165 elif msg_type == 'status':
1166 # idle message comes after all outputs
1167 if content['execution_state'] == 'idle':
1168 future = self._output_futures.get(msg_id)
1169 if future and not future.done():
1170 # TODO: should probably store actual outputs on the Future
1171 future.set_result(None)
1172 elif msg_type.startswith("comm_") and ip is not None and ip.kernel is not None:
1173 # only handle comm messages when we're in an IPython kernel
1174 if msg_type == "comm_open":
1175 # create proxy comm
1176 engine_uuid = msg['metadata'].get('engine', '')
1177 engine_ident = engine_uuid.encode("utf8", "replace")
1178 # DEBUG: engine_uuid can still be missing?!
1179
1180 comm = Comm(
1181 comm_id=content['comm_id'],
1182 primary=False,
1183 )
1184
1185 send_to_engine = partial(
1186 self._send,
1187 self._mux_socket,
1188 ident=engine_ident,
1189 )
1190
1191 def relay_comm(msg):
1192 send_to_engine(
1193 msg["msg_type"],
1194 content=msg['content'],
1195 metadata=msg['metadata'],
1196 buffers=msg["buffers"],
1197 )
1198
1199 comm.on_msg(relay_comm)
1200 comm.on_close(
1201 lambda: send_to_engine(
1202 "comm_close",
1203 content={
1204 "comm_id": comm.comm_id,
1205 },
1206 )
1207 )
1208 ip.kernel.comm_manager.register_comm(comm)
1209
1210 # relay all comm msgs
1211 ip.kernel.session.send(
1212 ip.kernel.iopub_socket,
1213 msg_type,
1214 content=msg['content'],
1215 metadata=msg['metadata'],
1216 buffers=msg['buffers'],
1217 # different parent!
1218 parent=ip.kernel.get_parent("shell"),
1219 )
1220
1221 msg_future = self._futures.get(msg_id, None)
1222 if msg_future:
1223 # Run any callback functions
1224 for callback in msg_future.iopub_callbacks:
1225 callback(msg)
1226
1227 def create_message_futures(self, msg_id, header, async_result=False, track=False):
1228 msg_future = MessageFuture(msg_id, header=header, track=track)
1229 futures = [msg_future]
1230 self._futures[msg_id] = msg_future
1231 if async_result:
1232 output = MessageFuture(msg_id, header=header)
1233 # add future for output
1234 self._output_futures[msg_id] = output
1235 # hook up metadata
1236 output.metadata = self.metadata[msg_id]
1237 output.metadata['submitted'] = util.utcnow()
1238 msg_future.output = output
1239 futures.append(output)
1240 return futures
1241
1242 def _send(
1243 self,
1244 socket,
1245 msg_type,
1246 content=None,
1247 parent=None,
1248 ident=None,
1249 buffers=None,
1250 track=False,
1251 header=None,
1252 metadata=None,
1253 track_outstanding=False,
1254 message_future_hook=None,
1255 ):
1256 """Send a message in the IO thread
1257
1258 returns msg object"""
1259 if self._closed:
1260 raise OSError("Connections have been closed.")
1261 msg = self.session.msg(
1262 msg_type, content=content, parent=parent, header=header, metadata=metadata
1263 )
1264 msg_id = msg['header']['msg_id']
1265
1266 expect_reply = msg_type not in {"comm_msg", "comm_close", "comm_open"}
1267
1268 if expect_reply and track_outstanding:
1269 # add to outstanding, history
1270 self.outstanding.add(msg_id)
1271 self.history.append(msg_id)
1272
1273 if ident:
1274 # possibly routed to a specific engine
1275 ident_str = ident
1276 if isinstance(ident_str, list):
1277 ident_str = ident_str[-1]
1278 ident_str = ident_str.decode("utf-8")
1279 if ident_str in self._engines.values():
1280 # save for later, in case of engine death
1281 self._outstanding_dict[ident_str].add(msg_id)
1282 self.metadata['submitted'] = util.utcnow()
1283
1284 if expect_reply:
1285 futures = self.create_message_futures(
1286 msg_id,
1287 msg['header'],
1288 async_result=msg_type in {'execute_request', 'apply_request'},
1289 track=track,
1290 )
1291 if message_future_hook is not None:
1292 message_future_hook(futures[0])
1293
1294 def cleanup(f):
1295 """Purge caches on Future resolution"""
1296 self.results.pop(msg_id, None)
1297 self._futures.pop(msg_id, None)
1298 self._output_futures.pop(msg_id, None)
1299 self.metadata.pop(msg_id, None)
1300
1301 multi_future(futures).add_done_callback(cleanup)
1302
1303 def _really_send():
1304 sent = self.session.send(
1305 socket, msg, track=track, buffers=buffers, ident=ident
1306 )
1307 if track:
1308 futures[0].tracker.set_result(sent['tracker'])
1309
1310 # hand off actual send to IO thread
1311 self._io_loop.add_callback(_really_send)
1312 if expect_reply:
1313 return futures[0]
1314
1315 def _send_recv(self, *args, **kwargs):
1316 """Send a message in the IO thread and return its reply"""
1317 future = self._send(*args, **kwargs)
1318 future.wait()
1319 return future.result()
1320
1321 # --------------------------------------------------------------------------
1322 # len, getitem
1323 # --------------------------------------------------------------------------
1324
1325 def __len__(self):
1326 """len(client) returns # of engines."""
1327 return len(self.ids)
1328
1329 def __getitem__(self, key):
1330 """index access returns DirectView multiplexer objects
1331
1332 Must be int, slice, or list/tuple/range of ints"""
1333 if not isinstance(key, (int, slice, tuple, list, range)):
1334 raise TypeError(f"key by int/slice/iterable of ints only, not {type(key)}")
1335 else:
1336 return self.direct_view(key)
1337
1338 def __iter__(self):
1339 """Since we define getitem, Client is iterable
1340
1341 but unless we also define __iter__, it won't work correctly unless engine IDs
1342 start at zero and are continuous.
1343 """
1344 for eid in self.ids:
1345 yield self.direct_view(eid)
1346
1347 # --------------------------------------------------------------------------
1348 # Begin public methods
1349 # --------------------------------------------------------------------------
1350
1351 @property
1352 def ids(self):
1353 # always copy:
1354 return list(self._ids)
1355
1356 def activate(self, targets='all', suffix=''):
1357 """Create a DirectView and register it with IPython magics
1358
1359 Defines the magics `%px, %autopx, %pxresult, %%px`
1360
1361 Parameters
1362 ----------
1363 targets : int, list of ints, or 'all'
1364 The engines on which the view's magics will run
1365 suffix : str [default: '']
1366 The suffix, if any, for the magics. This allows you to have
1367 multiple views associated with parallel magics at the same time.
1368
1369 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
1370 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
1371 on engine 0.
1372 """
1373 view = self.direct_view(targets)
1374 view.block = True
1375 view.activate(suffix)
1376 return view
1377
1378 def close(self, linger=None):
1379 """Close my zmq Sockets
1380
1381 If `linger`, set the zmq LINGER socket option,
1382 which allows discarding of messages.
1383 """
1384 if self._closed:
1385 return
1386 self._stop_io_thread()
1387 snames = [trait for trait in self.trait_names() if trait.endswith("socket")]
1388 for name in snames:
1389 socket = getattr(self, name)
1390 if socket is not None and not socket.closed:
1391 if linger is not None:
1392 socket.close(linger=linger)
1393 else:
1394 socket.close()
1395 self._closed = True
1396
1397 def spin_thread(self, interval=1):
1398 """DEPRECATED, DOES NOTHING"""
1399 warnings.warn(
1400 "Client.spin_thread is deprecated now that IO is always in a thread",
1401 DeprecationWarning,
1402 )
1403
1404 def stop_spin_thread(self):
1405 """DEPRECATED, DOES NOTHING"""
1406 warnings.warn(
1407 "Client.spin_thread is deprecated now that IO is always in a thread",
1408 DeprecationWarning,
1409 )
1410
1411 def spin(self):
1412 """DEPRECATED, DOES NOTHING"""
1413 warnings.warn(
1414 "Client.spin is deprecated now that IO is in a thread", DeprecationWarning
1415 )
1416
1417 def _await_futures(self, futures, timeout):
1418 """Wait for a collection of futures"""
1419 if not futures:
1420 return True
1421
1422 event = Event()
1423 if timeout and timeout < 0:
1424 timeout = None
1425
1426 f = multi_future(futures)
1427 f.add_done_callback(lambda f: event.set())
1428 return event.wait(timeout)
1429
1430 def _futures_for_msgs(self, msg_ids):
1431 """Turn msg_ids into Futures
1432
1433 msg_ids not in futures dict are presumed done.
1434 """
1435 futures = []
1436 for msg_id in msg_ids:
1437 f = self._futures.get(msg_id, None)
1438 if f:
1439 futures.append(f)
1440 return futures
1441
1442 def wait_for_engines(
1443 self, n=None, *, timeout=-1, block=True, interactive=None, widget=None
1444 ):
1445 """Wait for `n` engines to become available.
1446
1447 Returns when `n` engines are available,
1448 or raises TimeoutError if `timeout` is reached
1449 before `n` engines are ready.
1450
1451 Parameters
1452 ----------
1453 n : int
1454 Number of engines to wait for.
1455 timeout : float
1456 Time (in seconds) to wait before raising a TimeoutError
1457 block : bool
1458 if False, return Future instead of waiting
1459 interactive : bool
1460 default: True if in IPython, False otherwise.
1461 if True, show a progress bar while waiting for engines
1462 widget : bool
1463 default: True if in an IPython kernel (notebook), False otherwise.
1464 Only has an effect if `interactive` is True.
1465 if True, forces use of widget progress bar.
1466 If False, forces use of terminal tqdm.
1467
1468 Returns
1469 ------
1470 f : concurrent.futures.Future or None
1471 Future object to wait on if block is False,
1472 None if block is True.
1473
1474 Raises
1475 ------
1476 TimeoutError : if timeout is reached.
1477 """
1478 if n is None:
1479 # get n from cluster, if not specified
1480 if self.cluster is None:
1481 raise TypeError("n engines to wait for must be specified")
1482
1483 if self.cluster.n:
1484 n = self.cluster.n
1485 else:
1486 # compute n from engine sets,
1487 # e.g. the default where n is calculated at runtime from `cpu_count()`
1488 n = sum(engine_set.n for engine_set in self.cluster.engines.values())
1489
1490 if len(self.ids) >= n:
1491 if block:
1492 return
1493 else:
1494 f = Future()
1495 f.set_result(None)
1496 return f
1497 tic = now = time.perf_counter()
1498 if timeout >= 0:
1499 deadline = tic + timeout
1500 else:
1501 deadline = None
1502 seconds_remaining = 1000
1503
1504 if interactive is None:
1505 if ipp._NONINTERACTIVE:
1506 interactive = False
1507 else:
1508 interactive = get_ipython() is not None
1509
1510 if interactive:
1511 progress_bar = util.progress(
1512 widget=widget,
1513 initial=len(self.ids),
1514 total=n,
1515 unit='engine',
1516 )
1517
1518 # watch for engine-stop events
1519
1520 engine_stop_future = Future()
1521 if self.cluster and self.cluster.engines:
1522 # we have a parent cluster,
1523 # monitor for engines stopping
1524 def _signal_stopped(stop_data):
1525 if not engine_stop_future.done():
1526 engine_stop_future.set_result(stop_data)
1527
1528 def _remove_signal_stopped(f, es):
1529 try:
1530 es.stop_callbacks.remove(_signal_stopped)
1531 except ValueError:
1532 # already removed
1533 pass
1534
1535 for es in self.cluster.engines.values():
1536 es.on_stop(_signal_stopped)
1537 engine_stop_future.add_done_callback(
1538 partial(_remove_signal_stopped, es=es)
1539 )
1540
1541 future = Future()
1542
1543 def cancel_engine_stop(_):
1544 if not engine_stop_future.done():
1545 engine_stop_future.cancel()
1546
1547 future.add_done_callback(cancel_engine_stop)
1548
1549 def notice_engine_stop(f):
1550 if future.done():
1551 return
1552 stop_data = f.result()
1553 future.set_exception(error.EngineError(f"Engine set stopped: {stop_data}"))
1554
1555 engine_stop_future.add_done_callback(notice_engine_stop)
1556
1557 def notify(event):
1558 if future.done():
1559 return
1560 if event["event"] == "unregister":
1561 future.set_exception(
1562 error.EngineError(
1563 f"Engine {event['id']} unregistered while waiting for engines."
1564 )
1565 )
1566 return
1567 current_n = len(self.ids)
1568 if interactive:
1569 progress_bar.update(current_n - progress_bar.n)
1570 if current_n >= n:
1571 # ensure we refresh when we finish
1572 if interactive:
1573 progress_bar.close()
1574 future.set_result(None)
1575
1576 self._registration_callbacks.append(notify)
1577 future.add_done_callback(lambda f: self._registration_callbacks.remove(notify))
1578
1579 def on_timeout():
1580 """Called when timeout is reached"""
1581 if future.done():
1582 return
1583
1584 current_n = len(self.ids)
1585 if current_n >= n:
1586 future.set_result(None)
1587 else:
1588 future.set_exception(
1589 TimeoutError(
1590 f"{n} engines not ready in {timeout} seconds. Currently ready: {current_n}"
1591 )
1592 )
1593
1594 def schedule_timeout():
1595 handle = self._io_loop.add_timeout(
1596 self._io_loop.time() + timeout, on_timeout
1597 )
1598 future.add_done_callback(lambda f: self._io_loop.remove_timeout(handle))
1599
1600 if timeout >= 0:
1601 self._io_loop.add_callback(schedule_timeout)
1602
1603 if block:
1604 return future.result()
1605 else:
1606 return future
1607
1608 def wait(self, jobs=None, timeout=-1):
1609 """waits on one or more `jobs`, for up to `timeout` seconds.
1610
1611 Parameters
1612 ----------
1613 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1614 ints are indices to self.history
1615 strs are msg_ids
1616 default: wait on all outstanding messages
1617 timeout : float
1618 a time in seconds, after which to give up.
1619 default is -1, which means no timeout
1620
1621 Returns
1622 -------
1623 True : when all msg_ids are done
1624 False : timeout reached, some msg_ids still outstanding
1625 """
1626 futures = []
1627 if jobs is None:
1628 if not self.outstanding:
1629 return True
1630 # make a copy, so that we aren't passing a mutable collection to _futures_for_msgs
1631 theids = set(self.outstanding)
1632 else:
1633 if isinstance(jobs, (str, int, AsyncResult)) or not isinstance(
1634 jobs, Iterable
1635 ):
1636 jobs = [jobs]
1637 theids = set()
1638 for job in jobs:
1639 if isinstance(job, int):
1640 # index access
1641 job = self.history[job]
1642 elif isinstance(job, AsyncResult):
1643 theids.update(job.msg_ids)
1644 continue
1645 elif _is_future(job):
1646 futures.append(job)
1647 continue
1648 theids.add(job)
1649 if not futures and not theids.intersection(self.outstanding):
1650 return True
1651
1652 futures.extend(self._futures_for_msgs(theids))
1653 return self._await_futures(futures, timeout)
1654
1655 def wait_interactive(self, jobs=None, interval=1.0, timeout=-1.0):
1656 """Wait interactively for jobs
1657
1658 If no job is specified, will wait for all outstanding jobs to complete.
1659 """
1660 if jobs is None:
1661 # get futures for results
1662 futures = [f for f in self._futures.values() if hasattr(f, 'output')]
1663 if not futures:
1664 return
1665 ar = AsyncResult(self, futures, owner=False)
1666 else:
1667 ar = self._asyncresult_from_jobs(jobs, owner=False)
1668 return ar.wait_interactive(interval=interval, timeout=timeout)
1669
1670 # --------------------------------------------------------------------------
1671 # Control methods
1672 # --------------------------------------------------------------------------
1673
1674 def _send_control_request(self, targets, msg_type, content, block):
1675 """Send a request on the control channel"""
1676 target_identities = self._build_targets(targets)[0]
1677 futures = []
1678 for ident in target_identities:
1679 futures.append(
1680 self._send(self._control_stream, msg_type, content=content, ident=ident)
1681 )
1682 if not block:
1683 return multi_future(futures)
1684 for future in futures:
1685 future.wait()
1686 msg = future.result()
1687 if msg['content']['status'] != 'ok':
1688 raise self._unwrap_exception(msg['content'])
1689
1690 def send_signal(self, sig, targets=None, block=None):
1691 """Send a signal target(s).
1692
1693 Parameters
1694 ----------
1695
1696 sig: int or str
1697 The signal number or name to send.
1698 If a str, will evaluate to getattr(signal, sig) on the engine,
1699 which is useful for sending signals cross-platform.
1700
1701 .. versionadded:: 7.0
1702 """
1703 block = self.block if block is None else block
1704 return self._send_control_request(
1705 targets=targets,
1706 msg_type='signal_request',
1707 content={'sig': sig},
1708 block=block,
1709 )
1710
1711 def clear(self, targets=None, block=None):
1712 """Clear the namespace in target(s)."""
1713 block = self.block if block is None else block
1714 return self._send_control_request(
1715 targets=targets, msg_type='clear_request', content={}, block=block
1716 )
1717
1718 def abort(self, jobs=None, targets=None, block=None):
1719 """Abort specific jobs from the execution queues of target(s).
1720
1721 This is a mechanism to prevent jobs that have already been submitted
1722 from executing.
1723 To halt a running job,
1724 you must interrupt the engine(s) by sending a signal.
1725 This can be done via os.kill for local engines,
1726 or :meth:`.Cluster.signal_engines` for multiple engines.
1727
1728 Parameters
1729 ----------
1730 jobs : msg_id, list of msg_ids, or AsyncResult
1731 The jobs to be aborted
1732
1733 If unspecified/None: abort all outstanding jobs.
1734
1735 """
1736 block = self.block if block is None else block
1737 jobs = jobs if jobs is not None else list(self.outstanding)
1738
1739 msg_ids = []
1740 if isinstance(jobs, (str, AsyncResult)):
1741 jobs = [jobs]
1742 bad_ids = [obj for obj in jobs if not isinstance(obj, (str, AsyncResult))]
1743 if bad_ids:
1744 raise TypeError(
1745 f"Invalid msg_id type {bad_ids[0]!r}, expected str or AsyncResult"
1746 )
1747 for j in jobs:
1748 if isinstance(j, AsyncResult):
1749 msg_ids.extend(j.msg_ids)
1750 else:
1751 msg_ids.append(j)
1752 content = dict(msg_ids=msg_ids)
1753
1754 return self._send_control_request(
1755 targets,
1756 msg_type='abort_request',
1757 content=content,
1758 block=block,
1759 )
1760
1761 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1762 """Terminates one or more engine processes, optionally including the hub.
1763
1764 Parameters
1765 ----------
1766 targets : list of ints or 'all' [default: all]
1767 Which engines to shutdown.
1768 hub : bool [default: False]
1769 Whether to include the Hub. hub=True implies targets='all'.
1770 block : bool [default: self.block]
1771 Whether to wait for clean shutdown replies or not.
1772 restart : bool [default: False]
1773 NOT IMPLEMENTED
1774 whether to restart engines after shutting them down.
1775 """
1776 from ipyparallel.error import NoEnginesRegistered
1777
1778 if restart:
1779 raise NotImplementedError("Engine restart is not yet implemented")
1780
1781 block = self.block if block is None else block
1782 if hub:
1783 targets = 'all'
1784 try:
1785 targets = self._build_targets(targets)[0]
1786 except NoEnginesRegistered:
1787 targets = []
1788
1789 futures = []
1790 for t in targets:
1791 futures.append(
1792 self._send(
1793 self._control_stream,
1794 'shutdown_request',
1795 content={'restart': restart},
1796 ident=t,
1797 )
1798 )
1799 error = False
1800 if block or hub:
1801 for f in futures:
1802 f.wait()
1803 msg = f.result()
1804 if msg['content']['status'] != 'ok':
1805 error = self._unwrap_exception(msg['content'])
1806
1807 if hub:
1808 # don't trigger close on shutdown notification, which will prevent us from receiving the reply
1809 self._notification_handlers['shutdown_notification'] = lambda msg: None
1810 msg = self._send_recv(self._query_stream, 'shutdown_request')
1811 if msg['content']['status'] != 'ok':
1812 error = self._unwrap_exception(msg['content'])
1813 if not error:
1814 self.close()
1815
1816 if error:
1817 raise error
1818
1819 def become_dask(
1820 self, targets='all', port=0, nanny=False, scheduler_args=None, **worker_args
1821 ):
1822 """Turn the IPython cluster into a dask.distributed cluster
1823
1824 Parameters
1825 ----------
1826 targets : target spec (default: all)
1827 Which engines to turn into dask workers.
1828 port : int (default: random)
1829 Which port
1830 nanny : bool (default: False)
1831 Whether to start workers as subprocesses instead of in the engine process.
1832 Using a nanny allows restarting the worker processes via ``executor.restart``.
1833 scheduler_args : dict
1834 Keyword arguments (e.g. ip) to pass to the distributed.Scheduler constructor.
1835 **worker_args
1836 Any additional keyword arguments (e.g. nthreads) are passed to the distributed.Worker constructor.
1837
1838 Returns
1839 -------
1840 client = distributed.Client
1841 A dask.distributed.Client connected to the dask cluster.
1842 """
1843 import distributed
1844
1845 dview = self.direct_view(targets)
1846
1847 if scheduler_args is None:
1848 scheduler_args = {}
1849 else:
1850 scheduler_args = dict(scheduler_args) # copy
1851
1852 # Start a Scheduler on the Hub:
1853 reply = self._send_recv(
1854 self._query_stream,
1855 'become_dask_request',
1856 {'scheduler_args': scheduler_args},
1857 )
1858 if reply['content']['status'] != 'ok':
1859 raise self._unwrap_exception(reply['content'])
1860 distributed_info = reply['content']
1861
1862 # Start a Worker on the selected engines:
1863 worker_args['address'] = distributed_info['address']
1864 worker_args['nanny'] = nanny
1865 # distributed 2.0 renamed ncores to nthreads
1866 if int(distributed.__version__.partition(".")[0]) >= 2:
1867 nthreads = "nthreads"
1868 else:
1869 nthreads = "ncores"
1870 # set default nthreads=1, since that's how an IPython cluster is typically set up.
1871 worker_args.setdefault(nthreads, 1)
1872 dview.apply_sync(util.become_dask_worker, **worker_args)
1873
1874 # Finally, return a Client connected to the Scheduler
1875 try:
1876 distributed_Client = distributed.Client
1877 except AttributeError:
1878 # For distributed pre-1.18.1
1879 distributed_Client = distributed.Executor
1880
1881 client = distributed_Client('{address}'.format(**distributed_info))
1882
1883 return client
1884
1885 def stop_dask(self, targets='all'):
1886 """Stop the distributed Scheduler and Workers started by become_dask.
1887
1888 Parameters
1889 ----------
1890 targets : target spec (default: all)
1891 Which engines to stop dask workers on.
1892 """
1893 dview = self.direct_view(targets)
1894
1895 # Start a Scheduler on the Hub:
1896 reply = self._send_recv(self._query_stream, 'stop_distributed_request')
1897 if reply['content']['status'] != 'ok':
1898 raise self._unwrap_exception(reply['content'])
1899
1900 # Finally, stop all the Workers on the engines
1901 dview.apply_sync(util.stop_distributed_worker)
1902
1903 # aliases:
1904 become_distributed = become_dask
1905 stop_distributed = stop_dask
1906
1907 # --------------------------------------------------------------------------
1908 # Execution related methods
1909 # --------------------------------------------------------------------------
1910
1911 def _maybe_raise(self, result):
1912 """wrapper for maybe raising an exception if apply failed."""
1913 if isinstance(result, error.RemoteError):
1914 raise result
1915
1916 return result
1917
1918 def send_apply_request(
1919 self,
1920 socket,
1921 f,
1922 args=None,
1923 kwargs=None,
1924 metadata=None,
1925 track=False,
1926 ident=None,
1927 message_future_hook=None,
1928 ):
1929 """construct and send an apply message via a socket.
1930
1931 This is the principal method with which all engine execution is performed by views.
1932 """
1933
1934 if self._closed:
1935 raise RuntimeError(
1936 "Client cannot be used after its sockets have been closed"
1937 )
1938
1939 # defaults:
1940 args = args if args is not None else []
1941 kwargs = kwargs if kwargs is not None else {}
1942 metadata = metadata if metadata is not None else {}
1943
1944 # validate arguments
1945 if not callable(f) and not isinstance(f, (Reference, PrePickled)):
1946 raise TypeError(f"f must be callable, not {type(f)}")
1947 if not isinstance(args, (tuple, list)):
1948 raise TypeError(f"args must be tuple or list, not {type(args)}")
1949 if not isinstance(kwargs, dict):
1950 raise TypeError(f"kwargs must be dict, not {type(kwargs)}")
1951 if not isinstance(metadata, dict):
1952 raise TypeError(f"metadata must be dict, not {type(metadata)}")
1953
1954 bufs = serialize.pack_apply_message(
1955 f,
1956 args,
1957 kwargs,
1958 buffer_threshold=self.session.buffer_threshold,
1959 item_threshold=self.session.item_threshold,
1960 )
1961
1962 future = self._send(
1963 socket,
1964 "apply_request",
1965 buffers=bufs,
1966 ident=ident,
1967 metadata=metadata,
1968 track=track,
1969 track_outstanding=True,
1970 message_future_hook=message_future_hook,
1971 )
1972 msg_id = future.msg_id
1973
1974 return future
1975
1976 def send_execute_request(
1977 self,
1978 socket,
1979 code,
1980 silent=True,
1981 metadata=None,
1982 ident=None,
1983 message_future_hook=None,
1984 ):
1985 """construct and send an execute request via a socket."""
1986
1987 if self._closed:
1988 raise RuntimeError(
1989 "Client cannot be used after its sockets have been closed"
1990 )
1991
1992 # defaults:
1993 metadata = metadata if metadata is not None else {}
1994
1995 # validate arguments
1996 if not isinstance(code, str):
1997 raise TypeError(f"code must be text, not {type(code)}")
1998 if not isinstance(metadata, dict):
1999 raise TypeError(f"metadata must be dict, not {type(metadata)}")
2000
2001 content = dict(code=code, silent=bool(silent), user_expressions={})
2002
2003 future = self._send(
2004 socket,
2005 "execute_request",
2006 content=content,
2007 ident=ident,
2008 metadata=metadata,
2009 track_outstanding=True,
2010 message_future_hook=message_future_hook,
2011 )
2012
2013 return future
2014
2015 # --------------------------------------------------------------------------
2016 # construct a View object
2017 # --------------------------------------------------------------------------
2018
2019 def load_balanced_view(self, targets=None, **kwargs):
2020 """construct a DirectView object.
2021
2022 If no arguments are specified, create a LoadBalancedView
2023 using all engines.
2024
2025 Parameters
2026 ----------
2027 targets : list,slice,int,etc. [default: use all engines]
2028 The subset of engines across which to load-balance execution
2029 **kwargs : passed to LoadBalancedView
2030 """
2031 if targets == 'all':
2032 targets = None
2033 if targets is not None:
2034 targets = self._build_targets(targets)[1]
2035 return LoadBalancedView(
2036 client=self, socket=self._task_stream, targets=targets, **kwargs
2037 )
2038
2039 def executor(self, targets=None):
2040 """Construct a PEP-3148 Executor with a LoadBalancedView
2041
2042 Parameters
2043 ----------
2044 targets : list,slice,int,etc. [default: use all engines]
2045 The subset of engines across which to load-balance execution
2046
2047 Returns
2048 -------
2049 executor: Executor
2050 The Executor object
2051 """
2052 return self.load_balanced_view(targets).executor
2053
2054 def direct_view(self, targets='all', **kwargs):
2055 """construct a DirectView object.
2056
2057 If no targets are specified, create a DirectView using all engines.
2058
2059 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
2060 evaluate the target engines at each execution, whereas rc[:] will connect to
2061 all *current* engines, and that list will not change.
2062
2063 That is, 'all' will always use all engines, whereas rc[:] will not use
2064 engines added after the DirectView is constructed.
2065
2066 Parameters
2067 ----------
2068 targets : list,slice,int,etc. [default: use all engines]
2069 The engines to use for the View
2070 **kwargs : passed to DirectView
2071 """
2072 single = isinstance(targets, int)
2073 # allow 'all' to be lazily evaluated at each execution
2074 if targets != 'all':
2075 targets = self._build_targets(targets)[1]
2076 if single:
2077 targets = targets[0]
2078 return DirectView(
2079 client=self, socket=self._mux_stream, targets=targets, **kwargs
2080 )
2081
2082 def broadcast_view(self, targets='all', is_coalescing=False, **kwargs):
2083 """construct a BroadCastView object.
2084 If no arguments are specified, create a BroadCastView using all engines
2085 using all engines.
2086
2087 Parameters
2088 ----------
2089 targets : list,slice,int,etc. [default: use all engines]
2090 The subset of engines across which to load-balance execution
2091 is_coalescing : scheduler collects all messages from engines and returns them as one
2092 **kwargs : passed to BroadCastView
2093 """
2094 targets = self._build_targets(targets)[1]
2095
2096 bcast_view = BroadcastView(
2097 client=self,
2098 socket=self._broadcast_stream,
2099 targets=targets,
2100 **kwargs,
2101 )
2102 bcast_view.is_coalescing = is_coalescing
2103 return bcast_view
2104
2105 # --------------------------------------------------------------------------
2106 # Query methods
2107 # --------------------------------------------------------------------------
2108
2109 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
2110 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
2111
2112 If the client already has the results, no request to the Hub will be made.
2113
2114 This is a convenient way to construct AsyncResult objects, which are wrappers
2115 that include metadata about execution, and allow for awaiting results that
2116 were not submitted by this Client.
2117
2118 It can also be a convenient way to retrieve the metadata associated with
2119 blocking execution, since it always retrieves
2120
2121 Examples
2122 --------
2123 ::
2124
2125 In [10]: r = client.apply()
2126
2127 Parameters
2128 ----------
2129 indices_or_msg_ids : integer history index, str msg_id, AsyncResult,
2130 or a list of same.
2131 The indices or msg_ids of indices to be retrieved
2132 block : bool
2133 Whether to wait for the result to be done
2134 owner : bool [default: True]
2135 Whether this AsyncResult should own the result.
2136 If so, calling `ar.get()` will remove data from the
2137 client's result and metadata cache.
2138 There should only be one owner of any given msg_id.
2139
2140 Returns
2141 -------
2142 AsyncResult
2143 A single AsyncResult object will always be returned.
2144 AsyncHubResult
2145 A subclass of AsyncResult that retrieves results from the Hub
2146
2147 """
2148 block = self.block if block is None else block
2149 if indices_or_msg_ids is None:
2150 indices_or_msg_ids = -1
2151
2152 ar = self._asyncresult_from_jobs(indices_or_msg_ids, owner=owner)
2153
2154 if block:
2155 ar.wait()
2156
2157 return ar
2158
2159 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
2160 """Resubmit one or more tasks.
2161
2162 in-flight tasks may not be resubmitted.
2163
2164 Parameters
2165 ----------
2166 indices_or_msg_ids : integer history index, str msg_id, or list of either
2167 The indices or msg_ids of indices to be retrieved
2168 block : bool
2169 Whether to wait for the result to be done
2170
2171 Returns
2172 -------
2173 AsyncHubResult
2174 A subclass of AsyncResult that retrieves results from the Hub
2175
2176 """
2177 block = self.block if block is None else block
2178 if indices_or_msg_ids is None:
2179 indices_or_msg_ids = -1
2180
2181 theids = self._msg_ids_from_jobs(indices_or_msg_ids)
2182 content = dict(msg_ids=theids)
2183
2184 reply = self._send_recv(self._query_stream, 'resubmit_request', content)
2185 content = reply['content']
2186 if content['status'] != 'ok':
2187 raise self._unwrap_exception(content)
2188 mapping = content['resubmitted']
2189 new_ids = [mapping[msg_id] for msg_id in theids]
2190
2191 ar = AsyncHubResult(self, new_ids)
2192
2193 if block:
2194 ar.wait()
2195
2196 return ar
2197
2198 def result_status(self, msg_ids, status_only=True):
2199 """Check on the status of the result(s) of the apply request with `msg_ids`.
2200
2201 If status_only is False, then the actual results will be retrieved, else
2202 only the status of the results will be checked.
2203
2204 Parameters
2205 ----------
2206 msg_ids : list of msg_ids
2207 if int:
2208 Passed as index to self.history for convenience.
2209 status_only : bool (default: True)
2210 if False:
2211 Retrieve the actual results of completed tasks.
2212
2213 Returns
2214 -------
2215 results : dict
2216 There will always be the keys 'pending' and 'completed', which will
2217 be lists of msg_ids that are incomplete or complete. If `status_only`
2218 is False, then completed results will be keyed by their `msg_id`.
2219 """
2220 theids = self._msg_ids_from_jobs(msg_ids)
2221
2222 completed = []
2223 local_results = {}
2224
2225 # comment this block out to temporarily disable local shortcut:
2226 for msg_id in theids:
2227 if msg_id in self.results:
2228 completed.append(msg_id)
2229 local_results[msg_id] = self.results[msg_id]
2230 theids.remove(msg_id)
2231
2232 if theids: # some not locally cached
2233 content = dict(msg_ids=theids, status_only=status_only)
2234 reply = self._send_recv(
2235 self._query_stream, "result_request", content=content
2236 )
2237 content = reply['content']
2238 if content['status'] != 'ok':
2239 raise self._unwrap_exception(content)
2240 buffers = reply['buffers']
2241 else:
2242 content = dict(completed=[], pending=[])
2243
2244 content['completed'].extend(completed)
2245
2246 if status_only:
2247 return content
2248
2249 failures = []
2250 # load cached results into result:
2251 content.update(local_results)
2252
2253 # update cache with results:
2254 for msg_id in sorted(theids):
2255 if msg_id in content['completed']:
2256 rec = content[msg_id]
2257 parent = util.extract_dates(rec['header'])
2258 header = util.extract_dates(rec['result_header'])
2259 rcontent = rec['result_content']
2260 iodict = rec['io']
2261 if isinstance(rcontent, str):
2262 rcontent = self.session.unpack(rcontent)
2263
2264 md = self.metadata[msg_id]
2265 md_msg = dict(
2266 content=rcontent,
2267 parent_header=parent,
2268 header=header,
2269 metadata=rec['result_metadata'],
2270 )
2271 md.update(self._extract_metadata(md_msg))
2272 if rec.get('received'):
2273 md['received'] = util._parse_date(rec['received'])
2274 md.update(iodict)
2275
2276 if rcontent['status'] == 'ok':
2277 if header['msg_type'] == 'apply_reply':
2278 res, buffers = serialize.deserialize_object(buffers)
2279 elif header['msg_type'] == 'execute_reply':
2280 res = ExecuteReply(msg_id, rcontent, md)
2281 else:
2282 raise KeyError(
2283 "unhandled msg type: {!r}".format(header['msg_type'])
2284 )
2285 else:
2286 res = self._unwrap_exception(rcontent)
2287 failures.append(res)
2288
2289 self.results[msg_id] = res
2290 content[msg_id] = res
2291
2292 if len(theids) == 1 and failures:
2293 raise failures[0]
2294
2295 error.collect_exceptions(failures, "result_status")
2296 return content
2297
2298 def queue_status(self, targets='all', verbose=False):
2299 """Fetch the status of engine queues.
2300
2301 Parameters
2302 ----------
2303 targets : int/str/list of ints/strs
2304 the engines whose states are to be queried.
2305 default : all
2306 verbose : bool
2307 Whether to return lengths only, or lists of ids for each element
2308 """
2309 if targets == 'all':
2310 # allow 'all' to be evaluated on the engine
2311 engine_ids = None
2312 else:
2313 engine_ids = self._build_targets(targets)[1]
2314 content = dict(targets=engine_ids, verbose=verbose)
2315 reply = self._send_recv(self._query_stream, "queue_request", content=content)
2316 content = reply['content']
2317 status = content.pop('status')
2318 if status != 'ok':
2319 raise self._unwrap_exception(content)
2320 content = util.int_keys(content)
2321 if isinstance(targets, int):
2322 return content[targets]
2323 else:
2324 return content
2325
2326 def _msg_ids_from_target(self, targets=None):
2327 """Build a list of msg_ids from the list of engine targets"""
2328 if not targets: # needed as _build_targets otherwise uses all engines
2329 return []
2330 target_ids = self._build_targets(targets)[0]
2331 return [
2332 md_id
2333 for md_id in self.metadata
2334 if self.metadata[md_id]["engine_uuid"] in target_ids
2335 ]
2336
2337 def _msg_ids_from_jobs(self, jobs=None):
2338 """Given a 'jobs' argument, convert it to a list of msg_ids.
2339
2340 Can be either one or a list of:
2341
2342 - msg_id strings
2343 - integer indices to this Client's history
2344 - AsyncResult objects
2345 """
2346 if not isinstance(jobs, (list, tuple, set, types.GeneratorType)):
2347 jobs = [jobs]
2348 msg_ids = []
2349 for job in jobs:
2350 if isinstance(job, int):
2351 msg_ids.append(self.history[job])
2352 elif isinstance(job, str):
2353 msg_ids.append(job)
2354 elif isinstance(job, AsyncResult):
2355 msg_ids.extend(job.msg_ids)
2356 else:
2357 raise TypeError(f"Expected msg_id, int, or AsyncResult, got {job!r}")
2358 return msg_ids
2359
2360 def _asyncresult_from_jobs(self, jobs=None, owner=False):
2361 """Construct an AsyncResult from msg_ids or asyncresult objects"""
2362 if not isinstance(jobs, (list, tuple, set, types.GeneratorType)):
2363 single = True
2364 jobs = [jobs]
2365 else:
2366 single = False
2367 futures = []
2368 msg_ids = []
2369 for job in jobs:
2370 if isinstance(job, int):
2371 job = self.history[job]
2372 if isinstance(job, str):
2373 if job in self._futures:
2374 futures.append(job)
2375 elif job in self.results:
2376 f = MessageFuture(job)
2377 f.set_result(self.results[job])
2378 f.output = Future()
2379 f.output.metadata = self.metadata[job]
2380 f.output.set_result(None)
2381 futures.append(f)
2382 else:
2383 msg_ids.append(job)
2384 elif isinstance(job, AsyncResult):
2385 if job._children:
2386 futures.extend(job._children)
2387 else:
2388 msg_ids.extend(job.msg_ids)
2389 else:
2390 raise TypeError(f"Expected msg_id, int, or AsyncResult, got {job!r}")
2391 if msg_ids:
2392 if single:
2393 msg_ids = msg_ids[0]
2394 return AsyncHubResult(self, msg_ids, owner=owner)
2395 else:
2396 if single and futures:
2397 futures = futures[0]
2398 return AsyncResult(self, futures, owner=owner)
2399
2400 def purge_local_results(self, jobs=[], targets=[]):
2401 """Clears the client caches of results and their metadata.
2402
2403 Individual results can be purged by msg_id, or the entire
2404 history of specific targets can be purged.
2405
2406 Use `purge_local_results('all')` to scrub everything from the Clients's
2407 results and metadata caches.
2408
2409 After this call all `AsyncResults` are invalid and should be discarded.
2410
2411 If you must "reget" the results, you can still do so by using
2412 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
2413 redownload the results from the hub if they are still available
2414 (i.e `client.purge_hub_results(...)` has not been called.
2415
2416 Parameters
2417 ----------
2418 jobs : str or list of str or AsyncResult objects
2419 the msg_ids whose results should be purged.
2420 targets : int/list of ints
2421 The engines, by integer ID, whose entire result histories are to be purged.
2422
2423 Raises
2424 ------
2425 RuntimeError : if any of the tasks to be purged are still outstanding.
2426
2427 """
2428 if not targets and not jobs:
2429 raise ValueError("Must specify at least one of `targets` and `jobs`")
2430
2431 if jobs == 'all':
2432 if self.outstanding:
2433 raise RuntimeError(f"Can't purge outstanding tasks: {self.outstanding}")
2434 self.results.clear()
2435 self.metadata.clear()
2436 self._futures.clear()
2437 self._output_futures.clear()
2438 else:
2439 msg_ids = set()
2440 msg_ids.update(self._msg_ids_from_target(targets))
2441 msg_ids.update(self._msg_ids_from_jobs(jobs))
2442 still_outstanding = self.outstanding.intersection(msg_ids)
2443 if still_outstanding:
2444 raise RuntimeError(
2445 f"Can't purge outstanding tasks: {still_outstanding}"
2446 )
2447 for mid in msg_ids:
2448 self.results.pop(mid, None)
2449 self.metadata.pop(mid, None)
2450 self._futures.pop(mid, None)
2451 self._output_futures.pop(mid, None)
2452
2453 def purge_hub_results(self, jobs=[], targets=[]):
2454 """Tell the Hub to forget results.
2455
2456 Individual results can be purged by msg_id, or the entire
2457 history of specific targets can be purged.
2458
2459 Use `purge_results('all')` to scrub everything from the Hub's db.
2460
2461 Parameters
2462 ----------
2463 jobs : str or list of str or AsyncResult objects
2464 the msg_ids whose results should be forgotten.
2465 targets : int/str/list of ints/strs
2466 The targets, by int_id, whose entire history is to be purged.
2467
2468 default : None
2469 """
2470 if not targets and not jobs:
2471 raise ValueError("Must specify at least one of `targets` and `jobs`")
2472 if targets:
2473 targets = self._build_targets(targets)[1]
2474
2475 # construct msg_ids from jobs
2476 if jobs == 'all':
2477 msg_ids = jobs
2478 else:
2479 msg_ids = self._msg_ids_from_jobs(jobs)
2480
2481 content = dict(engine_ids=targets, msg_ids=msg_ids)
2482 reply = self._send_recv(self._query_stream, "purge_request", content=content)
2483 content = reply['content']
2484 if content['status'] != 'ok':
2485 raise self._unwrap_exception(content)
2486
2487 def purge_results(self, jobs=[], targets=[]):
2488 """Clears the cached results from both the hub and the local client
2489
2490 Individual results can be purged by msg_id, or the entire
2491 history of specific targets can be purged.
2492
2493 Use `purge_results('all')` to scrub every cached result from both the Hub's and
2494 the Client's db.
2495
2496 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
2497 the same arguments.
2498
2499 Parameters
2500 ----------
2501 jobs : str or list of str or AsyncResult objects
2502 the msg_ids whose results should be forgotten.
2503 targets : int/str/list of ints/strs
2504 The targets, by int_id, whose entire history is to be purged.
2505
2506 default : None
2507 """
2508 self.purge_local_results(jobs=jobs, targets=targets)
2509 self.purge_hub_results(jobs=jobs, targets=targets)
2510
2511 def purge_everything(self):
2512 """Clears all content from previous Tasks from both the hub and the local client
2513
2514 In addition to calling `purge_results("all")` it also deletes the history and
2515 other bookkeeping lists.
2516 """
2517 self.purge_results("all")
2518 self.history = []
2519 self.session.digest_history.clear()
2520
2521 def hub_history(self):
2522 """Get the Hub's history
2523
2524 Just like the Client, the Hub has a history, which is a list of msg_ids.
2525 This will contain the history of all clients, and, depending on configuration,
2526 may contain history across multiple cluster sessions.
2527
2528 Any msg_id returned here is a valid argument to `get_result`.
2529
2530 Returns
2531 -------
2532 msg_ids : list of strs
2533 list of all msg_ids, ordered by task submission time.
2534 """
2535
2536 reply = self._send_recv(self._query_stream, "history_request", content={})
2537 content = reply['content']
2538 if content['status'] != 'ok':
2539 raise self._unwrap_exception(content)
2540 else:
2541 return content['history']
2542
2543 def db_query(self, query, keys=None):
2544 """Query the Hub's TaskRecord database
2545
2546 This will return a list of task record dicts that match `query`
2547
2548 Parameters
2549 ----------
2550 query : mongodb query dict
2551 The search dict. See mongodb query docs for details.
2552 keys : list of strs [optional]
2553 The subset of keys to be returned. The default is to fetch everything but buffers.
2554 'msg_id' will *always* be included.
2555 """
2556 if isinstance(keys, str):
2557 keys = [keys]
2558 content = dict(query=query, keys=keys)
2559 reply = self._send_recv(self._query_stream, "db_request", content=content)
2560 content = reply['content']
2561 if content['status'] != 'ok':
2562 raise self._unwrap_exception(content)
2563
2564 records = content['records']
2565
2566 buffer_lens = content['buffer_lens']
2567 result_buffer_lens = content['result_buffer_lens']
2568 buffers = reply['buffers']
2569 has_bufs = buffer_lens is not None
2570 has_rbufs = result_buffer_lens is not None
2571 for i, rec in enumerate(records):
2572 # unpack datetime objects
2573 for hkey in ('header', 'result_header'):
2574 if hkey in rec:
2575 rec[hkey] = util.extract_dates(rec[hkey])
2576 for dtkey in ('submitted', 'started', 'completed', 'received'):
2577 if dtkey in rec:
2578 rec[dtkey] = util._parse_date(rec[dtkey])
2579 # relink buffers
2580 if has_bufs:
2581 blen = buffer_lens[i]
2582 rec['buffers'], buffers = buffers[:blen], buffers[blen:]
2583 if has_rbufs:
2584 blen = result_buffer_lens[i]
2585 rec['result_buffers'], buffers = buffers[:blen], buffers[blen:]
2586
2587 return records
2588
2589
2590__all__ = ['Client']