1"""Classes used in scattering and gathering sequences.
2
3Scattering consists of partitioning a sequence and sending the various
4pieces to individual nodes in a cluster.
5"""
6
7# Copyright (c) IPython Development Team.
8# Distributed under the terms of the Modified BSD License.
9import sys
10from itertools import chain, islice
11
12numpy = None
13
14
15def is_array(obj):
16 """Is an object a numpy array?
17
18 Avoids importing numpy until it is requested
19 """
20 global numpy
21 if 'numpy' not in sys.modules:
22 return False
23
24 if numpy is None:
25 import numpy
26 return isinstance(obj, numpy.ndarray)
27
28
29class Map:
30 """A class for partitioning a sequence using a map."""
31
32 def getPartition(self, seq, p, q, n=None):
33 """Returns the pth partition of q partitions of seq.
34
35 The length can be specified as `n`,
36 otherwise it is the value of `len(seq)`
37 """
38 n = len(seq) if n is None else n
39 # Test for error conditions here
40 if p < 0 or p >= q:
41 raise ValueError(f"must have 0 <= p <= q, but have p={p},q={q}")
42
43 remainder = n % q
44 basesize = n // q
45
46 if p < remainder:
47 low = p * (basesize + 1)
48 high = low + basesize + 1
49 else:
50 low = p * basesize + remainder
51 high = low + basesize
52
53 try:
54 result = seq[low:high]
55 except TypeError:
56 # some objects (iterators) can't be sliced,
57 # use islice:
58 result = list(islice(seq, low, high))
59
60 return result
61
62 def joinPartitions(self, listOfPartitions):
63 return self.concatenate(listOfPartitions)
64
65 def concatenate(self, listOfPartitions):
66 testObject = listOfPartitions[0]
67 # First see if we have a known array type
68 if is_array(testObject):
69 return numpy.concatenate(listOfPartitions)
70 # Next try for Python sequence types
71 if isinstance(testObject, (list, tuple)):
72 return list(chain.from_iterable(listOfPartitions))
73 # If we have scalars, just return listOfPartitions
74 return listOfPartitions
75
76
77class RoundRobinMap(Map):
78 """Partitions a sequence in a round robin fashion.
79
80 This currently does not work!
81 """
82
83 def getPartition(self, seq, p, q, n=None):
84 n = len(seq) if n is None else n
85 return seq[p:n:q]
86
87 def joinPartitions(self, listOfPartitions):
88 testObject = listOfPartitions[0]
89 # First see if we have a known array type
90 if is_array(testObject):
91 return self.flatten_array(listOfPartitions)
92 if isinstance(testObject, (list, tuple)):
93 return self.flatten_list(listOfPartitions)
94 return listOfPartitions
95
96 def flatten_array(self, listOfPartitions):
97 test = listOfPartitions[0]
98 shape = list(test.shape)
99 shape[0] = sum(p.shape[0] for p in listOfPartitions)
100 A = numpy.ndarray(shape)
101 N = shape[0]
102 q = len(listOfPartitions)
103 for p, part in enumerate(listOfPartitions):
104 A[p:N:q] = part
105 return A
106
107 def flatten_list(self, listOfPartitions):
108 flat = []
109 for i in range(len(listOfPartitions[0])):
110 flat.extend([part[i] for part in listOfPartitions if len(part) > i])
111 return flat
112
113
114def mappable(obj):
115 """return whether an object is mappable or not."""
116 if isinstance(obj, (tuple, list)):
117 return True
118 if is_array(obj):
119 return True
120 return False
121
122
123dists = {'b': Map, 'r': RoundRobinMap}