Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gen_batch_ops.py: 10%

274 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Python wrappers around TensorFlow ops. 

2 

3This file is MACHINE GENERATED! Do not edit. 

4""" 

5 

6import collections 

7 

8from tensorflow.python import pywrap_tfe as pywrap_tfe 

9from tensorflow.python.eager import context as _context 

10from tensorflow.python.eager import core as _core 

11from tensorflow.python.eager import execute as _execute 

12from tensorflow.python.framework import dtypes as _dtypes 

13from tensorflow.security.fuzzing.py import annotation_types as _atypes 

14 

15from tensorflow.python.framework import op_def_registry as _op_def_registry 

16from tensorflow.python.framework import ops as _ops 

17from tensorflow.python.framework import op_def_library as _op_def_library 

18from tensorflow.python.util.deprecation import deprecated_endpoints 

19from tensorflow.python.util import dispatch as _dispatch 

20from tensorflow.python.util.tf_export import tf_export 

21 

22from typing import TypeVar 

23_BatchOutput = collections.namedtuple( 

24 "Batch", 

25 ["batched_tensors", "batch_index", "id"]) 

26 

27 

28def batch(in_tensors, num_batch_threads, max_batch_size, batch_timeout_micros, grad_timeout_micros, max_enqueued_batches=10, allowed_batch_sizes=[], container="", shared_name="", batching_queue="", name=None): 

29 r"""Batches all input tensors nondeterministically. 

30 

31 When many instances of this Op are being run concurrently with the same 

32 container/shared_name in the same device, some will output zero-shaped Tensors 

33 and others will output Tensors of size up to max_batch_size. 

34 

35 All Tensors in in_tensors are batched together (so, for example, labels and 

36 features should be batched with a single instance of this operation. 

37 

38 Each invocation of batch emits an `id` scalar which will be used to identify 

39 this particular invocation when doing unbatch or its gradient. 

40 

41 Each op which emits a non-empty batch will also emit a non-empty batch_index 

42 Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, 

43 start, and length of elements of each set of Tensors present in batched_tensors. 

44 

45 Batched tensors are concatenated along the first dimension, and all tensors in 

46 in_tensors must have the first dimension of the same size. 

47 

48 in_tensors: The tensors to be batched. 

49 num_batch_threads: Number of scheduling threads for processing batches of work. 

50 Determines the number of batches processed in parallel. 

51 max_batch_size: Batch sizes will never be bigger than this. 

52 batch_timeout_micros: Maximum number of microseconds to wait before outputting 

53 an incomplete batch. 

54 allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does 

55 nothing. Otherwise, supplies a list of batch sizes, causing the op to pad 

56 batches up to one of those sizes. The entries must increase monotonically, and 

57 the final entry must equal max_batch_size. 

58 grad_timeout_micros: The timeout to use for the gradient. See Unbatch. 

59 batched_tensors: Either empty tensors or a batch of concatenated Tensors. 

60 batch_index: If out_tensors is non-empty, has information to invert it. 

61 container: Controls the scope of sharing of this batch. 

62 id: always contains a scalar with a unique ID for this invocation of Batch. 

63 shared_name: Concurrently running instances of batch in the same device with the 

64 same container and shared_name will batch their elements together. If left 

65 empty, the op name will be used as the shared name. 

66 T: the types of tensors to be batched. 

67 

68 Args: 

69 in_tensors: A list of `Tensor` objects. 

70 num_batch_threads: An `int`. 

71 max_batch_size: An `int`. 

72 batch_timeout_micros: An `int`. 

73 grad_timeout_micros: An `int`. 

74 max_enqueued_batches: An optional `int`. Defaults to `10`. 

75 allowed_batch_sizes: An optional list of `ints`. Defaults to `[]`. 

76 container: An optional `string`. Defaults to `""`. 

77 shared_name: An optional `string`. Defaults to `""`. 

78 batching_queue: An optional `string`. Defaults to `""`. 

79 name: A name for the operation (optional). 

80 

81 Returns: 

82 A tuple of `Tensor` objects (batched_tensors, batch_index, id). 

83 

84 batched_tensors: A list of `Tensor` objects. Has the same type as `in_tensors`. 

85 batch_index: A `Tensor` of type `int64`. 

86 id: A `Tensor` of type `int64`. 

87 """ 

88 _ctx = _context._context or _context.context() 

89 tld = _ctx._thread_local_data 

90 if tld.is_eager: 

91 try: 

92 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

93 _ctx, "Batch", name, in_tensors, "num_batch_threads", 

94 num_batch_threads, "max_batch_size", max_batch_size, 

95 "max_enqueued_batches", max_enqueued_batches, "batch_timeout_micros", 

96 batch_timeout_micros, "allowed_batch_sizes", allowed_batch_sizes, 

97 "grad_timeout_micros", grad_timeout_micros, "container", container, 

98 "shared_name", shared_name, "batching_queue", batching_queue) 

99 _result = _BatchOutput._make(_result) 

100 return _result 

101 except _core._NotOkStatusException as e: 

102 _ops.raise_from_not_ok_status(e, name) 

103 except _core._FallbackException: 

104 pass 

105 try: 

106 return batch_eager_fallback( 

107 in_tensors, num_batch_threads=num_batch_threads, 

108 max_batch_size=max_batch_size, 

109 max_enqueued_batches=max_enqueued_batches, 

110 batch_timeout_micros=batch_timeout_micros, 

111 allowed_batch_sizes=allowed_batch_sizes, 

112 grad_timeout_micros=grad_timeout_micros, container=container, 

113 shared_name=shared_name, batching_queue=batching_queue, name=name, 

114 ctx=_ctx) 

115 except _core._SymbolicException: 

116 pass # Add nodes to the TensorFlow graph. 

117 # Add nodes to the TensorFlow graph. 

118 num_batch_threads = _execute.make_int(num_batch_threads, "num_batch_threads") 

119 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

120 batch_timeout_micros = _execute.make_int(batch_timeout_micros, "batch_timeout_micros") 

121 grad_timeout_micros = _execute.make_int(grad_timeout_micros, "grad_timeout_micros") 

122 if max_enqueued_batches is None: 

123 max_enqueued_batches = 10 

124 max_enqueued_batches = _execute.make_int(max_enqueued_batches, "max_enqueued_batches") 

125 if allowed_batch_sizes is None: 

126 allowed_batch_sizes = [] 

127 if not isinstance(allowed_batch_sizes, (list, tuple)): 

128 raise TypeError( 

129 "Expected list for 'allowed_batch_sizes' argument to " 

130 "'batch' Op, not %r." % allowed_batch_sizes) 

131 allowed_batch_sizes = [_execute.make_int(_i, "allowed_batch_sizes") for _i in allowed_batch_sizes] 

132 if container is None: 

133 container = "" 

134 container = _execute.make_str(container, "container") 

135 if shared_name is None: 

136 shared_name = "" 

137 shared_name = _execute.make_str(shared_name, "shared_name") 

138 if batching_queue is None: 

139 batching_queue = "" 

140 batching_queue = _execute.make_str(batching_queue, "batching_queue") 

141 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

142 "Batch", in_tensors=in_tensors, num_batch_threads=num_batch_threads, 

143 max_batch_size=max_batch_size, 

144 batch_timeout_micros=batch_timeout_micros, 

145 grad_timeout_micros=grad_timeout_micros, 

146 max_enqueued_batches=max_enqueued_batches, 

147 allowed_batch_sizes=allowed_batch_sizes, container=container, 

148 shared_name=shared_name, batching_queue=batching_queue, 

149 name=name) 

150 _result = _outputs[:] 

151 if _execute.must_record_gradient(): 

152 _attrs = ("num_batch_threads", _op._get_attr_int("num_batch_threads"), 

153 "max_batch_size", _op._get_attr_int("max_batch_size"), 

154 "max_enqueued_batches", 

155 _op._get_attr_int("max_enqueued_batches"), 

156 "batch_timeout_micros", 

157 _op._get_attr_int("batch_timeout_micros"), 

158 "allowed_batch_sizes", _op.get_attr("allowed_batch_sizes"), 

159 "grad_timeout_micros", _op._get_attr_int("grad_timeout_micros"), 

160 "container", _op.get_attr("container"), "shared_name", 

161 _op.get_attr("shared_name"), "batching_queue", 

162 _op.get_attr("batching_queue"), "T", _op.get_attr("T")) 

163 _inputs_flat = _op.inputs 

164 _execute.record_gradient( 

165 "Batch", _inputs_flat, _attrs, _result) 

166 _result = [_result[:len(in_tensors)]] + _result[len(in_tensors):] 

167 _result = _BatchOutput._make(_result) 

168 return _result 

169 

170Batch = tf_export("raw_ops.Batch")(_ops.to_raw_op(batch)) 

171 

172 

173def batch_eager_fallback(in_tensors, num_batch_threads, max_batch_size, batch_timeout_micros, grad_timeout_micros, max_enqueued_batches, allowed_batch_sizes, container, shared_name, batching_queue, name, ctx): 

174 num_batch_threads = _execute.make_int(num_batch_threads, "num_batch_threads") 

175 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

176 batch_timeout_micros = _execute.make_int(batch_timeout_micros, "batch_timeout_micros") 

177 grad_timeout_micros = _execute.make_int(grad_timeout_micros, "grad_timeout_micros") 

178 if max_enqueued_batches is None: 

179 max_enqueued_batches = 10 

180 max_enqueued_batches = _execute.make_int(max_enqueued_batches, "max_enqueued_batches") 

181 if allowed_batch_sizes is None: 

182 allowed_batch_sizes = [] 

183 if not isinstance(allowed_batch_sizes, (list, tuple)): 

184 raise TypeError( 

185 "Expected list for 'allowed_batch_sizes' argument to " 

186 "'batch' Op, not %r." % allowed_batch_sizes) 

187 allowed_batch_sizes = [_execute.make_int(_i, "allowed_batch_sizes") for _i in allowed_batch_sizes] 

188 if container is None: 

189 container = "" 

190 container = _execute.make_str(container, "container") 

191 if shared_name is None: 

192 shared_name = "" 

193 shared_name = _execute.make_str(shared_name, "shared_name") 

194 if batching_queue is None: 

195 batching_queue = "" 

196 batching_queue = _execute.make_str(batching_queue, "batching_queue") 

197 _attr_T, in_tensors = _execute.convert_to_mixed_eager_tensors(in_tensors, ctx) 

198 _inputs_flat = list(in_tensors) 

199 _attrs = ("num_batch_threads", num_batch_threads, "max_batch_size", 

200 max_batch_size, "max_enqueued_batches", max_enqueued_batches, 

201 "batch_timeout_micros", batch_timeout_micros, "allowed_batch_sizes", 

202 allowed_batch_sizes, "grad_timeout_micros", grad_timeout_micros, 

203 "container", container, "shared_name", shared_name, "batching_queue", 

204 batching_queue, "T", _attr_T) 

205 _result = _execute.execute(b"Batch", len(in_tensors) + 2, 

206 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

207 name=name) 

208 if _execute.must_record_gradient(): 

209 _execute.record_gradient( 

210 "Batch", _inputs_flat, _attrs, _result) 

211 _result = [_result[:len(in_tensors)]] + _result[len(in_tensors):] 

212 _result = _BatchOutput._make(_result) 

213 return _result 

214 

215 

216def batch_function(in_tensors, captured_tensors, f, num_batch_threads, max_batch_size, batch_timeout_micros, Tout, max_enqueued_batches=10, allowed_batch_sizes=[], container="", shared_name="", batching_queue="", enable_large_batch_splitting=False, name=None): 

217 r"""Batches all the inputs tensors to the computation done by the function. 

218 

219 So, for example, in the following code 

220 

221 ```python 

222 

223 # This input will be captured. 

224 y = tf.placeholder_with_default(1.0, shape=[]) 

225 

226 @tf.Defun(tf.float32) 

227 def computation(a): 

228 return tf.matmul(a, a) + y 

229 

230 b = gen_batch_ops.batch_function( 

231 f=computation 

232 in_tensors=[a], 

233 captured_tensors=computation.captured_inputs, 

234 Tout=[o.type for o in computation.definition.signature.output_arg], 

235 num_batch_threads=1, 

236 max_batch_size=10, 

237 batch_timeout_micros=100000, # 100ms 

238 allowed_batch_sizes=[3, 10], 

239 batching_queue="") 

240 ``` 

241 

242 If more than one session.run call is simultaneously trying to compute `b` 

243 the values of `a` will be gathered, non-deterministically concatenated 

244 along the first axis, and only one thread will run the computation. 

245 

246 Assumes that all arguments of the function are Tensors which will be batched 

247 along their first dimension. 

248 

249 Arguments that are captured, are not batched. The session.run call which does 

250 the concatenation, will use the values of the captured tensors available to it. 

251 Therefore, typical uses of captured tensors should involve values which remain 

252 unchanged across session.run calls. Inference is a good example of this. 

253 

254 SparseTensor is not supported. The return value of the decorated function 

255 must be a Tensor or a list/tuple of Tensors. 

256 

257 Args: 

258 in_tensors: A list of `Tensor` objects. The tensors to be batched. 

259 captured_tensors: A list of `Tensor` objects. 

260 The tensors which are captured in the function, and don't need 

261 to be batched. 

262 f: A function decorated with @Defun. 

263 num_batch_threads: An `int`. 

264 Number of scheduling threads for processing batches of work. 

265 Determines the number of batches processed in parallel. 

266 max_batch_size: An `int`. Batch sizes will never be bigger than this. 

267 batch_timeout_micros: An `int`. 

268 Maximum number of microseconds to wait before outputting 

269 an incomplete batch. 

270 Tout: A list of `tf.DTypes` that has length `>= 1`. 

271 the types of the output tensors. 

272 max_enqueued_batches: An optional `int`. Defaults to `10`. 

273 Maximum number of batches enqueued. Default: 10. 

274 allowed_batch_sizes: An optional list of `ints`. Defaults to `[]`. 

275 Optional list of allowed batch sizes. If left empty, does 

276 nothing. Otherwise, supplies a list of batch sizes, causing the op to pad 

277 batches up to one of those sizes. The entries must increase monotonically. 

278 If enable_large_batch_splitting is false (i.e., large-input-split is not 

279 enabled) the final entry must equal max_batch_size. 

280 container: An optional `string`. Defaults to `""`. 

281 Controls the scope of sharing of this batch. 

282 shared_name: An optional `string`. Defaults to `""`. 

283 Concurrently running instances of batch in the same device with the 

284 same container and shared_name will batch their elements together. If left 

285 empty, the op name will be used as the shared name. 

286 batching_queue: An optional `string`. Defaults to `""`. 

287 enable_large_batch_splitting: An optional `bool`. Defaults to `False`. 

288 input with a large size (i.e., larger than the largest value of 

289 `allowed_batch_sizes`) will be splitted into multiple batches with batch size. 

290 name: A name for the operation (optional). 

291 

292 Returns: 

293 A list of `Tensor` objects of type `Tout`. 

294 """ 

295 _ctx = _context._context or _context.context() 

296 tld = _ctx._thread_local_data 

297 if tld.is_eager: 

298 try: 

299 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

300 _ctx, "BatchFunction", name, in_tensors, captured_tensors, "f", f, 

301 "num_batch_threads", num_batch_threads, "max_batch_size", 

302 max_batch_size, "batch_timeout_micros", batch_timeout_micros, 

303 "max_enqueued_batches", max_enqueued_batches, "allowed_batch_sizes", 

304 allowed_batch_sizes, "container", container, "shared_name", 

305 shared_name, "batching_queue", batching_queue, "Tout", Tout, 

306 "enable_large_batch_splitting", enable_large_batch_splitting) 

307 return _result 

308 except _core._NotOkStatusException as e: 

309 _ops.raise_from_not_ok_status(e, name) 

310 except _core._FallbackException: 

311 pass 

312 try: 

313 return batch_function_eager_fallback( 

314 in_tensors, captured_tensors, f=f, 

315 num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, 

316 batch_timeout_micros=batch_timeout_micros, 

317 max_enqueued_batches=max_enqueued_batches, 

318 allowed_batch_sizes=allowed_batch_sizes, container=container, 

319 shared_name=shared_name, batching_queue=batching_queue, Tout=Tout, 

320 enable_large_batch_splitting=enable_large_batch_splitting, 

321 name=name, ctx=_ctx) 

322 except _core._SymbolicException: 

323 pass # Add nodes to the TensorFlow graph. 

324 # Add nodes to the TensorFlow graph. 

325 num_batch_threads = _execute.make_int(num_batch_threads, "num_batch_threads") 

326 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

327 batch_timeout_micros = _execute.make_int(batch_timeout_micros, "batch_timeout_micros") 

328 if not isinstance(Tout, (list, tuple)): 

329 raise TypeError( 

330 "Expected list for 'Tout' argument to " 

331 "'batch_function' Op, not %r." % Tout) 

332 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

333 if max_enqueued_batches is None: 

334 max_enqueued_batches = 10 

335 max_enqueued_batches = _execute.make_int(max_enqueued_batches, "max_enqueued_batches") 

336 if allowed_batch_sizes is None: 

337 allowed_batch_sizes = [] 

338 if not isinstance(allowed_batch_sizes, (list, tuple)): 

339 raise TypeError( 

340 "Expected list for 'allowed_batch_sizes' argument to " 

341 "'batch_function' Op, not %r." % allowed_batch_sizes) 

342 allowed_batch_sizes = [_execute.make_int(_i, "allowed_batch_sizes") for _i in allowed_batch_sizes] 

343 if container is None: 

344 container = "" 

345 container = _execute.make_str(container, "container") 

346 if shared_name is None: 

347 shared_name = "" 

348 shared_name = _execute.make_str(shared_name, "shared_name") 

349 if batching_queue is None: 

350 batching_queue = "" 

351 batching_queue = _execute.make_str(batching_queue, "batching_queue") 

352 if enable_large_batch_splitting is None: 

353 enable_large_batch_splitting = False 

354 enable_large_batch_splitting = _execute.make_bool(enable_large_batch_splitting, "enable_large_batch_splitting") 

355 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

356 "BatchFunction", in_tensors=in_tensors, 

357 captured_tensors=captured_tensors, f=f, 

358 num_batch_threads=num_batch_threads, 

359 max_batch_size=max_batch_size, 

360 batch_timeout_micros=batch_timeout_micros, Tout=Tout, 

361 max_enqueued_batches=max_enqueued_batches, 

362 allowed_batch_sizes=allowed_batch_sizes, 

363 container=container, shared_name=shared_name, 

364 batching_queue=batching_queue, 

365 enable_large_batch_splitting=enable_large_batch_splitting, 

366 name=name) 

367 _result = _outputs[:] 

368 if _execute.must_record_gradient(): 

369 _attrs = ("f", _op.get_attr("f"), "num_batch_threads", 

370 _op._get_attr_int("num_batch_threads"), "max_batch_size", 

371 _op._get_attr_int("max_batch_size"), "batch_timeout_micros", 

372 _op._get_attr_int("batch_timeout_micros"), 

373 "max_enqueued_batches", 

374 _op._get_attr_int("max_enqueued_batches"), 

375 "allowed_batch_sizes", _op.get_attr("allowed_batch_sizes"), 

376 "container", _op.get_attr("container"), "shared_name", 

377 _op.get_attr("shared_name"), "batching_queue", 

378 _op.get_attr("batching_queue"), "Tin", _op.get_attr("Tin"), 

379 "Tcaptured", _op.get_attr("Tcaptured"), "Tout", 

380 _op.get_attr("Tout"), "enable_large_batch_splitting", 

381 _op._get_attr_bool("enable_large_batch_splitting")) 

382 _inputs_flat = _op.inputs 

383 _execute.record_gradient( 

384 "BatchFunction", _inputs_flat, _attrs, _result) 

385 return _result 

386 

387BatchFunction = tf_export("raw_ops.BatchFunction")(_ops.to_raw_op(batch_function)) 

388 

389 

390def batch_function_eager_fallback(in_tensors, captured_tensors, f, num_batch_threads, max_batch_size, batch_timeout_micros, Tout, max_enqueued_batches, allowed_batch_sizes, container, shared_name, batching_queue, enable_large_batch_splitting, name, ctx): 

391 num_batch_threads = _execute.make_int(num_batch_threads, "num_batch_threads") 

392 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

393 batch_timeout_micros = _execute.make_int(batch_timeout_micros, "batch_timeout_micros") 

394 if not isinstance(Tout, (list, tuple)): 

395 raise TypeError( 

396 "Expected list for 'Tout' argument to " 

397 "'batch_function' Op, not %r." % Tout) 

398 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

399 if max_enqueued_batches is None: 

400 max_enqueued_batches = 10 

401 max_enqueued_batches = _execute.make_int(max_enqueued_batches, "max_enqueued_batches") 

402 if allowed_batch_sizes is None: 

403 allowed_batch_sizes = [] 

404 if not isinstance(allowed_batch_sizes, (list, tuple)): 

405 raise TypeError( 

406 "Expected list for 'allowed_batch_sizes' argument to " 

407 "'batch_function' Op, not %r." % allowed_batch_sizes) 

408 allowed_batch_sizes = [_execute.make_int(_i, "allowed_batch_sizes") for _i in allowed_batch_sizes] 

409 if container is None: 

410 container = "" 

411 container = _execute.make_str(container, "container") 

412 if shared_name is None: 

413 shared_name = "" 

414 shared_name = _execute.make_str(shared_name, "shared_name") 

415 if batching_queue is None: 

416 batching_queue = "" 

417 batching_queue = _execute.make_str(batching_queue, "batching_queue") 

418 if enable_large_batch_splitting is None: 

419 enable_large_batch_splitting = False 

420 enable_large_batch_splitting = _execute.make_bool(enable_large_batch_splitting, "enable_large_batch_splitting") 

421 _attr_Tin, in_tensors = _execute.convert_to_mixed_eager_tensors(in_tensors, ctx) 

422 _attr_Tcaptured, captured_tensors = _execute.convert_to_mixed_eager_tensors(captured_tensors, ctx) 

423 _inputs_flat = list(in_tensors) + list(captured_tensors) 

424 _attrs = ("f", f, "num_batch_threads", num_batch_threads, "max_batch_size", 

425 max_batch_size, "batch_timeout_micros", batch_timeout_micros, 

426 "max_enqueued_batches", max_enqueued_batches, "allowed_batch_sizes", 

427 allowed_batch_sizes, "container", container, "shared_name", shared_name, 

428 "batching_queue", batching_queue, "Tin", _attr_Tin, "Tcaptured", 

429 _attr_Tcaptured, "Tout", Tout, "enable_large_batch_splitting", 

430 enable_large_batch_splitting) 

431 _result = _execute.execute(b"BatchFunction", len(Tout), inputs=_inputs_flat, 

432 attrs=_attrs, ctx=ctx, name=name) 

433 if _execute.must_record_gradient(): 

434 _execute.record_gradient( 

435 "BatchFunction", _inputs_flat, _attrs, _result) 

436 return _result 

437 

438 

439def unbatch(batched_tensor, batch_index, id, timeout_micros, container="", shared_name="", name=None): 

440 r"""Reverses the operation of Batch for a single output Tensor. 

441 

442 An instance of Unbatch either receives an empty batched_tensor, in which case it 

443 asynchronously waits until the values become available from a concurrently 

444 running instance of Unbatch with the same container and shared_name, or receives 

445 a non-empty batched_tensor in which case it finalizes all other concurrently 

446 running instances and outputs its own element from the batch. 

447 

448 batched_tensor: The possibly transformed output of Batch. The size of the first 

449 dimension should remain unchanged by the transformations for the operation to 

450 work. 

451 batch_index: The matching batch_index obtained from Batch. 

452 id: The id scalar emitted by Batch. 

453 unbatched_tensor: The Tensor corresponding to this execution. 

454 timeout_micros: Maximum amount of time (in microseconds) to wait to receive the 

455 batched input tensor associated with a given invocation of the op. 

456 container: Container to control resource sharing. 

457 shared_name: Instances of Unbatch with the same container and shared_name are 

458 assumed to possibly belong to the same batch. If left empty, the op name will 

459 be used as the shared name. 

460 

461 Args: 

462 batched_tensor: A `Tensor`. 

463 batch_index: A `Tensor` of type `int64`. 

464 id: A `Tensor` of type `int64`. 

465 timeout_micros: An `int`. 

466 container: An optional `string`. Defaults to `""`. 

467 shared_name: An optional `string`. Defaults to `""`. 

468 name: A name for the operation (optional). 

469 

470 Returns: 

471 A `Tensor`. Has the same type as `batched_tensor`. 

472 """ 

473 _ctx = _context._context or _context.context() 

474 tld = _ctx._thread_local_data 

475 if tld.is_eager: 

476 try: 

477 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

478 _ctx, "Unbatch", name, batched_tensor, batch_index, id, 

479 "timeout_micros", timeout_micros, "container", container, 

480 "shared_name", shared_name) 

481 return _result 

482 except _core._NotOkStatusException as e: 

483 _ops.raise_from_not_ok_status(e, name) 

484 except _core._FallbackException: 

485 pass 

486 try: 

487 return unbatch_eager_fallback( 

488 batched_tensor, batch_index, id, timeout_micros=timeout_micros, 

489 container=container, shared_name=shared_name, name=name, ctx=_ctx) 

490 except _core._SymbolicException: 

491 pass # Add nodes to the TensorFlow graph. 

492 # Add nodes to the TensorFlow graph. 

493 timeout_micros = _execute.make_int(timeout_micros, "timeout_micros") 

494 if container is None: 

495 container = "" 

496 container = _execute.make_str(container, "container") 

497 if shared_name is None: 

498 shared_name = "" 

499 shared_name = _execute.make_str(shared_name, "shared_name") 

500 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

501 "Unbatch", batched_tensor=batched_tensor, batch_index=batch_index, 

502 id=id, timeout_micros=timeout_micros, container=container, 

503 shared_name=shared_name, name=name) 

504 _result = _outputs[:] 

505 if _execute.must_record_gradient(): 

506 _attrs = ("timeout_micros", _op._get_attr_int("timeout_micros"), 

507 "container", _op.get_attr("container"), "shared_name", 

508 _op.get_attr("shared_name"), "T", _op._get_attr_type("T")) 

509 _inputs_flat = _op.inputs 

510 _execute.record_gradient( 

511 "Unbatch", _inputs_flat, _attrs, _result) 

512 _result, = _result 

513 return _result 

514 

515Unbatch = tf_export("raw_ops.Unbatch")(_ops.to_raw_op(unbatch)) 

516 

517 

518def unbatch_eager_fallback(batched_tensor, batch_index, id, timeout_micros, container, shared_name, name, ctx): 

519 timeout_micros = _execute.make_int(timeout_micros, "timeout_micros") 

520 if container is None: 

521 container = "" 

522 container = _execute.make_str(container, "container") 

523 if shared_name is None: 

524 shared_name = "" 

525 shared_name = _execute.make_str(shared_name, "shared_name") 

526 _attr_T, (batched_tensor,) = _execute.args_to_matching_eager([batched_tensor], ctx, []) 

527 batch_index = _ops.convert_to_tensor(batch_index, _dtypes.int64) 

528 id = _ops.convert_to_tensor(id, _dtypes.int64) 

529 _inputs_flat = [batched_tensor, batch_index, id] 

530 _attrs = ("timeout_micros", timeout_micros, "container", container, 

531 "shared_name", shared_name, "T", _attr_T) 

532 _result = _execute.execute(b"Unbatch", 1, inputs=_inputs_flat, attrs=_attrs, 

533 ctx=ctx, name=name) 

534 if _execute.must_record_gradient(): 

535 _execute.record_gradient( 

536 "Unbatch", _inputs_flat, _attrs, _result) 

537 _result, = _result 

538 return _result 

539 

540 

541def unbatch_grad(original_input, batch_index, grad, id, container="", shared_name="", name=None): 

542 r"""Gradient of Unbatch. 

543 

544 Acts like Batch but using the given batch_index index of batching things as they 

545 become available. This ensures that the gradients are propagated back in the 

546 same session which did the forward pass. 

547 

548 original_input: The input to the Unbatch operation this is the gradient of. 

549 batch_index: The batch_index given to the Unbatch operation this is the gradient 

550 of. 

551 grad: The downstream gradient. 

552 id: The id scalar emitted by Batch. 

553 batched_grad: The return value, either an empty tensor or the batched gradient. 

554 container: Container to control resource sharing. 

555 shared_name: Instances of UnbatchGrad with the same container and shared_name 

556 are assumed to possibly belong to the same batch. If left empty, the op name 

557 will be used as the shared name. 

558 

559 Args: 

560 original_input: A `Tensor`. 

561 batch_index: A `Tensor` of type `int64`. 

562 grad: A `Tensor`. Must have the same type as `original_input`. 

563 id: A `Tensor` of type `int64`. 

564 container: An optional `string`. Defaults to `""`. 

565 shared_name: An optional `string`. Defaults to `""`. 

566 name: A name for the operation (optional). 

567 

568 Returns: 

569 A `Tensor`. Has the same type as `original_input`. 

570 """ 

571 _ctx = _context._context or _context.context() 

572 tld = _ctx._thread_local_data 

573 if tld.is_eager: 

574 try: 

575 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

576 _ctx, "UnbatchGrad", name, original_input, batch_index, grad, id, 

577 "container", container, "shared_name", shared_name) 

578 return _result 

579 except _core._NotOkStatusException as e: 

580 _ops.raise_from_not_ok_status(e, name) 

581 except _core._FallbackException: 

582 pass 

583 try: 

584 return unbatch_grad_eager_fallback( 

585 original_input, batch_index, grad, id, container=container, 

586 shared_name=shared_name, name=name, ctx=_ctx) 

587 except _core._SymbolicException: 

588 pass # Add nodes to the TensorFlow graph. 

589 # Add nodes to the TensorFlow graph. 

590 if container is None: 

591 container = "" 

592 container = _execute.make_str(container, "container") 

593 if shared_name is None: 

594 shared_name = "" 

595 shared_name = _execute.make_str(shared_name, "shared_name") 

596 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

597 "UnbatchGrad", original_input=original_input, batch_index=batch_index, 

598 grad=grad, id=id, container=container, 

599 shared_name=shared_name, name=name) 

600 _result = _outputs[:] 

601 if _execute.must_record_gradient(): 

602 _attrs = ("container", _op.get_attr("container"), "shared_name", 

603 _op.get_attr("shared_name"), "T", _op._get_attr_type("T")) 

604 _inputs_flat = _op.inputs 

605 _execute.record_gradient( 

606 "UnbatchGrad", _inputs_flat, _attrs, _result) 

607 _result, = _result 

608 return _result 

609 

610UnbatchGrad = tf_export("raw_ops.UnbatchGrad")(_ops.to_raw_op(unbatch_grad)) 

611 

612 

613def unbatch_grad_eager_fallback(original_input, batch_index, grad, id, container, shared_name, name, ctx): 

614 if container is None: 

615 container = "" 

616 container = _execute.make_str(container, "container") 

617 if shared_name is None: 

618 shared_name = "" 

619 shared_name = _execute.make_str(shared_name, "shared_name") 

620 _attr_T, _inputs_T = _execute.args_to_matching_eager([original_input, grad], ctx, []) 

621 (original_input, grad) = _inputs_T 

622 batch_index = _ops.convert_to_tensor(batch_index, _dtypes.int64) 

623 id = _ops.convert_to_tensor(id, _dtypes.int64) 

624 _inputs_flat = [original_input, batch_index, grad, id] 

625 _attrs = ("container", container, "shared_name", shared_name, "T", _attr_T) 

626 _result = _execute.execute(b"UnbatchGrad", 1, inputs=_inputs_flat, 

627 attrs=_attrs, ctx=ctx, name=name) 

628 if _execute.must_record_gradient(): 

629 _execute.record_gradient( 

630 "UnbatchGrad", _inputs_flat, _attrs, _result) 

631 _result, = _result 

632 return _result 

633