Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/state_ops.py: 37%

128 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Variables. 

17 

18See the [Variables](https://www.tensorflow.org/guide/variables) guide. 

19""" 

20 

21from tensorflow.python.framework import ops 

22from tensorflow.python.framework import tensor_shape 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import gen_math_ops 

25from tensorflow.python.ops import gen_resource_variable_ops 

26from tensorflow.python.ops import gen_state_ops 

27# go/tf-wildcard-import 

28# pylint: disable=wildcard-import 

29from tensorflow.python.ops.gen_state_ops import * 

30# pylint: enable=wildcard-import 

31from tensorflow.python.util import deprecation 

32from tensorflow.python.util.deprecation import deprecated 

33from tensorflow.python.util.tf_export import tf_export 

34 

35 

36# pylint: disable=protected-access,g-doc-return-or-yield,g-doc-args 

37def variable_op(shape, dtype, name="Variable", set_shape=True, container="", 

38 shared_name=""): 

39 """Deprecated. Used variable_op_v2 instead.""" 

40 if not set_shape: 

41 shape = tensor_shape.unknown_shape() 

42 ret = gen_state_ops.variable(shape=shape, dtype=dtype, name=name, 

43 container=container, shared_name=shared_name) 

44 # TODO(mrry): Move this to where it is used, so we can get rid of this op 

45 # wrapper? 

46 if set_shape: 

47 ret.set_shape(shape) 

48 return ret 

49 

50 

51def variable_op_v2(shape, dtype, name="Variable", container="", shared_name=""): 

52 """Create a variable Operation. 

53 

54 See also variables.Variable. 

55 

56 Args: 

57 shape: The shape of the tensor managed by this variable 

58 dtype: The underlying type of the tensor values. 

59 name: optional name to use for the variable op. 

60 container: An optional string. Defaults to "". 

61 If non-empty, this variable is placed in the given container. 

62 Otherwise, a default container is used. 

63 shared_name: An optional string. Defaults to "". 

64 If non-empty, this variable is named in the given bucket 

65 with this shared_name. Otherwise, the node name is used instead. 

66 

67 Returns: 

68 A variable tensor. 

69 """ 

70 return gen_state_ops.variable_v2( 

71 shape=shape, 

72 dtype=dtype, 

73 name=name, 

74 container=container, 

75 shared_name=shared_name) 

76 

77 

78def init_variable(v, init, name="init"): 

79 """Initializes variable with "init". 

80 

81 This op does the following: 

82 if init is a Tensor, v = init 

83 if callable(init): v = init(VariableShape(v), v.dtype) 

84 

85 Args: 

86 v: Variable to initialize 

87 init: Tensor to assign to v, 

88 Or an object convertible to Tensor e.g. nparray, 

89 Or an Initializer that generates a tensor given the shape and type of v. 

90 An "Initializer" is a callable that returns a tensor that "v" should be 

91 set to. It will be called as init(shape, dtype). 

92 name: Optional name for the op. 

93 

94 Returns: 

95 The operation that initializes v. 

96 """ 

97 with ops.name_scope(None, v.op.name + "/", [v, init]): 

98 with ops.name_scope(name) as scope: 

99 with ops.colocate_with(v): 

100 if callable(init): 

101 assert v.get_shape().is_fully_defined(), "Variable shape unknown." 

102 # TODO(mrry): Convert to v.shape when the property and 

103 # accessor are reconciled (and all initializers support 

104 # tf.TensorShape objects). 

105 value = init(v.get_shape().as_list(), v.dtype.base_dtype) 

106 value = ops.convert_to_tensor(value, name="value") 

107 return gen_state_ops.assign(v, value, name=scope) 

108 else: 

109 init = ops.convert_to_tensor(init, name="init") 

110 return gen_state_ops.assign(v, init, name=scope) 

111 

112 

113def is_variable_initialized(ref, name=None): 

114 """Checks whether a tensor has been initialized. 

115 

116 Outputs boolean scalar indicating whether the tensor has been initialized. 

117 

118 Args: 

119 ref: A mutable `Tensor`. 

120 Should be from a `Variable` node. May be uninitialized. 

121 name: A name for the operation (optional). 

122 

123 Returns: 

124 A `Tensor` of type `bool`. 

125 """ 

126 if ref.dtype._is_ref_dtype: 

127 return gen_state_ops.is_variable_initialized(ref=ref, name=name) 

128 # Handle resource variables. 

129 return ref.is_initialized(name=name) 

130 

131 

132@tf_export(v1=["assign_sub"]) 

133def assign_sub(ref, value, use_locking=None, name=None): 

134 """Update `ref` by subtracting `value` from it. 

135 

136 This operation outputs `ref` after the update is done. 

137 This makes it easier to chain operations that need to use the reset value. 

138 Unlike `tf.math.subtract`, this op does not broadcast. `ref` and `value` 

139 must have the same shape. 

140 

141 Args: 

142 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

143 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 

144 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 

145 from a `Variable` node. 

146 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 

147 be subtracted to the variable. 

148 use_locking: An optional `bool`. Defaults to `False`. If True, the 

149 subtraction will be protected by a lock; otherwise the behavior is 

150 undefined, but may exhibit less contention. 

151 name: A name for the operation (optional). 

152 

153 Returns: 

154 Same as `ref`. Returned as a convenience for operations that want 

155 to use the new value after the variable has been updated. 

156 

157 @compatibility(TF2) 

158 `tf.compat.v1.assign_sub` is mostly compatible with eager 

159 execution and `tf.function`. 

160 

161 To switch to the native TF2 style, one could use method 'assign_sub' of 

162 `tf.Variable`: 

163 

164 #### How to Map Arguments 

165 

166 | TF1 Arg Name | TF2 Arg Name | Note | 

167 | :-------------------- | :-------------- | :------------------------- | 

168 | `ref` | `self` | In `assign_sub()` method | 

169 | `value` | `value` | In `assign_sub()` method | 

170 | `use_locking` | `use_locking` | In `assign_sub()` method | 

171 | `name` | `name` | In `assign_sub()` method | 

172 | - | `read_value` | Set to True to replicate | 

173 : : : behavior (True is default) : 

174 

175 

176 #### Before & After Usage Example 

177 

178 Before: 

179 

180 >>> with tf.Graph().as_default(): 

181 ... with tf.compat.v1.Session() as sess: 

182 ... a = tf.compat.v1.Variable(1, dtype=tf.int64) 

183 ... sess.run(a.initializer) 

184 ... update_op = tf.compat.v1.assign_sub(a, 1) 

185 ... res_a = sess.run(update_op) 

186 ... res_a 

187 0 

188 

189 After: 

190 

191 >>> b = tf.Variable(1, dtype=tf.int64) 

192 >>> res_b = b.assign_sub(1) 

193 >>> res_b.numpy() 

194 0 

195 

196 @end_compatibility 

197 """ 

198 if ref.dtype._is_ref_dtype: 

199 return gen_state_ops.assign_sub( 

200 ref, value, use_locking=use_locking, name=name) 

201 return ref.assign_sub(value) 

202 

203 

204@tf_export(v1=["assign_add"]) 

205def assign_add(ref, value, use_locking=None, name=None): 

206 """Update `ref` by adding `value` to it. 

207 

208 This operation outputs `ref` after the update is done. 

209 This makes it easier to chain operations that need to use the reset value. 

210 Unlike `tf.math.add`, this op does not broadcast. `ref` and `value` must have 

211 the same shape. 

212 

213 Args: 

214 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

215 `float64`, `int64`, `int32`, `uint8`, `uint16`, `int16`, `int8`, 

216 `complex64`, `complex128`, `qint8`, `quint8`, `qint32`, `half`. Should be 

217 from a `Variable` node. 

218 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 

219 be added to the variable. 

220 use_locking: An optional `bool`. Defaults to `False`. If True, the addition 

221 will be protected by a lock; otherwise the behavior is undefined, but may 

222 exhibit less contention. 

223 name: A name for the operation (optional). 

224 

225 Returns: 

226 Same as `ref`. Returned as a convenience for operations that want 

227 to use the new value after the variable has been updated. 

228 

229 @compatibility(TF2) 

230 `tf.compat.v1.assign_add` is mostly compatible with eager 

231 execution and `tf.function`. 

232 

233 To switch to the native TF2 style, one could use method 'assign_add' of 

234 `tf.Variable`: 

235 

236 #### How to Map Arguments 

237 

238 | TF1 Arg Name | TF2 Arg Name | Note | 

239 | :-------------------- | :-------------- | :------------------------- | 

240 | `ref` | `self` | In `assign_add()` method | 

241 | `value` | `value` | In `assign_add()` method | 

242 | `use_locking` | `use_locking` | In `assign_add()` method | 

243 | `name` | `name` | In `assign_add()` method | 

244 | - | `read_value` | Set to True to replicate | 

245 : : : behavior (True is default) : 

246 

247 

248 #### Before & After Usage Example 

249 

250 Before: 

251 

252 >>> with tf.Graph().as_default(): 

253 ... with tf.compat.v1.Session() as sess: 

254 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 

255 ... sess.run(a.initializer) 

256 ... update_op = tf.compat.v1.assign_add(a, 1) 

257 ... res_a = sess.run(update_op) 

258 ... res_a 

259 1 

260 

261 After: 

262 

263 >>> b = tf.Variable(0, dtype=tf.int64) 

264 >>> res_b = b.assign_add(1) 

265 >>> res_b.numpy() 

266 1 

267 

268 @end_compatibility 

269 """ 

270 if ref.dtype._is_ref_dtype: 

271 return gen_state_ops.assign_add( 

272 ref, value, use_locking=use_locking, name=name) 

273 return ref.assign_add(value) 

274 

275 

276@tf_export(v1=["assign"]) 

277def assign(ref, value, validate_shape=None, use_locking=None, name=None): 

278 """Update `ref` by assigning `value` to it. 

279 

280 This operation outputs a Tensor that holds the new value of `ref` after 

281 the value has been assigned. This makes it easier to chain operations that 

282 need to use the reset value. 

283 

284 Args: 

285 ref: A mutable `Tensor`. Should be from a `Variable` node. May be 

286 uninitialized. 

287 value: A `Tensor`. Must have the same shape and dtype as `ref`. The value to 

288 be assigned to the variable. 

289 validate_shape: An optional `bool`. Defaults to `True`. If true, the 

290 operation will validate that the shape of 'value' matches the shape of the 

291 Tensor being assigned to. If false, 'ref' will take on the shape of 

292 'value'. 

293 use_locking: An optional `bool`. Defaults to `True`. If True, the assignment 

294 will be protected by a lock; otherwise the behavior is undefined, but may 

295 exhibit less contention. 

296 name: A name for the operation (optional). 

297 

298 Returns: 

299 A `Tensor` that will hold the new value of `ref` after 

300 the assignment has completed. 

301 

302 @compatibility(TF2) 

303 `tf.compat.v1.assign` is mostly compatible with eager 

304 execution and `tf.function`. However, argument 'validate_shape' will be 

305 ignored. To avoid shape validation, set 'shape' to tf.TensorShape(None) when 

306 constructing the variable: 

307 

308 >>> import tensorflow as tf 

309 >>> a = tf.Variable([1], shape=tf.TensorShape(None)) 

310 >>> tf.compat.v1.assign(a, [2,3]) 

311 

312 To switch to the native TF2 style, one could use method 'assign' of 

313 `tf.Variable`: 

314 

315 #### How to Map Arguments 

316 

317 | TF1 Arg Name | TF2 Arg Name | Note | 

318 | :-------------------- | :-------------- | :------------------------- | 

319 | `ref` | `self` | In `assign()` method | 

320 | `value` | `value` | In `assign()` method | 

321 | `validate_shape` | Not supported | Specify `shape` in the | 

322 : : : constructor to replicate : 

323 : : : behavior : 

324 | `use_locking` | `use_locking` | In `assign()` method | 

325 | `name` | `name` | In `assign()` method | 

326 | - | `read_value` | Set to True to replicate | 

327 : : : behavior (True is default) : 

328 @end_compatibility 

329 

330 

331 #### Before & After Usage Example 

332 

333 Before: 

334 

335 >>> with tf.Graph().as_default(): 

336 ... with tf.compat.v1.Session() as sess: 

337 ... a = tf.compat.v1.Variable(0, dtype=tf.int64) 

338 ... sess.run(a.initializer) 

339 ... update_op = tf.compat.v1.assign(a, 2) 

340 ... res_a = sess.run(update_op) 

341 ... res_a 

342 2 

343 

344 After: 

345 

346 >>> b = tf.Variable(0, dtype=tf.int64) 

347 >>> res_b = b.assign(2) 

348 >>> res_b.numpy() 

349 2 

350 """ 

351 if ref.dtype._is_ref_dtype: 

352 return gen_state_ops.assign( 

353 ref, value, use_locking=use_locking, name=name, 

354 validate_shape=validate_shape) 

355 return ref.assign(value, name=name) 

356 

357 

358@tf_export(v1=["count_up_to"]) 

359@deprecated(None, "Prefer Dataset.range instead.") 

360def count_up_to(ref, limit, name=None): 

361 r"""Increments 'ref' until it reaches 'limit'. 

362 

363 Args: 

364 ref: A Variable. Must be one of the following types: `int32`, `int64`. 

365 Should be from a scalar `Variable` node. 

366 limit: An `int`. 

367 If incrementing ref would bring it above limit, instead generates an 

368 'OutOfRange' error. 

369 name: A name for the operation (optional). 

370 

371 Returns: 

372 A `Tensor`. Has the same type as `ref`. 

373 A copy of the input before increment. If nothing else modifies the 

374 input, the values produced will all be distinct. 

375 """ 

376 if ref.dtype._is_ref_dtype: 

377 return gen_state_ops.count_up_to(ref, limit=limit, name=name) 

378 return gen_state_ops.resource_count_up_to( 

379 ref.handle, limit, T=ref.dtype, name=name) 

380 

381 

382@tf_export(v1=["scatter_update"]) 

383def scatter_update(ref, indices, updates, use_locking=True, name=None): 

384 # pylint: disable=line-too-long 

385 r"""Applies sparse updates to a variable reference. 

386 

387 This operation computes 

388 

389 ```python 

390 # Scalar indices 

391 ref[indices, ...] = updates[...] 

392 

393 # Vector indices (for each i) 

394 ref[indices[i], ...] = updates[i, ...] 

395 

396 # High rank indices (for each i, ..., j) 

397 ref[indices[i, ..., j], ...] = updates[i, ..., j, ...] 

398 ``` 

399 

400 This operation outputs `ref` after the update is done. 

401 This makes it easier to chain operations that need to use the reset value. 

402 

403 If values in `ref` is to be updated more than once, because there are 

404 duplicate entries in `indices`, the order at which the updates happen 

405 for each value is undefined. 

406 

407 Requires `updates.shape = indices.shape + ref.shape[1:]`. 

408 

409 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 

410 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt> 

411 </div> 

412 

413 Args: 

414 ref: A `Variable`. 

415 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

416 A tensor of indices into the first dimension of `ref`. 

417 updates: A `Tensor`. Must have the same type as `ref`. 

418 A tensor of updated values to store in `ref`. 

419 use_locking: An optional `bool`. Defaults to `True`. 

420 If True, the assignment will be protected by a lock; 

421 otherwise the behavior is undefined, but may exhibit less contention. 

422 name: A name for the operation (optional). 

423 

424 Returns: 

425 Same as `ref`. Returned as a convenience for operations that want 

426 to use the updated values after the update is done. 

427 """ 

428 if ref.dtype._is_ref_dtype: 

429 return gen_state_ops.scatter_update(ref, indices, updates, 

430 use_locking=use_locking, name=name) 

431 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access 

432 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

433 name=name)) 

434 

435 

436@tf_export(v1=["scatter_nd_update"]) 

437def scatter_nd_update(ref, indices, updates, use_locking=True, name=None): 

438 r"""Applies sparse `updates` to individual values or slices in a Variable. 

439 

440 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

441 

442 `indices` must be integer tensor, containing indices into `ref`. 

443 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

444 

445 The innermost dimension of `indices` (with length `K`) corresponds to 

446 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

447 dimension of `ref`. 

448 

449 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

450 

451 ``` 

452 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]]. 

453 ``` 

454 

455 For example, say we want to update 4 scattered elements to a rank-1 tensor to 

456 8 elements. In Python, that update would look like this: 

457 

458 ```python 

459 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

460 indices = tf.constant([[4], [3], [1] ,[7]]) 

461 updates = tf.constant([9, 10, 11, 12]) 

462 update = tf.compat.v1.scatter_nd_update(ref, indices, updates) 

463 with tf.compat.v1.Session() as sess: 

464 print sess.run(update) 

465 ``` 

466 

467 The resulting update to ref would look like this: 

468 

469 [1, 11, 3, 10, 9, 6, 7, 12] 

470 

471 See `tf.scatter_nd` for more details about how to make updates to 

472 slices. 

473 

474 Args: 

475 ref: A Variable. 

476 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

477 A tensor of indices into ref. 

478 updates: A `Tensor`. Must have the same type as `ref`. 

479 A Tensor. Must have the same type as ref. A tensor of updated 

480 values to add to ref. 

481 use_locking: An optional `bool`. Defaults to `True`. 

482 An optional bool. Defaults to True. If True, the assignment will 

483 be protected by a lock; otherwise the behavior is undefined, 

484 but may exhibit less contention. 

485 name: A name for the operation (optional). 

486 

487 Returns: 

488 The value of the variable after the update. 

489 """ 

490 if ref.dtype._is_ref_dtype: 

491 return gen_state_ops.scatter_nd_update( 

492 ref, indices, updates, use_locking, name) 

493 return ref._lazy_read(gen_state_ops.resource_scatter_nd_update( # pylint: disable=protected-access 

494 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

495 name=name)) 

496 

497 

498@tf_export(v1=["scatter_add"]) 

499def scatter_add(ref, indices, updates, use_locking=False, name=None): 

500 # pylint: disable=line-too-long 

501 r"""Adds sparse updates to the variable referenced by `resource`. 

502 

503 This operation computes 

504 

505 ```python 

506 # Scalar indices 

507 ref[indices, ...] += updates[...] 

508 

509 # Vector indices (for each i) 

510 ref[indices[i], ...] += updates[i, ...] 

511 

512 # High rank indices (for each i, ..., j) 

513 ref[indices[i, ..., j], ...] += updates[i, ..., j, ...] 

514 ``` 

515 

516 This operation outputs `ref` after the update is done. 

517 This makes it easier to chain operations that need to use the updated value. 

518 Duplicate entries are handled correctly: if multiple `indices` reference 

519 the same location, their contributions add. 

520 

521 Requires `updates.shape = indices.shape + ref.shape[1:]`. 

522 

523 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 

524 <img style="width:100%" src='https://www.tensorflow.org/images/ScatterAdd.png' alt> 

525 </div> 

526 

527 Args: 

528 ref: A `Variable`. 

529 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

530 A tensor of indices into the first dimension of `ref`. 

531 updates: A `Tensor`. Must have the same type as `ref`. 

532 A tensor of updated values to store in `ref`. 

533 use_locking: An optional `bool`. Defaults to `False`. 

534 If True, the assignment will be protected by a lock; 

535 otherwise the behavior is undefined, but may exhibit less contention. 

536 name: A name for the operation (optional). 

537 

538 Returns: 

539 Same as `ref`. Returned as a convenience for operations that want 

540 to use the updated values after the update is done. 

541 """ 

542 if ref.dtype._is_ref_dtype: 

543 return gen_state_ops.scatter_add(ref, indices, updates, 

544 use_locking=use_locking, name=name) 

545 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_add( # pylint: disable=protected-access 

546 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

547 name=name)) 

548 

549 

550@tf_export(v1=["scatter_nd_add"]) 

551def scatter_nd_add(ref, indices, updates, use_locking=False, name=None): 

552 r"""Applies sparse addition to individual values or slices in a Variable. 

553 

554 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

555 

556 `indices` must be integer tensor, containing indices into `ref`. 

557 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

558 

559 The innermost dimension of `indices` (with length `K`) corresponds to 

560 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

561 dimension of `ref`. 

562 

563 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

564 

565 ``` 

566 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 

567 ``` 

568 

569 For example, say we want to add 4 scattered elements to a rank-1 tensor to 

570 8 elements. In Python, that addition would look like this: 

571 

572 ```python 

573 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

574 indices = tf.constant([[4], [3], [1], [7]]) 

575 updates = tf.constant([9, 10, 11, 12]) 

576 add = tf.compat.v1.scatter_nd_add(ref, indices, updates) 

577 with tf.compat.v1.Session() as sess: 

578 print sess.run(add) 

579 ``` 

580 

581 The resulting update to ref would look like this: 

582 

583 [1, 13, 3, 14, 14, 6, 7, 20] 

584 

585 See `tf.scatter_nd` for more details about how to make updates to 

586 slices. 

587 

588 Args: 

589 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

590 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 

591 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 

592 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 

593 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

594 A tensor of indices into ref. 

595 updates: A `Tensor`. Must have the same type as `ref`. 

596 A tensor of updated values to add to ref. 

597 use_locking: An optional `bool`. Defaults to `False`. 

598 If True, the assignment will be protected by a lock; 

599 otherwise the behavior is undefined, but may exhibit less contention. 

600 name: A name for the operation (optional). 

601 

602 Returns: 

603 A mutable `Tensor`. Has the same type as `ref`. 

604 """ 

605 if ref.dtype._is_ref_dtype: 

606 return gen_state_ops.scatter_nd_add( 

607 ref, indices, updates, use_locking, name) 

608 return ref._lazy_read(gen_state_ops.resource_scatter_nd_add( # pylint: disable=protected-access 

609 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

610 name=name)) 

611 

612 

613@tf_export(v1=["scatter_sub"]) 

614def scatter_sub(ref, indices, updates, use_locking=False, name=None): 

615 r"""Subtracts sparse updates to a variable reference. 

616 

617 ```python 

618 # Scalar indices 

619 ref[indices, ...] -= updates[...] 

620 

621 # Vector indices (for each i) 

622 ref[indices[i], ...] -= updates[i, ...] 

623 

624 # High rank indices (for each i, ..., j) 

625 ref[indices[i, ..., j], ...] -= updates[i, ..., j, ...] 

626 ``` 

627 

628 This operation outputs `ref` after the update is done. 

629 This makes it easier to chain operations that need to use the reset value. 

630 

631 Duplicate entries are handled correctly: if multiple `indices` reference 

632 the same location, their (negated) contributions add. 

633 

634 Requires `updates.shape = indices.shape + ref.shape[1:]` or 

635 `updates.shape = []`. 

636 

637 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 

638 <img style="width:100%" 

639 src="https://www.tensorflow.org/images/ScatterSub.png" alt> 

640 </div> 

641 

642 Args: 

643 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

644 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 

645 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 

646 `uint32`, `uint64`. Should be from a `Variable` node. 

647 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

648 A tensor of indices into the first dimension of `ref`. 

649 updates: A `Tensor`. Must have the same type as `ref`. 

650 A tensor of updated values to subtract from `ref`. 

651 use_locking: An optional `bool`. Defaults to `False`. 

652 If True, the subtraction will be protected by a lock; 

653 otherwise the behavior is undefined, but may exhibit less contention. 

654 name: A name for the operation (optional). 

655 

656 Returns: 

657 A mutable `Tensor`. Has the same type as `ref`. 

658 """ 

659 if ref.dtype._is_ref_dtype: 

660 return gen_state_ops.scatter_sub(ref, indices, updates, 

661 use_locking=use_locking, name=name) 

662 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_sub( # pylint: disable=protected-access 

663 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

664 name=name)) 

665 

666 

667@tf_export(v1=["scatter_nd_sub"]) 

668def scatter_nd_sub(ref, indices, updates, use_locking=False, name=None): 

669 r"""Applies sparse subtraction to individual values or slices in a Variable. 

670 

671 `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`. 

672 

673 `indices` must be integer tensor, containing indices into `ref`. 

674 It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`. 

675 

676 The innermost dimension of `indices` (with length `K`) corresponds to 

677 indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th 

678 dimension of `ref`. 

679 

680 `updates` is `Tensor` of rank `Q-1+P-K` with shape: 

681 

682 ``` 

683 [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]] 

684 ``` 

685 

686 For example, say we want to subtract 4 scattered elements from a rank-1 tensor 

687 with 8 elements. In Python, that update would look like this: 

688 

689 ```python 

690 ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8]) 

691 indices = tf.constant([[4], [3], [1] ,[7]]) 

692 updates = tf.constant([9, 10, 11, 12]) 

693 op = tf.compat.v1.scatter_nd_sub(ref, indices, updates) 

694 with tf.compat.v1.Session() as sess: 

695 print sess.run(op) 

696 ``` 

697 

698 The resulting update to ref would look like this: 

699 

700 [1, -9, 3, -6, -6, 6, 7, -4] 

701 

702 See `tf.scatter_nd` for more details about how to make updates to 

703 slices. 

704 

705 Args: 

706 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

707 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 

708 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 

709 `uint32`, `uint64`. A mutable Tensor. Should be from a Variable node. 

710 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

711 A tensor of indices into ref. 

712 updates: A `Tensor`. Must have the same type as `ref`. 

713 A tensor of updated values to add to ref. 

714 use_locking: An optional `bool`. Defaults to `False`. 

715 An optional bool. Defaults to True. If True, the assignment will 

716 be protected by a lock; otherwise the behavior is undefined, 

717 but may exhibit less contention. 

718 name: A name for the operation (optional). 

719 

720 Returns: 

721 A mutable `Tensor`. Has the same type as `ref`. 

722 """ 

723 if ref.dtype._is_ref_dtype: 

724 return gen_state_ops.scatter_nd_sub( 

725 ref, indices, updates, use_locking, name) 

726 return ref._lazy_read(gen_state_ops.resource_scatter_nd_sub( # pylint: disable=protected-access 

727 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

728 name=name)) 

729 

730 

731@tf_export(v1=["scatter_mul"]) 

732def scatter_mul(ref, indices, updates, use_locking=False, name=None): 

733 # pylint: disable=line-too-long 

734 r"""Multiplies sparse updates into a variable reference. 

735 

736 This operation computes 

737 

738 ```python 

739 # Scalar indices 

740 ref[indices, ...] *= updates[...] 

741 

742 # Vector indices (for each i) 

743 ref[indices[i], ...] *= updates[i, ...] 

744 

745 # High rank indices (for each i, ..., j) 

746 ref[indices[i, ..., j], ...] *= updates[i, ..., j, ...] 

747 ``` 

748 

749 This operation outputs `ref` after the update is done. 

750 This makes it easier to chain operations that need to use the reset value. 

751 

752 Duplicate entries are handled correctly: if multiple `indices` reference 

753 the same location, their contributions multiply. 

754 

755 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 

756 []`. 

757 

758 Args: 

759 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

760 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 

761 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 

762 `uint32`, `uint64`. Should be from a `Variable` node. 

763 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 

764 tensor of indices into the first dimension of `ref`. 

765 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 

766 values to multiply to `ref`. 

767 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 

768 will be protected by a lock; otherwise the behavior is undefined, but may 

769 exhibit less contention. 

770 name: A name for the operation (optional). 

771 

772 Returns: 

773 A mutable `Tensor`. Has the same type as `ref`. 

774 """ 

775 if ref.dtype._is_ref_dtype: 

776 return gen_state_ops.scatter_mul(ref, indices, updates, 

777 use_locking=use_locking, name=name) 

778 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_mul( # pylint: disable=protected-access 

779 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

780 name=name)) 

781 

782 

783@tf_export(v1=["scatter_div"]) 

784def scatter_div(ref, indices, updates, use_locking=False, name=None): 

785 # pylint: disable=line-too-long 

786 r"""Divides a variable reference by sparse updates. 

787 

788 This operation computes 

789 

790 ```python 

791 # Scalar indices 

792 ref[indices, ...] /= updates[...] 

793 

794 # Vector indices (for each i) 

795 ref[indices[i], ...] /= updates[i, ...] 

796 

797 # High rank indices (for each i, ..., j) 

798 ref[indices[i, ..., j], ...] /= updates[i, ..., j, ...] 

799 ``` 

800 

801 This operation outputs `ref` after the update is done. 

802 This makes it easier to chain operations that need to use the reset value. 

803 

804 Duplicate entries are handled correctly: if multiple `indices` reference 

805 the same location, their contributions divide. 

806 

807 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 

808 []`. 

809 

810 Args: 

811 ref: A mutable `Tensor`. Must be one of the following types: `float32`, 

812 `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, 

813 `qint8`, `quint8`, `qint32`, `bfloat16`, `uint16`, `complex128`, `half`, 

814 `uint32`, `uint64`. Should be from a `Variable` node. 

815 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 

816 tensor of indices into the first dimension of `ref`. 

817 updates: A `Tensor`. Must have the same type as `ref`. A tensor of values 

818 that `ref` is divided by. 

819 use_locking: An optional `bool`. Defaults to `False`. If True, the operation 

820 will be protected by a lock; otherwise the behavior is undefined, but may 

821 exhibit less contention. 

822 name: A name for the operation (optional). 

823 

824 Returns: 

825 A mutable `Tensor`. Has the same type as `ref`. 

826 """ 

827 if ref.dtype._is_ref_dtype: 

828 return gen_state_ops.scatter_div(ref, indices, updates, 

829 use_locking=use_locking, name=name) 

830 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_div( # pylint: disable=protected-access 

831 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

832 name=name)) 

833 

834 

835@tf_export(v1=["scatter_max"]) 

836def scatter_max(ref, indices, updates, use_locking=False, name=None): 

837 # pylint: disable=line-too-long 

838 r"""Reduces sparse updates into a variable reference using the `max` operation. 

839 

840 This operation computes 

841 

842 # Scalar indices 

843 ref[indices, ...] = max(ref[indices, ...], updates[...]) 

844 

845 # Vector indices (for each i) 

846 ref[indices[i], ...] = max(ref[indices[i], ...], updates[i, ...]) 

847 

848 # High rank indices (for each i, ..., j) 

849 ref[indices[i, ..., j], ...] = max(ref[indices[i, ..., j], ...], 

850 updates[i, ..., j, ...]) 

851 

852 This operation outputs `ref` after the update is done. 

853 This makes it easier to chain operations that need to use the reset value. 

854 

855 Duplicate entries are handled correctly: if multiple `indices` reference 

856 the same location, their contributions combine. 

857 

858 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 

859 []`. 

860 

861 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 

862 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 

863 alt> 

864 </div> 

865 

866 Args: 

867 ref: A mutable `Tensor`. Must be one of the following types: `half`, 

868 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 

869 `Variable` node. 

870 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 

871 tensor of indices into the first dimension of `ref`. 

872 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 

873 values to reduce into `ref`. 

874 use_locking: An optional `bool`. Defaults to `False`. If True, the update 

875 will be protected by a lock; otherwise the behavior is undefined, but may 

876 exhibit less contention. 

877 name: A name for the operation (optional). 

878 

879 Returns: 

880 A mutable `Tensor`. Has the same type as `ref`. 

881 """ 

882 if ref.dtype._is_ref_dtype: 

883 return gen_state_ops.scatter_max(ref, indices, updates, 

884 use_locking=use_locking, name=name) 

885 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_max( # pylint: disable=protected-access 

886 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

887 name=name)) 

888 

889 

890@tf_export(v1=["scatter_min"]) 

891def scatter_min(ref, indices, updates, use_locking=False, name=None): 

892 # pylint: disable=line-too-long 

893 r"""Reduces sparse updates into a variable reference using the `min` operation. 

894 

895 This operation computes 

896 

897 # Scalar indices 

898 ref[indices, ...] = min(ref[indices, ...], updates[...]) 

899 

900 # Vector indices (for each i) 

901 ref[indices[i], ...] = min(ref[indices[i], ...], updates[i, ...]) 

902 

903 # High rank indices (for each i, ..., j) 

904 ref[indices[i, ..., j], ...] = min(ref[indices[i, ..., j], ...], 

905 updates[i, ..., j, ...]) 

906 

907 This operation outputs `ref` after the update is done. 

908 This makes it easier to chain operations that need to use the reset value. 

909 

910 Duplicate entries are handled correctly: if multiple `indices` reference 

911 the same location, their contributions combine. 

912 

913 Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = 

914 []`. 

915 

916 <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> 

917 <img style="width:100%" src="https://www.tensorflow.org/images/ScatterAdd.png" 

918 alt> 

919 </div> 

920 

921 Args: 

922 ref: A mutable `Tensor`. Must be one of the following types: `half`, 

923 `bfloat16`, `float32`, `float64`, `int32`, `int64`. Should be from a 

924 `Variable` node. 

925 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. A 

926 tensor of indices into the first dimension of `ref`. 

927 updates: A `Tensor`. Must have the same type as `ref`. A tensor of updated 

928 values to reduce into `ref`. 

929 use_locking: An optional `bool`. Defaults to `False`. If True, the update 

930 will be protected by a lock; otherwise the behavior is undefined, but may 

931 exhibit less contention. 

932 name: A name for the operation (optional). 

933 

934 Returns: 

935 A mutable `Tensor`. Has the same type as `ref`. 

936 """ 

937 if ref.dtype._is_ref_dtype: 

938 return gen_state_ops.scatter_min(ref, indices, updates, 

939 use_locking=use_locking, name=name) 

940 return ref._lazy_read(gen_resource_variable_ops.resource_scatter_min( # pylint: disable=protected-access 

941 ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), 

942 name=name)) 

943 

944 

945@tf_export(v1=["batch_scatter_update"]) 

946@deprecation.deprecated( 

947 "2018-11-29", "Use the batch_scatter_update method of Variable instead.") 

948def batch_scatter_update(ref, indices, updates, use_locking=True, name=None): 

949 """Generalization of `tf.compat.v1.scatter_update` to axis different than 0. 

950 

951 Analogous to `batch_gather`. This assumes that `ref`, `indices` and `updates` 

952 have a series of leading dimensions that are the same for all of them, and the 

953 updates are performed on the last dimension of indices. In other words, the 

954 dimensions should be the following: 

955 

956 `num_prefix_dims = indices.ndims - 1` 

957 `batch_dim = num_prefix_dims + 1` 

958 `updates.shape = indices.shape + var.shape[batch_dim:]` 

959 

960 where 

961 

962 `updates.shape[:num_prefix_dims]` 

963 `== indices.shape[:num_prefix_dims]` 

964 `== var.shape[:num_prefix_dims]` 

965 

966 And the operation performed can be expressed as: 

967 

968 `var[i_1, ..., i_n, indices[i_1, ..., i_n, j]] = updates[i_1, ..., i_n, j]` 

969 

970 When indices is a 1D tensor, this operation is equivalent to 

971 `tf.compat.v1.scatter_update`. 

972 

973 To avoid this operation there would be 2 alternatives: 

974 1) Reshaping the variable by merging the first `ndims` dimensions. However, 

975 this is not possible because `tf.reshape` returns a Tensor, which we 

976 cannot use `tf.compat.v1.scatter_update` on. 

977 2) Looping over the first `ndims` of the variable and using 

978 `tf.compat.v1.scatter_update` on the subtensors that result of slicing the 

979 first 

980 dimension. This is a valid option for `ndims = 1`, but less efficient than 

981 this implementation. 

982 

983 See also `tf.compat.v1.scatter_update` and `tf.compat.v1.scatter_nd_update`. 

984 

985 Args: 

986 ref: `Variable` to scatter onto. 

987 indices: Tensor containing indices as described above. 

988 updates: Tensor of updates to apply to `ref`. 

989 use_locking: Boolean indicating whether to lock the writing operation. 

990 name: Optional scope name string. 

991 

992 Returns: 

993 Ref to `variable` after it has been modified. 

994 

995 Raises: 

996 ValueError: If the initial `ndims` of `ref`, `indices`, and `updates` are 

997 not the same. 

998 """ 

999 with ops.name_scope(name): 

1000 indices = ops.convert_to_tensor(indices, name="indices") 

1001 indices_shape = array_ops.shape(indices) 

1002 indices_dimensions = indices.get_shape().ndims 

1003 

1004 if indices_dimensions is None: 

1005 raise ValueError("batch_gather does not allow indices with unknown " 

1006 "shape.") 

1007 

1008 nd_indices = array_ops.expand_dims(indices, axis=-1) 

1009 nd_indices_list = [] 

1010 

1011 # Scatter ND requires indices to have an additional dimension, in which the 

1012 # coordinates of the updated things are specified. For this to be adapted to 

1013 # the scatter_update with several leading dimensions, we simply make use of 

1014 # a tf.range for all the leading dimensions followed by concat of all the 

1015 # coordinates we created with the original indices. 

1016 

1017 # For example if indices.shape = [2, 3, 4], we should generate the following 

1018 # indices for tf.compat.v1.scatter_nd_update: 

1019 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 

1020 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 

1021 # nd_indices[:, :, 2] = indices 

1022 for dimension in range(indices_dimensions - 1): 

1023 # In this loop we generate the following for the example (one for each 

1024 # iteration). 

1025 # nd_indices[:, :, 0] = [[0, 0, 0], [1, 1, 1]] 

1026 # nd_indices[:, :, 1] = [[0, 1, 2], [0, 1, 2]] 

1027 # This is done at every iteration with a tf.range over the size of the 

1028 # i-th dimension and using broadcasting over the desired shape. 

1029 dimension_size = indices_shape[dimension] 

1030 shape_to_broadcast = [1] * (indices_dimensions + 1) 

1031 shape_to_broadcast[dimension] = dimension_size 

1032 dimension_range = array_ops.reshape( 

1033 gen_math_ops._range(0, dimension_size, 1), shape_to_broadcast) 

1034 if dimension_range.dtype.base_dtype != nd_indices.dtype: 

1035 dimension_range = gen_math_ops.cast(dimension_range, nd_indices.dtype) 

1036 nd_indices_list.append( 

1037 dimension_range * array_ops.ones_like(nd_indices)) 

1038 # Add the original indices at the end, as described above, and concat. 

1039 nd_indices_list.append(nd_indices) 

1040 final_indices = array_ops.concat(nd_indices_list, axis=-1) 

1041 return scatter_nd_update( 

1042 ref, final_indices, updates, use_locking=use_locking)