Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/saving/legacy/saved_model/utils.py: 28%

112 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions shared between SavedModel saving/loading 

16implementations.""" 

17 

18import copy 

19import itertools 

20import threading 

21import types 

22 

23import tensorflow.compat.v2 as tf 

24 

25from keras.src import backend 

26from keras.src.engine import base_layer_utils 

27from keras.src.utils import control_flow_util 

28from keras.src.utils import tf_contextlib 

29from keras.src.utils.generic_utils import LazyLoader 

30from keras.src.utils.layer_utils import CallFunctionSpec 

31 

32training_lib = LazyLoader("training_lib", globals(), "keras.src.engine.training") 

33 

34 

35def use_wrapped_call( 

36 layer, call_fn, call_spec, default_training_value=None, return_method=False 

37): 

38 """Creates fn that adds losses returned by call_fn & returns the outputs. 

39 

40 Args: 

41 layer: A Keras layer object 

42 call_fn: tf.function that takes layer inputs (and possibly a training 

43 arg), and returns a tuple of (outputs, list of losses). 

44 call_spec: The `CallFunctionSpec` for the layer's call function. 

45 default_training_value: Default value of the training kwarg. If `None`, 

46 the default is `tf.keras.backend.learning_phase()`. 

47 return_method: Whether to return a method bound to the layer. 

48 

49 Returns: 

50 function that calls call_fn and returns the outputs. Losses returned by 

51 call_fn are added to the layer losses. 

52 """ 

53 expects_training_arg = layer_uses_training_bool(layer) 

54 

55 fn, arg_spec = maybe_add_training_arg( 

56 call_spec, call_fn, expects_training_arg, default_training_value 

57 ) 

58 

59 def return_outputs_and_add_losses(*args, **kwargs): 

60 """Returns the outputs from the layer call function, and adds the 

61 losses.""" 

62 if return_method: 

63 args = args[1:] 

64 

65 outputs, losses = fn(*args, **kwargs) 

66 layer.add_loss(losses) 

67 

68 # TODO(kathywu): This is a temporary hack. When a network of layers is 

69 # revived from SavedModel, only the top-level layer will have losses. 

70 # This causes issues in eager mode because the child layers may have 

71 # graph losses (thus model.losses returns a mix of Eager and graph 

72 # tensors). To fix this, whenever eager losses are added to one layer, 

73 # add eager losses to all child layers. This causes `.losses` to only 

74 # return eager losses. 

75 

76 if tf.executing_eagerly(): 

77 for i in layer._flatten_layers(): 

78 if i is not layer: 

79 i._eager_losses = [ 

80 base_layer_utils.REVIVED_LOSS_PLACEHOLDER 

81 ] 

82 

83 return outputs 

84 

85 decorated = tf.__internal__.decorator.make_decorator( 

86 target=call_fn, 

87 decorator_func=return_outputs_and_add_losses, 

88 decorator_argspec=arg_spec, 

89 ) 

90 

91 if return_method: 

92 return types.MethodType(decorated, layer) 

93 else: 

94 return decorated 

95 

96 

97def layer_uses_training_bool(layer): 

98 """Returns whether this layer or any of its children uses the training 

99 arg.""" 

100 if layer._expects_training_arg: 

101 return True 

102 visited = {layer} 

103 to_visit = list_all_layers(layer) 

104 while to_visit: 

105 layer = to_visit.pop() 

106 if layer in visited: 

107 continue 

108 if getattr(layer, "_expects_training_arg", True): 

109 return True 

110 visited.add(layer) 

111 to_visit.extend(list_all_layers(layer)) 

112 return False 

113 

114 

115def list_all_layers(obj): 

116 if isinstance(obj, training_lib.Model): 

117 # Handle special case of Sequential, which doesn't return 

118 # the `Input` layer. 

119 return obj.layers 

120 else: 

121 return list(obj._flatten_layers(include_self=False, recursive=False)) 

122 

123 

124def list_all_layers_and_sublayers(obj): 

125 s = set([obj]) 

126 s.update( 

127 itertools.chain.from_iterable( 

128 list_all_layers_and_sublayers(layer) 

129 for layer in list_all_layers(obj) 

130 ) 

131 ) 

132 return s 

133 

134 

135def maybe_add_training_arg( 

136 call_spec, wrapped_call, expects_training_arg, default_training_value 

137): 

138 """Decorate call and optionally adds training argument. 

139 

140 If a layer expects a training argument, this function ensures that 

141 'training' is present in the layer args or kwonly args, with the default 

142 training value. 

143 

144 Args: 

145 call_spec: CallFunctionSpec of the layer. 

146 wrapped_call: Wrapped call function. 

147 expects_training_arg: Whether to include 'training' argument. 

148 default_training_value: Default value of the training kwarg to include in 

149 the arg spec. If `None`, the default is 

150 `tf.keras.backend.learning_phase()`. 

151 

152 Returns: 

153 Tuple of ( 

154 function that calls `wrapped_call` and sets the training arg, 

155 Argspec of returned function or `None` if the argspec is unchanged) 

156 """ 

157 if not expects_training_arg: 

158 return wrapped_call, None 

159 

160 arg_spec = set_training_arg_spec( 

161 call_spec.full_argspec, default_training_value 

162 ) 

163 call_spec = CallFunctionSpec(arg_spec) 

164 

165 def wrap_with_training_arg(*args, **kwargs): 

166 """Wrap the `wrapped_call` function, and set training argument.""" 

167 try: 

168 training = call_spec.get_arg_value( 

169 "training", args, kwargs, inputs_in_args=True 

170 ) 

171 except KeyError: 

172 training = None 

173 

174 if training is None: 

175 training = ( 

176 default_training_value 

177 or base_layer_utils.call_context().training 

178 or backend.learning_phase() 

179 ) 

180 

181 args = list(args) 

182 kwargs = kwargs.copy() 

183 

184 def replace_training_and_call(training): 

185 new_args, new_kwargs = call_spec.set_arg_value( 

186 "training", training, args, kwargs, inputs_in_args=True 

187 ) 

188 return wrapped_call(*new_args, **new_kwargs) 

189 

190 return control_flow_util.smart_cond( 

191 training, 

192 lambda: replace_training_and_call(True), 

193 lambda: replace_training_and_call(False), 

194 ) 

195 

196 return wrap_with_training_arg, arg_spec 

197 

198 

199def set_training_arg_spec(arg_spec, default_training_value): 

200 """Set `training=DEFAULT` argument in an ArgSpec.""" 

201 if "training" in arg_spec.args: 

202 # If `training` is already in the args list, try to set the default 

203 # value. 

204 index = arg_spec.args.index("training") 

205 training_default_index = len(arg_spec.args) - index 

206 defaults = ( 

207 list(arg_spec.defaults) if arg_spec.defaults is not None else [] 

208 ) 

209 if ( 

210 arg_spec.defaults 

211 and len(arg_spec.defaults) >= training_default_index 

212 and defaults[-training_default_index] is None 

213 ): 

214 defaults[-training_default_index] = default_training_value 

215 return arg_spec._replace(defaults=defaults) 

216 elif "training" not in arg_spec.kwonlyargs: 

217 kwonlyargs = arg_spec.kwonlyargs + ["training"] 

218 kwonlydefaults = copy.copy(arg_spec.kwonlydefaults) or {} 

219 kwonlydefaults["training"] = default_training_value 

220 return arg_spec._replace( 

221 kwonlyargs=kwonlyargs, kwonlydefaults=kwonlydefaults 

222 ) 

223 

224 return arg_spec 

225 

226 

227class SaveOptionsContext(threading.local): 

228 def __init__(self): 

229 super().__init__() 

230 self.save_traces = True 

231 self.in_tf_saved_model_scope = False 

232 

233 

234_save_options_context = SaveOptionsContext() 

235 

236 

237@tf_contextlib.contextmanager 

238def keras_option_scope(save_traces, in_tf_saved_model_scope=True): 

239 save_traces_previous_value = _save_options_context.save_traces 

240 in_scope_previous_value = _save_options_context.in_tf_saved_model_scope 

241 try: 

242 _save_options_context.save_traces = save_traces 

243 _save_options_context.in_tf_saved_model_scope = in_tf_saved_model_scope 

244 yield 

245 finally: 

246 _save_options_context.save_traces = save_traces_previous_value 

247 _save_options_context.in_tf_saved_model_scope = in_scope_previous_value 

248 

249 

250def should_save_traces(): 

251 """Whether to trace layer functions-can be disabled in the save_traces 

252 arg.""" 

253 return _save_options_context.save_traces 

254 

255 

256def in_tf_saved_model_scope(): 

257 return _save_options_context.in_tf_saved_model_scope 

258 

259 

260@tf_contextlib.contextmanager 

261def no_automatic_dependency_tracking_scope(obj): 

262 """Context that disables automatic dependency tracking when assigning attrs. 

263 

264 Objects that inherit from Autotrackable automatically creates dependencies 

265 to trackable objects through attribute assignments, and wraps data 

266 structures (lists or dicts) with trackable classes. This scope may be used 

267 to temporarily disable this behavior. This works similar to the decorator 

268 `no_automatic_dependency_tracking`. 

269 

270 Example usage: 

271 ``` 

272 model = tf.keras.Model() 

273 model.arr1 = [] # Creates a ListWrapper object 

274 with no_automatic_dependency_tracking_scope(model): 

275 model.arr2 = [] # Creates a regular, untracked python list 

276 ``` 

277 

278 Args: 

279 obj: A trackable object. 

280 

281 Yields: 

282 a scope in which the object doesn't track dependencies. 

283 """ 

284 previous_value = getattr(obj, "_setattr_tracking", True) 

285 obj._setattr_tracking = False 

286 try: 

287 yield 

288 finally: 

289 obj._setattr_tracking = previous_value 

290