Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/layout_map.py: 26%
137 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for map layout and corresponding tf.Variable."""
17import collections
18import contextlib
19import re
20import threading
22import tensorflow.compat.v2 as tf
24from keras.src.dtensor import dtensor_api as dtensor
25from keras.src.dtensor import lazy_variable
26from keras.src.dtensor import utils
27from keras.src.engine import base_layer
29# isort: off
30from tensorflow.python.util.tf_export import keras_export
33# We will skip the path for certain attributes when mapping the layout, e.g.
34# model._self_tracked_trackables, or layer._trainable_weights/
35# _non_trainable_weights, etc. Those attributes are usually served as a cache,
36# and the actual variable should be in somewhere else.
37_KERAS_ATTRIBUTES_TO_SKIP = [
38 "_self_tracked_trackables",
39 "_trainable_weights",
40 "_non_trainable_weights",
41 "_captured_weight_regularizer",
42]
45_LAYOUT_MAP = threading.local()
48def get_current_layout_map():
49 return getattr(_LAYOUT_MAP, "layout_map", None)
52@keras_export("keras.dtensor.experimental.LayoutMap", v1=[])
53class LayoutMap(collections.abc.MutableMapping):
54 """A dict-like object that maps string to `Layout` instances.
56 `LayoutMap` uses a string as key and a `Layout` as value. There is a
57 behavior difference between a normal Python dict and this class. The string
58 key will be treated as a regex when retrieving the value. See the docstring
59 of `get` for more details.
61 See below for a usage example. You can define the naming schema
62 of the `Layout`, and then retrieve the corresponding `Layout` instance.
64 To use the `LayoutMap` with a `Model`, please see the docstring of
65 `tf.keras.dtensor.experimental.layout_map_scope`.
67 ```python
68 map = LayoutMap(mesh=None)
69 map['.*dense.*kernel'] = layout_2d
70 map['.*dense.*bias'] = layout_1d
71 map['.*conv2d.*kernel'] = layout_4d
72 map['.*conv2d.*bias'] = layout_1d
74 layout_1 = map['dense_1.kernel'] # layout_1 == layout_2d
75 layout_2 = map['dense_1.bias'] # layout_2 == layout_1d
76 layout_3 = map['dense_2.kernel'] # layout_3 == layout_2d
77 layout_4 = map['dense_2.bias'] # layout_4 == layout_1d
78 layout_5 = map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d
79 layout_6 = map['my_model/conv2d_123/bias'] # layout_6 == layout_1d
80 ```
82 Args:
83 mesh: An optional `Mesh` that can be used to create all replicated
84 layout as default when there isn't a layout found based on the input
85 string query.
86 """
88 def __init__(self, mesh=None):
89 self._layout_map = collections.OrderedDict()
90 self._default_mesh = mesh
92 def __getitem__(self, key):
93 """Retrieve the corresponding layout by the string key.
95 When there isn't an exact match, all the existing keys in the layout map
96 will be treated as a regex and map against the input key again. The
97 first match will be returned, based on the key insertion order. Return
98 None if there isn't any match found.
100 Args:
101 key: the string key as the query for the layout.
103 Returns:
104 Corresponding layout based on the query.
105 """
106 if key in self._layout_map:
107 return self._layout_map[key]
109 for k in self._layout_map:
110 if re.match(k, key):
111 return self._layout_map[k]
112 return None
114 def __setitem__(self, key, layout):
115 if key in self._layout_map:
116 raise ValueError(
117 f"{key} already exist in the LayoutMap with "
118 f"value {self._layout_map[key]}. Please make sure to "
119 "not use duplicated keys."
120 )
121 if not isinstance(layout, dtensor.Layout):
122 raise ValueError(
123 f"{layout} should be a dtensor.Layout type, got {type(layout)}"
124 )
126 self._layout_map[key] = layout
128 def __delitem__(self, key):
129 # let the dict to handle the key missing error
130 return self._layout_map.pop(key)
132 def __len__(self):
133 return len(self._layout_map)
135 def __iter__(self):
136 return iter(self._layout_map)
138 def get_default_mesh(self):
139 """Return the default `Mesh` set at instance creation.
141 The `Mesh` can be used to create default replicated `Layout` when there
142 isn't a match of the input string query.
143 """
144 return self._default_mesh
146 def scope(self):
147 """Apply layout to all `tf.Variable` instances created under the scope.
149 All `tf.Variable` instances created under this scope
150 will be lazily initialized first. Once they are attached as the model
151 or layer attributes, and there is a stable layout mapping for it, the
152 variables will be reinitialized into a
153 `tf.experimental.dtensor.DVariable` with corresponding layout.
155 Note that the layout mapping will use object/attribute names as the
156 keys to map the variable to the layout.
158 For subclassed models, the full object/attribute name is used as the
159 key. For Functional/Sequential models, we use `layer.name` as
160 the key for the layer, followed by the attribute name. Keras ensures
161 name uniqueness among the layers within a Functional/Sequential model.
163 See the following examples that show variable object names
164 for different Keras model types:
166 ```python
167 layout_map = layout_map_lib.LayoutMap(mesh=self.mesh)
168 layout_map['d1.kernel'] = layout_1
169 layout_map['d1.bias'] = layout_2
170 layout_map['d2.kernel'] = layout_3
171 layout_map['d2.bias'] = layout_4
173 ## Subclassed model
174 class SubclassModel(tf.keras.Model):
176 def __init__(self, name=None):
177 super().__init__(name=name)
178 self.d1 = tf.keras.layers.Dense(1000)
179 self.d2 = tf.keras.layers.Dense(1000)
181 def call(self, inputs):
182 x = self.d1(inputs)
183 return self.d2(x)
185 with layout_map.scope():
186 model = SubclassModel()
187 inputs = tf.zeros((10, 10))
188 results = model(inputs)
190 model.d1.kernel.layout == layout_1
191 model.d1.bias.layout == layout_2
192 model.d2.kernel.layout == layout_3
193 model.d2.bias.layout == layout_4
195 ## Functional model
196 with layout_map.scope():
197 inputs = tf.keras.Input((10,), batch_size=10)
198 x = tf.keras.layers.Dense(20, name='d1')(inputs)
199 output = tf.keras.layers.Dense(30, name='d2')(x)
201 model = tf.keras.Model(inputs, output)
203 d1 = model.layers[1]
204 d2 = model.layers[2]
206 d1.kernel.layout == layout_1
207 d1.bias.layout == layout_2
208 d1.kernel.layout == layout_3
209 d1.bias.layout == layout_4
211 ## Sequential model
212 with layout_map.scope():
213 model = tf.keras.Sequential([
214 tf.keras.layers.Dense(20, name='d1', input_shape=(10,)),
215 tf.keras.layers.Dense(30, name='d2')
216 ])
218 d1 = model.layers[0]
219 d2 = model.layers[1]
221 d1.kernel.layout == layout_1
222 d1.bias.layout == layout_2
223 d1.kernel.layout == layout_3
224 d1.bias.layout == layout_4
225 ```
227 Returns:
228 A context that will lazily initialize all `tf.Variable` objects
229 within the model, with their attributed layouts.
230 """
231 return layout_map_scope(self)
234LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__
237@contextlib.contextmanager
238def layout_map_scope(layout_map):
239 """Apply the layout to all the tf.Variables created under the scope.
241 Create a scope that all the tf.Variable created under this scope
242 will be lazily inited, and initialized later on with proper layout when the
243 object path in the model is stable/finalized.
245 Note that the layout mapping will use the object/attribute names as the key
246 to map the variable against the layout.
248 For subclassed models, the full object/attribute name is used as the key.
249 For Functional/Sequential models, since the layers within the model do not
250 get assigned to a meaningful attribute, we use `layer.name` as the key for
251 the layer, followed by the attribute name. Keras ensures name uniqueness
252 among the layers in all Functional/Sequential models.
254 See the following examples that show the variable object names
255 for different Keras model types:
257 ```python
258 layout_map = layout_map_lib.LayoutMap(mesh=self.mesh)
259 layout_map['d1.kernel'] = layout_1
260 layout_map['d1.bias'] = layout_2
261 layout_map['d2.kernel'] = layout_3
262 layout_map['d2.bias'] = layout_4
264 ## Subclassed model
265 class SubclassModel(tf.keras.Model):
267 def __init__(self, name=None):
268 super().__init__(name=name)
269 self.d1 = tf.keras.layers.Dense(1000)
270 self.d2 = tf.keras.layers.Dense(1000)
272 def call(self, inputs):
273 x = self.d1(inputs)
274 return self.d2(x)
276 with layout_map_scope(layout_map):
277 model = SubclassModel()
278 # Triggering the creation of weights within or outside of the scope works
279 inputs = tf.zeros((10, 10))
280 results = model(inputs)
282 model.d1.kernel.layout == layout_1
283 model.d1.bias.layout == layout_2
284 model.d2.kernel.layout == layout_3
285 model.d2.bias.layout == layout_4
287 ## Functional model
288 with layout_map_scope(layout_map):
289 inputs = tf.keras.Input((10,), batch_size=10)
290 x = tf.keras.layers.Dense(20, name='d1')(inputs)
291 output = tf.keras.layers.Dense(30, name='d2')(x)
293 model = tf.keras.Model(inputs, output)
295 d1 = model.layers[1]
296 d2 = model.layers[2]
298 d1.kernel.layout == layout_1
299 d1.bias.layout == layout_2
300 d1.kernel.layout == layout_3
301 d1.bias.layout == layout_4
303 ## Sequential model
304 with layout_map_scope(layout_map):
305 model = tf.keras.Sequential([
306 tf.keras.layers.Dense(20, name='d1', input_shape=(10,)),
307 tf.keras.layers.Dense(30, name='d2')
308 ])
310 d1 = model.layers[0]
311 d2 = model.layers[1]
313 d1.kernel.layout == layout_1
314 d1.bias.layout == layout_2
315 d1.kernel.layout == layout_3
316 d1.bias.layout == layout_4
317 ```
319 Args:
320 layout_map: a LayoutMap which contains the variable_object_path (string)
321 -> Layout. When a layout is not found for the variable, a default all
322 replicated layout will be created for the variable.
324 Yields:
325 A context that will lazily initialize all `tf.Variable` objects
326 within the model, with their attributed layouts.
327 """
328 previous_layout_map = get_current_layout_map()
329 global _LAYOUT_MAP
330 _LAYOUT_MAP.layout_map = layout_map
332 with lazy_variable.lazy_init_scope():
333 try:
334 yield
335 finally:
336 _LAYOUT_MAP.layout_map = previous_layout_map
339def _map_subclass_model_variable(model, layout_map):
340 """Map/Replace LazyInitVariable for subclass model."""
341 lazy_init_variable_to_tf_variable_map = {}
343 # Note that the model._flatten is a method from tf.Module, and it returns
344 # duplicated items (since some of the items have different paths).
345 for path, variable in model._flatten(
346 predicate=_is_lazy_init_variable,
347 with_path=True,
348 ):
349 # Note that path is a tuple that contains string and ints, eg:
350 # ('d1', '_trainable_weights', 0) maps to model.d1._trainable_weights[0]
351 if [a for a in _KERAS_ATTRIBUTES_TO_SKIP if a in path]:
352 continue
353 # Convert all the ints to string and join with .
354 object_path = ".".join([str(item) for item in path])
356 new_variable = _create_dvariable(layout_map, object_path, variable)
357 _set_object_by_path(model, path, new_variable)
358 lazy_init_variable_to_tf_variable_map[id(variable)] = new_variable
360 for layer in model._flatten(
361 predicate=lambda o: isinstance(o, base_layer.Layer)
362 ):
363 _config_dvariable_regularization(
364 layer, lazy_init_variable_to_tf_variable_map
365 )
366 # After we replaced all the variables, we want to make sure all the cached
367 # attributes are having the new variable, rather than old LazyInitVariable.
368 for path, variable in model._flatten(
369 predicate=_is_lazy_init_variable,
370 with_path=True,
371 ):
372 tf_variable = lazy_init_variable_to_tf_variable_map[id(variable)]
373 _set_object_by_path(model, path, tf_variable)
375 _init_state_variable_for_rng(model, layout_map)
376 _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map)
377 return model
380def _map_functional_model_variable(model, layout_map):
381 """Map/Replace LazyInitVariable for functional/sequential model."""
382 lazy_init_variable_to_tf_variable_map = {}
384 for layer in model.layers:
385 # Note that layer name is unique among the functional/sequential model
386 # when the layer name is not provided, Keras will auto generate a layer
387 # name based on the class name.
388 layer_name = layer.name
389 for path, variable in layer._flatten(
390 predicate=_is_lazy_init_variable,
391 with_path=True,
392 ):
393 # Note that path is a tuple that contains string and ints, eg:
394 # ('d1', '_trainable_weights', 0) maps to
395 # model.d1._trainable_weights[0]
396 if [a for a in _KERAS_ATTRIBUTES_TO_SKIP if a in path]:
397 continue
398 # Convert all the ints to string and join with .
399 object_path = ".".join([str(item) for item in path])
400 # Also attach the layer name
401 object_path = layer_name + "." + object_path
403 new_variable = _create_dvariable(layout_map, object_path, variable)
404 _set_object_by_path(layer, path, new_variable)
405 lazy_init_variable_to_tf_variable_map[id(variable)] = new_variable
407 _config_dvariable_regularization(
408 layer, lazy_init_variable_to_tf_variable_map
409 )
411 # After we replaced all the variables, we want to make sure all the
412 # cached attributes are having the new variable, rather than old
413 # LazyInitVariable.
414 for path, variable in layer._flatten(
415 predicate=_is_lazy_init_variable,
416 with_path=True,
417 ):
418 tf_variable = lazy_init_variable_to_tf_variable_map[id(variable)]
419 _set_object_by_path(layer, path, tf_variable)
421 _init_state_variable_for_rng(model, layout_map)
422 _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map)
423 return model
426def _init_state_variable_for_rng(model, layout_map):
427 """Init the state variable in tf.ranodm.Generator.
429 Since the BaseRandomLayer in keras explicitly untrack the
430 tf.random.Generator, the variable in it will stay as LazyInitVariable, which
431 cause runtime error if we don't replace them with proper DVariable. Since
432 user usually are not aware the existence of those variable, we will just
433 give them replicated layout since they are tiny.
435 Args:
436 model: the model whose layers will be checked to find the
437 BaseRandomLayers.
438 layout_map: used to get the default mesh information to create DVariable.
439 """
441 for l in model._flatten(
442 predicate=lambda o: isinstance(o, base_layer.BaseRandomLayer)
443 ):
444 keras_generator = l._random_generator
445 if keras_generator._built and keras_generator._generator is None:
446 raise ValueError(
447 "Keras is expected to use tf.random.Generator when using "
448 "DTensor API. Please call "
449 "`tf.keras.backend.experimental.enable_tf_random_generator` at "
450 "the beginning of your program."
451 )
452 if hasattr(keras_generator, "_generator") and _is_lazy_init_variable(
453 keras_generator._generator._state_var
454 ):
455 # Replace it with DVariable
456 keras_generator._generator._state_var = _create_dvariable(
457 layout_map, "", keras_generator._generator._state_var
458 )
459 else:
460 # When the keras_generator is not built yet. Call the init function
461 # with DTensor device to init all the variable with default
462 # replicated layout.
463 with dtensor.default_mesh(layout_map.get_default_mesh()):
464 keras_generator._maybe_init()
467def _config_dvariable_regularization(
468 layer, lazy_init_variable_to_tf_variable_map
469):
470 """Update the weights regularizer for newly created `DVariable`.
472 The weight regularization usually happens when `layer.add_weight()` is
473 called, at which point the library will first create a `LazyInitVariable`,
474 and then replace it with a `DVariable`. We will defer the creation of those
475 losses, until the DVariable is created.
477 See `layer._captured_weight_regularizer` for more details.
479 Args:
480 layer: the layer instance for DVariable regularization config.
481 lazy_init_variable_to_tf_variable_map: the dict between LazyInitVariable
482 ID and newly created DVariable.
483 """
485 for name, variable, regualarizer in layer._captured_weight_regularizer:
486 if not _is_lazy_init_variable(variable):
487 raise ValueError(
488 "Expect the regularization loss are created from "
489 f"LazyInitVariable, got {variable}"
490 )
491 d_variable = lazy_init_variable_to_tf_variable_map[id(variable)]
492 layer._handle_weight_regularization(name, d_variable, regualarizer)
493 # After that, we should cleanup `layer._captured_weight_regularizer`
494 layer._captured_weight_regularizer = []
497def _create_dvariable(layout_map, object_path, variable):
498 """Create a new variable instead of using the LazyInitVariable.
500 We choose to do this since even the LazyInitVariable might behavior like
501 a normal tf.Variable/DVariable, it is not future proof for any new changes
502 to variable class. It will also fail the instance type check in python,
503 which could affect user's code when they do any filtering based on type to
504 find any variables.
506 Args:
507 layout_map: a LayoutMap which contains the variable_object_path (string)
508 -> Layout.
509 object_path: string, the object attribute path for the variable.
510 variable: LazyInitVariable which will be replaced by the newly created
511 tf.Variable.
512 Returns:
513 A new tf.Variable with correct layout information.
514 """
515 # TODO(b/228209108): Revisit this in future and see if we can just reuse the
516 # LazyInitVariable rather than creating a new tf.Variable instance.
517 layout = layout_map[object_path]
518 if layout is None:
519 variable_rank = variable.shape.rank
520 layout = dtensor.Layout.replicated(
521 mesh=layout_map.get_default_mesh(), rank=variable_rank
522 )
523 init_val = variable._initial_value
524 if callable(init_val):
525 with lazy_variable.disable_init_variable_creator():
526 init_val = utils.call_with_layout(init_val, layout)
527 else:
528 # The init value is probably already created as a tensor, we will just
529 # copy it to mesh and give it a proper layout.
530 init_val = dtensor.copy_to_mesh(init_val, layout)
531 # Use the original variable name for new DVariable creation. TF was adding
532 # ":0" suffix to it.
533 variable_name = variable.name
534 if variable_name.endswith(":0"):
535 variable_name = variable_name[:-2]
536 new_variable = dtensor.DVariable(
537 init_val, trainable=variable.trainable, name=variable_name
538 )
539 return new_variable
542def _set_object_by_path(object_to_set, path, value):
543 """Set the attribute of instance to the object.
545 Args:
546 object_to_set: the instance whose attribute should be set.
547 path: the tuple/list of string and ints, representing the attribute names.
548 Int means that the attribute to set is a item a list.
549 value: the value of the attribute.
550 """
552 for i, attr_name in enumerate(path):
553 if i == len(path) - 1:
554 # We found the actual attribute to set
555 if isinstance(attr_name, int):
556 # This means we are trying to set an element in the array, make
557 # sure the instance is array like object.
558 object_to_set[attr_name] = value
559 else:
560 setattr(object_to_set, attr_name, value)
561 else:
562 if isinstance(attr_name, int):
563 object_to_set = object_to_set[attr_name]
564 else:
565 object_to_set = getattr(object_to_set, attr_name)
568# TODO(b/228209108): Revisit this after we can reinit LazyInitVariable.
569def _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map):
570 """Update the trackable object references for the model.
572 Note that this method is only needed because of a corner case for model
573 checkpoint, where it could accidently catch a LazyInitVariable in checkpoint
574 dependency and not visible to the model attribute graph itself.
576 Args:
577 model: the keras model instance whose checkpoint dependency will be
578 examed.
579 lazy_init_variable_to_tf_variable_map: the dict between LazyInitVariable
580 ID and newly created DVariable.
581 """
582 # See b/234621758 for more details.
583 object_graph = tf.__internal__.tracking.ObjectGraphView(model)
584 trackables, _ = object_graph.breadth_first_traversal()
585 for trackable in trackables:
586 for ref_name, ref in trackable._trackable_children().items():
587 if _is_lazy_init_variable(ref):
588 # Replacing the LazyVariable with DVariable.
589 trackable._track_trackable(
590 lazy_init_variable_to_tf_variable_map[id(ref)],
591 ref_name,
592 overwrite=True,
593 )
596def _is_lazy_init_variable(obj):
597 return isinstance(obj, lazy_variable.LazyInitVariable)