Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/lite/python/tflite_keras_util.py: 18%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Keras functions required by TensorFlow Lite. 

17 

18The functions defined in this library have been copied over from Keras in order 

19to remove the dependency from TensorFlow Lite to Keras. The functions which 

20could not be copied over are accessed using the dependency inversion principle. 

21(for details, refer to tensorflow/python/util/keras_deps.py). 

22""" 

23 

24import copy 

25 

26from tensorflow.python.eager import def_function 

27from tensorflow.python.util import keras_deps 

28from tensorflow.python.util import nest 

29from tensorflow.python.util.compat import collections_abc 

30 

31 

32def _enforce_names_consistency(specs): 

33 """Enforces that either all specs have names or none do.""" 

34 

35 def _has_name(spec): 

36 return hasattr(spec, 'name') and spec.name is not None 

37 

38 def _clear_name(spec): 

39 spec = copy.deepcopy(spec) 

40 if hasattr(spec, 'name'): 

41 spec._name = None # pylint:disable=protected-access 

42 return spec 

43 

44 flat_specs = nest.flatten(specs) 

45 name_inconsistency = ( 

46 any(_has_name(s) for s in flat_specs) and 

47 not all(_has_name(s) for s in flat_specs)) 

48 

49 if name_inconsistency: 

50 specs = nest.map_structure(_clear_name, specs) 

51 return specs 

52 

53 

54def model_input_signature(model, keep_original_batch_size=False): 

55 """Inspect model to get its input signature. 

56 

57 The model's input signature is a list with a single (possibly-nested) object. 

58 This is due to the Keras-enforced restriction that tensor inputs must be 

59 passed in as the first argument. 

60 

61 For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>} 

62 will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}] 

63 

64 Args: 

65 model: Keras Model object. 

66 keep_original_batch_size: A boolean indicating whether we want to keep using 

67 the original batch size or set it to None. Default is `False`, which means 

68 that the batch dim of the returned input signature will always be set to 

69 `None`. 

70 

71 Returns: 

72 A list containing either a single TensorSpec or an object with nested 

73 TensorSpecs. This list does not contain the `training` argument. 

74 """ 

75 if hasattr(model, 'save_spec'): 

76 input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size) 

77 if input_specs is None: 

78 return None 

79 # The model's save spec returns (args, kwargs). Extract the first input arg 

80 # to use as the input spec. 

81 # TODO(b/188105669): Add support for multiple tensor arguments. 

82 input_specs = input_specs[0][0] 

83 else: 

84 input_specs = model._get_save_spec( # pylint: disable=protected-access 

85 dynamic_batch=not keep_original_batch_size) 

86 if input_specs is None: 

87 return None 

88 input_specs = _enforce_names_consistency(input_specs) 

89 # Return a list with a single element as the model's input signature. 

90 if isinstance(input_specs, 

91 collections_abc.Sequence) and len(input_specs) == 1: 

92 # Note that the isinstance check filters out single-element dictionaries, 

93 # which should also be wrapped as a single-element list. 

94 return input_specs 

95 else: 

96 return [input_specs] 

97 

98 

99def raise_model_input_error(model): 

100 raise ValueError( 

101 'Model {} cannot be saved because the input shapes have not been ' 

102 'set. Usually, input shapes are automatically determined from calling' 

103 ' `.fit()` or `.predict()`. To manually set the shapes, call ' 

104 '`model.build(input_shape)`.'.format(model)) 

105 

106 

107def _create_pseudo_names(tensors, prefix): 

108 """Creates pseudo {input | output} names for subclassed Models. 

109 

110 Warning: this function should only be used to define default 

111 names for `Metics` and `SavedModel`. No other use cases should 

112 rely on a `Model`'s input or output names. 

113 

114 Example with dict: 

115 

116 `{'a': [x1, x2], 'b': x3}` becomes: 

117 `['a_1', 'a_2', 'b']` 

118 

119 Example with list: 

120 

121 `[x, y]` becomes: 

122 `['output_1', 'output_2']` 

123 

124 Args: 

125 tensors: `Model`'s outputs or inputs. 

126 prefix: 'output_' for outputs, 'input_' for inputs. 

127 

128 Returns: 

129 Flattened list of pseudo names. 

130 """ 

131 

132 def one_index(ele): 

133 # Start with "output_1" instead of "output_0". 

134 if isinstance(ele, int): 

135 return ele + 1 

136 return ele 

137 

138 flat_paths = list(nest.yield_flat_paths(tensors)) 

139 flat_paths = nest.map_structure(one_index, flat_paths) 

140 names = [] 

141 for path in flat_paths: 

142 if not path: 

143 name = prefix + '1' # Single output. 

144 else: 

145 name = '_'.join(str(p) for p in path) 

146 if isinstance(path[0], int): 

147 name = prefix + name 

148 names.append(name) 

149 return names 

150 

151 

152def create_pseudo_output_names(outputs): 

153 """Create pseudo output names for a subclassed Model.""" 

154 return _create_pseudo_names(outputs, prefix='output_') 

155 

156 

157def trace_model_call(model, input_signature=None): 

158 """Trace the model call to create a tf.function for exporting a Keras model. 

159 

160 Args: 

161 model: A Keras model. 

162 input_signature: optional, a list of tf.TensorSpec objects specifying the 

163 inputs to the model. 

164 

165 Returns: 

166 A tf.function wrapping the model's call function with input signatures set. 

167 

168 Raises: 

169 ValueError: if input signature cannot be inferred from the model. 

170 """ 

171 if input_signature is None: 

172 if isinstance(model.call, def_function.Function): 

173 input_signature = model.call.input_signature 

174 

175 if input_signature is None: 

176 input_signature = model_input_signature(model) 

177 

178 if input_signature is None: 

179 raise_model_input_error(model) 

180 

181 @def_function.function(input_signature=input_signature, autograph=False) 

182 def _wrapped_model(*args): 

183 """A concrete tf.function that wraps the model's call function.""" 

184 # When given a single input, Keras models will call the model on the tensor 

185 # rather than a list consisting of the single tensor. 

186 inputs = args[0] if len(input_signature) == 1 else list(args) 

187 

188 with keras_deps.get_call_context_function()().enter( 

189 model, inputs=inputs, build_graph=False, training=False, saving=True): 

190 outputs = model(inputs, training=False) 

191 

192 return outputs 

193 

194 return _wrapped_model