Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/functional_ops.py: 19%

257 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Functional operations.""" 

16 

17from tensorflow.core.framework import attr_value_pb2 

18from tensorflow.python.eager import context 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import function 

22from tensorflow.python.framework import ops 

23from tensorflow.python.framework import tensor_shape 

24from tensorflow.python.ops import array_ops 

25from tensorflow.python.ops import gen_functional_ops 

26from tensorflow.python.ops import math_ops 

27from tensorflow.python.ops import tensor_array_ops 

28from tensorflow.python.ops import variable_scope as vs 

29from tensorflow.python.ops import while_loop 

30# pylint: disable=unused-import 

31from tensorflow.python.ops.gen_functional_ops import remote_call 

32# pylint: enable=unused-import 

33from tensorflow.python.ops.gen_functional_ops import symbolic_gradient 

34from tensorflow.python.util import deprecation 

35from tensorflow.python.util import dispatch 

36from tensorflow.python.util import nest 

37from tensorflow.python.util.tf_export import tf_export 

38 

39 

40# TODO(yuanbyu, mrry): Handle stride to support sliding windows. 

41@tf_export(v1=["foldl"]) 

42@dispatch.add_dispatch_support 

43def foldl(fn, 

44 elems, 

45 initializer=None, 

46 parallel_iterations=10, 

47 back_prop=True, 

48 swap_memory=False, 

49 name=None): 

50 """foldl on the list of tensors unpacked from `elems` on dimension 0. 

51 

52 This foldl operator repeatedly applies the callable `fn` to a sequence 

53 of elements from first to last. The elements are made of the tensors 

54 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 

55 arguments. The first argument is the accumulated value computed from the 

56 preceding invocation of fn, and the second is the value at the current 

57 position of `elems`. If `initializer` is None, `elems` must contain at least 

58 one element, and its first element is used as the initializer. 

59 

60 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

61 of the result tensor is fn(initializer, values[0]).shape`. 

62 

63 This method also allows multi-arity `elems` and output of `fn`. If `elems` 

64 is a (possibly nested) list or tuple of tensors, then each of these tensors 

65 must have a matching first (unpack) dimension. The signature of `fn` may 

66 match the structure of `elems`. That is, if `elems` is 

67 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 

68 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 

69 

70 Args: 

71 fn: The callable to be performed. 

72 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

73 be unpacked along their first dimension. The nested sequence of the 

74 resulting slices will be the first argument to `fn`. 

75 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

76 as the initial value for the accumulator. 

77 parallel_iterations: (optional) The number of iterations allowed to run in 

78 parallel. 

79 back_prop: (optional) True enables support for back propagation. 

80 swap_memory: (optional) True enables GPU-CPU memory swapping. 

81 name: (optional) Name prefix for the returned tensors. 

82 

83 Returns: 

84 A tensor or (possibly nested) sequence of tensors, resulting from applying 

85 `fn` consecutively to the list of tensors unpacked from `elems`, from first 

86 to last. 

87 

88 Raises: 

89 TypeError: if `fn` is not callable. 

90 

91 Example: 

92 ```python 

93 elems = tf.constant([1, 2, 3, 4, 5, 6]) 

94 sum = foldl(lambda a, x: a + x, elems) 

95 # sum == 21 

96 ``` 

97 """ 

98 if not callable(fn): 

99 raise TypeError( 

100 f"{fn.__name__} is not callable. Please provide a callable function.") 

101 

102 def create_ta(elem): 

103 return tensor_array_ops.TensorArray( 

104 dtype=elem.dtype, size=n, dynamic_size=False, 

105 infer_shape=True).unstack(elem) 

106 

107 in_graph_mode = not context.executing_eagerly() 

108 with ops.name_scope(name, "foldl", [elems]): 

109 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

110 # supported in Eager 

111 if in_graph_mode: 

112 # Any get_variable calls in fn will cache the first call locally 

113 # and not issue repeated network I/O requests for each iteration. 

114 varscope = vs.get_variable_scope() 

115 varscope_caching_device_was_none = False 

116 if varscope.caching_device is None: 

117 # TODO(ebrevdo): Change to using colocate_with here and in other 

118 # methods. 

119 varscope.set_caching_device(lambda op: op.device) 

120 varscope_caching_device_was_none = True 

121 

122 # Convert elems to tensor array. n may be known statically. 

123 elems_flat = [ 

124 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 

125 ] 

126 n = ( 

127 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 

128 array_ops.shape(elems_flat[0])[0]) 

129 

130 elems_ta = nest.map_structure(create_ta, elems) 

131 

132 if initializer is None: 

133 a = nest.map_structure(lambda elem: elem.read(0), elems_ta) 

134 i = constant_op.constant(1) 

135 else: 

136 a = initializer 

137 i = constant_op.constant(0) 

138 

139 def compute(i, a): 

140 elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta) 

141 a = fn(a, elem_i) 

142 return [i + 1, a] 

143 

144 _, r_a = while_loop.while_loop( 

145 lambda i, a: i < n, 

146 compute, [i, a], 

147 parallel_iterations=parallel_iterations, 

148 back_prop=back_prop, 

149 swap_memory=swap_memory, 

150 maximum_iterations=n) 

151 

152 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

153 # supported in Eager 

154 if in_graph_mode and varscope_caching_device_was_none: 

155 varscope.set_caching_device(None) 

156 

157 return r_a 

158 

159 

160@tf_export("foldl", v1=[]) 

161@dispatch.add_dispatch_support 

162@deprecation.deprecated_arg_values( 

163 None, 

164 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 

165Instead of: 

166results = tf.foldl(fn, elems, back_prop=False) 

167Use: 

168results = tf.nest.map_structure(tf.stop_gradient, tf.foldl(fn, elems))""", 

169 warn_once=True, 

170 back_prop=False) 

171def foldl_v2(fn, 

172 elems, 

173 initializer=None, 

174 parallel_iterations=10, 

175 back_prop=True, 

176 swap_memory=False, 

177 name=None): 

178 """foldl on the list of tensors unpacked from `elems` on dimension 0. 

179 

180 This foldl operator repeatedly applies the callable `fn` to a sequence 

181 of elements from first to last. The elements are made of the tensors 

182 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 

183 arguments. The first argument is the accumulated value computed from the 

184 preceding invocation of fn, and the second is the value at the current 

185 position of `elems`. If `initializer` is None, `elems` must contain at least 

186 one element, and its first element is used as the initializer. 

187 

188 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

189 of the result tensor is fn(initializer, values[0]).shape`. 

190 

191 This method also allows multi-arity `elems` and output of `fn`. If `elems` 

192 is a (possibly nested) list or tuple of tensors, then each of these tensors 

193 must have a matching first (unpack) dimension. The signature of `fn` may 

194 match the structure of `elems`. That is, if `elems` is 

195 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 

196 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 

197 

198 Args: 

199 fn: The callable to be performed. 

200 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

201 be unpacked along their first dimension. The nested sequence of the 

202 resulting slices will be the first argument to `fn`. 

203 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

204 as the initial value for the accumulator. 

205 parallel_iterations: (optional) The number of iterations allowed to run in 

206 parallel. 

207 back_prop: (optional) Deprecated. False disables support for back 

208 propagation. Prefer using `tf.stop_gradient` instead. 

209 swap_memory: (optional) True enables GPU-CPU memory swapping. 

210 name: (optional) Name prefix for the returned tensors. 

211 

212 Returns: 

213 A tensor or (possibly nested) sequence of tensors, resulting from applying 

214 `fn` consecutively to the list of tensors unpacked from `elems`, from first 

215 to last. 

216 

217 Raises: 

218 TypeError: if `fn` is not callable. 

219 

220 Example: 

221 ```python 

222 elems = tf.constant([1, 2, 3, 4, 5, 6]) 

223 sum = tf.foldl(lambda a, x: a + x, elems) 

224 # sum == 21 

225 ``` 

226 """ 

227 return foldl( 

228 fn=fn, 

229 elems=elems, 

230 initializer=initializer, 

231 parallel_iterations=parallel_iterations, 

232 back_prop=back_prop, 

233 swap_memory=swap_memory, 

234 name=name) 

235 

236 

237@tf_export(v1=["foldr"]) 

238@dispatch.add_dispatch_support 

239def foldr(fn, 

240 elems, 

241 initializer=None, 

242 parallel_iterations=10, 

243 back_prop=True, 

244 swap_memory=False, 

245 name=None): 

246 """foldr on the list of tensors unpacked from `elems` on dimension 0. 

247 

248 This foldr operator repeatedly applies the callable `fn` to a sequence 

249 of elements from last to first. The elements are made of the tensors 

250 unpacked from `elems`. The callable fn takes two tensors as arguments. 

251 The first argument is the accumulated value computed from the preceding 

252 invocation of fn, and the second is the value at the current position of 

253 `elems`. If `initializer` is None, `elems` must contain at least one element, 

254 and its first element is used as the initializer. 

255 

256 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

257 of the result tensor is `fn(initializer, values[0]).shape`. 

258 

259 This method also allows multi-arity `elems` and output of `fn`. If `elems` 

260 is a (possibly nested) list or tuple of tensors, then each of these tensors 

261 must have a matching first (unpack) dimension. The signature of `fn` may 

262 match the structure of `elems`. That is, if `elems` is 

263 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 

264 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 

265 

266 Args: 

267 fn: The callable to be performed. 

268 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

269 be unpacked along their first dimension. The nested sequence of the 

270 resulting slices will be the first argument to `fn`. 

271 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

272 as the initial value for the accumulator. 

273 parallel_iterations: (optional) The number of iterations allowed to run in 

274 parallel. 

275 back_prop: (optional) True enables support for back propagation. 

276 swap_memory: (optional) True enables GPU-CPU memory swapping. 

277 name: (optional) Name prefix for the returned tensors. 

278 

279 Returns: 

280 A tensor or (possibly nested) sequence of tensors, resulting from applying 

281 `fn` consecutively to the list of tensors unpacked from `elems`, from last 

282 to first. 

283 

284 Raises: 

285 TypeError: if `fn` is not callable. 

286 

287 Example: 

288 ```python 

289 elems = [1, 2, 3, 4, 5, 6] 

290 sum = foldr(lambda a, x: a + x, elems) 

291 # sum == 21 

292 ``` 

293 """ 

294 if not callable(fn): 

295 raise TypeError( 

296 f"{fn.__name__} is not callable. Please provide a callable function.") 

297 

298 def create_ta(elem): 

299 return tensor_array_ops.TensorArray( 

300 dtype=elem.dtype, size=n, dynamic_size=False, 

301 infer_shape=True).unstack(elem) 

302 

303 in_graph_mode = not context.executing_eagerly() 

304 with ops.name_scope(name, "foldr", [elems]): 

305 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

306 # supported in Eager 

307 if in_graph_mode: 

308 # Any get_variable calls in fn will cache the first call locally and not 

309 # issue repeated network I/O requests for each iteration. 

310 varscope = vs.get_variable_scope() 

311 varscope_caching_device_was_none = False 

312 if varscope.caching_device is None: 

313 # TODO(ebrevdo): Change to using colocate_with here and in other 

314 # methods. 

315 varscope.set_caching_device(lambda op: op.device) 

316 varscope_caching_device_was_none = True 

317 

318 # Convert elems to tensor array. n may be known statically. 

319 elems_flat = [ 

320 ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems) 

321 ] 

322 n = ( 

323 tensor_shape.dimension_value(elems_flat[0].shape[0]) or 

324 array_ops.shape(elems_flat[0])[0]) 

325 

326 elems_ta = nest.map_structure(create_ta, elems) 

327 

328 if initializer is None: 

329 i = n - 1 

330 a = nest.map_structure(lambda elem: elem.read(i), elems_ta) 

331 else: 

332 i = n 

333 a = initializer 

334 

335 def compute(i, a): 

336 i -= 1 

337 elem = nest.map_structure(lambda elem: elem.read(i), elems_ta) 

338 a_out = fn(a, elem) 

339 return [i, a_out] 

340 

341 _, r_a = while_loop.while_loop( 

342 lambda i, a: i > 0, 

343 compute, [i, a], 

344 parallel_iterations=parallel_iterations, 

345 back_prop=back_prop, 

346 swap_memory=swap_memory, 

347 maximum_iterations=n) 

348 

349 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

350 # supported in Eager 

351 if in_graph_mode and varscope_caching_device_was_none: 

352 varscope.set_caching_device(None) 

353 

354 return r_a 

355 

356 

357@tf_export("foldr", v1=[]) 

358@dispatch.add_dispatch_support 

359@deprecation.deprecated_arg_values( 

360 None, 

361 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 

362Instead of: 

363results = tf.foldr(fn, elems, back_prop=False) 

364Use: 

365results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems))""", 

366 warn_once=True, 

367 back_prop=False) 

368def foldr_v2(fn, 

369 elems, 

370 initializer=None, 

371 parallel_iterations=10, 

372 back_prop=True, 

373 swap_memory=False, 

374 name=None): 

375 """foldr on the list of tensors unpacked from `elems` on dimension 0. 

376 

377 This foldr operator repeatedly applies the callable `fn` to a sequence 

378 of elements from last to first. The elements are made of the tensors 

379 unpacked from `elems`. The callable fn takes two tensors as arguments. 

380 The first argument is the accumulated value computed from the preceding 

381 invocation of fn, and the second is the value at the current position of 

382 `elems`. If `initializer` is None, `elems` must contain at least one element, 

383 and its first element is used as the initializer. 

384 

385 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

386 of the result tensor is `fn(initializer, values[0]).shape`. 

387 

388 This method also allows multi-arity `elems` and output of `fn`. If `elems` 

389 is a (possibly nested) list or tuple of tensors, then each of these tensors 

390 must have a matching first (unpack) dimension. The signature of `fn` may 

391 match the structure of `elems`. That is, if `elems` is 

392 `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is: 

393 `fn = lambda (t1, [t2, t3, [t4, t5]]):`. 

394 

395 Args: 

396 fn: The callable to be performed. 

397 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

398 be unpacked along their first dimension. The nested sequence of the 

399 resulting slices will be the first argument to `fn`. 

400 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

401 as the initial value for the accumulator. 

402 parallel_iterations: (optional) The number of iterations allowed to run in 

403 parallel. 

404 back_prop: (optional) Deprecated. False disables support for back 

405 propagation. Prefer using `tf.stop_gradient` instead. 

406 swap_memory: (optional) True enables GPU-CPU memory swapping. 

407 name: (optional) Name prefix for the returned tensors. 

408 

409 Returns: 

410 A tensor or (possibly nested) sequence of tensors, resulting from applying 

411 `fn` consecutively to the list of tensors unpacked from `elems`, from last 

412 to first. 

413 

414 Raises: 

415 TypeError: if `fn` is not callable. 

416 

417 Example: 

418 ```python 

419 elems = [1, 2, 3, 4, 5, 6] 

420 sum = tf.foldr(lambda a, x: a + x, elems) 

421 # sum == 21 

422 ``` 

423 """ 

424 return foldr( 

425 fn=fn, 

426 elems=elems, 

427 initializer=initializer, 

428 parallel_iterations=parallel_iterations, 

429 back_prop=back_prop, 

430 swap_memory=swap_memory, 

431 name=name) 

432 

433 

434@tf_export(v1=["scan"]) 

435@dispatch.add_dispatch_support 

436def scan(fn, 

437 elems, 

438 initializer=None, 

439 parallel_iterations=10, 

440 back_prop=True, 

441 swap_memory=False, 

442 infer_shape=True, 

443 reverse=False, 

444 name=None): 

445 """scan on the list of tensors unpacked from `elems` on dimension 0. 

446 

447 See also `tf.map_fn`. 

448 

449 The simplest version of `scan` repeatedly applies the callable `fn` to a 

450 sequence of elements from first to last. The elements are made of the tensors 

451 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 

452 arguments. The first argument is the accumulated value computed from the 

453 preceding invocation of fn, and the second is the value at the current 

454 position of `elems`. If `initializer` is None, `elems` must contain at least 

455 one element, and its first element is used as the initializer. 

456 

457 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

458 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 

459 If reverse=True, it's fn(initializer, values[-1]).shape. 

460 

461 This method also allows multi-arity `elems` and accumulator. If `elems` 

462 is a (possibly nested) list or tuple of tensors, then each of these tensors 

463 must have a matching first (unpack) dimension. The second argument of 

464 `fn` must match the structure of `elems`. 

465 

466 If no `initializer` is provided, the output structure and dtypes of `fn` 

467 are assumed to be the same as its input; and in this case, the first 

468 argument of `fn` must match the structure of `elems`. 

469 

470 If an `initializer` is provided, then the output of `fn` must have the same 

471 structure as `initializer`; and the first argument of `fn` must match 

472 this structure. 

473 

474 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 

475 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 

476 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 

477 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 

478 one that works in `python3`, is: 

479 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 

480 

481 Args: 

482 fn: The callable to be performed. It accepts two arguments. The first will 

483 have the same structure as `initializer` if one is provided, otherwise it 

484 will have the same structure as `elems`. The second will have the same 

485 (possibly nested) structure as `elems`. Its output must have the same 

486 structure as `initializer` if one is provided, otherwise it must have the 

487 same structure as `elems`. 

488 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

489 be unpacked along their first dimension. The nested sequence of the 

490 resulting slices will be the first argument to `fn`. 

491 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

492 initial value for the accumulator, and the expected output type of `fn`. 

493 parallel_iterations: (optional) The number of iterations allowed to run in 

494 parallel. 

495 back_prop: (optional) True enables support for back propagation. 

496 swap_memory: (optional) True enables GPU-CPU memory swapping. 

497 infer_shape: (optional) False disables tests for consistent output shapes. 

498 reverse: (optional) True scans the tensor last to first (instead of first to 

499 last). 

500 name: (optional) Name prefix for the returned tensors. 

501 

502 Returns: 

503 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 

504 results of applying `fn` to tensors unpacked from `elems` along the first 

505 dimension, and the previous accumulator value(s), from first to last (or 

506 last to first, if `reverse=True`). 

507 

508 Raises: 

509 TypeError: if `fn` is not callable or the structure of the output of 

510 `fn` and `initializer` do not match. 

511 ValueError: if the lengths of the output of `fn` and `initializer` 

512 do not match. 

513 

514 Examples: 

515 ```python 

516 elems = np.array([1, 2, 3, 4, 5, 6]) 

517 sum = scan(lambda a, x: a + x, elems) 

518 # sum == [1, 3, 6, 10, 15, 21] 

519 sum = scan(lambda a, x: a + x, elems, reverse=True) 

520 # sum == [21, 20, 18, 15, 11, 6] 

521 ``` 

522 

523 ```python 

524 elems = np.array([1, 2, 3, 4, 5, 6]) 

525 initializer = np.array(0) 

526 sum_one = scan( 

527 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 

528 # sum_one == [1, 2, 3, 4, 5, 6] 

529 ``` 

530 

531 ```python 

532 elems = np.array([1, 0, 0, 0, 0, 0]) 

533 initializer = (np.array(0), np.array(1)) 

534 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 

535 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 

536 ``` 

537 """ 

538 if not callable(fn): 

539 raise TypeError( 

540 f"{fn.__name__} is not callable. Please provide a callable function.") 

541 

542 input_is_sequence = nest.is_nested(elems) 

543 input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x] 

544 

545 def input_pack(x): 

546 return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0] 

547 

548 if initializer is None: 

549 output_is_sequence = input_is_sequence 

550 output_flatten = input_flatten 

551 output_pack = input_pack 

552 else: 

553 output_is_sequence = nest.is_nested(initializer) 

554 output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x] 

555 

556 def output_pack(x): 

557 return (nest.pack_sequence_as(initializer, x) 

558 if output_is_sequence else x[0]) 

559 

560 elems_flat = input_flatten(elems) 

561 

562 in_graph_mode = not context.executing_eagerly() 

563 with ops.name_scope(name, "scan", elems_flat): 

564 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

565 # supported in Eager 

566 if in_graph_mode: 

567 # Any get_variable calls in fn will cache the first call locally 

568 # and not issue repeated network I/O requests for each iteration. 

569 varscope = vs.get_variable_scope() 

570 varscope_caching_device_was_none = False 

571 if varscope.caching_device is None: 

572 # TODO(ebrevdo): Change to using colocate_with here and in other 

573 # methods. 

574 varscope.set_caching_device(lambda op: op.device) 

575 varscope_caching_device_was_none = True 

576 

577 # Convert elems to tensor array. 

578 elems_flat = [ 

579 ops.convert_to_tensor(elem, name="elem") for elem in elems_flat 

580 ] 

581 

582 # Convert elems to tensor array. n may be known statically. 

583 n = tensor_shape.dimension_value(elems_flat[0].shape[0]) 

584 if n is None: 

585 n = array_ops.shape(elems_flat[0])[0] 

586 

587 # TensorArrays are always flat 

588 elems_ta = [ 

589 tensor_array_ops.TensorArray( 

590 dtype=elem.dtype, 

591 size=n, 

592 dynamic_size=False, 

593 element_shape=elem.shape[1:], 

594 infer_shape=True) for elem in elems_flat 

595 ] 

596 # Unpack elements 

597 elems_ta = [ 

598 elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat) 

599 ] 

600 

601 if initializer is None: 

602 a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta] 

603 i = 1 

604 else: 

605 initializer_flat = output_flatten(initializer) 

606 a_flat = [ops.convert_to_tensor(init) for init in initializer_flat] 

607 i = 0 

608 

609 # Create a tensor array to store the intermediate values. 

610 accs_ta = [ 

611 tensor_array_ops.TensorArray( 

612 dtype=init.dtype, 

613 size=n, 

614 element_shape=init.shape if infer_shape else None, 

615 dynamic_size=False, 

616 infer_shape=infer_shape) for init in a_flat 

617 ] 

618 

619 if initializer is None: 

620 accs_ta = [ 

621 acc_ta.write(n - 1 if reverse else 0, a) 

622 for (acc_ta, a) in zip(accs_ta, a_flat) 

623 ] 

624 

625 def compute(i, a_flat, tas): 

626 """The loop body of scan. 

627 

628 Args: 

629 i: the loop counter. 

630 a_flat: the accumulator value(s), flattened. 

631 tas: the output accumulator TensorArray(s), flattened. 

632 

633 Returns: 

634 [i + 1, a_flat, tas]: the updated counter + new accumulator values + 

635 updated TensorArrays 

636 

637 Raises: 

638 TypeError: if initializer and fn() output structure do not match 

639 ValueType: if initializer and fn() output lengths do not match 

640 """ 

641 packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta]) 

642 packed_a = output_pack(a_flat) 

643 a_out = fn(packed_a, packed_elems) 

644 nest.assert_same_structure(elems if initializer is None else initializer, 

645 a_out) 

646 flat_a_out = output_flatten(a_out) 

647 tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)] 

648 if reverse: 

649 next_i = i - 1 

650 else: 

651 next_i = i + 1 

652 return (next_i, flat_a_out, tas) 

653 

654 if reverse: 

655 initial_i = n - 1 - i 

656 condition = lambda i, _1, _2: i >= 0 

657 else: 

658 initial_i = i 

659 condition = lambda i, _1, _2: i < n 

660 _, _, r_a = while_loop.while_loop( 

661 condition, 

662 compute, (initial_i, a_flat, accs_ta), 

663 parallel_iterations=parallel_iterations, 

664 back_prop=back_prop, 

665 swap_memory=swap_memory, 

666 maximum_iterations=n) 

667 

668 results_flat = [r.stack() for r in r_a] 

669 

670 n_static = tensor_shape.Dimension( 

671 tensor_shape.dimension_value( 

672 elems_flat[0].get_shape().with_rank_at_least(1)[0])) 

673 for elem in elems_flat[1:]: 

674 n_static.assert_is_compatible_with( 

675 tensor_shape.Dimension( 

676 tensor_shape.dimension_value( 

677 elem.get_shape().with_rank_at_least(1)[0]))) 

678 for r in results_flat: 

679 r.set_shape( 

680 tensor_shape.TensorShape(n_static).concatenate(r.get_shape()[1:])) 

681 

682 # TODO(akshayka): Remove the in_graph_mode check once caching devices are 

683 # supported in Eager 

684 if in_graph_mode and varscope_caching_device_was_none: 

685 varscope.set_caching_device(None) 

686 

687 return output_pack(results_flat) 

688 

689 

690@tf_export("scan", v1=[]) 

691@dispatch.add_dispatch_support 

692@deprecation.deprecated_arg_values( 

693 None, 

694 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 

695Instead of: 

696results = tf.scan(fn, elems, back_prop=False) 

697Use: 

698results = tf.nest.map_structure(tf.stop_gradient, tf.scan(fn, elems))""", 

699 warn_once=True, 

700 back_prop=False) 

701def scan_v2(fn, 

702 elems, 

703 initializer=None, 

704 parallel_iterations=10, 

705 back_prop=True, 

706 swap_memory=False, 

707 infer_shape=True, 

708 reverse=False, 

709 name=None): 

710 """scan on the list of tensors unpacked from `elems` on dimension 0. 

711 

712 The simplest version of `scan` repeatedly applies the callable `fn` to a 

713 sequence of elements from first to last. The elements are made of the tensors 

714 unpacked from `elems` on dimension 0. The callable fn takes two tensors as 

715 arguments. The first argument is the accumulated value computed from the 

716 preceding invocation of fn, and the second is the value at the current 

717 position of `elems`. If `initializer` is None, `elems` must contain at least 

718 one element, and its first element is used as the initializer. 

719 

720 Suppose that `elems` is unpacked into `values`, a list of tensors. The shape 

721 of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`. 

722 If reverse=True, it's fn(initializer, values[-1]).shape. 

723 

724 This method also allows multi-arity `elems` and accumulator. If `elems` 

725 is a (possibly nested) list or tuple of tensors, then each of these tensors 

726 must have a matching first (unpack) dimension. The second argument of 

727 `fn` must match the structure of `elems`. 

728 

729 If no `initializer` is provided, the output structure and dtypes of `fn` 

730 are assumed to be the same as its input; and in this case, the first 

731 argument of `fn` must match the structure of `elems`. 

732 

733 If an `initializer` is provided, then the output of `fn` must have the same 

734 structure as `initializer`; and the first argument of `fn` must match 

735 this structure. 

736 

737 For example, if `elems` is `(t1, [t2, t3])` and `initializer` is 

738 `[i1, i2]` then an appropriate signature for `fn` in `python2` is: 

739 `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list, 

740 `[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the 

741 one that works in `python3`, is: 

742 `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples. 

743 

744 Args: 

745 fn: The callable to be performed. It accepts two arguments. The first will 

746 have the same structure as `initializer` if one is provided, otherwise it 

747 will have the same structure as `elems`. The second will have the same 

748 (possibly nested) structure as `elems`. Its output must have the same 

749 structure as `initializer` if one is provided, otherwise it must have the 

750 same structure as `elems`. 

751 elems: A tensor or (possibly nested) sequence of tensors, each of which will 

752 be unpacked along their first dimension. The nested sequence of the 

753 resulting slices will be the first argument to `fn`. 

754 initializer: (optional) A tensor or (possibly nested) sequence of tensors, 

755 initial value for the accumulator, and the expected output type of `fn`. 

756 parallel_iterations: (optional) The number of iterations allowed to run in 

757 parallel. 

758 back_prop: (optional) Deprecated. False disables support for back 

759 propagation. Prefer using `tf.stop_gradient` instead. 

760 swap_memory: (optional) True enables GPU-CPU memory swapping. 

761 infer_shape: (optional) False disables tests for consistent output shapes. 

762 reverse: (optional) True scans the tensor last to first (instead of first to 

763 last). 

764 name: (optional) Name prefix for the returned tensors. 

765 

766 Returns: 

767 A tensor or (possibly nested) sequence of tensors. Each tensor packs the 

768 results of applying `fn` to tensors unpacked from `elems` along the first 

769 dimension, and the previous accumulator value(s), from first to last (or 

770 last to first, if `reverse=True`). 

771 

772 Raises: 

773 TypeError: if `fn` is not callable or the structure of the output of 

774 `fn` and `initializer` do not match. 

775 ValueError: if the lengths of the output of `fn` and `initializer` 

776 do not match. 

777 

778 Examples: 

779 ```python 

780 elems = np.array([1, 2, 3, 4, 5, 6]) 

781 sum = scan(lambda a, x: a + x, elems) 

782 # sum == [1, 3, 6, 10, 15, 21] 

783 sum = scan(lambda a, x: a + x, elems, reverse=True) 

784 # sum == [21, 20, 18, 15, 11, 6] 

785 ``` 

786 

787 ```python 

788 elems = np.array([1, 2, 3, 4, 5, 6]) 

789 initializer = np.array(0) 

790 sum_one = scan( 

791 lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer) 

792 # sum_one == [1, 2, 3, 4, 5, 6] 

793 ``` 

794 

795 ```python 

796 elems = np.array([1, 0, 0, 0, 0, 0]) 

797 initializer = (np.array(0), np.array(1)) 

798 fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer) 

799 # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13]) 

800 ``` 

801 """ 

802 return scan( 

803 fn=fn, 

804 elems=elems, 

805 initializer=initializer, 

806 parallel_iterations=parallel_iterations, 

807 back_prop=back_prop, 

808 swap_memory=swap_memory, 

809 infer_shape=infer_shape, 

810 reverse=reverse, 

811 name=name) 

812 

813 

814# pylint: disable=invalid-name 

815def If(cond, inputs, then_branch, else_branch, name=None): 

816 r"""output = Cond(inputs) ? 

817 

818 then_branch(inputs) : else_branch(inputs). 

819 

820 Args: 

821 cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is 

822 converted to a boolean according to the following rule: if the scalar is a 

823 numerical value, non-zero means True and zero means False; if the scalar 

824 is a string, non-empty means True and empty means False. 

825 inputs: A list of input tensors. 

826 then_branch: A function takes 'inputs' and returns a list of tensors, whose 

827 types are the same as what else_branch returns. 

828 else_branch: A function takes 'inputs' and returns a list of tensors. whose 

829 types are the same as what then_branch returns. 

830 name: A name for the operation (optional). 

831 

832 Returns: 

833 A list of tensors returned by either then_branch(inputs) 

834 or else_branch(inputs). 

835 """ 

836 # pylint: disable=protected-access 

837 # Handle the Defun case until users have transitioned to tf.function. Note 

838 # that composites may need to be re-packed by the caller. 

839 if isinstance(then_branch, function._DefinedFunction): 

840 tlist = [_.type for _ in then_branch.definition.signature.output_arg] 

841 return gen_functional_ops._if( 

842 cond, inputs, tlist, then_branch, else_branch, name=name) 

843 

844 # We assume that `then_branch` is a ConcreteFunction here. 

845 then_out = then_branch.structured_outputs 

846 else_out = else_branch.structured_outputs 

847 

848 # Ensure then/else are the same type of composites to avoid an invalid call 

849 # to pack_sequence_as later on. 

850 nest.assert_same_structure(then_out, else_out, expand_composites=True) 

851 

852 tlist = nest.flatten(then_branch.output_dtypes) 

853 ret = gen_functional_ops._if( 

854 cond, inputs, tlist, then_branch, else_branch, name=name) 

855 

856 # Re-pack the outputs to restore any CompositeTensors 

857 return nest.pack_sequence_as(then_out, ret, expand_composites=True) 

858 

859 

860def Gradient(inputs, f, name=None): 

861 r"""Computes the gradient function for function f via backpropagation. 

862 

863 Args: 

864 inputs: A list of tensors of size N + M. 

865 f: The function we want to compute the gradient for. The function 'f' must 

866 be a numerical function which takes N inputs and produces M outputs. Its 

867 gradient function 'g', which is a function taking N + M inputs and 

868 produces N outputs. I.e. if we have (y1, y2, ..., yM) = f(x1, x2, ..., 

869 xN), then, g is (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN, dL/dy1, 

870 dL/dy2, ..., dL/dyM), where L is a scalar-value function of (x1, x2, ..., 

871 xN) (e.g., the loss function). dL/dxi is the partial derivative of L with 

872 respect to xi. 

873 name: A name for the operation (optional). 

874 

875 Returns: 

876 A list of tensors of size N. 

877 """ 

878 # TODO(zhifengc): Pretty-print the above spec in latex. 

879 # TODO(zhfiengc): Needs some math expert to say the comment above better. 

880 tlist = [_.type for _ in f.definition.signature.input_arg] 

881 return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name) 

882 

883 

884def _GetInputDtypes(func): 

885 """Returns the input dtypes of func, excluding dtypes for captured inputs.""" 

886 if isinstance(func, function._DefinedFunction): # pylint: disable=protected-access 

887 return func.declared_input_types 

888 

889 # We assume that `func` is a ConcreteFunction here, but we are not able to 

890 # verify since importing eager function library will cause cyclic dependence. 

891 # 

892 # ConcreteFunction.inputs includes captured inputs. 

893 num_non_captured_inputs = len(func.inputs) - len(func.captured_inputs) 

894 inputs_without_captured = func.inputs[:num_non_captured_inputs] 

895 return [t.dtype for t in inputs_without_captured] 

896 

897 

898def _LoopBodyCaptureWrapper(func): 

899 """Returns a wrapper for `func` that handles loop-carried captured inputs.""" 

900 

901 @function.Defun(*_GetInputDtypes(func), func_name="%s_Wrapper" % func.name) 

902 def Wrapper(*args): 

903 """A wrapper that handles loop-carried captured inputs.""" 

904 result = func(*args) 

905 extra_args = tuple(function.get_extra_args()) 

906 # Nullary functions return an Operation. Normal functions can't do this 

907 # because their return values are converted to Tensors. 

908 if isinstance(result, ops.Operation): 

909 return extra_args 

910 # Unary functions return a single Tensor value. 

911 elif not isinstance(result, (list, tuple)): 

912 return (result,) + extra_args 

913 # N-ary functions return a tuple of Tensors. 

914 else: 

915 return result + type(result)(extra_args) 

916 

917 return Wrapper 

918 

919 

920# pylint: disable=invalid-name,protected-access 

921def While(input_, cond, body, name=None, hostmem=None): 

922 r"""output = input; While (Cond(output)) { output = Body(output) }. 

923 

924 Args: 

925 input_: A list of `Tensor` objects. A list of input tensors whose types are 

926 T. 

927 cond: . A function takes 'input' and returns a tensor. If the tensor is a 

928 scalar of non-boolean, the scalar is converted to a boolean 

929 according to the following rule: if the scalar is a numerical value, 

930 non-zero means True and zero means False; if the scalar is a string, 

931 non-empty means True and empty means False. If the tensor is not a 

932 scalar, non-emptiness means True and False otherwise. 

933 body: . A function takes a list of tensors and returns another list tensors. 

934 Both lists have the same types as specified by T. 

935 name: A name for the operation (optional). 

936 hostmem: A list of integer. If i is in the list, input[i] is a host memory 

937 tensor. 

938 

939 Raises: 

940 ValueError: if `cond` has implicitly captured inputs or if `cond` and `body` 

941 have different signatures. 

942 

943 Returns: 

944 A list of `Tensor` objects. Has the same type as `input`. 

945 A list of output tensors whose types are T. 

946 """ 

947 if cond.captured_inputs: 

948 raise ValueError( 

949 "The 'cond' argument can not have implicitly captured inputs. Received " 

950 f"captured_inputs: {cond.captured_inputs}") 

951 

952 cond_input_types = _GetInputDtypes(cond) 

953 body_input_types = _GetInputDtypes(body) 

954 

955 if cond_input_types != body_input_types: 

956 raise ValueError( 

957 "The 'cond' and 'body' signatures do not match. Received: " 

958 f"cond_input_types={cond_input_types}, body_input_types=" 

959 f"{body_input_types}") 

960 

961 if body.captured_inputs: 

962 cond_dtypes = list(body_input_types) + [ 

963 t.dtype for t in body.captured_inputs 

964 ] 

965 

966 @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name) 

967 def CondWrapper(*args): 

968 """A wrapper that handles loop-carried captured inputs.""" 

969 return cond(*args[:len(body_input_types)]) 

970 

971 ret = gen_functional_ops._while( 

972 input_ + body.captured_inputs, 

973 CondWrapper, 

974 _LoopBodyCaptureWrapper(body), 

975 name=name) 

976 # Slice off the loop-carried captured inputs. 

977 ret = ret[:-len(body.captured_inputs)] 

978 else: 

979 ret = gen_functional_ops._while(input_, cond, body, name=name) 

980 if hostmem: 

981 input_attr = attr_value_pb2.AttrValue() 

982 input_attr.list.i.extend(hostmem) 

983 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 

984 

985 output_attr = attr_value_pb2.AttrValue() 

986 output_attr.list.i.extend(hostmem) 

987 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 

988 return ret 

989 

990 

991# b/36459430 

992# 

993# Ideally, we do not need this rewrite For loop into a While loop. 

994# However, today, if a While runs on GPU and the condition returns a 

995# boolean, the While kernel crashes. Even if we fix the crash, the 

996# bool needs to be copied between GPU and CPU. So, a for loop is much 

997# preferred when running on GPU. 

998# 

999# On the other hand, For op has no directly XLA kernel. So, when we run 

1000# a for loop, we need to rewrite it using a While op. 

1001# 

1002# It should be possible and probably better to write a XLA C++ kernel 

1003# implementing the logic in _ForUsingWhile. 

1004def _ForUsingWhile(start, 

1005 limit, 

1006 delta, 

1007 inputs, 

1008 forbody, 

1009 name=None, 

1010 hostmem=None): 

1011 """Helper to implement a For loop using a While.""" 

1012 # To support negative delta (e.g., range(100, 0, -3)), we iterate 

1013 # over the range(n) and use iter * delta + start as the real 

1014 # iteration index. (e.g., for i in range(34): iter = i * (-3) + 

1015 # 100). 

1016 d = math_ops.abs(delta) 

1017 # XLA on TPUs doesn't support integer division 

1018 n = math_ops.cast( 

1019 math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) / 

1020 math_ops.cast(d, dtypes.float32), dtypes.int32) 

1021 

1022 # Carried loop variables ("extra_args") are implicitly added to the input list 

1023 # of the WhileBody function. WhileCond does not call forbody, and so does not 

1024 # depend on any of forbody's extra_args. Since WhileCond and WhileBody 

1025 # must have identical inputs, we have to augment the cond signature to take 

1026 # the same types as the carried loop variables. 

1027 body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:] 

1028 

1029 cond_name = "%s_Cond" % forbody.name 

1030 

1031 @function.Defun(*body_sig, func_name=cond_name) 

1032 def WhileCond(i, n, *args): 

1033 del args 

1034 return i < n 

1035 

1036 body_name = "%s_Body" % forbody.name 

1037 

1038 @function.Defun(*body_sig, func_name=body_name) 

1039 def WhileBody(i, n, start, delta, *args): 

1040 """A While wrapper for forbody that handles loop-carried captured inputs.""" 

1041 for_result = forbody(start + i * delta, *args) 

1042 # Nullary functions return an Operation. Normal functions can't do this 

1043 # because their return values are converted to Tensors. 

1044 if isinstance(for_result, ops.Operation): 

1045 for_result = () 

1046 # Unary functions return a single Tensor value. 

1047 elif isinstance(for_result, ops.Tensor): 

1048 for_result = (for_result,) 

1049 return (i + 1, n, start, delta) + tuple(for_result) 

1050 

1051 if hostmem is not None: 

1052 hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem] 

1053 else: 

1054 hostmem = [0, 1, 2, 3] 

1055 

1056 results = While( 

1057 input_=[0, n, start, delta] + inputs, 

1058 cond=WhileCond, 

1059 body=WhileBody, 

1060 name=name, 

1061 hostmem=hostmem) 

1062 # Slice off the loop-carried captured inputs. 

1063 return list(results[4:len(results)]) 

1064 

1065 

1066def For(start, 

1067 limit, 

1068 delta, 

1069 inputs, 

1070 body, 

1071 name=None, 

1072 hostmem=None, 

1073 rewrite_with_while=None): 

1074 r"""out = input; for i in range(start, limit, delta) out = body(i, out). 

1075 

1076 Args: 

1077 start: A `Tensor` of type `int32`. 

1078 limit: A `Tensor` of type `int32`. 

1079 delta: A `Tensor` of type `int32`. 

1080 inputs: A list of `Tensor` objects. A list of input tensors whose types are 

1081 T. 

1082 body: A function takes a list of tensors and returns another list of 

1083 tensors. Both lists have the same types as (int32, T...). 

1084 name: A name for the operation (optional). 

1085 hostmem: A list of integer. If i is in the list, inputs[i] is a host memory 

1086 tensor. In other words, (i+1)-th argument of the body function is 

1087 expecting a host memory. 

1088 rewrite_with_while: If True, using While op to implement the For. 

1089 

1090 Returns: 

1091 A list of `Tensor` objects. Has the same type as `input`. 

1092 A list of output tensors whose types are T. 

1093 """ 

1094 if rewrite_with_while: 

1095 return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem) 

1096 if body.captured_inputs: 

1097 ret = gen_functional_ops._for( 

1098 start, 

1099 limit, 

1100 delta, 

1101 inputs + body.captured_inputs, 

1102 _LoopBodyCaptureWrapper(body), 

1103 name=name) 

1104 # Slice off the loop-carried captured inputs. 

1105 ret = ret[:-len(body.captured_inputs)] 

1106 else: 

1107 ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name) 

1108 if hostmem: 

1109 num_for_params = 3 # start/limit/delta 

1110 

1111 input_attr = attr_value_pb2.AttrValue() 

1112 input_attr.list.i.extend([num_for_params + i for i in hostmem]) 

1113 ret[0].op._set_attr("_input_hostmem", input_attr) # pylint: disable=protected-access 

1114 

1115 output_attr = attr_value_pb2.AttrValue() 

1116 output_attr.list.i.extend(hostmem) 

1117 ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access 

1118 return ret