Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/pywrap_mlir.py: 43%
28 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python module for MLIR functions exported by pybind11."""
17# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable
18from tensorflow.python import pywrap_tensorflow
19from tensorflow.python.eager import context
20from tensorflow.python._pywrap_mlir import *
23def import_graphdef(
24 graphdef,
25 pass_pipeline,
26 show_debug_info,
27 input_names=None,
28 input_data_types=None,
29 input_data_shapes=None,
30 output_names=[],
31):
32 if input_names is not None:
33 return ImportGraphDef(
34 str(graphdef).encode('utf-8'),
35 pass_pipeline.encode('utf-8'),
36 show_debug_info,
37 ','.join(input_names).encode('utf-8'),
38 ','.join(input_data_types).encode('utf-8'),
39 ':'.join(input_data_shapes).encode('utf-8'),
40 ','.join(output_names).encode('utf-8'),
41 )
42 return ImportGraphDef(
43 str(graphdef).encode('utf-8'),
44 pass_pipeline.encode('utf-8'),
45 show_debug_info,
46 )
49def import_function(concrete_function, pass_pipeline, show_debug_info):
50 ctxt = context.context()
51 ctxt.ensure_initialized()
52 return ImportFunction(
53 ctxt._handle,
54 str(concrete_function.function_def).encode('utf-8'),
55 pass_pipeline.encode('utf-8'),
56 show_debug_info,
57 )
60def experimental_convert_saved_model_to_mlir(
61 saved_model_path, exported_names, show_debug_info
62):
63 return ExperimentalConvertSavedModelToMlir(
64 str(saved_model_path).encode('utf-8'),
65 str(exported_names).encode('utf-8'),
66 show_debug_info,
67 )
70def experimental_convert_saved_model_v1_to_mlir_lite(
71 saved_model_path, exported_names, tags, upgrade_legacy, show_debug_info
72):
73 return ExperimentalConvertSavedModelV1ToMlirLite(
74 str(saved_model_path).encode('utf-8'),
75 str(exported_names).encode('utf-8'),
76 str(tags).encode('utf-8'),
77 upgrade_legacy,
78 show_debug_info,
79 )
82def experimental_convert_saved_model_v1_to_mlir(
83 saved_model_path,
84 exported_names,
85 tags,
86 lift_variables,
87 include_variables_in_initializers,
88 upgrade_legacy,
89 show_debug_info,
90):
91 return ExperimentalConvertSavedModelV1ToMlir(
92 str(saved_model_path).encode('utf-8'),
93 str(exported_names).encode('utf-8'),
94 str(tags).encode('utf-8'),
95 lift_variables,
96 include_variables_in_initializers,
97 upgrade_legacy,
98 show_debug_info,
99 )
102def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info):
103 return ExperimentalRunPassPipeline(
104 mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info
105 )
108def experimental_write_bytecode(filename, mlir_txt):
109 return ExperimentalWriteBytecode(filename.encode('utf-8'), mlir_txt.encode())
112def experimental_tflite_to_tosa_bytecode(
113 flatbuffer,
114 bytecode,
115 use_external_constant=False,
116 ordered_input_arrays=None,
117 ordered_output_arrays=None,
118):
119 if ordered_input_arrays is None:
120 ordered_input_arrays = []
121 if ordered_output_arrays is None:
122 ordered_output_arrays = []
123 return ExperimentalTFLiteToTosaBytecode(
124 flatbuffer.encode('utf-8'),
125 bytecode.encode('utf-8'),
126 use_external_constant,
127 ordered_input_arrays,
128 ordered_output_arrays,
129 )