Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/utils.py: 36%
36 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras Utilities for DTensor related API."""
17import inspect
19import tensorflow.compat.v2 as tf
21from keras.src.dtensor import dtensor_api as dtensor
23# All the variable names in the default keras layers. We will use those to map
24# against the args in the __init__ method to find corresponding layout args.
25# See allow_layout() for more details.
26KERAS_VARIABLE_NAMES = [
27 "alpha",
28 "beta",
29 "bias",
30 "depthwise",
31 "embeddings",
32 "gamma",
33 "kernel",
34 "moving_mean",
35 "moving_variance",
36 "pointwise",
37 "recurrent",
38]
41def allow_initializer_layout(init_method):
42 """A decorator for injecting layout information to layer.__init__.
44 Layout will be a new param for any of the weights for all the keras layers.
45 Adding the param to all the __init__ method will be a big/duplicated work.
47 This decorator is design to reduce and code duplication and make it easy to
48 add/remove the dtensor feature if needed.
50 Sample usage:
51 ```python
52 class Dense(tf.keras.layer.Layer):
54 @allow_initializer_layout
55 def __init__(self, units,
56 kernel_initializer='zeros',
57 bias_initializer='zeros',
58 **kwargs):
59 super().__init__(**kwargs)
61 d = Dense(units=8, kernel_layout=layout1, bias_layout=layout2)
62 d.kernel_layout == layout1
63 d.bias_layout == layout2
64 ```
66 By adding this annotation, it will:
68 1. Filter out the kwargs based on some keywords, eg if the
69 'kernel_initialzer' appears in method signature, then it will try to pop
70 the 'kernel_layout' if it presents. Same for "bias" and
71 "recurrent_kernel", etc. This will make sure the layout related param is
72 not passed to `BaseLayer.__init__`, which will raise error about unexpect
73 keyword args.
74 2. Set the self.kernel/bias_layout attribute after the `__init__` method is
75 called. Keras framework will use those fields to create weights down the
76 stream.
78 Args:
79 init_method: the `__init__` method of the Keras layer to annotate.
81 Returns:
82 the annotated __init__ method.
83 """
85 def _wrap_function(layer_instance, *args, **kwargs):
86 signature = inspect.signature(init_method)
87 layout_args = {}
88 # Check args like 'kernel_initializer' and pop the 'kernel_layout' if it
89 # presents.
90 for variable_name in KERAS_VARIABLE_NAMES:
91 if variable_name + "_initializer" in signature.parameters:
92 layout = kwargs.pop(variable_name + "_layout", None)
93 if layout:
94 layout_args[variable_name + "_layout"] = layout
96 init_method(layer_instance, *args, **kwargs)
98 # Inject the layout parameter after the invocation of __init__()
99 for layout_param_name, layout in layout_args.items():
100 setattr(layer_instance, layout_param_name, layout)
102 # return decorated
103 return tf.__internal__.decorator.make_decorator(
104 target=init_method, decorator_func=_wrap_function
105 )
108def inject_mesh(init_method):
109 """Inject DTensor mesh information to an object.
111 This is useful for keras object like `Metric` and `Optimizer` which need
112 DTensor mesh to create the weights, but doesn't want to change the current
113 public API interface.
115 This is for temporary usage and eventually the mesh/layout information will
116 be public arguments in the `__init__` method.
118 Sample usage:
119 ```python
120 class Accuracy(tf.keras.metrics.Metric):
122 @inject_mesh
123 def __init__(self, name='accuracy', dtype=None):
124 super().__init__(**kwargs)
126 acc = Accuracy(mesh=mesh)
127 assert acc._mesh == mesh
128 ```
130 Args:
131 init_method: the `__init__` method of the Keras class to annotate.
133 Returns:
134 the annotated __init__ method.
135 """
137 def _wrap_function(instance, *args, **kwargs):
138 mesh = kwargs.pop("mesh", None)
139 # Note that the injection of _mesh need to happen before the invocation
140 # of __init__, since the class might need the mesh to create weights in
141 # the __init__.
142 if mesh is not None:
143 instance._mesh = mesh
144 init_method(instance, *args, **kwargs)
146 return tf.__internal__.decorator.make_decorator(
147 target=init_method, decorator_func=_wrap_function
148 )
151def call_with_layout(fn, layout, *args, **kwargs):
152 """Invoke the function with inputs and relayout the result.
154 Args:
155 fn: the function to invoke.
156 layout: if not None, the output of the fn will be relayout with this.
157 *args: positional arguments to be called with fn.
158 **kwargs: keyword arguments to be called with fn.
160 Returns:
161 The output of fn, with potential relayout with the layout specified.
162 """
163 if layout:
164 with dtensor.default_mesh(layout.mesh):
165 result = fn(*args, **kwargs)
166 return dtensor.relayout(result, layout)
167 return fn(*args, **kwargs)
170def running_with_dtensor_strategy():
171 """Check whether running with a `Strategy` that is backed by DTensor.
173 In the DTensor based training, all the tensors are in global context, which
174 is different from the local context. Some keras components need to
175 behave differently, e.g. BatchNormalization and SyncBatchNormalization, as
176 well as optimizers.
178 This check will help those layer to branch the logic and keep the correct
179 behavior between different context.
180 """
181 if not tf.distribute.has_strategy():
182 return False
183 strategy = tf.distribute.get_strategy()
184 # TODO(scottzhu): Finalize the strategy API to check if a strategy is backed
185 # by DTensor.
186 return getattr(strategy, "_mesh", None) is not None