Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/core/tf_op_layer.py: 34%

199 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Contains the TFOpLambda layer.""" 

16import tensorflow.compat.v2 as tf 

17 

18from keras.src import backend 

19from keras.src.engine import keras_tensor 

20from keras.src.engine.base_layer import Layer 

21 

22# isort: off 

23from tensorflow.python.platform import tf_logging 

24from tensorflow.python.util.tf_export import ( 

25 get_canonical_name_for_symbol, 

26) 

27from tensorflow.python.util.tf_export import ( 

28 get_symbol_from_name, 

29) 

30 

31 

32class ClassMethod(Layer): 

33 """Wraps a TF API Class's class method in a `Layer` object. 

34 

35 It is inserted by the Functional API construction whenever users call 

36 a supported TF Class's class method on KerasTensors. 

37 

38 This is useful in the case where users do something like: 

39 x = keras.Input(...) 

40 y = keras.Input(...) 

41 out = tf.RaggedTensor.from_row_splits(x, y) 

42 """ 

43 

44 @tf.__internal__.tracking.no_automatic_dependency_tracking 

45 def __init__(self, cls_ref, method_name, **kwargs): 

46 self.cls_ref = cls_ref 

47 self.method_name = method_name 

48 self.cls_symbol = get_canonical_name_for_symbol( 

49 self.cls_ref, add_prefix_to_v1_names=True 

50 ) or get_canonical_name_for_symbol( 

51 self.cls_ref, api_name="keras", add_prefix_to_v1_names=True 

52 ) 

53 if "name" not in kwargs: 

54 kwargs["name"] = backend.unique_object_name( 

55 "tf." + self.cls_symbol + "." + self.method_name, 

56 zero_based=True, 

57 avoid_observed_names=True, 

58 ) 

59 kwargs["autocast"] = False 

60 

61 # Do not individually trace op layers in the SavedModel. 

62 self._must_restore_from_config = True 

63 

64 super().__init__(**kwargs) 

65 

66 # Preserve all argument data structures when saving/loading a config 

67 # (e.g., don't unnest lists that contain one element) 

68 self._preserve_input_structure_in_config = True 

69 

70 self._call_spec.expects_training_arg = False 

71 self._call_spec.expects_mask_arg = False 

72 

73 def call(self, args, kwargs): 

74 return getattr(self.cls_ref, self.method_name)(*args, **kwargs) 

75 

76 def get_config(self): 

77 if not self.cls_symbol: 

78 raise ValueError( 

79 "This Keras class method conversion tried to convert " 

80 f"a method belonging to class {self.cls_symbol}, a class " 

81 "that is not publicly exposed in the TensorFlow API. " 

82 "To ensure cross-version compatibility of Keras models " 

83 "that use op layers, only op layers produced from " 

84 "public TensorFlow API symbols can be serialized." 

85 ) 

86 

87 config = { 

88 "cls_symbol": self.cls_symbol, 

89 "method_name": self.method_name, 

90 } 

91 base_config = super().get_config() 

92 return dict(list(base_config.items()) + list(config.items())) 

93 

94 @classmethod 

95 def from_config(cls, config, custom_objects=None): 

96 config = config.copy() 

97 symbol_name = config.pop("cls_symbol") 

98 cls_ref = get_symbol_from_name(symbol_name) 

99 if not cls_ref: 

100 raise ValueError( 

101 f"TensorFlow symbol `{symbol_name}` could not be found." 

102 ) 

103 

104 config["cls_ref"] = cls_ref 

105 

106 return cls(**config) 

107 

108 

109class KerasOpDispatcher(tf.__internal__.dispatch.GlobalOpDispatcher): 

110 """A global dispatcher that allows building a functional model with TF 

111 Ops.""" 

112 

113 def handle(self, op, args, kwargs): 

114 """Handle the specified operation with the specified arguments.""" 

115 if any( 

116 isinstance(x, keras_tensor.KerasTensor) 

117 for x in tf.nest.flatten([args, kwargs]) 

118 ): 

119 return TFOpLambda(op)(*args, **kwargs) 

120 else: 

121 return self.NOT_SUPPORTED 

122 

123 

124KerasOpDispatcher().register() 

125 

126 

127class InstanceProperty(Layer): 

128 """Wraps an instance property access (e.g. 

129 

130 `x.foo`) in a Keras Layer. 

131 

132 This layer takes an attribute name `attr_name` in the constructor and, 

133 when called on input tensor `obj` returns `obj.attr_name`. 

134 

135 KerasTensors specialized for specific extension types use it to 

136 represent instance property accesses on the represented object in the 

137 case where the property needs to be dynamically accessed as opposed to 

138 being statically computed from the typespec, e.g. 

139 

140 x = keras.Input(..., ragged=True) 

141 out = x.flat_values 

142 """ 

143 

144 @tf.__internal__.tracking.no_automatic_dependency_tracking 

145 def __init__(self, attr_name, **kwargs): 

146 self.attr_name = attr_name 

147 

148 if "name" not in kwargs: 

149 kwargs["name"] = backend.unique_object_name( 

150 "input." + self.attr_name, 

151 zero_based=True, 

152 avoid_observed_names=True, 

153 ) 

154 kwargs["autocast"] = False 

155 

156 # Do not individually trace op layers in the SavedModel. 

157 self._must_restore_from_config = True 

158 

159 super().__init__(**kwargs) 

160 

161 # Preserve all argument data structures when saving/loading a config 

162 # (e.g., don't unnest lists that contain one element) 

163 self._preserve_input_structure_in_config = True 

164 

165 def call(self, obj): 

166 return getattr(obj, self.attr_name) 

167 

168 def get_config(self): 

169 config = {"attr_name": self.attr_name} 

170 base_config = super().get_config() 

171 return dict(list(base_config.items()) + list(config.items())) 

172 

173 @classmethod 

174 def from_config(cls, config, custom_objects=None): 

175 return cls(**config) 

176 

177 

178class InstanceMethod(InstanceProperty): 

179 """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer. 

180 

181 This layer takes an attribute name `attr_name` in the constructor and, 

182 when called on input tensor `obj` with additional arguments `args` and 

183 `kwargs` returns `obj.attr_name(*args, **kwargs)`. 

184 

185 KerasTensors specialized for specific extension types use it to 

186 represent dynamic instance method calls on the represented object, e.g. 

187 

188 x = keras.Input(..., ragged=True) 

189 new_values = keras.Input(...) 

190 out = x.with_values(new_values) 

191 """ 

192 

193 def call(self, obj, args, kwargs): 

194 method = getattr(obj, self.attr_name) 

195 return method(*args, **kwargs) 

196 

197 

198class TFOpLambda(Layer): 

199 """Wraps TF API symbols in a `Layer` object. 

200 

201 It is inserted by the Functional API construction whenever users call 

202 a supported TF symbol on KerasTensors. 

203 

204 Like Lambda layers, this layer tries to raise warnings when it detects users 

205 explicitly use variables in the call. (To let them know 

206 that the layer will not capture the variables). 

207 

208 This is useful in the case where users do something like: 

209 x = keras.Input(...) 

210 y = tf.Variable(...) 

211 out = x * tf_variable 

212 """ 

213 

214 @tf.__internal__.tracking.no_automatic_dependency_tracking 

215 def __init__(self, function, **kwargs): 

216 self.function = function 

217 self.symbol = get_canonical_name_for_symbol( 

218 self.function, add_prefix_to_v1_names=True 

219 ) or get_canonical_name_for_symbol( 

220 self.function, api_name="keras", add_prefix_to_v1_names=True 

221 ) 

222 if "name" not in kwargs: 

223 # Generate a name. 

224 # TFOpLambda layers avoid already-observed names, 

225 # because users cannot easily control the generated names. 

226 # Without this avoidance, users would be more likely to run 

227 # into unavoidable duplicate layer name collisions. 

228 # (For standard layers users could just set `name` when creating the 

229 # layer to work around a collision, but they can't do that for 

230 # auto-generated layers) 

231 if self.symbol: 

232 name = "tf." + self.symbol 

233 else: 

234 name = self.function.__name__ 

235 kwargs["name"] = backend.unique_object_name( 

236 name, zero_based=True, avoid_observed_names=True 

237 ) 

238 kwargs["autocast"] = False 

239 

240 # Decorate the function to produce this layer's call method 

241 def _call_wrapper(*args, **kwargs): 

242 return self._call_wrapper(*args, **kwargs) 

243 

244 self.call = tf.__internal__.decorator.make_decorator( 

245 function, _call_wrapper 

246 ) 

247 

248 # Do not individually trace op layers in the SavedModel. 

249 self._must_restore_from_config = True 

250 

251 super().__init__(**kwargs) 

252 

253 # Preserve all argument data structures when saving/loading a config 

254 # (e.g., don't unnest lists that contain one element) 

255 self._preserve_input_structure_in_config = True 

256 

257 # Warning on every invocation will be quite irksome in Eager mode. 

258 self._already_warned = False 

259 

260 self._call_spec.expects_training_arg = False 

261 self._call_spec.expects_mask_arg = False 

262 

263 def _call_wrapper(self, *args, **kwargs): 

264 created_variables = [] 

265 

266 def _variable_creator(next_creator, **creator_kwargs): 

267 var = next_creator(**creator_kwargs) 

268 created_variables.append(var) 

269 return var 

270 

271 with tf.GradientTape( 

272 watch_accessed_variables=True 

273 ) as tape, tf.variable_creator_scope(_variable_creator): 

274 # We explicitly drop `name` arguments here, 

275 # to guard against the case where an op explicitly has a 

276 # `name` passed (which is susceptible to producing 

277 # multiple ops w/ the same name when the layer is reused) 

278 kwargs.pop("name", None) 

279 result = self.function(*args, **kwargs) 

280 self._check_variables(created_variables, tape.watched_variables()) 

281 return result 

282 

283 def _check_variables(self, created_variables, accessed_variables): 

284 if not created_variables and not accessed_variables: 

285 # In the common case that a Lambda layer does not touch a Variable, 

286 # we don't want to incur the runtime cost of assembling any state 

287 # used for checking only to immediately discard it. 

288 return 

289 

290 tracked_weights = set(v.ref() for v in self.weights) 

291 untracked_new_vars = [ 

292 v for v in created_variables if v.ref() not in tracked_weights 

293 ] 

294 if untracked_new_vars: 

295 variable_str = "\n".join(f" {i}" for i in untracked_new_vars) 

296 raise ValueError( 

297 "The following Variables were created within a Lambda layer " 

298 f"({self.name}) but are not tracked by said layer: " 

299 f"{variable_str}\n" 

300 "The layer cannot safely ensure proper Variable reuse " 

301 "across multiple calls, and consequently this behavior " 

302 "is disallowed for safety reasons. Lambda layers are " 

303 "not well suited for stateful computation; instead, " 

304 "writing a subclassed Layer is the recommend " 

305 "way to define layers with Variables." 

306 ) 

307 

308 untracked_used_vars = [ 

309 v for v in accessed_variables if v.ref() not in tracked_weights 

310 ] 

311 if untracked_used_vars and not self._already_warned: 

312 variable_str = "\n".join(f" {i}" for i in untracked_used_vars) 

313 self._warn( 

314 "The following Variables were used in a Lambda layer's call " 

315 f"({self.name}), but are not present in its tracked objects: " 

316 f"{variable_str}. This is a strong indication that the Lambda " 

317 "layer should be rewritten as a subclassed Layer." 

318 ) 

319 self._already_warned = True 

320 

321 def _warn(self, msg): 

322 # This method will be overridden in a unit test to raise an error, 

323 # because self.assertWarns is not universally implemented. 

324 return tf_logging.warning(msg) 

325 

326 def get_config(self): 

327 if not self.symbol: 

328 raise ValueError( 

329 f"This Keras op layer was generated from {self.function}, a " 

330 "method that is not publicly exposed in the TensorFlow API. " 

331 "This may have happened if the method was explicitly " 

332 "decorated to add dispatching support, and it was used " 

333 "during Functional model construction. " 

334 "To ensure cross-version compatibility of Keras models " 

335 "that use op layers, only op layers produced from " 

336 "public TensorFlow API symbols can be serialized." 

337 ) 

338 config = {"function": self.symbol} 

339 

340 base_config = super().get_config() 

341 return dict(list(base_config.items()) + list(config.items())) 

342 

343 @classmethod 

344 def from_config(cls, config, custom_objects=None): 

345 config = config.copy() 

346 symbol_name = config["function"] 

347 function = get_symbol_from_name(symbol_name) 

348 if not function: 

349 raise ValueError(f"TF symbol `{symbol_name}` could not be found.") 

350 

351 config["function"] = function 

352 

353 return cls(**config) 

354 

355 

356def _delegate_property(keras_tensor_cls, property_name): 

357 """Register property on a KerasTensor class. 

358 

359 Calling this multiple times with the same arguments should be a no-op. 

360 

361 This method exposes a property on the KerasTensor class that will use an 

362 `InstanceProperty` layer to access the property on the represented 

363 intermediate values in the model. 

364 

365 Args: 

366 keras_tensor_cls: The KerasTensor subclass that should expose the 

367 property. 

368 property_name: The name of the property to expose and delegate to the 

369 represented (Composite)Tensor. 

370 """ 

371 # We use a lambda because we can't create a Keras layer at import time 

372 # due to dynamic layer class versioning. 

373 property_access = property( 

374 lambda self: InstanceProperty(property_name)(self) 

375 ) 

376 setattr(keras_tensor_cls, property_name, property_access) 

377 

378 

379def _delegate_method(keras_tensor_cls, method_name): 

380 """Register method on a KerasTensor class. 

381 

382 Calling this function times with the same arguments should be a no-op. 

383 

384 This method exposes an instance method on the KerasTensor class that will 

385 use an `InstanceMethod` layer to run the desired method on the represented 

386 intermediate values in the model. 

387 

388 Args: 

389 keras_tensor_cls: The KerasTensor subclass that should expose the 

390 property. 

391 method_name: The name of the method to expose and delegate to the 

392 represented (Composite)Tensor. 

393 """ 

394 

395 def delegate(self, *args, **kwargs): 

396 return InstanceMethod(method_name)(self, args, kwargs) 

397 

398 setattr(keras_tensor_cls, method_name, delegate) 

399 

400 

401# We do not support the `uniform_row_length` property because it 

402# returns either `None` or an int tensor, and code that relies on it tends 

403# to check `is None` directly. Delegating it here would always return a 

404# `KerasTensor`, regardless of what can be statically inferred. This would 

405# never equal `None`, breaking code that expects it to be partially-static 

406# in unpredictable ways. 

407for ragged_property in [ 

408 "values", 

409 "flat_values", 

410 "row_splits", 

411 "nested_row_splits", 

412]: 

413 _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property) 

414 

415for ragged_method_name in [ 

416 "value_rowids", 

417 "nested_value_rowids", 

418 "nrows", 

419 "row_starts", 

420 "row_limits", 

421 "row_lengths", 

422 "nested_row_lengths", 

423 "bounding_shape", 

424 "with_values", 

425 "with_flat_values", 

426 "with_row_splits_dtype", 

427 "merge_dims", 

428 "to_tensor", 

429 "to_sparse", 

430]: 

431 _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name) 

432 

433for sparse_property in [ 

434 "indices", 

435 "values", 

436 "dense_shape", 

437]: 

438 _delegate_property(keras_tensor.SparseKerasTensor, sparse_property) 

439 

440for sparse_method in [ 

441 "with_values", 

442]: 

443 _delegate_method(keras_tensor.SparseKerasTensor, sparse_method) 

444 

445 

446class TFClassMethodDispatcher(tf.__internal__.dispatch.OpDispatcher): 

447 """A class method dispatcher that allows building a functional model with TF 

448 class methods.""" 

449 

450 def __init__(self, cls, method_name): 

451 self.cls = cls 

452 self.method_name = method_name 

453 

454 def handle(self, args, kwargs): 

455 """Handle the specified operation with the specified arguments.""" 

456 if any( 

457 isinstance(x, keras_tensor.KerasTensor) 

458 for x in tf.nest.flatten([args, kwargs]) 

459 ): 

460 return ClassMethod(self.cls, self.method_name)(args[1:], kwargs) 

461 else: 

462 return self.NOT_SUPPORTED 

463 

464 

465for ragged_class_method in [ 

466 "from_value_rowids", 

467 "from_row_splits", 

468 "from_row_lengths", 

469 "from_row_starts", 

470 "from_row_limits", 

471 "from_uniform_row_length", 

472 "from_nested_value_rowids", 

473 "from_nested_row_splits", 

474 "from_nested_row_lengths", 

475 "from_tensor", 

476 "from_sparse", 

477]: 

478 TFClassMethodDispatcher(tf.RaggedTensor, ragged_class_method).register( 

479 getattr(tf.RaggedTensor, ragged_class_method) 

480 ) 

481 

482 

483class SlicingOpLambda(TFOpLambda): 

484 """Wraps TF API symbols in a `Layer` object. 

485 

486 It is inserted by the Functional API construction whenever users call 

487 a supported TF symbol on KerasTensors. 

488 

489 Like Lambda layers, this layer tries to raise warnings when it detects users 

490 explicitly use variables in the call. (To let them know 

491 that the layer will not capture the variables). 

492 

493 This is useful in the case where users do something like: 

494 x = keras.Input(...) 

495 y = tf.Variable(...) 

496 out = x * tf_variable 

497 """ 

498 

499 @tf.__internal__.tracking.no_automatic_dependency_tracking 

500 def __init__(self, function, **kwargs): 

501 super().__init__(function, **kwargs) 

502 

503 original_call = self.call 

504 

505 # Decorate the function to produce this layer's call method 

506 def _call_wrapper(*args, **kwargs): 

507 # Turn any slice dicts in the args back into `slice` objects. 

508 # This conversion cannot use nest.flatten/map_structure, 

509 # because dicts are flattened by nest while slices aren't. 

510 # So, map_structure would only see the individual elements in the 

511 # dict. 

512 # This can't use map_structure_up_to either because the 

513 # 'shallowness' of the shallow tree would have to vary depending on 

514 # if only one dim or multiple are being sliced. 

515 new_args = [] 

516 for arg in args: 

517 arg = _dict_to_slice(arg) 

518 if isinstance(arg, (list, tuple)): 

519 new_arg = [] 

520 for sub_arg in arg: 

521 new_arg.append(_dict_to_slice(sub_arg)) 

522 arg = new_arg 

523 new_args.append(arg) 

524 

525 # Handle the kwargs too. 

526 new_kwargs = {} 

527 for key, value in kwargs.items(): 

528 value = _dict_to_slice(value) 

529 if isinstance(value, (list, tuple)): 

530 new_value = [] 

531 for v in value: 

532 new_value.append(_dict_to_slice(v)) 

533 value = new_value 

534 new_kwargs[key] = value 

535 

536 return original_call(*new_args, **new_kwargs) 

537 

538 self.call = tf.__internal__.decorator.make_decorator( 

539 original_call, _call_wrapper 

540 ) 

541 

542 

543def _slice_to_dict(x): 

544 if isinstance(x, slice): 

545 return {"start": x.start, "stop": x.stop, "step": x.step} 

546 return x 

547 

548 

549def _dict_to_slice(x): 

550 if isinstance(x, dict): 

551 return slice(x["start"], x["stop"], x["step"]) 

552 return x 

553 

554 

555class TFSlicingOpDispatcher(tf.__internal__.dispatch.OpDispatcher): 

556 """A global dispatcher that allows building a functional model with TF 

557 Ops.""" 

558 

559 def __init__(self, op): 

560 self.op = op 

561 

562 def handle(self, args, kwargs): 

563 """Handle the specified operation with the specified arguments.""" 

564 args = tf.nest.map_structure(_slice_to_dict, args) 

565 kwargs = tf.nest.map_structure(_slice_to_dict, kwargs) 

566 if any( 

567 isinstance(x, keras_tensor.KerasTensor) 

568 for x in tf.nest.flatten([args, kwargs]) 

569 ): 

570 return SlicingOpLambda(self.op)(*args, **kwargs) 

571 else: 

572 return self.NOT_SUPPORTED 

573 

574 

575for slicing_op in [ 

576 tf.__operators__.getitem, 

577 tf.compat.v1.boolean_mask, 

578 tf.boolean_mask, 

579 tf.__operators__.ragged_getitem, 

580]: 

581 TFSlicingOpDispatcher(slicing_op).register(slicing_op) 

582