Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/utils.py: 36%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras Utilities for DTensor related API.""" 

16 

17import inspect 

18 

19import tensorflow.compat.v2 as tf 

20 

21from keras.src.dtensor import dtensor_api as dtensor 

22 

23# All the variable names in the default keras layers. We will use those to map 

24# against the args in the __init__ method to find corresponding layout args. 

25# See allow_layout() for more details. 

26KERAS_VARIABLE_NAMES = [ 

27 "alpha", 

28 "beta", 

29 "bias", 

30 "depthwise", 

31 "embeddings", 

32 "gamma", 

33 "kernel", 

34 "moving_mean", 

35 "moving_variance", 

36 "pointwise", 

37 "recurrent", 

38] 

39 

40 

41def allow_initializer_layout(init_method): 

42 """A decorator for injecting layout information to layer.__init__. 

43 

44 Layout will be a new param for any of the weights for all the keras layers. 

45 Adding the param to all the __init__ method will be a big/duplicated work. 

46 

47 This decorator is design to reduce and code duplication and make it easy to 

48 add/remove the dtensor feature if needed. 

49 

50 Sample usage: 

51 ```python 

52 class Dense(tf.keras.layer.Layer): 

53 

54 @allow_initializer_layout 

55 def __init__(self, units, 

56 kernel_initializer='zeros', 

57 bias_initializer='zeros', 

58 **kwargs): 

59 super().__init__(**kwargs) 

60 

61 d = Dense(units=8, kernel_layout=layout1, bias_layout=layout2) 

62 d.kernel_layout == layout1 

63 d.bias_layout == layout2 

64 ``` 

65 

66 By adding this annotation, it will: 

67 

68 1. Filter out the kwargs based on some keywords, eg if the 

69 'kernel_initialzer' appears in method signature, then it will try to pop 

70 the 'kernel_layout' if it presents. Same for "bias" and 

71 "recurrent_kernel", etc. This will make sure the layout related param is 

72 not passed to `BaseLayer.__init__`, which will raise error about unexpect 

73 keyword args. 

74 2. Set the self.kernel/bias_layout attribute after the `__init__` method is 

75 called. Keras framework will use those fields to create weights down the 

76 stream. 

77 

78 Args: 

79 init_method: the `__init__` method of the Keras layer to annotate. 

80 

81 Returns: 

82 the annotated __init__ method. 

83 """ 

84 

85 def _wrap_function(layer_instance, *args, **kwargs): 

86 signature = inspect.signature(init_method) 

87 layout_args = {} 

88 # Check args like 'kernel_initializer' and pop the 'kernel_layout' if it 

89 # presents. 

90 for variable_name in KERAS_VARIABLE_NAMES: 

91 if variable_name + "_initializer" in signature.parameters: 

92 layout = kwargs.pop(variable_name + "_layout", None) 

93 if layout: 

94 layout_args[variable_name + "_layout"] = layout 

95 

96 init_method(layer_instance, *args, **kwargs) 

97 

98 # Inject the layout parameter after the invocation of __init__() 

99 for layout_param_name, layout in layout_args.items(): 

100 setattr(layer_instance, layout_param_name, layout) 

101 

102 # return decorated 

103 return tf.__internal__.decorator.make_decorator( 

104 target=init_method, decorator_func=_wrap_function 

105 ) 

106 

107 

108def inject_mesh(init_method): 

109 """Inject DTensor mesh information to an object. 

110 

111 This is useful for keras object like `Metric` and `Optimizer` which need 

112 DTensor mesh to create the weights, but doesn't want to change the current 

113 public API interface. 

114 

115 This is for temporary usage and eventually the mesh/layout information will 

116 be public arguments in the `__init__` method. 

117 

118 Sample usage: 

119 ```python 

120 class Accuracy(tf.keras.metrics.Metric): 

121 

122 @inject_mesh 

123 def __init__(self, name='accuracy', dtype=None): 

124 super().__init__(**kwargs) 

125 

126 acc = Accuracy(mesh=mesh) 

127 assert acc._mesh == mesh 

128 ``` 

129 

130 Args: 

131 init_method: the `__init__` method of the Keras class to annotate. 

132 

133 Returns: 

134 the annotated __init__ method. 

135 """ 

136 

137 def _wrap_function(instance, *args, **kwargs): 

138 mesh = kwargs.pop("mesh", None) 

139 # Note that the injection of _mesh need to happen before the invocation 

140 # of __init__, since the class might need the mesh to create weights in 

141 # the __init__. 

142 if mesh is not None: 

143 instance._mesh = mesh 

144 init_method(instance, *args, **kwargs) 

145 

146 return tf.__internal__.decorator.make_decorator( 

147 target=init_method, decorator_func=_wrap_function 

148 ) 

149 

150 

151def call_with_layout(fn, layout, *args, **kwargs): 

152 """Invoke the function with inputs and relayout the result. 

153 

154 Args: 

155 fn: the function to invoke. 

156 layout: if not None, the output of the fn will be relayout with this. 

157 *args: positional arguments to be called with fn. 

158 **kwargs: keyword arguments to be called with fn. 

159 

160 Returns: 

161 The output of fn, with potential relayout with the layout specified. 

162 """ 

163 if layout: 

164 with dtensor.default_mesh(layout.mesh): 

165 result = fn(*args, **kwargs) 

166 return dtensor.relayout(result, layout) 

167 return fn(*args, **kwargs) 

168 

169 

170def running_with_dtensor_strategy(): 

171 """Check whether running with a `Strategy` that is backed by DTensor. 

172 

173 In the DTensor based training, all the tensors are in global context, which 

174 is different from the local context. Some keras components need to 

175 behave differently, e.g. BatchNormalization and SyncBatchNormalization, as 

176 well as optimizers. 

177 

178 This check will help those layer to branch the logic and keep the correct 

179 behavior between different context. 

180 """ 

181 if not tf.distribute.has_strategy(): 

182 return False 

183 strategy = tf.distribute.get_strategy() 

184 # TODO(scottzhu): Finalize the strategy API to check if a strategy is backed 

185 # by DTensor. 

186 return getattr(strategy, "_mesh", None) is not None 

187