Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/load_v1_in_v2.py: 23%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Import a TF v1-style SavedModel when executing eagerly.""" 

16 

17import functools 

18 

19from tensorflow.python.eager import context 

20from tensorflow.python.eager import lift_to_graph 

21from tensorflow.python.eager import wrap_function 

22from tensorflow.python.framework import composite_tensor 

23from tensorflow.python.framework import constant_op 

24from tensorflow.python.framework import func_graph 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import sparse_tensor 

27from tensorflow.python.platform import tf_logging as logging 

28from tensorflow.python.saved_model import function_deserialization 

29from tensorflow.python.saved_model import loader_impl 

30from tensorflow.python.saved_model import signature_serialization 

31from tensorflow.python.saved_model.pywrap_saved_model import metrics 

32from tensorflow.python.trackable import asset 

33from tensorflow.python.trackable import autotrackable 

34from tensorflow.python.trackable import resource 

35from tensorflow.python.training import monitored_session 

36from tensorflow.python.training import saver as tf_saver 

37from tensorflow.python.util import nest 

38 

39# API label for SavedModel metrics. 

40_LOAD_V1_V2_LABEL = "load_v1_in_v2" 

41 

42 

43class _Initializer(resource.CapturableResource): 

44 """Represents an initialization operation restored from a SavedModel. 

45 

46 Without this object re-export of imported 1.x SavedModels would omit the 

47 original SavedModel's initialization procedure. 

48 

49 Created when `tf.saved_model.load` loads a TF 1.x-style SavedModel with an 

50 initialization op. This object holds a function that runs the 

51 initialization. It does not require any manual user intervention; 

52 `tf.saved_model.save` will see this object and automatically add it to the 

53 exported SavedModel, and `tf.saved_model.load` runs the initialization 

54 function automatically. 

55 """ 

56 

57 def __init__(self, init_fn, asset_paths): 

58 super(_Initializer, self).__init__() 

59 self._asset_paths = asset_paths 

60 self._init_fn = init_fn 

61 

62 def _create_resource(self): 

63 # Return a constant here so that when re-saved, the traced `create_resource` 

64 # has valid returns. 

65 return constant_op.constant(1.0) 

66 

67 def _initialize(self): 

68 return self._init_fn(*[path.asset_path for path in self._asset_paths]) 

69 

70 

71class _EagerSavedModelLoader(loader_impl.SavedModelLoader): 

72 """Loads a SavedModel without using Sessions.""" 

73 

74 def get_meta_graph_def_from_tags(self, tags): 

75 """Override to support implicit one-MetaGraph loading with tags=None.""" 

76 if tags is None: 

77 if len(self._saved_model.meta_graphs) != 1: 

78 tag_sets = [ 

79 mg.meta_info_def.tags for mg in self._saved_model.meta_graphs 

80 ] 

81 raise ValueError( 

82 "Importing a SavedModel with `tf.saved_model.load` requires a " 

83 "`tags=` argument if there is more than one MetaGraph. Got " 

84 f"`tags=None`, but there are {len(self._saved_model.meta_graphs)} " 

85 f"MetaGraphs in the SavedModel with tag sets: {tag_sets}. Pass a " 

86 "`tags=` argument to load this SavedModel." 

87 ) 

88 return self._saved_model.meta_graphs[0] 

89 return super(_EagerSavedModelLoader, self).get_meta_graph_def_from_tags( 

90 tags 

91 ) 

92 

93 def load_graph(self, returns, meta_graph_def): 

94 """Called from wrap_function to import `meta_graph_def`.""" 

95 # pylint: disable=protected-access 

96 saver, _ = tf_saver._import_meta_graph_with_return_elements(meta_graph_def) 

97 # pylint: enable=protected-access 

98 returns[0] = saver 

99 

100 def _extract_saver_restore(self, wrapped, saver): 

101 if saver is None: 

102 return None 

103 saver_def = saver.saver_def 

104 filename_tensor = wrapped.graph.as_graph_element( 

105 saver_def.filename_tensor_name 

106 ) 

107 # We both feed and fetch filename_tensor so we have an operation to use to 

108 # feed into variable initializers (only relevant for v1 graph building). 

109 return wrapped.prune( 

110 feeds=[filename_tensor], 

111 fetches=[ 

112 filename_tensor, 

113 wrapped.graph.as_graph_element(saver_def.restore_op_name), 

114 ], 

115 ) 

116 

117 def restore_variables(self, wrapped, restore_from_saver): 

118 """Restores variables from the checkpoint.""" 

119 if restore_from_saver is not None: 

120 initializer, _ = restore_from_saver( 

121 constant_op.constant(self._variables_path) 

122 ) 

123 if not ops.executing_eagerly_outside_functions(): 

124 # Add the initialization operation to the "saved_model_initializers" 

125 # collection in case we don't have any lifted variables to attach it to. 

126 ops.add_to_collection("saved_model_initializers", initializer) 

127 one_unlifted = False 

128 

129 for variable in wrapped.graph.get_collection_ref( 

130 ops.GraphKeys.GLOBAL_VARIABLES 

131 ): 

132 if variable.graph is wrapped.graph: 

133 one_unlifted = True 

134 # pylint: disable=protected-access 

135 variable._initializer_op = initializer 

136 # pylint: enable=protected-access 

137 if one_unlifted: 

138 logging.warning( 

139 "Some variables could not be lifted out of a loaded function. " 

140 "Please run " 

141 '`sess.run(tf.get_collection("saved_model_initializers"))`to ' 

142 "restore these variables." 

143 ) 

144 

145 def _extract_signatures(self, wrapped, meta_graph_def): 

146 """Creates ConcreteFunctions for signatures in `meta_graph_def`.""" 

147 signature_functions = {} 

148 for signature_key, signature_def in meta_graph_def.signature_def.items(): 

149 if signature_def.inputs: 

150 input_items = sorted( 

151 signature_def.inputs.items(), key=lambda item: item[0] 

152 ) 

153 original_input_names, input_specs = zip(*input_items) 

154 else: 

155 original_input_names = [] 

156 input_specs = [] 

157 # TODO(b/205015292): Support optional arguments 

158 feeds = [ 

159 wrap_function._get_element_from_tensor_info(input_spec, wrapped.graph) # pylint: disable=protected-access 

160 for input_spec in input_specs 

161 ] 

162 input_names = [] 

163 input_tensors = [] 

164 for original_input_name, feed in zip(original_input_names, feeds): 

165 if isinstance(feed, sparse_tensor.SparseTensor): 

166 # We have to give explicit name for SparseTensor arguments, because 

167 # these are not present in the TensorInfo. 

168 indices_name = "%s_indices" % original_input_name 

169 values_name = "%s_values" % original_input_name 

170 dense_shape_name = "%s_dense_shape" % original_input_name 

171 input_names.extend([indices_name, values_name, dense_shape_name]) 

172 input_tensors.extend([feed.indices, feed.values, feed.dense_shape]) 

173 elif isinstance(feed, composite_tensor.CompositeTensor): 

174 component_tensors = nest.flatten(feed, expand_composites=True) 

175 input_names.extend( 

176 "%s_component_%d" % (original_input_name, n) 

177 for n in range(len(component_tensors)) 

178 ) 

179 input_tensors.extend(component_tensors) 

180 else: 

181 input_names.append(original_input_name) 

182 input_tensors.append(feed) 

183 fetches = {name: out for name, out in signature_def.outputs.items()} 

184 try: 

185 signature_fn = wrapped.prune(feeds=feeds, fetches=fetches) 

186 except lift_to_graph.UnliftableError as ex: 

187 # Mutate the exception to add a bit more detail. 

188 args = ex.args 

189 if not args: 

190 message = "" 

191 else: 

192 message = args[0] 

193 message = ( 

194 "A SavedModel signature needs an input for each placeholder the " 

195 "signature's outputs use. An output for signature '{}' depends on " 

196 "a placeholder which is not an input (i.e. the placeholder is not " 

197 "fed a value).\n\n" 

198 ).format(signature_key) + message 

199 ex.args = (message,) + args[1:] 

200 raise 

201 # pylint: disable=protected-access 

202 signature_fn._arg_keywords = input_names 

203 signature_fn._func_graph.structured_input_signature = ( 

204 (), 

205 func_graph.convert_structure_to_signature( 

206 dict(zip(input_names, input_tensors)) 

207 ), 

208 ) 

209 

210 if len(input_names) == 1: 

211 # Allowing positional arguments does not create any ambiguity if there's 

212 # only one. 

213 signature_fn._num_positional_args = 1 

214 else: 

215 signature_fn._num_positional_args = 0 

216 # pylint: enable=protected-access 

217 signature_functions[signature_key] = signature_fn 

218 return signature_functions 

219 

220 def load(self, tags): 

221 """Creates an object from the MetaGraph identified by `tags`.""" 

222 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 

223 load_shared_name_suffix = "_load_{}".format(ops.uid()) 

224 functions = function_deserialization.load_function_def_library( 

225 meta_graph_def.graph_def.library, 

226 load_shared_name_suffix=load_shared_name_suffix, 

227 ) 

228 # Replace existing functions in the MetaGraphDef with renamed functions so 

229 # we don't have duplicates or name collisions. 

230 meta_graph_def.graph_def.library.Clear() 

231 for function in functions.values(): 

232 meta_graph_def.graph_def.library.function.add().CopyFrom( 

233 function.function_def 

234 ) 

235 # We've renamed functions and shared names. We need the same operation on 

236 # the GraphDef itself for consistency. 

237 for node_def in meta_graph_def.graph_def.node: 

238 function_deserialization.fix_node_def( 

239 node_def, functions, load_shared_name_suffix 

240 ) 

241 

242 load_graph_returns = [None] 

243 wrapped = wrap_function.wrap_function( 

244 functools.partial(self.load_graph, load_graph_returns, meta_graph_def), 

245 signature=[], 

246 ) 

247 (saver,) = load_graph_returns 

248 restore_from_saver = self._extract_saver_restore(wrapped, saver) 

249 self.restore_variables(wrapped, restore_from_saver) 

250 with wrapped.graph.as_default(): 

251 init_op = ( 

252 loader_impl.get_init_op(meta_graph_def) 

253 or monitored_session.Scaffold.default_local_init_op() 

254 ) 

255 # Add a dummy Tensor we know we can fetch to add control dependencies to. 

256 init_anchor = constant_op.constant(0.0, name="dummy_fetch") 

257 

258 root = autotrackable.AutoTrackable() 

259 if restore_from_saver is not None: 

260 root.restore = lambda path: restore_from_saver(constant_op.constant(path)) 

261 asset_feed_tensors = [] 

262 asset_paths = [] 

263 for tensor_name, value in loader_impl.get_asset_tensors( 

264 self._export_dir, meta_graph_def 

265 ).items(): 

266 asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name)) 

267 asset_paths.append(asset.Asset(value)) 

268 init_fn = wrapped.prune( 

269 feeds=asset_feed_tensors, 

270 fetches=[init_anchor, wrapped.graph.as_graph_element(init_op)], 

271 ) 

272 initializer = _Initializer(init_fn, asset_paths) 

273 # pylint: disable=protected-access 

274 local_init_op, _ = initializer._initialize() 

275 # pylint: enable=protected-access 

276 with ops.init_scope(): 

277 if not context.executing_eagerly(): 

278 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, local_init_op) 

279 for variable in wrapped.graph.get_collection_ref( 

280 ops.GraphKeys.LOCAL_VARIABLES 

281 ): 

282 # pylint: disable=protected-access 

283 variable._initializer_op = local_init_op 

284 # pylint: enable=protected-access 

285 root.initializer = initializer 

286 root.asset_paths = asset_paths 

287 signature_functions = self._extract_signatures(wrapped, meta_graph_def) 

288 

289 root.signatures = signature_serialization.create_signature_map( 

290 signature_functions 

291 ) 

292 root.variables = list(wrapped.graph.variables) 

293 root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version 

294 root.tensorflow_git_version = ( 

295 meta_graph_def.meta_info_def.tensorflow_git_version 

296 ) 

297 root.graph = wrapped.graph 

298 root.prune = wrapped.prune 

299 return root 

300 

301 

302def load(export_dir, tags): 

303 """Load a v1-style SavedModel as an object.""" 

304 metrics.IncrementReadApi(_LOAD_V1_V2_LABEL) 

305 loader = _EagerSavedModelLoader(export_dir) 

306 result = loader.load(tags=tags) 

307 metrics.IncrementRead(write_version="1") 

308 return result