1"""Custom ipyparallel trait types"""
2
3import sys
4
5if sys.version_info < (3, 10):
6 from importlib_metadata import entry_points
7else:
8 from importlib.metadata import entry_points
9
10from traitlets import List, TraitError, Type
11
12
13class Launcher(Type):
14 """Entry point-extended Type
15
16 classes can be registered via entry points
17 in addition to standard 'mypackage.MyClass' strings
18 """
19
20 def __init__(self, *args, entry_point_group, **kwargs):
21 self.entry_point_group = entry_point_group
22 kwargs.setdefault('klass', 'ipyparallel.cluster.launcher.BaseLauncher')
23 super().__init__(*args, **kwargs)
24
25 _original_help = ''
26
27 @property
28 def help(self):
29 """Extend help by listing currently installed choices"""
30 chunks = [self._original_help]
31 chunks.append("Currently installed: ")
32 for key, entry_point in self.load_entry_points().items():
33 chunks.append(f" - {key}: {entry_point.value}")
34 return '\n'.join(chunks)
35
36 @help.setter
37 def help(self, value):
38 self._original_help = value
39
40 def load_entry_points(self):
41 """Load my entry point group"""
42 return {
43 entry_point.name.lower(): entry_point
44 for entry_point in entry_points(group=self.entry_point_group)
45 }
46
47 def validate(self, obj, value):
48 if isinstance(value, str):
49 # first, look up in entry point registry
50 registry = self.load_entry_points()
51 key = value.lower()
52 if key in registry:
53 value = registry[key].load()
54 return super().validate(obj, value)
55
56
57class PortList(List):
58 """List of ports
59
60 For use configuring a list of ports to consume
61
62 Ports will be a list of valid ports
63
64 Can be specified as a port-range string for convenience
65 (mainly for use on the command-line)
66 e.g. '10101-10105,10108'
67 """
68
69 @staticmethod
70 def parse_port_range(s):
71 """Parse a port range string in the form '1,3-5,6' into [1,3,4,5,6]"""
72 ports = []
73 ranges = s.split(",")
74 for r in ranges:
75 start, _, end = r.partition("-")
76 start = int(start)
77 if end:
78 end = int(end)
79 ports.extend(range(start, end + 1))
80 else:
81 ports.append(start)
82 return ports
83
84 def from_string_list(self, s_list):
85 ports = []
86 for s in s_list:
87 ports.extend(self.parse_port_range(s))
88 return ports
89
90 def validate(self, obj, value):
91 if isinstance(value, str):
92 value = self.parse_port_range(value)
93 value = super().validate(obj, value)
94 for item in value:
95 if not isinstance(item, int):
96 raise TraitError(
97 f"Ports must be integers in range 1-65536, not {item!r}"
98 )
99 if not 1 <= item <= 65536:
100 raise TraitError(
101 f"Ports must be integers in range 1-65536, not {item!r}"
102 )
103 return value