Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/state_ops.py: 37%
128 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Variables.
18See the [Variables](https://www.tensorflow.org/guide/variables) guide.
19"""
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_math_ops
25from tensorflow.python.ops import gen_resource_variable_ops
26from tensorflow.python.ops import gen_state_ops
27# go/tf-wildcard-import
28# pylint: disable=wildcard-import
29from tensorflow.python.ops.gen_state_ops import *
30# pylint: enable=wildcard-import
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.deprecation import deprecated
33from tensorflow.python.util.tf_export import tf_export
36# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args
37def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
38 shared_name=""):
39 """Deprecated. Used variable_op_v2 instead."""
40 if not set_shape:
41 shape = tensor_shape.unknown_shape()
42 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name,
43 container=container, shared_name=shared_name)
44 # TODO(mrry): Move this to where it is used, so we can get rid of this op
45 # wrapper?
46 if set_shape:
47 ret.set_shape(shape)
48 return ret
51def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""):
52 """Create a variable Operation.
54 See also variables.Variable.
56 Args:
57 shape: The shape of the tensor managed by this variable
58 dtype: The underlying type of the tensor values.
59 name: optional name to use for the variable op.
60 container: An optional string. Defaults to "".
61 If non-empty, this variable is placed in the given container.
62 Otherwise, a default container is used.
63 shared_name: An optional string. Defaults to "".
64 If non-empty, this variable is named in the given bucket
65 with this shared_name. Otherwise, the node name is used instead.
67 Returns:
68 A variable tensor.
69 """
70 return gen_state_ops.variable_v2(
71 shape=shape,
72 dtype=dtype,
73 name=name,
74 container=container,
75 shared_name=shared_name)
78def init_variable(v, init, name="init"):
79 """Initializes variable with "init".
81 This op does the following:
82 if init is a Tensor, v = init
83 if callable(init): v = init(VariableShape(v), v.dtype)
85 Args:
86 v: Variable to initialize
87 init: Tensor to assign to v,
88 Or an object convertible to Tensor e.g. nparray,
89 Or an Initializer that generates a tensor given the shape and type of v.
90 An "Initializer" is a callable that returns a tensor that "v" should be
91 set to. It will be called as init(shape, dtype).
92 name: Optional name for the op.
94 Returns:
95 The operation that initializes v.
96 """
97 with ops.name_scope(None, v.op.name + "/", [v, init]):
98 with ops.name_scope(name) as scope:
99 with ops.colocate_with(v):
100 if callable(init):
101 assert v.get_shape().is_fully_defined(), "Variable shape unknown."
102 # TODO(mrry): Convert to v.shape when the property and
103 # accessor are reconciled (and all initializers support
104 # tf.TensorShape objects).
105 value = init(v.get_shape().as_list(), v.dtype.base_dtype)
106 value = ops.convert_to_tensor(value, name="value")
107 return gen_state_ops.assign(v, value, name=scope)
108 else:
109 init = ops.convert_to_tensor(init, name="init")
110 return gen_state_ops.assign(v, init, name=scope)
113def is_variable_initialized(ref, name=None):
114 """Checks whether a tensor has been initialized.
116 Outputs boolean scalar indicating whether the tensor has been initialized.
118 Args:
119 ref: A mutable `Tensor`.
120 Should be from a `Variable` node. May be uninitialized.
121 name: A name for the operation (optional).
123 Returns:
124 A `Tensor` of type `bool`.
125 """
126 if ref.dtype._is_ref_dtype:
127 return gen_state_ops.is_variable_initialized(ref=ref, name=name)
128 # Handle resource variables.
129 return ref.is_initialized(name=name)
132@tf_export(v1=["assign_sub"])
133def assign_sub(ref, value, use_locking=None, name=None):
134 """Update `ref` by subtracting `value` from it.
136 This operation outputs `ref` after the update is done.
137 This makes it easier to chain operations that need to use the reset value.
138 Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value`
139 must have the same shape.
141 Args:
142 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
143 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
144 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
145 from a `Variable` node.
146 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
147 be subtracted to the variable.
148 use_locking: An optional `bool`. Defaults to `False`. If True, the
149 subtraction will be protected by a lock; otherwise the behavior is
150 undefined, but may exhibit less contention.
151 name: A name for the operation (optional).
153 Returns:
154 Same as `ref`. Returned as a convenience for operations that want
155 to use the new value after the variable has been updated.
157 @compatibility(TF2)
158 `tf.compat.v1.assign_sub` is mostly compatible with eager
159 execution and `tf.function`.
161 To switch to the native TF2 style, one could use method 'assign_sub' of
162 `tf.Variable`:
164 #### How to Map Arguments
166 | TF1 Arg Name | TF2 Arg Name | Note |
167 | :-------------------- | :-------------- | :------------------------- |
168 | `ref` | `self` | In `assign_sub()` method |
169 | `value` | `value` | In `assign_sub()` method |
170 | `use_locking` | `use_locking` | In `assign_sub()` method |
171 | `name` | `name` | In `assign_sub()` method |
172 | - | `read_value` | Set to True to replicate |
173 : : : behavior (True is default) :
176 #### Before & After Usage Example
178 Before:
180 >>> with tf.Graph().as_default():
181 ... with tf.compat.v1.Session() as sess:
182 ... a = tf.compat.v1.Variable(1, dtype=tf.int64)
183 ... sess.run(a.initializer)
184 ... update_op = tf.compat.v1.assign_sub(a, 1)
185 ... res_a = sess.run(update_op)
186 ... res_a
187 0
189 After:
191 >>> b = tf.Variable(1, dtype=tf.int64)
192 >>> res_b = b.assign_sub(1)
193 >>> res_b.numpy()
194 0
196 @end_compatibility
197 """
198 if ref.dtype._is_ref_dtype:
199 return gen_state_ops.assign_sub(
200 ref, value, use_locking=use_locking, name=name)
201 return ref.assign_sub(value)
204@tf_export(v1=["assign_add"])
205def assign_add(ref, value, use_locking=None, name=None):
206 """Update `ref` by adding `value` to it.
208 This operation outputs `ref` after the update is done.
209 This makes it easier to chain operations that need to use the reset value.
210 Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have
211 the same shape.
213 Args:
214 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
215 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`,
216 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be
217 from a `Variable` node.
218 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
219 be added to the variable.
220 use_locking: An optional `bool`. Defaults to `False`. If True, the addition
221 will be protected by a lock; otherwise the behavior is undefined, but may
222 exhibit less contention.
223 name: A name for the operation (optional).
225 Returns:
226 Same as `ref`. Returned as a convenience for operations that want
227 to use the new value after the variable has been updated.
229 @compatibility(TF2)
230 `tf.compat.v1.assign_add` is mostly compatible with eager
231 execution and `tf.function`.
233 To switch to the native TF2 style, one could use method 'assign_add' of
234 `tf.Variable`:
236 #### How to Map Arguments
238 | TF1 Arg Name | TF2 Arg Name | Note |
239 | :-------------------- | :-------------- | :------------------------- |
240 | `ref` | `self` | In `assign_add()` method |
241 | `value` | `value` | In `assign_add()` method |
242 | `use_locking` | `use_locking` | In `assign_add()` method |
243 | `name` | `name` | In `assign_add()` method |
244 | - | `read_value` | Set to True to replicate |
245 : : : behavior (True is default) :
248 #### Before & After Usage Example
250 Before:
252 >>> with tf.Graph().as_default():
253 ... with tf.compat.v1.Session() as sess:
254 ... a = tf.compat.v1.Variable(0, dtype=tf.int64)
255 ... sess.run(a.initializer)
256 ... update_op = tf.compat.v1.assign_add(a, 1)
257 ... res_a = sess.run(update_op)
258 ... res_a
259 1
261 After:
263 >>> b = tf.Variable(0, dtype=tf.int64)
264 >>> res_b = b.assign_add(1)
265 >>> res_b.numpy()
266 1
268 @end_compatibility
269 """
270 if ref.dtype._is_ref_dtype:
271 return gen_state_ops.assign_add(
272 ref, value, use_locking=use_locking, name=name)
273 return ref.assign_add(value)
276@tf_export(v1=["assign"])
277def assign(ref, value, validate_shape=None, use_locking=None, name=None):
278 """Update `ref` by assigning `value` to it.
280 This operation outputs a Tensor that holds the new value of `ref` after
281 the value has been assigned. This makes it easier to chain operations that
282 need to use the reset value.
284 Args:
285 ref: A mutable `Tensor`. Should be from a `Variable` node. May be
286 uninitialized.
287 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to
288 be assigned to the variable.
289 validate_shape: An optional `bool`. Defaults to `True`. If true, the
290 operation will validate that the shape of 'value' matches the shape of the
291 Tensor being assigned to. If false, 'ref' will take on the shape of
292 'value'.
293 use_locking: An optional `bool`. Defaults to `True`. If True, the assignment
294 will be protected by a lock; otherwise the behavior is undefined, but may
295 exhibit less contention.
296 name: A name for the operation (optional).
298 Returns:
299 A `Tensor` that will hold the new value of `ref` after
300 the assignment has completed.
302 @compatibility(TF2)
303 `tf.compat.v1.assign` is mostly compatible with eager
304 execution and `tf.function`. However, argument 'validate_shape' will be
305 ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when
306 constructing the variable:
308 >>> import tensorflow as tf
309 >>> a = tf.Variable([1], shape=tf.TensorShape(None))
310 >>> tf.compat.v1.assign(a, [2,3])
312 To switch to the native TF2 style, one could use method 'assign' of
313 `tf.Variable`:
315 #### How to Map Arguments
317 | TF1 Arg Name | TF2 Arg Name | Note |
318 | :-------------------- | :-------------- | :------------------------- |
319 | `ref` | `self` | In `assign()` method |
320 | `value` | `value` | In `assign()` method |
321 | `validate_shape` | Not supported | Specify `shape` in the |
322 : : : constructor to replicate :
323 : : : behavior :
324 | `use_locking` | `use_locking` | In `assign()` method |
325 | `name` | `name` | In `assign()` method |
326 | - | `read_value` | Set to True to replicate |
327 : : : behavior (True is default) :
328 @end_compatibility
331 #### Before & After Usage Example
333 Before:
335 >>> with tf.Graph().as_default():
336 ... with tf.compat.v1.Session() as sess:
337 ... a = tf.compat.v1.Variable(0, dtype=tf.int64)
338 ... sess.run(a.initializer)
339 ... update_op = tf.compat.v1.assign(a, 2)
340 ... res_a = sess.run(update_op)
341 ... res_a
342 2
344 After:
346 >>> b = tf.Variable(0, dtype=tf.int64)
347 >>> res_b = b.assign(2)
348 >>> res_b.numpy()
349 2
350 """
351 if ref.dtype._is_ref_dtype:
352 return gen_state_ops.assign(
353 ref, value, use_locking=use_locking, name=name,
354 validate_shape=validate_shape)
355 return ref.assign(value, name=name)
358@tf_export(v1=["count_up_to"])
359@deprecated(None, "Prefer Dataset.range instead.")
360def count_up_to(ref, limit, name=None):
361 r"""Increments 'ref' until it reaches 'limit'.
363 Args:
364 ref: A Variable. Must be one of the following types: `int32`, `int64`.
365 Should be from a scalar `Variable` node.
366 limit: An `int`.
367 If incrementing ref would bring it above limit, instead generates an
368 'OutOfRange' error.
369 name: A name for the operation (optional).
371 Returns:
372 A `Tensor`. Has the same type as `ref`.
373 A copy of the input before increment. If nothing else modifies the
374 input, the values produced will all be distinct.
375 """
376 if ref.dtype._is_ref_dtype:
377 return gen_state_ops.count_up_to(ref, limit=limit, name=name)
378 return gen_state_ops.resource_count_up_to(
379 ref.handle, limit, T=ref.dtype, name=name)
382@tf_export(v1=["scatter_update"])
383def scatter_update(ref, indices, updates, use_locking=True, name=None):
384 # pylint: disable=line-too-long
385 r"""Applies sparse updates to a variable reference.
387 This operation computes
389 ```python
390 # Scalar indices
391 ref[indices, ...] = updates[...]
393 # Vector indices (for each i)
394 ref[indices[i], ...] = updates[i, ...]
396 # High rank indices (for each i, ..., j)
397 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...]
398 ```
400 This operation outputs `ref` after the update is done.
401 This makes it easier to chain operations that need to use the reset value.
403 If values in `ref` is to be updated more than once, because there are
404 duplicate entries in `indices`, the order at which the updates happen
405 for each value is undefined.
407 Requires `updates.shape = indices.shape + ref.shape[1:]`.
409 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
410 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
411 </div>
413 Args:
414 ref: A `Variable`.
415 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
416 A tensor of indices into the first dimension of `ref`.
417 updates: A `Tensor`. Must have the same type as `ref`.
418 A tensor of updated values to store in `ref`.
419 use_locking: An optional `bool`. Defaults to `True`.
420 If True, the assignment will be protected by a lock;
421 otherwise the behavior is undefined, but may exhibit less contention.
422 name: A name for the operation (optional).
424 Returns:
425 Same as `ref`. Returned as a convenience for operations that want
426 to use the updated values after the update is done.
427 """
428 if ref.dtype._is_ref_dtype:
429 return gen_state_ops.scatter_update(ref, indices, updates,
430 use_locking=use_locking, name=name)
431 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access
432 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
433 name=name))
436@tf_export(v1=["scatter_nd_update"])
437def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
438 r"""Applies sparse `updates` to individual values or slices in a Variable.
440 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
442 `indices` must be integer tensor, containing indices into `ref`.
443 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
445 The innermost dimension of `indices` (with length `K`) corresponds to
446 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
447 dimension of `ref`.
449 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
451 ```
452 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
453 ```
455 For example, say we want to update 4 scattered elements to a rank-1 tensor to
456 8 elements. In Python, that update would look like this:
458 ```python
459 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
460 indices = tf.constant([[4], [3], [1] ,[7]])
461 updates = tf.constant([9, 10, 11, 12])
462 update = tf.compat.v1.scatter_nd_update(ref, indices, updates)
463 with tf.compat.v1.Session() as sess:
464 print sess.run(update)
465 ```
467 The resulting update to ref would look like this:
469 [1, 11, 3, 10, 9, 6, 7, 12]
471 See `tf.scatter_nd` for more details about how to make updates to
472 slices.
474 Args:
475 ref: A Variable.
476 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
477 A tensor of indices into ref.
478 updates: A `Tensor`. Must have the same type as `ref`.
479 A Tensor. Must have the same type as ref. A tensor of updated
480 values to add to ref.
481 use_locking: An optional `bool`. Defaults to `True`.
482 An optional bool. Defaults to True. If True, the assignment will
483 be protected by a lock; otherwise the behavior is undefined,
484 but may exhibit less contention.
485 name: A name for the operation (optional).
487 Returns:
488 The value of the variable after the update.
489 """
490 if ref.dtype._is_ref_dtype:
491 return gen_state_ops.scatter_nd_update(
492 ref, indices, updates, use_locking, name)
493 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access
494 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
495 name=name))
498@tf_export(v1=["scatter_add"])
499def scatter_add(ref, indices, updates, use_locking=False, name=None):
500 # pylint: disable=line-too-long
501 r"""Adds sparse updates to the variable referenced by `resource`.
503 This operation computes
505 ```python
506 # Scalar indices
507 ref[indices, ...] += updates[...]
509 # Vector indices (for each i)
510 ref[indices[i], ...] += updates[i, ...]
512 # High rank indices (for each i, ..., j)
513 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...]
514 ```
516 This operation outputs `ref` after the update is done.
517 This makes it easier to chain operations that need to use the updated value.
518 Duplicate entries are handled correctly: if multiple `indices` reference
519 the same location, their contributions add.
521 Requires `updates.shape = indices.shape + ref.shape[1:]`.
523 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
524 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt>
525 </div>
527 Args:
528 ref: A `Variable`.
529 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
530 A tensor of indices into the first dimension of `ref`.
531 updates: A `Tensor`. Must have the same type as `ref`.
532 A tensor of updated values to store in `ref`.
533 use_locking: An optional `bool`. Defaults to `False`.
534 If True, the assignment will be protected by a lock;
535 otherwise the behavior is undefined, but may exhibit less contention.
536 name: A name for the operation (optional).
538 Returns:
539 Same as `ref`. Returned as a convenience for operations that want
540 to use the updated values after the update is done.
541 """
542 if ref.dtype._is_ref_dtype:
543 return gen_state_ops.scatter_add(ref, indices, updates,
544 use_locking=use_locking, name=name)
545 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access
546 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
547 name=name))
550@tf_export(v1=["scatter_nd_add"])
551def scatter_nd_add(ref, indices, updates, use_locking=False, name=None):
552 r"""Applies sparse addition to individual values or slices in a Variable.
554 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
556 `indices` must be integer tensor, containing indices into `ref`.
557 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
559 The innermost dimension of `indices` (with length `K`) corresponds to
560 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
561 dimension of `ref`.
563 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
565 ```
566 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
567 ```
569 For example, say we want to add 4 scattered elements to a rank-1 tensor to
570 8 elements. In Python, that addition would look like this:
572 ```python
573 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
574 indices = tf.constant([[4], [3], [1], [7]])
575 updates = tf.constant([9, 10, 11, 12])
576 add = tf.compat.v1.scatter_nd_add(ref, indices, updates)
577 with tf.compat.v1.Session() as sess:
578 print sess.run(add)
579 ```
581 The resulting update to ref would look like this:
583 [1, 13, 3, 14, 14, 6, 7, 20]
585 See `tf.scatter_nd` for more details about how to make updates to
586 slices.
588 Args:
589 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
590 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
591 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
592 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
593 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
594 A tensor of indices into ref.
595 updates: A `Tensor`. Must have the same type as `ref`.
596 A tensor of updated values to add to ref.
597 use_locking: An optional `bool`. Defaults to `False`.
598 If True, the assignment will be protected by a lock;
599 otherwise the behavior is undefined, but may exhibit less contention.
600 name: A name for the operation (optional).
602 Returns:
603 A mutable `Tensor`. Has the same type as `ref`.
604 """
605 if ref.dtype._is_ref_dtype:
606 return gen_state_ops.scatter_nd_add(
607 ref, indices, updates, use_locking, name)
608 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access
609 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
610 name=name))
613@tf_export(v1=["scatter_sub"])
614def scatter_sub(ref, indices, updates, use_locking=False, name=None):
615 r"""Subtracts sparse updates to a variable reference.
617 ```python
618 # Scalar indices
619 ref[indices, ...] -= updates[...]
621 # Vector indices (for each i)
622 ref[indices[i], ...] -= updates[i, ...]
624 # High rank indices (for each i, ..., j)
625 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...]
626 ```
628 This operation outputs `ref` after the update is done.
629 This makes it easier to chain operations that need to use the reset value.
631 Duplicate entries are handled correctly: if multiple `indices` reference
632 the same location, their (negated) contributions add.
634 Requires `updates.shape = indices.shape + ref.shape[1:]` or
635 `updates.shape = []`.
637 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
638 <img style="width:100%"
639 src="https://www.tensorflow.org/images/ScatterSub.png" alt>
640 </div>
642 Args:
643 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
644 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
645 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
646 `uint32`, `uint64`. Should be from a `Variable` node.
647 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
648 A tensor of indices into the first dimension of `ref`.
649 updates: A `Tensor`. Must have the same type as `ref`.
650 A tensor of updated values to subtract from `ref`.
651 use_locking: An optional `bool`. Defaults to `False`.
652 If True, the subtraction will be protected by a lock;
653 otherwise the behavior is undefined, but may exhibit less contention.
654 name: A name for the operation (optional).
656 Returns:
657 A mutable `Tensor`. Has the same type as `ref`.
658 """
659 if ref.dtype._is_ref_dtype:
660 return gen_state_ops.scatter_sub(ref, indices, updates,
661 use_locking=use_locking, name=name)
662 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access
663 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
664 name=name))
667@tf_export(v1=["scatter_nd_sub"])
668def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None):
669 r"""Applies sparse subtraction to individual values or slices in a Variable.
671 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
673 `indices` must be integer tensor, containing indices into `ref`.
674 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
676 The innermost dimension of `indices` (with length `K`) corresponds to
677 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
678 dimension of `ref`.
680 `updates` is `Tensor` of rank `Q-1+P-K` with shape:
682 ```
683 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]
684 ```
686 For example, say we want to subtract 4 scattered elements from a rank-1 tensor
687 with 8 elements. In Python, that update would look like this:
689 ```python
690 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
691 indices = tf.constant([[4], [3], [1] ,[7]])
692 updates = tf.constant([9, 10, 11, 12])
693 op = tf.compat.v1.scatter_nd_sub(ref, indices, updates)
694 with tf.compat.v1.Session() as sess:
695 print sess.run(op)
696 ```
698 The resulting update to ref would look like this:
700 [1, -9, 3, -6, -6, 6, 7, -4]
702 See `tf.scatter_nd` for more details about how to make updates to
703 slices.
705 Args:
706 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
707 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
708 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
709 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node.
710 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
711 A tensor of indices into ref.
712 updates: A `Tensor`. Must have the same type as `ref`.
713 A tensor of updated values to add to ref.
714 use_locking: An optional `bool`. Defaults to `False`.
715 An optional bool. Defaults to True. If True, the assignment will
716 be protected by a lock; otherwise the behavior is undefined,
717 but may exhibit less contention.
718 name: A name for the operation (optional).
720 Returns:
721 A mutable `Tensor`. Has the same type as `ref`.
722 """
723 if ref.dtype._is_ref_dtype:
724 return gen_state_ops.scatter_nd_sub(
725 ref, indices, updates, use_locking, name)
726 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access
727 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
728 name=name))
731@tf_export(v1=["scatter_mul"])
732def scatter_mul(ref, indices, updates, use_locking=False, name=None):
733 # pylint: disable=line-too-long
734 r"""Multiplies sparse updates into a variable reference.
736 This operation computes
738 ```python
739 # Scalar indices
740 ref[indices, ...] *= updates[...]
742 # Vector indices (for each i)
743 ref[indices[i], ...] *= updates[i, ...]
745 # High rank indices (for each i, ..., j)
746 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...]
747 ```
749 This operation outputs `ref` after the update is done.
750 This makes it easier to chain operations that need to use the reset value.
752 Duplicate entries are handled correctly: if multiple `indices` reference
753 the same location, their contributions multiply.
755 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
756 []`.
758 Args:
759 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
760 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
761 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
762 `uint32`, `uint64`. Should be from a `Variable` node.
763 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
764 tensor of indices into the first dimension of `ref`.
765 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
766 values to multiply to `ref`.
767 use_locking: An optional `bool`. Defaults to `False`. If True, the operation
768 will be protected by a lock; otherwise the behavior is undefined, but may
769 exhibit less contention.
770 name: A name for the operation (optional).
772 Returns:
773 A mutable `Tensor`. Has the same type as `ref`.
774 """
775 if ref.dtype._is_ref_dtype:
776 return gen_state_ops.scatter_mul(ref, indices, updates,
777 use_locking=use_locking, name=name)
778 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access
779 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
780 name=name))
783@tf_export(v1=["scatter_div"])
784def scatter_div(ref, indices, updates, use_locking=False, name=None):
785 # pylint: disable=line-too-long
786 r"""Divides a variable reference by sparse updates.
788 This operation computes
790 ```python
791 # Scalar indices
792 ref[indices, ...] /= updates[...]
794 # Vector indices (for each i)
795 ref[indices[i], ...] /= updates[i, ...]
797 # High rank indices (for each i, ..., j)
798 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...]
799 ```
801 This operation outputs `ref` after the update is done.
802 This makes it easier to chain operations that need to use the reset value.
804 Duplicate entries are handled correctly: if multiple `indices` reference
805 the same location, their contributions divide.
807 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
808 []`.
810 Args:
811 ref: A mutable `Tensor`. Must be one of the following types: `float32`,
812 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`,
813 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`,
814 `uint32`, `uint64`. Should be from a `Variable` node.
815 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
816 tensor of indices into the first dimension of `ref`.
817 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values
818 that `ref` is divided by.
819 use_locking: An optional `bool`. Defaults to `False`. If True, the operation
820 will be protected by a lock; otherwise the behavior is undefined, but may
821 exhibit less contention.
822 name: A name for the operation (optional).
824 Returns:
825 A mutable `Tensor`. Has the same type as `ref`.
826 """
827 if ref.dtype._is_ref_dtype:
828 return gen_state_ops.scatter_div(ref, indices, updates,
829 use_locking=use_locking, name=name)
830 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access
831 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
832 name=name))
835@tf_export(v1=["scatter_max"])
836def scatter_max(ref, indices, updates, use_locking=False, name=None):
837 # pylint: disable=line-too-long
838 r"""Reduces sparse updates into a variable reference using the `max` operation.
840 This operation computes
842 # Scalar indices
843 ref[indices, ...] = max(ref[indices, ...], updates[...])
845 # Vector indices (for each i)
846 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...])
848 # High rank indices (for each i, ..., j)
849 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...],
850 updates[i, ..., j, ...])
852 This operation outputs `ref` after the update is done.
853 This makes it easier to chain operations that need to use the reset value.
855 Duplicate entries are handled correctly: if multiple `indices` reference
856 the same location, their contributions combine.
858 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
859 []`.
861 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
862 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
863 alt>
864 </div>
866 Args:
867 ref: A mutable `Tensor`. Must be one of the following types: `half`,
868 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
869 `Variable` node.
870 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
871 tensor of indices into the first dimension of `ref`.
872 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
873 values to reduce into `ref`.
874 use_locking: An optional `bool`. Defaults to `False`. If True, the update
875 will be protected by a lock; otherwise the behavior is undefined, but may
876 exhibit less contention.
877 name: A name for the operation (optional).
879 Returns:
880 A mutable `Tensor`. Has the same type as `ref`.
881 """
882 if ref.dtype._is_ref_dtype:
883 return gen_state_ops.scatter_max(ref, indices, updates,
884 use_locking=use_locking, name=name)
885 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access
886 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
887 name=name))
890@tf_export(v1=["scatter_min"])
891def scatter_min(ref, indices, updates, use_locking=False, name=None):
892 # pylint: disable=line-too-long
893 r"""Reduces sparse updates into a variable reference using the `min` operation.
895 This operation computes
897 # Scalar indices
898 ref[indices, ...] = min(ref[indices, ...], updates[...])
900 # Vector indices (for each i)
901 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...])
903 # High rank indices (for each i, ..., j)
904 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...],
905 updates[i, ..., j, ...])
907 This operation outputs `ref` after the update is done.
908 This makes it easier to chain operations that need to use the reset value.
910 Duplicate entries are handled correctly: if multiple `indices` reference
911 the same location, their contributions combine.
913 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape =
914 []`.
916 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
917 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png"
918 alt>
919 </div>
921 Args:
922 ref: A mutable `Tensor`. Must be one of the following types: `half`,
923 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a
924 `Variable` node.
925 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A
926 tensor of indices into the first dimension of `ref`.
927 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated
928 values to reduce into `ref`.
929 use_locking: An optional `bool`. Defaults to `False`. If True, the update
930 will be protected by a lock; otherwise the behavior is undefined, but may
931 exhibit less contention.
932 name: A name for the operation (optional).
934 Returns:
935 A mutable `Tensor`. Has the same type as `ref`.
936 """
937 if ref.dtype._is_ref_dtype:
938 return gen_state_ops.scatter_min(ref, indices, updates,
939 use_locking=use_locking, name=name)
940 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access
941 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
942 name=name))
945@tf_export(v1=["batch_scatter_update"])
946@deprecation.deprecated(
947 "2018-11-29", "Use the batch_scatter_update method of Variable instead.")
948def batch_scatter_update(ref, indices, updates, use_locking=True, name=None):
949 """Generalization of `tf.compat.v1.scatter_update` to axis different than 0.
951 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates`
952 have a series of leading dimensions that are the same for all of them, and the
953 updates are performed on the last dimension of indices. In other words, the
954 dimensions should be the following:
956 `num_prefix_dims = indices.ndims - 1`
957 `batch_dim = num_prefix_dims + 1`
958 `updates.shape = indices.shape + var.shape[batch_dim:]`
960 where
962 `updates.shape[:num_prefix_dims]`
963 `== indices.shape[:num_prefix_dims]`
964 `== var.shape[:num_prefix_dims]`
966 And the operation performed can be expressed as:
968 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]`
970 When indices is a 1D tensor, this operation is equivalent to
971 `tf.compat.v1.scatter_update`.
973 To avoid this operation there would be 2 alternatives:
974 1) Reshaping the variable by merging the first `ndims` dimensions. However,
975 this is not possible because `tf.reshape` returns a Tensor, which we
976 cannot use `tf.compat.v1.scatter_update` on.
977 2) Looping over the first `ndims` of the variable and using
978 `tf.compat.v1.scatter_update` on the subtensors that result of slicing the
979 first
980 dimension. This is a valid option for `ndims = 1`, but less efficient than
981 this implementation.
983 See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`.
985 Args:
986 ref: `Variable` to scatter onto.
987 indices: Tensor containing indices as described above.
988 updates: Tensor of updates to apply to `ref`.
989 use_locking: Boolean indicating whether to lock the writing operation.
990 name: Optional scope name string.
992 Returns:
993 Ref to `variable` after it has been modified.
995 Raises:
996 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are
997 not the same.
998 """
999 with ops.name_scope(name):
1000 indices = ops.convert_to_tensor(indices, name="indices")
1001 indices_shape = array_ops.shape(indices)
1002 indices_dimensions = indices.get_shape().ndims
1004 if indices_dimensions is None:
1005 raise ValueError("batch_gather does not allow indices with unknown "
1006 "shape.")
1008 nd_indices = array_ops.expand_dims(indices, axis=-1)
1009 nd_indices_list = []
1011 # Scatter ND requires indices to have an additional dimension, in which the
1012 # coordinates of the updated things are specified. For this to be adapted to
1013 # the scatter_update with several leading dimensions, we simply make use of
1014 # a tf.range for all the leading dimensions followed by concat of all the
1015 # coordinates we created with the original indices.
1017 # For example if indices.shape = [2, 3, 4], we should generate the following
1018 # indices for tf.compat.v1.scatter_nd_update:
1019 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1020 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1021 # nd_indices[:, :, 2] = indices
1022 for dimension in range(indices_dimensions - 1):
1023 # In this loop we generate the following for the example (one for each
1024 # iteration).
1025 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]]
1026 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]]
1027 # This is done at every iteration with a tf.range over the size of the
1028 # i-th dimension and using broadcasting over the desired shape.
1029 dimension_size = indices_shape[dimension]
1030 shape_to_broadcast = [1] * (indices_dimensions + 1)
1031 shape_to_broadcast[dimension] = dimension_size
1032 dimension_range = array_ops.reshape(
1033 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast)
1034 if dimension_range.dtype.base_dtype != nd_indices.dtype:
1035 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype)
1036 nd_indices_list.append(
1037 dimension_range * array_ops.ones_like(nd_indices))
1038 # Add the original indices at the end, as described above, and concat.
1039 nd_indices_list.append(nd_indices)
1040 final_indices = array_ops.concat(nd_indices_list, axis=-1)
1041 return scatter_nd_update(
1042 ref, final_indices, updates, use_locking=use_locking)