Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/ipyparallel/util.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

453 statements  

1"""Some generic utilities for dealing with classes, urls, and serialization.""" 

2 

3# Copyright (c) IPython Development Team. 

4# Distributed under the terms of the Modified BSD License. 

5import asyncio 

6import functools 

7import inspect 

8import logging 

9import os 

10import re 

11import shlex 

12import socket 

13import sys 

14import warnings 

15from datetime import datetime, timezone 

16from functools import lru_cache, partial 

17from signal import SIGABRT, SIGINT, SIGTERM, signal 

18from threading import Thread, current_thread 

19from types import FunctionType 

20 

21import traitlets 

22import zmq 

23from dateutil.parser import parse as dateutil_parse 

24from dateutil.tz import tzlocal 

25from IPython import get_ipython 

26from IPython.core.profiledir import ProfileDir, ProfileDirError 

27from IPython.paths import get_ipython_dir 

28from jupyter_client import session 

29from jupyter_client.localinterfaces import is_public_ip, localhost, public_ips 

30from tornado.ioloop import IOLoop 

31from traitlets.log import get_logger 

32from zmq.log import handlers 

33 

34utc = timezone.utc 

35 

36 

37# ----------------------------------------------------------------------------- 

38# Classes 

39# ----------------------------------------------------------------------------- 

40 

41 

42class Namespace(dict): 

43 """Subclass of dict for attribute access to keys.""" 

44 

45 def __getattr__(self, key): 

46 """getattr aliased to getitem""" 

47 if key in self: 

48 return self[key] 

49 else: 

50 raise NameError(key) 

51 

52 def __setattr__(self, key, value): 

53 """setattr aliased to setitem, with strict""" 

54 if hasattr(dict, key): 

55 raise KeyError(f"Cannot override dict keys {key!r}") 

56 self[key] = value 

57 

58 

59class ReverseDict(dict): 

60 """simple double-keyed subset of dict methods.""" 

61 

62 def __init__(self, *args, **kwargs): 

63 dict.__init__(self, *args, **kwargs) 

64 self._reverse = dict() 

65 for key, value in self.items(): 

66 self._reverse[value] = key 

67 

68 def __getitem__(self, key): 

69 try: 

70 return dict.__getitem__(self, key) 

71 except KeyError: 

72 return self._reverse[key] 

73 

74 def __setitem__(self, key, value): 

75 if key in self._reverse: 

76 raise KeyError(f"Can't have key {key!r} on both sides!") 

77 dict.__setitem__(self, key, value) 

78 self._reverse[value] = key 

79 

80 def pop(self, key): 

81 value = dict.pop(self, key) 

82 self._reverse.pop(value) 

83 return value 

84 

85 def get(self, key, default=None): 

86 try: 

87 return self[key] 

88 except KeyError: 

89 return default 

90 

91 

92# ----------------------------------------------------------------------------- 

93# Functions 

94# ----------------------------------------------------------------------------- 

95 

96 

97def log_errors(f): 

98 """decorator to log unhandled exceptions raised in a method. 

99 

100 For use wrapping on_recv callbacks, so that exceptions 

101 do not cause the stream to be closed. 

102 """ 

103 

104 @functools.wraps(f) 

105 def logs_errors(self, *args, **kwargs): 

106 try: 

107 result = f(self, *args, **kwargs) 

108 except Exception as e: 

109 self.log.exception(f"Uncaught exception in {f}: {e}") 

110 return 

111 

112 if inspect.isawaitable(result): 

113 # if it's async, schedule logging for when the future resolves 

114 future = asyncio.ensure_future(result) 

115 

116 def _log_error(future): 

117 if future.exception(): 

118 self.log.error(f"Uncaught exception in {f}: {future.exception()}") 

119 

120 future.add_done_callback(_log_error) 

121 

122 return logs_errors 

123 

124 

125def is_url(url): 

126 """boolean check for whether a string is a zmq url""" 

127 if '://' not in url: 

128 return False 

129 proto, addr = url.split('://', 1) 

130 if proto.lower() not in ['tcp', 'pgm', 'epgm', 'ipc', 'inproc']: 

131 return False 

132 return True 

133 

134 

135def validate_url(url): 

136 """validate a url for zeromq""" 

137 if not isinstance(url, str): 

138 raise TypeError(f"url must be a string, not {type(url)!r}") 

139 url = url.lower() 

140 

141 proto_addr = url.split('://') 

142 assert len(proto_addr) == 2, f'Invalid url: {url!r}' 

143 proto, addr = proto_addr 

144 assert proto in [ 

145 'tcp', 

146 'pgm', 

147 'epgm', 

148 'ipc', 

149 'inproc', 

150 ], f"Invalid protocol: {proto!r}" 

151 

152 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391 

153 # author: Remi Sabourin 

154 pat = re.compile( 

155 r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$' 

156 ) 

157 

158 if proto == 'tcp': 

159 lis = addr.split(':') 

160 assert len(lis) == 2, f'Invalid url: {url!r}' 

161 addr, s_port = lis 

162 try: 

163 port = int(s_port) 

164 except ValueError: 

165 raise AssertionError(f"Invalid port {port!r} in url: {url!r}") 

166 

167 assert addr == '*' or pat.match(addr) is not None, f'Invalid url: {url!r}' 

168 

169 else: 

170 # only validate tcp urls currently 

171 pass 

172 

173 return True 

174 

175 

176def validate_url_container(container): 

177 """validate a potentially nested collection of urls.""" 

178 if isinstance(container, str): 

179 url = container 

180 return validate_url(url) 

181 elif isinstance(container, dict): 

182 container = container.values() 

183 

184 for element in container: 

185 validate_url_container(element) 

186 

187 

188def split_url(url): 

189 """split a zmq url (tcp://ip:port) into ('tcp','ip','port').""" 

190 proto_addr = url.split('://') 

191 assert len(proto_addr) == 2, f'Invalid url: {url!r}' 

192 proto, addr = proto_addr 

193 lis = addr.split(':') 

194 assert len(lis) == 2, f'Invalid url: {url!r}' 

195 addr, s_port = lis 

196 return proto, addr, s_port 

197 

198 

199def is_ip(location): 

200 """Is a location an ip? 

201 

202 It could be a hostname. 

203 """ 

204 return bool(re.match(location, r'(\d+\.){3}\d+')) 

205 

206 

207@lru_cache 

208def ip_for_host(host): 

209 """Get the ip address for a host 

210 

211 If no ips can be found for the host, 

212 the host is returned unmodified. 

213 """ 

214 try: 

215 return socket.gethostbyname_ex(host)[2][0] 

216 except Exception as e: 

217 warnings.warn( 

218 f"IPython could not determine IPs for {host}: {e}", RuntimeWarning 

219 ) 

220 return host 

221 

222 

223def disambiguate_ip_address(ip, location=None): 

224 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address 

225 

226 Explicit IP addresses are returned unmodified. 

227 

228 Parameters 

229 ---------- 

230 ip : IP address 

231 An IP address, or the special values 0.0.0.0, or * 

232 location : IP address or hostname, optional 

233 A public IP of the target machine, or its hostname. 

234 If location is an IP of the current machine, 

235 localhost will be returned, 

236 otherwise location will be returned. 

237 """ 

238 if ip in {'0.0.0.0', '*'}: 

239 if not location: 

240 # unspecified location, localhost is the only choice 

241 return localhost() 

242 elif not is_ip(location): 

243 if location == socket.gethostname(): 

244 # hostname matches, use localhost 

245 return localhost() 

246 else: 

247 # hostname doesn't match, but the machine can have a few names. 

248 location = ip_for_host(location) 

249 

250 if is_public_ip(location): 

251 # location is a public IP on this machine, use localhost 

252 ip = localhost() 

253 elif not public_ips(): 

254 # this machine's public IPs cannot be determined, 

255 # assume `location` is not this machine 

256 warnings.warn("IPython could not determine public IPs", RuntimeWarning) 

257 ip = location 

258 else: 

259 # location is not this machine, do not use loopback 

260 ip = location 

261 return ip 

262 

263 

264def disambiguate_url(url, location=None): 

265 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable 

266 ones, based on the location (default interpretation is localhost). 

267 

268 This is for zeromq urls, such as ``tcp://*:10101``. 

269 """ 

270 try: 

271 proto, ip, port = split_url(url) 

272 except AssertionError: 

273 # probably not tcp url; could be ipc, etc. 

274 return url 

275 

276 ip = disambiguate_ip_address(ip, location) 

277 

278 return f"{proto}://{ip}:{port}" 

279 

280 

281# -------------------------------------------------------------------------- 

282# helpers for implementing old MEC API via view.apply 

283# -------------------------------------------------------------------------- 

284 

285 

286def interactive(f): 

287 """decorator for making functions appear as interactively defined. 

288 This results in the function being linked to the user_ns as globals() 

289 instead of the module globals(). 

290 """ 

291 

292 # build new FunctionType, so it can have the right globals 

293 # interactive functions never have closures, that's kind of the point 

294 if isinstance(f, FunctionType): 

295 mainmod = __import__('__main__') 

296 f = FunctionType( 

297 f.__code__, 

298 mainmod.__dict__, 

299 f.__name__, 

300 f.__defaults__, 

301 ) 

302 # associate with __main__ for uncanning 

303 f.__module__ = '__main__' 

304 return f 

305 

306 

307def _push(**ns): 

308 """helper method for implementing `client.push` via `client.apply`""" 

309 user_ns = get_ipython().user_global_ns 

310 tmp = '_IP_PUSH_TMP_' 

311 while tmp in user_ns: 

312 tmp = tmp + '_' 

313 try: 

314 for name, value in ns.items(): 

315 user_ns[tmp] = value 

316 exec(f"{name} = {tmp}", user_ns) 

317 finally: 

318 user_ns.pop(tmp, None) 

319 

320 

321def _pull(keys): 

322 """helper method for implementing `client.pull` via `client.apply`""" 

323 user_ns = get_ipython().user_global_ns 

324 if isinstance(keys, (list, tuple, set)): 

325 return [eval(key, user_ns) for key in keys] 

326 else: 

327 return eval(keys, user_ns) 

328 

329 

330def _execute(code): 

331 """helper method for implementing `client.execute` via `client.apply`""" 

332 user_ns = get_ipython().user_global_ns 

333 exec(code, user_ns) 

334 

335 

336# -------------------------------------------------------------------------- 

337# extra process management utilities 

338# -------------------------------------------------------------------------- 

339 

340_random_ports = set() 

341 

342 

343def select_random_ports(n): 

344 """Selects and return n random ports that are available.""" 

345 ports = [] 

346 for i in range(n): 

347 sock = socket.socket() 

348 sock.bind(('', 0)) 

349 while sock.getsockname()[1] in _random_ports: 

350 sock.close() 

351 sock = socket.socket() 

352 sock.bind(('', 0)) 

353 ports.append(sock) 

354 for i, sock in enumerate(ports): 

355 port = sock.getsockname()[1] 

356 sock.close() 

357 ports[i] = port 

358 _random_ports.add(port) 

359 return ports 

360 

361 

362def signal_children(children): 

363 """Relay interupt/term signals to children, for more solid process cleanup.""" 

364 

365 def terminate_children(sig, frame): 

366 log = get_logger() 

367 log.critical("Got signal %i, terminating children...", sig) 

368 for child in children: 

369 child.terminate() 

370 

371 sys.exit(sig != SIGINT) 

372 # sys.exit(sig) 

373 

374 for sig in (SIGINT, SIGABRT, SIGTERM): 

375 signal(sig, terminate_children) 

376 

377 

378def integer_loglevel(loglevel): 

379 try: 

380 loglevel = int(loglevel) 

381 except ValueError: 

382 if isinstance(loglevel, str): 

383 loglevel = getattr(logging, loglevel) 

384 return loglevel 

385 

386 

387def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG): 

388 logger = logging.getLogger(logname) 

389 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): 

390 # don't add a second PUBHandler 

391 return 

392 loglevel = integer_loglevel(loglevel) 

393 lsock = context.socket(zmq.PUB) 

394 lsock.connect(iface) 

395 handler = handlers.PUBHandler(lsock) 

396 handler.setLevel(loglevel) 

397 handler.root_topic = root 

398 logger.addHandler(handler) 

399 logger.setLevel(loglevel) 

400 return logger 

401 

402 

403def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG): 

404 from ipyparallel.engine.log import EnginePUBHandler 

405 

406 logger = logging.getLogger() 

407 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): 

408 # don't add a second PUBHandler 

409 return 

410 loglevel = integer_loglevel(loglevel) 

411 lsock = context.socket(zmq.PUB) 

412 lsock.connect(iface) 

413 handler = EnginePUBHandler(engine, lsock) 

414 handler.setLevel(loglevel) 

415 logger.addHandler(handler) 

416 logger.setLevel(loglevel) 

417 return logger 

418 

419 

420def local_logger(logname, loglevel=logging.DEBUG): 

421 loglevel = integer_loglevel(loglevel) 

422 logger = logging.getLogger(logname) 

423 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): 

424 # don't add a second StreamHandler 

425 return 

426 handler = logging.StreamHandler() 

427 handler.setLevel(loglevel) 

428 formatter = logging.Formatter( 

429 "%(asctime)s.%(msecs).03d [%(name)s] %(message)s", 

430 datefmt="%Y-%m-%d %H:%M:%S", 

431 ) 

432 handler.setFormatter(formatter) 

433 

434 logger.addHandler(handler) 

435 logger.setLevel(loglevel) 

436 return logger 

437 

438 

439def set_hwm(sock, hwm=0): 

440 """set zmq High Water Mark on a socket 

441 

442 in a way that always works for various pyzmq / libzmq versions. 

443 """ 

444 import zmq 

445 

446 for key in ('HWM', 'SNDHWM', 'RCVHWM'): 

447 opt = getattr(zmq, key, None) 

448 if opt is None: 

449 continue 

450 try: 

451 sock.setsockopt(opt, hwm) 

452 except zmq.ZMQError: 

453 pass 

454 

455 

456def int_keys(dikt): 

457 """Rekey a dict that has been forced to cast number keys to str for JSON 

458 

459 where there should be ints. 

460 """ 

461 for k in list(dikt): 

462 if isinstance(k, str): 

463 nk = None 

464 try: 

465 nk = int(k) 

466 except ValueError: 

467 try: 

468 nk = float(k) 

469 except ValueError: 

470 continue 

471 if nk in dikt: 

472 raise KeyError(f"already have key {nk!r}") 

473 dikt[nk] = dikt.pop(k) 

474 return dikt 

475 

476 

477def become_dask_worker(address, nanny=False, **kwargs): 

478 """Task function for becoming a dask.distributed Worker 

479 

480 Parameters 

481 ---------- 

482 address : str 

483 The URL of the dask Scheduler. 

484 **kwargs 

485 Any additional keyword arguments will be passed to the Worker constructor. 

486 """ 

487 shell = get_ipython() 

488 kernel = shell.kernel 

489 if getattr(kernel, 'dask_worker', None) is not None: 

490 kernel.log.info("Dask worker is already running.") 

491 return 

492 from distributed import Nanny, Worker 

493 

494 if nanny: 

495 w = Nanny(address, **kwargs) 

496 else: 

497 w = Worker(address, **kwargs) 

498 shell.user_ns['dask_worker'] = shell.user_ns['distributed_worker'] = ( 

499 kernel.distributed_worker 

500 ) = w 

501 

502 # call_soon doesn't launch coroutines 

503 def _log_error(f): 

504 kernel.log.info(f"dask start finished {f=}") 

505 try: 

506 f.result() 

507 except Exception: 

508 kernel.log.error("Error starting dask worker", exc_info=True) 

509 

510 f = asyncio.ensure_future(w.start()) 

511 f.add_done_callback(_log_error) 

512 

513 

514def stop_distributed_worker(): 

515 """Task function for stopping the the distributed worker on an engine.""" 

516 shell = get_ipython() 

517 kernel = shell.kernel 

518 if getattr(kernel, 'distributed_worker', None) is None: 

519 kernel.log.info("Distributed worker already stopped.") 

520 return 

521 w = kernel.distributed_worker 

522 kernel.distributed_worker = None 

523 if shell.user_ns.get('distributed_worker', None) is w: 

524 shell.user_ns.pop('distributed_worker', None) 

525 IOLoop.current().add_callback(lambda: w.terminate(None)) 

526 

527 

528def ensure_timezone(dt): 

529 """Ensure a datetime object has a timezone 

530 

531 If it doesn't have one, attach the local timezone. 

532 """ 

533 if dt.tzinfo is None: 

534 return dt.replace(tzinfo=tzlocal()) 

535 else: 

536 return dt 

537 

538 

539# extract_dates forward-port from jupyter_client 5.0 

540# timestamp formats 

541ISO8601 = "%Y-%m-%dT%H:%M:%S.%f" 

542ISO8601_PAT = re.compile( 

543 r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?(Z|([\+\-]\d{2}:?\d{2}))?$" 

544) 

545 

546 

547def _ensure_tzinfo(dt): 

548 """Ensure a datetime object has tzinfo 

549 

550 If no tzinfo is present, add tzlocal 

551 """ 

552 if not dt.tzinfo: 

553 # No more naïve datetime objects! 

554 warnings.warn( 

555 f"Interpreting naïve datetime as local {dt}. Please add timezone info to timestamps.", 

556 DeprecationWarning, 

557 stacklevel=4, 

558 ) 

559 dt = dt.replace(tzinfo=tzlocal()) 

560 return dt 

561 

562 

563def _parse_date(s): 

564 """parse an ISO8601 date string 

565 

566 If it is None or not a valid ISO8601 timestamp, 

567 it will be returned unmodified. 

568 Otherwise, it will return a datetime object. 

569 """ 

570 if s is None: 

571 return s 

572 m = ISO8601_PAT.match(s) 

573 if m: 

574 dt = dateutil_parse(s) 

575 return _ensure_tzinfo(dt) 

576 return s 

577 

578 

579def extract_dates(obj): 

580 """extract ISO8601 dates from unpacked JSON""" 

581 if isinstance(obj, dict): 

582 new_obj = {} # don't clobber 

583 for k, v in obj.items(): 

584 new_obj[k] = extract_dates(v) 

585 obj = new_obj 

586 elif isinstance(obj, (list, tuple)): 

587 obj = [extract_dates(o) for o in obj] 

588 elif isinstance(obj, str): 

589 obj = _parse_date(obj) 

590 return obj 

591 

592 

593def compare_datetimes(a, b): 

594 """Compare two datetime objects 

595 

596 If one has a timezone and the other doesn't, 

597 treat the naïve datetime as local time to avoid errors. 

598 

599 Returns the timedelta 

600 """ 

601 if a.tzinfo is None and b.tzinfo is not None: 

602 a = a.replace(tzinfo=tzlocal()) 

603 elif a.tzinfo is not None and b.tzinfo is None: 

604 b = b.replace(tzinfo=tzlocal()) 

605 return a - b 

606 

607 

608def utcnow(): 

609 """Timezone-aware UTC timestamp""" 

610 return datetime.now(utc) 

611 

612 

613def _v(version_s): 

614 return tuple(int(s) for s in re.findall(r"\d+", version_s)) 

615 

616 

617@lru_cache 

618def _disable_session_extract_dates(): 

619 """Monkeypatch jupyter_client.extract_dates to be a no-op 

620 

621 avoids performance problem parsing unused timestamp strings 

622 """ 

623 session.extract_dates = lambda obj: obj 

624 

625 

626def progress(*args, widget=None, **kwargs): 

627 """Create a tqdm progress bar 

628 

629 If `widget` is None, autodetects if IPython widgets should be used, 

630 otherwise use basic tqdm. 

631 """ 

632 if widget is None: 

633 # auto widget if in a kernel 

634 ip = get_ipython() 

635 if ip is not None and getattr(ip, 'kernel', None) is not None: 

636 try: 

637 import ipywidgets # noqa 

638 except ImportError: 

639 widget = False 

640 else: 

641 widget = True 

642 else: 

643 widget = False 

644 if widget: 

645 import tqdm.notebook 

646 

647 f = tqdm.notebook.tqdm_notebook 

648 else: 

649 import tqdm 

650 

651 kwargs.setdefault("file", sys.stdout) 

652 f = tqdm.tqdm 

653 return f(*args, **kwargs) 

654 

655 

656def abbreviate_profile_dir(profile_dir): 

657 """Abbreviate IPython profile directory if in $IPYTHONDIR""" 

658 profile_prefix = os.path.join(get_ipython_dir(), "profile_") 

659 if profile_dir.startswith(profile_prefix): 

660 # use just the profile name if it's in $IPYTHONDIR 

661 return profile_dir[len(profile_prefix) :] 

662 else: 

663 return profile_dir 

664 

665 

666def _all_profile_dirs(): 

667 """List all IPython profile directories""" 

668 profile_dirs = [] 

669 if not os.path.isdir(get_ipython_dir()): 

670 return profile_dirs 

671 with os.scandir(get_ipython_dir()) as paths: 

672 for path in paths: 

673 if path.is_dir() and path.name.startswith('profile_'): 

674 profile_dirs.append(path.path) 

675 return profile_dirs 

676 

677 

678def _default_profile_dir(profile=None): 

679 """Locate the default IPython profile directory 

680 

681 Priorities: 

682 

683 - named profile, if specified 

684 - current IPython profile, if run inside IPython 

685 - $IPYTHONDIR/profile_default 

686 

687 Returns absolute profile directory path, 

688 ensuring it exists 

689 """ 

690 if not profile: 

691 ip = get_ipython() 

692 if ip is not None: 

693 return ip.profile_dir.location 

694 ipython_dir = get_ipython_dir() 

695 profile = profile or 'default' 

696 try: 

697 pd = ProfileDir.find_profile_dir_by_name(ipython_dir, name=profile) 

698 except ProfileDirError: 

699 pd = ProfileDir.create_profile_dir_by_name(ipython_dir, name=profile) 

700 return pd.location 

701 

702 

703def _locate_profiles(profiles=None): 

704 """Locate one or more IPython profiles by name""" 

705 ipython_dir = get_ipython_dir() 

706 return [ 

707 ProfileDir.find_profile_dir_by_name(ipython_dir, name=profile).location 

708 for profile in profiles 

709 ] 

710 

711 

712def shlex_join(cmd): 

713 """Backport shlex.join to Python < 3.8""" 

714 return ' '.join(shlex.quote(s) for s in cmd) 

715 

716 

717_traitlet_annotations = { 

718 traitlets.Bool: bool, 

719 traitlets.Integer: int, 

720 traitlets.Float: float, 

721 traitlets.List: list, 

722 traitlets.Dict: dict, 

723 traitlets.Set: set, 

724 traitlets.Unicode: str, 

725 traitlets.Tuple: tuple, 

726} 

727 

728 

729class _TraitAnnotation: 

730 """Trait annotation for a trait type""" 

731 

732 def __init__(self, trait_type): 

733 self.trait_type = trait_type 

734 

735 def __repr__(self): 

736 return self.trait_type.__name__ 

737 

738 

739def _trait_annotation(trait_type): 

740 """Return an annotation for a trait""" 

741 if trait_type in _traitlet_annotations: 

742 return _traitlet_annotations[trait_type] 

743 else: 

744 annotation = _traitlet_annotations[trait_type] = _TraitAnnotation(trait_type) 

745 return annotation 

746 

747 

748def _traitlet_signature(cls): 

749 """Add traitlet-based signature to a class""" 

750 parameters = [] 

751 for name, trait in cls.class_traits().items(): 

752 if name.startswith("_"): 

753 # omit private traits 

754 continue 

755 if trait.metadata.get("nosignature"): 

756 continue 

757 if "alias" in trait.metadata: 

758 name = trait.metadata["alias"] 

759 if hasattr(trait, 'default'): 

760 # traitlets 5 

761 default = trait.default() 

762 else: 

763 default = trait.default_value 

764 if default is traitlets.Undefined: 

765 default = None 

766 

767 annotation = _trait_annotation(trait.__class__) 

768 

769 parameters.append( 

770 inspect.Parameter( 

771 name=name, 

772 kind=inspect.Parameter.KEYWORD_ONLY, 

773 annotation=annotation, 

774 default=default, 

775 ) 

776 ) 

777 cls.__signature__ = inspect.Signature(parameters) 

778 return cls 

779 

780 

781def bind(socket, url, curve_publickey=None, curve_secretkey=None): 

782 """Common utility to bind with optional auth info""" 

783 if curve_secretkey: 

784 socket.setsockopt(zmq.CURVE_SERVER, 1) 

785 socket.setsockopt(zmq.CURVE_SECRETKEY, curve_secretkey) 

786 return socket.bind(url) 

787 

788 

789def connect( 

790 socket, 

791 url, 

792 curve_serverkey=None, 

793 curve_publickey=None, 

794 curve_secretkey=None, 

795): 

796 """Common utility to connect with optional auth info""" 

797 if curve_serverkey: 

798 if not curve_publickey or not curve_secretkey: 

799 # unspecified, generate new client credentials 

800 # we don't use client secret auth, 

801 # so these are just used for encryption. 

802 # any values will do. 

803 curve_publickey, curve_secretkey = zmq.curve_keypair() 

804 socket.setsockopt(zmq.CURVE_SERVERKEY, curve_serverkey) 

805 socket.setsockopt(zmq.CURVE_SECRETKEY, curve_secretkey) 

806 socket.setsockopt(zmq.CURVE_PUBLICKEY, curve_publickey) 

807 return socket.connect(url) 

808 

809 

810def _detach_thread_output(ident=None): 

811 """undo thread-parent mapping in ipykernel#1186""" 

812 # disable ipykernel's association of thread output with the cell that 

813 # spawned the thread. 

814 # there should be a public API for this... 

815 if ident is None: 

816 ident = current_thread().ident 

817 for stream in (sys.stdout, sys.stderr): 

818 for name in ("_thread_to_parent", "_thread_to_parent_header"): 

819 mapping = getattr(stream, name, None) 

820 if mapping: 

821 mapping.pop(ident, None) 

822 

823 

824class _OutputProducingThread(Thread): 

825 """ 

826 Subclass Thread to workaround bug in ipykernel 

827 associating thread output with wrong Cell 

828 

829 See https://github.com/ipython/ipykernel/issues/1289 

830 """ 

831 

832 def __init__(self, target, **kwargs): 

833 wrapped_target = partial(self._wrapped_target, target) 

834 super().__init__(target=wrapped_target, **kwargs) 

835 

836 def _wrapped_target(self, target, *args, **kwargs): 

837 _detach_thread_output(self.ident) 

838 return target(*args, **kwargs) 

839 

840 

841# minimal subset of TermColors, removed from IPython 

842# not for public consumption 

843class _TermColors: 

844 Normal = '\033[0m' 

845 Red = '\033[0;31m'