Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/keras_tensor.py: 33%

262 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Keras Input Tensor used to track functional API Topology.""" 

16 

17import tensorflow.compat.v2 as tf 

18 

19from keras.src.utils import object_identity 

20 

21# isort: off 

22from tensorflow.python.data.util import structure 

23from tensorflow.python.util.tf_export import keras_export 

24 

25 

26# Tensorflow tensors have a maximum rank of 254 

27# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h ) 

28# So we do not try to infer values for int32 tensors larger than this, 

29# As they cannot represent shapes. 

30_MAX_TENSOR_RANK = 254 

31 

32 

33@keras_export("keras.__internal__.KerasTensor", v1=[]) 

34class KerasTensor: 

35 """A representation of a Keras in/output during Functional API construction. 

36 

37 `KerasTensor`s are tensor-like objects that represent the symbolic inputs 

38 and outputs of Keras layers during Functional model construction. They are 

39 comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be 

40 consumed/produced in the corresponding location of the Functional model. 

41 

42 KerasTensors are intended as a private API, so users should never need to 

43 directly instantiate `KerasTensor`s. 

44 

45 **Building Functional Models with KerasTensors** 

46 `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs 

47 to your model. 

48 

49 Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know 

50 that you are building a Functional model. The layer __call__ will 

51 infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s 

52 corresponding to the symbolic outputs of that layer call. These output 

53 `KerasTensor`s will have all of the internal KerasHistory metadata attached 

54 to them that Keras needs to construct a Functional Model. 

55 

56 Currently, layers infer the output signature by: 

57 * creating a scratch `FuncGraph` 

58 * making placeholders in the scratch graph that match the input typespecs 

59 * Calling `layer.call` on these placeholders 

60 * extracting the signatures of the outputs before clearing the scratch 

61 graph 

62 

63 (Note: names assigned to KerasTensors by this process are not guaranteed to 

64 be unique, and are subject to implementation details). 

65 

66 `tf.nest` methods are used to insure all of the inputs/output data 

67 structures get maintained, with elements swapped between KerasTensors and 

68 placeholders. 

69 

70 In rare cases (such as when directly manipulating shapes using Keras 

71 layers), the layer may be able to partially infer the value of the output in 

72 addition to just inferring the signature. 

73 When this happens, the returned KerasTensor will also contain the inferred 

74 value information. Follow-on layers can use this information. 

75 during their own output signature inference. 

76 E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses 

77 as the shape of its outputs, partially knowing the value helps infer the 

78 output shape. 

79 

80 **Automatically converting TF APIs to layers**: 

81 If you passing a `KerasTensor` to a TF API that supports dispatching, 

82 Keras will automatically turn that API call into a lambda 

83 layer in the Functional model, and return KerasTensors representing the 

84 symbolic outputs. 

85 

86 Most TF APIs that take only tensors as input and produce output tensors 

87 will support dispatching. 

88 

89 Calling a `tf.function` does not support dispatching, so you cannot pass 

90 `KerasTensor`s as inputs to a `tf.function`. 

91 

92 Higher-order APIs that take methods which produce tensors (e.g. `tf.while`, 

93 `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you 

94 cannot directly pass KerasTensors as inputs to these APIs either. If you 

95 want to use these APIs inside of a Functional model, you must put them 

96 inside of a custom layer. 

97 

98 Args: 

99 type_spec: The `tf.TypeSpec` for the symbolic input created by 

100 `tf.keras.Input`, or symbolically inferred for the output 

101 during a symbolic layer `__call__`. 

102 inferred_value: (Optional) a non-symbolic static value, possibly partially 

103 specified, that could be symbolically inferred for the outputs during 

104 a symbolic layer `__call__`. This will generally only happen when 

105 grabbing and manipulating `tf.int32` shapes directly as tensors. 

106 Statically inferring values in this way and storing them in the 

107 KerasTensor allows follow-on layers to infer output signatures 

108 more effectively. (e.g. when using a symbolic shape tensor to later 

109 construct a tensor with that shape). 

110 name: (optional) string name for this KerasTensor. Names automatically 

111 generated by symbolic layer `__call__`s are not guaranteed to be unique, 

112 and are subject to implementation details. 

113 """ 

114 

115 def __init__(self, type_spec, inferred_value=None, name=None): 

116 """Constructs a KerasTensor.""" 

117 if not isinstance(type_spec, tf.TypeSpec): 

118 raise ValueError( 

119 "KerasTensors must be constructed with a `tf.TypeSpec`." 

120 ) 

121 

122 self._type_spec = type_spec 

123 self._inferred_value = inferred_value 

124 self._name = name 

125 

126 if not isinstance(type_spec, structure.NoneTensorSpec): 

127 if not hasattr(type_spec, "shape"): 

128 raise ValueError( 

129 "KerasTensor only supports TypeSpecs that have a shape " 

130 f"field; got {type(type_spec).__qualname__}, " 

131 "which does not have a shape." 

132 ) 

133 if not isinstance(type_spec.shape, tf.TensorShape): 

134 raise TypeError( 

135 "KerasTensor requires that wrapped TypeSpec's shape is a " 

136 f"TensorShape; got TypeSpec {type(type_spec).__qualname__}" 

137 ", whose shape field has unexpected type " 

138 f"{type(type_spec.dtype).__qualname__}." 

139 ) 

140 

141 @property 

142 def type_spec(self): 

143 """Returns the `tf.TypeSpec` symbolically inferred for Keras output.""" 

144 return self._type_spec 

145 

146 @property 

147 def shape(self): 

148 """Returns the `TensorShape` symbolically inferred for Keras output.""" 

149 return self._type_spec.shape 

150 

151 @classmethod 

152 def from_tensor(cls, tensor): 

153 """Convert a traced (composite)tensor to a representative 

154 KerasTensor.""" 

155 if isinstance(tensor, tf.Tensor): 

156 name = getattr(tensor, "name", None) 

157 type_spec = tf.type_spec_from_value(tensor) 

158 inferred_value = None 

159 if ( 

160 type_spec.dtype == tf.int32 

161 and type_spec.shape.rank is not None 

162 and type_spec.shape.rank < 2 

163 ): 

164 # If this tensor might be representing shape information, 

165 # (dtype=int32, rank of 0 or 1, not too large to represent a 

166 # shape) we attempt to capture any value information 

167 # tensorflow's shape handling can extract from the current 

168 # scratch graph. 

169 # 

170 # Even though keras layers each trace in their own scratch 

171 # graph, this shape value info extraction allows us to capture a 

172 # sizable and useful subset of the C++ shape value inference TF 

173 # can do if all tf ops appear in the same graph when using shape 

174 # ops. 

175 # 

176 # Examples of things this cannot infer concrete dimensions for 

177 # that the full single-graph C++ shape inference sometimes can 

178 # are: 

179 # * cases where the shape tensor is cast out of int32 before 

180 # being manipulated w/ floating point numbers then converted 

181 # back 

182 # * cases where int32 tensors w/ rank >= 2 are manipulated 

183 # before being used as a shape tensor 

184 # * cases where int32 tensors too large to represent shapes are 

185 # manipulated to a smaller size before being used as a shape 

186 # tensor 

187 inferred_value = tf.ones(shape=tensor).shape 

188 if inferred_value.dims: 

189 inferred_value = inferred_value.as_list() 

190 if len(inferred_value) > _MAX_TENSOR_RANK: 

191 inferred_value = None 

192 else: 

193 inferred_value = None 

194 

195 return KerasTensor( 

196 type_spec, inferred_value=inferred_value, name=name 

197 ) 

198 else: 

199 # Fallback to the generic arbitrary-typespec KerasTensor 

200 name = getattr(tensor, "name", None) 

201 type_spec = tf.type_spec_from_value(tensor) 

202 return cls(type_spec, name=name) 

203 

204 @classmethod 

205 def from_type_spec(cls, type_spec, name=None): 

206 return cls(type_spec=type_spec, name=name) 

207 

208 def _to_placeholder(self): 

209 """Convert this KerasTensor to a placeholder in a graph.""" 

210 # If there is an inferred value for this tensor, inject the inferred 

211 # value 

212 if self._inferred_value is not None: 

213 # If we suspect this KerasTensor might be representing a shape 

214 # tensor, and we were able to extract value information with 

215 # TensorFlow's shape handling when making the KerasTensor, we 

216 # construct the placeholder by re-injecting the inferred value 

217 # information into the graph. We do this injection through the shape 

218 # of a placeholder, because that allows us to specify 

219 # partially-unspecified shape values. 

220 # 

221 # See the comment on value extraction inside `from_tensor` for more 

222 # info. 

223 inferred_value = tf.shape( 

224 tf.compat.v1.placeholder( 

225 shape=self._inferred_value, dtype=tf.int32 

226 ) 

227 ) 

228 if self.type_spec.shape.rank == 0: 

229 # `tf.shape` always returns a rank-1, we may need to turn it 

230 # back to a scalar. 

231 inferred_value = inferred_value[0] 

232 return inferred_value 

233 

234 # Use the generic conversion from typespec to a placeholder. 

235 def component_to_placeholder(component): 

236 return tf.compat.v1.placeholder(component.dtype, component.shape) 

237 

238 return tf.nest.map_structure( 

239 component_to_placeholder, self.type_spec, expand_composites=True 

240 ) 

241 

242 def get_shape(self): 

243 return self.shape 

244 

245 def __len__(self): 

246 raise TypeError( 

247 "Keras symbolic inputs/outputs do not " 

248 "implement `__len__`. You may be " 

249 "trying to pass Keras symbolic inputs/outputs " 

250 "to a TF API that does not register dispatching, " 

251 "preventing Keras from automatically " 

252 "converting the API call to a lambda layer " 

253 "in the Functional Model. This error will also get raised " 

254 "if you try asserting a symbolic input/output directly." 

255 ) 

256 

257 @property 

258 def op(self): 

259 raise TypeError( 

260 "Keras symbolic inputs/outputs do not " 

261 "implement `op`. You may be " 

262 "trying to pass Keras symbolic inputs/outputs " 

263 "to a TF API that does not register dispatching, " 

264 "preventing Keras from automatically " 

265 "converting the API call to a lambda layer " 

266 "in the Functional Model." 

267 ) 

268 

269 def __hash__(self): 

270 raise TypeError( 

271 f"Tensors are unhashable (this tensor: {self}). " 

272 "Instead, use tensor.ref() as the key." 

273 ) 

274 

275 # Note: This enables the KerasTensor's overloaded "right" binary 

276 # operators to run when the left operand is an ndarray, because it 

277 # accords the Tensor class higher priority than an ndarray, or a 

278 # numpy matrix. 

279 # In the future explore changing this to using numpy's __numpy_ufunc__ 

280 # mechanism, which allows more control over how Tensors interact 

281 # with ndarrays. 

282 __array_priority__ = 100 

283 

284 def __array__(self, dtype=None): 

285 raise TypeError( 

286 f"You are passing {self}, an intermediate Keras symbolic " 

287 "input/output, to a TF API that does not allow registering custom " 

288 "dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, " 

289 "or `tf.map_fn`. Keras Functional model construction only supports " 

290 "TF API calls that *do* support dispatching, such as `tf.math.add` " 

291 "or `tf.reshape`. " 

292 "Other APIs cannot be called directly on symbolic Keras" 

293 "inputs/outputs. You can work around " 

294 "this limitation by putting the operation in a custom Keras layer " 

295 "`call` and calling that layer " 

296 "on this symbolic input/output." 

297 ) 

298 

299 @property 

300 def is_tensor_like(self): 

301 return True 

302 

303 def set_shape(self, shape): 

304 """Updates the shape of this KerasTensor. Mimics 

305 `tf.Tensor.set_shape()`.""" 

306 if not isinstance(shape, tf.TensorShape): 

307 shape = tf.TensorShape(shape) 

308 if not self.shape.is_compatible_with(shape): 

309 raise ValueError( 

310 f"Keras symbolic input/output's shape {self.shape} is not " 

311 f"compatible with supplied shape {shape}." 

312 ) 

313 else: 

314 shape = self.shape.merge_with(shape) 

315 self._type_spec = type_spec_with_shape(self._type_spec, shape) 

316 

317 def __str__(self): 

318 symbolic_description = "" 

319 inferred_value_string = "" 

320 name_string = "" 

321 

322 if hasattr(self, "_keras_history"): 

323 layer = self._keras_history.layer 

324 symbolic_description = ", description=\"created by layer '%s'\"" % ( 

325 layer.name, 

326 ) 

327 if self._inferred_value is not None: 

328 inferred_value_string = f", inferred_value={self._inferred_value}" 

329 if self.name is not None: 

330 name_string = f", name='{self._name}'" 

331 return "KerasTensor(type_spec=%s%s%s%s)" % ( 

332 self.type_spec, 

333 inferred_value_string, 

334 name_string, 

335 symbolic_description, 

336 ) 

337 

338 def __repr__(self): 

339 symbolic_description = "" 

340 inferred_value_string = "" 

341 if isinstance(self.type_spec, tf.TensorSpec): 

342 type_spec_string = f"shape={self.shape} dtype={self.dtype.name}" 

343 else: 

344 type_spec_string = f"type_spec={self.type_spec}" 

345 

346 if hasattr(self, "_keras_history"): 

347 layer = self._keras_history.layer 

348 symbolic_description = f" (created by layer '{layer.name}')" 

349 if self._inferred_value is not None: 

350 inferred_value_string = f" inferred_value={self._inferred_value}" 

351 return "<KerasTensor: %s%s%s>" % ( 

352 type_spec_string, 

353 inferred_value_string, 

354 symbolic_description, 

355 ) 

356 

357 @property 

358 def dtype(self): 

359 """Returns the `dtype` symbolically inferred for this Keras output.""" 

360 type_spec = self._type_spec 

361 if not hasattr(type_spec, "dtype"): 

362 raise AttributeError( 

363 f"KerasTensor wraps TypeSpec {type(type_spec).__qualname__}, " 

364 "which does not have a dtype." 

365 ) 

366 if not isinstance(type_spec.dtype, tf.DType): 

367 raise TypeError( 

368 "KerasTensor requires that wrapped TypeSpec's dtype is a " 

369 f"DType; got TypeSpec {type(type_spec).__qualname__}, whose " 

370 "dtype field has unexpected type " 

371 f"{type(type_spec.dtype).__qualname__}." 

372 ) 

373 return type_spec.dtype 

374 

375 def ref(self): 

376 """Returns a hashable reference object to this KerasTensor. 

377 

378 The primary use case for this API is to put KerasTensors in a 

379 set/dictionary. We can't put tensors in a set/dictionary as 

380 `tensor.__hash__()` is not available and tensor equality (`==`) is 

381 supposed to produce a tensor representing if the two inputs are equal. 

382 

383 See the documentation of `tf.Tensor.ref()` for more info. 

384 """ 

385 return object_identity.Reference(self) 

386 

387 @property 

388 def node(self): 

389 """Find the corresponding `Node` that produce this keras_tensor. 

390 

391 During functional model construction, Keras will attach `KerasHistory` 

392 to keras tensor to track the connectivity between calls of layers. 

393 Return None if there isn't any KerasHistory attached to this tensor. 

394 """ 

395 if hasattr(self, "_keras_history"): 

396 layer, node_index, _ = self._keras_history 

397 return layer.inbound_nodes[node_index] 

398 return None 

399 

400 def __iter__(self): 

401 shape = None 

402 if self.shape.ndims is not None: 

403 shape = [dim.value for dim in self.shape.dims] 

404 

405 if shape is None: 

406 raise TypeError("Cannot iterate over a Tensor with unknown shape.") 

407 if not shape: 

408 raise TypeError("Cannot iterate over a scalar.") 

409 if shape[0] is None: 

410 raise TypeError( 

411 "Cannot iterate over a Tensor with unknown first dimension." 

412 ) 

413 return _KerasTensorIterator(self, shape[0]) 

414 

415 @property 

416 def name(self): 

417 """Returns the (non-unique, optional) name of this symbolic Keras 

418 value.""" 

419 return self._name 

420 

421 @classmethod 

422 def _overload_all_operators(cls, tensor_class): 

423 """Register overloads for all operators.""" 

424 for operator in tf.Tensor.OVERLOADABLE_OPERATORS: 

425 cls._overload_operator(tensor_class, operator) 

426 

427 # We include `experimental_ref` for versions of TensorFlow that 

428 # still include the deprecated method in Tensors. 

429 if hasattr(tensor_class, "experimental_ref"): 

430 cls._overload_operator(tensor_class, "experimental_ref") 

431 

432 @classmethod 

433 def _overload_operator(cls, tensor_class, operator): 

434 """Overload operator with the same implementation as the Tensor class. 

435 

436 We pull the operator out of the class dynamically to avoid ordering 

437 issues. 

438 

439 Args: 

440 tensor_class: The (Composite)Tensor to get the method from. 

441 operator: string. The operator name. 

442 """ 

443 tensor_oper = getattr(tensor_class, operator) 

444 

445 # Compatibility with Python 2: 

446 # Python 2 unbound methods have type checks for the first arg, 

447 # so we need to extract the underlying function 

448 tensor_oper = getattr(tensor_oper, "__func__", tensor_oper) 

449 

450 setattr(cls, operator, tensor_oper) 

451 

452 

453KerasTensor._overload_all_operators(tf.Tensor) 

454 

455 

456@keras_export("keras.__internal__.SparseKerasTensor", v1=[]) 

457class SparseKerasTensor(KerasTensor): 

458 """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s. 

459 

460 Specifically, it specializes the conversion to a placeholder in order 

461 to maintain dense shape information. 

462 """ 

463 

464 def _to_placeholder(self): 

465 spec = self.type_spec 

466 

467 # nest.map_structure loses dense shape information for sparse tensors. 

468 # So, we special-case sparse placeholder creation. 

469 # This only preserves shape information for top-level sparse tensors; 

470 # not for sparse tensors that are nested inside another composite 

471 # tensor. 

472 return tf.compat.v1.sparse_placeholder( 

473 dtype=spec.dtype, shape=spec.shape 

474 ) 

475 

476 

477@keras_export("keras.__internal__.RaggedKerasTensor", v1=[]) 

478class RaggedKerasTensor(KerasTensor): 

479 """A specialized KerasTensor representation for `tf.RaggedTensor`s. 

480 

481 Specifically, it: 

482 

483 1. Specializes the conversion to a placeholder in order 

484 to maintain shape information for non-ragged dimensions. 

485 2. Overloads the KerasTensor's operators with the RaggedTensor versions 

486 when they don't match the `tf.Tensor` versions 

487 3. Exposes some of the instance method/attribute that are unique to 

488 the RaggedTensor API (such as ragged_rank). 

489 """ 

490 

491 def _to_placeholder(self): 

492 ragged_spec = self.type_spec 

493 if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None: 

494 return super()._to_placeholder() 

495 

496 flat_shape = ragged_spec.shape[ragged_spec.ragged_rank :] 

497 result = tf.compat.v1.placeholder(ragged_spec.dtype, flat_shape) 

498 

499 known_num_splits = [] 

500 prod = 1 

501 for axis_size in ragged_spec.shape: 

502 if prod is not None: 

503 if axis_size is None or ( 

504 getattr(axis_size, "value", True) is None 

505 ): 

506 prod = None 

507 else: 

508 prod = prod * axis_size 

509 known_num_splits.append(prod) 

510 

511 for axis in range(ragged_spec.ragged_rank, 0, -1): 

512 axis_size = ragged_spec.shape[axis] 

513 if axis_size is None or (getattr(axis_size, "value", True) is None): 

514 num_splits = known_num_splits[axis - 1] 

515 if num_splits is not None: 

516 num_splits = num_splits + 1 

517 splits = tf.compat.v1.placeholder( 

518 ragged_spec.row_splits_dtype, [num_splits] 

519 ) 

520 result = tf.RaggedTensor.from_row_splits( 

521 result, splits, validate=False 

522 ) 

523 else: 

524 rowlen = tf.constant(axis_size, ragged_spec.row_splits_dtype) 

525 result = tf.RaggedTensor.from_uniform_row_length( 

526 result, rowlen, validate=False 

527 ) 

528 return result 

529 

530 @property 

531 def ragged_rank(self): 

532 return self.type_spec.ragged_rank 

533 

534 

535# Overload slicing 

536RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__getitem__") 

537 

538# Overload math ops 

539RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__add__") 

540RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__radd__") 

541RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__mul__") 

542RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__rmul__") 

543 

544 

545# TODO(b/161487382): 

546# Special-case user-registered symbolic objects (registered by the 

547# private `register_symbolic_tensor_type` method) by passing them between 

548# scratch graphs directly. 

549# This is needed to not break Tensorflow probability 

550# while they finish migrating to composite tensors. 

551class UserRegisteredSpec(tf.TypeSpec): 

552 """TypeSpec to represent user-registered symbolic objects.""" 

553 

554 def __init__(self, shape, dtype): 

555 self.shape = shape 

556 self._dtype = dtype 

557 self.dtype = dtype 

558 

559 def _component_specs(self): 

560 raise NotImplementedError 

561 

562 def _from_components(self, components): 

563 raise NotImplementedError 

564 

565 def _serialize(self): 

566 raise NotImplementedError 

567 

568 def _to_components(self, value): 

569 raise NotImplementedError 

570 

571 def value_type(self): 

572 raise NotImplementedError 

573 

574 

575# TODO(b/161487382): 

576# Special-case user-registered symbolic objects (registered by the 

577# private `register_symbolic_tensor_type` method) by passing them between 

578# scratch graphs directly. 

579# This is needed to not break Tensorflow probability 

580# while they finish migrating to composite tensors. 

581class UserRegisteredTypeKerasTensor(KerasTensor): 

582 """KerasTensor that represents legacy register_symbolic_tensor_type.""" 

583 

584 def __init__(self, user_registered_symbolic_object): 

585 x = user_registered_symbolic_object 

586 self._user_registered_symbolic_object = x 

587 type_spec = UserRegisteredSpec(x.shape, x.dtype) 

588 name = getattr(x, "name", None) 

589 

590 super().__init__(type_spec, name) 

591 

592 @classmethod 

593 def from_tensor(cls, tensor): 

594 return cls(tensor) 

595 

596 @classmethod 

597 def from_type_spec(cls, type_spec, name=None): 

598 raise NotImplementedError( 

599 "You cannot instantiate a KerasTensor directly from TypeSpec: %s" 

600 % type_spec 

601 ) 

602 

603 def _to_placeholder(self): 

604 return self._user_registered_symbolic_object 

605 

606 

607class _KerasTensorIterator: 

608 """Iterates over the leading dim of a KerasTensor. Performs 0 error 

609 checks.""" 

610 

611 def __init__(self, tensor, dim0): 

612 self._tensor = tensor 

613 self._index = 0 

614 self._limit = dim0 

615 

616 def __iter__(self): 

617 return self 

618 

619 def __next__(self): 

620 if self._index == self._limit: 

621 raise StopIteration 

622 result = self._tensor[self._index] 

623 self._index += 1 

624 return result 

625 

626 

627# Specify the mappings of tensor class to KerasTensor class. 

628# This is specifically a list instead of a dict for now because 

629# 1. we do a check w/ isinstance because a key lookup based on class 

630# would miss subclasses 

631# 2. a list allows us to control lookup ordering 

632# We include ops.Tensor -> KerasTensor in the first position as a fastpath, 

633# *and* include object -> KerasTensor at the end as a catch-all. 

634# We can re-visit these choices in the future as needed. 

635keras_tensor_classes = [ 

636 (tf.Tensor, KerasTensor), 

637 (tf.SparseTensor, SparseKerasTensor), 

638 (tf.RaggedTensor, RaggedKerasTensor), 

639 (object, KerasTensor), 

640] 

641 

642 

643def register_keras_tensor_specialization(cls, keras_tensor_subclass): 

644 """Register a specialized KerasTensor subclass for a Tensor type.""" 

645 # We always leave (object, KerasTensor) at the end as a generic fallback 

646 keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass)) 

647 

648 

649def keras_tensor_to_placeholder(x): 

650 """Construct a graph placeholder to represent a KerasTensor when tracing.""" 

651 if isinstance(x, KerasTensor): 

652 return x._to_placeholder() 

653 else: 

654 return x 

655 

656 

657def keras_tensor_from_tensor(tensor): 

658 """Convert a traced (composite)tensor to a representative KerasTensor.""" 

659 # Create a specialized KerasTensor that supports instance methods, 

660 # operators, and additional value inference if possible 

661 keras_tensor_cls = None 

662 for tensor_type, cls in keras_tensor_classes: 

663 if isinstance(tensor, tensor_type): 

664 keras_tensor_cls = cls 

665 break 

666 

667 out = keras_tensor_cls.from_tensor(tensor) 

668 

669 if getattr(tensor, "_keras_mask", None) is not None: 

670 out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask) 

671 return out 

672 

673 

674def keras_tensor_from_type_spec(type_spec, name=None): 

675 """Convert a TypeSpec to a representative KerasTensor.""" 

676 # Create a specialized KerasTensor that supports instance methods, 

677 # operators, and additional value inference if possible 

678 keras_tensor_cls = None 

679 value_type = type_spec.value_type 

680 for tensor_type, cls in keras_tensor_classes: 

681 if issubclass(value_type, tensor_type): 

682 keras_tensor_cls = cls 

683 break 

684 

685 return keras_tensor_cls.from_type_spec(type_spec, name=name) 

686 

687 

688def type_spec_with_shape(spec, shape): 

689 """Returns a copy of TypeSpec `spec` with its shape set to `shape`.""" 

690 if isinstance(spec, tf.TensorSpec): 

691 

692 # TODO(b/203201161) Figure out why mutation is needed here, and remove 

693 # it. (TensorSpec objects should be immutable; and we should not be 

694 # modifying private fields.) 

695 shape = tf.TensorShape(shape) 

696 spec._shape = shape 

697 return spec 

698 elif isinstance(spec, tf.RaggedTensorSpec): 

699 return tf.RaggedTensorSpec( 

700 shape, 

701 spec.dtype, 

702 spec.ragged_rank, 

703 spec.row_splits_dtype, 

704 spec.flat_values_spec, 

705 ) 

706 elif isinstance(spec, tf.SparseTensorSpec): 

707 return tf.SparseTensorSpec(shape, spec.dtype) 

708 elif hasattr(spec, "with_shape"): 

709 # TODO(edloper): Consider adding .with_shape method to TensorSpec, 

710 # RaggedTensorSpec, and SparseTensorSpec. 

711 return spec.with_shape(shape) 

712 else: 

713 # TODO(edloper): Consider moving this check to the KerasTensor 

714 # constructor. 

715 raise ValueError( 

716 "Keras requires TypeSpec to have a `with_shape` method " 

717 "that returns a copy of `self` with an updated shape." 

718 ) 

719