Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cross_device_utils.py: 18%

294 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for cross_device_ops.""" 

16 

17import copy 

18import threading 

19from typing import Callable, List, Optional, Union 

20 

21from tensorflow.python.distribute import collective_util 

22from tensorflow.python.distribute import values as value_lib 

23from tensorflow.python.eager import backprop_util 

24from tensorflow.python.eager import context 

25from tensorflow.python.framework import dtypes 

26from tensorflow.python.framework import indexed_slices 

27from tensorflow.python.framework import ops 

28from tensorflow.python.framework import tensor_spec 

29from tensorflow.python.ops import array_ops 

30from tensorflow.python.ops import collective_ops 

31from tensorflow.python.ops import cond 

32from tensorflow.python.ops import math_ops 

33from tensorflow.python.ops import nccl_ops 

34from tensorflow.python.ops import resource_variable_ops 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.types import core 

37 

38INSTANCE_KEY_START_NUMBER = 100 

39 

40 

41def aggregate_gradients_using_nccl(replica_grads): 

42 """Aggregate gradients using nccl allreduce.""" 

43 agg_all_g_and_v = [] 

44 for single_g_and_v in zip(*replica_grads): 

45 single_grads = [g for g, _ in single_g_and_v] 

46 agg_grads = nccl_ops.all_sum(single_grads) 

47 agg_all_g_and_v.append( 

48 [(g, v) for g, (_, v) in zip(agg_grads, single_g_and_v)]) 

49 

50 agg_all_g_and_v = list(zip(*agg_all_g_and_v)) 

51 

52 return agg_all_g_and_v 

53 

54 

55def aggregate_gradients_using_hierarchical_copy(avail_devices, replica_grads): 

56 """Aggregate gradients using hierarchical copies. 

57 

58 Args: 

59 avail_devices: available GPU devices. 

60 replica_grads: List of lists of (gradient, variable) tuples. The outer list 

61 is over replicas. The inner list is over individual gradients. 

62 

63 Returns: 

64 The list of (aggregated_gradient, variable), where the gradient has been 

65 summed across all replicas and the variable is chosen from the first 

66 replica. 

67 """ 

68 # This only works for DGX-1 type of machine topology 

69 # Device peer to peer matrix 

70 # DMA: 0 1 2 3 4 5 6 7 

71 # 0: Y Y Y Y Y N N N 

72 # 1: Y Y Y Y N Y N N 

73 # 2: Y Y Y Y N N Y N 

74 # 3: Y Y Y Y N N N Y 

75 # 4: Y N N N Y Y Y Y 

76 # 5: N Y N N Y Y Y Y 

77 # 6: N N Y N Y Y Y Y 

78 # 7: N N N Y Y Y Y Y 

79 agg_grads = [] 

80 num_devices = len(avail_devices) 

81 # In the special case of DGX-1 machine topology, the two groups have equal 

82 # size. 

83 group_size = num_devices // 2 

84 for i, single_grads in enumerate(zip(*replica_grads)): 

85 group_0_main_device = i % num_devices 

86 group_1_main_device = (group_0_main_device + group_size) % num_devices 

87 if group_0_main_device < group_size: 

88 group_0_begin = 0 

89 group_1_begin = group_size 

90 else: 

91 group_0_begin = group_size 

92 group_1_begin = 0 

93 

94 # Aggregate the first group. 

95 group_0_device_grads = single_grads[group_0_begin: 

96 group_0_begin + group_size] 

97 with ops.device(avail_devices[group_0_main_device]): 

98 group_0_agg_grads, _ = aggregate_single_gradient_using_copy( 

99 group_0_device_grads, False, False) 

100 

101 # Aggregate the second group. 

102 group_1_device_grads = single_grads[group_1_begin: 

103 group_1_begin + group_size] 

104 with ops.device(avail_devices[group_1_main_device]): 

105 group_1_agg_grads, _ = aggregate_single_gradient_using_copy( 

106 group_1_device_grads, False, False) 

107 

108 # Aggregate between the groups. 

109 with ops.device(avail_devices[group_0_main_device]): 

110 (agg_total_grads, _), _ = aggregate_single_gradient_using_copy( 

111 [group_0_agg_grads, group_1_agg_grads], False, False) 

112 

113 # Broadcast the result back into the root of each group. 

114 with ops.device(avail_devices[group_0_main_device]): 

115 group_0_agg_grads_bcast = array_ops.identity(agg_total_grads) 

116 with ops.device(avail_devices[group_1_main_device]): 

117 group_1_agg_grads_bcast = array_ops.identity(agg_total_grads) 

118 

119 agg_grads_bcast = [] 

120 for j in range(len(single_grads)): 

121 with ops.device(avail_devices[j]): 

122 # Broadcast the result back to each member in the group from the root. 

123 if (group_0_main_device < group_size) == (j < group_size): 

124 src_device_grad = group_0_agg_grads_bcast 

125 else: 

126 src_device_grad = group_1_agg_grads_bcast 

127 agg_grads_bcast.append(array_ops.identity(src_device_grad)) 

128 

129 agg_grads.append( 

130 [(g, v) for g, (_, v) in zip(agg_grads_bcast, single_grads)]) 

131 

132 agg_grads = list(zip(*agg_grads)) 

133 

134 return agg_grads 

135 

136 

137def aggregate_single_gradient_using_copy(grad_and_vars, use_mean, 

138 check_inf_nan): 

139 """Calculate the average gradient for a shared variable across all replicas. 

140 

141 Note that this function provides a synchronization point across all replicas. 

142 

143 Args: 

144 grad_and_vars: A list or tuple of (gradient, variable) tuples. Each 

145 (gradient, variable) pair within the outer list represents the gradient 

146 of the variable calculated for a single replica, and the number of pairs 

147 equals the number of replicas. 

148 use_mean: if True, mean is taken, else sum of gradients is taken. 

149 check_inf_nan: check grads for nans and infs. 

150 

151 Returns: 

152 The tuple ([(average_gradient, variable),], has_nan_or_inf) where the 

153 gradient has been averaged across all replicas. The variable is chosen 

154 from the first replica. The has_nan_or_inf indicates the grads has nan or 

155 inf. 

156 """ 

157 grads = [g for g, _ in grad_and_vars] 

158 grad = math_ops.add_n(grads) 

159 

160 if use_mean and len(grads) > 1: 

161 grad = array_ops.multiply(grad, 1.0 / len(grads)) 

162 

163 v = grad_and_vars[0][1] 

164 if check_inf_nan: 

165 has_nan_or_inf = array_ops.logical_not( 

166 array_ops.reduce_all(array_ops.is_finite(grads))) 

167 return (grad, v), has_nan_or_inf 

168 else: 

169 return (grad, v), None 

170 

171 

172# TODO(yuefengz): use random key starts to avoid reusing keys? 

173class CollectiveKeys(object): 

174 """Class that manages collective keys. 

175 

176 We need to manage three different keys for collective: 

177 

178 *Group key*: an integer key to identify the set of cooperative devices. 

179 Collective ops work under the same set of devices must using the same group 

180 key. 

181 

182 *Instance key*: an integer key to identify the set of same counterpart of 

183 tensors on different devices in a device group that need to be all-reduced. 

184 

185 This class is thread safe. 

186 """ 

187 

188 def __init__(self, group_key_start=1): 

189 """Initializes the object. 

190 

191 Args: 

192 group_key_start: the starting integer of group key. 

193 """ 

194 self._group_key = group_key_start 

195 self._instance_key_table = {} 

196 self._lock = threading.Lock() 

197 self._known_groups = {} 

198 

199 def get_group_key(self, devices): 

200 """Returns a group key for the list of local devices. 

201 

202 The same group key is returned if the list of local devices is the same. 

203 

204 Args: 

205 devices: a list of local canonical device strings in a collective group. 

206 

207 Returns: 

208 a group key. 

209 """ 

210 with self._lock: 

211 devices_key = ','.join(devices) 

212 if devices_key not in self._known_groups: 

213 self._known_groups[devices_key] = self._get_new_group_key(devices) 

214 return self._known_groups[devices_key] 

215 

216 def _get_new_group_key(self, devices): 

217 """Returns a new group key. 

218 

219 The caller should store and reuse the same group key for the same set of 

220 devices. Calling this method always returns a new group key. 

221 

222 This method is not thread-safe. 

223 

224 Args: 

225 devices: a list of canonical device strings in a collective group. 

226 

227 Returns: 

228 a new group key. 

229 """ 

230 new_key = self._group_key 

231 self._group_key += 1 

232 self._instance_key_table[new_key] = {} 

233 for device in devices: 

234 self._instance_key_table[new_key][device] = INSTANCE_KEY_START_NUMBER 

235 return new_key 

236 

237 def get_instance_key(self, group_key, device): 

238 """Returns a new instance key for use in defining a collective op. 

239 

240 You should call this once per each collective op of a collective instance. 

241 

242 Args: 

243 group_key: the group key returned by get_group_key(). You should not 

244 assign the group key yourself. 

245 device: a canonical device string. It should be the device this collective 

246 op is on. 

247 

248 Returns: 

249 a new instance key. 

250 

251 Raises: 

252 ValueError: when the group key is invalid or the device is not in the 

253 group. 

254 """ 

255 with self._lock: 

256 group = self._instance_key_table.get(group_key, None) 

257 if group is None: 

258 raise ValueError(f'Group {group_key} is not found.') 

259 if device not in group: 

260 raise ValueError(f'Device {device} is not present in group {group_key}') 

261 v = group[device] 

262 group[device] += 1 

263 return v 

264 

265 def __deepcopy__(self, memo): 

266 # distribute_coordinator deep-copies the strategy object, so 

267 # CollectiveKeys needs to support deep copy as well. 

268 copied = CollectiveKeys() 

269 copied._group_key = self._group_key 

270 copied._instance_key_table = copy.deepcopy(self._instance_key_table, memo) 

271 return copied 

272 

273 

274class CollectiveReplicaLauncher(object): 

275 """Launch collectives on one replica.""" 

276 

277 _prefer_unique_instance_key = True 

278 _prefer_ordering_token = True 

279 

280 def __init__(self, group_key: int, group_size: int, 

281 collective_keys: CollectiveKeys, device: str, 

282 options: collective_util.Options): 

283 self._group_key = group_key 

284 self._group_size = group_size 

285 self._collective_keys = collective_keys 

286 self._device = device 

287 self._options = options 

288 if self._use_ordering_token(): 

289 with ops.init_scope(), ops.device(device): 

290 self._ordering_token = resource_variable_ops.ResourceVariable(0.) 

291 else: 

292 self._ordering_token = None 

293 

294 def _control_input(self, control_input: Union[core.TensorLike, 

295 ops.Operation]): 

296 if control_input is not None and not self._use_ordering_token(): 

297 return ops.control_dependencies([control_input]) 

298 return ops.NullContextmanager() 

299 

300 def _use_unique_instance_key(self): 

301 if not ops.executing_eagerly_outside_functions(): 

302 return False 

303 return CollectiveReplicaLauncher._prefer_unique_instance_key 

304 

305 def _use_ordering_token(self): 

306 # We rely on auto control dep to insert control edges between NCCL calls, 

307 # but for tf1 graph mode auto control dep is not used. 

308 if not ops.executing_eagerly_outside_functions(): 

309 return False 

310 return CollectiveReplicaLauncher._prefer_ordering_token 

311 

312 def _next_instance_key(self): 

313 """Returns the next instance key.""" 

314 if self._use_unique_instance_key(): 

315 # Assigning instance keys at function building time have issues since 

316 # different workers may retrace the function at different times. With 

317 # collective V2 we can use capture_call_time_value to use a placeholder as 

318 # the instance key and feed it at function call time. In this way we also 

319 # don't reuse instance keys, which allows for per-instance cancellation. 

320 graph = ops.get_default_graph() 

321 # Control flow ops don't work with capture_call_time_value, so we put the 

322 # capture in the function graph of that control flow op. 

323 while getattr(graph, 'is_control_flow_graph', False): 

324 graph = graph.outer_graph 

325 if not context.executing_eagerly() and graph.building_function: 

326 with graph.as_default(): 

327 # Capture self._next_instance_key so that when building a function 

328 # that calls another tf.function, the instance key assignment is 

329 # further delayed until we actually call the function in eager. Note 

330 # that capture_call_time_value doesn't automatically propagate the 

331 # deferred capture to the outer function. 

332 return graph.capture_call_time_value( 

333 self._next_instance_key, tensor_spec.TensorSpec([], dtypes.int32)) 

334 else: 

335 instance_key = self._collective_keys.get_instance_key( 

336 self._group_key, self._device) 

337 with ops.device('CPU:0'): 

338 return ops.convert_to_tensor(instance_key, dtype=dtypes.int32) 

339 else: 

340 return self._collective_keys.get_instance_key(self._group_key, 

341 self._device) 

342 

343 def _get_ordering_token(self): 

344 if self._use_ordering_token(): 

345 return self._ordering_token.handle # pytype: disable=attribute-error 

346 

347 def can_order_nccl(self): 

348 """Whether this launcher can order NCCL operations.""" 

349 return self._use_ordering_token() 

350 

351 def all_reduce( 

352 self, 

353 input_tensor: core.TensorLike, 

354 control_input: Optional[Union[core.TensorLike, ops.Operation]] = None, 

355 options: Optional[collective_util.Options] = None) -> core.Tensor: 

356 """All-reduce a dense tensor. 

357 

358 Args: 

359 input_tensor: a dense tensor. It must have the same shape on all replicas. 

360 control_input: if not None, add control edges between control_input and 

361 the all-reduce. 

362 options: an optional tf.distribute.experimental.CommunicationOptions. If 

363 provided, it overrides the default options. 

364 

365 Returns: 

366 The reduced tensor. 

367 """ 

368 instance_key = self._next_instance_key() 

369 options = self._options.merge(options) 

370 ordering_token = self._get_ordering_token() 

371 with ops.device(self._device), \ 

372 self._control_input(control_input): 

373 return collective_ops.all_reduce_v2( 

374 input_tensor, 

375 self._group_size, 

376 self._group_key, 

377 instance_key, 

378 communication_hint=options.implementation.value, 

379 timeout=options.timeout_seconds, 

380 ordering_token=ordering_token) 

381 

382 def _all_gather(self, input_tensor: core.TensorLike, 

383 options: Optional[collective_util.Options]) -> core.Tensor: 

384 """All-gather a dense tensor. 

385 

386 Args: 

387 input_tensor: a dense tensor. It must have the same shape on all replicas. 

388 options: an optional tf.distribute.experimental.CommunicationOptions. If 

389 provided, it overrides the default options. 

390 

391 Returns: 

392 The reduced tensor. 

393 """ 

394 instance_key = self._next_instance_key() 

395 options = self._options.merge(options) 

396 ordering_token = self._get_ordering_token() 

397 with ops.device(self._device): 

398 return collective_ops.all_gather_v2( 

399 input_tensor, 

400 self._group_size, 

401 self._group_key, 

402 instance_key, 

403 communication_hint=options.implementation.value, 

404 timeout=options.timeout_seconds, 

405 ordering_token=ordering_token) 

406 

407 def batch_all_reduce( 

408 self, 

409 input_tensor_packs: List[List[core.TensorLike]], 

410 options: Optional[collective_util.Options] = None) -> core.Tensor: 

411 """Batch all-reduce dense tensors. 

412 

413 This takes a list of batches of tensors. Using multiple batches have the 

414 benefit that it doesn't need to wait for all inputs to be ready to start the 

415 all-reduce. 

416 

417 Args: 

418 input_tensor_packs: a list of lists of dense tensors. 

419 options: an optional tf.distribute.experimental.CommunicationOptions. If 

420 provided, it overrides the default options. 

421 

422 Returns: 

423 A flat list of reduced tensors. 

424 """ 

425 options = self._options.merge(options) 

426 outputs = [] 

427 for pack in input_tensor_packs: 

428 if context.executing_eagerly(): 

429 # We don't batch in eager as it sometimes makes the performance worse 

430 # due the concat/split ops. 

431 for input_tensor in pack: 

432 outputs.append(self.all_reduce(input_tensor, None, options)) 

433 else: 

434 # TODO(b/169168846): inserts a parallel all_gather to verify packings 

435 # are the same on each replica. 

436 with ops.device(self._device): 

437 flat_tensors = [array_ops.reshape(t, [-1]) for t in pack] 

438 shapes = [array_ops.shape(t) for t in pack] 

439 if (options.implementation 

440 == collective_util.CommunicationImplementation.NCCL and outputs): 

441 control_input = outputs[-1] 

442 else: 

443 control_input = None 

444 reduced = self.all_reduce( 

445 array_ops.concat(flat_tensors, axis=0), control_input, options) 

446 num_elements = [math_ops.reduce_prod(s) for s in shapes] 

447 flat_outputs = array_ops.split(reduced, num_elements, axis=0) 

448 for shape, flat_output in zip(shapes, flat_outputs): 

449 outputs.append(array_ops.reshape(flat_output, shape)) 

450 

451 return outputs 

452 

453 def all_gather( 

454 self, 

455 input_tensor: core.TensorLike, 

456 axis: core.TensorLike, 

457 options: Optional[collective_util.Options] = None) -> core.Tensor: 

458 """All-gather a dense tensor. 

459 

460 This method must be called inside a tf.function. 

461 

462 Args: 

463 input_tensor: a dense tensor. It must have the same rank on all replicas, 

464 and dimensions other than `axis` need to be the same as well. 

465 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 

466 range [0, rank(value)). 

467 options: an optional tf.distribute.experimental.CommunicationOptions. If 

468 provided, it overrides the default options. 

469 

470 Returns: 

471 The gathered Tensor. 

472 

473 Raises: 

474 RuntimeError: if called in eager mode. 

475 """ 

476 if context.executing_eagerly(): 

477 raise RuntimeError('all_gather is not supported in eager mode.') 

478 

479 with ops.device(self._device), \ 

480 ops.control_dependencies([array_ops.identity(input_tensor)]): 

481 # 1. Transpose 

482 # E.g. Given an input_tensor with shape [2,2,5,1] and axis to gather is 3, 

483 # we use perm_pre=[3 0 1 2] to reshape it to [1,2,2,5], which 

484 # brings the 3rd dim first; afterwards we use perm_after=[1,2,3,0] to 

485 # place it back. 

486 perm_pre = array_ops.concat( 

487 ([axis], math_ops.range(axis), 

488 math_ops.range(axis + 1, array_ops.rank(input_tensor))), 

489 axis=0) 

490 input_tensor_t = array_ops.transpose(input_tensor, perm=perm_pre) 

491 # 2. Pad 

492 gathered_shape = self._all_gather( 

493 array_ops.expand_dims_v2(array_ops.shape_v2(input_tensor_t), axis=0), 

494 options) 

495 first_dims = gathered_shape[:, 0] 

496 full_axis_dim = math_ops.reduce_max(first_dims) 

497 padded_input_tensor = _pad_util(input_tensor_t, full_axis_dim) 

498 

499 # 3. Gather 

500 gather_padded_out_tensor = self._all_gather(padded_input_tensor, options) 

501 # 4. Unpad 

502 split_tensors = [] 

503 for i in range(self._group_size): 

504 start_pos = i * full_axis_dim 

505 split_tensors.append(gather_padded_out_tensor[start_pos:start_pos + 

506 first_dims[i]]) 

507 out_tensor_t = array_ops.concat(split_tensors, 0) 

508 

509 # 5. Transpose back 

510 perm_after = array_ops.concat( 

511 (math_ops.range(1, axis + 1), [0], 

512 math_ops.range(axis + 1, array_ops.rank(input_tensor_t))), 

513 axis=0) 

514 return array_ops.transpose(out_tensor_t, perm=perm_after) 

515 

516 def all_reduce_indexed_slices( 

517 self, 

518 input_slices: indexed_slices.IndexedSlices, 

519 options: Optional[collective_util.Options] = None 

520 ) -> indexed_slices.IndexedSlices: 

521 """All-reduce an IndexedSlices. 

522 

523 This method can be called outside tf.function. 

524 

525 Args: 

526 input_slices: an IndexedSlices. 

527 options: an optional tf.distribute.experimental.CommunicationOptions. If 

528 provided, it overrides the default options. 

529 

530 Returns: 

531 The reduced IndexedSlices. 

532 """ 

533 

534 # Current CollectiveAllGather implementations require input IndexedSlices to 

535 # have consistent length across the board, we handle the reduction of 

536 # IndexedSlices as follows: 

537 # 1. Gather the lengths of IndexedSlices from all participants. 

538 # 2. If they have consistent length, apply all_gather. 

539 # 3. Otherwise pad IndexedSlices to be the same length across all 

540 # participants and apply_gather. 

541 options = self._options.merge(options) 

542 with ops.device(self._device): 

543 

544 def all_gather_indexed_slices( 

545 all_gather_fn: Callable[ 

546 [core.TensorLike, Optional[collective_util.Options]], core.Tensor] 

547 ) -> indexed_slices.IndexedSlices: 

548 """Use all_gather_fn to aggregate `IndexedSlices`.""" 

549 all_values = all_gather_fn(input_slices.values, options) 

550 # Add control dependency to order the all-gather. 

551 if (options.implementation == 

552 collective_util.CommunicationImplementation.NCCL): 

553 control = [all_values] 

554 else: 

555 control = [] 

556 with ops.control_dependencies(control): 

557 all_indices = all_gather_fn(input_slices.indices, options) 

558 return indexed_slices.IndexedSlices( 

559 values=all_values, 

560 indices=all_indices, 

561 dense_shape=input_slices.dense_shape) 

562 

563 length = array_ops.shape(input_slices.indices) 

564 all_lengths = self._all_gather(length, options) 

565 

566 def all_gather_with_padding( 

567 input_tensor: core.TensorLike, 

568 options: Optional[collective_util.Options]) -> core.Tensor: 

569 """all_gather tensors of different sizes using padding.""" 

570 max_length = math_ops.reduce_max(all_lengths) 

571 padded_tensor = _pad_util(input_tensor, max_length) 

572 all_padded_tensors = self._all_gather(padded_tensor, options) 

573 split_tensors = [] 

574 for i in range(self._group_size): 

575 start_pos = i * max_length 

576 split_tensors.append(all_padded_tensors[start_pos:start_pos + 

577 all_lengths[i]]) 

578 return array_ops.concat(split_tensors, 0) 

579 

580 return cond.cond( 

581 math_ops.equal( 

582 math_ops.reduce_max(all_lengths), 

583 math_ops.reduce_min(all_lengths)), 

584 lambda: all_gather_indexed_slices(self._all_gather), 

585 lambda: all_gather_indexed_slices(all_gather_with_padding)) 

586 

587 

588def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n): 

589 """Aggregate tensors using `accumulation_fn` and IndexedSlices via concat.""" 

590 if any(isinstance(v, indexed_slices.IndexedSlices) for v in values): 

591 return backprop_util.AggregateIndexedSlicesGradients(values) 

592 else: 

593 return accumulation_fn(values) 

594 

595 

596def divide_by_n_tensors_or_indexed_slices(value, n): 

597 if isinstance(value, indexed_slices.IndexedSlices): 

598 value = backprop_util.FlattenNestedIndexedSlices(value) 

599 return indexed_slices.IndexedSlices(value.values / n, value.indices, 

600 value.dense_shape) 

601 else: 

602 return value / n 

603 

604 

605def copy_tensor_or_indexed_slices_to_device(value, device): 

606 """Copies a tensor or IndexedSlices to a device.""" 

607 with ops.device(device): 

608 if isinstance(value, indexed_slices.IndexedSlices): 

609 copied_values = array_ops.identity(value.values) 

610 copied_indices = array_ops.identity(value.indices) 

611 if value.dense_shape is not None: 

612 copied_shape = array_ops.identity(value.dense_shape) 

613 else: 

614 copied_shape = None 

615 result = indexed_slices.IndexedSlices(copied_values, copied_indices, 

616 copied_shape) 

617 else: 

618 result = array_ops.identity(value) 

619 return result 

620 

621 

622def is_indexed_slices(value): 

623 if isinstance(value, indexed_slices.IndexedSlices): 

624 return True 

625 if isinstance(value, value_lib.DistributedValues): 

626 return all( 

627 isinstance(v, indexed_slices.IndexedSlices) for v in value.values) 

628 return False 

629 

630 

631def split_by_sparsity(values): 

632 """Split values into dense and sparse values. 

633 

634 Args: 

635 values: a list of tensors or `PerReplica`s. 

636 

637 Returns: 

638 Four lists: 

639 a list of dense values, a list of their indices in `values` and 

640 a list of sparse values, a list of their indices in `values`. 

641 """ 

642 dense_values = [] 

643 dense_indices = [] 

644 sparse_values = [] 

645 sparse_indices = [] 

646 for i, v in enumerate(values): 

647 if is_indexed_slices(v): 

648 sparse_values.append(v) 

649 sparse_indices.append(i) 

650 else: 

651 dense_values.append(v) 

652 dense_indices.append(i) 

653 return dense_values, dense_indices, sparse_values, sparse_indices 

654 

655 

656def stitch_values(values_and_indices_list): 

657 """Stitch values together according to their indices. 

658 

659 Args: 

660 values_and_indices_list: a list of tuples of values and indices indicating 

661 the values and positions in the returned list. 

662 

663 Returns: 

664 a stitched list of values. 

665 """ 

666 length = 0 

667 for values_and_indices in values_and_indices_list: 

668 length += len(values_and_indices[0]) 

669 

670 result = [None] * length 

671 for values_and_indices in values_and_indices_list: 

672 if values_and_indices and values_and_indices[0]: 

673 for v, i in zip(*values_and_indices): 

674 assert result[i] is None 

675 result[i] = v 

676 return result 

677 

678 

679def group_by_size(input_tensors, bytes_per_pack): 

680 """Groups `input_tensors` into chunks of `bytes_per_pack`. 

681 

682 The method preserves the original order of `input_tensors`. The grouping is 

683 best effort, each pack could have more or less bytes than `bytes_per_pack`. 

684 It only groups values with known shape. 

685 

686 Args: 

687 input_tensors: a list of Tensor. 

688 bytes_per_pack: an integer. 

689 

690 Returns: 

691 A list of packs of Tensor. All values are grouped into one pack if 

692 `bytes_per_pack` is zero or any of the value has unknown shape. 

693 """ 

694 

695 if bytes_per_pack == 0: 

696 return [input_tensors] 

697 packs = [] 

698 last_pack_size = 0 

699 for value in input_tensors: 

700 num_elements = value.shape.num_elements() 

701 if num_elements is None: 

702 # Can't pack values with unknown shape. 

703 logging.warning( 

704 'not packing values due to the unknown or inconsistent shape of %s', 

705 value) 

706 return [input_tensors] 

707 size = num_elements * value.dtype.size 

708 # Try to keep each pack as close to bytes_per_pack as possible, while each 

709 # pack is at least bytes_per_pack large. I.E. we err on the side of having 

710 # few but large packs. 

711 if not packs or last_pack_size > bytes_per_pack: 

712 packs.append([]) 

713 last_pack_size = 0 

714 packs[-1].append(value) 

715 last_pack_size += size 

716 return packs 

717 

718 

719def _pad_util(input_tensor, full_axis_dim): 

720 """Pad the `input_tensor`'s first dimension to be `full_axis_dim`.""" 

721 missing_axis_dim = full_axis_dim - array_ops.shape_v2(input_tensor)[0] 

722 tensor_rank = array_ops.rank(input_tensor) 

723 paddings_axis = [[0, missing_axis_dim]] 

724 paddings = array_ops.concat([ 

725 paddings_axis, 

726 array_ops.zeros(shape=(tensor_rank - 1, 2), dtype=dtypes.int32) 

727 ], 

728 axis=0) 

729 padded_input_tensor = array_ops.pad(input_tensor, paddings) 

730 return padded_input_tensor