Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/data_flow_ops.py: 24%

621 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14#============================================================================== 

15"""Data Flow Operations.""" 

16# pylint: disable=g-bad-name 

17import functools 

18import hashlib 

19import threading 

20 

21from tensorflow.python.eager import context 

22from tensorflow.python.framework import dtypes as _dtypes 

23from tensorflow.python.framework import indexed_slices 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import random_seed 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_util 

28from tensorflow.python.lib.io import python_io 

29from tensorflow.python.ops import array_ops 

30from tensorflow.python.ops import array_ops_stack 

31from tensorflow.python.ops import control_flow_ops 

32from tensorflow.python.ops import gen_data_flow_ops 

33from tensorflow.python.ops import math_ops 

34from tensorflow.python.ops import resource_variable_ops 

35# go/tf-wildcard-import 

36# pylint: disable=wildcard-import 

37from tensorflow.python.ops.gen_data_flow_ops import * 

38from tensorflow.python.util import deprecation 

39from tensorflow.python.util.compat import collections_abc 

40from tensorflow.python.util.tf_export import tf_export 

41 

42# pylint: enable=wildcard-import 

43 

44 

45def _as_type_list(dtypes): 

46 """Convert dtypes to a list of types.""" 

47 assert dtypes is not None 

48 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): 

49 # We have a single type. 

50 return [dtypes] 

51 else: 

52 # We have a list or tuple of types. 

53 return list(dtypes) 

54 

55 

56def _as_shape_list(shapes, 

57 dtypes, 

58 unknown_dim_allowed=False, 

59 unknown_rank_allowed=False): 

60 """Convert shapes to a list of tuples of int (or None).""" 

61 del dtypes 

62 if unknown_dim_allowed: 

63 if (not isinstance(shapes, collections_abc.Sequence) or not shapes or 

64 any(shape is None or isinstance(shape, int) for shape in shapes)): 

65 raise ValueError( 

66 "When providing partial shapes, a list of shapes must be provided.") 

67 if shapes is None: 

68 return None 

69 if isinstance(shapes, tensor_shape.TensorShape): 

70 shapes = [shapes] 

71 if not isinstance(shapes, (tuple, list)): 

72 raise TypeError( 

73 "Shapes must be a TensorShape or a list or tuple of TensorShapes, " 

74 f"got {type(shapes)} instead.") 

75 if all(shape is None or isinstance(shape, int) for shape in shapes): 

76 # We have a single shape. 

77 shapes = [shapes] 

78 shapes = [tensor_shape.as_shape(shape) for shape in shapes] 

79 if not unknown_dim_allowed: 

80 if any(not shape.is_fully_defined() for shape in shapes): 

81 raise ValueError(f"All shapes must be fully defined: {shapes}") 

82 if not unknown_rank_allowed: 

83 if any(shape.dims is None for shape in shapes): 

84 raise ValueError(f"All shapes must have a defined rank: {shapes}") 

85 

86 return shapes 

87 

88 

89def _as_name_list(names, dtypes): 

90 if names is None: 

91 return None 

92 if not isinstance(names, (list, tuple)): 

93 names = [names] 

94 if len(names) != len(dtypes): 

95 raise ValueError("List of names must have the same length as the list " 

96 f"of dtypes, received len(names)={len(names)}," 

97 f"len(dtypes)={len(dtypes)}") 

98 return list(names) 

99 

100 

101def _shape_common(s1, s2): 

102 """The greatest lower bound (ordered by specificity) TensorShape.""" 

103 s1 = tensor_shape.TensorShape(s1) 

104 s2 = tensor_shape.TensorShape(s2) 

105 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims: 

106 return tensor_shape.unknown_shape() 

107 d = [ 

108 d1 if d1 is not None and d1 == d2 else None 

109 for (d1, d2) in zip(s1.as_list(), s2.as_list()) 

110 ] 

111 return tensor_shape.TensorShape(d) 

112 

113 

114# pylint: disable=protected-access 

115@tf_export("queue.QueueBase", 

116 v1=["queue.QueueBase", "io.QueueBase", "QueueBase"]) 

117@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"]) 

118class QueueBase: 

119 """Base class for queue implementations. 

120 

121 A queue is a TensorFlow data structure that stores tensors across 

122 multiple steps, and exposes operations that enqueue and dequeue 

123 tensors. 

124 

125 Each queue element is a tuple of one or more tensors, where each 

126 tuple component has a static dtype, and may have a static shape. The 

127 queue implementations support versions of enqueue and dequeue that 

128 handle single elements, versions that support enqueuing and 

129 dequeuing a batch of elements at once. 

130 

131 See `tf.queue.FIFOQueue` and 

132 `tf.queue.RandomShuffleQueue` for concrete 

133 implementations of this class, and instructions on how to create 

134 them. 

135 """ 

136 

137 def __init__(self, dtypes, shapes, names, queue_ref): 

138 """Constructs a queue object from a queue reference. 

139 

140 The two optional lists, `shapes` and `names`, must be of the same length 

141 as `dtypes` if provided. The values at a given index `i` indicate the 

142 shape and name to use for the corresponding queue component in `dtypes`. 

143 

144 Args: 

145 dtypes: A list of types. The length of dtypes must equal the number 

146 of tensors in each element. 

147 shapes: Constraints on the shapes of tensors in an element: 

148 A list of shape tuples or None. This list is the same length 

149 as dtypes. If the shape of any tensors in the element are constrained, 

150 all must be; shapes can be None if the shapes should not be constrained. 

151 names: Optional list of names. If provided, the `enqueue()` and 

152 `dequeue()` methods will use dictionaries with these names as keys. 

153 Must be None or a list or tuple of the same length as `dtypes`. 

154 queue_ref: The queue reference, i.e. the output of the queue op. 

155 

156 Raises: 

157 ValueError: If one of the arguments is invalid. 

158 """ 

159 self._dtypes = dtypes 

160 if shapes is not None: 

161 if len(shapes) != len(dtypes): 

162 raise ValueError("Queue shapes must have the same length as dtypes, " 

163 f"received len(shapes)={len(shapes)}, " 

164 f"len(dtypes)={len(dtypes)}") 

165 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 

166 else: 

167 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 

168 if names is not None: 

169 if len(names) != len(dtypes): 

170 raise ValueError("Queue names must have the same length as dtypes," 

171 f"received len(names)={len(names)}," 

172 f"len {len(dtypes)}") 

173 self._names = names 

174 else: 

175 self._names = None 

176 self._queue_ref = queue_ref 

177 if isinstance(queue_ref, ops.EagerTensor): 

178 if context.context().scope_name: 

179 self._name = context.context().scope_name 

180 else: 

181 self._name = "Empty" 

182 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 

183 queue_ref, None) 

184 else: 

185 self._name = self._queue_ref.op.name.split("/")[-1] 

186 

187 @staticmethod 

188 def from_list(index, queues): 

189 """Create a queue using the queue reference from `queues[index]`. 

190 

191 Args: 

192 index: An integer scalar tensor that determines the input that gets 

193 selected. 

194 queues: A list of `QueueBase` objects. 

195 

196 Returns: 

197 A `QueueBase` object. 

198 

199 Raises: 

200 TypeError: When `queues` is not a list of `QueueBase` objects, 

201 or when the data types of `queues` are not all the same. 

202 """ 

203 if ((not queues) or (not isinstance(queues, list)) or 

204 (not all(isinstance(x, QueueBase) for x in queues))): 

205 raise TypeError("A list of queues expected") 

206 

207 dtypes = queues[0].dtypes 

208 if not all(dtypes == q.dtypes for q in queues[1:]): 

209 raise TypeError("Queues do not have matching component dtypes.") 

210 

211 names = queues[0].names 

212 if not all(names == q.names for q in queues[1:]): 

213 raise TypeError("Queues do not have matching component names.") 

214 

215 queue_shapes = [q.shapes for q in queues] 

216 reduced_shapes = [ 

217 functools.reduce(_shape_common, s) for s in zip(*queue_shapes) 

218 ] 

219 

220 queue_refs = array_ops_stack.stack([x.queue_ref for x in queues]) 

221 selected_queue = array_ops.gather(queue_refs, index) 

222 return QueueBase( 

223 dtypes=dtypes, 

224 shapes=reduced_shapes, 

225 names=names, 

226 queue_ref=selected_queue) 

227 

228 @property 

229 def queue_ref(self): 

230 """The underlying queue reference.""" 

231 return self._queue_ref 

232 

233 @property 

234 def name(self): 

235 """The name of the underlying queue.""" 

236 if context.executing_eagerly(): 

237 return self._name 

238 return self._queue_ref.op.name 

239 

240 @property 

241 def dtypes(self): 

242 """The list of dtypes for each component of a queue element.""" 

243 return self._dtypes 

244 

245 @property 

246 def shapes(self): 

247 """The list of shapes for each component of a queue element.""" 

248 return self._shapes 

249 

250 @property 

251 def names(self): 

252 """The list of names for each component of a queue element.""" 

253 return self._names 

254 

255 def _check_enqueue_dtypes(self, vals): 

256 """Validate and convert `vals` to a list of `Tensor`s. 

257 

258 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 

259 dictionary with tensor values. 

260 

261 If it is a dictionary, the queue must have been constructed with a 

262 `names` attribute and the dictionary keys must match the queue names. 

263 If the queue was constructed with a `names` attribute, `vals` must 

264 be a dictionary. 

265 

266 Args: 

267 vals: A tensor, a list or tuple of tensors, or a dictionary.. 

268 

269 Returns: 

270 A list of `Tensor` objects. 

271 

272 Raises: 

273 ValueError: If `vals` is invalid. 

274 """ 

275 if isinstance(vals, dict): 

276 if not self._names: 

277 raise ValueError("Queue must have names to enqueue a dictionary") 

278 if sorted(self._names, key=str) != sorted(vals.keys(), key=str): 

279 raise ValueError("Keys in dictionary to enqueue do not match " 

280 f"names of Queue. Dictionary: {sorted(vals.keys())}," 

281 f"Queue: {sorted(self._names)}") 

282 # The order of values in `self._names` indicates the order in which the 

283 # tensors in the dictionary `vals` must be listed. 

284 vals = [vals[k] for k in self._names] 

285 else: 

286 if self._names: 

287 raise ValueError("You must enqueue a dictionary in a Queue with names") 

288 if not isinstance(vals, (list, tuple)): 

289 vals = [vals] 

290 

291 tensors = [] 

292 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): 

293 tensors.append( 

294 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 

295 

296 return tensors 

297 

298 def _scope_vals(self, vals): 

299 """Return a list of values to pass to `name_scope()`. 

300 

301 Args: 

302 vals: A tensor, a list or tuple of tensors, or a dictionary. 

303 

304 Returns: 

305 The values in vals as a list. 

306 """ 

307 if isinstance(vals, (list, tuple)): 

308 return vals 

309 elif isinstance(vals, dict): 

310 return vals.values() 

311 else: 

312 return [vals] 

313 

314 def enqueue(self, vals, name=None): 

315 """Enqueues one element to this queue. 

316 

317 If the queue is full when this operation executes, it will block 

318 until the element has been enqueued. 

319 

320 At runtime, this operation may raise an error if the queue is 

321 `tf.QueueBase.close` before or during its execution. If the 

322 queue is closed before this operation runs, 

323 `tf.errors.CancelledError` will be raised. If this operation is 

324 blocked, and either (i) the queue is closed by a close operation 

325 with `cancel_pending_enqueues=True`, or (ii) the session is 

326 `tf.Session.close`, 

327 `tf.errors.CancelledError` will be raised. 

328 

329 Args: 

330 vals: A tensor, a list or tuple of tensors, or a dictionary containing 

331 the values to enqueue. 

332 name: A name for the operation (optional). 

333 

334 Returns: 

335 The operation that enqueues a new tuple of tensors to the queue. 

336 """ 

337 with ops.name_scope(name, "%s_enqueue" % self._name, 

338 self._scope_vals(vals)) as scope: 

339 vals = self._check_enqueue_dtypes(vals) 

340 

341 # NOTE(mrry): Not using a shape function because we need access to 

342 # the `QueueBase` object. 

343 for val, shape in zip(vals, self._shapes): 

344 val.get_shape().assert_is_compatible_with(shape) 

345 

346 if self._queue_ref.dtype == _dtypes.resource: 

347 return gen_data_flow_ops.queue_enqueue_v2( 

348 self._queue_ref, vals, name=scope) 

349 else: 

350 return gen_data_flow_ops.queue_enqueue( 

351 self._queue_ref, vals, name=scope) 

352 

353 def enqueue_many(self, vals, name=None): 

354 """Enqueues zero or more elements to this queue. 

355 

356 This operation slices each component tensor along the 0th dimension to 

357 make multiple queue elements. All of the tensors in `vals` must have the 

358 same size in the 0th dimension. 

359 

360 If the queue is full when this operation executes, it will block 

361 until all of the elements have been enqueued. 

362 

363 At runtime, this operation may raise an error if the queue is 

364 `tf.QueueBase.close` before or during its execution. If the 

365 queue is closed before this operation runs, 

366 `tf.errors.CancelledError` will be raised. If this operation is 

367 blocked, and either (i) the queue is closed by a close operation 

368 with `cancel_pending_enqueues=True`, or (ii) the session is 

369 `tf.Session.close`, 

370 `tf.errors.CancelledError` will be raised. 

371 

372 Args: 

373 vals: A tensor, a list or tuple of tensors, or a dictionary 

374 from which the queue elements are taken. 

375 name: A name for the operation (optional). 

376 

377 Returns: 

378 The operation that enqueues a batch of tuples of tensors to the queue. 

379 """ 

380 with ops.name_scope(name, "%s_EnqueueMany" % self._name, 

381 self._scope_vals(vals)) as scope: 

382 vals = self._check_enqueue_dtypes(vals) 

383 

384 # NOTE(mrry): Not using a shape function because we need access to 

385 # the `QueueBase` object. 

386 # NOTE(fchollet): the code that follow is verbose because it needs to be 

387 # compatible with both TF v1 TensorShape behavior and TF v2 behavior. 

388 batch_dim = tensor_shape.dimension_value( 

389 vals[0].get_shape().with_rank_at_least(1)[0]) 

390 batch_dim = tensor_shape.Dimension(batch_dim) 

391 for val, shape in zip(vals, self._shapes): 

392 val_batch_dim = tensor_shape.dimension_value( 

393 val.get_shape().with_rank_at_least(1)[0]) 

394 val_batch_dim = tensor_shape.Dimension(val_batch_dim) 

395 batch_dim = batch_dim.merge_with(val_batch_dim) 

396 val.get_shape()[1:].assert_is_compatible_with(shape) 

397 

398 return gen_data_flow_ops.queue_enqueue_many_v2( 

399 self._queue_ref, vals, name=scope) 

400 

401 def _dequeue_return_value(self, tensors): 

402 """Return the value to return from a dequeue op. 

403 

404 If the queue has names, return a dictionary with the 

405 names as keys. Otherwise return either a single tensor 

406 or a list of tensors depending on the length of `tensors`. 

407 

408 Args: 

409 tensors: List of tensors from the dequeue op. 

410 

411 Returns: 

412 A single tensor, a list of tensors, or a dictionary 

413 of tensors. 

414 """ 

415 if self._names: 

416 # The returned values in `tensors` are in the same order as 

417 # the names in `self._names`. 

418 return {n: tensors[i] for i, n in enumerate(self._names)} 

419 elif len(tensors) == 1: 

420 return tensors[0] 

421 else: 

422 return tensors 

423 

424 def dequeue(self, name=None): 

425 """Dequeues one element from this queue. 

426 

427 If the queue is empty when this operation executes, it will block 

428 until there is an element to dequeue. 

429 

430 At runtime, this operation may raise an error if the queue is 

431 `tf.QueueBase.close` before or during its execution. If the 

432 queue is closed, the queue is empty, and there are no pending 

433 enqueue operations that can fulfill this request, 

434 `tf.errors.OutOfRangeError` will be raised. If the session is 

435 `tf.Session.close`, 

436 `tf.errors.CancelledError` will be raised. 

437 

438 Args: 

439 name: A name for the operation (optional). 

440 

441 Returns: 

442 The tuple of tensors that was dequeued. 

443 """ 

444 if name is None: 

445 name = "%s_Dequeue" % self._name 

446 if self._queue_ref.dtype == _dtypes.resource: 

447 ret = gen_data_flow_ops.queue_dequeue_v2( 

448 self._queue_ref, self._dtypes, name=name) 

449 else: 

450 ret = gen_data_flow_ops.queue_dequeue( 

451 self._queue_ref, self._dtypes, name=name) 

452 

453 # NOTE(mrry): Not using a shape function because we need access to 

454 # the `QueueBase` object. 

455 if not context.executing_eagerly(): 

456 op = ret[0].op 

457 for output, shape in zip(op.values(), self._shapes): 

458 output.set_shape(shape) 

459 

460 return self._dequeue_return_value(ret) 

461 

462 def dequeue_many(self, n, name=None): 

463 """Dequeues and concatenates `n` elements from this queue. 

464 

465 This operation concatenates queue-element component tensors along 

466 the 0th dimension to make a single component tensor. All of the 

467 components in the dequeued tuple will have size `n` in the 0th dimension. 

468 

469 If the queue is closed and there are less than `n` elements left, then an 

470 `OutOfRange` exception is raised. 

471 

472 At runtime, this operation may raise an error if the queue is 

473 `tf.QueueBase.close` before or during its execution. If the 

474 queue is closed, the queue contains fewer than `n` elements, and 

475 there are no pending enqueue operations that can fulfill this 

476 request, `tf.errors.OutOfRangeError` will be raised. If the 

477 session is `tf.Session.close`, 

478 `tf.errors.CancelledError` will be raised. 

479 

480 Args: 

481 n: A scalar `Tensor` containing the number of elements to dequeue. 

482 name: A name for the operation (optional). 

483 

484 Returns: 

485 The list of concatenated tensors that was dequeued. 

486 """ 

487 if name is None: 

488 name = "%s_DequeueMany" % self._name 

489 

490 ret = gen_data_flow_ops.queue_dequeue_many_v2( 

491 self._queue_ref, n=n, component_types=self._dtypes, name=name) 

492 

493 # NOTE(mrry): Not using a shape function because we need access to 

494 # the Queue object. 

495 if not context.executing_eagerly(): 

496 op = ret[0].op 

497 batch_dim = tensor_shape.Dimension( 

498 tensor_util.constant_value(op.inputs[1])) 

499 for output, shape in zip(op.values(), self._shapes): 

500 output.set_shape( 

501 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 

502 

503 return self._dequeue_return_value(ret) 

504 

505 def dequeue_up_to(self, n, name=None): 

506 """Dequeues and concatenates `n` elements from this queue. 

507 

508 **Note** This operation is not supported by all queues. If a queue does not 

509 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised. 

510 

511 This operation concatenates queue-element component tensors along 

512 the 0th dimension to make a single component tensor. If the queue 

513 has not been closed, all of the components in the dequeued tuple 

514 will have size `n` in the 0th dimension. 

515 

516 If the queue is closed and there are more than `0` but fewer than 

517 `n` elements remaining, then instead of raising a 

518 `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`, 

519 less than `n` elements are returned immediately. If the queue is 

520 closed and there are `0` elements left in the queue, then a 

521 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. 

522 Otherwise the behavior is identical to `dequeue_many`. 

523 

524 Args: 

525 n: A scalar `Tensor` containing the number of elements to dequeue. 

526 name: A name for the operation (optional). 

527 

528 Returns: 

529 The tuple of concatenated tensors that was dequeued. 

530 """ 

531 if name is None: 

532 name = "%s_DequeueUpTo" % self._name 

533 

534 ret = gen_data_flow_ops.queue_dequeue_up_to_v2( 

535 self._queue_ref, n=n, component_types=self._dtypes, name=name) 

536 

537 # NOTE(mrry): Not using a shape function because we need access to 

538 # the Queue object. 

539 if not context.executing_eagerly(): 

540 op = ret[0].op 

541 for output, shape in zip(op.values(), self._shapes): 

542 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) 

543 

544 return self._dequeue_return_value(ret) 

545 

546 def close(self, cancel_pending_enqueues=False, name=None): 

547 """Closes this queue. 

548 

549 This operation signals that no more elements will be enqueued in 

550 the given queue. Subsequent `enqueue` and `enqueue_many` 

551 operations will fail. Subsequent `dequeue` and `dequeue_many` 

552 operations will continue to succeed if sufficient elements remain 

553 in the queue. Subsequently dequeue and dequeue_many operations 

554 that would otherwise block waiting for more elements (if close 

555 hadn't been called) will now fail immediately. 

556 

557 If `cancel_pending_enqueues` is `True`, all pending requests will also 

558 be canceled. 

559 

560 Args: 

561 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 

562 `False` (described above). 

563 name: A name for the operation (optional). 

564 

565 Returns: 

566 The operation that closes the queue. 

567 """ 

568 if name is None: 

569 name = "%s_Close" % self._name 

570 if self._queue_ref.dtype == _dtypes.resource: 

571 return gen_data_flow_ops.queue_close_v2( 

572 self._queue_ref, 

573 cancel_pending_enqueues=cancel_pending_enqueues, 

574 name=name) 

575 else: 

576 return gen_data_flow_ops.queue_close( 

577 self._queue_ref, 

578 cancel_pending_enqueues=cancel_pending_enqueues, 

579 name=name) 

580 

581 def is_closed(self, name=None): 

582 """Returns true if queue is closed. 

583 

584 This operation returns true if the queue is closed and false if the queue 

585 is open. 

586 

587 Args: 

588 name: A name for the operation (optional). 

589 

590 Returns: 

591 True if the queue is closed and false if the queue is open. 

592 """ 

593 if name is None: 

594 name = "%s_Is_Closed" % self._name 

595 if self._queue_ref.dtype == _dtypes.resource: 

596 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) 

597 else: 

598 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) 

599 

600 def size(self, name=None): 

601 """Compute the number of elements in this queue. 

602 

603 Args: 

604 name: A name for the operation (optional). 

605 

606 Returns: 

607 A scalar tensor containing the number of elements in this queue. 

608 """ 

609 if name is None: 

610 name = "%s_Size" % self._name 

611 if self._queue_ref.dtype == _dtypes.resource: 

612 return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name) 

613 else: 

614 return gen_data_flow_ops.queue_size(self._queue_ref, name=name) 

615 

616def _shared_name(shared_name): 

617 if context.executing_eagerly(): 

618 return str(ops.uid()) 

619 return shared_name 

620 

621 

622@tf_export( 

623 "queue.RandomShuffleQueue", 

624 v1=["queue.RandomShuffleQueue", 

625 "io.RandomShuffleQueue", "RandomShuffleQueue"]) 

626@deprecation.deprecated_endpoints( 

627 ["io.RandomShuffleQueue", "RandomShuffleQueue"]) 

628class RandomShuffleQueue(QueueBase): 

629 """A queue implementation that dequeues elements in a random order. 

630 

631 See `tf.queue.QueueBase` for a description of the methods on 

632 this class. 

633 """ 

634 

635 def __init__(self, 

636 capacity, 

637 min_after_dequeue, 

638 dtypes, 

639 shapes=None, 

640 names=None, 

641 seed=None, 

642 shared_name=None, 

643 name="random_shuffle_queue"): 

644 """Create a queue that dequeues elements in a random order. 

645 

646 A `RandomShuffleQueue` has bounded capacity; supports multiple 

647 concurrent producers and consumers; and provides exactly-once 

648 delivery. 

649 

650 A `RandomShuffleQueue` holds a list of up to `capacity` 

651 elements. Each element is a fixed-length tuple of tensors whose 

652 dtypes are described by `dtypes`, and whose shapes are optionally 

653 described by the `shapes` argument. 

654 

655 If the `shapes` argument is specified, each component of a queue 

656 element must have the respective fixed shape. If it is 

657 unspecified, different queue elements may have different shapes, 

658 but the use of `dequeue_many` is disallowed. 

659 

660 The `min_after_dequeue` argument allows the caller to specify a 

661 minimum number of elements that will remain in the queue after a 

662 `dequeue` or `dequeue_many` operation completes, to ensure a 

663 minimum level of mixing of elements. This invariant is maintained 

664 by blocking those operations until sufficient elements have been 

665 enqueued. The `min_after_dequeue` argument is ignored after the 

666 queue has been closed. 

667 

668 Args: 

669 capacity: An integer. The upper bound on the number of elements 

670 that may be stored in this queue. 

671 min_after_dequeue: An integer (described above). 

672 dtypes: A list of `DType` objects. The length of `dtypes` must equal 

673 the number of tensors in each queue element. 

674 shapes: (Optional.) A list of fully-defined `TensorShape` objects 

675 with the same length as `dtypes`, or `None`. 

676 names: (Optional.) A list of string naming the components in the queue 

677 with the same length as `dtypes`, or `None`. If specified the dequeue 

678 methods return a dictionary with the names as keys. 

679 seed: A Python integer. Used to create a random seed. See 

680 `tf.compat.v1.set_random_seed` 

681 for behavior. 

682 shared_name: (Optional.) If non-empty, this queue will be shared under 

683 the given name across multiple sessions. 

684 name: Optional name for the queue operation. 

685 """ 

686 dtypes = _as_type_list(dtypes) 

687 shapes = _as_shape_list(shapes, dtypes) 

688 names = _as_name_list(names, dtypes) 

689 seed1, seed2 = random_seed.get_seed(seed) 

690 if seed1 is None and seed2 is None: 

691 seed1, seed2 = 0, 0 

692 elif seed is None and shared_name is not None: 

693 # This means that graph seed is provided but op seed is not provided. 

694 # If shared_name is also provided, make seed2 depend only on the graph 

695 # seed and shared_name. (seed2 from get_seed() is generally dependent on 

696 # the id of the last op created.) 

697 string = (str(seed1) + shared_name).encode("utf-8") 

698 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 

699 queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( 

700 component_types=dtypes, 

701 shapes=shapes, 

702 capacity=capacity, 

703 min_after_dequeue=min_after_dequeue, 

704 seed=seed1, 

705 seed2=seed2, 

706 shared_name=_shared_name(shared_name), 

707 name=name) 

708 

709 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) 

710 

711 

712@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"]) 

713@deprecation.deprecated_endpoints("FIFOQueue") 

714class FIFOQueue(QueueBase): 

715 """A queue implementation that dequeues elements in first-in first-out order. 

716 

717 See `tf.queue.QueueBase` for a description of the methods on 

718 this class. 

719 """ 

720 

721 def __init__(self, 

722 capacity, 

723 dtypes, 

724 shapes=None, 

725 names=None, 

726 shared_name=None, 

727 name="fifo_queue"): 

728 """Creates a queue that dequeues elements in a first-in first-out order. 

729 

730 A `FIFOQueue` has bounded capacity; supports multiple concurrent 

731 producers and consumers; and provides exactly-once delivery. 

732 

733 A `FIFOQueue` holds a list of up to `capacity` elements. Each 

734 element is a fixed-length tuple of tensors whose dtypes are 

735 described by `dtypes`, and whose shapes are optionally described 

736 by the `shapes` argument. 

737 

738 If the `shapes` argument is specified, each component of a queue 

739 element must have the respective fixed shape. If it is 

740 unspecified, different queue elements may have different shapes, 

741 but the use of `dequeue_many` is disallowed. 

742 

743 Args: 

744 capacity: An integer. The upper bound on the number of elements 

745 that may be stored in this queue. 

746 dtypes: A list of `DType` objects. The length of `dtypes` must equal 

747 the number of tensors in each queue element. 

748 shapes: (Optional.) A list of fully-defined `TensorShape` objects 

749 with the same length as `dtypes`, or `None`. 

750 names: (Optional.) A list of string naming the components in the queue 

751 with the same length as `dtypes`, or `None`. If specified the dequeue 

752 methods return a dictionary with the names as keys. 

753 shared_name: (Optional.) If non-empty, this queue will be shared under 

754 the given name across multiple sessions. 

755 name: Optional name for the queue operation. 

756 """ 

757 dtypes = _as_type_list(dtypes) 

758 shapes = _as_shape_list(shapes, dtypes) 

759 names = _as_name_list(names, dtypes) 

760 with ops.init_scope(), ops.device("CPU"): 

761 queue_ref = gen_data_flow_ops.fifo_queue_v2( 

762 component_types=dtypes, 

763 shapes=shapes, 

764 capacity=capacity, 

765 shared_name=_shared_name(shared_name), 

766 name=name) 

767 

768 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 

769 

770 

771# TODO(allenl): If GPU-compatible queues turn out to be useful, we should 

772# implement GPU kernels for EnqueueMany and DequeueMany so we can make the 

773# public FIFOQueue GPU-compatible and remove this internal version. 

774class GPUCompatibleFIFOQueue(QueueBase): 

775 """A queue implementation that dequeues elements in first-in first-out order. 

776 

777 GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed 

778 either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues 

779 will be colocated with the queue resource. GPUCompatibleFIFOQueue only 

780 supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many. 

781 

782 See `tf.queue.QueueBase` for a description of the methods on this class. 

783 """ 

784 

785 def __init__(self, 

786 capacity, 

787 dtypes, 

788 shapes=None, 

789 names=None, 

790 shared_name=None, 

791 name="fifo_queue"): 

792 """Creates a queue that dequeues elements in a first-in first-out order. 

793 

794 A `FIFOQueue` has bounded capacity; supports multiple concurrent 

795 producers and consumers; and provides exactly-once delivery. 

796 

797 A `FIFOQueue` holds a list of up to `capacity` elements. Each 

798 element is a fixed-length tuple of tensors whose dtypes are 

799 described by `dtypes`, and whose shapes are optionally described 

800 by the `shapes` argument. 

801 

802 If the `shapes` argument is specified, each component of a queue 

803 element must have the respective fixed shape. If it is 

804 unspecified, different queue elements may have different shapes, 

805 but the use of `dequeue_many` is disallowed. 

806 

807 Args: 

808 capacity: An integer. The upper bound on the number of elements 

809 that may be stored in this queue. 

810 dtypes: A list of `DType` objects. The length of `dtypes` must equal 

811 the number of tensors in each queue element. 

812 shapes: (Optional.) A list of fully-defined `TensorShape` objects 

813 with the same length as `dtypes`, or `None`. 

814 names: (Optional.) A list of string naming the components in the queue 

815 with the same length as `dtypes`, or `None`. If specified the dequeue 

816 methods return a dictionary with the names as keys. 

817 shared_name: (Optional.) If non-empty, this queue will be shared under 

818 the given name across multiple sessions. 

819 name: Optional name for the queue operation. 

820 """ 

821 dtypes = _as_type_list(dtypes) 

822 shapes = _as_shape_list(shapes, dtypes) 

823 names = _as_name_list(names, dtypes) 

824 with ops.init_scope(): 

825 queue_ref = gen_data_flow_ops.fifo_queue_v2( 

826 component_types=dtypes, 

827 shapes=shapes, 

828 capacity=capacity, 

829 shared_name=_shared_name(shared_name), 

830 name=name) 

831 

832 super(GPUCompatibleFIFOQueue, self).__init__( 

833 dtypes, shapes, names, queue_ref) 

834 

835 def enqueue_many(self, vals, name=None): 

836 """enqueue_many is not supported on GPUCompatibleFIFOQueue.""" 

837 raise NotImplementedError( 

838 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 

839 "only enqueue and dequeue.") 

840 

841 def dequeue_many(self, n, name=None): 

842 """dequeue_many is not supported on GPUCompatibleFIFOQueue.""" 

843 raise NotImplementedError( 

844 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 

845 "only enqueue and dequeue.") 

846 

847 

848@tf_export( 

849 "queue.PaddingFIFOQueue", 

850 v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 

851@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 

852class PaddingFIFOQueue(QueueBase): 

853 """A FIFOQueue that supports batching variable-sized tensors by padding. 

854 

855 A `PaddingFIFOQueue` may contain components with dynamic shape, while also 

856 supporting `dequeue_many`. See the constructor for more details. 

857 

858 See `tf.queue.QueueBase` for a description of the methods on 

859 this class. 

860 """ 

861 

862 def __init__(self, 

863 capacity, 

864 dtypes, 

865 shapes, 

866 names=None, 

867 shared_name=None, 

868 name="padding_fifo_queue"): 

869 """Creates a queue that dequeues elements in a first-in first-out order. 

870 

871 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent 

872 producers and consumers; and provides exactly-once delivery. 

873 

874 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each 

875 element is a fixed-length tuple of tensors whose dtypes are 

876 described by `dtypes`, and whose shapes are described by the `shapes` 

877 argument. 

878 

879 The `shapes` argument must be specified; each component of a queue 

880 element must have the respective shape. Shapes of fixed 

881 rank but variable size are allowed by setting any shape dimension to None. 

882 In this case, the inputs' shape may vary along the given dimension, and 

883 `dequeue_many` will pad the given dimension with zeros up to the maximum 

884 shape of all elements in the given batch. 

885 

886 Args: 

887 capacity: An integer. The upper bound on the number of elements 

888 that may be stored in this queue. 

889 dtypes: A list of `DType` objects. The length of `dtypes` must equal 

890 the number of tensors in each queue element. 

891 shapes: A list of `TensorShape` objects, with the same length as 

892 `dtypes`. Any dimension in the `TensorShape` containing value 

893 `None` is dynamic and allows values to be enqueued with 

894 variable size in that dimension. 

895 names: (Optional.) A list of string naming the components in the queue 

896 with the same length as `dtypes`, or `None`. If specified the dequeue 

897 methods return a dictionary with the names as keys. 

898 shared_name: (Optional.) If non-empty, this queue will be shared under 

899 the given name across multiple sessions. 

900 name: Optional name for the queue operation. 

901 

902 Raises: 

903 ValueError: If shapes is not a list of shapes, or the lengths of dtypes 

904 and shapes do not match, or if names is specified and the lengths of 

905 dtypes and names do not match. 

906 """ 

907 dtypes = _as_type_list(dtypes) 

908 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True) 

909 names = _as_name_list(names, dtypes) 

910 if len(dtypes) != len(shapes): 

911 raise ValueError("Shapes must be provided for all components, " 

912 f"but received {len(dtypes)} dtypes and " 

913 f"{len(shapes)} shapes.") 

914 queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( 

915 component_types=dtypes, 

916 shapes=shapes, 

917 capacity=capacity, 

918 shared_name=_shared_name(shared_name), 

919 name=name) 

920 

921 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 

922 

923 

924@tf_export("queue.PriorityQueue", 

925 v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"]) 

926@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"]) 

927class PriorityQueue(QueueBase): 

928 """A queue implementation that dequeues elements in prioritized order. 

929 

930 See `tf.queue.QueueBase` for a description of the methods on 

931 this class. 

932 """ 

933 

934 def __init__(self, 

935 capacity, 

936 types, 

937 shapes=None, 

938 names=None, 

939 shared_name=None, 

940 name="priority_queue"): 

941 """Creates a queue that dequeues elements in a first-in first-out order. 

942 

943 A `PriorityQueue` has bounded capacity; supports multiple concurrent 

944 producers and consumers; and provides exactly-once delivery. 

945 

946 A `PriorityQueue` holds a list of up to `capacity` elements. Each 

947 element is a fixed-length tuple of tensors whose dtypes are 

948 described by `types`, and whose shapes are optionally described 

949 by the `shapes` argument. 

950 

951 If the `shapes` argument is specified, each component of a queue 

952 element must have the respective fixed shape. If it is 

953 unspecified, different queue elements may have different shapes, 

954 but the use of `dequeue_many` is disallowed. 

955 

956 Enqueues and Dequeues to the `PriorityQueue` must include an additional 

957 tuple entry at the beginning: the `priority`. The priority must be 

958 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`). 

959 

960 Args: 

961 capacity: An integer. The upper bound on the number of elements 

962 that may be stored in this queue. 

963 types: A list of `DType` objects. The length of `types` must equal 

964 the number of tensors in each queue element, except the first priority 

965 element. The first tensor in each element is the priority, 

966 which must be type int64. 

967 shapes: (Optional.) A list of fully-defined `TensorShape` objects, 

968 with the same length as `types`, or `None`. 

969 names: (Optional.) A list of strings naming the components in the queue 

970 with the same length as `dtypes`, or `None`. If specified, the dequeue 

971 methods return a dictionary with the names as keys. 

972 shared_name: (Optional.) If non-empty, this queue will be shared under 

973 the given name across multiple sessions. 

974 name: Optional name for the queue operation. 

975 """ 

976 types = _as_type_list(types) 

977 shapes = _as_shape_list(shapes, types) 

978 

979 queue_ref = gen_data_flow_ops.priority_queue_v2( 

980 component_types=types, 

981 shapes=shapes, 

982 capacity=capacity, 

983 shared_name=_shared_name(shared_name), 

984 name=name) 

985 

986 priority_dtypes = [_dtypes.int64] + types 

987 priority_shapes = [()] + shapes if shapes else shapes 

988 

989 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, 

990 queue_ref) 

991 

992 

993# TODO(josh11b): class BatchQueue(QueueBase): 

994 

995 

996class Barrier: 

997 """Represents a key-value map that persists across graph executions.""" 

998 

999 def __init__(self, types, shapes=None, shared_name=None, name="barrier"): 

1000 """Creates a barrier that persists across different graph executions. 

1001 

1002 A barrier represents a key-value map, where each key is a string, and 

1003 each value is a tuple of tensors. 

1004 

1005 At runtime, the barrier contains 'complete' and 'incomplete' 

1006 elements. A complete element has defined tensors for all 

1007 components of its value tuple, and may be accessed using 

1008 take_many. An incomplete element has some undefined components in 

1009 its value tuple, and may be updated using insert_many. 

1010 

1011 The barrier call `take_many` outputs values in a particular order. 

1012 First, it only outputs completed values. Second, the order in which 

1013 completed values are returned matches the order in which their very 

1014 first component was inserted into the barrier. So, for example, for this 

1015 sequence of insertions and removals: 

1016 

1017 barrier = Barrier((tf.string, tf.int32), shapes=((), ())) 

1018 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run() 

1019 barrier.insert_many(1, keys=["k1"], values=[1]).run() 

1020 barrier.insert_many(0, keys=["k3"], values=["c"]).run() 

1021 barrier.insert_many(1, keys=["k3"], values=[3]).run() 

1022 barrier.insert_many(1, keys=["k2"], values=[2]).run() 

1023 

1024 (indices, keys, values) = barrier.take_many(2) 

1025 (indices_val, keys_val, values0_val, values1_val) = 

1026 session.run([indices, keys, values[0], values[1]]) 

1027 

1028 The output will be (up to permutation of "k1" and "k2"): 

1029 

1030 indices_val == (-2**63, -2**63) 

1031 keys_val == ("k1", "k2") 

1032 values0_val == ("a", "b") 

1033 values1_val == (1, 2) 

1034 

1035 Note the key "k2" was inserted into the barrier before "k3". Even though 

1036 "k3" was completed first, both are complete by the time 

1037 take_many is called. As a result, "k2" is prioritized and "k1" and "k2" 

1038 are returned first. "k3" remains in the barrier until the next execution 

1039 of `take_many`. Since "k1" and "k2" had their first insertions into 

1040 the barrier together, their indices are the same (-2**63). The index 

1041 of "k3" will be -2**63 + 1, because it was the next new inserted key. 

1042 

1043 Args: 

1044 types: A single dtype or a tuple of dtypes, corresponding to the 

1045 dtypes of the tensor elements that comprise a value in this barrier. 

1046 shapes: Optional. Constraints on the shapes of tensors in the values: 

1047 a single tensor shape tuple; a tuple of tensor shape tuples 

1048 for each barrier-element tuple component; or None if the shape should 

1049 not be constrained. 

1050 shared_name: Optional. If non-empty, this barrier will be shared under 

1051 the given name across multiple sessions. 

1052 name: Optional name for the barrier op. 

1053 

1054 Raises: 

1055 ValueError: If one of the `shapes` indicate no elements. 

1056 """ 

1057 self._types = _as_type_list(types) 

1058 

1059 if shapes is not None: 

1060 shapes = _as_shape_list(shapes, self._types) 

1061 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 

1062 for i, shape in enumerate(self._shapes): 

1063 if shape.num_elements() == 0: 

1064 raise ValueError("Empty tensors are not supported, but received " 

1065 f"shape '{shape}' at index {i}") 

1066 else: 

1067 self._shapes = [tensor_shape.unknown_shape() for _ in self._types] 

1068 

1069 self._barrier_ref = gen_data_flow_ops.barrier( 

1070 component_types=self._types, 

1071 shapes=self._shapes, 

1072 shared_name=shared_name, 

1073 name=name) 

1074 if context.executing_eagerly(): 

1075 self._name = context.context().scope_name 

1076 else: 

1077 self._name = self._barrier_ref.op.name.split("/")[-1] 

1078 

1079 @property 

1080 def barrier_ref(self): 

1081 """Get the underlying barrier reference.""" 

1082 return self._barrier_ref 

1083 

1084 @property 

1085 def name(self): 

1086 """The name of the underlying barrier.""" 

1087 if context.executing_eagerly(): 

1088 return self._name 

1089 return self._barrier_ref.op.name 

1090 

1091 def insert_many(self, component_index, keys, values, name=None): 

1092 """For each key, assigns the respective value to the specified component. 

1093 

1094 This operation updates each element at component_index. 

1095 

1096 Args: 

1097 component_index: The component of the value that is being assigned. 

1098 keys: A vector of keys, with length n. 

1099 values: An any-dimensional tensor of values, which are associated with the 

1100 respective keys. The first dimension must have length n. 

1101 name: Optional name for the op. 

1102 

1103 Returns: 

1104 The operation that performs the insertion. 

1105 Raises: 

1106 InvalidArgumentsError: If inserting keys and values without elements. 

1107 """ 

1108 if name is None: 

1109 name = "%s_BarrierInsertMany" % self._name 

1110 return gen_data_flow_ops.barrier_insert_many( 

1111 self._barrier_ref, keys, values, component_index, name=name) 

1112 

1113 def take_many(self, 

1114 num_elements, 

1115 allow_small_batch=False, 

1116 timeout=None, 

1117 name=None): 

1118 """Takes the given number of completed elements from this barrier. 

1119 

1120 This operation concatenates completed-element component tensors along 

1121 the 0th dimension to make a single component tensor. 

1122 

1123 If barrier has no completed elements, this operation will block 

1124 until there are 'num_elements' elements to take. 

1125 

1126 TODO(b/25743580): the semantics of `allow_small_batch` are experimental 

1127 and may be extended to other cases in the future. 

1128 

1129 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking 

1130 already when the barrier is closed, it will block for ever. Fix this 

1131 by using asynchronous operations. 

1132 

1133 Args: 

1134 num_elements: The number of elements to take. 

1135 allow_small_batch: If the barrier is closed, don't block if there are less 

1136 completed elements than requested, but instead return all available 

1137 completed elements. 

1138 timeout: This specifies the number of milliseconds to block 

1139 before returning with DEADLINE_EXCEEDED. (This option is not 

1140 supported yet.) 

1141 name: A name for the operation (optional). 

1142 

1143 Returns: 

1144 A tuple of (index, key, value_list). 

1145 "index" is a int64 tensor of length num_elements containing the 

1146 index of the insert_many call for which the very first component of 

1147 the given element was inserted into the Barrier, starting with 

1148 the value -2**63. Note, this value is different from the 

1149 index of the insert_many call for which the element was completed. 

1150 "key" is a string tensor of length num_elements containing the keys. 

1151 "value_list" is a tuple of tensors, each one with size num_elements 

1152 in the 0th dimension for each component in the barrier's values. 

1153 

1154 """ 

1155 if name is None: 

1156 name = "%s_BarrierTakeMany" % self._name 

1157 ret = gen_data_flow_ops.barrier_take_many( 

1158 self._barrier_ref, 

1159 num_elements, 

1160 self._types, 

1161 allow_small_batch, 

1162 timeout, 

1163 name=name) 

1164 

1165 # NOTE(mrry): Not using a shape function because we need access to 

1166 # the Barrier object. 

1167 if not context.executing_eagerly(): 

1168 op = ret[0].op 

1169 if allow_small_batch: 

1170 batch_dim = None 

1171 else: 

1172 batch_dim = tensor_shape.Dimension( 

1173 tensor_util.constant_value(op.inputs[1])) 

1174 op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim])) # indices 

1175 op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim])) # keys 

1176 for output, shape in zip(op.outputs[2:], self._shapes): # value_list 

1177 output.set_shape( 

1178 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 

1179 

1180 return ret 

1181 

1182 def close(self, cancel_pending_enqueues=False, name=None): 

1183 """Closes this barrier. 

1184 

1185 This operation signals that no more new key values will be inserted in the 

1186 given barrier. Subsequent InsertMany operations with new keys will fail. 

1187 InsertMany operations that just complement already existing keys with other 

1188 components, will continue to succeed. Subsequent TakeMany operations will 

1189 continue to succeed if sufficient elements remain in the barrier. Subsequent 

1190 TakeMany operations that would block will fail immediately. 

1191 

1192 If `cancel_pending_enqueues` is `True`, all pending requests to the 

1193 underlying queue will also be canceled, and completing of already 

1194 started values is also not acceptable anymore. 

1195 

1196 Args: 

1197 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 

1198 `False` (described above). 

1199 name: Optional name for the op. 

1200 

1201 Returns: 

1202 The operation that closes the barrier. 

1203 """ 

1204 if name is None: 

1205 name = "%s_BarrierClose" % self._name 

1206 return gen_data_flow_ops.barrier_close( 

1207 self._barrier_ref, 

1208 cancel_pending_enqueues=cancel_pending_enqueues, 

1209 name=name) 

1210 

1211 def ready_size(self, name=None): 

1212 """Compute the number of complete elements in the given barrier. 

1213 

1214 Args: 

1215 name: A name for the operation (optional). 

1216 

1217 Returns: 

1218 A single-element tensor containing the number of complete elements in the 

1219 given barrier. 

1220 """ 

1221 if name is None: 

1222 name = "%s_BarrierReadySize" % self._name 

1223 return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name) 

1224 

1225 def incomplete_size(self, name=None): 

1226 """Compute the number of incomplete elements in the given barrier. 

1227 

1228 Args: 

1229 name: A name for the operation (optional). 

1230 

1231 Returns: 

1232 A single-element tensor containing the number of incomplete elements in 

1233 the given barrier. 

1234 """ 

1235 if name is None: 

1236 name = "%s_BarrierIncompleteSize" % self._name 

1237 return gen_data_flow_ops.barrier_incomplete_size( 

1238 self._barrier_ref, name=name) 

1239 

1240 

1241@tf_export(v1=["ConditionalAccumulatorBase"]) 

1242class ConditionalAccumulatorBase: 

1243 """A conditional accumulator for aggregating gradients. 

1244 

1245 Up-to-date gradients (i.e., time step at which gradient was computed is 

1246 equal to the accumulator's time step) are added to the accumulator. 

1247 

1248 Extraction of the average gradient is blocked until the required number of 

1249 gradients has been accumulated. 

1250 """ 

1251 

1252 def __init__(self, dtype, shape, accumulator_ref): 

1253 """Creates a new ConditionalAccumulator. 

1254 

1255 Args: 

1256 dtype: Datatype of the accumulated gradients. 

1257 shape: Shape of the accumulated gradients. 

1258 accumulator_ref: A handle to the conditional accumulator, created by sub- 

1259 classes 

1260 """ 

1261 self._dtype = dtype 

1262 if shape is not None: 

1263 self._shape = tensor_shape.TensorShape(shape) 

1264 else: 

1265 self._shape = tensor_shape.unknown_shape() 

1266 self._accumulator_ref = accumulator_ref 

1267 if context.executing_eagerly(): 

1268 self._name = context.context().scope_name 

1269 else: 

1270 self._name = self._accumulator_ref.op.name.split("/")[-1] 

1271 

1272 @property 

1273 def accumulator_ref(self): 

1274 """The underlying accumulator reference.""" 

1275 return self._accumulator_ref 

1276 

1277 @property 

1278 def name(self): 

1279 """The name of the underlying accumulator.""" 

1280 return self._name 

1281 

1282 @property 

1283 def dtype(self): 

1284 """The datatype of the gradients accumulated by this accumulator.""" 

1285 return self._dtype 

1286 

1287 def num_accumulated(self, name=None): 

1288 """Number of gradients that have currently been aggregated in accumulator. 

1289 

1290 Args: 

1291 name: Optional name for the operation. 

1292 

1293 Returns: 

1294 Number of accumulated gradients currently in accumulator. 

1295 """ 

1296 if name is None: 

1297 name = "%s_NumAccumulated" % self._name 

1298 

1299 return gen_data_flow_ops.resource_accumulator_num_accumulated( 

1300 self._accumulator_ref, name=name) 

1301 

1302 def set_global_step(self, new_global_step, name=None): 

1303 """Sets the global time step of the accumulator. 

1304 

1305 The operation logs a warning if we attempt to set to a time step that is 

1306 lower than the accumulator's own time step. 

1307 

1308 Args: 

1309 new_global_step: Value of new time step. Can be a variable or a constant 

1310 name: Optional name for the operation. 

1311 

1312 Returns: 

1313 Operation that sets the accumulator's time step. 

1314 """ 

1315 return gen_data_flow_ops.resource_accumulator_set_global_step( 

1316 self._accumulator_ref, 

1317 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 

1318 name=name) 

1319 

1320 

1321@tf_export(v1=["ConditionalAccumulator"]) 

1322class ConditionalAccumulator(ConditionalAccumulatorBase): 

1323 """A conditional accumulator for aggregating gradients. 

1324 

1325 Up-to-date gradients (i.e., time step at which gradient was computed is 

1326 equal to the accumulator's time step) are added to the accumulator. 

1327 

1328 Extraction of the average gradient is blocked until the required number of 

1329 gradients has been accumulated. 

1330 """ 

1331 

1332 def __init__(self, 

1333 dtype, 

1334 shape=None, 

1335 shared_name=None, 

1336 name="conditional_accumulator", 

1337 reduction_type="MEAN"): 

1338 """Creates a new ConditionalAccumulator. 

1339 

1340 Args: 

1341 dtype: Datatype of the accumulated gradients. 

1342 shape: Shape of the accumulated gradients. 

1343 shared_name: Optional. If non-empty, this accumulator will be shared under 

1344 the given name across multiple sessions. 

1345 name: Optional name for the accumulator. 

1346 reduction_type: Reduction type to use when taking the gradient. 

1347 """ 

1348 accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator( 

1349 dtype=dtype, 

1350 shape=shape, 

1351 shared_name=shared_name, 

1352 name=name, 

1353 reduction_type=reduction_type) 

1354 if context.executing_eagerly(): 

1355 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 

1356 handle=accumulator_ref, handle_device=context.context().device_name) 

1357 

1358 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) 

1359 

1360 def apply_grad(self, grad, local_step=0, name=None): 

1361 """Attempts to apply a gradient to the accumulator. 

1362 

1363 The attempt is silently dropped if the gradient is stale, i.e., local_step 

1364 is less than the accumulator's global time step. 

1365 

1366 Args: 

1367 grad: The gradient tensor to be applied. 

1368 local_step: Time step at which the gradient was computed. 

1369 name: Optional name for the operation. 

1370 

1371 Returns: 

1372 The operation that (conditionally) applies a gradient to the accumulator. 

1373 

1374 Raises: 

1375 ValueError: If grad is of the wrong shape 

1376 """ 

1377 grad = ops.convert_to_tensor(grad, self._dtype) 

1378 grad.get_shape().assert_is_compatible_with(self._shape) 

1379 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 

1380 

1381 return gen_data_flow_ops.resource_accumulator_apply_gradient( 

1382 self._accumulator_ref, local_step=local_step, gradient=grad, name=name) 

1383 

1384 def take_grad(self, num_required, name=None): 

1385 """Attempts to extract the average gradient from the accumulator. 

1386 

1387 The operation blocks until sufficient number of gradients have been 

1388 successfully applied to the accumulator. 

1389 

1390 Once successful, the following actions are also triggered: 

1391 

1392 - Counter of accumulated gradients is reset to 0. 

1393 - Aggregated gradient is reset to 0 tensor. 

1394 - Accumulator's internal time step is incremented by 1. 

1395 

1396 Args: 

1397 num_required: Number of gradients that needs to have been aggregated 

1398 name: Optional name for the operation 

1399 

1400 Returns: 

1401 A tensor holding the value of the average gradient. 

1402 

1403 Raises: 

1404 InvalidArgumentError: If num_required < 1 

1405 """ 

1406 out = gen_data_flow_ops.resource_accumulator_take_gradient( 

1407 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 

1408 out.set_shape(self._shape) 

1409 return out 

1410 

1411 

1412@tf_export( 

1413 v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"]) 

1414class SparseConditionalAccumulator(ConditionalAccumulatorBase): 

1415 """A conditional accumulator for aggregating sparse gradients. 

1416 

1417 Sparse gradients are represented by `IndexedSlices`. 

1418 

1419 Up-to-date gradients (i.e., time step at which gradient was computed is 

1420 equal to the accumulator's time step) are added to the accumulator. 

1421 

1422 Extraction of the average gradient is blocked until the required number of 

1423 gradients has been accumulated. 

1424 

1425 Args: 

1426 dtype: Datatype of the accumulated gradients. 

1427 shape: Shape of the accumulated gradients. 

1428 shared_name: Optional. If non-empty, this accumulator will be shared under 

1429 the given name across multiple sessions. 

1430 name: Optional name for the accumulator. 

1431 reduction_type: Reduction type to use when taking the gradient. 

1432 """ 

1433 

1434 def __init__(self, 

1435 dtype, 

1436 shape=None, 

1437 shared_name=None, 

1438 name="sparse_conditional_accumulator", 

1439 reduction_type="MEAN"): 

1440 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( 

1441 dtype=dtype, 

1442 shape=shape, 

1443 shared_name=shared_name, 

1444 name=name, 

1445 reduction_type=reduction_type) 

1446 super(SparseConditionalAccumulator, self).__init__(dtype, shape, 

1447 accumulator_ref) 

1448 

1449 def apply_indexed_slices_grad(self, grad, local_step=0, name=None): 

1450 """Attempts to apply a gradient to the accumulator. 

1451 

1452 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 

1453 is less than the accumulator's global time step. 

1454 

1455 Args: 

1456 grad: The gradient `IndexedSlices` to be applied. 

1457 local_step: Time step at which the gradient was computed. 

1458 name: Optional name for the operation. 

1459 

1460 Returns: 

1461 The operation that (conditionally) applies a gradient to the accumulator. 

1462 

1463 Raises: 

1464 InvalidArgumentError: If grad is of the wrong shape 

1465 """ 

1466 return self.apply_grad( 

1467 grad_indices=grad.indices, 

1468 grad_values=grad.values, 

1469 grad_shape=grad.dense_shape, 

1470 local_step=local_step, 

1471 name=name) 

1472 

1473 def apply_grad(self, 

1474 grad_indices, 

1475 grad_values, 

1476 grad_shape=None, 

1477 local_step=0, 

1478 name=None): 

1479 """Attempts to apply a sparse gradient to the accumulator. 

1480 

1481 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 

1482 is less than the accumulator's global time step. 

1483 

1484 A sparse gradient is represented by its indices, values and possibly empty 

1485 or None shape. Indices must be a vector representing the locations of 

1486 non-zero entries in the tensor. Values are the non-zero slices of the 

1487 gradient, and must have the same first dimension as indices, i.e., the nnz 

1488 represented by indices and values must be consistent. Shape, if not empty or 

1489 None, must be consistent with the accumulator's shape (if also provided). 

1490 

1491 Example: 

1492 A tensor [[0, 0], [0, 1], [2, 3]] can be represented 

1493 indices: [1,2] 

1494 values: [[0,1],[2,3]] 

1495 shape: [3, 2] 

1496 

1497 Args: 

1498 grad_indices: Indices of the sparse gradient to be applied. 

1499 grad_values: Values of the sparse gradient to be applied. 

1500 grad_shape: Shape of the sparse gradient to be applied. 

1501 local_step: Time step at which the gradient was computed. 

1502 name: Optional name for the operation. 

1503 

1504 Returns: 

1505 The operation that (conditionally) applies a gradient to the accumulator. 

1506 

1507 Raises: 

1508 InvalidArgumentError: If grad is of the wrong shape 

1509 """ 

1510 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 

1511 return gen_data_flow_ops.sparse_accumulator_apply_gradient( 

1512 self._accumulator_ref, 

1513 local_step=local_step, 

1514 gradient_indices=math_ops.cast(grad_indices, _dtypes.int64), 

1515 gradient_values=grad_values, 

1516 gradient_shape=math_ops.cast( 

1517 [] if grad_shape is None else grad_shape, _dtypes.int64), 

1518 has_known_shape=(grad_shape is not None), 

1519 name=name) 

1520 

1521 def take_grad(self, num_required, name=None): 

1522 """Attempts to extract the average gradient from the accumulator. 

1523 

1524 The operation blocks until sufficient number of gradients have been 

1525 successfully applied to the accumulator. 

1526 

1527 Once successful, the following actions are also triggered: 

1528 - Counter of accumulated gradients is reset to 0. 

1529 - Aggregated gradient is reset to 0 tensor. 

1530 - Accumulator's internal time step is incremented by 1. 

1531 

1532 Args: 

1533 num_required: Number of gradients that needs to have been aggregated 

1534 name: Optional name for the operation 

1535 

1536 Returns: 

1537 A tuple of indices, values, and shape representing the average gradient. 

1538 

1539 Raises: 

1540 InvalidArgumentError: If `num_required` < 1 

1541 """ 

1542 return gen_data_flow_ops.sparse_accumulator_take_gradient( 

1543 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 

1544 

1545 def take_indexed_slices_grad(self, num_required, name=None): 

1546 """Attempts to extract the average gradient from the accumulator. 

1547 

1548 The operation blocks until sufficient number of gradients have been 

1549 successfully applied to the accumulator. 

1550 

1551 Once successful, the following actions are also triggered: 

1552 - Counter of accumulated gradients is reset to 0. 

1553 - Aggregated gradient is reset to 0 tensor. 

1554 - Accumulator's internal time step is incremented by 1. 

1555 

1556 Args: 

1557 num_required: Number of gradients that needs to have been aggregated 

1558 name: Optional name for the operation 

1559 

1560 Returns: 

1561 An `IndexedSlices` holding the value of the average gradient. 

1562 

1563 Raises: 

1564 InvalidArgumentError: If `num_required` < 1 

1565 """ 

1566 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient( 

1567 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 

1568 return indexed_slices.IndexedSlices( 

1569 indices=return_val.indices, 

1570 values=return_val.values, 

1571 dense_shape=return_val.shape) 

1572 

1573 # SparseConditionalAccumulator is not switched to resource. Use old kernels. 

1574 def num_accumulated(self, name=None): 

1575 """Number of gradients that have currently been aggregated in accumulator. 

1576 

1577 Args: 

1578 name: Optional name for the operation. 

1579 

1580 Returns: 

1581 Number of accumulated gradients currently in accumulator. 

1582 """ 

1583 if name is None: 

1584 name = "%s_NumAccumulated" % self._name 

1585 

1586 return gen_data_flow_ops.accumulator_num_accumulated( 

1587 self._accumulator_ref, name=name) 

1588 

1589 def set_global_step(self, new_global_step, name=None): 

1590 """Sets the global time step of the accumulator. 

1591 

1592 The operation logs a warning if we attempt to set to a time step that is 

1593 lower than the accumulator's own time step. 

1594 

1595 Args: 

1596 new_global_step: Value of new time step. Can be a variable or a constant 

1597 name: Optional name for the operation. 

1598 

1599 Returns: 

1600 Operation that sets the accumulator's time step. 

1601 """ 

1602 return gen_data_flow_ops.accumulator_set_global_step( 

1603 self._accumulator_ref, 

1604 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 

1605 name=name) 

1606 

1607 

1608class BaseStagingArea: 

1609 """Base class for Staging Areas.""" 

1610 _identifier = 0 

1611 _lock = threading.Lock() 

1612 

1613 def __init__(self, 

1614 dtypes, 

1615 shapes=None, 

1616 names=None, 

1617 shared_name=None, 

1618 capacity=0, 

1619 memory_limit=0): 

1620 if shared_name is None: 

1621 self._name = ( 

1622 ops.get_default_graph().unique_name(self.__class__.__name__)) 

1623 elif isinstance(shared_name, str): 

1624 self._name = shared_name 

1625 else: 

1626 raise ValueError(f"shared_name must be a string, got {shared_name}") 

1627 

1628 self._dtypes = dtypes 

1629 

1630 if shapes is not None: 

1631 if len(shapes) != len(dtypes): 

1632 raise ValueError("StagingArea shapes must be the same length as dtypes") 

1633 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 

1634 else: 

1635 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 

1636 

1637 if names is not None: 

1638 if len(names) != len(dtypes): 

1639 raise ValueError("StagingArea names must be the same length as dtypes") 

1640 self._names = names 

1641 else: 

1642 self._names = None 

1643 

1644 self._capacity = capacity 

1645 self._memory_limit = memory_limit 

1646 

1647 # all get and put ops must colocate with this op 

1648 with ops.name_scope("%s_root" % self._name): 

1649 self._coloc_op = control_flow_ops.no_op() 

1650 

1651 @property 

1652 def name(self): 

1653 """The name of the staging area.""" 

1654 return self._name 

1655 

1656 @property 

1657 def dtypes(self): 

1658 """The list of dtypes for each component of a staging area element.""" 

1659 return self._dtypes 

1660 

1661 @property 

1662 def shapes(self): 

1663 """The list of shapes for each component of a staging area element.""" 

1664 return self._shapes 

1665 

1666 @property 

1667 def names(self): 

1668 """The list of names for each component of a staging area element.""" 

1669 return self._names 

1670 

1671 @property 

1672 def capacity(self): 

1673 """The maximum number of elements of this staging area.""" 

1674 return self._capacity 

1675 

1676 @property 

1677 def memory_limit(self): 

1678 """The maximum number of bytes of this staging area.""" 

1679 return self._memory_limit 

1680 

1681 def _check_put_dtypes(self, vals, indices=None): 

1682 """Validate and convert `vals` to a list of `Tensor`s. 

1683 

1684 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 

1685 dictionary with tensor values. 

1686 

1687 If `vals` is a list, then the appropriate indices associated with the 

1688 values must be provided. 

1689 

1690 If it is a dictionary, the staging area must have been constructed with a 

1691 `names` attribute and the dictionary keys must match the staging area names. 

1692 `indices` will be inferred from the dictionary keys. 

1693 If the staging area was constructed with a `names` attribute, `vals` must 

1694 be a dictionary. 

1695 

1696 Checks that the dtype and shape of each value matches that 

1697 of the staging area. 

1698 

1699 Args: 

1700 vals: A tensor, a list or tuple of tensors, or a dictionary. 

1701 

1702 Returns: 

1703 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects 

1704 and `indices` is a list of indices associated with the tensors. 

1705 

1706 Raises: 

1707 ValueError: If `vals` or `indices` is invalid. 

1708 """ 

1709 if isinstance(vals, dict): 

1710 if not self._names: 

1711 raise ValueError( 

1712 "Staging areas must have names to enqueue a dictionary") 

1713 if not set(vals.keys()).issubset(self._names): 

1714 raise ValueError("Keys in dictionary to put do not match names " 

1715 f"of staging area. Dictionary: {sorted(vals.keys())}" 

1716 f"Queue: {sorted(self._names)}") 

1717 # The order of values in `self._names` indicates the order in which the 

1718 # tensors in the dictionary `vals` must be listed. 

1719 vals, indices, _ = zip(*[(vals[k], i, k) 

1720 for i, k in enumerate(self._names) 

1721 if k in vals]) 

1722 else: 

1723 if self._names: 

1724 raise ValueError("You must enqueue a dictionary in a staging area " 

1725 "with names") 

1726 

1727 if indices is None: 

1728 raise ValueError("Indices must be supplied when inserting a list " 

1729 "of tensors") 

1730 

1731 if len(indices) != len(vals): 

1732 raise ValueError(f"Number of indices {len(indices)} doesn't match " 

1733 f"number of values {len(vals)}") 

1734 

1735 if not isinstance(vals, (list, tuple)): 

1736 vals = [vals] 

1737 indices = [0] 

1738 

1739 # Sanity check number of values 

1740 if not len(vals) <= len(self._dtypes): 

1741 raise ValueError(f"Unexpected number of inputs {len(vals)} vs " 

1742 f"{len(self._dtypes)}") 

1743 

1744 tensors = [] 

1745 

1746 for val, i in zip(vals, indices): 

1747 dtype, shape = self._dtypes[i], self._shapes[i] 

1748 # Check dtype 

1749 if val.dtype != dtype: 

1750 raise ValueError(f"Datatypes do not match. " 

1751 f"Received val.dtype {str(val.dtype)} and " 

1752 f"dtype {str(dtype)}") 

1753 # Check shape 

1754 val.get_shape().assert_is_compatible_with(shape) 

1755 

1756 tensors.append( 

1757 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 

1758 

1759 return tensors, indices 

1760 

1761 def _create_device_transfers(self, tensors): 

1762 """Encode inter-device transfers if the current device 

1763 is not the same as the Staging Area's device. 

1764 """ 

1765 

1766 if not isinstance(tensors, (tuple, list)): 

1767 tensors = [tensors] 

1768 

1769 curr_device_scope = control_flow_ops.no_op().device 

1770 

1771 if curr_device_scope != self._coloc_op.device: 

1772 tensors = [array_ops.identity(t) for t in tensors] 

1773 

1774 return tensors 

1775 

1776 def _get_return_value(self, tensors, indices): 

1777 """Return the value to return from a get op. 

1778 

1779 If the staging area has names, return a dictionary with the 

1780 names as keys. Otherwise return either a single tensor 

1781 or a list of tensors depending on the length of `tensors`. 

1782 

1783 Args: 

1784 tensors: List of tensors from the get op. 

1785 indices: Indices of associated names and shapes 

1786 

1787 Returns: 

1788 A single tensor, a list of tensors, or a dictionary 

1789 of tensors. 

1790 """ 

1791 

1792 tensors = self._create_device_transfers(tensors) 

1793 

1794 # Sets shape 

1795 for output, i in zip(tensors, indices): 

1796 output.set_shape(self._shapes[i]) 

1797 

1798 if self._names: 

1799 # The returned values in `tensors` are in the same order as 

1800 # the names in `self._names`. 

1801 return {self._names[i]: t for t, i in zip(tensors, indices)} 

1802 return tensors 

1803 

1804 def _scope_vals(self, vals): 

1805 """Return a list of values to pass to `name_scope()`. 

1806 

1807 Args: 

1808 vals: A tensor, a list or tuple of tensors, or a dictionary. 

1809 

1810 Returns: 

1811 The values in vals as a list. 

1812 """ 

1813 if isinstance(vals, (list, tuple)): 

1814 return vals 

1815 elif isinstance(vals, dict): 

1816 return vals.values() 

1817 else: 

1818 return [vals] 

1819 

1820 

1821class StagingArea(BaseStagingArea): 

1822 """Class for staging inputs. No ordering guarantees. 

1823 

1824 A `StagingArea` is a TensorFlow data structure that stores tensors across 

1825 multiple steps, and exposes operations that can put and get tensors. 

1826 

1827 Each `StagingArea` element is a tuple of one or more tensors, where each 

1828 tuple component has a static dtype, and may have a static shape. 

1829 

1830 The capacity of a `StagingArea` may be bounded or unbounded. 

1831 It supports multiple concurrent producers and consumers; and 

1832 provides exactly-once delivery. 

1833 

1834 Each element of a `StagingArea` is a fixed-length tuple of tensors whose 

1835 dtypes are described by `dtypes`, and whose shapes are optionally described 

1836 by the `shapes` argument. 

1837 

1838 If the `shapes` argument is specified, each component of a staging area 

1839 element must have the respective fixed shape. If it is 

1840 unspecified, different elements may have different shapes, 

1841 

1842 It can be configured with a capacity in which case 

1843 put(values) will block until space becomes available. 

1844 

1845 Similarly, it can be configured with a memory limit which 

1846 will block put(values) until space is available. 

1847 This is mostly useful for limiting the number of tensors on 

1848 devices such as GPUs. 

1849 

1850 All get() and peek() commands block if the requested data 

1851 is not present in the Staging Area. 

1852 

1853 """ 

1854 

1855 def __init__(self, 

1856 dtypes, 

1857 shapes=None, 

1858 names=None, 

1859 shared_name=None, 

1860 capacity=0, 

1861 memory_limit=0): 

1862 """Constructs a staging area object. 

1863 

1864 The two optional lists, `shapes` and `names`, must be of the same length 

1865 as `dtypes` if provided. The values at a given index `i` indicate the 

1866 shape and name to use for the corresponding queue component in `dtypes`. 

1867 

1868 The device scope at the time of object creation determines where the 

1869 storage for the `StagingArea` will reside. Calls to `put` will incur a copy 

1870 to this memory space, if necessary. Tensors returned by `get` will be 

1871 placed according to the device scope when `get` is called. 

1872 

1873 Args: 

1874 dtypes: A list of types. The length of dtypes must equal the number 

1875 of tensors in each element. 

1876 shapes: (Optional.) Constraints on the shapes of tensors in an element. 

1877 A list of shape tuples or None. This list is the same length 

1878 as dtypes. If the shape of any tensors in the element are constrained, 

1879 all must be; shapes can be None if the shapes should not be constrained. 

1880 names: (Optional.) If provided, the `get()` and 

1881 `put()` methods will use dictionaries with these names as keys. 

1882 Must be None or a list or tuple of the same length as `dtypes`. 

1883 shared_name: (Optional.) A name to be used for the shared object. By 

1884 passing the same name to two different python objects they will share 

1885 the underlying staging area. Must be a string. 

1886 capacity: (Optional.) Maximum number of elements. 

1887 An integer. If zero, the Staging Area is unbounded 

1888 memory_limit: (Optional.) Maximum number of bytes of all tensors 

1889 in the Staging Area. 

1890 An integer. If zero, the Staging Area is unbounded 

1891 

1892 Raises: 

1893 ValueError: If one of the arguments is invalid. 

1894 """ 

1895 

1896 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, 

1897 capacity, memory_limit) 

1898 

1899 def put(self, values, name=None): 

1900 """Create an op that places a value into the staging area. 

1901 

1902 This operation will block if the `StagingArea` has reached 

1903 its capacity. 

1904 

1905 Args: 

1906 values: A single tensor, a list or tuple of tensors, or a dictionary with 

1907 tensor values. The number of elements must match the length of the 

1908 list provided to the dtypes argument when creating the StagingArea. 

1909 name: A name for the operation (optional). 

1910 

1911 Returns: 

1912 The created op. 

1913 

1914 Raises: 

1915 ValueError: If the number or type of inputs don't match the staging area. 

1916 """ 

1917 with ops.name_scope(name, "%s_put" % self._name, 

1918 self._scope_vals(values)) as scope: 

1919 

1920 if not isinstance(values, (list, tuple, dict)): 

1921 values = [values] 

1922 

1923 # Hard-code indices for this staging area 

1924 indices = list(range(len(values))) 

1925 vals, _ = self._check_put_dtypes(values, indices) 

1926 

1927 with ops.colocate_with(self._coloc_op): 

1928 op = gen_data_flow_ops.stage( 

1929 values=vals, 

1930 shared_name=self._name, 

1931 name=scope, 

1932 capacity=self._capacity, 

1933 memory_limit=self._memory_limit) 

1934 

1935 return op 

1936 

1937 def __internal_get(self, get_fn, name): 

1938 with ops.colocate_with(self._coloc_op): 

1939 ret = get_fn() 

1940 

1941 indices = list(range(len(self._dtypes))) # Hard coded 

1942 return self._get_return_value(ret, indices) 

1943 

1944 def get(self, name=None): 

1945 """Gets one element from this staging area. 

1946 

1947 If the staging area is empty when this operation executes, it will block 

1948 until there is an element to dequeue. 

1949 

1950 Note that unlike others ops that can block, like the queue Dequeue 

1951 operations, this can stop other work from happening. To avoid this, the 

1952 intended use is for this to be called only when there will be an element 

1953 already available. One method for doing this in a training loop would be to 

1954 run a `put()` call during a warmup session.run call, and then call both 

1955 `get()` and `put()` in each subsequent step. 

1956 

1957 The placement of the returned tensor will be determined by the current 

1958 device scope when this function is called. 

1959 

1960 Args: 

1961 name: A name for the operation (optional). 

1962 

1963 Returns: 

1964 The tuple of tensors that was gotten. 

1965 """ 

1966 if name is None: 

1967 name = "%s_get" % self._name 

1968 

1969 # pylint: disable=bad-continuation 

1970 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, 

1971 shared_name=self._name, name=name, 

1972 capacity=self._capacity, 

1973 memory_limit=self._memory_limit) 

1974 # pylint: enable=bad-continuation 

1975 

1976 return self.__internal_get(fn, name) 

1977 

1978 def peek(self, index, name=None): 

1979 """Peeks at an element in the staging area. 

1980 

1981 If the staging area is too small to contain the element at 

1982 the specified index, it will block until enough elements 

1983 are inserted to complete the operation. 

1984 

1985 The placement of the returned tensor will be determined by 

1986 the current device scope when this function is called. 

1987 

1988 Args: 

1989 index: The index of the tensor within the staging area 

1990 to look up. 

1991 name: A name for the operation (optional). 

1992 

1993 Returns: 

1994 The tuple of tensors that was gotten. 

1995 """ 

1996 if name is None: 

1997 name = "%s_peek" % self._name 

1998 

1999 # pylint: disable=bad-continuation 

2000 fn = lambda: gen_data_flow_ops.stage_peek(index, 

2001 dtypes=self._dtypes, shared_name=self._name, 

2002 name=name, capacity=self._capacity, 

2003 memory_limit=self._memory_limit) 

2004 # pylint: enable=bad-continuation 

2005 

2006 return self.__internal_get(fn, name) 

2007 

2008 def size(self, name=None): 

2009 """Returns the number of elements in the staging area. 

2010 

2011 Args: 

2012 name: A name for the operation (optional) 

2013 

2014 Returns: 

2015 The created op 

2016 """ 

2017 if name is None: 

2018 name = "%s_size" % self._name 

2019 

2020 return gen_data_flow_ops.stage_size( 

2021 name=name, 

2022 shared_name=self._name, 

2023 dtypes=self._dtypes, 

2024 capacity=self._capacity, 

2025 memory_limit=self._memory_limit) 

2026 

2027 def clear(self, name=None): 

2028 """Clears the staging area. 

2029 

2030 Args: 

2031 name: A name for the operation (optional) 

2032 

2033 Returns: 

2034 The created op 

2035 """ 

2036 if name is None: 

2037 name = "%s_clear" % self._name 

2038 

2039 return gen_data_flow_ops.stage_clear( 

2040 name=name, 

2041 shared_name=self._name, 

2042 dtypes=self._dtypes, 

2043 capacity=self._capacity, 

2044 memory_limit=self._memory_limit) 

2045 

2046 

2047class MapStagingArea(BaseStagingArea): 

2048 """A `MapStagingArea` is a TensorFlow data structure that stores tensors 

2049 across multiple steps, and exposes operations that can put and get tensors. 

2050 

2051 Each `MapStagingArea` element is a (key, value) pair. 

2052 Only int64 keys are supported, other types should be 

2053 hashed to produce a key. 

2054 Values are a tuple of one or more tensors. 

2055 Each tuple component has a static dtype, 

2056 and may have a static shape. 

2057 

2058 The capacity of a `MapStagingArea` may be bounded or unbounded. 

2059 It supports multiple concurrent producers and consumers; and 

2060 provides exactly-once delivery. 

2061 

2062 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors 

2063 whose 

2064 dtypes are described by `dtypes`, and whose shapes are optionally described 

2065 by the `shapes` argument. 

2066 

2067 If the `shapes` argument is specified, each component of a staging area 

2068 element must have the respective fixed shape. If it is 

2069 unspecified, different elements may have different shapes, 

2070 

2071 It behaves like an associative container with support for: 

2072 

2073 - put(key, values) 

2074 - peek(key) like dict.get(key) 

2075 - get(key) like dict.pop(key) 

2076 - get(key=None) like dict.popitem() 

2077 - size() 

2078 - clear() 

2079 

2080 If ordered a tree structure ordered by key will be used and 

2081 get(key=None) will remove (key, value) pairs in increasing key order. 

2082 Otherwise a hashtable 

2083 

2084 It can be configured with a capacity in which case 

2085 put(key, values) will block until space becomes available. 

2086 

2087 Similarly, it can be configured with a memory limit which 

2088 will block put(key, values) until space is available. 

2089 This is mostly useful for limiting the number of tensors on 

2090 devices such as GPUs. 

2091 

2092 All get() and peek() commands block if the requested 

2093 (key, value) pair is not present in the staging area. 

2094 

2095 Partial puts are supported and will be placed in an incomplete 

2096 map until such time as all values associated with the key have 

2097 been inserted. Once completed, this (key, value) pair will be 

2098 inserted into the map. Data in the incomplete map 

2099 counts towards the memory limit, but not towards capacity limit. 

2100 

2101 Partial gets from the map are also supported. 

2102 This removes the partially requested tensors from the entry, 

2103 but the entry is only removed from the map once all tensors 

2104 associated with it are removed. 

2105 """ 

2106 

2107 def __init__(self, 

2108 dtypes, 

2109 shapes=None, 

2110 names=None, 

2111 shared_name=None, 

2112 ordered=False, 

2113 capacity=0, 

2114 memory_limit=0): 

2115 """Args: 

2116 

2117 dtypes: A list of types. The length of dtypes must equal the number 

2118 of tensors in each element. 

2119 capacity: (Optional.) Maximum number of elements. 

2120 An integer. If zero, the Staging Area is unbounded 

2121 memory_limit: (Optional.) Maximum number of bytes of all tensors 

2122 in the Staging Area (excluding keys). 

2123 An integer. If zero, the Staging Area is unbounded 

2124 ordered: (Optional.) If True the underlying data structure 

2125 is a tree ordered on key. Otherwise assume a hashtable. 

2126 shapes: (Optional.) Constraints on the shapes of tensors in an element. 

2127 A list of shape tuples or None. This list is the same length 

2128 as dtypes. If the shape of any tensors in the element are constrained, 

2129 all must be; shapes can be None if the shapes should not be constrained. 

2130 names: (Optional.) If provided, the `get()` and 

2131 `put()` methods will use dictionaries with these names as keys. 

2132 Must be None or a list or tuple of the same length as `dtypes`. 

2133 shared_name: (Optional.) A name to be used for the shared object. By 

2134 passing the same name to two different python objects they will share 

2135 the underlying staging area. Must be a string. 

2136 

2137 Raises: 

2138 ValueError: If one of the arguments is invalid. 

2139 

2140 """ 

2141 

2142 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, 

2143 capacity, memory_limit) 

2144 

2145 # Defer to different methods depending if the map is ordered 

2146 self._ordered = ordered 

2147 

2148 if ordered: 

2149 self._put_fn = gen_data_flow_ops.ordered_map_stage 

2150 self._pop_fn = gen_data_flow_ops.ordered_map_unstage 

2151 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key 

2152 self._peek_fn = gen_data_flow_ops.ordered_map_peek 

2153 self._size_fn = gen_data_flow_ops.ordered_map_size 

2154 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size 

2155 self._clear_fn = gen_data_flow_ops.ordered_map_clear 

2156 else: 

2157 self._put_fn = gen_data_flow_ops.map_stage 

2158 self._pop_fn = gen_data_flow_ops.map_unstage 

2159 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key 

2160 self._peek_fn = gen_data_flow_ops.map_peek 

2161 self._size_fn = gen_data_flow_ops.map_size 

2162 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size 

2163 self._clear_fn = gen_data_flow_ops.map_clear 

2164 

2165 def put(self, key, vals, indices=None, name=None): 

2166 """Create an op that stores the (key, vals) pair in the staging area. 

2167 

2168 Incomplete puts are possible, preferably using a dictionary for vals 

2169 as the appropriate dtypes and shapes can be inferred from the value names 

2170 dictionary key values. If vals is a list or tuple, indices must 

2171 also be specified so that the op knows at which element position 

2172 to perform the insert. 

2173 

2174 This operation will block if the capacity or memory limit of this 

2175 container is reached. 

2176 

2177 Args: 

2178 key: Key associated with the data 

2179 vals: Tensor (or a dict/tuple of Tensors) to place 

2180 into the staging area. 

2181 indices: (Optional) if vals is a tuple/list, this is required. 

2182 name: A name for the operation (optional) 

2183 

2184 Returns: 

2185 The created op 

2186 

2187 Raises: 

2188 ValueError: If the number or type of inputs don't match the staging 

2189 area. 

2190 """ 

2191 

2192 with ops.name_scope(name, "%s_put" % self._name, 

2193 self._scope_vals(vals)) as scope: 

2194 

2195 vals, indices = self._check_put_dtypes(vals, indices) 

2196 

2197 with ops.colocate_with(self._coloc_op): 

2198 op = self._put_fn( 

2199 key, 

2200 indices, 

2201 vals, 

2202 dtypes=self._dtypes, 

2203 shared_name=self._name, 

2204 name=scope, 

2205 capacity=self._capacity, 

2206 memory_limit=self._memory_limit) 

2207 return op 

2208 

2209 def _get_indices_and_dtypes(self, indices=None): 

2210 if indices is None: 

2211 indices = list(range(len(self._dtypes))) 

2212 

2213 if not isinstance(indices, (tuple, list)): 

2214 raise TypeError(f"Invalid indices type {type(indices)}") 

2215 

2216 if len(indices) == 0: 

2217 raise ValueError("Empty indices") 

2218 

2219 if all(isinstance(i, str) for i in indices): 

2220 if self._names is None: 

2221 raise ValueError(f"String indices provided {indices}, but " 

2222 "this Staging Area was not created with names.") 

2223 

2224 try: 

2225 indices = [self._names.index(n) for n in indices] 

2226 except ValueError: 

2227 raise ValueError(f"Named index not in " 

2228 f"Staging Area names {self._names}") 

2229 elif all(isinstance(i, int) for i in indices): 

2230 pass 

2231 else: 

2232 raise TypeError(f"Mixed types in indices {indices}. " 

2233 "May only be str or int") 

2234 

2235 dtypes = [self._dtypes[i] for i in indices] 

2236 

2237 return indices, dtypes 

2238 

2239 def peek(self, key, indices=None, name=None): 

2240 """Peeks at staging area data associated with the key. 

2241 

2242 If the key is not in the staging area, it will block 

2243 until the associated (key, value) is inserted. 

2244 

2245 Args: 

2246 key: Key associated with the required data 

2247 indices: Partial list of tensors to retrieve (optional). 

2248 A list of integer or string indices. 

2249 String indices are only valid if the Staging Area 

2250 has names associated with it. 

2251 name: A name for the operation (optional) 

2252 

2253 Returns: 

2254 The created op 

2255 """ 

2256 

2257 if name is None: 

2258 name = "%s_pop" % self._name 

2259 

2260 indices, dtypes = self._get_indices_and_dtypes(indices) 

2261 

2262 with ops.colocate_with(self._coloc_op): 

2263 result = self._peek_fn( 

2264 key, 

2265 shared_name=self._name, 

2266 indices=indices, 

2267 dtypes=dtypes, 

2268 name=name, 

2269 capacity=self._capacity, 

2270 memory_limit=self._memory_limit) 

2271 

2272 return self._get_return_value(result, indices) 

2273 

2274 def get(self, key=None, indices=None, name=None): 

2275 """If the key is provided, the associated (key, value) is returned from the staging area. 

2276 

2277 If the key is not in the staging area, this method will block until 

2278 the associated (key, value) is inserted. 

2279 If no key is provided and the staging area is ordered, 

2280 the (key, value) with the smallest key will be returned. 

2281 Otherwise, a random (key, value) will be returned. 

2282 

2283 If the staging area is empty when this operation executes, 

2284 it will block until there is an element to dequeue. 

2285 

2286 Args: 

2287 key: Key associated with the required data (Optional) 

2288 indices: Partial list of tensors to retrieve (optional). 

2289 A list of integer or string indices. 

2290 String indices are only valid if the Staging Area 

2291 has names associated with it. 

2292 name: A name for the operation (optional) 

2293 

2294 Returns: 

2295 The created op 

2296 """ 

2297 if key is None: 

2298 return self._popitem(indices=indices, name=name) 

2299 else: 

2300 return self._pop(key, indices=indices, name=name) 

2301 

2302 def _pop(self, key, indices=None, name=None): 

2303 """Remove and return the associated (key, value) is returned from the staging area. 

2304 

2305 If the key is not in the staging area, this method will block until 

2306 the associated (key, value) is inserted. 

2307 Args: 

2308 key: Key associated with the required data 

2309 indices: Partial list of tensors to retrieve (optional). 

2310 A list of integer or string indices. 

2311 String indices are only valid if the Staging Area 

2312 has names associated with it. 

2313 name: A name for the operation (optional) 

2314 

2315 Returns: 

2316 The created op 

2317 """ 

2318 if name is None: 

2319 name = "%s_get" % self._name 

2320 

2321 indices, dtypes = self._get_indices_and_dtypes(indices) 

2322 

2323 with ops.colocate_with(self._coloc_op): 

2324 result = self._pop_fn( 

2325 key, 

2326 shared_name=self._name, 

2327 indices=indices, 

2328 dtypes=dtypes, 

2329 name=name, 

2330 capacity=self._capacity, 

2331 memory_limit=self._memory_limit) 

2332 

2333 return key, self._get_return_value(result, indices) 

2334 

2335 def _popitem(self, indices=None, name=None): 

2336 """If the staging area is ordered, the (key, value) with the smallest key will be returned. 

2337 

2338 Otherwise, a random (key, value) will be returned. 

2339 If the staging area is empty when this operation executes, 

2340 it will block until there is an element to dequeue. 

2341 

2342 Args: 

2343 key: Key associated with the required data 

2344 indices: Partial list of tensors to retrieve (optional). 

2345 A list of integer or string indices. 

2346 String indices are only valid if the Staging Area 

2347 has names associated with it. 

2348 name: A name for the operation (optional) 

2349 

2350 Returns: 

2351 The created op 

2352 """ 

2353 if name is None: 

2354 name = "%s_get_nokey" % self._name 

2355 

2356 indices, dtypes = self._get_indices_and_dtypes(indices) 

2357 

2358 with ops.colocate_with(self._coloc_op): 

2359 key, result = self._popitem_fn( 

2360 shared_name=self._name, 

2361 indices=indices, 

2362 dtypes=dtypes, 

2363 name=name, 

2364 capacity=self._capacity, 

2365 memory_limit=self._memory_limit) 

2366 

2367 # Separate keys and results out from 

2368 # underlying namedtuple 

2369 key = self._create_device_transfers(key)[0] 

2370 result = self._get_return_value(result, indices) 

2371 

2372 return key, result 

2373 

2374 def size(self, name=None): 

2375 """Returns the number of elements in the staging area. 

2376 

2377 Args: 

2378 name: A name for the operation (optional) 

2379 

2380 Returns: 

2381 The created op 

2382 """ 

2383 if name is None: 

2384 name = "%s_size" % self._name 

2385 

2386 return self._size_fn( 

2387 shared_name=self._name, 

2388 name=name, 

2389 dtypes=self._dtypes, 

2390 capacity=self._capacity, 

2391 memory_limit=self._memory_limit) 

2392 

2393 def incomplete_size(self, name=None): 

2394 """Returns the number of incomplete elements in the staging area. 

2395 

2396 Args: 

2397 name: A name for the operation (optional) 

2398 

2399 Returns: 

2400 The created op 

2401 """ 

2402 if name is None: 

2403 name = "%s_incomplete_size" % self._name 

2404 

2405 return self._incomplete_size_fn( 

2406 shared_name=self._name, 

2407 name=name, 

2408 dtypes=self._dtypes, 

2409 capacity=self._capacity, 

2410 memory_limit=self._memory_limit) 

2411 

2412 def clear(self, name=None): 

2413 """Clears the staging area. 

2414 

2415 Args: 

2416 name: A name for the operation (optional) 

2417 

2418 Returns: 

2419 The created op 

2420 """ 

2421 if name is None: 

2422 name = "%s_clear" % self._name 

2423 

2424 return self._clear_fn( 

2425 shared_name=self._name, 

2426 name=name, 

2427 dtypes=self._dtypes, 

2428 capacity=self._capacity, 

2429 memory_limit=self._memory_limit) 

2430 

2431 

2432class RecordInput: 

2433 """RecordInput asynchronously reads and randomly yields TFRecords. 

2434 

2435 A RecordInput Op will continuously read a batch of records asynchronously 

2436 into a buffer of some fixed capacity. It can also asynchronously yield 

2437 random records from this buffer. 

2438 

2439 It will not start yielding until at least `buffer_size / 2` elements have been 

2440 placed into the buffer so that sufficient randomization can take place. 

2441 

2442 The order the files are read will be shifted each epoch by `shift_amount` so 

2443 that the data is presented in a different order every epoch. 

2444 """ 

2445 

2446 def __init__(self, 

2447 file_pattern, 

2448 batch_size=1, 

2449 buffer_size=1, 

2450 parallelism=1, 

2451 shift_ratio=0, 

2452 seed=0, 

2453 name=None, 

2454 batches=None, 

2455 compression_type=None): 

2456 """Constructs a RecordInput Op. 

2457 

2458 Args: 

2459 file_pattern: File path to the dataset, possibly containing wildcards. 

2460 All matching files will be iterated over each epoch. 

2461 batch_size: How many records to return at a time. 

2462 buffer_size: The maximum number of records the buffer will contain. 

2463 parallelism: How many reader threads to use for reading from files. 

2464 shift_ratio: What percentage of the total number files to move the start 

2465 file forward by each epoch. 

2466 seed: Specify the random number seed used by generator that randomizes 

2467 records. 

2468 name: Optional name for the operation. 

2469 batches: None by default, creating a single batch op. Otherwise specifies 

2470 how many batches to create, which are returned as a list when 

2471 `get_yield_op()` is called. An example use case is to split processing 

2472 between devices on one computer. 

2473 compression_type: The type of compression for the file. Currently ZLIB and 

2474 GZIP are supported. Defaults to none. 

2475 

2476 Raises: 

2477 ValueError: If one of the arguments is invalid. 

2478 """ 

2479 self._batch_size = batch_size 

2480 if batches is not None: 

2481 self._batch_size *= batches 

2482 self._batches = batches 

2483 self._file_pattern = file_pattern 

2484 self._buffer_size = buffer_size 

2485 self._parallelism = parallelism 

2486 self._shift_ratio = shift_ratio 

2487 self._seed = seed 

2488 self._name = name 

2489 self._compression_type = python_io.TFRecordCompressionType.NONE 

2490 if compression_type is not None: 

2491 self._compression_type = compression_type 

2492 

2493 def get_yield_op(self): 

2494 """Adds a node that yields a group of records every time it is executed. 

2495 If RecordInput `batches` parameter is not None, it yields a list of 

2496 record batches with the specified `batch_size`. 

2497 """ 

2498 compression_type = python_io.TFRecordOptions.get_compression_type_string( 

2499 python_io.TFRecordOptions(self._compression_type)) 

2500 records = gen_data_flow_ops.record_input( 

2501 file_pattern=self._file_pattern, 

2502 file_buffer_size=self._buffer_size, 

2503 file_parallelism=self._parallelism, 

2504 file_shuffle_shift_ratio=self._shift_ratio, 

2505 batch_size=self._batch_size, 

2506 file_random_seed=self._seed, 

2507 compression_type=compression_type, 

2508 name=self._name) 

2509 if self._batches is None: 

2510 return records 

2511 else: 

2512 with ops.name_scope(self._name): 

2513 batch_list = [[] for _ in range(self._batches)] 

2514 records = array_ops.split(records, self._batch_size, 0) 

2515 for index, protobuf in enumerate(records): 

2516 batch_index = index % self._batches 

2517 batch_list[batch_index].append(array_ops.reshape(protobuf, [])) 

2518 return batch_list