Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/data_flow_ops.py: 24%
621 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Data Flow Operations."""
16# pylint: disable=g-bad-name
17import functools
18import hashlib
19import threading
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes as _dtypes
23from tensorflow.python.framework import indexed_slices
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import random_seed
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.lib.io import python_io
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import array_ops_stack
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import gen_data_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37from tensorflow.python.ops.gen_data_flow_ops import *
38from tensorflow.python.util import deprecation
39from tensorflow.python.util.compat import collections_abc
40from tensorflow.python.util.tf_export import tf_export
42# pylint: enable=wildcard-import
45def _as_type_list(dtypes):
46 """Convert dtypes to a list of types."""
47 assert dtypes is not None
48 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
49 # We have a single type.
50 return [dtypes]
51 else:
52 # We have a list or tuple of types.
53 return list(dtypes)
56def _as_shape_list(shapes,
57 dtypes,
58 unknown_dim_allowed=False,
59 unknown_rank_allowed=False):
60 """Convert shapes to a list of tuples of int (or None)."""
61 del dtypes
62 if unknown_dim_allowed:
63 if (not isinstance(shapes, collections_abc.Sequence) or not shapes or
64 any(shape is None or isinstance(shape, int) for shape in shapes)):
65 raise ValueError(
66 "When providing partial shapes, a list of shapes must be provided.")
67 if shapes is None:
68 return None
69 if isinstance(shapes, tensor_shape.TensorShape):
70 shapes = [shapes]
71 if not isinstance(shapes, (tuple, list)):
72 raise TypeError(
73 "Shapes must be a TensorShape or a list or tuple of TensorShapes, "
74 f"got {type(shapes)} instead.")
75 if all(shape is None or isinstance(shape, int) for shape in shapes):
76 # We have a single shape.
77 shapes = [shapes]
78 shapes = [tensor_shape.as_shape(shape) for shape in shapes]
79 if not unknown_dim_allowed:
80 if any(not shape.is_fully_defined() for shape in shapes):
81 raise ValueError(f"All shapes must be fully defined: {shapes}")
82 if not unknown_rank_allowed:
83 if any(shape.dims is None for shape in shapes):
84 raise ValueError(f"All shapes must have a defined rank: {shapes}")
86 return shapes
89def _as_name_list(names, dtypes):
90 if names is None:
91 return None
92 if not isinstance(names, (list, tuple)):
93 names = [names]
94 if len(names) != len(dtypes):
95 raise ValueError("List of names must have the same length as the list "
96 f"of dtypes, received len(names)={len(names)},"
97 f"len(dtypes)={len(dtypes)}")
98 return list(names)
101def _shape_common(s1, s2):
102 """The greatest lower bound (ordered by specificity) TensorShape."""
103 s1 = tensor_shape.TensorShape(s1)
104 s2 = tensor_shape.TensorShape(s2)
105 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
106 return tensor_shape.unknown_shape()
107 d = [
108 d1 if d1 is not None and d1 == d2 else None
109 for (d1, d2) in zip(s1.as_list(), s2.as_list())
110 ]
111 return tensor_shape.TensorShape(d)
114# pylint: disable=protected-access
115@tf_export("queue.QueueBase",
116 v1=["queue.QueueBase", "io.QueueBase", "QueueBase"])
117@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"])
118class QueueBase:
119 """Base class for queue implementations.
121 A queue is a TensorFlow data structure that stores tensors across
122 multiple steps, and exposes operations that enqueue and dequeue
123 tensors.
125 Each queue element is a tuple of one or more tensors, where each
126 tuple component has a static dtype, and may have a static shape. The
127 queue implementations support versions of enqueue and dequeue that
128 handle single elements, versions that support enqueuing and
129 dequeuing a batch of elements at once.
131 See `tf.queue.FIFOQueue` and
132 `tf.queue.RandomShuffleQueue` for concrete
133 implementations of this class, and instructions on how to create
134 them.
135 """
137 def __init__(self, dtypes, shapes, names, queue_ref):
138 """Constructs a queue object from a queue reference.
140 The two optional lists, `shapes` and `names`, must be of the same length
141 as `dtypes` if provided. The values at a given index `i` indicate the
142 shape and name to use for the corresponding queue component in `dtypes`.
144 Args:
145 dtypes: A list of types. The length of dtypes must equal the number
146 of tensors in each element.
147 shapes: Constraints on the shapes of tensors in an element:
148 A list of shape tuples or None. This list is the same length
149 as dtypes. If the shape of any tensors in the element are constrained,
150 all must be; shapes can be None if the shapes should not be constrained.
151 names: Optional list of names. If provided, the `enqueue()` and
152 `dequeue()` methods will use dictionaries with these names as keys.
153 Must be None or a list or tuple of the same length as `dtypes`.
154 queue_ref: The queue reference, i.e. the output of the queue op.
156 Raises:
157 ValueError: If one of the arguments is invalid.
158 """
159 self._dtypes = dtypes
160 if shapes is not None:
161 if len(shapes) != len(dtypes):
162 raise ValueError("Queue shapes must have the same length as dtypes, "
163 f"received len(shapes)={len(shapes)}, "
164 f"len(dtypes)={len(dtypes)}")
165 self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
166 else:
167 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
168 if names is not None:
169 if len(names) != len(dtypes):
170 raise ValueError("Queue names must have the same length as dtypes,"
171 f"received len(names)={len(names)},"
172 f"len {len(dtypes)}")
173 self._names = names
174 else:
175 self._names = None
176 self._queue_ref = queue_ref
177 if isinstance(queue_ref, ops.EagerTensor):
178 if context.context().scope_name:
179 self._name = context.context().scope_name
180 else:
181 self._name = "Empty"
182 self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
183 queue_ref, None)
184 else:
185 self._name = self._queue_ref.op.name.split("/")[-1]
187 @staticmethod
188 def from_list(index, queues):
189 """Create a queue using the queue reference from `queues[index]`.
191 Args:
192 index: An integer scalar tensor that determines the input that gets
193 selected.
194 queues: A list of `QueueBase` objects.
196 Returns:
197 A `QueueBase` object.
199 Raises:
200 TypeError: When `queues` is not a list of `QueueBase` objects,
201 or when the data types of `queues` are not all the same.
202 """
203 if ((not queues) or (not isinstance(queues, list)) or
204 (not all(isinstance(x, QueueBase) for x in queues))):
205 raise TypeError("A list of queues expected")
207 dtypes = queues[0].dtypes
208 if not all(dtypes == q.dtypes for q in queues[1:]):
209 raise TypeError("Queues do not have matching component dtypes.")
211 names = queues[0].names
212 if not all(names == q.names for q in queues[1:]):
213 raise TypeError("Queues do not have matching component names.")
215 queue_shapes = [q.shapes for q in queues]
216 reduced_shapes = [
217 functools.reduce(_shape_common, s) for s in zip(*queue_shapes)
218 ]
220 queue_refs = array_ops_stack.stack([x.queue_ref for x in queues])
221 selected_queue = array_ops.gather(queue_refs, index)
222 return QueueBase(
223 dtypes=dtypes,
224 shapes=reduced_shapes,
225 names=names,
226 queue_ref=selected_queue)
228 @property
229 def queue_ref(self):
230 """The underlying queue reference."""
231 return self._queue_ref
233 @property
234 def name(self):
235 """The name of the underlying queue."""
236 if context.executing_eagerly():
237 return self._name
238 return self._queue_ref.op.name
240 @property
241 def dtypes(self):
242 """The list of dtypes for each component of a queue element."""
243 return self._dtypes
245 @property
246 def shapes(self):
247 """The list of shapes for each component of a queue element."""
248 return self._shapes
250 @property
251 def names(self):
252 """The list of names for each component of a queue element."""
253 return self._names
255 def _check_enqueue_dtypes(self, vals):
256 """Validate and convert `vals` to a list of `Tensor`s.
258 The `vals` argument can be a Tensor, a list or tuple of tensors, or a
259 dictionary with tensor values.
261 If it is a dictionary, the queue must have been constructed with a
262 `names` attribute and the dictionary keys must match the queue names.
263 If the queue was constructed with a `names` attribute, `vals` must
264 be a dictionary.
266 Args:
267 vals: A tensor, a list or tuple of tensors, or a dictionary..
269 Returns:
270 A list of `Tensor` objects.
272 Raises:
273 ValueError: If `vals` is invalid.
274 """
275 if isinstance(vals, dict):
276 if not self._names:
277 raise ValueError("Queue must have names to enqueue a dictionary")
278 if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
279 raise ValueError("Keys in dictionary to enqueue do not match "
280 f"names of Queue. Dictionary: {sorted(vals.keys())},"
281 f"Queue: {sorted(self._names)}")
282 # The order of values in `self._names` indicates the order in which the
283 # tensors in the dictionary `vals` must be listed.
284 vals = [vals[k] for k in self._names]
285 else:
286 if self._names:
287 raise ValueError("You must enqueue a dictionary in a Queue with names")
288 if not isinstance(vals, (list, tuple)):
289 vals = [vals]
291 tensors = []
292 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
293 tensors.append(
294 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
296 return tensors
298 def _scope_vals(self, vals):
299 """Return a list of values to pass to `name_scope()`.
301 Args:
302 vals: A tensor, a list or tuple of tensors, or a dictionary.
304 Returns:
305 The values in vals as a list.
306 """
307 if isinstance(vals, (list, tuple)):
308 return vals
309 elif isinstance(vals, dict):
310 return vals.values()
311 else:
312 return [vals]
314 def enqueue(self, vals, name=None):
315 """Enqueues one element to this queue.
317 If the queue is full when this operation executes, it will block
318 until the element has been enqueued.
320 At runtime, this operation may raise an error if the queue is
321 `tf.QueueBase.close` before or during its execution. If the
322 queue is closed before this operation runs,
323 `tf.errors.CancelledError` will be raised. If this operation is
324 blocked, and either (i) the queue is closed by a close operation
325 with `cancel_pending_enqueues=True`, or (ii) the session is
326 `tf.Session.close`,
327 `tf.errors.CancelledError` will be raised.
329 Args:
330 vals: A tensor, a list or tuple of tensors, or a dictionary containing
331 the values to enqueue.
332 name: A name for the operation (optional).
334 Returns:
335 The operation that enqueues a new tuple of tensors to the queue.
336 """
337 with ops.name_scope(name, "%s_enqueue" % self._name,
338 self._scope_vals(vals)) as scope:
339 vals = self._check_enqueue_dtypes(vals)
341 # NOTE(mrry): Not using a shape function because we need access to
342 # the `QueueBase` object.
343 for val, shape in zip(vals, self._shapes):
344 val.get_shape().assert_is_compatible_with(shape)
346 if self._queue_ref.dtype == _dtypes.resource:
347 return gen_data_flow_ops.queue_enqueue_v2(
348 self._queue_ref, vals, name=scope)
349 else:
350 return gen_data_flow_ops.queue_enqueue(
351 self._queue_ref, vals, name=scope)
353 def enqueue_many(self, vals, name=None):
354 """Enqueues zero or more elements to this queue.
356 This operation slices each component tensor along the 0th dimension to
357 make multiple queue elements. All of the tensors in `vals` must have the
358 same size in the 0th dimension.
360 If the queue is full when this operation executes, it will block
361 until all of the elements have been enqueued.
363 At runtime, this operation may raise an error if the queue is
364 `tf.QueueBase.close` before or during its execution. If the
365 queue is closed before this operation runs,
366 `tf.errors.CancelledError` will be raised. If this operation is
367 blocked, and either (i) the queue is closed by a close operation
368 with `cancel_pending_enqueues=True`, or (ii) the session is
369 `tf.Session.close`,
370 `tf.errors.CancelledError` will be raised.
372 Args:
373 vals: A tensor, a list or tuple of tensors, or a dictionary
374 from which the queue elements are taken.
375 name: A name for the operation (optional).
377 Returns:
378 The operation that enqueues a batch of tuples of tensors to the queue.
379 """
380 with ops.name_scope(name, "%s_EnqueueMany" % self._name,
381 self._scope_vals(vals)) as scope:
382 vals = self._check_enqueue_dtypes(vals)
384 # NOTE(mrry): Not using a shape function because we need access to
385 # the `QueueBase` object.
386 # NOTE(fchollet): the code that follow is verbose because it needs to be
387 # compatible with both TF v1 TensorShape behavior and TF v2 behavior.
388 batch_dim = tensor_shape.dimension_value(
389 vals[0].get_shape().with_rank_at_least(1)[0])
390 batch_dim = tensor_shape.Dimension(batch_dim)
391 for val, shape in zip(vals, self._shapes):
392 val_batch_dim = tensor_shape.dimension_value(
393 val.get_shape().with_rank_at_least(1)[0])
394 val_batch_dim = tensor_shape.Dimension(val_batch_dim)
395 batch_dim = batch_dim.merge_with(val_batch_dim)
396 val.get_shape()[1:].assert_is_compatible_with(shape)
398 return gen_data_flow_ops.queue_enqueue_many_v2(
399 self._queue_ref, vals, name=scope)
401 def _dequeue_return_value(self, tensors):
402 """Return the value to return from a dequeue op.
404 If the queue has names, return a dictionary with the
405 names as keys. Otherwise return either a single tensor
406 or a list of tensors depending on the length of `tensors`.
408 Args:
409 tensors: List of tensors from the dequeue op.
411 Returns:
412 A single tensor, a list of tensors, or a dictionary
413 of tensors.
414 """
415 if self._names:
416 # The returned values in `tensors` are in the same order as
417 # the names in `self._names`.
418 return {n: tensors[i] for i, n in enumerate(self._names)}
419 elif len(tensors) == 1:
420 return tensors[0]
421 else:
422 return tensors
424 def dequeue(self, name=None):
425 """Dequeues one element from this queue.
427 If the queue is empty when this operation executes, it will block
428 until there is an element to dequeue.
430 At runtime, this operation may raise an error if the queue is
431 `tf.QueueBase.close` before or during its execution. If the
432 queue is closed, the queue is empty, and there are no pending
433 enqueue operations that can fulfill this request,
434 `tf.errors.OutOfRangeError` will be raised. If the session is
435 `tf.Session.close`,
436 `tf.errors.CancelledError` will be raised.
438 Args:
439 name: A name for the operation (optional).
441 Returns:
442 The tuple of tensors that was dequeued.
443 """
444 if name is None:
445 name = "%s_Dequeue" % self._name
446 if self._queue_ref.dtype == _dtypes.resource:
447 ret = gen_data_flow_ops.queue_dequeue_v2(
448 self._queue_ref, self._dtypes, name=name)
449 else:
450 ret = gen_data_flow_ops.queue_dequeue(
451 self._queue_ref, self._dtypes, name=name)
453 # NOTE(mrry): Not using a shape function because we need access to
454 # the `QueueBase` object.
455 if not context.executing_eagerly():
456 op = ret[0].op
457 for output, shape in zip(op.values(), self._shapes):
458 output.set_shape(shape)
460 return self._dequeue_return_value(ret)
462 def dequeue_many(self, n, name=None):
463 """Dequeues and concatenates `n` elements from this queue.
465 This operation concatenates queue-element component tensors along
466 the 0th dimension to make a single component tensor. All of the
467 components in the dequeued tuple will have size `n` in the 0th dimension.
469 If the queue is closed and there are less than `n` elements left, then an
470 `OutOfRange` exception is raised.
472 At runtime, this operation may raise an error if the queue is
473 `tf.QueueBase.close` before or during its execution. If the
474 queue is closed, the queue contains fewer than `n` elements, and
475 there are no pending enqueue operations that can fulfill this
476 request, `tf.errors.OutOfRangeError` will be raised. If the
477 session is `tf.Session.close`,
478 `tf.errors.CancelledError` will be raised.
480 Args:
481 n: A scalar `Tensor` containing the number of elements to dequeue.
482 name: A name for the operation (optional).
484 Returns:
485 The list of concatenated tensors that was dequeued.
486 """
487 if name is None:
488 name = "%s_DequeueMany" % self._name
490 ret = gen_data_flow_ops.queue_dequeue_many_v2(
491 self._queue_ref, n=n, component_types=self._dtypes, name=name)
493 # NOTE(mrry): Not using a shape function because we need access to
494 # the Queue object.
495 if not context.executing_eagerly():
496 op = ret[0].op
497 batch_dim = tensor_shape.Dimension(
498 tensor_util.constant_value(op.inputs[1]))
499 for output, shape in zip(op.values(), self._shapes):
500 output.set_shape(
501 tensor_shape.TensorShape([batch_dim]).concatenate(shape))
503 return self._dequeue_return_value(ret)
505 def dequeue_up_to(self, n, name=None):
506 """Dequeues and concatenates `n` elements from this queue.
508 **Note** This operation is not supported by all queues. If a queue does not
509 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
511 This operation concatenates queue-element component tensors along
512 the 0th dimension to make a single component tensor. If the queue
513 has not been closed, all of the components in the dequeued tuple
514 will have size `n` in the 0th dimension.
516 If the queue is closed and there are more than `0` but fewer than
517 `n` elements remaining, then instead of raising a
518 `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`,
519 less than `n` elements are returned immediately. If the queue is
520 closed and there are `0` elements left in the queue, then a
521 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
522 Otherwise the behavior is identical to `dequeue_many`.
524 Args:
525 n: A scalar `Tensor` containing the number of elements to dequeue.
526 name: A name for the operation (optional).
528 Returns:
529 The tuple of concatenated tensors that was dequeued.
530 """
531 if name is None:
532 name = "%s_DequeueUpTo" % self._name
534 ret = gen_data_flow_ops.queue_dequeue_up_to_v2(
535 self._queue_ref, n=n, component_types=self._dtypes, name=name)
537 # NOTE(mrry): Not using a shape function because we need access to
538 # the Queue object.
539 if not context.executing_eagerly():
540 op = ret[0].op
541 for output, shape in zip(op.values(), self._shapes):
542 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
544 return self._dequeue_return_value(ret)
546 def close(self, cancel_pending_enqueues=False, name=None):
547 """Closes this queue.
549 This operation signals that no more elements will be enqueued in
550 the given queue. Subsequent `enqueue` and `enqueue_many`
551 operations will fail. Subsequent `dequeue` and `dequeue_many`
552 operations will continue to succeed if sufficient elements remain
553 in the queue. Subsequently dequeue and dequeue_many operations
554 that would otherwise block waiting for more elements (if close
555 hadn't been called) will now fail immediately.
557 If `cancel_pending_enqueues` is `True`, all pending requests will also
558 be canceled.
560 Args:
561 cancel_pending_enqueues: (Optional.) A boolean, defaulting to
562 `False` (described above).
563 name: A name for the operation (optional).
565 Returns:
566 The operation that closes the queue.
567 """
568 if name is None:
569 name = "%s_Close" % self._name
570 if self._queue_ref.dtype == _dtypes.resource:
571 return gen_data_flow_ops.queue_close_v2(
572 self._queue_ref,
573 cancel_pending_enqueues=cancel_pending_enqueues,
574 name=name)
575 else:
576 return gen_data_flow_ops.queue_close(
577 self._queue_ref,
578 cancel_pending_enqueues=cancel_pending_enqueues,
579 name=name)
581 def is_closed(self, name=None):
582 """Returns true if queue is closed.
584 This operation returns true if the queue is closed and false if the queue
585 is open.
587 Args:
588 name: A name for the operation (optional).
590 Returns:
591 True if the queue is closed and false if the queue is open.
592 """
593 if name is None:
594 name = "%s_Is_Closed" % self._name
595 if self._queue_ref.dtype == _dtypes.resource:
596 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
597 else:
598 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
600 def size(self, name=None):
601 """Compute the number of elements in this queue.
603 Args:
604 name: A name for the operation (optional).
606 Returns:
607 A scalar tensor containing the number of elements in this queue.
608 """
609 if name is None:
610 name = "%s_Size" % self._name
611 if self._queue_ref.dtype == _dtypes.resource:
612 return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name)
613 else:
614 return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
616def _shared_name(shared_name):
617 if context.executing_eagerly():
618 return str(ops.uid())
619 return shared_name
622@tf_export(
623 "queue.RandomShuffleQueue",
624 v1=["queue.RandomShuffleQueue",
625 "io.RandomShuffleQueue", "RandomShuffleQueue"])
626@deprecation.deprecated_endpoints(
627 ["io.RandomShuffleQueue", "RandomShuffleQueue"])
628class RandomShuffleQueue(QueueBase):
629 """A queue implementation that dequeues elements in a random order.
631 See `tf.queue.QueueBase` for a description of the methods on
632 this class.
633 """
635 def __init__(self,
636 capacity,
637 min_after_dequeue,
638 dtypes,
639 shapes=None,
640 names=None,
641 seed=None,
642 shared_name=None,
643 name="random_shuffle_queue"):
644 """Create a queue that dequeues elements in a random order.
646 A `RandomShuffleQueue` has bounded capacity; supports multiple
647 concurrent producers and consumers; and provides exactly-once
648 delivery.
650 A `RandomShuffleQueue` holds a list of up to `capacity`
651 elements. Each element is a fixed-length tuple of tensors whose
652 dtypes are described by `dtypes`, and whose shapes are optionally
653 described by the `shapes` argument.
655 If the `shapes` argument is specified, each component of a queue
656 element must have the respective fixed shape. If it is
657 unspecified, different queue elements may have different shapes,
658 but the use of `dequeue_many` is disallowed.
660 The `min_after_dequeue` argument allows the caller to specify a
661 minimum number of elements that will remain in the queue after a
662 `dequeue` or `dequeue_many` operation completes, to ensure a
663 minimum level of mixing of elements. This invariant is maintained
664 by blocking those operations until sufficient elements have been
665 enqueued. The `min_after_dequeue` argument is ignored after the
666 queue has been closed.
668 Args:
669 capacity: An integer. The upper bound on the number of elements
670 that may be stored in this queue.
671 min_after_dequeue: An integer (described above).
672 dtypes: A list of `DType` objects. The length of `dtypes` must equal
673 the number of tensors in each queue element.
674 shapes: (Optional.) A list of fully-defined `TensorShape` objects
675 with the same length as `dtypes`, or `None`.
676 names: (Optional.) A list of string naming the components in the queue
677 with the same length as `dtypes`, or `None`. If specified the dequeue
678 methods return a dictionary with the names as keys.
679 seed: A Python integer. Used to create a random seed. See
680 `tf.compat.v1.set_random_seed`
681 for behavior.
682 shared_name: (Optional.) If non-empty, this queue will be shared under
683 the given name across multiple sessions.
684 name: Optional name for the queue operation.
685 """
686 dtypes = _as_type_list(dtypes)
687 shapes = _as_shape_list(shapes, dtypes)
688 names = _as_name_list(names, dtypes)
689 seed1, seed2 = random_seed.get_seed(seed)
690 if seed1 is None and seed2 is None:
691 seed1, seed2 = 0, 0
692 elif seed is None and shared_name is not None:
693 # This means that graph seed is provided but op seed is not provided.
694 # If shared_name is also provided, make seed2 depend only on the graph
695 # seed and shared_name. (seed2 from get_seed() is generally dependent on
696 # the id of the last op created.)
697 string = (str(seed1) + shared_name).encode("utf-8")
698 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
699 queue_ref = gen_data_flow_ops.random_shuffle_queue_v2(
700 component_types=dtypes,
701 shapes=shapes,
702 capacity=capacity,
703 min_after_dequeue=min_after_dequeue,
704 seed=seed1,
705 seed2=seed2,
706 shared_name=_shared_name(shared_name),
707 name=name)
709 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
712@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"])
713@deprecation.deprecated_endpoints("FIFOQueue")
714class FIFOQueue(QueueBase):
715 """A queue implementation that dequeues elements in first-in first-out order.
717 See `tf.queue.QueueBase` for a description of the methods on
718 this class.
719 """
721 def __init__(self,
722 capacity,
723 dtypes,
724 shapes=None,
725 names=None,
726 shared_name=None,
727 name="fifo_queue"):
728 """Creates a queue that dequeues elements in a first-in first-out order.
730 A `FIFOQueue` has bounded capacity; supports multiple concurrent
731 producers and consumers; and provides exactly-once delivery.
733 A `FIFOQueue` holds a list of up to `capacity` elements. Each
734 element is a fixed-length tuple of tensors whose dtypes are
735 described by `dtypes`, and whose shapes are optionally described
736 by the `shapes` argument.
738 If the `shapes` argument is specified, each component of a queue
739 element must have the respective fixed shape. If it is
740 unspecified, different queue elements may have different shapes,
741 but the use of `dequeue_many` is disallowed.
743 Args:
744 capacity: An integer. The upper bound on the number of elements
745 that may be stored in this queue.
746 dtypes: A list of `DType` objects. The length of `dtypes` must equal
747 the number of tensors in each queue element.
748 shapes: (Optional.) A list of fully-defined `TensorShape` objects
749 with the same length as `dtypes`, or `None`.
750 names: (Optional.) A list of string naming the components in the queue
751 with the same length as `dtypes`, or `None`. If specified the dequeue
752 methods return a dictionary with the names as keys.
753 shared_name: (Optional.) If non-empty, this queue will be shared under
754 the given name across multiple sessions.
755 name: Optional name for the queue operation.
756 """
757 dtypes = _as_type_list(dtypes)
758 shapes = _as_shape_list(shapes, dtypes)
759 names = _as_name_list(names, dtypes)
760 with ops.init_scope(), ops.device("CPU"):
761 queue_ref = gen_data_flow_ops.fifo_queue_v2(
762 component_types=dtypes,
763 shapes=shapes,
764 capacity=capacity,
765 shared_name=_shared_name(shared_name),
766 name=name)
768 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
771# TODO(allenl): If GPU-compatible queues turn out to be useful, we should
772# implement GPU kernels for EnqueueMany and DequeueMany so we can make the
773# public FIFOQueue GPU-compatible and remove this internal version.
774class GPUCompatibleFIFOQueue(QueueBase):
775 """A queue implementation that dequeues elements in first-in first-out order.
777 GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed
778 either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues
779 will be colocated with the queue resource. GPUCompatibleFIFOQueue only
780 supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many.
782 See `tf.queue.QueueBase` for a description of the methods on this class.
783 """
785 def __init__(self,
786 capacity,
787 dtypes,
788 shapes=None,
789 names=None,
790 shared_name=None,
791 name="fifo_queue"):
792 """Creates a queue that dequeues elements in a first-in first-out order.
794 A `FIFOQueue` has bounded capacity; supports multiple concurrent
795 producers and consumers; and provides exactly-once delivery.
797 A `FIFOQueue` holds a list of up to `capacity` elements. Each
798 element is a fixed-length tuple of tensors whose dtypes are
799 described by `dtypes`, and whose shapes are optionally described
800 by the `shapes` argument.
802 If the `shapes` argument is specified, each component of a queue
803 element must have the respective fixed shape. If it is
804 unspecified, different queue elements may have different shapes,
805 but the use of `dequeue_many` is disallowed.
807 Args:
808 capacity: An integer. The upper bound on the number of elements
809 that may be stored in this queue.
810 dtypes: A list of `DType` objects. The length of `dtypes` must equal
811 the number of tensors in each queue element.
812 shapes: (Optional.) A list of fully-defined `TensorShape` objects
813 with the same length as `dtypes`, or `None`.
814 names: (Optional.) A list of string naming the components in the queue
815 with the same length as `dtypes`, or `None`. If specified the dequeue
816 methods return a dictionary with the names as keys.
817 shared_name: (Optional.) If non-empty, this queue will be shared under
818 the given name across multiple sessions.
819 name: Optional name for the queue operation.
820 """
821 dtypes = _as_type_list(dtypes)
822 shapes = _as_shape_list(shapes, dtypes)
823 names = _as_name_list(names, dtypes)
824 with ops.init_scope():
825 queue_ref = gen_data_flow_ops.fifo_queue_v2(
826 component_types=dtypes,
827 shapes=shapes,
828 capacity=capacity,
829 shared_name=_shared_name(shared_name),
830 name=name)
832 super(GPUCompatibleFIFOQueue, self).__init__(
833 dtypes, shapes, names, queue_ref)
835 def enqueue_many(self, vals, name=None):
836 """enqueue_many is not supported on GPUCompatibleFIFOQueue."""
837 raise NotImplementedError(
838 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
839 "only enqueue and dequeue.")
841 def dequeue_many(self, n, name=None):
842 """dequeue_many is not supported on GPUCompatibleFIFOQueue."""
843 raise NotImplementedError(
844 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, "
845 "only enqueue and dequeue.")
848@tf_export(
849 "queue.PaddingFIFOQueue",
850 v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"])
851@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"])
852class PaddingFIFOQueue(QueueBase):
853 """A FIFOQueue that supports batching variable-sized tensors by padding.
855 A `PaddingFIFOQueue` may contain components with dynamic shape, while also
856 supporting `dequeue_many`. See the constructor for more details.
858 See `tf.queue.QueueBase` for a description of the methods on
859 this class.
860 """
862 def __init__(self,
863 capacity,
864 dtypes,
865 shapes,
866 names=None,
867 shared_name=None,
868 name="padding_fifo_queue"):
869 """Creates a queue that dequeues elements in a first-in first-out order.
871 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
872 producers and consumers; and provides exactly-once delivery.
874 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
875 element is a fixed-length tuple of tensors whose dtypes are
876 described by `dtypes`, and whose shapes are described by the `shapes`
877 argument.
879 The `shapes` argument must be specified; each component of a queue
880 element must have the respective shape. Shapes of fixed
881 rank but variable size are allowed by setting any shape dimension to None.
882 In this case, the inputs' shape may vary along the given dimension, and
883 `dequeue_many` will pad the given dimension with zeros up to the maximum
884 shape of all elements in the given batch.
886 Args:
887 capacity: An integer. The upper bound on the number of elements
888 that may be stored in this queue.
889 dtypes: A list of `DType` objects. The length of `dtypes` must equal
890 the number of tensors in each queue element.
891 shapes: A list of `TensorShape` objects, with the same length as
892 `dtypes`. Any dimension in the `TensorShape` containing value
893 `None` is dynamic and allows values to be enqueued with
894 variable size in that dimension.
895 names: (Optional.) A list of string naming the components in the queue
896 with the same length as `dtypes`, or `None`. If specified the dequeue
897 methods return a dictionary with the names as keys.
898 shared_name: (Optional.) If non-empty, this queue will be shared under
899 the given name across multiple sessions.
900 name: Optional name for the queue operation.
902 Raises:
903 ValueError: If shapes is not a list of shapes, or the lengths of dtypes
904 and shapes do not match, or if names is specified and the lengths of
905 dtypes and names do not match.
906 """
907 dtypes = _as_type_list(dtypes)
908 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
909 names = _as_name_list(names, dtypes)
910 if len(dtypes) != len(shapes):
911 raise ValueError("Shapes must be provided for all components, "
912 f"but received {len(dtypes)} dtypes and "
913 f"{len(shapes)} shapes.")
914 queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
915 component_types=dtypes,
916 shapes=shapes,
917 capacity=capacity,
918 shared_name=_shared_name(shared_name),
919 name=name)
921 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
924@tf_export("queue.PriorityQueue",
925 v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"])
926@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"])
927class PriorityQueue(QueueBase):
928 """A queue implementation that dequeues elements in prioritized order.
930 See `tf.queue.QueueBase` for a description of the methods on
931 this class.
932 """
934 def __init__(self,
935 capacity,
936 types,
937 shapes=None,
938 names=None,
939 shared_name=None,
940 name="priority_queue"):
941 """Creates a queue that dequeues elements in a first-in first-out order.
943 A `PriorityQueue` has bounded capacity; supports multiple concurrent
944 producers and consumers; and provides exactly-once delivery.
946 A `PriorityQueue` holds a list of up to `capacity` elements. Each
947 element is a fixed-length tuple of tensors whose dtypes are
948 described by `types`, and whose shapes are optionally described
949 by the `shapes` argument.
951 If the `shapes` argument is specified, each component of a queue
952 element must have the respective fixed shape. If it is
953 unspecified, different queue elements may have different shapes,
954 but the use of `dequeue_many` is disallowed.
956 Enqueues and Dequeues to the `PriorityQueue` must include an additional
957 tuple entry at the beginning: the `priority`. The priority must be
958 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
960 Args:
961 capacity: An integer. The upper bound on the number of elements
962 that may be stored in this queue.
963 types: A list of `DType` objects. The length of `types` must equal
964 the number of tensors in each queue element, except the first priority
965 element. The first tensor in each element is the priority,
966 which must be type int64.
967 shapes: (Optional.) A list of fully-defined `TensorShape` objects,
968 with the same length as `types`, or `None`.
969 names: (Optional.) A list of strings naming the components in the queue
970 with the same length as `dtypes`, or `None`. If specified, the dequeue
971 methods return a dictionary with the names as keys.
972 shared_name: (Optional.) If non-empty, this queue will be shared under
973 the given name across multiple sessions.
974 name: Optional name for the queue operation.
975 """
976 types = _as_type_list(types)
977 shapes = _as_shape_list(shapes, types)
979 queue_ref = gen_data_flow_ops.priority_queue_v2(
980 component_types=types,
981 shapes=shapes,
982 capacity=capacity,
983 shared_name=_shared_name(shared_name),
984 name=name)
986 priority_dtypes = [_dtypes.int64] + types
987 priority_shapes = [()] + shapes if shapes else shapes
989 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
990 queue_ref)
993# TODO(josh11b): class BatchQueue(QueueBase):
996class Barrier:
997 """Represents a key-value map that persists across graph executions."""
999 def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
1000 """Creates a barrier that persists across different graph executions.
1002 A barrier represents a key-value map, where each key is a string, and
1003 each value is a tuple of tensors.
1005 At runtime, the barrier contains 'complete' and 'incomplete'
1006 elements. A complete element has defined tensors for all
1007 components of its value tuple, and may be accessed using
1008 take_many. An incomplete element has some undefined components in
1009 its value tuple, and may be updated using insert_many.
1011 The barrier call `take_many` outputs values in a particular order.
1012 First, it only outputs completed values. Second, the order in which
1013 completed values are returned matches the order in which their very
1014 first component was inserted into the barrier. So, for example, for this
1015 sequence of insertions and removals:
1017 barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
1018 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
1019 barrier.insert_many(1, keys=["k1"], values=[1]).run()
1020 barrier.insert_many(0, keys=["k3"], values=["c"]).run()
1021 barrier.insert_many(1, keys=["k3"], values=[3]).run()
1022 barrier.insert_many(1, keys=["k2"], values=[2]).run()
1024 (indices, keys, values) = barrier.take_many(2)
1025 (indices_val, keys_val, values0_val, values1_val) =
1026 session.run([indices, keys, values[0], values[1]])
1028 The output will be (up to permutation of "k1" and "k2"):
1030 indices_val == (-2**63, -2**63)
1031 keys_val == ("k1", "k2")
1032 values0_val == ("a", "b")
1033 values1_val == (1, 2)
1035 Note the key "k2" was inserted into the barrier before "k3". Even though
1036 "k3" was completed first, both are complete by the time
1037 take_many is called. As a result, "k2" is prioritized and "k1" and "k2"
1038 are returned first. "k3" remains in the barrier until the next execution
1039 of `take_many`. Since "k1" and "k2" had their first insertions into
1040 the barrier together, their indices are the same (-2**63). The index
1041 of "k3" will be -2**63 + 1, because it was the next new inserted key.
1043 Args:
1044 types: A single dtype or a tuple of dtypes, corresponding to the
1045 dtypes of the tensor elements that comprise a value in this barrier.
1046 shapes: Optional. Constraints on the shapes of tensors in the values:
1047 a single tensor shape tuple; a tuple of tensor shape tuples
1048 for each barrier-element tuple component; or None if the shape should
1049 not be constrained.
1050 shared_name: Optional. If non-empty, this barrier will be shared under
1051 the given name across multiple sessions.
1052 name: Optional name for the barrier op.
1054 Raises:
1055 ValueError: If one of the `shapes` indicate no elements.
1056 """
1057 self._types = _as_type_list(types)
1059 if shapes is not None:
1060 shapes = _as_shape_list(shapes, self._types)
1061 self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1062 for i, shape in enumerate(self._shapes):
1063 if shape.num_elements() == 0:
1064 raise ValueError("Empty tensors are not supported, but received "
1065 f"shape '{shape}' at index {i}")
1066 else:
1067 self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
1069 self._barrier_ref = gen_data_flow_ops.barrier(
1070 component_types=self._types,
1071 shapes=self._shapes,
1072 shared_name=shared_name,
1073 name=name)
1074 if context.executing_eagerly():
1075 self._name = context.context().scope_name
1076 else:
1077 self._name = self._barrier_ref.op.name.split("/")[-1]
1079 @property
1080 def barrier_ref(self):
1081 """Get the underlying barrier reference."""
1082 return self._barrier_ref
1084 @property
1085 def name(self):
1086 """The name of the underlying barrier."""
1087 if context.executing_eagerly():
1088 return self._name
1089 return self._barrier_ref.op.name
1091 def insert_many(self, component_index, keys, values, name=None):
1092 """For each key, assigns the respective value to the specified component.
1094 This operation updates each element at component_index.
1096 Args:
1097 component_index: The component of the value that is being assigned.
1098 keys: A vector of keys, with length n.
1099 values: An any-dimensional tensor of values, which are associated with the
1100 respective keys. The first dimension must have length n.
1101 name: Optional name for the op.
1103 Returns:
1104 The operation that performs the insertion.
1105 Raises:
1106 InvalidArgumentsError: If inserting keys and values without elements.
1107 """
1108 if name is None:
1109 name = "%s_BarrierInsertMany" % self._name
1110 return gen_data_flow_ops.barrier_insert_many(
1111 self._barrier_ref, keys, values, component_index, name=name)
1113 def take_many(self,
1114 num_elements,
1115 allow_small_batch=False,
1116 timeout=None,
1117 name=None):
1118 """Takes the given number of completed elements from this barrier.
1120 This operation concatenates completed-element component tensors along
1121 the 0th dimension to make a single component tensor.
1123 If barrier has no completed elements, this operation will block
1124 until there are 'num_elements' elements to take.
1126 TODO(b/25743580): the semantics of `allow_small_batch` are experimental
1127 and may be extended to other cases in the future.
1129 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
1130 already when the barrier is closed, it will block for ever. Fix this
1131 by using asynchronous operations.
1133 Args:
1134 num_elements: The number of elements to take.
1135 allow_small_batch: If the barrier is closed, don't block if there are less
1136 completed elements than requested, but instead return all available
1137 completed elements.
1138 timeout: This specifies the number of milliseconds to block
1139 before returning with DEADLINE_EXCEEDED. (This option is not
1140 supported yet.)
1141 name: A name for the operation (optional).
1143 Returns:
1144 A tuple of (index, key, value_list).
1145 "index" is a int64 tensor of length num_elements containing the
1146 index of the insert_many call for which the very first component of
1147 the given element was inserted into the Barrier, starting with
1148 the value -2**63. Note, this value is different from the
1149 index of the insert_many call for which the element was completed.
1150 "key" is a string tensor of length num_elements containing the keys.
1151 "value_list" is a tuple of tensors, each one with size num_elements
1152 in the 0th dimension for each component in the barrier's values.
1154 """
1155 if name is None:
1156 name = "%s_BarrierTakeMany" % self._name
1157 ret = gen_data_flow_ops.barrier_take_many(
1158 self._barrier_ref,
1159 num_elements,
1160 self._types,
1161 allow_small_batch,
1162 timeout,
1163 name=name)
1165 # NOTE(mrry): Not using a shape function because we need access to
1166 # the Barrier object.
1167 if not context.executing_eagerly():
1168 op = ret[0].op
1169 if allow_small_batch:
1170 batch_dim = None
1171 else:
1172 batch_dim = tensor_shape.Dimension(
1173 tensor_util.constant_value(op.inputs[1]))
1174 op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim])) # indices
1175 op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim])) # keys
1176 for output, shape in zip(op.outputs[2:], self._shapes): # value_list
1177 output.set_shape(
1178 tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1180 return ret
1182 def close(self, cancel_pending_enqueues=False, name=None):
1183 """Closes this barrier.
1185 This operation signals that no more new key values will be inserted in the
1186 given barrier. Subsequent InsertMany operations with new keys will fail.
1187 InsertMany operations that just complement already existing keys with other
1188 components, will continue to succeed. Subsequent TakeMany operations will
1189 continue to succeed if sufficient elements remain in the barrier. Subsequent
1190 TakeMany operations that would block will fail immediately.
1192 If `cancel_pending_enqueues` is `True`, all pending requests to the
1193 underlying queue will also be canceled, and completing of already
1194 started values is also not acceptable anymore.
1196 Args:
1197 cancel_pending_enqueues: (Optional.) A boolean, defaulting to
1198 `False` (described above).
1199 name: Optional name for the op.
1201 Returns:
1202 The operation that closes the barrier.
1203 """
1204 if name is None:
1205 name = "%s_BarrierClose" % self._name
1206 return gen_data_flow_ops.barrier_close(
1207 self._barrier_ref,
1208 cancel_pending_enqueues=cancel_pending_enqueues,
1209 name=name)
1211 def ready_size(self, name=None):
1212 """Compute the number of complete elements in the given barrier.
1214 Args:
1215 name: A name for the operation (optional).
1217 Returns:
1218 A single-element tensor containing the number of complete elements in the
1219 given barrier.
1220 """
1221 if name is None:
1222 name = "%s_BarrierReadySize" % self._name
1223 return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name)
1225 def incomplete_size(self, name=None):
1226 """Compute the number of incomplete elements in the given barrier.
1228 Args:
1229 name: A name for the operation (optional).
1231 Returns:
1232 A single-element tensor containing the number of incomplete elements in
1233 the given barrier.
1234 """
1235 if name is None:
1236 name = "%s_BarrierIncompleteSize" % self._name
1237 return gen_data_flow_ops.barrier_incomplete_size(
1238 self._barrier_ref, name=name)
1241@tf_export(v1=["ConditionalAccumulatorBase"])
1242class ConditionalAccumulatorBase:
1243 """A conditional accumulator for aggregating gradients.
1245 Up-to-date gradients (i.e., time step at which gradient was computed is
1246 equal to the accumulator's time step) are added to the accumulator.
1248 Extraction of the average gradient is blocked until the required number of
1249 gradients has been accumulated.
1250 """
1252 def __init__(self, dtype, shape, accumulator_ref):
1253 """Creates a new ConditionalAccumulator.
1255 Args:
1256 dtype: Datatype of the accumulated gradients.
1257 shape: Shape of the accumulated gradients.
1258 accumulator_ref: A handle to the conditional accumulator, created by sub-
1259 classes
1260 """
1261 self._dtype = dtype
1262 if shape is not None:
1263 self._shape = tensor_shape.TensorShape(shape)
1264 else:
1265 self._shape = tensor_shape.unknown_shape()
1266 self._accumulator_ref = accumulator_ref
1267 if context.executing_eagerly():
1268 self._name = context.context().scope_name
1269 else:
1270 self._name = self._accumulator_ref.op.name.split("/")[-1]
1272 @property
1273 def accumulator_ref(self):
1274 """The underlying accumulator reference."""
1275 return self._accumulator_ref
1277 @property
1278 def name(self):
1279 """The name of the underlying accumulator."""
1280 return self._name
1282 @property
1283 def dtype(self):
1284 """The datatype of the gradients accumulated by this accumulator."""
1285 return self._dtype
1287 def num_accumulated(self, name=None):
1288 """Number of gradients that have currently been aggregated in accumulator.
1290 Args:
1291 name: Optional name for the operation.
1293 Returns:
1294 Number of accumulated gradients currently in accumulator.
1295 """
1296 if name is None:
1297 name = "%s_NumAccumulated" % self._name
1299 return gen_data_flow_ops.resource_accumulator_num_accumulated(
1300 self._accumulator_ref, name=name)
1302 def set_global_step(self, new_global_step, name=None):
1303 """Sets the global time step of the accumulator.
1305 The operation logs a warning if we attempt to set to a time step that is
1306 lower than the accumulator's own time step.
1308 Args:
1309 new_global_step: Value of new time step. Can be a variable or a constant
1310 name: Optional name for the operation.
1312 Returns:
1313 Operation that sets the accumulator's time step.
1314 """
1315 return gen_data_flow_ops.resource_accumulator_set_global_step(
1316 self._accumulator_ref,
1317 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1318 name=name)
1321@tf_export(v1=["ConditionalAccumulator"])
1322class ConditionalAccumulator(ConditionalAccumulatorBase):
1323 """A conditional accumulator for aggregating gradients.
1325 Up-to-date gradients (i.e., time step at which gradient was computed is
1326 equal to the accumulator's time step) are added to the accumulator.
1328 Extraction of the average gradient is blocked until the required number of
1329 gradients has been accumulated.
1330 """
1332 def __init__(self,
1333 dtype,
1334 shape=None,
1335 shared_name=None,
1336 name="conditional_accumulator",
1337 reduction_type="MEAN"):
1338 """Creates a new ConditionalAccumulator.
1340 Args:
1341 dtype: Datatype of the accumulated gradients.
1342 shape: Shape of the accumulated gradients.
1343 shared_name: Optional. If non-empty, this accumulator will be shared under
1344 the given name across multiple sessions.
1345 name: Optional name for the accumulator.
1346 reduction_type: Reduction type to use when taking the gradient.
1347 """
1348 accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator(
1349 dtype=dtype,
1350 shape=shape,
1351 shared_name=shared_name,
1352 name=name,
1353 reduction_type=reduction_type)
1354 if context.executing_eagerly():
1355 self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
1356 handle=accumulator_ref, handle_device=context.context().device_name)
1358 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
1360 def apply_grad(self, grad, local_step=0, name=None):
1361 """Attempts to apply a gradient to the accumulator.
1363 The attempt is silently dropped if the gradient is stale, i.e., local_step
1364 is less than the accumulator's global time step.
1366 Args:
1367 grad: The gradient tensor to be applied.
1368 local_step: Time step at which the gradient was computed.
1369 name: Optional name for the operation.
1371 Returns:
1372 The operation that (conditionally) applies a gradient to the accumulator.
1374 Raises:
1375 ValueError: If grad is of the wrong shape
1376 """
1377 grad = ops.convert_to_tensor(grad, self._dtype)
1378 grad.get_shape().assert_is_compatible_with(self._shape)
1379 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1381 return gen_data_flow_ops.resource_accumulator_apply_gradient(
1382 self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
1384 def take_grad(self, num_required, name=None):
1385 """Attempts to extract the average gradient from the accumulator.
1387 The operation blocks until sufficient number of gradients have been
1388 successfully applied to the accumulator.
1390 Once successful, the following actions are also triggered:
1392 - Counter of accumulated gradients is reset to 0.
1393 - Aggregated gradient is reset to 0 tensor.
1394 - Accumulator's internal time step is incremented by 1.
1396 Args:
1397 num_required: Number of gradients that needs to have been aggregated
1398 name: Optional name for the operation
1400 Returns:
1401 A tensor holding the value of the average gradient.
1403 Raises:
1404 InvalidArgumentError: If num_required < 1
1405 """
1406 out = gen_data_flow_ops.resource_accumulator_take_gradient(
1407 self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1408 out.set_shape(self._shape)
1409 return out
1412@tf_export(
1413 v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"])
1414class SparseConditionalAccumulator(ConditionalAccumulatorBase):
1415 """A conditional accumulator for aggregating sparse gradients.
1417 Sparse gradients are represented by `IndexedSlices`.
1419 Up-to-date gradients (i.e., time step at which gradient was computed is
1420 equal to the accumulator's time step) are added to the accumulator.
1422 Extraction of the average gradient is blocked until the required number of
1423 gradients has been accumulated.
1425 Args:
1426 dtype: Datatype of the accumulated gradients.
1427 shape: Shape of the accumulated gradients.
1428 shared_name: Optional. If non-empty, this accumulator will be shared under
1429 the given name across multiple sessions.
1430 name: Optional name for the accumulator.
1431 reduction_type: Reduction type to use when taking the gradient.
1432 """
1434 def __init__(self,
1435 dtype,
1436 shape=None,
1437 shared_name=None,
1438 name="sparse_conditional_accumulator",
1439 reduction_type="MEAN"):
1440 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
1441 dtype=dtype,
1442 shape=shape,
1443 shared_name=shared_name,
1444 name=name,
1445 reduction_type=reduction_type)
1446 super(SparseConditionalAccumulator, self).__init__(dtype, shape,
1447 accumulator_ref)
1449 def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
1450 """Attempts to apply a gradient to the accumulator.
1452 The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1453 is less than the accumulator's global time step.
1455 Args:
1456 grad: The gradient `IndexedSlices` to be applied.
1457 local_step: Time step at which the gradient was computed.
1458 name: Optional name for the operation.
1460 Returns:
1461 The operation that (conditionally) applies a gradient to the accumulator.
1463 Raises:
1464 InvalidArgumentError: If grad is of the wrong shape
1465 """
1466 return self.apply_grad(
1467 grad_indices=grad.indices,
1468 grad_values=grad.values,
1469 grad_shape=grad.dense_shape,
1470 local_step=local_step,
1471 name=name)
1473 def apply_grad(self,
1474 grad_indices,
1475 grad_values,
1476 grad_shape=None,
1477 local_step=0,
1478 name=None):
1479 """Attempts to apply a sparse gradient to the accumulator.
1481 The attempt is silently dropped if the gradient is stale, i.e., `local_step`
1482 is less than the accumulator's global time step.
1484 A sparse gradient is represented by its indices, values and possibly empty
1485 or None shape. Indices must be a vector representing the locations of
1486 non-zero entries in the tensor. Values are the non-zero slices of the
1487 gradient, and must have the same first dimension as indices, i.e., the nnz
1488 represented by indices and values must be consistent. Shape, if not empty or
1489 None, must be consistent with the accumulator's shape (if also provided).
1491 Example:
1492 A tensor [[0, 0], [0, 1], [2, 3]] can be represented
1493 indices: [1,2]
1494 values: [[0,1],[2,3]]
1495 shape: [3, 2]
1497 Args:
1498 grad_indices: Indices of the sparse gradient to be applied.
1499 grad_values: Values of the sparse gradient to be applied.
1500 grad_shape: Shape of the sparse gradient to be applied.
1501 local_step: Time step at which the gradient was computed.
1502 name: Optional name for the operation.
1504 Returns:
1505 The operation that (conditionally) applies a gradient to the accumulator.
1507 Raises:
1508 InvalidArgumentError: If grad is of the wrong shape
1509 """
1510 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1511 return gen_data_flow_ops.sparse_accumulator_apply_gradient(
1512 self._accumulator_ref,
1513 local_step=local_step,
1514 gradient_indices=math_ops.cast(grad_indices, _dtypes.int64),
1515 gradient_values=grad_values,
1516 gradient_shape=math_ops.cast(
1517 [] if grad_shape is None else grad_shape, _dtypes.int64),
1518 has_known_shape=(grad_shape is not None),
1519 name=name)
1521 def take_grad(self, num_required, name=None):
1522 """Attempts to extract the average gradient from the accumulator.
1524 The operation blocks until sufficient number of gradients have been
1525 successfully applied to the accumulator.
1527 Once successful, the following actions are also triggered:
1528 - Counter of accumulated gradients is reset to 0.
1529 - Aggregated gradient is reset to 0 tensor.
1530 - Accumulator's internal time step is incremented by 1.
1532 Args:
1533 num_required: Number of gradients that needs to have been aggregated
1534 name: Optional name for the operation
1536 Returns:
1537 A tuple of indices, values, and shape representing the average gradient.
1539 Raises:
1540 InvalidArgumentError: If `num_required` < 1
1541 """
1542 return gen_data_flow_ops.sparse_accumulator_take_gradient(
1543 self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1545 def take_indexed_slices_grad(self, num_required, name=None):
1546 """Attempts to extract the average gradient from the accumulator.
1548 The operation blocks until sufficient number of gradients have been
1549 successfully applied to the accumulator.
1551 Once successful, the following actions are also triggered:
1552 - Counter of accumulated gradients is reset to 0.
1553 - Aggregated gradient is reset to 0 tensor.
1554 - Accumulator's internal time step is incremented by 1.
1556 Args:
1557 num_required: Number of gradients that needs to have been aggregated
1558 name: Optional name for the operation
1560 Returns:
1561 An `IndexedSlices` holding the value of the average gradient.
1563 Raises:
1564 InvalidArgumentError: If `num_required` < 1
1565 """
1566 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
1567 self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1568 return indexed_slices.IndexedSlices(
1569 indices=return_val.indices,
1570 values=return_val.values,
1571 dense_shape=return_val.shape)
1573 # SparseConditionalAccumulator is not switched to resource. Use old kernels.
1574 def num_accumulated(self, name=None):
1575 """Number of gradients that have currently been aggregated in accumulator.
1577 Args:
1578 name: Optional name for the operation.
1580 Returns:
1581 Number of accumulated gradients currently in accumulator.
1582 """
1583 if name is None:
1584 name = "%s_NumAccumulated" % self._name
1586 return gen_data_flow_ops.accumulator_num_accumulated(
1587 self._accumulator_ref, name=name)
1589 def set_global_step(self, new_global_step, name=None):
1590 """Sets the global time step of the accumulator.
1592 The operation logs a warning if we attempt to set to a time step that is
1593 lower than the accumulator's own time step.
1595 Args:
1596 new_global_step: Value of new time step. Can be a variable or a constant
1597 name: Optional name for the operation.
1599 Returns:
1600 Operation that sets the accumulator's time step.
1601 """
1602 return gen_data_flow_ops.accumulator_set_global_step(
1603 self._accumulator_ref,
1604 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1605 name=name)
1608class BaseStagingArea:
1609 """Base class for Staging Areas."""
1610 _identifier = 0
1611 _lock = threading.Lock()
1613 def __init__(self,
1614 dtypes,
1615 shapes=None,
1616 names=None,
1617 shared_name=None,
1618 capacity=0,
1619 memory_limit=0):
1620 if shared_name is None:
1621 self._name = (
1622 ops.get_default_graph().unique_name(self.__class__.__name__))
1623 elif isinstance(shared_name, str):
1624 self._name = shared_name
1625 else:
1626 raise ValueError(f"shared_name must be a string, got {shared_name}")
1628 self._dtypes = dtypes
1630 if shapes is not None:
1631 if len(shapes) != len(dtypes):
1632 raise ValueError("StagingArea shapes must be the same length as dtypes")
1633 self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1634 else:
1635 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
1637 if names is not None:
1638 if len(names) != len(dtypes):
1639 raise ValueError("StagingArea names must be the same length as dtypes")
1640 self._names = names
1641 else:
1642 self._names = None
1644 self._capacity = capacity
1645 self._memory_limit = memory_limit
1647 # all get and put ops must colocate with this op
1648 with ops.name_scope("%s_root" % self._name):
1649 self._coloc_op = control_flow_ops.no_op()
1651 @property
1652 def name(self):
1653 """The name of the staging area."""
1654 return self._name
1656 @property
1657 def dtypes(self):
1658 """The list of dtypes for each component of a staging area element."""
1659 return self._dtypes
1661 @property
1662 def shapes(self):
1663 """The list of shapes for each component of a staging area element."""
1664 return self._shapes
1666 @property
1667 def names(self):
1668 """The list of names for each component of a staging area element."""
1669 return self._names
1671 @property
1672 def capacity(self):
1673 """The maximum number of elements of this staging area."""
1674 return self._capacity
1676 @property
1677 def memory_limit(self):
1678 """The maximum number of bytes of this staging area."""
1679 return self._memory_limit
1681 def _check_put_dtypes(self, vals, indices=None):
1682 """Validate and convert `vals` to a list of `Tensor`s.
1684 The `vals` argument can be a Tensor, a list or tuple of tensors, or a
1685 dictionary with tensor values.
1687 If `vals` is a list, then the appropriate indices associated with the
1688 values must be provided.
1690 If it is a dictionary, the staging area must have been constructed with a
1691 `names` attribute and the dictionary keys must match the staging area names.
1692 `indices` will be inferred from the dictionary keys.
1693 If the staging area was constructed with a `names` attribute, `vals` must
1694 be a dictionary.
1696 Checks that the dtype and shape of each value matches that
1697 of the staging area.
1699 Args:
1700 vals: A tensor, a list or tuple of tensors, or a dictionary.
1702 Returns:
1703 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
1704 and `indices` is a list of indices associated with the tensors.
1706 Raises:
1707 ValueError: If `vals` or `indices` is invalid.
1708 """
1709 if isinstance(vals, dict):
1710 if not self._names:
1711 raise ValueError(
1712 "Staging areas must have names to enqueue a dictionary")
1713 if not set(vals.keys()).issubset(self._names):
1714 raise ValueError("Keys in dictionary to put do not match names "
1715 f"of staging area. Dictionary: {sorted(vals.keys())}"
1716 f"Queue: {sorted(self._names)}")
1717 # The order of values in `self._names` indicates the order in which the
1718 # tensors in the dictionary `vals` must be listed.
1719 vals, indices, _ = zip(*[(vals[k], i, k)
1720 for i, k in enumerate(self._names)
1721 if k in vals])
1722 else:
1723 if self._names:
1724 raise ValueError("You must enqueue a dictionary in a staging area "
1725 "with names")
1727 if indices is None:
1728 raise ValueError("Indices must be supplied when inserting a list "
1729 "of tensors")
1731 if len(indices) != len(vals):
1732 raise ValueError(f"Number of indices {len(indices)} doesn't match "
1733 f"number of values {len(vals)}")
1735 if not isinstance(vals, (list, tuple)):
1736 vals = [vals]
1737 indices = [0]
1739 # Sanity check number of values
1740 if not len(vals) <= len(self._dtypes):
1741 raise ValueError(f"Unexpected number of inputs {len(vals)} vs "
1742 f"{len(self._dtypes)}")
1744 tensors = []
1746 for val, i in zip(vals, indices):
1747 dtype, shape = self._dtypes[i], self._shapes[i]
1748 # Check dtype
1749 if val.dtype != dtype:
1750 raise ValueError(f"Datatypes do not match. "
1751 f"Received val.dtype {str(val.dtype)} and "
1752 f"dtype {str(dtype)}")
1753 # Check shape
1754 val.get_shape().assert_is_compatible_with(shape)
1756 tensors.append(
1757 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
1759 return tensors, indices
1761 def _create_device_transfers(self, tensors):
1762 """Encode inter-device transfers if the current device
1763 is not the same as the Staging Area's device.
1764 """
1766 if not isinstance(tensors, (tuple, list)):
1767 tensors = [tensors]
1769 curr_device_scope = control_flow_ops.no_op().device
1771 if curr_device_scope != self._coloc_op.device:
1772 tensors = [array_ops.identity(t) for t in tensors]
1774 return tensors
1776 def _get_return_value(self, tensors, indices):
1777 """Return the value to return from a get op.
1779 If the staging area has names, return a dictionary with the
1780 names as keys. Otherwise return either a single tensor
1781 or a list of tensors depending on the length of `tensors`.
1783 Args:
1784 tensors: List of tensors from the get op.
1785 indices: Indices of associated names and shapes
1787 Returns:
1788 A single tensor, a list of tensors, or a dictionary
1789 of tensors.
1790 """
1792 tensors = self._create_device_transfers(tensors)
1794 # Sets shape
1795 for output, i in zip(tensors, indices):
1796 output.set_shape(self._shapes[i])
1798 if self._names:
1799 # The returned values in `tensors` are in the same order as
1800 # the names in `self._names`.
1801 return {self._names[i]: t for t, i in zip(tensors, indices)}
1802 return tensors
1804 def _scope_vals(self, vals):
1805 """Return a list of values to pass to `name_scope()`.
1807 Args:
1808 vals: A tensor, a list or tuple of tensors, or a dictionary.
1810 Returns:
1811 The values in vals as a list.
1812 """
1813 if isinstance(vals, (list, tuple)):
1814 return vals
1815 elif isinstance(vals, dict):
1816 return vals.values()
1817 else:
1818 return [vals]
1821class StagingArea(BaseStagingArea):
1822 """Class for staging inputs. No ordering guarantees.
1824 A `StagingArea` is a TensorFlow data structure that stores tensors across
1825 multiple steps, and exposes operations that can put and get tensors.
1827 Each `StagingArea` element is a tuple of one or more tensors, where each
1828 tuple component has a static dtype, and may have a static shape.
1830 The capacity of a `StagingArea` may be bounded or unbounded.
1831 It supports multiple concurrent producers and consumers; and
1832 provides exactly-once delivery.
1834 Each element of a `StagingArea` is a fixed-length tuple of tensors whose
1835 dtypes are described by `dtypes`, and whose shapes are optionally described
1836 by the `shapes` argument.
1838 If the `shapes` argument is specified, each component of a staging area
1839 element must have the respective fixed shape. If it is
1840 unspecified, different elements may have different shapes,
1842 It can be configured with a capacity in which case
1843 put(values) will block until space becomes available.
1845 Similarly, it can be configured with a memory limit which
1846 will block put(values) until space is available.
1847 This is mostly useful for limiting the number of tensors on
1848 devices such as GPUs.
1850 All get() and peek() commands block if the requested data
1851 is not present in the Staging Area.
1853 """
1855 def __init__(self,
1856 dtypes,
1857 shapes=None,
1858 names=None,
1859 shared_name=None,
1860 capacity=0,
1861 memory_limit=0):
1862 """Constructs a staging area object.
1864 The two optional lists, `shapes` and `names`, must be of the same length
1865 as `dtypes` if provided. The values at a given index `i` indicate the
1866 shape and name to use for the corresponding queue component in `dtypes`.
1868 The device scope at the time of object creation determines where the
1869 storage for the `StagingArea` will reside. Calls to `put` will incur a copy
1870 to this memory space, if necessary. Tensors returned by `get` will be
1871 placed according to the device scope when `get` is called.
1873 Args:
1874 dtypes: A list of types. The length of dtypes must equal the number
1875 of tensors in each element.
1876 shapes: (Optional.) Constraints on the shapes of tensors in an element.
1877 A list of shape tuples or None. This list is the same length
1878 as dtypes. If the shape of any tensors in the element are constrained,
1879 all must be; shapes can be None if the shapes should not be constrained.
1880 names: (Optional.) If provided, the `get()` and
1881 `put()` methods will use dictionaries with these names as keys.
1882 Must be None or a list or tuple of the same length as `dtypes`.
1883 shared_name: (Optional.) A name to be used for the shared object. By
1884 passing the same name to two different python objects they will share
1885 the underlying staging area. Must be a string.
1886 capacity: (Optional.) Maximum number of elements.
1887 An integer. If zero, the Staging Area is unbounded
1888 memory_limit: (Optional.) Maximum number of bytes of all tensors
1889 in the Staging Area.
1890 An integer. If zero, the Staging Area is unbounded
1892 Raises:
1893 ValueError: If one of the arguments is invalid.
1894 """
1896 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
1897 capacity, memory_limit)
1899 def put(self, values, name=None):
1900 """Create an op that places a value into the staging area.
1902 This operation will block if the `StagingArea` has reached
1903 its capacity.
1905 Args:
1906 values: A single tensor, a list or tuple of tensors, or a dictionary with
1907 tensor values. The number of elements must match the length of the
1908 list provided to the dtypes argument when creating the StagingArea.
1909 name: A name for the operation (optional).
1911 Returns:
1912 The created op.
1914 Raises:
1915 ValueError: If the number or type of inputs don't match the staging area.
1916 """
1917 with ops.name_scope(name, "%s_put" % self._name,
1918 self._scope_vals(values)) as scope:
1920 if not isinstance(values, (list, tuple, dict)):
1921 values = [values]
1923 # Hard-code indices for this staging area
1924 indices = list(range(len(values)))
1925 vals, _ = self._check_put_dtypes(values, indices)
1927 with ops.colocate_with(self._coloc_op):
1928 op = gen_data_flow_ops.stage(
1929 values=vals,
1930 shared_name=self._name,
1931 name=scope,
1932 capacity=self._capacity,
1933 memory_limit=self._memory_limit)
1935 return op
1937 def __internal_get(self, get_fn, name):
1938 with ops.colocate_with(self._coloc_op):
1939 ret = get_fn()
1941 indices = list(range(len(self._dtypes))) # Hard coded
1942 return self._get_return_value(ret, indices)
1944 def get(self, name=None):
1945 """Gets one element from this staging area.
1947 If the staging area is empty when this operation executes, it will block
1948 until there is an element to dequeue.
1950 Note that unlike others ops that can block, like the queue Dequeue
1951 operations, this can stop other work from happening. To avoid this, the
1952 intended use is for this to be called only when there will be an element
1953 already available. One method for doing this in a training loop would be to
1954 run a `put()` call during a warmup session.run call, and then call both
1955 `get()` and `put()` in each subsequent step.
1957 The placement of the returned tensor will be determined by the current
1958 device scope when this function is called.
1960 Args:
1961 name: A name for the operation (optional).
1963 Returns:
1964 The tuple of tensors that was gotten.
1965 """
1966 if name is None:
1967 name = "%s_get" % self._name
1969 # pylint: disable=bad-continuation
1970 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
1971 shared_name=self._name, name=name,
1972 capacity=self._capacity,
1973 memory_limit=self._memory_limit)
1974 # pylint: enable=bad-continuation
1976 return self.__internal_get(fn, name)
1978 def peek(self, index, name=None):
1979 """Peeks at an element in the staging area.
1981 If the staging area is too small to contain the element at
1982 the specified index, it will block until enough elements
1983 are inserted to complete the operation.
1985 The placement of the returned tensor will be determined by
1986 the current device scope when this function is called.
1988 Args:
1989 index: The index of the tensor within the staging area
1990 to look up.
1991 name: A name for the operation (optional).
1993 Returns:
1994 The tuple of tensors that was gotten.
1995 """
1996 if name is None:
1997 name = "%s_peek" % self._name
1999 # pylint: disable=bad-continuation
2000 fn = lambda: gen_data_flow_ops.stage_peek(index,
2001 dtypes=self._dtypes, shared_name=self._name,
2002 name=name, capacity=self._capacity,
2003 memory_limit=self._memory_limit)
2004 # pylint: enable=bad-continuation
2006 return self.__internal_get(fn, name)
2008 def size(self, name=None):
2009 """Returns the number of elements in the staging area.
2011 Args:
2012 name: A name for the operation (optional)
2014 Returns:
2015 The created op
2016 """
2017 if name is None:
2018 name = "%s_size" % self._name
2020 return gen_data_flow_ops.stage_size(
2021 name=name,
2022 shared_name=self._name,
2023 dtypes=self._dtypes,
2024 capacity=self._capacity,
2025 memory_limit=self._memory_limit)
2027 def clear(self, name=None):
2028 """Clears the staging area.
2030 Args:
2031 name: A name for the operation (optional)
2033 Returns:
2034 The created op
2035 """
2036 if name is None:
2037 name = "%s_clear" % self._name
2039 return gen_data_flow_ops.stage_clear(
2040 name=name,
2041 shared_name=self._name,
2042 dtypes=self._dtypes,
2043 capacity=self._capacity,
2044 memory_limit=self._memory_limit)
2047class MapStagingArea(BaseStagingArea):
2048 """A `MapStagingArea` is a TensorFlow data structure that stores tensors
2049 across multiple steps, and exposes operations that can put and get tensors.
2051 Each `MapStagingArea` element is a (key, value) pair.
2052 Only int64 keys are supported, other types should be
2053 hashed to produce a key.
2054 Values are a tuple of one or more tensors.
2055 Each tuple component has a static dtype,
2056 and may have a static shape.
2058 The capacity of a `MapStagingArea` may be bounded or unbounded.
2059 It supports multiple concurrent producers and consumers; and
2060 provides exactly-once delivery.
2062 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
2063 whose
2064 dtypes are described by `dtypes`, and whose shapes are optionally described
2065 by the `shapes` argument.
2067 If the `shapes` argument is specified, each component of a staging area
2068 element must have the respective fixed shape. If it is
2069 unspecified, different elements may have different shapes,
2071 It behaves like an associative container with support for:
2073 - put(key, values)
2074 - peek(key) like dict.get(key)
2075 - get(key) like dict.pop(key)
2076 - get(key=None) like dict.popitem()
2077 - size()
2078 - clear()
2080 If ordered a tree structure ordered by key will be used and
2081 get(key=None) will remove (key, value) pairs in increasing key order.
2082 Otherwise a hashtable
2084 It can be configured with a capacity in which case
2085 put(key, values) will block until space becomes available.
2087 Similarly, it can be configured with a memory limit which
2088 will block put(key, values) until space is available.
2089 This is mostly useful for limiting the number of tensors on
2090 devices such as GPUs.
2092 All get() and peek() commands block if the requested
2093 (key, value) pair is not present in the staging area.
2095 Partial puts are supported and will be placed in an incomplete
2096 map until such time as all values associated with the key have
2097 been inserted. Once completed, this (key, value) pair will be
2098 inserted into the map. Data in the incomplete map
2099 counts towards the memory limit, but not towards capacity limit.
2101 Partial gets from the map are also supported.
2102 This removes the partially requested tensors from the entry,
2103 but the entry is only removed from the map once all tensors
2104 associated with it are removed.
2105 """
2107 def __init__(self,
2108 dtypes,
2109 shapes=None,
2110 names=None,
2111 shared_name=None,
2112 ordered=False,
2113 capacity=0,
2114 memory_limit=0):
2115 """Args:
2117 dtypes: A list of types. The length of dtypes must equal the number
2118 of tensors in each element.
2119 capacity: (Optional.) Maximum number of elements.
2120 An integer. If zero, the Staging Area is unbounded
2121 memory_limit: (Optional.) Maximum number of bytes of all tensors
2122 in the Staging Area (excluding keys).
2123 An integer. If zero, the Staging Area is unbounded
2124 ordered: (Optional.) If True the underlying data structure
2125 is a tree ordered on key. Otherwise assume a hashtable.
2126 shapes: (Optional.) Constraints on the shapes of tensors in an element.
2127 A list of shape tuples or None. This list is the same length
2128 as dtypes. If the shape of any tensors in the element are constrained,
2129 all must be; shapes can be None if the shapes should not be constrained.
2130 names: (Optional.) If provided, the `get()` and
2131 `put()` methods will use dictionaries with these names as keys.
2132 Must be None or a list or tuple of the same length as `dtypes`.
2133 shared_name: (Optional.) A name to be used for the shared object. By
2134 passing the same name to two different python objects they will share
2135 the underlying staging area. Must be a string.
2137 Raises:
2138 ValueError: If one of the arguments is invalid.
2140 """
2142 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
2143 capacity, memory_limit)
2145 # Defer to different methods depending if the map is ordered
2146 self._ordered = ordered
2148 if ordered:
2149 self._put_fn = gen_data_flow_ops.ordered_map_stage
2150 self._pop_fn = gen_data_flow_ops.ordered_map_unstage
2151 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
2152 self._peek_fn = gen_data_flow_ops.ordered_map_peek
2153 self._size_fn = gen_data_flow_ops.ordered_map_size
2154 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
2155 self._clear_fn = gen_data_flow_ops.ordered_map_clear
2156 else:
2157 self._put_fn = gen_data_flow_ops.map_stage
2158 self._pop_fn = gen_data_flow_ops.map_unstage
2159 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
2160 self._peek_fn = gen_data_flow_ops.map_peek
2161 self._size_fn = gen_data_flow_ops.map_size
2162 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
2163 self._clear_fn = gen_data_flow_ops.map_clear
2165 def put(self, key, vals, indices=None, name=None):
2166 """Create an op that stores the (key, vals) pair in the staging area.
2168 Incomplete puts are possible, preferably using a dictionary for vals
2169 as the appropriate dtypes and shapes can be inferred from the value names
2170 dictionary key values. If vals is a list or tuple, indices must
2171 also be specified so that the op knows at which element position
2172 to perform the insert.
2174 This operation will block if the capacity or memory limit of this
2175 container is reached.
2177 Args:
2178 key: Key associated with the data
2179 vals: Tensor (or a dict/tuple of Tensors) to place
2180 into the staging area.
2181 indices: (Optional) if vals is a tuple/list, this is required.
2182 name: A name for the operation (optional)
2184 Returns:
2185 The created op
2187 Raises:
2188 ValueError: If the number or type of inputs don't match the staging
2189 area.
2190 """
2192 with ops.name_scope(name, "%s_put" % self._name,
2193 self._scope_vals(vals)) as scope:
2195 vals, indices = self._check_put_dtypes(vals, indices)
2197 with ops.colocate_with(self._coloc_op):
2198 op = self._put_fn(
2199 key,
2200 indices,
2201 vals,
2202 dtypes=self._dtypes,
2203 shared_name=self._name,
2204 name=scope,
2205 capacity=self._capacity,
2206 memory_limit=self._memory_limit)
2207 return op
2209 def _get_indices_and_dtypes(self, indices=None):
2210 if indices is None:
2211 indices = list(range(len(self._dtypes)))
2213 if not isinstance(indices, (tuple, list)):
2214 raise TypeError(f"Invalid indices type {type(indices)}")
2216 if len(indices) == 0:
2217 raise ValueError("Empty indices")
2219 if all(isinstance(i, str) for i in indices):
2220 if self._names is None:
2221 raise ValueError(f"String indices provided {indices}, but "
2222 "this Staging Area was not created with names.")
2224 try:
2225 indices = [self._names.index(n) for n in indices]
2226 except ValueError:
2227 raise ValueError(f"Named index not in "
2228 f"Staging Area names {self._names}")
2229 elif all(isinstance(i, int) for i in indices):
2230 pass
2231 else:
2232 raise TypeError(f"Mixed types in indices {indices}. "
2233 "May only be str or int")
2235 dtypes = [self._dtypes[i] for i in indices]
2237 return indices, dtypes
2239 def peek(self, key, indices=None, name=None):
2240 """Peeks at staging area data associated with the key.
2242 If the key is not in the staging area, it will block
2243 until the associated (key, value) is inserted.
2245 Args:
2246 key: Key associated with the required data
2247 indices: Partial list of tensors to retrieve (optional).
2248 A list of integer or string indices.
2249 String indices are only valid if the Staging Area
2250 has names associated with it.
2251 name: A name for the operation (optional)
2253 Returns:
2254 The created op
2255 """
2257 if name is None:
2258 name = "%s_pop" % self._name
2260 indices, dtypes = self._get_indices_and_dtypes(indices)
2262 with ops.colocate_with(self._coloc_op):
2263 result = self._peek_fn(
2264 key,
2265 shared_name=self._name,
2266 indices=indices,
2267 dtypes=dtypes,
2268 name=name,
2269 capacity=self._capacity,
2270 memory_limit=self._memory_limit)
2272 return self._get_return_value(result, indices)
2274 def get(self, key=None, indices=None, name=None):
2275 """If the key is provided, the associated (key, value) is returned from the staging area.
2277 If the key is not in the staging area, this method will block until
2278 the associated (key, value) is inserted.
2279 If no key is provided and the staging area is ordered,
2280 the (key, value) with the smallest key will be returned.
2281 Otherwise, a random (key, value) will be returned.
2283 If the staging area is empty when this operation executes,
2284 it will block until there is an element to dequeue.
2286 Args:
2287 key: Key associated with the required data (Optional)
2288 indices: Partial list of tensors to retrieve (optional).
2289 A list of integer or string indices.
2290 String indices are only valid if the Staging Area
2291 has names associated with it.
2292 name: A name for the operation (optional)
2294 Returns:
2295 The created op
2296 """
2297 if key is None:
2298 return self._popitem(indices=indices, name=name)
2299 else:
2300 return self._pop(key, indices=indices, name=name)
2302 def _pop(self, key, indices=None, name=None):
2303 """Remove and return the associated (key, value) is returned from the staging area.
2305 If the key is not in the staging area, this method will block until
2306 the associated (key, value) is inserted.
2307 Args:
2308 key: Key associated with the required data
2309 indices: Partial list of tensors to retrieve (optional).
2310 A list of integer or string indices.
2311 String indices are only valid if the Staging Area
2312 has names associated with it.
2313 name: A name for the operation (optional)
2315 Returns:
2316 The created op
2317 """
2318 if name is None:
2319 name = "%s_get" % self._name
2321 indices, dtypes = self._get_indices_and_dtypes(indices)
2323 with ops.colocate_with(self._coloc_op):
2324 result = self._pop_fn(
2325 key,
2326 shared_name=self._name,
2327 indices=indices,
2328 dtypes=dtypes,
2329 name=name,
2330 capacity=self._capacity,
2331 memory_limit=self._memory_limit)
2333 return key, self._get_return_value(result, indices)
2335 def _popitem(self, indices=None, name=None):
2336 """If the staging area is ordered, the (key, value) with the smallest key will be returned.
2338 Otherwise, a random (key, value) will be returned.
2339 If the staging area is empty when this operation executes,
2340 it will block until there is an element to dequeue.
2342 Args:
2343 key: Key associated with the required data
2344 indices: Partial list of tensors to retrieve (optional).
2345 A list of integer or string indices.
2346 String indices are only valid if the Staging Area
2347 has names associated with it.
2348 name: A name for the operation (optional)
2350 Returns:
2351 The created op
2352 """
2353 if name is None:
2354 name = "%s_get_nokey" % self._name
2356 indices, dtypes = self._get_indices_and_dtypes(indices)
2358 with ops.colocate_with(self._coloc_op):
2359 key, result = self._popitem_fn(
2360 shared_name=self._name,
2361 indices=indices,
2362 dtypes=dtypes,
2363 name=name,
2364 capacity=self._capacity,
2365 memory_limit=self._memory_limit)
2367 # Separate keys and results out from
2368 # underlying namedtuple
2369 key = self._create_device_transfers(key)[0]
2370 result = self._get_return_value(result, indices)
2372 return key, result
2374 def size(self, name=None):
2375 """Returns the number of elements in the staging area.
2377 Args:
2378 name: A name for the operation (optional)
2380 Returns:
2381 The created op
2382 """
2383 if name is None:
2384 name = "%s_size" % self._name
2386 return self._size_fn(
2387 shared_name=self._name,
2388 name=name,
2389 dtypes=self._dtypes,
2390 capacity=self._capacity,
2391 memory_limit=self._memory_limit)
2393 def incomplete_size(self, name=None):
2394 """Returns the number of incomplete elements in the staging area.
2396 Args:
2397 name: A name for the operation (optional)
2399 Returns:
2400 The created op
2401 """
2402 if name is None:
2403 name = "%s_incomplete_size" % self._name
2405 return self._incomplete_size_fn(
2406 shared_name=self._name,
2407 name=name,
2408 dtypes=self._dtypes,
2409 capacity=self._capacity,
2410 memory_limit=self._memory_limit)
2412 def clear(self, name=None):
2413 """Clears the staging area.
2415 Args:
2416 name: A name for the operation (optional)
2418 Returns:
2419 The created op
2420 """
2421 if name is None:
2422 name = "%s_clear" % self._name
2424 return self._clear_fn(
2425 shared_name=self._name,
2426 name=name,
2427 dtypes=self._dtypes,
2428 capacity=self._capacity,
2429 memory_limit=self._memory_limit)
2432class RecordInput:
2433 """RecordInput asynchronously reads and randomly yields TFRecords.
2435 A RecordInput Op will continuously read a batch of records asynchronously
2436 into a buffer of some fixed capacity. It can also asynchronously yield
2437 random records from this buffer.
2439 It will not start yielding until at least `buffer_size / 2` elements have been
2440 placed into the buffer so that sufficient randomization can take place.
2442 The order the files are read will be shifted each epoch by `shift_amount` so
2443 that the data is presented in a different order every epoch.
2444 """
2446 def __init__(self,
2447 file_pattern,
2448 batch_size=1,
2449 buffer_size=1,
2450 parallelism=1,
2451 shift_ratio=0,
2452 seed=0,
2453 name=None,
2454 batches=None,
2455 compression_type=None):
2456 """Constructs a RecordInput Op.
2458 Args:
2459 file_pattern: File path to the dataset, possibly containing wildcards.
2460 All matching files will be iterated over each epoch.
2461 batch_size: How many records to return at a time.
2462 buffer_size: The maximum number of records the buffer will contain.
2463 parallelism: How many reader threads to use for reading from files.
2464 shift_ratio: What percentage of the total number files to move the start
2465 file forward by each epoch.
2466 seed: Specify the random number seed used by generator that randomizes
2467 records.
2468 name: Optional name for the operation.
2469 batches: None by default, creating a single batch op. Otherwise specifies
2470 how many batches to create, which are returned as a list when
2471 `get_yield_op()` is called. An example use case is to split processing
2472 between devices on one computer.
2473 compression_type: The type of compression for the file. Currently ZLIB and
2474 GZIP are supported. Defaults to none.
2476 Raises:
2477 ValueError: If one of the arguments is invalid.
2478 """
2479 self._batch_size = batch_size
2480 if batches is not None:
2481 self._batch_size *= batches
2482 self._batches = batches
2483 self._file_pattern = file_pattern
2484 self._buffer_size = buffer_size
2485 self._parallelism = parallelism
2486 self._shift_ratio = shift_ratio
2487 self._seed = seed
2488 self._name = name
2489 self._compression_type = python_io.TFRecordCompressionType.NONE
2490 if compression_type is not None:
2491 self._compression_type = compression_type
2493 def get_yield_op(self):
2494 """Adds a node that yields a group of records every time it is executed.
2495 If RecordInput `batches` parameter is not None, it yields a list of
2496 record batches with the specified `batch_size`.
2497 """
2498 compression_type = python_io.TFRecordOptions.get_compression_type_string(
2499 python_io.TFRecordOptions(self._compression_type))
2500 records = gen_data_flow_ops.record_input(
2501 file_pattern=self._file_pattern,
2502 file_buffer_size=self._buffer_size,
2503 file_parallelism=self._parallelism,
2504 file_shuffle_shift_ratio=self._shift_ratio,
2505 batch_size=self._batch_size,
2506 file_random_seed=self._seed,
2507 compression_type=compression_type,
2508 name=self._name)
2509 if self._batches is None:
2510 return records
2511 else:
2512 with ops.name_scope(self._name):
2513 batch_list = [[] for _ in range(self._batches)]
2514 records = array_ops.split(records, self._batch_size, 0)
2515 for index, protobuf in enumerate(records):
2516 batch_index = index % self._batches
2517 batch_list[batch_index].append(array_ops.reshape(protobuf, []))
2518 return batch_list