Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py: 24%

131 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions shared between SavedModel saving/loading implementations.""" 

16 

17import itertools 

18import threading 

19import types 

20 

21from tensorflow.python.eager import context 

22from tensorflow.python.keras import backend as K 

23from tensorflow.python.keras.engine import base_layer_utils 

24from tensorflow.python.keras.utils import control_flow_util 

25from tensorflow.python.keras.utils import tf_contextlib 

26from tensorflow.python.keras.utils import tf_inspect 

27from tensorflow.python.keras.utils.generic_utils import LazyLoader 

28from tensorflow.python.util import tf_decorator 

29 

30 

31# pylint:disable=g-inconsistent-quotes 

32training_lib = LazyLoader( 

33 "training_lib", globals(), 

34 "tensorflow.python.keras.engine.training") 

35# pylint:enable=g-inconsistent-quotes 

36 

37 

38def use_wrapped_call(layer, call_fn, default_training_value=None, 

39 return_method=False): 

40 """Creates fn that adds the losses returned by call_fn & returns the outputs. 

41 

42 Args: 

43 layer: A Keras layer object 

44 call_fn: tf.function that takes layer inputs (and possibly a training arg), 

45 and returns a tuple of (outputs, list of losses). 

46 default_training_value: Default value of the training kwarg. If `None`, the 

47 default is `K.learning_phase()`. 

48 return_method: Whether to return a method bound to the layer. 

49 

50 Returns: 

51 function that calls call_fn and returns the outputs. Losses returned by 

52 call_fn are added to the layer losses. 

53 """ 

54 expects_training_arg = layer_uses_training_bool(layer) 

55 if hasattr(call_fn, 'original_layer_call'): # call_fn is a LayerCall object 

56 original_call = call_fn.original_layer_call 

57 # In Python 3, callable objects are not compatible with inspect.getargspec 

58 call_fn = call_fn.__call__ 

59 else: 

60 original_call = call_fn 

61 fn, arg_spec = maybe_add_training_arg( 

62 original_call, call_fn, expects_training_arg, default_training_value) 

63 

64 def return_outputs_and_add_losses(*args, **kwargs): 

65 """Returns the outputs from the layer call function, and adds the losses.""" 

66 if return_method: 

67 args = args[1:] 

68 

69 outputs, losses = fn(*args, **kwargs) 

70 layer.add_loss(losses, inputs=True) 

71 

72 # TODO(kathywu): This is a temporary hack. When a network of layers is 

73 # revived from SavedModel, only the top-level layer will have losses. This 

74 # causes issues in eager mode because the child layers may have graph losses 

75 # (thus model.losses returns a mix of Eager and graph tensors). To fix this, 

76 # whenever eager losses are added to one layer, add eager losses to all 

77 # child layers. This causes `.losses` to only return eager losses. 

78 # pylint: disable=protected-access 

79 if context.executing_eagerly(): 

80 for i in layer._flatten_layers(): 

81 if i is not layer: 

82 i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER] 

83 # pylint: enable=protected-access 

84 return outputs 

85 

86 decorated = tf_decorator.make_decorator( 

87 target=call_fn, 

88 decorator_func=return_outputs_and_add_losses, 

89 decorator_argspec=arg_spec) 

90 

91 if return_method: 

92 return types.MethodType(decorated, layer) 

93 else: 

94 return decorated 

95 

96 

97def layer_uses_training_bool(layer): 

98 """Returns whether this layer or any of its children uses the training arg.""" 

99 if layer._expects_training_arg: # pylint: disable=protected-access 

100 return True 

101 visited = {layer} 

102 to_visit = list_all_layers(layer) 

103 while to_visit: 

104 layer = to_visit.pop() 

105 if layer in visited: 

106 continue 

107 if getattr(layer, '_expects_training_arg', True): 

108 return True 

109 visited.add(layer) 

110 to_visit.extend(list_all_layers(layer)) 

111 return False 

112 

113 

114def list_all_layers(obj): 

115 if isinstance(obj, training_lib.Model): 

116 # Handle special case of Sequential, which doesn't return 

117 # the `Input` layer. 

118 return obj.layers 

119 else: 

120 return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access 

121 

122 

123def list_all_layers_and_sublayers(obj): 

124 s = set([obj]) 

125 s.update(itertools.chain.from_iterable( 

126 list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj))) 

127 return s 

128 

129 

130def maybe_add_training_arg( 

131 original_call, wrapped_call, expects_training_arg, default_training_value): 

132 """Decorate call and optionally adds training argument. 

133 

134 If a layer expects a training argument, this function ensures that 'training' 

135 is present in the layer args or kwonly args, with the default training value. 

136 

137 Args: 

138 original_call: Original call function. 

139 wrapped_call: Wrapped call function. 

140 expects_training_arg: Whether to include 'training' argument. 

141 default_training_value: Default value of the training kwarg to include in 

142 the arg spec. If `None`, the default is `K.learning_phase()`. 

143 

144 Returns: 

145 Tuple of ( 

146 function that calls `wrapped_call` and sets the training arg, 

147 Argspec of returned function or `None` if the argspec is unchanged) 

148 """ 

149 if not expects_training_arg: 

150 return wrapped_call, None 

151 def wrap_with_training_arg(*args, **kwargs): 

152 """Wrap the `wrapped_call` function, and set training argument.""" 

153 training_arg_index = get_training_arg_index(original_call) 

154 training = get_training_arg(training_arg_index, args, kwargs) 

155 if training is None: 

156 training = default_training_value or K.learning_phase() 

157 

158 args = list(args) 

159 kwargs = kwargs.copy() 

160 

161 def replace_training_and_call(training): 

162 set_training_arg(training, training_arg_index, args, kwargs) 

163 return wrapped_call(*args, **kwargs) 

164 

165 return control_flow_util.smart_cond( 

166 training, lambda: replace_training_and_call(True), 

167 lambda: replace_training_and_call(False)) 

168 

169 # Create arg spec for decorated function. If 'training' is not defined in the 

170 # args of the original arg spec, then add it to kwonlyargs. 

171 arg_spec = tf_inspect.getfullargspec(original_call) 

172 defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else [] 

173 

174 kwonlyargs = arg_spec.kwonlyargs 

175 kwonlydefaults = arg_spec.kwonlydefaults or {} 

176 # Add training arg if it does not exist, or set the default training value. 

177 if 'training' not in arg_spec.args: 

178 kwonlyargs.append('training') 

179 kwonlydefaults['training'] = default_training_value 

180 else: 

181 index = arg_spec.args.index('training') 

182 training_default_index = len(arg_spec.args) - index 

183 if (arg_spec.defaults and 

184 len(arg_spec.defaults) >= training_default_index and 

185 defaults[-training_default_index] is None): 

186 defaults[-training_default_index] = default_training_value 

187 

188 decorator_argspec = tf_inspect.FullArgSpec( 

189 args=arg_spec.args, 

190 varargs=arg_spec.varargs, 

191 varkw=arg_spec.varkw, 

192 defaults=defaults, 

193 kwonlyargs=kwonlyargs, 

194 kwonlydefaults=kwonlydefaults, 

195 annotations=arg_spec.annotations) 

196 return wrap_with_training_arg, decorator_argspec 

197 

198 

199def get_training_arg_index(call_fn): 

200 """Returns the index of 'training' in the layer call function arguments. 

201 

202 Args: 

203 call_fn: Call function. 

204 

205 Returns: 

206 - n: index of 'training' in the call function arguments. 

207 - -1: if 'training' is not found in the arguments, but layer.call accepts 

208 variable keyword arguments 

209 - None: if layer doesn't expect a training argument. 

210 """ 

211 argspec = tf_inspect.getfullargspec(call_fn) 

212 if argspec.varargs: 

213 # When there are variable args, training must be a keyword arg. 

214 if 'training' in argspec.kwonlyargs or argspec.varkw: 

215 return -1 

216 return None 

217 else: 

218 # Try to find 'training' in the list of args or kwargs. 

219 arg_list = argspec.args 

220 if tf_inspect.ismethod(call_fn): 

221 arg_list = arg_list[1:] 

222 

223 if 'training' in arg_list: 

224 return arg_list.index('training') 

225 elif 'training' in argspec.kwonlyargs or argspec.varkw: 

226 return -1 

227 return None 

228 

229 

230def set_training_arg(training, index, args, kwargs): 

231 if index is None or index < 0 or len(args) <= index: # index is invalid 

232 kwargs['training'] = training 

233 else: 

234 args[index] = training 

235 return args, kwargs 

236 

237 

238def get_training_arg(index, args, kwargs): 

239 if index is None or index < 0 or len(args) <= index: # index is invalid 

240 return kwargs.get('training', None) 

241 else: 

242 return args[index] 

243 

244 

245def remove_training_arg(index, args, kwargs): 

246 if index is None or index < 0 or len(args) <= index: # index is invalid 

247 kwargs.pop('training', None) 

248 else: 

249 args.pop(index) 

250 

251 

252class SaveOptionsContext(threading.local): 

253 

254 def __init__(self): 

255 super(SaveOptionsContext, self).__init__() 

256 self.save_traces = True 

257 

258 

259_save_options_context = SaveOptionsContext() 

260 

261 

262@tf_contextlib.contextmanager 

263def keras_option_scope(save_traces): 

264 previous_value = _save_options_context.save_traces 

265 try: 

266 _save_options_context.save_traces = save_traces 

267 yield 

268 finally: 

269 _save_options_context.save_traces = previous_value 

270 

271 

272def should_save_traces(): 

273 """Whether to trace layer functions-can be disabled in the save_traces arg.""" 

274 return _save_options_context.save_traces 

275 

276 

277@tf_contextlib.contextmanager 

278def no_automatic_dependency_tracking_scope(obj): 

279 """A context that disables automatic dependency tracking when assigning attrs. 

280 

281 Objects that inherit from Autotrackable automatically creates dependencies 

282 to trackable objects through attribute assignments, and wraps data structures 

283 (lists or dicts) with trackable classes. This scope may be used to temporarily 

284 disable this behavior. This works similar to the decorator 

285 `no_automatic_dependency_tracking`. 

286 

287 Example usage: 

288 ``` 

289 model = tf.keras.Model() 

290 model.arr1 = [] # Creates a ListWrapper object 

291 with no_automatic_dependency_tracking_scope(model): 

292 model.arr2 = [] # Creates a regular, untracked python list 

293 ``` 

294 

295 Args: 

296 obj: A trackable object. 

297 

298 Yields: 

299 a scope in which the object doesn't track dependencies. 

300 """ 

301 previous_value = getattr(obj, '_setattr_tracking', True) 

302 obj._setattr_tracking = False # pylint: disable=protected-access 

303 try: 

304 yield 

305 finally: 

306 obj._setattr_tracking = previous_value # pylint: disable=protected-access