Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/compiler/tf2tensorrt/ops/gen_trt_ops.py: 13%
387 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Python wrappers around TensorFlow ops.
3This file is MACHINE GENERATED! Do not edit.
4"""
6import collections
8from tensorflow.python import pywrap_tfe as pywrap_tfe
9from tensorflow.python.eager import context as _context
10from tensorflow.python.eager import core as _core
11from tensorflow.python.eager import execute as _execute
12from tensorflow.python.framework import dtypes as _dtypes
13from tensorflow.security.fuzzing.py import annotation_types as _atypes
15from tensorflow.python.framework import op_def_registry as _op_def_registry
16from tensorflow.python.framework import ops as _ops
17from tensorflow.python.framework import op_def_library as _op_def_library
18from tensorflow.python.util.deprecation import deprecated_endpoints
19from tensorflow.python.util import dispatch as _dispatch
20from tensorflow.python.util.tf_export import tf_export
22from typing import TypeVar
24@_dispatch.add_fallback_dispatch_list
25@_dispatch.add_type_based_api_dispatcher
26@tf_export('create_trt_resource_handle')
27def create_trt_resource_handle(resource_name, name=None):
28 r"""TODO: add doc.
30 Args:
31 resource_name: A `string`.
32 name: A name for the operation (optional).
34 Returns:
35 A `Tensor` of type `resource`.
36 """
37 _ctx = _context._context or _context.context()
38 tld = _ctx._thread_local_data
39 if tld.is_eager:
40 try:
41 _result = pywrap_tfe.TFE_Py_FastPathExecute(
42 _ctx, "CreateTRTResourceHandle", name, "resource_name", resource_name)
43 return _result
44 except _core._NotOkStatusException as e:
45 _ops.raise_from_not_ok_status(e, name)
46 except _core._FallbackException:
47 pass
48 try:
49 _result = _dispatcher_for_create_trt_resource_handle(
50 (resource_name, name,), None)
51 if _result is not NotImplemented:
52 return _result
53 return create_trt_resource_handle_eager_fallback(
54 resource_name=resource_name, name=name, ctx=_ctx)
55 except _core._SymbolicException:
56 pass # Add nodes to the TensorFlow graph.
57 except (TypeError, ValueError):
58 _result = _dispatch.dispatch(
59 create_trt_resource_handle, (), dict(resource_name=resource_name,
60 name=name)
61 )
62 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
63 return _result
64 raise
65 else:
66 _result = _dispatcher_for_create_trt_resource_handle(
67 (resource_name, name,), None)
68 if _result is not NotImplemented:
69 return _result
70 # Add nodes to the TensorFlow graph.
71 resource_name = _execute.make_str(resource_name, "resource_name")
72 try:
73 _, _, _op, _outputs = _op_def_library._apply_op_helper(
74 "CreateTRTResourceHandle", resource_name=resource_name, name=name)
75 except (TypeError, ValueError):
76 _result = _dispatch.dispatch(
77 create_trt_resource_handle, (), dict(resource_name=resource_name,
78 name=name)
79 )
80 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
81 return _result
82 raise
83 _result = _outputs[:]
84 if _execute.must_record_gradient():
85 _attrs = ("resource_name", _op.get_attr("resource_name"))
86 _inputs_flat = _op.inputs
87 _execute.record_gradient(
88 "CreateTRTResourceHandle", _inputs_flat, _attrs, _result)
89 _result, = _result
90 return _result
92CreateTRTResourceHandle = tf_export("raw_ops.CreateTRTResourceHandle")(_ops.to_raw_op(create_trt_resource_handle))
93_dispatcher_for_create_trt_resource_handle = create_trt_resource_handle._tf_type_based_dispatcher.Dispatch
96def create_trt_resource_handle_eager_fallback(resource_name, name, ctx):
97 resource_name = _execute.make_str(resource_name, "resource_name")
98 _inputs_flat = []
99 _attrs = ("resource_name", resource_name)
100 _result = _execute.execute(b"CreateTRTResourceHandle", 1,
101 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
102 name=name)
103 if _execute.must_record_gradient():
104 _execute.record_gradient(
105 "CreateTRTResourceHandle", _inputs_flat, _attrs, _result)
106 _result, = _result
107 return _result
110@_dispatch.add_fallback_dispatch_list
111@_dispatch.add_type_based_api_dispatcher
112@tf_export('get_calibration_data_op')
113def get_calibration_data_op(resource_name, name=None):
114 r"""Returns calibration data for the given resource name
116 Args:
117 resource_name: A `Tensor` of type `string`.
118 name: A name for the operation (optional).
120 Returns:
121 A `Tensor` of type `string`.
122 """
123 _ctx = _context._context or _context.context()
124 tld = _ctx._thread_local_data
125 if tld.is_eager:
126 try:
127 _result = pywrap_tfe.TFE_Py_FastPathExecute(
128 _ctx, "GetCalibrationDataOp", name, resource_name)
129 return _result
130 except _core._NotOkStatusException as e:
131 _ops.raise_from_not_ok_status(e, name)
132 except _core._FallbackException:
133 pass
134 try:
135 _result = _dispatcher_for_get_calibration_data_op(
136 (resource_name, name,), None)
137 if _result is not NotImplemented:
138 return _result
139 return get_calibration_data_op_eager_fallback(
140 resource_name, name=name, ctx=_ctx)
141 except _core._SymbolicException:
142 pass # Add nodes to the TensorFlow graph.
143 except (TypeError, ValueError):
144 _result = _dispatch.dispatch(
145 get_calibration_data_op, (), dict(resource_name=resource_name,
146 name=name)
147 )
148 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
149 return _result
150 raise
151 else:
152 _result = _dispatcher_for_get_calibration_data_op(
153 (resource_name, name,), None)
154 if _result is not NotImplemented:
155 return _result
156 # Add nodes to the TensorFlow graph.
157 try:
158 _, _, _op, _outputs = _op_def_library._apply_op_helper(
159 "GetCalibrationDataOp", resource_name=resource_name, name=name)
160 except (TypeError, ValueError):
161 _result = _dispatch.dispatch(
162 get_calibration_data_op, (), dict(resource_name=resource_name,
163 name=name)
164 )
165 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
166 return _result
167 raise
168 _result = _outputs[:]
169 if _execute.must_record_gradient():
170 _attrs = ()
171 _inputs_flat = _op.inputs
172 _execute.record_gradient(
173 "GetCalibrationDataOp", _inputs_flat, _attrs, _result)
174 _result, = _result
175 return _result
177GetCalibrationDataOp = tf_export("raw_ops.GetCalibrationDataOp")(_ops.to_raw_op(get_calibration_data_op))
178_dispatcher_for_get_calibration_data_op = get_calibration_data_op._tf_type_based_dispatcher.Dispatch
181def get_calibration_data_op_eager_fallback(resource_name, name, ctx):
182 resource_name = _ops.convert_to_tensor(resource_name, _dtypes.string)
183 _inputs_flat = [resource_name]
184 _attrs = None
185 _result = _execute.execute(b"GetCalibrationDataOp", 1, inputs=_inputs_flat,
186 attrs=_attrs, ctx=ctx, name=name)
187 if _execute.must_record_gradient():
188 _execute.record_gradient(
189 "GetCalibrationDataOp", _inputs_flat, _attrs, _result)
190 _result, = _result
191 return _result
194@_dispatch.add_fallback_dispatch_list
195@_dispatch.add_type_based_api_dispatcher
196@tf_export('initialize_trt_resource')
197def initialize_trt_resource(resource_handle, filename, max_cached_engines_count=1, name=None):
198 r"""TODO: add doc.
200 Args:
201 resource_handle: A `Tensor` of type `resource`.
202 filename: A `Tensor` of type `string`.
203 max_cached_engines_count: An optional `int`. Defaults to `1`.
204 name: A name for the operation (optional).
206 Returns:
207 The created Operation.
208 """
209 _ctx = _context._context or _context.context()
210 tld = _ctx._thread_local_data
211 if tld.is_eager:
212 try:
213 _result = pywrap_tfe.TFE_Py_FastPathExecute(
214 _ctx, "InitializeTRTResource", name, resource_handle, filename,
215 "max_cached_engines_count", max_cached_engines_count)
216 return _result
217 except _core._NotOkStatusException as e:
218 _ops.raise_from_not_ok_status(e, name)
219 except _core._FallbackException:
220 pass
221 try:
222 _result = _dispatcher_for_initialize_trt_resource(
223 (resource_handle, filename, max_cached_engines_count, name,), None)
224 if _result is not NotImplemented:
225 return _result
226 return initialize_trt_resource_eager_fallback(
227 resource_handle, filename,
228 max_cached_engines_count=max_cached_engines_count, name=name,
229 ctx=_ctx)
230 except _core._SymbolicException:
231 pass # Add nodes to the TensorFlow graph.
232 except (TypeError, ValueError):
233 _result = _dispatch.dispatch(
234 initialize_trt_resource, (), dict(resource_handle=resource_handle,
235 filename=filename,
236 max_cached_engines_count=max_cached_engines_count,
237 name=name)
238 )
239 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
240 return _result
241 raise
242 else:
243 _result = _dispatcher_for_initialize_trt_resource(
244 (resource_handle, filename, max_cached_engines_count, name,), None)
245 if _result is not NotImplemented:
246 return _result
247 # Add nodes to the TensorFlow graph.
248 if max_cached_engines_count is None:
249 max_cached_engines_count = 1
250 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count")
251 try:
252 _, _, _op, _outputs = _op_def_library._apply_op_helper(
253 "InitializeTRTResource", resource_handle=resource_handle,
254 filename=filename,
255 max_cached_engines_count=max_cached_engines_count,
256 name=name)
257 except (TypeError, ValueError):
258 _result = _dispatch.dispatch(
259 initialize_trt_resource, (), dict(resource_handle=resource_handle,
260 filename=filename,
261 max_cached_engines_count=max_cached_engines_count,
262 name=name)
263 )
264 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
265 return _result
266 raise
267 return _op
268InitializeTRTResource = tf_export("raw_ops.InitializeTRTResource")(_ops.to_raw_op(initialize_trt_resource))
269_dispatcher_for_initialize_trt_resource = initialize_trt_resource._tf_type_based_dispatcher.Dispatch
272def initialize_trt_resource_eager_fallback(resource_handle, filename, max_cached_engines_count, name, ctx):
273 if max_cached_engines_count is None:
274 max_cached_engines_count = 1
275 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count")
276 resource_handle = _ops.convert_to_tensor(resource_handle, _dtypes.resource)
277 filename = _ops.convert_to_tensor(filename, _dtypes.string)
278 _inputs_flat = [resource_handle, filename]
279 _attrs = ("max_cached_engines_count", max_cached_engines_count)
280 _result = _execute.execute(b"InitializeTRTResource", 0, inputs=_inputs_flat,
281 attrs=_attrs, ctx=ctx, name=name)
282 _result = None
283 return _result
286@_dispatch.add_fallback_dispatch_list
287@_dispatch.add_type_based_api_dispatcher
288@tf_export('serialize_trt_resource')
289def serialize_trt_resource(resource_name, filename, delete_resource=False, save_gpu_specific_engines=True, name=None):
290 r"""TODO: add doc.
292 Args:
293 resource_name: A `Tensor` of type `string`.
294 filename: A `Tensor` of type `string`.
295 delete_resource: An optional `bool`. Defaults to `False`.
296 save_gpu_specific_engines: An optional `bool`. Defaults to `True`.
297 name: A name for the operation (optional).
299 Returns:
300 The created Operation.
301 """
302 _ctx = _context._context or _context.context()
303 tld = _ctx._thread_local_data
304 if tld.is_eager:
305 try:
306 _result = pywrap_tfe.TFE_Py_FastPathExecute(
307 _ctx, "SerializeTRTResource", name, resource_name, filename,
308 "delete_resource", delete_resource, "save_gpu_specific_engines",
309 save_gpu_specific_engines)
310 return _result
311 except _core._NotOkStatusException as e:
312 _ops.raise_from_not_ok_status(e, name)
313 except _core._FallbackException:
314 pass
315 try:
316 _result = _dispatcher_for_serialize_trt_resource(
317 (resource_name, filename, delete_resource,
318 save_gpu_specific_engines, name,), None)
319 if _result is not NotImplemented:
320 return _result
321 return serialize_trt_resource_eager_fallback(
322 resource_name, filename, delete_resource=delete_resource,
323 save_gpu_specific_engines=save_gpu_specific_engines, name=name,
324 ctx=_ctx)
325 except _core._SymbolicException:
326 pass # Add nodes to the TensorFlow graph.
327 except (TypeError, ValueError):
328 _result = _dispatch.dispatch(
329 serialize_trt_resource, (), dict(resource_name=resource_name,
330 filename=filename,
331 delete_resource=delete_resource,
332 save_gpu_specific_engines=save_gpu_specific_engines,
333 name=name)
334 )
335 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
336 return _result
337 raise
338 else:
339 _result = _dispatcher_for_serialize_trt_resource(
340 (resource_name, filename, delete_resource, save_gpu_specific_engines,
341 name,), None)
342 if _result is not NotImplemented:
343 return _result
344 # Add nodes to the TensorFlow graph.
345 if delete_resource is None:
346 delete_resource = False
347 delete_resource = _execute.make_bool(delete_resource, "delete_resource")
348 if save_gpu_specific_engines is None:
349 save_gpu_specific_engines = True
350 save_gpu_specific_engines = _execute.make_bool(save_gpu_specific_engines, "save_gpu_specific_engines")
351 try:
352 _, _, _op, _outputs = _op_def_library._apply_op_helper(
353 "SerializeTRTResource", resource_name=resource_name,
354 filename=filename,
355 delete_resource=delete_resource,
356 save_gpu_specific_engines=save_gpu_specific_engines,
357 name=name)
358 except (TypeError, ValueError):
359 _result = _dispatch.dispatch(
360 serialize_trt_resource, (), dict(resource_name=resource_name,
361 filename=filename,
362 delete_resource=delete_resource,
363 save_gpu_specific_engines=save_gpu_specific_engines,
364 name=name)
365 )
366 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
367 return _result
368 raise
369 return _op
370SerializeTRTResource = tf_export("raw_ops.SerializeTRTResource")(_ops.to_raw_op(serialize_trt_resource))
371_dispatcher_for_serialize_trt_resource = serialize_trt_resource._tf_type_based_dispatcher.Dispatch
374def serialize_trt_resource_eager_fallback(resource_name, filename, delete_resource, save_gpu_specific_engines, name, ctx):
375 if delete_resource is None:
376 delete_resource = False
377 delete_resource = _execute.make_bool(delete_resource, "delete_resource")
378 if save_gpu_specific_engines is None:
379 save_gpu_specific_engines = True
380 save_gpu_specific_engines = _execute.make_bool(save_gpu_specific_engines, "save_gpu_specific_engines")
381 resource_name = _ops.convert_to_tensor(resource_name, _dtypes.string)
382 filename = _ops.convert_to_tensor(filename, _dtypes.string)
383 _inputs_flat = [resource_name, filename]
384 _attrs = ("delete_resource", delete_resource, "save_gpu_specific_engines",
385 save_gpu_specific_engines)
386 _result = _execute.execute(b"SerializeTRTResource", 0, inputs=_inputs_flat,
387 attrs=_attrs, ctx=ctx, name=name)
388 _result = None
389 return _result
392@_dispatch.add_fallback_dispatch_list
393@_dispatch.add_type_based_api_dispatcher
394@tf_export('trt_engine_op')
395def trt_engine_op(in_tensor, serialized_segment, OutT, workspace_size_bytes, precision_mode, segment_func="", input_shapes=[], output_shapes=[], max_cached_engines_count=1, max_batch_size=1, calibration_data="", use_calibration=True, segment_funcdef_name="", cached_engine_batches=[], fixed_input_size=True, static_engine=True, profile_strategy="", use_explicit_precision=False, name=None):
396 r"""TODO: add doc.
398 Args:
399 in_tensor: A list of `Tensor` objects with types from: `bool`, `int8`, `half`, `float32`, `int32`, `resource`.
400 serialized_segment: A `string`.
401 OutT: A list of `tf.DTypes` from: `tf.bool, tf.int8, tf.half, tf.float32, tf.int32` that has length `>= 1`.
402 workspace_size_bytes: An `int`.
403 precision_mode: A `string` from: `"FP32", "FP16", "INT8"`.
404 segment_func: An optional function decorated with @Defun. Defaults to `""`.
405 input_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
406 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
407 max_cached_engines_count: An optional `int`. Defaults to `1`.
408 max_batch_size: An optional `int`. Defaults to `1`.
409 calibration_data: An optional `string`. Defaults to `""`.
410 use_calibration: An optional `bool`. Defaults to `True`.
411 segment_funcdef_name: An optional `string`. Defaults to `""`.
412 cached_engine_batches: An optional list of `ints`. Defaults to `[]`.
413 fixed_input_size: An optional `bool`. Defaults to `True`.
414 static_engine: An optional `bool`. Defaults to `True`.
415 profile_strategy: An optional `string`. Defaults to `""`.
416 use_explicit_precision: An optional `bool`. Defaults to `False`.
417 name: A name for the operation (optional).
419 Returns:
420 A list of `Tensor` objects of type `OutT`.
421 """
422 _ctx = _context._context or _context.context()
423 tld = _ctx._thread_local_data
424 if tld.is_eager:
425 try:
426 _result = pywrap_tfe.TFE_Py_FastPathExecute(
427 _ctx, "TRTEngineOp", name, in_tensor, "serialized_segment",
428 serialized_segment, "segment_func", segment_func, "OutT", OutT,
429 "input_shapes", input_shapes, "output_shapes", output_shapes,
430 "max_cached_engines_count", max_cached_engines_count,
431 "max_batch_size", max_batch_size, "workspace_size_bytes",
432 workspace_size_bytes, "precision_mode", precision_mode,
433 "calibration_data", calibration_data, "use_calibration",
434 use_calibration, "segment_funcdef_name", segment_funcdef_name,
435 "cached_engine_batches", cached_engine_batches, "fixed_input_size",
436 fixed_input_size, "static_engine", static_engine, "profile_strategy",
437 profile_strategy, "use_explicit_precision", use_explicit_precision)
438 return _result
439 except _core._NotOkStatusException as e:
440 _ops.raise_from_not_ok_status(e, name)
441 except _core._FallbackException:
442 pass
443 try:
444 _result = _dispatcher_for_trt_engine_op(
445 (in_tensor, serialized_segment, OutT, workspace_size_bytes,
446 precision_mode, segment_func, input_shapes, output_shapes,
447 max_cached_engines_count, max_batch_size, calibration_data,
448 use_calibration, segment_funcdef_name, cached_engine_batches,
449 fixed_input_size, static_engine, profile_strategy,
450 use_explicit_precision, name,), None)
451 if _result is not NotImplemented:
452 return _result
453 return trt_engine_op_eager_fallback(
454 in_tensor, serialized_segment=serialized_segment,
455 segment_func=segment_func, OutT=OutT, input_shapes=input_shapes,
456 output_shapes=output_shapes,
457 max_cached_engines_count=max_cached_engines_count,
458 max_batch_size=max_batch_size,
459 workspace_size_bytes=workspace_size_bytes,
460 precision_mode=precision_mode, calibration_data=calibration_data,
461 use_calibration=use_calibration,
462 segment_funcdef_name=segment_funcdef_name,
463 cached_engine_batches=cached_engine_batches,
464 fixed_input_size=fixed_input_size, static_engine=static_engine,
465 profile_strategy=profile_strategy,
466 use_explicit_precision=use_explicit_precision, name=name, ctx=_ctx)
467 except _core._SymbolicException:
468 pass # Add nodes to the TensorFlow graph.
469 except (TypeError, ValueError):
470 _result = _dispatch.dispatch(
471 trt_engine_op, (), dict(in_tensor=in_tensor,
472 serialized_segment=serialized_segment,
473 OutT=OutT,
474 workspace_size_bytes=workspace_size_bytes,
475 precision_mode=precision_mode,
476 segment_func=segment_func,
477 input_shapes=input_shapes,
478 output_shapes=output_shapes,
479 max_cached_engines_count=max_cached_engines_count,
480 max_batch_size=max_batch_size,
481 calibration_data=calibration_data,
482 use_calibration=use_calibration,
483 segment_funcdef_name=segment_funcdef_name,
484 cached_engine_batches=cached_engine_batches,
485 fixed_input_size=fixed_input_size,
486 static_engine=static_engine,
487 profile_strategy=profile_strategy,
488 use_explicit_precision=use_explicit_precision,
489 name=name)
490 )
491 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
492 return _result
493 raise
494 else:
495 _result = _dispatcher_for_trt_engine_op(
496 (in_tensor, serialized_segment, OutT, workspace_size_bytes,
497 precision_mode, segment_func, input_shapes, output_shapes,
498 max_cached_engines_count, max_batch_size, calibration_data,
499 use_calibration, segment_funcdef_name, cached_engine_batches,
500 fixed_input_size, static_engine, profile_strategy,
501 use_explicit_precision, name,), None)
502 if _result is not NotImplemented:
503 return _result
504 # Add nodes to the TensorFlow graph.
505 serialized_segment = _execute.make_str(serialized_segment, "serialized_segment")
506 if not isinstance(OutT, (list, tuple)):
507 raise TypeError(
508 "Expected list for 'OutT' argument to "
509 "'trt_engine_op' Op, not %r." % OutT)
510 OutT = [_execute.make_type(_t, "OutT") for _t in OutT]
511 workspace_size_bytes = _execute.make_int(workspace_size_bytes, "workspace_size_bytes")
512 precision_mode = _execute.make_str(precision_mode, "precision_mode")
513 if segment_func is None:
514 segment_func = ""
515 if input_shapes is None:
516 input_shapes = []
517 if not isinstance(input_shapes, (list, tuple)):
518 raise TypeError(
519 "Expected list for 'input_shapes' argument to "
520 "'trt_engine_op' Op, not %r." % input_shapes)
521 input_shapes = [_execute.make_shape(_s, "input_shapes") for _s in input_shapes]
522 if output_shapes is None:
523 output_shapes = []
524 if not isinstance(output_shapes, (list, tuple)):
525 raise TypeError(
526 "Expected list for 'output_shapes' argument to "
527 "'trt_engine_op' Op, not %r." % output_shapes)
528 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
529 if max_cached_engines_count is None:
530 max_cached_engines_count = 1
531 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count")
532 if max_batch_size is None:
533 max_batch_size = 1
534 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size")
535 if calibration_data is None:
536 calibration_data = ""
537 calibration_data = _execute.make_str(calibration_data, "calibration_data")
538 if use_calibration is None:
539 use_calibration = True
540 use_calibration = _execute.make_bool(use_calibration, "use_calibration")
541 if segment_funcdef_name is None:
542 segment_funcdef_name = ""
543 segment_funcdef_name = _execute.make_str(segment_funcdef_name, "segment_funcdef_name")
544 if cached_engine_batches is None:
545 cached_engine_batches = []
546 if not isinstance(cached_engine_batches, (list, tuple)):
547 raise TypeError(
548 "Expected list for 'cached_engine_batches' argument to "
549 "'trt_engine_op' Op, not %r." % cached_engine_batches)
550 cached_engine_batches = [_execute.make_int(_i, "cached_engine_batches") for _i in cached_engine_batches]
551 if fixed_input_size is None:
552 fixed_input_size = True
553 fixed_input_size = _execute.make_bool(fixed_input_size, "fixed_input_size")
554 if static_engine is None:
555 static_engine = True
556 static_engine = _execute.make_bool(static_engine, "static_engine")
557 if profile_strategy is None:
558 profile_strategy = ""
559 profile_strategy = _execute.make_str(profile_strategy, "profile_strategy")
560 if use_explicit_precision is None:
561 use_explicit_precision = False
562 use_explicit_precision = _execute.make_bool(use_explicit_precision, "use_explicit_precision")
563 try:
564 _, _, _op, _outputs = _op_def_library._apply_op_helper(
565 "TRTEngineOp", in_tensor=in_tensor,
566 serialized_segment=serialized_segment, OutT=OutT,
567 workspace_size_bytes=workspace_size_bytes,
568 precision_mode=precision_mode,
569 segment_func=segment_func, input_shapes=input_shapes,
570 output_shapes=output_shapes,
571 max_cached_engines_count=max_cached_engines_count,
572 max_batch_size=max_batch_size,
573 calibration_data=calibration_data,
574 use_calibration=use_calibration,
575 segment_funcdef_name=segment_funcdef_name,
576 cached_engine_batches=cached_engine_batches,
577 fixed_input_size=fixed_input_size,
578 static_engine=static_engine,
579 profile_strategy=profile_strategy,
580 use_explicit_precision=use_explicit_precision,
581 name=name)
582 except (TypeError, ValueError):
583 _result = _dispatch.dispatch(
584 trt_engine_op, (), dict(in_tensor=in_tensor,
585 serialized_segment=serialized_segment,
586 OutT=OutT,
587 workspace_size_bytes=workspace_size_bytes,
588 precision_mode=precision_mode,
589 segment_func=segment_func,
590 input_shapes=input_shapes,
591 output_shapes=output_shapes,
592 max_cached_engines_count=max_cached_engines_count,
593 max_batch_size=max_batch_size,
594 calibration_data=calibration_data,
595 use_calibration=use_calibration,
596 segment_funcdef_name=segment_funcdef_name,
597 cached_engine_batches=cached_engine_batches,
598 fixed_input_size=fixed_input_size,
599 static_engine=static_engine,
600 profile_strategy=profile_strategy,
601 use_explicit_precision=use_explicit_precision,
602 name=name)
603 )
604 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
605 return _result
606 raise
607 _result = _outputs[:]
608 if _execute.must_record_gradient():
609 _attrs = ("serialized_segment", _op.get_attr("serialized_segment"),
610 "segment_func", _op.get_attr("segment_func"), "InT",
611 _op.get_attr("InT"), "OutT", _op.get_attr("OutT"),
612 "input_shapes", _op.get_attr("input_shapes"), "output_shapes",
613 _op.get_attr("output_shapes"), "max_cached_engines_count",
614 _op._get_attr_int("max_cached_engines_count"), "max_batch_size",
615 _op._get_attr_int("max_batch_size"), "workspace_size_bytes",
616 _op._get_attr_int("workspace_size_bytes"), "precision_mode",
617 _op.get_attr("precision_mode"), "calibration_data",
618 _op.get_attr("calibration_data"), "use_calibration",
619 _op._get_attr_bool("use_calibration"), "segment_funcdef_name",
620 _op.get_attr("segment_funcdef_name"), "cached_engine_batches",
621 _op.get_attr("cached_engine_batches"), "fixed_input_size",
622 _op._get_attr_bool("fixed_input_size"), "static_engine",
623 _op._get_attr_bool("static_engine"), "profile_strategy",
624 _op.get_attr("profile_strategy"), "use_explicit_precision",
625 _op._get_attr_bool("use_explicit_precision"))
626 _inputs_flat = _op.inputs
627 _execute.record_gradient(
628 "TRTEngineOp", _inputs_flat, _attrs, _result)
629 return _result
631TRTEngineOp = tf_export("raw_ops.TRTEngineOp")(_ops.to_raw_op(trt_engine_op))
632_dispatcher_for_trt_engine_op = trt_engine_op._tf_type_based_dispatcher.Dispatch
635def trt_engine_op_eager_fallback(in_tensor, serialized_segment, OutT, workspace_size_bytes, precision_mode, segment_func, input_shapes, output_shapes, max_cached_engines_count, max_batch_size, calibration_data, use_calibration, segment_funcdef_name, cached_engine_batches, fixed_input_size, static_engine, profile_strategy, use_explicit_precision, name, ctx):
636 serialized_segment = _execute.make_str(serialized_segment, "serialized_segment")
637 if not isinstance(OutT, (list, tuple)):
638 raise TypeError(
639 "Expected list for 'OutT' argument to "
640 "'trt_engine_op' Op, not %r." % OutT)
641 OutT = [_execute.make_type(_t, "OutT") for _t in OutT]
642 workspace_size_bytes = _execute.make_int(workspace_size_bytes, "workspace_size_bytes")
643 precision_mode = _execute.make_str(precision_mode, "precision_mode")
644 if segment_func is None:
645 segment_func = ""
646 if input_shapes is None:
647 input_shapes = []
648 if not isinstance(input_shapes, (list, tuple)):
649 raise TypeError(
650 "Expected list for 'input_shapes' argument to "
651 "'trt_engine_op' Op, not %r." % input_shapes)
652 input_shapes = [_execute.make_shape(_s, "input_shapes") for _s in input_shapes]
653 if output_shapes is None:
654 output_shapes = []
655 if not isinstance(output_shapes, (list, tuple)):
656 raise TypeError(
657 "Expected list for 'output_shapes' argument to "
658 "'trt_engine_op' Op, not %r." % output_shapes)
659 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
660 if max_cached_engines_count is None:
661 max_cached_engines_count = 1
662 max_cached_engines_count = _execute.make_int(max_cached_engines_count, "max_cached_engines_count")
663 if max_batch_size is None:
664 max_batch_size = 1
665 max_batch_size = _execute.make_int(max_batch_size, "max_batch_size")
666 if calibration_data is None:
667 calibration_data = ""
668 calibration_data = _execute.make_str(calibration_data, "calibration_data")
669 if use_calibration is None:
670 use_calibration = True
671 use_calibration = _execute.make_bool(use_calibration, "use_calibration")
672 if segment_funcdef_name is None:
673 segment_funcdef_name = ""
674 segment_funcdef_name = _execute.make_str(segment_funcdef_name, "segment_funcdef_name")
675 if cached_engine_batches is None:
676 cached_engine_batches = []
677 if not isinstance(cached_engine_batches, (list, tuple)):
678 raise TypeError(
679 "Expected list for 'cached_engine_batches' argument to "
680 "'trt_engine_op' Op, not %r." % cached_engine_batches)
681 cached_engine_batches = [_execute.make_int(_i, "cached_engine_batches") for _i in cached_engine_batches]
682 if fixed_input_size is None:
683 fixed_input_size = True
684 fixed_input_size = _execute.make_bool(fixed_input_size, "fixed_input_size")
685 if static_engine is None:
686 static_engine = True
687 static_engine = _execute.make_bool(static_engine, "static_engine")
688 if profile_strategy is None:
689 profile_strategy = ""
690 profile_strategy = _execute.make_str(profile_strategy, "profile_strategy")
691 if use_explicit_precision is None:
692 use_explicit_precision = False
693 use_explicit_precision = _execute.make_bool(use_explicit_precision, "use_explicit_precision")
694 _attr_InT, in_tensor = _execute.convert_to_mixed_eager_tensors(in_tensor, ctx)
695 _inputs_flat = list(in_tensor)
696 _attrs = ("serialized_segment", serialized_segment, "segment_func",
697 segment_func, "InT", _attr_InT, "OutT", OutT, "input_shapes", input_shapes,
698 "output_shapes", output_shapes, "max_cached_engines_count",
699 max_cached_engines_count, "max_batch_size", max_batch_size,
700 "workspace_size_bytes", workspace_size_bytes, "precision_mode",
701 precision_mode, "calibration_data", calibration_data, "use_calibration",
702 use_calibration, "segment_funcdef_name", segment_funcdef_name,
703 "cached_engine_batches", cached_engine_batches, "fixed_input_size",
704 fixed_input_size, "static_engine", static_engine, "profile_strategy",
705 profile_strategy, "use_explicit_precision", use_explicit_precision)
706 _result = _execute.execute(b"TRTEngineOp", len(OutT), inputs=_inputs_flat,
707 attrs=_attrs, ctx=ctx, name=name)
708 if _execute.must_record_gradient():
709 _execute.record_gradient(
710 "TRTEngineOp", _inputs_flat, _attrs, _result)
711 return _result