Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/parallel_device/parallel_device.py: 34%

86 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility for eagerly executing operations in parallel on multiple devices.""" 

16 

17import threading 

18import weakref 

19 

20from tensorflow.python import _pywrap_parallel_device 

21from tensorflow.python.distribute import device_util 

22from tensorflow.python.eager import context 

23from tensorflow.python.framework import composite_tensor 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import ops 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import variables 

28from tensorflow.python.tpu.ops import tpu_ops 

29from tensorflow.python.util import nest 

30from tensorflow.python.util import variable_utils 

31 

32_next_device_number = 0 

33_next_device_number_lock = threading.Lock() 

34 

35_all_parallel_devices = weakref.WeakValueDictionary() 

36 

37 

38def unpack(tensor): 

39 """Finds `tensor`'s parallel device and unpacks its components.""" 

40 parallel_device = _all_parallel_devices.get(tensor.device, None) 

41 if parallel_device is None: 

42 raise ValueError("{} is not a parallel device".format(tensor.device)) 

43 return parallel_device.unpack(tensor) 

44 

45 

46# TODO(allenl): Expand this docstring once things like getting components on and 

47# off the device are stable. 

48# 

49# TODO(allenl): Make multi-client work; we need an offset for device IDs, and an 

50# indication of how many other devices there are total for collectives which 

51# don't have a number of participants hard-coded in their attributes. 

52class ParallelDevice(object): 

53 """A device which executes operations in parallel.""" 

54 

55 def __init__(self, components): 

56 """Creates a device which executes operations in parallel on `components`. 

57 

58 Args: 

59 components: A list of device names. Each operation executed on the 

60 returned device executes on these component devices. 

61 

62 Returns: 

63 A string with the name of the newly created device. 

64 """ 

65 global _next_device_number, _next_device_number_lock 

66 self.components = tuple(device_util.canonicalize(d) for d in components) 

67 if not self.components: 

68 raise ValueError("ParallelDevice requires at least one component.") 

69 ctx = context.context() 

70 with _next_device_number_lock: 

71 # TODO(allenl): Better names for parallel devices (right now "CUSTOM" is 

72 # special-cased). 

73 self._name = "{}/device:CUSTOM:{}".format(ctx.host_address_space(), 

74 _next_device_number) 

75 _next_device_number += 1 

76 device, device_info = _pywrap_parallel_device.GetParallelDeviceCapsules( 

77 self._name, self.components) 

78 context.register_custom_device(device, self._name, device_info) 

79 self._device_ids = None 

80 self._device_scope = None 

81 _all_parallel_devices[self._name] = self 

82 

83 def _pack_tensor(self, *tensors): 

84 """Helper to pack plain-old-tensors, not structures or composites.""" 

85 for tensor in tensors: 

86 if not isinstance(tensor, (ops.Tensor, composite_tensor.CompositeTensor, 

87 variables.Variable)): 

88 raise ValueError( 

89 ("Every component to pack onto the ParallelDevice must already be " 

90 "a tensor, got {}. Consider running `tf.constant` or " 

91 "`tf.convert_to_tensor` first on literal values.") 

92 .format(tensors)) 

93 with ops.device(self._name): 

94 return tpu_ops.tpu_replicated_input(inputs=tensors) 

95 

96 def pack(self, tensors): 

97 """Create a tensor on the parallel device from a sequence of tensors. 

98 

99 Args: 

100 tensors: A list of tensors, one per device in `self.components`. The list 

101 can contain composite tensors and nests (lists, dicts, etc. supported by 

102 `tf.nest`) with the same structure for each device, but every component 

103 of nests must already be a `tf.Tensor` or composite. Passing 

104 `tf.Variable` objects reads their value, it does not share a mutable 

105 reference between the packed and unpacked forms. 

106 

107 Returns: 

108 A tensor placed on the ParallelDevice. For nested structures, returns a 

109 single structure containing tensors placed on the ParallelDevice (same 

110 structure as each component of `tensors`). 

111 

112 Raises: 

113 ValueError: If the length of `tensors` does not match the number of 

114 component devices, or if there are non-tensor inputs. 

115 

116 """ 

117 self._assert_eager() 

118 if len(tensors) != len(self.components): 

119 raise ValueError( 

120 ("Creating a parallel tensor requires one tensor per component. " 

121 "Got {} but was expecting {}.") 

122 .format(len(tensors), len(self.components))) 

123 with ops.device(None): 

124 # Explicitly read variable values. This can not be done on the parallel 

125 # device since the tensors are to be packed. 

126 tensors = variable_utils.convert_variables_to_tensors(tensors) 

127 return nest.map_structure(self._pack_tensor, *tensors, 

128 expand_composites=True) 

129 

130 def _unpack_tensor(self, parallel_tensor): 

131 """Helper to unpack a single tensor.""" 

132 if not isinstance(parallel_tensor, ( 

133 ops.Tensor, composite_tensor.CompositeTensor, variables.Variable)): 

134 raise ValueError( 

135 "Expected a tensor, got {}.".format(parallel_tensor)) 

136 with ops.device(self._name): 

137 return tpu_ops.tpu_replicated_output( 

138 parallel_tensor, num_replicas=len(self.components)) 

139 

140 def unpack(self, parallel_tensor): 

141 """Unpack a parallel tensor into its components. 

142 

143 Args: 

144 parallel_tensor: A tensor, composite tensor, or `tf.nest` of such placed 

145 on the ParallelDevice. Passing `tf.Variable` objects reads their value, 

146 it does not share a mutable reference between the packed and unpacked 

147 forms. 

148 

149 Returns: 

150 A list with the same length as `self.components` each with the same 

151 structure as `parallel_tensor`, containing component tensors. 

152 

153 """ 

154 self._assert_eager() 

155 unpacked_components = [[] for _ in range(len(self.components))] 

156 with ops.device(self._name): 

157 parallel_tensor = variable_utils.convert_variables_to_tensors( 

158 parallel_tensor) 

159 for tensor in nest.flatten(parallel_tensor, expand_composites=True): 

160 for accumulator, unpacked_tensor in zip( 

161 unpacked_components, self._unpack_tensor(tensor)): 

162 accumulator.append(unpacked_tensor) 

163 return [nest.pack_sequence_as(parallel_tensor, unpacked, 

164 expand_composites=True) 

165 for unpacked in unpacked_components] 

166 

167 @property 

168 def device_ids(self): 

169 """A parallel tensor with scalar integers numbering component devices. 

170 

171 Each device ID is placed on its corresponding device, in the same order as 

172 the `components` constructor argument. 

173 

174 Returns: 

175 A parallel tensor containing 0 on the first device, 1 on the second, etc. 

176 """ 

177 if self._device_ids is None: 

178 # device_ids may be called from inside a tf.function, in which case the 

179 # function captures the eager tensor. We can't pack tensors in a function 

180 # at the moment, and even if we could we don't want to hold on to a 

181 # symbolic tensor, so we need to init_scope out of the function 

182 # temporarily. 

183 with ops.init_scope(): 

184 # TODO(allenl): Functions which capture eager device ID tensors won't be 

185 # saveable in SavedModels. Ideally we'd run a DeviceID op every time 

186 # device IDs are required, with functions using the op in their bodies 

187 # but not hard-coding a fixed number of devices (so they can be re-used 

188 # with a different replica count). 

189 device_ids_list = [] 

190 for index, device in enumerate(self.components): 

191 with ops.device(device): 

192 # The identity op ensures each device ID tensor is placed on its 

193 # device. 

194 device_ids_list.append( 

195 array_ops.identity(constant_op.constant(index))) 

196 self._device_ids = self.pack(device_ids_list) 

197 

198 return self._device_ids 

199 

200 def _assert_eager(self): 

201 """Verifies that tracing is not active.""" 

202 if not context.executing_eagerly(): 

203 raise NotImplementedError( 

204 "ParallelDevice is currently not supported inside `tf.function`. It " 

205 "can however run calls to a `tf.function` in parallel:\n\n" 

206 "with ParallelDevice() as p:\n f()") 

207 

208 def __enter__(self): 

209 """Runs ops in parallel, makes variables which save independent buffers.""" 

210 if self._device_scope is not None: 

211 raise AssertionError( 

212 "Re-entered a ParallelDevice scope without first exiting it.") 

213 self._assert_eager() 

214 self._device_scope = ops.device(self._name) 

215 self._device_scope.__enter__() 

216 return self 

217 

218 def __exit__(self, typ, exc, tb): 

219 self._device_scope.__exit__(typ, exc, tb) 

220 self._device_scope = None