Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/compiler/tf2tensorrt/ops/gen_trt_ops.py: 13%

387 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Python wrappers around TensorFlow ops. 

2 

3This file is MACHINE GENERATED! Do not edit. 

4""" 

5 

6import collections 

7 

8from tensorflow.python import pywrap_tfe as pywrap_tfe 

9from tensorflow.python.eager import context as _context 

10from tensorflow.python.eager import core as _core 

11from tensorflow.python.eager import execute as _execute 

12from tensorflow.python.framework import dtypes as _dtypes 

13from tensorflow.security.fuzzing.py import annotation_types as _atypes 

14 

15from tensorflow.python.framework import op_def_registry as _op_def_registry 

16from tensorflow.python.framework import ops as _ops 

17from tensorflow.python.framework import op_def_library as _op_def_library 

18from tensorflow.python.util.deprecation import deprecated_endpoints 

19from tensorflow.python.util import dispatch as _dispatch 

20from tensorflow.python.util.tf_export import tf_export 

21 

22from typing import TypeVar 

23 

24@_dispatch.add_fallback_dispatch_list 

25@_dispatch.add_type_based_api_dispatcher 

26@tf_export('create_trt_resource_handle') 

27def create_trt_resource_handle(resource_name, name=None): 

28 r"""TODO: add doc. 

29 

30 Args: 

31 resource_name: A `string`. 

32 name: A name for the operation (optional). 

33 

34 Returns: 

35 A `Tensor` of type `resource`. 

36 """ 

37 _ctx = _context._context or _context.context() 

38 tld = _ctx._thread_local_data 

39 if tld.is_eager: 

40 try: 

41 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

42 _ctx, "CreateTRTResourceHandle", name, "resource_name", resource_name) 

43 return _result 

44 except _core._NotOkStatusException as e: 

45 _ops.raise_from_not_ok_status(e, name) 

46 except _core._FallbackException: 

47 pass 

48 try: 

49 _result = _dispatcher_for_create_trt_resource_handle( 

50 (resource_name, name,), None) 

51 if _result is not NotImplemented: 

52 return _result 

53 return create_trt_resource_handle_eager_fallback( 

54 resource_name=resource_name, name=name, ctx=_ctx) 

55 except _core._SymbolicException: 

56 pass # Add nodes to the TensorFlow graph. 

57 except (TypeError, ValueError): 

58 _result = _dispatch.dispatch( 

59 create_trt_resource_handle, (), dict(resource_name=resource_name, 

60 name=name) 

61 ) 

62 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

63 return _result 

64 raise 

65 else: 

66 _result = _dispatcher_for_create_trt_resource_handle( 

67 (resource_name, name,), None) 

68 if _result is not NotImplemented: 

69 return _result 

70 # Add nodes to the TensorFlow graph. 

71 resource_name = _execute.make_str(resource_name, "resource_name") 

72 try: 

73 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

74 "CreateTRTResourceHandle", resource_name=resource_name, name=name) 

75 except (TypeError, ValueError): 

76 _result = _dispatch.dispatch( 

77 create_trt_resource_handle, (), dict(resource_name=resource_name, 

78 name=name) 

79 ) 

80 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

81 return _result 

82 raise 

83 _result = _outputs[:] 

84 if _execute.must_record_gradient(): 

85 _attrs = ("resource_name", _op.get_attr("resource_name")) 

86 _inputs_flat = _op.inputs 

87 _execute.record_gradient( 

88 "CreateTRTResourceHandle", _inputs_flat, _attrs, _result) 

89 _result, = _result 

90 return _result 

91 

92CreateTRTResourceHandle = tf_export("raw_ops.CreateTRTResourceHandle")(_ops.to_raw_op(create_trt_resource_handle)) 

93_dispatcher_for_create_trt_resource_handle = create_trt_resource_handle._tf_type_based_dispatcher.Dispatch 

94 

95 

96def create_trt_resource_handle_eager_fallback(resource_name, name, ctx): 

97 resource_name = _execute.make_str(resource_name, "resource_name") 

98 _inputs_flat = [] 

99 _attrs = ("resource_name", resource_name) 

100 _result = _execute.execute(b"CreateTRTResourceHandle", 1, 

101 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

102 name=name) 

103 if _execute.must_record_gradient(): 

104 _execute.record_gradient( 

105 "CreateTRTResourceHandle", _inputs_flat, _attrs, _result) 

106 _result, = _result 

107 return _result 

108 

109 

110@_dispatch.add_fallback_dispatch_list 

111@_dispatch.add_type_based_api_dispatcher 

112@tf_export('get_calibration_data_op') 

113def get_calibration_data_op(resource_name, name=None): 

114 r"""Returns calibration data for the given resource name 

115 

116 Args: 

117 resource_name: A `Tensor` of type `string`. 

118 name: A name for the operation (optional). 

119 

120 Returns: 

121 A `Tensor` of type `string`. 

122 """ 

123 _ctx = _context._context or _context.context() 

124 tld = _ctx._thread_local_data 

125 if tld.is_eager: 

126 try: 

127 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

128 _ctx, "GetCalibrationDataOp", name, resource_name) 

129 return _result 

130 except _core._NotOkStatusException as e: 

131 _ops.raise_from_not_ok_status(e, name) 

132 except _core._FallbackException: 

133 pass 

134 try: 

135 _result = _dispatcher_for_get_calibration_data_op( 

136 (resource_name, name,), None) 

137 if _result is not NotImplemented: 

138 return _result 

139 return get_calibration_data_op_eager_fallback( 

140 resource_name, name=name, ctx=_ctx) 

141 except _core._SymbolicException: 

142 pass # Add nodes to the TensorFlow graph. 

143 except (TypeError, ValueError): 

144 _result = _dispatch.dispatch( 

145 get_calibration_data_op, (), dict(resource_name=resource_name, 

146 name=name) 

147 ) 

148 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

149 return _result 

150 raise 

151 else: 

152 _result = _dispatcher_for_get_calibration_data_op( 

153 (resource_name, name,), None) 

154 if _result is not NotImplemented: 

155 return _result 

156 # Add nodes to the TensorFlow graph. 

157 try: 

158 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

159 "GetCalibrationDataOp", resource_name=resource_name, name=name) 

160 except (TypeError, ValueError): 

161 _result = _dispatch.dispatch( 

162 get_calibration_data_op, (), dict(resource_name=resource_name, 

163 name=name) 

164 ) 

165 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

166 return _result 

167 raise 

168 _result = _outputs[:] 

169 if _execute.must_record_gradient(): 

170 _attrs = () 

171 _inputs_flat = _op.inputs 

172 _execute.record_gradient( 

173 "GetCalibrationDataOp", _inputs_flat, _attrs, _result) 

174 _result, = _result 

175 return _result 

176 

177GetCalibrationDataOp = tf_export("raw_ops.GetCalibrationDataOp")(_ops.to_raw_op(get_calibration_data_op)) 

178_dispatcher_for_get_calibration_data_op = get_calibration_data_op._tf_type_based_dispatcher.Dispatch 

179 

180 

181def get_calibration_data_op_eager_fallback(resource_name, name, ctx): 

182 resource_name = _ops.convert_to_tensor(resource_name, _dtypes.string) 

183 _inputs_flat = [resource_name] 

184 _attrs = None 

185 _result = _execute.execute(b"GetCalibrationDataOp", 1, inputs=_inputs_flat, 

186 attrs=_attrs, ctx=ctx, name=name) 

187 if _execute.must_record_gradient(): 

188 _execute.record_gradient( 

189 "GetCalibrationDataOp", _inputs_flat, _attrs, _result) 

190 _result, = _result 

191 return _result 

192 

193 

194@_dispatch.add_fallback_dispatch_list 

195@_dispatch.add_type_based_api_dispatcher 

196@tf_export('initialize_trt_resource') 

197def initialize_trt_resource(resource_handle, filename, max_cached_engines_count=1, name=None): 

198 r"""TODO: add doc. 

199 

200 Args: 

201 resource_handle: A `Tensor` of type `resource`. 

202 filename: A `Tensor` of type `string`. 

203 max_cached_engines_count: An optional `int`. Defaults to `1`. 

204 name: A name for the operation (optional). 

205 

206 Returns: 

207 The created Operation. 

208 """ 

209 _ctx = _context._context or _context.context() 

210 tld = _ctx._thread_local_data 

211 if tld.is_eager: 

212 try: 

213 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

214 _ctx, "InitializeTRTResource", name, resource_handle, filename, 

215 "max_cached_engines_count", max_cached_engines_count) 

216 return _result 

217 except _core._NotOkStatusException as e: 

218 _ops.raise_from_not_ok_status(e, name) 

219 except _core._FallbackException: 

220 pass 

221 try: 

222 _result = _dispatcher_for_initialize_trt_resource( 

223 (resource_handle, filename, max_cached_engines_count, name,), None) 

224 if _result is not NotImplemented: 

225 return _result 

226 return initialize_trt_resource_eager_fallback( 

227 resource_handle, filename, 

228 max_cached_engines_count=max_cached_engines_count, name=name, 

229 ctx=_ctx) 

230 except _core._SymbolicException: 

231 pass # Add nodes to the TensorFlow graph. 

232 except (TypeError, ValueError): 

233 _result = _dispatch.dispatch( 

234 initialize_trt_resource, (), dict(resource_handle=resource_handle, 

235 filename=filename, 

236 max_cached_engines_count=max_cached_engines_count, 

237 name=name) 

238 ) 

239 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

240 return _result 

241 raise 

242 else: 

243 _result = _dispatcher_for_initialize_trt_resource( 

244 (resource_handle, filename, max_cached_engines_count, name,), None) 

245 if _result is not NotImplemented: 

246 return _result 

247 # Add nodes to the TensorFlow graph. 

248 if max_cached_engines_count is None: 

249 max_cached_engines_count = 1 

250 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count") 

251 try: 

252 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

253 "InitializeTRTResource", resource_handle=resource_handle, 

254 filename=filename, 

255 max_cached_engines_count=max_cached_engines_count, 

256 name=name) 

257 except (TypeError, ValueError): 

258 _result = _dispatch.dispatch( 

259 initialize_trt_resource, (), dict(resource_handle=resource_handle, 

260 filename=filename, 

261 max_cached_engines_count=max_cached_engines_count, 

262 name=name) 

263 ) 

264 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

265 return _result 

266 raise 

267 return _op 

268InitializeTRTResource = tf_export("raw_ops.InitializeTRTResource")(_ops.to_raw_op(initialize_trt_resource)) 

269_dispatcher_for_initialize_trt_resource = initialize_trt_resource._tf_type_based_dispatcher.Dispatch 

270 

271 

272def initialize_trt_resource_eager_fallback(resource_handle, filename, max_cached_engines_count, name, ctx): 

273 if max_cached_engines_count is None: 

274 max_cached_engines_count = 1 

275 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count") 

276 resource_handle = _ops.convert_to_tensor(resource_handle, _dtypes.resource) 

277 filename = _ops.convert_to_tensor(filename, _dtypes.string) 

278 _inputs_flat = [resource_handle, filename] 

279 _attrs = ("max_cached_engines_count", max_cached_engines_count) 

280 _result = _execute.execute(b"InitializeTRTResource", 0, inputs=_inputs_flat, 

281 attrs=_attrs, ctx=ctx, name=name) 

282 _result = None 

283 return _result 

284 

285 

286@_dispatch.add_fallback_dispatch_list 

287@_dispatch.add_type_based_api_dispatcher 

288@tf_export('serialize_trt_resource') 

289def serialize_trt_resource(resource_name, filename, delete_resource=False, save_gpu_specific_engines=True, name=None): 

290 r"""TODO: add doc. 

291 

292 Args: 

293 resource_name: A `Tensor` of type `string`. 

294 filename: A `Tensor` of type `string`. 

295 delete_resource: An optional `bool`. Defaults to `False`. 

296 save_gpu_specific_engines: An optional `bool`. Defaults to `True`. 

297 name: A name for the operation (optional). 

298 

299 Returns: 

300 The created Operation. 

301 """ 

302 _ctx = _context._context or _context.context() 

303 tld = _ctx._thread_local_data 

304 if tld.is_eager: 

305 try: 

306 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

307 _ctx, "SerializeTRTResource", name, resource_name, filename, 

308 "delete_resource", delete_resource, "save_gpu_specific_engines", 

309 save_gpu_specific_engines) 

310 return _result 

311 except _core._NotOkStatusException as e: 

312 _ops.raise_from_not_ok_status(e, name) 

313 except _core._FallbackException: 

314 pass 

315 try: 

316 _result = _dispatcher_for_serialize_trt_resource( 

317 (resource_name, filename, delete_resource, 

318 save_gpu_specific_engines, name,), None) 

319 if _result is not NotImplemented: 

320 return _result 

321 return serialize_trt_resource_eager_fallback( 

322 resource_name, filename, delete_resource=delete_resource, 

323 save_gpu_specific_engines=save_gpu_specific_engines, name=name, 

324 ctx=_ctx) 

325 except _core._SymbolicException: 

326 pass # Add nodes to the TensorFlow graph. 

327 except (TypeError, ValueError): 

328 _result = _dispatch.dispatch( 

329 serialize_trt_resource, (), dict(resource_name=resource_name, 

330 filename=filename, 

331 delete_resource=delete_resource, 

332 save_gpu_specific_engines=save_gpu_specific_engines, 

333 name=name) 

334 ) 

335 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

336 return _result 

337 raise 

338 else: 

339 _result = _dispatcher_for_serialize_trt_resource( 

340 (resource_name, filename, delete_resource, save_gpu_specific_engines, 

341 name,), None) 

342 if _result is not NotImplemented: 

343 return _result 

344 # Add nodes to the TensorFlow graph. 

345 if delete_resource is None: 

346 delete_resource = False 

347 delete_resource = _execute.make_bool(delete_resource, "delete_resource") 

348 if save_gpu_specific_engines is None: 

349 save_gpu_specific_engines = True 

350 save_gpu_specific_engines = _execute.make_bool(save_gpu_specific_engines, "save_gpu_specific_engines") 

351 try: 

352 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

353 "SerializeTRTResource", resource_name=resource_name, 

354 filename=filename, 

355 delete_resource=delete_resource, 

356 save_gpu_specific_engines=save_gpu_specific_engines, 

357 name=name) 

358 except (TypeError, ValueError): 

359 _result = _dispatch.dispatch( 

360 serialize_trt_resource, (), dict(resource_name=resource_name, 

361 filename=filename, 

362 delete_resource=delete_resource, 

363 save_gpu_specific_engines=save_gpu_specific_engines, 

364 name=name) 

365 ) 

366 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

367 return _result 

368 raise 

369 return _op 

370SerializeTRTResource = tf_export("raw_ops.SerializeTRTResource")(_ops.to_raw_op(serialize_trt_resource)) 

371_dispatcher_for_serialize_trt_resource = serialize_trt_resource._tf_type_based_dispatcher.Dispatch 

372 

373 

374def serialize_trt_resource_eager_fallback(resource_name, filename, delete_resource, save_gpu_specific_engines, name, ctx): 

375 if delete_resource is None: 

376 delete_resource = False 

377 delete_resource = _execute.make_bool(delete_resource, "delete_resource") 

378 if save_gpu_specific_engines is None: 

379 save_gpu_specific_engines = True 

380 save_gpu_specific_engines = _execute.make_bool(save_gpu_specific_engines, "save_gpu_specific_engines") 

381 resource_name = _ops.convert_to_tensor(resource_name, _dtypes.string) 

382 filename = _ops.convert_to_tensor(filename, _dtypes.string) 

383 _inputs_flat = [resource_name, filename] 

384 _attrs = ("delete_resource", delete_resource, "save_gpu_specific_engines", 

385 save_gpu_specific_engines) 

386 _result = _execute.execute(b"SerializeTRTResource", 0, inputs=_inputs_flat, 

387 attrs=_attrs, ctx=ctx, name=name) 

388 _result = None 

389 return _result 

390 

391 

392@_dispatch.add_fallback_dispatch_list 

393@_dispatch.add_type_based_api_dispatcher 

394@tf_export('trt_engine_op') 

395def trt_engine_op(in_tensor, serialized_segment, OutT, workspace_size_bytes, precision_mode, segment_func="", input_shapes=[], output_shapes=[], max_cached_engines_count=1, max_batch_size=1, calibration_data="", use_calibration=True, segment_funcdef_name="", cached_engine_batches=[], fixed_input_size=True, static_engine=True, profile_strategy="", use_explicit_precision=False, name=None): 

396 r"""TODO: add doc. 

397 

398 Args: 

399 in_tensor: A list of `Tensor` objects with types from: `bool`, `int8`, `half`, `float32`, `int32`, `resource`. 

400 serialized_segment: A `string`. 

401 OutT: A list of `tf.DTypes` from: `tf.bool, tf.int8, tf.half, tf.float32, tf.int32` that has length `>= 1`. 

402 workspace_size_bytes: An `int`. 

403 precision_mode: A `string` from: `"FP32", "FP16", "INT8"`. 

404 segment_func: An optional function decorated with @Defun. Defaults to `""`. 

405 input_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`. 

406 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`. 

407 max_cached_engines_count: An optional `int`. Defaults to `1`. 

408 max_batch_size: An optional `int`. Defaults to `1`. 

409 calibration_data: An optional `string`. Defaults to `""`. 

410 use_calibration: An optional `bool`. Defaults to `True`. 

411 segment_funcdef_name: An optional `string`. Defaults to `""`. 

412 cached_engine_batches: An optional list of `ints`. Defaults to `[]`. 

413 fixed_input_size: An optional `bool`. Defaults to `True`. 

414 static_engine: An optional `bool`. Defaults to `True`. 

415 profile_strategy: An optional `string`. Defaults to `""`. 

416 use_explicit_precision: An optional `bool`. Defaults to `False`. 

417 name: A name for the operation (optional). 

418 

419 Returns: 

420 A list of `Tensor` objects of type `OutT`. 

421 """ 

422 _ctx = _context._context or _context.context() 

423 tld = _ctx._thread_local_data 

424 if tld.is_eager: 

425 try: 

426 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

427 _ctx, "TRTEngineOp", name, in_tensor, "serialized_segment", 

428 serialized_segment, "segment_func", segment_func, "OutT", OutT, 

429 "input_shapes", input_shapes, "output_shapes", output_shapes, 

430 "max_cached_engines_count", max_cached_engines_count, 

431 "max_batch_size", max_batch_size, "workspace_size_bytes", 

432 workspace_size_bytes, "precision_mode", precision_mode, 

433 "calibration_data", calibration_data, "use_calibration", 

434 use_calibration, "segment_funcdef_name", segment_funcdef_name, 

435 "cached_engine_batches", cached_engine_batches, "fixed_input_size", 

436 fixed_input_size, "static_engine", static_engine, "profile_strategy", 

437 profile_strategy, "use_explicit_precision", use_explicit_precision) 

438 return _result 

439 except _core._NotOkStatusException as e: 

440 _ops.raise_from_not_ok_status(e, name) 

441 except _core._FallbackException: 

442 pass 

443 try: 

444 _result = _dispatcher_for_trt_engine_op( 

445 (in_tensor, serialized_segment, OutT, workspace_size_bytes, 

446 precision_mode, segment_func, input_shapes, output_shapes, 

447 max_cached_engines_count, max_batch_size, calibration_data, 

448 use_calibration, segment_funcdef_name, cached_engine_batches, 

449 fixed_input_size, static_engine, profile_strategy, 

450 use_explicit_precision, name,), None) 

451 if _result is not NotImplemented: 

452 return _result 

453 return trt_engine_op_eager_fallback( 

454 in_tensor, serialized_segment=serialized_segment, 

455 segment_func=segment_func, OutT=OutT, input_shapes=input_shapes, 

456 output_shapes=output_shapes, 

457 max_cached_engines_count=max_cached_engines_count, 

458 max_batch_size=max_batch_size, 

459 workspace_size_bytes=workspace_size_bytes, 

460 precision_mode=precision_mode, calibration_data=calibration_data, 

461 use_calibration=use_calibration, 

462 segment_funcdef_name=segment_funcdef_name, 

463 cached_engine_batches=cached_engine_batches, 

464 fixed_input_size=fixed_input_size, static_engine=static_engine, 

465 profile_strategy=profile_strategy, 

466 use_explicit_precision=use_explicit_precision, name=name, ctx=_ctx) 

467 except _core._SymbolicException: 

468 pass # Add nodes to the TensorFlow graph. 

469 except (TypeError, ValueError): 

470 _result = _dispatch.dispatch( 

471 trt_engine_op, (), dict(in_tensor=in_tensor, 

472 serialized_segment=serialized_segment, 

473 OutT=OutT, 

474 workspace_size_bytes=workspace_size_bytes, 

475 precision_mode=precision_mode, 

476 segment_func=segment_func, 

477 input_shapes=input_shapes, 

478 output_shapes=output_shapes, 

479 max_cached_engines_count=max_cached_engines_count, 

480 max_batch_size=max_batch_size, 

481 calibration_data=calibration_data, 

482 use_calibration=use_calibration, 

483 segment_funcdef_name=segment_funcdef_name, 

484 cached_engine_batches=cached_engine_batches, 

485 fixed_input_size=fixed_input_size, 

486 static_engine=static_engine, 

487 profile_strategy=profile_strategy, 

488 use_explicit_precision=use_explicit_precision, 

489 name=name) 

490 ) 

491 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

492 return _result 

493 raise 

494 else: 

495 _result = _dispatcher_for_trt_engine_op( 

496 (in_tensor, serialized_segment, OutT, workspace_size_bytes, 

497 precision_mode, segment_func, input_shapes, output_shapes, 

498 max_cached_engines_count, max_batch_size, calibration_data, 

499 use_calibration, segment_funcdef_name, cached_engine_batches, 

500 fixed_input_size, static_engine, profile_strategy, 

501 use_explicit_precision, name,), None) 

502 if _result is not NotImplemented: 

503 return _result 

504 # Add nodes to the TensorFlow graph. 

505 serialized_segment = _execute.make_str(serialized_segment, "serialized_segment") 

506 if not isinstance(OutT, (list, tuple)): 

507 raise TypeError( 

508 "Expected list for 'OutT' argument to " 

509 "'trt_engine_op' Op, not %r." % OutT) 

510 OutT = [_execute.make_type(_t, "OutT") for _t in OutT] 

511 workspace_size_bytes = _execute.make_int(workspace_size_bytes, "workspace_size_bytes") 

512 precision_mode = _execute.make_str(precision_mode, "precision_mode") 

513 if segment_func is None: 

514 segment_func = "" 

515 if input_shapes is None: 

516 input_shapes = [] 

517 if not isinstance(input_shapes, (list, tuple)): 

518 raise TypeError( 

519 "Expected list for 'input_shapes' argument to " 

520 "'trt_engine_op' Op, not %r." % input_shapes) 

521 input_shapes = [_execute.make_shape(_s, "input_shapes") for _s in input_shapes] 

522 if output_shapes is None: 

523 output_shapes = [] 

524 if not isinstance(output_shapes, (list, tuple)): 

525 raise TypeError( 

526 "Expected list for 'output_shapes' argument to " 

527 "'trt_engine_op' Op, not %r." % output_shapes) 

528 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes] 

529 if max_cached_engines_count is None: 

530 max_cached_engines_count = 1 

531 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count") 

532 if max_batch_size is None: 

533 max_batch_size = 1 

534 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

535 if calibration_data is None: 

536 calibration_data = "" 

537 calibration_data = _execute.make_str(calibration_data, "calibration_data") 

538 if use_calibration is None: 

539 use_calibration = True 

540 use_calibration = _execute.make_bool(use_calibration, "use_calibration") 

541 if segment_funcdef_name is None: 

542 segment_funcdef_name = "" 

543 segment_funcdef_name = _execute.make_str(segment_funcdef_name, "segment_funcdef_name") 

544 if cached_engine_batches is None: 

545 cached_engine_batches = [] 

546 if not isinstance(cached_engine_batches, (list, tuple)): 

547 raise TypeError( 

548 "Expected list for 'cached_engine_batches' argument to " 

549 "'trt_engine_op' Op, not %r." % cached_engine_batches) 

550 cached_engine_batches = [_execute.make_int(_i, "cached_engine_batches") for _i in cached_engine_batches] 

551 if fixed_input_size is None: 

552 fixed_input_size = True 

553 fixed_input_size = _execute.make_bool(fixed_input_size, "fixed_input_size") 

554 if static_engine is None: 

555 static_engine = True 

556 static_engine = _execute.make_bool(static_engine, "static_engine") 

557 if profile_strategy is None: 

558 profile_strategy = "" 

559 profile_strategy = _execute.make_str(profile_strategy, "profile_strategy") 

560 if use_explicit_precision is None: 

561 use_explicit_precision = False 

562 use_explicit_precision = _execute.make_bool(use_explicit_precision, "use_explicit_precision") 

563 try: 

564 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

565 "TRTEngineOp", in_tensor=in_tensor, 

566 serialized_segment=serialized_segment, OutT=OutT, 

567 workspace_size_bytes=workspace_size_bytes, 

568 precision_mode=precision_mode, 

569 segment_func=segment_func, input_shapes=input_shapes, 

570 output_shapes=output_shapes, 

571 max_cached_engines_count=max_cached_engines_count, 

572 max_batch_size=max_batch_size, 

573 calibration_data=calibration_data, 

574 use_calibration=use_calibration, 

575 segment_funcdef_name=segment_funcdef_name, 

576 cached_engine_batches=cached_engine_batches, 

577 fixed_input_size=fixed_input_size, 

578 static_engine=static_engine, 

579 profile_strategy=profile_strategy, 

580 use_explicit_precision=use_explicit_precision, 

581 name=name) 

582 except (TypeError, ValueError): 

583 _result = _dispatch.dispatch( 

584 trt_engine_op, (), dict(in_tensor=in_tensor, 

585 serialized_segment=serialized_segment, 

586 OutT=OutT, 

587 workspace_size_bytes=workspace_size_bytes, 

588 precision_mode=precision_mode, 

589 segment_func=segment_func, 

590 input_shapes=input_shapes, 

591 output_shapes=output_shapes, 

592 max_cached_engines_count=max_cached_engines_count, 

593 max_batch_size=max_batch_size, 

594 calibration_data=calibration_data, 

595 use_calibration=use_calibration, 

596 segment_funcdef_name=segment_funcdef_name, 

597 cached_engine_batches=cached_engine_batches, 

598 fixed_input_size=fixed_input_size, 

599 static_engine=static_engine, 

600 profile_strategy=profile_strategy, 

601 use_explicit_precision=use_explicit_precision, 

602 name=name) 

603 ) 

604 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

605 return _result 

606 raise 

607 _result = _outputs[:] 

608 if _execute.must_record_gradient(): 

609 _attrs = ("serialized_segment", _op.get_attr("serialized_segment"), 

610 "segment_func", _op.get_attr("segment_func"), "InT", 

611 _op.get_attr("InT"), "OutT", _op.get_attr("OutT"), 

612 "input_shapes", _op.get_attr("input_shapes"), "output_shapes", 

613 _op.get_attr("output_shapes"), "max_cached_engines_count", 

614 _op._get_attr_int("max_cached_engines_count"), "max_batch_size", 

615 _op._get_attr_int("max_batch_size"), "workspace_size_bytes", 

616 _op._get_attr_int("workspace_size_bytes"), "precision_mode", 

617 _op.get_attr("precision_mode"), "calibration_data", 

618 _op.get_attr("calibration_data"), "use_calibration", 

619 _op._get_attr_bool("use_calibration"), "segment_funcdef_name", 

620 _op.get_attr("segment_funcdef_name"), "cached_engine_batches", 

621 _op.get_attr("cached_engine_batches"), "fixed_input_size", 

622 _op._get_attr_bool("fixed_input_size"), "static_engine", 

623 _op._get_attr_bool("static_engine"), "profile_strategy", 

624 _op.get_attr("profile_strategy"), "use_explicit_precision", 

625 _op._get_attr_bool("use_explicit_precision")) 

626 _inputs_flat = _op.inputs 

627 _execute.record_gradient( 

628 "TRTEngineOp", _inputs_flat, _attrs, _result) 

629 return _result 

630 

631TRTEngineOp = tf_export("raw_ops.TRTEngineOp")(_ops.to_raw_op(trt_engine_op)) 

632_dispatcher_for_trt_engine_op = trt_engine_op._tf_type_based_dispatcher.Dispatch 

633 

634 

635def trt_engine_op_eager_fallback(in_tensor, serialized_segment, OutT, workspace_size_bytes, precision_mode, segment_func, input_shapes, output_shapes, max_cached_engines_count, max_batch_size, calibration_data, use_calibration, segment_funcdef_name, cached_engine_batches, fixed_input_size, static_engine, profile_strategy, use_explicit_precision, name, ctx): 

636 serialized_segment = _execute.make_str(serialized_segment, "serialized_segment") 

637 if not isinstance(OutT, (list, tuple)): 

638 raise TypeError( 

639 "Expected list for 'OutT' argument to " 

640 "'trt_engine_op' Op, not %r." % OutT) 

641 OutT = [_execute.make_type(_t, "OutT") for _t in OutT] 

642 workspace_size_bytes = _execute.make_int(workspace_size_bytes, "workspace_size_bytes") 

643 precision_mode = _execute.make_str(precision_mode, "precision_mode") 

644 if segment_func is None: 

645 segment_func = "" 

646 if input_shapes is None: 

647 input_shapes = [] 

648 if not isinstance(input_shapes, (list, tuple)): 

649 raise TypeError( 

650 "Expected list for 'input_shapes' argument to " 

651 "'trt_engine_op' Op, not %r." % input_shapes) 

652 input_shapes = [_execute.make_shape(_s, "input_shapes") for _s in input_shapes] 

653 if output_shapes is None: 

654 output_shapes = [] 

655 if not isinstance(output_shapes, (list, tuple)): 

656 raise TypeError( 

657 "Expected list for 'output_shapes' argument to " 

658 "'trt_engine_op' Op, not %r." % output_shapes) 

659 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes] 

660 if max_cached_engines_count is None: 

661 max_cached_engines_count = 1 

662 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count") 

663 if max_batch_size is None: 

664 max_batch_size = 1 

665 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size") 

666 if calibration_data is None: 

667 calibration_data = "" 

668 calibration_data = _execute.make_str(calibration_data, "calibration_data") 

669 if use_calibration is None: 

670 use_calibration = True 

671 use_calibration = _execute.make_bool(use_calibration, "use_calibration") 

672 if segment_funcdef_name is None: 

673 segment_funcdef_name = "" 

674 segment_funcdef_name = _execute.make_str(segment_funcdef_name, "segment_funcdef_name") 

675 if cached_engine_batches is None: 

676 cached_engine_batches = [] 

677 if not isinstance(cached_engine_batches, (list, tuple)): 

678 raise TypeError( 

679 "Expected list for 'cached_engine_batches' argument to " 

680 "'trt_engine_op' Op, not %r." % cached_engine_batches) 

681 cached_engine_batches = [_execute.make_int(_i, "cached_engine_batches") for _i in cached_engine_batches] 

682 if fixed_input_size is None: 

683 fixed_input_size = True 

684 fixed_input_size = _execute.make_bool(fixed_input_size, "fixed_input_size") 

685 if static_engine is None: 

686 static_engine = True 

687 static_engine = _execute.make_bool(static_engine, "static_engine") 

688 if profile_strategy is None: 

689 profile_strategy = "" 

690 profile_strategy = _execute.make_str(profile_strategy, "profile_strategy") 

691 if use_explicit_precision is None: 

692 use_explicit_precision = False 

693 use_explicit_precision = _execute.make_bool(use_explicit_precision, "use_explicit_precision") 

694 _attr_InT, in_tensor = _execute.convert_to_mixed_eager_tensors(in_tensor, ctx) 

695 _inputs_flat = list(in_tensor) 

696 _attrs = ("serialized_segment", serialized_segment, "segment_func", 

697 segment_func, "InT", _attr_InT, "OutT", OutT, "input_shapes", input_shapes, 

698 "output_shapes", output_shapes, "max_cached_engines_count", 

699 max_cached_engines_count, "max_batch_size", max_batch_size, 

700 "workspace_size_bytes", workspace_size_bytes, "precision_mode", 

701 precision_mode, "calibration_data", calibration_data, "use_calibration", 

702 use_calibration, "segment_funcdef_name", segment_funcdef_name, 

703 "cached_engine_batches", cached_engine_batches, "fixed_input_size", 

704 fixed_input_size, "static_engine", static_engine, "profile_strategy", 

705 profile_strategy, "use_explicit_precision", use_explicit_precision) 

706 _result = _execute.execute(b"TRTEngineOp", len(OutT), inputs=_inputs_flat, 

707 attrs=_attrs, ctx=ctx, name=name) 

708 if _execute.must_record_gradient(): 

709 _execute.record_gradient( 

710 "TRTEngineOp", _inputs_flat, _attrs, _result) 

711 return _result 

712