Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/dtensor/layout_map.py: 26%

137 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Library for map layout and corresponding tf.Variable.""" 

16 

17import collections 

18import contextlib 

19import re 

20import threading 

21 

22import tensorflow.compat.v2 as tf 

23 

24from keras.src.dtensor import dtensor_api as dtensor 

25from keras.src.dtensor import lazy_variable 

26from keras.src.dtensor import utils 

27from keras.src.engine import base_layer 

28 

29# isort: off 

30from tensorflow.python.util.tf_export import keras_export 

31 

32 

33# We will skip the path for certain attributes when mapping the layout, e.g. 

34# model._self_tracked_trackables, or layer._trainable_weights/ 

35# _non_trainable_weights, etc. Those attributes are usually served as a cache, 

36# and the actual variable should be in somewhere else. 

37_KERAS_ATTRIBUTES_TO_SKIP = [ 

38 "_self_tracked_trackables", 

39 "_trainable_weights", 

40 "_non_trainable_weights", 

41 "_captured_weight_regularizer", 

42] 

43 

44 

45_LAYOUT_MAP = threading.local() 

46 

47 

48def get_current_layout_map(): 

49 return getattr(_LAYOUT_MAP, "layout_map", None) 

50 

51 

52@keras_export("keras.dtensor.experimental.LayoutMap", v1=[]) 

53class LayoutMap(collections.abc.MutableMapping): 

54 """A dict-like object that maps string to `Layout` instances. 

55 

56 `LayoutMap` uses a string as key and a `Layout` as value. There is a 

57 behavior difference between a normal Python dict and this class. The string 

58 key will be treated as a regex when retrieving the value. See the docstring 

59 of `get` for more details. 

60 

61 See below for a usage example. You can define the naming schema 

62 of the `Layout`, and then retrieve the corresponding `Layout` instance. 

63 

64 To use the `LayoutMap` with a `Model`, please see the docstring of 

65 `tf.keras.dtensor.experimental.layout_map_scope`. 

66 

67 ```python 

68 map = LayoutMap(mesh=None) 

69 map['.*dense.*kernel'] = layout_2d 

70 map['.*dense.*bias'] = layout_1d 

71 map['.*conv2d.*kernel'] = layout_4d 

72 map['.*conv2d.*bias'] = layout_1d 

73 

74 layout_1 = map['dense_1.kernel'] # layout_1 == layout_2d 

75 layout_2 = map['dense_1.bias'] # layout_2 == layout_1d 

76 layout_3 = map['dense_2.kernel'] # layout_3 == layout_2d 

77 layout_4 = map['dense_2.bias'] # layout_4 == layout_1d 

78 layout_5 = map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d 

79 layout_6 = map['my_model/conv2d_123/bias'] # layout_6 == layout_1d 

80 ``` 

81 

82 Args: 

83 mesh: An optional `Mesh` that can be used to create all replicated 

84 layout as default when there isn't a layout found based on the input 

85 string query. 

86 """ 

87 

88 def __init__(self, mesh=None): 

89 self._layout_map = collections.OrderedDict() 

90 self._default_mesh = mesh 

91 

92 def __getitem__(self, key): 

93 """Retrieve the corresponding layout by the string key. 

94 

95 When there isn't an exact match, all the existing keys in the layout map 

96 will be treated as a regex and map against the input key again. The 

97 first match will be returned, based on the key insertion order. Return 

98 None if there isn't any match found. 

99 

100 Args: 

101 key: the string key as the query for the layout. 

102 

103 Returns: 

104 Corresponding layout based on the query. 

105 """ 

106 if key in self._layout_map: 

107 return self._layout_map[key] 

108 

109 for k in self._layout_map: 

110 if re.match(k, key): 

111 return self._layout_map[k] 

112 return None 

113 

114 def __setitem__(self, key, layout): 

115 if key in self._layout_map: 

116 raise ValueError( 

117 f"{key} already exist in the LayoutMap with " 

118 f"value {self._layout_map[key]}. Please make sure to " 

119 "not use duplicated keys." 

120 ) 

121 if not isinstance(layout, dtensor.Layout): 

122 raise ValueError( 

123 f"{layout} should be a dtensor.Layout type, got {type(layout)}" 

124 ) 

125 

126 self._layout_map[key] = layout 

127 

128 def __delitem__(self, key): 

129 # let the dict to handle the key missing error 

130 return self._layout_map.pop(key) 

131 

132 def __len__(self): 

133 return len(self._layout_map) 

134 

135 def __iter__(self): 

136 return iter(self._layout_map) 

137 

138 def get_default_mesh(self): 

139 """Return the default `Mesh` set at instance creation. 

140 

141 The `Mesh` can be used to create default replicated `Layout` when there 

142 isn't a match of the input string query. 

143 """ 

144 return self._default_mesh 

145 

146 def scope(self): 

147 """Apply layout to all `tf.Variable` instances created under the scope. 

148 

149 All `tf.Variable` instances created under this scope 

150 will be lazily initialized first. Once they are attached as the model 

151 or layer attributes, and there is a stable layout mapping for it, the 

152 variables will be reinitialized into a 

153 `tf.experimental.dtensor.DVariable` with corresponding layout. 

154 

155 Note that the layout mapping will use object/attribute names as the 

156 keys to map the variable to the layout. 

157 

158 For subclassed models, the full object/attribute name is used as the 

159 key. For Functional/Sequential models, we use `layer.name` as 

160 the key for the layer, followed by the attribute name. Keras ensures 

161 name uniqueness among the layers within a Functional/Sequential model. 

162 

163 See the following examples that show variable object names 

164 for different Keras model types: 

165 

166 ```python 

167 layout_map = layout_map_lib.LayoutMap(mesh=self.mesh) 

168 layout_map['d1.kernel'] = layout_1 

169 layout_map['d1.bias'] = layout_2 

170 layout_map['d2.kernel'] = layout_3 

171 layout_map['d2.bias'] = layout_4 

172 

173 ## Subclassed model 

174 class SubclassModel(tf.keras.Model): 

175 

176 def __init__(self, name=None): 

177 super().__init__(name=name) 

178 self.d1 = tf.keras.layers.Dense(1000) 

179 self.d2 = tf.keras.layers.Dense(1000) 

180 

181 def call(self, inputs): 

182 x = self.d1(inputs) 

183 return self.d2(x) 

184 

185 with layout_map.scope(): 

186 model = SubclassModel() 

187 inputs = tf.zeros((10, 10)) 

188 results = model(inputs) 

189 

190 model.d1.kernel.layout == layout_1 

191 model.d1.bias.layout == layout_2 

192 model.d2.kernel.layout == layout_3 

193 model.d2.bias.layout == layout_4 

194 

195 ## Functional model 

196 with layout_map.scope(): 

197 inputs = tf.keras.Input((10,), batch_size=10) 

198 x = tf.keras.layers.Dense(20, name='d1')(inputs) 

199 output = tf.keras.layers.Dense(30, name='d2')(x) 

200 

201 model = tf.keras.Model(inputs, output) 

202 

203 d1 = model.layers[1] 

204 d2 = model.layers[2] 

205 

206 d1.kernel.layout == layout_1 

207 d1.bias.layout == layout_2 

208 d1.kernel.layout == layout_3 

209 d1.bias.layout == layout_4 

210 

211 ## Sequential model 

212 with layout_map.scope(): 

213 model = tf.keras.Sequential([ 

214 tf.keras.layers.Dense(20, name='d1', input_shape=(10,)), 

215 tf.keras.layers.Dense(30, name='d2') 

216 ]) 

217 

218 d1 = model.layers[0] 

219 d2 = model.layers[1] 

220 

221 d1.kernel.layout == layout_1 

222 d1.bias.layout == layout_2 

223 d1.kernel.layout == layout_3 

224 d1.bias.layout == layout_4 

225 ``` 

226 

227 Returns: 

228 A context that will lazily initialize all `tf.Variable` objects 

229 within the model, with their attributed layouts. 

230 """ 

231 return layout_map_scope(self) 

232 

233 

234LayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__ 

235 

236 

237@contextlib.contextmanager 

238def layout_map_scope(layout_map): 

239 """Apply the layout to all the tf.Variables created under the scope. 

240 

241 Create a scope that all the tf.Variable created under this scope 

242 will be lazily inited, and initialized later on with proper layout when the 

243 object path in the model is stable/finalized. 

244 

245 Note that the layout mapping will use the object/attribute names as the key 

246 to map the variable against the layout. 

247 

248 For subclassed models, the full object/attribute name is used as the key. 

249 For Functional/Sequential models, since the layers within the model do not 

250 get assigned to a meaningful attribute, we use `layer.name` as the key for 

251 the layer, followed by the attribute name. Keras ensures name uniqueness 

252 among the layers in all Functional/Sequential models. 

253 

254 See the following examples that show the variable object names 

255 for different Keras model types: 

256 

257 ```python 

258 layout_map = layout_map_lib.LayoutMap(mesh=self.mesh) 

259 layout_map['d1.kernel'] = layout_1 

260 layout_map['d1.bias'] = layout_2 

261 layout_map['d2.kernel'] = layout_3 

262 layout_map['d2.bias'] = layout_4 

263 

264 ## Subclassed model 

265 class SubclassModel(tf.keras.Model): 

266 

267 def __init__(self, name=None): 

268 super().__init__(name=name) 

269 self.d1 = tf.keras.layers.Dense(1000) 

270 self.d2 = tf.keras.layers.Dense(1000) 

271 

272 def call(self, inputs): 

273 x = self.d1(inputs) 

274 return self.d2(x) 

275 

276 with layout_map_scope(layout_map): 

277 model = SubclassModel() 

278 # Triggering the creation of weights within or outside of the scope works 

279 inputs = tf.zeros((10, 10)) 

280 results = model(inputs) 

281 

282 model.d1.kernel.layout == layout_1 

283 model.d1.bias.layout == layout_2 

284 model.d2.kernel.layout == layout_3 

285 model.d2.bias.layout == layout_4 

286 

287 ## Functional model 

288 with layout_map_scope(layout_map): 

289 inputs = tf.keras.Input((10,), batch_size=10) 

290 x = tf.keras.layers.Dense(20, name='d1')(inputs) 

291 output = tf.keras.layers.Dense(30, name='d2')(x) 

292 

293 model = tf.keras.Model(inputs, output) 

294 

295 d1 = model.layers[1] 

296 d2 = model.layers[2] 

297 

298 d1.kernel.layout == layout_1 

299 d1.bias.layout == layout_2 

300 d1.kernel.layout == layout_3 

301 d1.bias.layout == layout_4 

302 

303 ## Sequential model 

304 with layout_map_scope(layout_map): 

305 model = tf.keras.Sequential([ 

306 tf.keras.layers.Dense(20, name='d1', input_shape=(10,)), 

307 tf.keras.layers.Dense(30, name='d2') 

308 ]) 

309 

310 d1 = model.layers[0] 

311 d2 = model.layers[1] 

312 

313 d1.kernel.layout == layout_1 

314 d1.bias.layout == layout_2 

315 d1.kernel.layout == layout_3 

316 d1.bias.layout == layout_4 

317 ``` 

318 

319 Args: 

320 layout_map: a LayoutMap which contains the variable_object_path (string) 

321 -> Layout. When a layout is not found for the variable, a default all 

322 replicated layout will be created for the variable. 

323 

324 Yields: 

325 A context that will lazily initialize all `tf.Variable` objects 

326 within the model, with their attributed layouts. 

327 """ 

328 previous_layout_map = get_current_layout_map() 

329 global _LAYOUT_MAP 

330 _LAYOUT_MAP.layout_map = layout_map 

331 

332 with lazy_variable.lazy_init_scope(): 

333 try: 

334 yield 

335 finally: 

336 _LAYOUT_MAP.layout_map = previous_layout_map 

337 

338 

339def _map_subclass_model_variable(model, layout_map): 

340 """Map/Replace LazyInitVariable for subclass model.""" 

341 lazy_init_variable_to_tf_variable_map = {} 

342 

343 # Note that the model._flatten is a method from tf.Module, and it returns 

344 # duplicated items (since some of the items have different paths). 

345 for path, variable in model._flatten( 

346 predicate=_is_lazy_init_variable, 

347 with_path=True, 

348 ): 

349 # Note that path is a tuple that contains string and ints, eg: 

350 # ('d1', '_trainable_weights', 0) maps to model.d1._trainable_weights[0] 

351 if [a for a in _KERAS_ATTRIBUTES_TO_SKIP if a in path]: 

352 continue 

353 # Convert all the ints to string and join with . 

354 object_path = ".".join([str(item) for item in path]) 

355 

356 new_variable = _create_dvariable(layout_map, object_path, variable) 

357 _set_object_by_path(model, path, new_variable) 

358 lazy_init_variable_to_tf_variable_map[id(variable)] = new_variable 

359 

360 for layer in model._flatten( 

361 predicate=lambda o: isinstance(o, base_layer.Layer) 

362 ): 

363 _config_dvariable_regularization( 

364 layer, lazy_init_variable_to_tf_variable_map 

365 ) 

366 # After we replaced all the variables, we want to make sure all the cached 

367 # attributes are having the new variable, rather than old LazyInitVariable. 

368 for path, variable in model._flatten( 

369 predicate=_is_lazy_init_variable, 

370 with_path=True, 

371 ): 

372 tf_variable = lazy_init_variable_to_tf_variable_map[id(variable)] 

373 _set_object_by_path(model, path, tf_variable) 

374 

375 _init_state_variable_for_rng(model, layout_map) 

376 _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map) 

377 return model 

378 

379 

380def _map_functional_model_variable(model, layout_map): 

381 """Map/Replace LazyInitVariable for functional/sequential model.""" 

382 lazy_init_variable_to_tf_variable_map = {} 

383 

384 for layer in model.layers: 

385 # Note that layer name is unique among the functional/sequential model 

386 # when the layer name is not provided, Keras will auto generate a layer 

387 # name based on the class name. 

388 layer_name = layer.name 

389 for path, variable in layer._flatten( 

390 predicate=_is_lazy_init_variable, 

391 with_path=True, 

392 ): 

393 # Note that path is a tuple that contains string and ints, eg: 

394 # ('d1', '_trainable_weights', 0) maps to 

395 # model.d1._trainable_weights[0] 

396 if [a for a in _KERAS_ATTRIBUTES_TO_SKIP if a in path]: 

397 continue 

398 # Convert all the ints to string and join with . 

399 object_path = ".".join([str(item) for item in path]) 

400 # Also attach the layer name 

401 object_path = layer_name + "." + object_path 

402 

403 new_variable = _create_dvariable(layout_map, object_path, variable) 

404 _set_object_by_path(layer, path, new_variable) 

405 lazy_init_variable_to_tf_variable_map[id(variable)] = new_variable 

406 

407 _config_dvariable_regularization( 

408 layer, lazy_init_variable_to_tf_variable_map 

409 ) 

410 

411 # After we replaced all the variables, we want to make sure all the 

412 # cached attributes are having the new variable, rather than old 

413 # LazyInitVariable. 

414 for path, variable in layer._flatten( 

415 predicate=_is_lazy_init_variable, 

416 with_path=True, 

417 ): 

418 tf_variable = lazy_init_variable_to_tf_variable_map[id(variable)] 

419 _set_object_by_path(layer, path, tf_variable) 

420 

421 _init_state_variable_for_rng(model, layout_map) 

422 _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map) 

423 return model 

424 

425 

426def _init_state_variable_for_rng(model, layout_map): 

427 """Init the state variable in tf.ranodm.Generator. 

428 

429 Since the BaseRandomLayer in keras explicitly untrack the 

430 tf.random.Generator, the variable in it will stay as LazyInitVariable, which 

431 cause runtime error if we don't replace them with proper DVariable. Since 

432 user usually are not aware the existence of those variable, we will just 

433 give them replicated layout since they are tiny. 

434 

435 Args: 

436 model: the model whose layers will be checked to find the 

437 BaseRandomLayers. 

438 layout_map: used to get the default mesh information to create DVariable. 

439 """ 

440 

441 for l in model._flatten( 

442 predicate=lambda o: isinstance(o, base_layer.BaseRandomLayer) 

443 ): 

444 keras_generator = l._random_generator 

445 if keras_generator._built and keras_generator._generator is None: 

446 raise ValueError( 

447 "Keras is expected to use tf.random.Generator when using " 

448 "DTensor API. Please call " 

449 "`tf.keras.backend.experimental.enable_tf_random_generator` at " 

450 "the beginning of your program." 

451 ) 

452 if hasattr(keras_generator, "_generator") and _is_lazy_init_variable( 

453 keras_generator._generator._state_var 

454 ): 

455 # Replace it with DVariable 

456 keras_generator._generator._state_var = _create_dvariable( 

457 layout_map, "", keras_generator._generator._state_var 

458 ) 

459 else: 

460 # When the keras_generator is not built yet. Call the init function 

461 # with DTensor device to init all the variable with default 

462 # replicated layout. 

463 with dtensor.default_mesh(layout_map.get_default_mesh()): 

464 keras_generator._maybe_init() 

465 

466 

467def _config_dvariable_regularization( 

468 layer, lazy_init_variable_to_tf_variable_map 

469): 

470 """Update the weights regularizer for newly created `DVariable`. 

471 

472 The weight regularization usually happens when `layer.add_weight()` is 

473 called, at which point the library will first create a `LazyInitVariable`, 

474 and then replace it with a `DVariable`. We will defer the creation of those 

475 losses, until the DVariable is created. 

476 

477 See `layer._captured_weight_regularizer` for more details. 

478 

479 Args: 

480 layer: the layer instance for DVariable regularization config. 

481 lazy_init_variable_to_tf_variable_map: the dict between LazyInitVariable 

482 ID and newly created DVariable. 

483 """ 

484 

485 for name, variable, regualarizer in layer._captured_weight_regularizer: 

486 if not _is_lazy_init_variable(variable): 

487 raise ValueError( 

488 "Expect the regularization loss are created from " 

489 f"LazyInitVariable, got {variable}" 

490 ) 

491 d_variable = lazy_init_variable_to_tf_variable_map[id(variable)] 

492 layer._handle_weight_regularization(name, d_variable, regualarizer) 

493 # After that, we should cleanup `layer._captured_weight_regularizer` 

494 layer._captured_weight_regularizer = [] 

495 

496 

497def _create_dvariable(layout_map, object_path, variable): 

498 """Create a new variable instead of using the LazyInitVariable. 

499 

500 We choose to do this since even the LazyInitVariable might behavior like 

501 a normal tf.Variable/DVariable, it is not future proof for any new changes 

502 to variable class. It will also fail the instance type check in python, 

503 which could affect user's code when they do any filtering based on type to 

504 find any variables. 

505 

506 Args: 

507 layout_map: a LayoutMap which contains the variable_object_path (string) 

508 -> Layout. 

509 object_path: string, the object attribute path for the variable. 

510 variable: LazyInitVariable which will be replaced by the newly created 

511 tf.Variable. 

512 Returns: 

513 A new tf.Variable with correct layout information. 

514 """ 

515 # TODO(b/228209108): Revisit this in future and see if we can just reuse the 

516 # LazyInitVariable rather than creating a new tf.Variable instance. 

517 layout = layout_map[object_path] 

518 if layout is None: 

519 variable_rank = variable.shape.rank 

520 layout = dtensor.Layout.replicated( 

521 mesh=layout_map.get_default_mesh(), rank=variable_rank 

522 ) 

523 init_val = variable._initial_value 

524 if callable(init_val): 

525 with lazy_variable.disable_init_variable_creator(): 

526 init_val = utils.call_with_layout(init_val, layout) 

527 else: 

528 # The init value is probably already created as a tensor, we will just 

529 # copy it to mesh and give it a proper layout. 

530 init_val = dtensor.copy_to_mesh(init_val, layout) 

531 # Use the original variable name for new DVariable creation. TF was adding 

532 # ":0" suffix to it. 

533 variable_name = variable.name 

534 if variable_name.endswith(":0"): 

535 variable_name = variable_name[:-2] 

536 new_variable = dtensor.DVariable( 

537 init_val, trainable=variable.trainable, name=variable_name 

538 ) 

539 return new_variable 

540 

541 

542def _set_object_by_path(object_to_set, path, value): 

543 """Set the attribute of instance to the object. 

544 

545 Args: 

546 object_to_set: the instance whose attribute should be set. 

547 path: the tuple/list of string and ints, representing the attribute names. 

548 Int means that the attribute to set is a item a list. 

549 value: the value of the attribute. 

550 """ 

551 

552 for i, attr_name in enumerate(path): 

553 if i == len(path) - 1: 

554 # We found the actual attribute to set 

555 if isinstance(attr_name, int): 

556 # This means we are trying to set an element in the array, make 

557 # sure the instance is array like object. 

558 object_to_set[attr_name] = value 

559 else: 

560 setattr(object_to_set, attr_name, value) 

561 else: 

562 if isinstance(attr_name, int): 

563 object_to_set = object_to_set[attr_name] 

564 else: 

565 object_to_set = getattr(object_to_set, attr_name) 

566 

567 

568# TODO(b/228209108): Revisit this after we can reinit LazyInitVariable. 

569def _update_trackable_reference(model, lazy_init_variable_to_tf_variable_map): 

570 """Update the trackable object references for the model. 

571 

572 Note that this method is only needed because of a corner case for model 

573 checkpoint, where it could accidently catch a LazyInitVariable in checkpoint 

574 dependency and not visible to the model attribute graph itself. 

575 

576 Args: 

577 model: the keras model instance whose checkpoint dependency will be 

578 examed. 

579 lazy_init_variable_to_tf_variable_map: the dict between LazyInitVariable 

580 ID and newly created DVariable. 

581 """ 

582 # See b/234621758 for more details. 

583 object_graph = tf.__internal__.tracking.ObjectGraphView(model) 

584 trackables, _ = object_graph.breadth_first_traversal() 

585 for trackable in trackables: 

586 for ref_name, ref in trackable._trackable_children().items(): 

587 if _is_lazy_init_variable(ref): 

588 # Replacing the LazyVariable with DVariable. 

589 trackable._track_trackable( 

590 lazy_init_variable_to_tf_variable_map[id(ref)], 

591 ref_name, 

592 overwrite=True, 

593 ) 

594 

595 

596def _is_lazy_init_variable(obj): 

597 return isinstance(obj, lazy_variable.LazyInitVariable) 

598