Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/module/module.py: 35%

113 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Modules encapsulate building stateful components.""" 

16 

17import re 

18 

19from tensorflow.python import tf2 

20from tensorflow.python.framework import composite_tensor 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import variables 

23from tensorflow.python.trackable import autotrackable 

24from tensorflow.python.util import nest 

25from tensorflow.python.util import tf_decorator 

26from tensorflow.python.util.tf_export import tf_export 

27 

28 

29@tf_export("Module") 

30class Module(autotrackable.AutoTrackable): 

31 """Base neural network module class. 

32 

33 A module is a named container for `tf.Variable`s, other `tf.Module`s and 

34 functions which apply to user input. For example a dense layer in a neural 

35 network might be implemented as a `tf.Module`: 

36 

37 >>> class Dense(tf.Module): 

38 ... def __init__(self, input_dim, output_size, name=None): 

39 ... super().__init__(name=name) 

40 ... self.w = tf.Variable( 

41 ... tf.random.normal([input_dim, output_size]), name='w') 

42 ... self.b = tf.Variable(tf.zeros([output_size]), name='b') 

43 ... def __call__(self, x): 

44 ... y = tf.matmul(x, self.w) + self.b 

45 ... return tf.nn.relu(y) 

46 

47 You can use the Dense layer as you would expect: 

48 

49 >>> d = Dense(input_dim=3, output_size=2) 

50 >>> d(tf.ones([1, 3])) 

51 <tf.Tensor: shape=(1, 2), dtype=float32, numpy=..., dtype=float32)> 

52 

53 

54 By subclassing `tf.Module` instead of `object` any `tf.Variable` or 

55 `tf.Module` instances assigned to object properties can be collected using 

56 the `variables`, `trainable_variables` or `submodules` property: 

57 

58 >>> d.variables 

59 (<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=..., 

60 dtype=float32)>, 

61 <tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=..., dtype=float32)>) 

62 

63 

64 Subclasses of `tf.Module` can also take advantage of the `_flatten` method 

65 which can be used to implement tracking of any other types. 

66 

67 All `tf.Module` classes have an associated `tf.name_scope` which can be used 

68 to group operations in TensorBoard and create hierarchies for variable names 

69 which can help with debugging. We suggest using the name scope when creating 

70 nested submodules/parameters or for forward methods whose graph you might want 

71 to inspect in TensorBoard. You can enter the name scope explicitly using 

72 `with self.name_scope:` or you can annotate methods (apart from `__init__`) 

73 with `@tf.Module.with_name_scope`. 

74 

75 >>> class MLP(tf.Module): 

76 ... def __init__(self, input_size, sizes, name=None): 

77 ... super().__init__(name=name) 

78 ... self.layers = [] 

79 ... with self.name_scope: 

80 ... for size in sizes: 

81 ... self.layers.append(Dense(input_dim=input_size, output_size=size)) 

82 ... input_size = size 

83 ... @tf.Module.with_name_scope 

84 ... def __call__(self, x): 

85 ... for layer in self.layers: 

86 ... x = layer(x) 

87 ... return x 

88 

89 >>> module = MLP(input_size=5, sizes=[5, 5]) 

90 >>> module.variables 

91 (<tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 

92 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 

93 dtype=float32)>, 

94 <tf.Variable 'mlp/b:0' shape=(5,) dtype=float32, numpy=..., dtype=float32)>, 

95 <tf.Variable 'mlp/w:0' shape=(5, 5) dtype=float32, numpy=..., 

96 dtype=float32)>) 

97 """ 

98 

99 # AutoTrackable adds object attributes that users will not expect us to 

100 # include when flattening (these reference dependencies reachable via other 

101 # object attributes). 

102 _TF_MODULE_IGNORED_PROPERTIES = frozenset(( 

103 "_self_unconditional_checkpoint_dependencies", 

104 "_self_unconditional_dependency_names" 

105 )) 

106 

107 def __init__(self, name=None): 

108 if name is None: 

109 name = camel_to_snake(type(self).__name__) 

110 else: 

111 if not valid_identifier(name): 

112 raise ValueError( 

113 "%r is not a valid module name. Module names must be valid Python " 

114 "identifiers (e.g. a valid class name)." % name) 

115 

116 self._name = name 

117 if tf2.enabled(): 

118 with ops.name_scope_v2(name) as scope_name: 

119 self._name_scope = ops.name_scope_v2(scope_name) 

120 else: 

121 with ops.name_scope(name, skip_on_eager=False) as scope_name: 

122 self._scope_name = scope_name 

123 

124 @property 

125 def name(self): 

126 """Returns the name of this module as passed or determined in the ctor. 

127 

128 NOTE: This is not the same as the `self.name_scope.name` which includes 

129 parent module names. 

130 """ 

131 return self._name 

132 

133 @property 

134 def name_scope(self): 

135 """Returns a `tf.name_scope` instance for this class.""" 

136 if tf2.enabled(): 

137 return self._name_scope 

138 else: 

139 # In TF1 name_scope is not re-entrant in eager so we cannot memoize it. 

140 return ops.name_scope(self._scope_name, skip_on_eager=False) 

141 

142 @property 

143 def variables(self): 

144 """Sequence of variables owned by this module and its submodules. 

145 

146 Note: this method uses reflection to find variables on the current instance 

147 and submodules. For performance reasons you may wish to cache the result 

148 of calling this method if you don't expect the return value to change. 

149 

150 Returns: 

151 A sequence of variables for the current module (sorted by attribute 

152 name) followed by variables from all submodules recursively (breadth 

153 first). 

154 """ 

155 return tuple(self._flatten(predicate=_is_variable, expand_composites=True)) 

156 

157 @property 

158 def trainable_variables(self): 

159 """Sequence of trainable variables owned by this module and its submodules. 

160 

161 Note: this method uses reflection to find variables on the current instance 

162 and submodules. For performance reasons you may wish to cache the result 

163 of calling this method if you don't expect the return value to change. 

164 

165 Returns: 

166 A sequence of variables for the current module (sorted by attribute 

167 name) followed by variables from all submodules recursively (breadth 

168 first). 

169 """ 

170 return tuple( 

171 self._flatten(predicate=_is_trainable_variable, expand_composites=True)) 

172 

173 @property 

174 def non_trainable_variables(self): 

175 """Sequence of non-trainable variables owned by this module and its submodules. 

176 

177 Note: this method uses reflection to find variables on the current instance 

178 and submodules. For performance reasons you may wish to cache the result 

179 of calling this method if you don't expect the return value to change. 

180 

181 Returns: 

182 A sequence of variables for the current module (sorted by attribute 

183 name) followed by variables from all submodules recursively (breadth 

184 first). 

185 """ 

186 return tuple(self._flatten( 

187 predicate=_is_non_trainable_variable, expand_composites=True)) 

188 

189 @property 

190 def submodules(self): 

191 """Sequence of all sub-modules. 

192 

193 Submodules are modules which are properties of this module, or found as 

194 properties of modules which are properties of this module (and so on). 

195 

196 >>> a = tf.Module() 

197 >>> b = tf.Module() 

198 >>> c = tf.Module() 

199 >>> a.b = b 

200 >>> b.c = c 

201 >>> list(a.submodules) == [b, c] 

202 True 

203 >>> list(b.submodules) == [c] 

204 True 

205 >>> list(c.submodules) == [] 

206 True 

207 

208 Returns: 

209 A sequence of all submodules. 

210 """ 

211 return tuple(self._flatten(predicate=_is_module)) 

212 

213 def _flatten(self, 

214 recursive=True, 

215 predicate=None, 

216 attribute_traversal_key=None, 

217 with_path=False, 

218 expand_composites=False): 

219 """Flattened attribute values in sorted order by attribute name. 

220 

221 Modules are flattened by first walking their attributes in name order. 

222 Each attribute value is then flattened to find leaf values. If flatten is 

223 applied `recursive`ly and if the leaf is a `Module` it will also be 

224 flattened to find leaves. Finally every leaf value is optionally tested 

225 against the given `predicate` and finally yielded. 

226 

227 ``` 

228 class Foo(tf.Module): 

229 def __init__(self): 

230 super().__init__() 

231 self.x = [tf.constant('a'), tf.constant('b')] 

232 self.y = {'i': tf.constant('c'), 'j': tf.constant('d')} 

233 self.z = tf.constant('e') 

234 

235 @property 

236 def tensors(self): 

237 return tuple(self._flatten(predicate=is_tensor, with_path=True)) 

238 

239 foo = Foo() 

240 foo.tensors 

241 # ==> ((('x', 0), <tf.Tensor: ...'a'>), 

242 # (('x', 1), <tf.Tensor: ...'b'>), 

243 # (('y', 'i'), <tf.Tensor: ...'c'>), 

244 # (('y', 'j'), <tf.Tensor: ...'d'>), 

245 # (('z',), <tf.Tensor: ...'e'>)) 

246 ``` 

247 

248 `attribute_traversal_key` controls the order object properties are visited. 

249 If not set objects are visited in ascending order by name. 

250 

251 Args: 

252 recursive: Whether to recurse into child modules or not. 

253 predicate: (Optional) If set then only values matching predicate are 

254 yielded. A value of `None` (the default) means no items will be 

255 filtered. 

256 attribute_traversal_key: (Optional) Method to rekey object attributes 

257 before they are sorted. Contract is the same as `key` argument to 

258 builtin `sorted` and only applies to object properties. 

259 with_path: (Optional) Whether to include the path to the object as well 

260 as the object itself. If `with_path` is `True` then leaves will not be 

261 de-duplicated (e.g. if the same leaf instance is reachable via multiple 

262 modules then it will be yielded multiple times with different paths). 

263 expand_composites: If true, then composite tensors are expanded into their 

264 component tensors. 

265 

266 Returns: 

267 Flat generator for leaves of the current module and optionally all 

268 submodules. 

269 """ 

270 if predicate is None: 

271 predicate = lambda _: True 

272 

273 return _flatten_module( 

274 self, 

275 recursive=recursive, 

276 predicate=predicate, 

277 attributes_to_ignore=self._TF_MODULE_IGNORED_PROPERTIES, 

278 attribute_traversal_key=attribute_traversal_key, 

279 with_path=with_path, 

280 expand_composites=expand_composites) 

281 

282 @classmethod 

283 def with_name_scope(cls, method): 

284 """Decorator to automatically enter the module name scope. 

285 

286 >>> class MyModule(tf.Module): 

287 ... @tf.Module.with_name_scope 

288 ... def __call__(self, x): 

289 ... if not hasattr(self, 'w'): 

290 ... self.w = tf.Variable(tf.random.normal([x.shape[1], 3])) 

291 ... return tf.matmul(x, self.w) 

292 

293 Using the above module would produce `tf.Variable`s and `tf.Tensor`s whose 

294 names included the module name: 

295 

296 >>> mod = MyModule() 

297 >>> mod(tf.ones([1, 2])) 

298 <tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)> 

299 >>> mod.w 

300 <tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32, 

301 numpy=..., dtype=float32)> 

302 

303 Args: 

304 method: The method to wrap. 

305 

306 Returns: 

307 The original method wrapped such that it enters the module's name scope. 

308 """ 

309 def method_with_name_scope(self, *args, **kwargs): 

310 with self.name_scope: 

311 return method(self, *args, **kwargs) 

312 

313 return tf_decorator.make_decorator(method, method_with_name_scope) 

314 

315 

316def _is_variable(obj): 

317 return isinstance(obj, variables.Variable) 

318 

319 

320def _is_trainable_variable(obj): 

321 return _is_variable(obj) and getattr(obj, "trainable", False) 

322 

323 

324def _is_non_trainable_variable(obj): 

325 return _is_variable(obj) and not getattr(obj, "trainable", False) 

326 

327 

328def _is_module(obj): 

329 return isinstance(obj, Module) 

330 

331_CAMEL_TO_SNAKE_R = re.compile(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))") 

332_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_]([a-zA-Z0-9_])*$") 

333 

334 

335def valid_identifier(name): 

336 return bool(_VALID_IDENTIFIER.match(name)) 

337 

338 

339def camel_to_snake(value): 

340 return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower() 

341 

342 

343def _flatten_non_variable_composites_with_tuple_path(structure, path_prefix=()): 

344 """Flattens composite tensors with tuple path expect variables.""" 

345 for path, child in nest.flatten_with_tuple_paths(structure): 

346 if (isinstance(child, composite_tensor.CompositeTensor) and 

347 not _is_variable(child)): 

348 # pylint: disable=protected-access 

349 spec = child._type_spec 

350 yield from _flatten_non_variable_composites_with_tuple_path( 

351 spec._to_components(child), 

352 path_prefix + path + (spec.value_type.__name__,)) 

353 # pylint: enable=protected-access 

354 else: 

355 yield path_prefix + path, child 

356 

357 

358def _flatten_module(module, 

359 recursive, 

360 predicate, 

361 attribute_traversal_key, 

362 attributes_to_ignore, 

363 with_path, 

364 expand_composites, 

365 module_path=(), 

366 seen=None, 

367 recursion_stack=None): 

368 """Implementation of `flatten`. 

369 

370 Args: 

371 module: Current module to process. 

372 recursive: Whether to recurse into child modules or not. 

373 predicate: (Optional) If set then only values matching predicate are 

374 yielded. A value of `None` (the default) means no items will be 

375 filtered. 

376 attribute_traversal_key: (Optional) Method to rekey object attributes 

377 before they are sorted. Contract is the same as `key` argument to 

378 builtin `sorted` and only applies to object properties. 

379 attributes_to_ignore: object attributes to ignored. 

380 with_path: (Optional) Whether to include the path to the object as well 

381 as the object itself. If `with_path` is `True` then leaves will not be 

382 de-duplicated (e.g. if the same leaf instance is reachable via multiple 

383 modules then it will be yielded multiple times with different paths). 

384 expand_composites: If true, then composite tensors are expanded into their 

385 component tensors. 

386 module_path: The path to the current module as a tuple. 

387 seen: A set containing all leaf IDs seen so far. 

388 recursion_stack: A list containing all module IDs associated with the 

389 current call stack. 

390 

391 Yields: 

392 Matched leaves with the optional corresponding paths of the current module 

393 and optionally all its submodules. 

394 """ 

395 module_id = id(module) 

396 if seen is None: 

397 seen = set([module_id]) 

398 

399 module_dict = vars(module) 

400 submodules = [] 

401 

402 if recursion_stack is None: 

403 recursion_stack = [] 

404 

405 # When calling `_flatten_module` with `with_path=False`, the global lookup 

406 # table `seen` guarantees the uniqueness of the matched objects. 

407 # In the case of `with_path=True`, there might be multiple paths associated 

408 # with the same predicate, so we don't stop traversing according to `seen` 

409 # to make sure all these paths are returned. 

410 # When there are cycles connecting submodules, we break cycles by avoiding 

411 # following back edges (links pointing to a node in `recursion_stack`). 

412 if module_id in recursion_stack: 

413 recursive = False 

414 

415 for key in sorted(module_dict, key=attribute_traversal_key): 

416 if key in attributes_to_ignore: 

417 continue 

418 

419 prop = module_dict[key] 

420 try: 

421 if expand_composites: 

422 leaves = list(_flatten_non_variable_composites_with_tuple_path(prop)) 

423 else: 

424 leaves = nest.flatten_with_tuple_paths(prop) 

425 except Exception as cause: # pylint: disable=broad-except 

426 raise ValueError("Error processing property {!r} of {!r}".format( 

427 key, prop)) from cause 

428 

429 for leaf_path, leaf in leaves: 

430 leaf_path = (key,) + leaf_path 

431 

432 if not with_path: 

433 leaf_id = id(leaf) 

434 if leaf_id in seen: 

435 continue 

436 seen.add(leaf_id) 

437 

438 if predicate(leaf): 

439 if with_path: 

440 yield module_path + leaf_path, leaf 

441 else: 

442 yield leaf 

443 

444 if recursive and _is_module(leaf): 

445 # Walk direct properties first then recurse. 

446 submodules.append((module_path + leaf_path, leaf)) 

447 

448 recursion_stack.append(module_id) 

449 

450 for submodule_path, submodule in submodules: 

451 subvalues = _flatten_module( 

452 submodule, 

453 recursive=recursive, 

454 predicate=predicate, 

455 attribute_traversal_key=attribute_traversal_key, 

456 attributes_to_ignore=submodule._TF_MODULE_IGNORED_PROPERTIES, # pylint: disable=protected-access 

457 with_path=with_path, 

458 expand_composites=expand_composites, 

459 module_path=submodule_path, 

460 seen=seen, 

461 recursion_stack=recursion_stack) 

462 

463 for subvalue in subvalues: 

464 # Predicate is already tested for these values. 

465 yield subvalue 

466 

467 recursion_stack.pop()