Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/module/module.py: 35%
113 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Modules encapsulate building stateful components."""
17import re
19from tensorflow.python import tf2
20from tensorflow.python.framework import composite_tensor
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import variables
23from tensorflow.python.trackable import autotrackable
24from tensorflow.python.util import nest
25from tensorflow.python.util import tf_decorator
26from tensorflow.python.util.tf_export import tf_export
29@tf_export("Module")
30class Module(autotrackable.AutoTrackable):
31 """Base neural network module class.
33 A module is a named container for `tf.Variable`s, other `tf.Module`s and
34 functions which apply to user input. For example a dense layer in a neural
35 network might be implemented as a `tf.Module`:
37 >>> class Dense(tf.Module):
38 ... def __init__(self, input_dim, output_size, name=None):
39 ... super().__init__(name=name)
40 ... self.w = tf.Variable(
41 ... tf.random.normal([input_dim, output_size]), name='w')
42 ... self.b = tf.Variable(tf.zeros([output_size]), name='b')
43 ... def __call__(self, x):
44 ... y = tf.matmul(x, self.w) + self.b
45 ... return tf.nn.relu(y)
47 You can use the Dense layer as you would expect:
49 >>> d = Dense(input_dim=3, output_size=2)
50 >>> d(tf.ones([1, 3]))
51 <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)>
54 By subclassing `tf.Module` instead of `object` any `tf.Variable` or
55 `tf.Module` instances assigned to object properties can be collected using
56 the `variables`, `trainable_variables` or `submodules` property:
58 >>> d.variables
59 (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=...,
60 dtype=float32)>,
61 <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>)
64 Subclasses of `tf.Module` can also take advantage of the `_flatten` method
65 which can be used to implement tracking of any other types.
67 All `tf.Module` classes have an associated `tf.name_scope` which can be used
68 to group operations in TensorBoard and create hierarchies for variable names
69 which can help with debugging. We suggest using the name scope when creating
70 nested submodules/parameters or for forward methods whose graph you might want
71 to inspect in TensorBoard. You can enter the name scope explicitly using
72 `with self.name_scope:` or you can annotate methods (apart from `__init__`)
73 with `@tf.Module.with_name_scope`.
75 >>> class MLP(tf.Module):
76 ... def __init__(self, input_size, sizes, name=None):
77 ... super().__init__(name=name)
78 ... self.layers = []
79 ... with self.name_scope:
80 ... for size in sizes:
81 ... self.layers.append(Dense(input_dim=input_size, output_size=size))
82 ... input_size = size
83 ... @tf.Module.with_name_scope
84 ... def __call__(self, x):
85 ... for layer in self.layers:
86 ... x = layer(x)
87 ... return x
89 >>> module = MLP(input_size=5, sizes=[5, 5])
90 >>> module.variables
91 (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
92 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
93 dtype=float32)>,
94 <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>,
95 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=...,
96 dtype=float32)>)
97 """
99 # AutoTrackable adds object attributes that users will not expect us to
100 # include when flattening (these reference dependencies reachable via other
101 # object attributes).
102 _TF_MODULE_IGNORED_PROPERTIES = frozenset((
103 "_self_unconditional_checkpoint_dependencies",
104 "_self_unconditional_dependency_names"
105 ))
107 def __init__(self, name=None):
108 if name is None:
109 name = camel_to_snake(type(self).__name__)
110 else:
111 if not valid_identifier(name):
112 raise ValueError(
113 "%r is not a valid module name. Module names must be valid Python "
114 "identifiers (e.g. a valid class name)." % name)
116 self._name = name
117 if tf2.enabled():
118 with ops.name_scope_v2(name) as scope_name:
119 self._name_scope = ops.name_scope_v2(scope_name)
120 else:
121 with ops.name_scope(name, skip_on_eager=False) as scope_name:
122 self._scope_name = scope_name
124 @property
125 def name(self):
126 """Returns the name of this module as passed or determined in the ctor.
128 NOTE: This is not the same as the `self.name_scope.name` which includes
129 parent module names.
130 """
131 return self._name
133 @property
134 def name_scope(self):
135 """Returns a `tf.name_scope` instance for this class."""
136 if tf2.enabled():
137 return self._name_scope
138 else:
139 # In TF1 name_scope is not re-entrant in eager so we cannot memoize it.
140 return ops.name_scope(self._scope_name, skip_on_eager=False)
142 @property
143 def variables(self):
144 """Sequence of variables owned by this module and its submodules.
146 Note: this method uses reflection to find variables on the current instance
147 and submodules. For performance reasons you may wish to cache the result
148 of calling this method if you don't expect the return value to change.
150 Returns:
151 A sequence of variables for the current module (sorted by attribute
152 name) followed by variables from all submodules recursively (breadth
153 first).
154 """
155 return tuple(self._flatten(predicate=_is_variable, expand_composites=True))
157 @property
158 def trainable_variables(self):
159 """Sequence of trainable variables owned by this module and its submodules.
161 Note: this method uses reflection to find variables on the current instance
162 and submodules. For performance reasons you may wish to cache the result
163 of calling this method if you don't expect the return value to change.
165 Returns:
166 A sequence of variables for the current module (sorted by attribute
167 name) followed by variables from all submodules recursively (breadth
168 first).
169 """
170 return tuple(
171 self._flatten(predicate=_is_trainable_variable, expand_composites=True))
173 @property
174 def non_trainable_variables(self):
175 """Sequence of non-trainable variables owned by this module and its submodules.
177 Note: this method uses reflection to find variables on the current instance
178 and submodules. For performance reasons you may wish to cache the result
179 of calling this method if you don't expect the return value to change.
181 Returns:
182 A sequence of variables for the current module (sorted by attribute
183 name) followed by variables from all submodules recursively (breadth
184 first).
185 """
186 return tuple(self._flatten(
187 predicate=_is_non_trainable_variable, expand_composites=True))
189 @property
190 def submodules(self):
191 """Sequence of all sub-modules.
193 Submodules are modules which are properties of this module, or found as
194 properties of modules which are properties of this module (and so on).
196 >>> a = tf.Module()
197 >>> b = tf.Module()
198 >>> c = tf.Module()
199 >>> a.b = b
200 >>> b.c = c
201 >>> list(a.submodules) == [b, c]
202 True
203 >>> list(b.submodules) == [c]
204 True
205 >>> list(c.submodules) == []
206 True
208 Returns:
209 A sequence of all submodules.
210 """
211 return tuple(self._flatten(predicate=_is_module))
213 def _flatten(self,
214 recursive=True,
215 predicate=None,
216 attribute_traversal_key=None,
217 with_path=False,
218 expand_composites=False):
219 """Flattened attribute values in sorted order by attribute name.
221 Modules are flattened by first walking their attributes in name order.
222 Each attribute value is then flattened to find leaf values. If flatten is
223 applied `recursive`ly and if the leaf is a `Module` it will also be
224 flattened to find leaves. Finally every leaf value is optionally tested
225 against the given `predicate` and finally yielded.
227 ```
228 class Foo(tf.Module):
229 def __init__(self):
230 super().__init__()
231 self.x = [tf.constant('a'), tf.constant('b')]
232 self.y = {'i': tf.constant('c'), 'j': tf.constant('d')}
233 self.z = tf.constant('e')
235 @property
236 def tensors(self):
237 return tuple(self._flatten(predicate=is_tensor, with_path=True))
239 foo = Foo()
240 foo.tensors
241 # ==> ((('x', 0), <tf.Tensor: ...'a'>),
242 # (('x', 1), <tf.Tensor: ...'b'>),
243 # (('y', 'i'), <tf.Tensor: ...'c'>),
244 # (('y', 'j'), <tf.Tensor: ...'d'>),
245 # (('z',), <tf.Tensor: ...'e'>))
246 ```
248 `attribute_traversal_key` controls the order object properties are visited.
249 If not set objects are visited in ascending order by name.
251 Args:
252 recursive: Whether to recurse into child modules or not.
253 predicate: (Optional) If set then only values matching predicate are
254 yielded. A value of `None` (the default) means no items will be
255 filtered.
256 attribute_traversal_key: (Optional) Method to rekey object attributes
257 before they are sorted. Contract is the same as `key` argument to
258 builtin `sorted` and only applies to object properties.
259 with_path: (Optional) Whether to include the path to the object as well
260 as the object itself. If `with_path` is `True` then leaves will not be
261 de-duplicated (e.g. if the same leaf instance is reachable via multiple
262 modules then it will be yielded multiple times with different paths).
263 expand_composites: If true, then composite tensors are expanded into their
264 component tensors.
266 Returns:
267 Flat generator for leaves of the current module and optionally all
268 submodules.
269 """
270 if predicate is None:
271 predicate = lambda _: True
273 return _flatten_module(
274 self,
275 recursive=recursive,
276 predicate=predicate,
277 attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES,
278 attribute_traversal_key=attribute_traversal_key,
279 with_path=with_path,
280 expand_composites=expand_composites)
282 @classmethod
283 def with_name_scope(cls, method):
284 """Decorator to automatically enter the module name scope.
286 >>> class MyModule(tf.Module):
287 ... @tf.Module.with_name_scope
288 ... def __call__(self, x):
289 ... if not hasattr(self, 'w'):
290 ... self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
291 ... return tf.matmul(x, self.w)
293 Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose
294 names included the module name:
296 >>> mod = MyModule()
297 >>> mod(tf.ones([1, 2]))
298 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
299 >>> mod.w
300 <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
301 numpy=..., dtype=float32)>
303 Args:
304 method: The method to wrap.
306 Returns:
307 The original method wrapped such that it enters the module's name scope.
308 """
309 def method_with_name_scope(self, *args, **kwargs):
310 with self.name_scope:
311 return method(self, *args, **kwargs)
313 return tf_decorator.make_decorator(method, method_with_name_scope)
316def _is_variable(obj):
317 return isinstance(obj, variables.Variable)
320def _is_trainable_variable(obj):
321 return _is_variable(obj) and getattr(obj, "trainable", False)
324def _is_non_trainable_variable(obj):
325 return _is_variable(obj) and not getattr(obj, "trainable", False)
328def _is_module(obj):
329 return isinstance(obj, Module)
331_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))")
332_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$")
335def valid_identifier(name):
336 return bool(_VALID_IDENTIFIER.match(name))
339def camel_to_snake(value):
340 return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
343def _flatten_non_variable_composites_with_tuple_path(structure, path_prefix=()):
344 """Flattens composite tensors with tuple path expect variables."""
345 for path, child in nest.flatten_with_tuple_paths(structure):
346 if (isinstance(child, composite_tensor.CompositeTensor) and
347 not _is_variable(child)):
348 # pylint: disable=protected-access
349 spec = child._type_spec
350 yield from _flatten_non_variable_composites_with_tuple_path(
351 spec._to_components(child),
352 path_prefix + path + (spec.value_type.__name__,))
353 # pylint: enable=protected-access
354 else:
355 yield path_prefix + path, child
358def _flatten_module(module,
359 recursive,
360 predicate,
361 attribute_traversal_key,
362 attributes_to_ignore,
363 with_path,
364 expand_composites,
365 module_path=(),
366 seen=None,
367 recursion_stack=None):
368 """Implementation of `flatten`.
370 Args:
371 module: Current module to process.
372 recursive: Whether to recurse into child modules or not.
373 predicate: (Optional) If set then only values matching predicate are
374 yielded. A value of `None` (the default) means no items will be
375 filtered.
376 attribute_traversal_key: (Optional) Method to rekey object attributes
377 before they are sorted. Contract is the same as `key` argument to
378 builtin `sorted` and only applies to object properties.
379 attributes_to_ignore: object attributes to ignored.
380 with_path: (Optional) Whether to include the path to the object as well
381 as the object itself. If `with_path` is `True` then leaves will not be
382 de-duplicated (e.g. if the same leaf instance is reachable via multiple
383 modules then it will be yielded multiple times with different paths).
384 expand_composites: If true, then composite tensors are expanded into their
385 component tensors.
386 module_path: The path to the current module as a tuple.
387 seen: A set containing all leaf IDs seen so far.
388 recursion_stack: A list containing all module IDs associated with the
389 current call stack.
391 Yields:
392 Matched leaves with the optional corresponding paths of the current module
393 and optionally all its submodules.
394 """
395 module_id = id(module)
396 if seen is None:
397 seen = set([module_id])
399 module_dict = vars(module)
400 submodules = []
402 if recursion_stack is None:
403 recursion_stack = []
405 # When calling `_flatten_module` with `with_path=False`, the global lookup
406 # table `seen` guarantees the uniqueness of the matched objects.
407 # In the case of `with_path=True`, there might be multiple paths associated
408 # with the same predicate, so we don't stop traversing according to `seen`
409 # to make sure all these paths are returned.
410 # When there are cycles connecting submodules, we break cycles by avoiding
411 # following back edges (links pointing to a node in `recursion_stack`).
412 if module_id in recursion_stack:
413 recursive = False
415 for key in sorted(module_dict, key=attribute_traversal_key):
416 if key in attributes_to_ignore:
417 continue
419 prop = module_dict[key]
420 try:
421 if expand_composites:
422 leaves = list(_flatten_non_variable_composites_with_tuple_path(prop))
423 else:
424 leaves = nest.flatten_with_tuple_paths(prop)
425 except Exception as cause: # pylint: disable=broad-except
426 raise ValueError("Error processing property {!r} of {!r}".format(
427 key, prop)) from cause
429 for leaf_path, leaf in leaves:
430 leaf_path = (key,) + leaf_path
432 if not with_path:
433 leaf_id = id(leaf)
434 if leaf_id in seen:
435 continue
436 seen.add(leaf_id)
438 if predicate(leaf):
439 if with_path:
440 yield module_path + leaf_path, leaf
441 else:
442 yield leaf
444 if recursive and _is_module(leaf):
445 # Walk direct properties first then recurse.
446 submodules.append((module_path + leaf_path, leaf))
448 recursion_stack.append(module_id)
450 for submodule_path, submodule in submodules:
451 subvalues = _flatten_module(
452 submodule,
453 recursive=recursive,
454 predicate=predicate,
455 attribute_traversal_key=attribute_traversal_key,
456 attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access
457 with_path=with_path,
458 expand_composites=expand_composites,
459 module_path=submodule_path,
460 seen=seen,
461 recursion_stack=recursion_stack)
463 for subvalue in subvalues:
464 # Predicate is already tested for these values.
465 yield subvalue
467 recursion_stack.pop()