Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/pywrap_mlir.py: 43%

28 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Python module for MLIR functions exported by pybind11.""" 

16 

17# pylint: disable=invalid-import-order, g-bad-import-order, wildcard-import, unused-import, undefined-variable 

18from tensorflow.python import pywrap_tensorflow 

19from tensorflow.python.eager import context 

20from tensorflow.python._pywrap_mlir import * 

21 

22 

23def import_graphdef( 

24 graphdef, 

25 pass_pipeline, 

26 show_debug_info, 

27 input_names=None, 

28 input_data_types=None, 

29 input_data_shapes=None, 

30 output_names=[], 

31): 

32 if input_names is not None: 

33 return ImportGraphDef( 

34 str(graphdef).encode('utf-8'), 

35 pass_pipeline.encode('utf-8'), 

36 show_debug_info, 

37 ','.join(input_names).encode('utf-8'), 

38 ','.join(input_data_types).encode('utf-8'), 

39 ':'.join(input_data_shapes).encode('utf-8'), 

40 ','.join(output_names).encode('utf-8'), 

41 ) 

42 return ImportGraphDef( 

43 str(graphdef).encode('utf-8'), 

44 pass_pipeline.encode('utf-8'), 

45 show_debug_info, 

46 ) 

47 

48 

49def import_function(concrete_function, pass_pipeline, show_debug_info): 

50 ctxt = context.context() 

51 ctxt.ensure_initialized() 

52 return ImportFunction( 

53 ctxt._handle, 

54 str(concrete_function.function_def).encode('utf-8'), 

55 pass_pipeline.encode('utf-8'), 

56 show_debug_info, 

57 ) 

58 

59 

60def experimental_convert_saved_model_to_mlir( 

61 saved_model_path, exported_names, show_debug_info 

62): 

63 return ExperimentalConvertSavedModelToMlir( 

64 str(saved_model_path).encode('utf-8'), 

65 str(exported_names).encode('utf-8'), 

66 show_debug_info, 

67 ) 

68 

69 

70def experimental_convert_saved_model_v1_to_mlir_lite( 

71 saved_model_path, exported_names, tags, upgrade_legacy, show_debug_info 

72): 

73 return ExperimentalConvertSavedModelV1ToMlirLite( 

74 str(saved_model_path).encode('utf-8'), 

75 str(exported_names).encode('utf-8'), 

76 str(tags).encode('utf-8'), 

77 upgrade_legacy, 

78 show_debug_info, 

79 ) 

80 

81 

82def experimental_convert_saved_model_v1_to_mlir( 

83 saved_model_path, 

84 exported_names, 

85 tags, 

86 lift_variables, 

87 include_variables_in_initializers, 

88 upgrade_legacy, 

89 show_debug_info, 

90): 

91 return ExperimentalConvertSavedModelV1ToMlir( 

92 str(saved_model_path).encode('utf-8'), 

93 str(exported_names).encode('utf-8'), 

94 str(tags).encode('utf-8'), 

95 lift_variables, 

96 include_variables_in_initializers, 

97 upgrade_legacy, 

98 show_debug_info, 

99 ) 

100 

101 

102def experimental_run_pass_pipeline(mlir_txt, pass_pipeline, show_debug_info): 

103 return ExperimentalRunPassPipeline( 

104 mlir_txt.encode('utf-8'), pass_pipeline.encode('utf-8'), show_debug_info 

105 ) 

106 

107 

108def experimental_write_bytecode(filename, mlir_txt): 

109 return ExperimentalWriteBytecode(filename.encode('utf-8'), mlir_txt.encode()) 

110 

111 

112def experimental_tflite_to_tosa_bytecode( 

113 flatbuffer, 

114 bytecode, 

115 use_external_constant=False, 

116 ordered_input_arrays=None, 

117 ordered_output_arrays=None, 

118): 

119 if ordered_input_arrays is None: 

120 ordered_input_arrays = [] 

121 if ordered_output_arrays is None: 

122 ordered_output_arrays = [] 

123 return ExperimentalTFLiteToTosaBytecode( 

124 flatbuffer.encode('utf-8'), 

125 bytecode.encode('utf-8'), 

126 use_external_constant, 

127 ordered_input_arrays, 

128 ordered_output_arrays, 

129 )