Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/engine/keras_tensor.py: 33%
262 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras Input Tensor used to track functional API Topology."""
17import tensorflow.compat.v2 as tf
19from keras.src.utils import object_identity
21# isort: off
22from tensorflow.python.data.util import structure
23from tensorflow.python.util.tf_export import keras_export
26# Tensorflow tensors have a maximum rank of 254
27# (See `MaxDimensions()` in //tensorflow/core/framework/tensor_shape.h )
28# So we do not try to infer values for int32 tensors larger than this,
29# As they cannot represent shapes.
30_MAX_TENSOR_RANK = 254
33@keras_export("keras.__internal__.KerasTensor", v1=[])
34class KerasTensor:
35 """A representation of a Keras in/output during Functional API construction.
37 `KerasTensor`s are tensor-like objects that represent the symbolic inputs
38 and outputs of Keras layers during Functional model construction. They are
39 comprised of the `tf.TypeSpec` of the (Composite)Tensor that will be
40 consumed/produced in the corresponding location of the Functional model.
42 KerasTensors are intended as a private API, so users should never need to
43 directly instantiate `KerasTensor`s.
45 **Building Functional Models with KerasTensors**
46 `tf.keras.Input` produces `KerasTensor`s that represent the symbolic inputs
47 to your model.
49 Passing a `KerasTensor` to a `tf.keras.Layer` `__call__` lets the layer know
50 that you are building a Functional model. The layer __call__ will
51 infer the output signature and return `KerasTensor`s with `tf.TypeSpec`s
52 corresponding to the symbolic outputs of that layer call. These output
53 `KerasTensor`s will have all of the internal KerasHistory metadata attached
54 to them that Keras needs to construct a Functional Model.
56 Currently, layers infer the output signature by:
57 * creating a scratch `FuncGraph`
58 * making placeholders in the scratch graph that match the input typespecs
59 * Calling `layer.call` on these placeholders
60 * extracting the signatures of the outputs before clearing the scratch
61 graph
63 (Note: names assigned to KerasTensors by this process are not guaranteed to
64 be unique, and are subject to implementation details).
66 `tf.nest` methods are used to insure all of the inputs/output data
67 structures get maintained, with elements swapped between KerasTensors and
68 placeholders.
70 In rare cases (such as when directly manipulating shapes using Keras
71 layers), the layer may be able to partially infer the value of the output in
72 addition to just inferring the signature.
73 When this happens, the returned KerasTensor will also contain the inferred
74 value information. Follow-on layers can use this information.
75 during their own output signature inference.
76 E.g. if one layer produces a symbolic `KerasTensor` that the next layer uses
77 as the shape of its outputs, partially knowing the value helps infer the
78 output shape.
80 **Automatically converting TF APIs to layers**:
81 If you passing a `KerasTensor` to a TF API that supports dispatching,
82 Keras will automatically turn that API call into a lambda
83 layer in the Functional model, and return KerasTensors representing the
84 symbolic outputs.
86 Most TF APIs that take only tensors as input and produce output tensors
87 will support dispatching.
89 Calling a `tf.function` does not support dispatching, so you cannot pass
90 `KerasTensor`s as inputs to a `tf.function`.
92 Higher-order APIs that take methods which produce tensors (e.g. `tf.while`,
93 `tf.map_fn`, `tf.cond`) also do not currently support dispatching. So, you
94 cannot directly pass KerasTensors as inputs to these APIs either. If you
95 want to use these APIs inside of a Functional model, you must put them
96 inside of a custom layer.
98 Args:
99 type_spec: The `tf.TypeSpec` for the symbolic input created by
100 `tf.keras.Input`, or symbolically inferred for the output
101 during a symbolic layer `__call__`.
102 inferred_value: (Optional) a non-symbolic static value, possibly partially
103 specified, that could be symbolically inferred for the outputs during
104 a symbolic layer `__call__`. This will generally only happen when
105 grabbing and manipulating `tf.int32` shapes directly as tensors.
106 Statically inferring values in this way and storing them in the
107 KerasTensor allows follow-on layers to infer output signatures
108 more effectively. (e.g. when using a symbolic shape tensor to later
109 construct a tensor with that shape).
110 name: (optional) string name for this KerasTensor. Names automatically
111 generated by symbolic layer `__call__`s are not guaranteed to be unique,
112 and are subject to implementation details.
113 """
115 def __init__(self, type_spec, inferred_value=None, name=None):
116 """Constructs a KerasTensor."""
117 if not isinstance(type_spec, tf.TypeSpec):
118 raise ValueError(
119 "KerasTensors must be constructed with a `tf.TypeSpec`."
120 )
122 self._type_spec = type_spec
123 self._inferred_value = inferred_value
124 self._name = name
126 if not isinstance(type_spec, structure.NoneTensorSpec):
127 if not hasattr(type_spec, "shape"):
128 raise ValueError(
129 "KerasTensor only supports TypeSpecs that have a shape "
130 f"field; got {type(type_spec).__qualname__}, "
131 "which does not have a shape."
132 )
133 if not isinstance(type_spec.shape, tf.TensorShape):
134 raise TypeError(
135 "KerasTensor requires that wrapped TypeSpec's shape is a "
136 f"TensorShape; got TypeSpec {type(type_spec).__qualname__}"
137 ", whose shape field has unexpected type "
138 f"{type(type_spec.dtype).__qualname__}."
139 )
141 @property
142 def type_spec(self):
143 """Returns the `tf.TypeSpec` symbolically inferred for Keras output."""
144 return self._type_spec
146 @property
147 def shape(self):
148 """Returns the `TensorShape` symbolically inferred for Keras output."""
149 return self._type_spec.shape
151 @classmethod
152 def from_tensor(cls, tensor):
153 """Convert a traced (composite)tensor to a representative
154 KerasTensor."""
155 if isinstance(tensor, tf.Tensor):
156 name = getattr(tensor, "name", None)
157 type_spec = tf.type_spec_from_value(tensor)
158 inferred_value = None
159 if (
160 type_spec.dtype == tf.int32
161 and type_spec.shape.rank is not None
162 and type_spec.shape.rank < 2
163 ):
164 # If this tensor might be representing shape information,
165 # (dtype=int32, rank of 0 or 1, not too large to represent a
166 # shape) we attempt to capture any value information
167 # tensorflow's shape handling can extract from the current
168 # scratch graph.
169 #
170 # Even though keras layers each trace in their own scratch
171 # graph, this shape value info extraction allows us to capture a
172 # sizable and useful subset of the C++ shape value inference TF
173 # can do if all tf ops appear in the same graph when using shape
174 # ops.
175 #
176 # Examples of things this cannot infer concrete dimensions for
177 # that the full single-graph C++ shape inference sometimes can
178 # are:
179 # * cases where the shape tensor is cast out of int32 before
180 # being manipulated w/ floating point numbers then converted
181 # back
182 # * cases where int32 tensors w/ rank >= 2 are manipulated
183 # before being used as a shape tensor
184 # * cases where int32 tensors too large to represent shapes are
185 # manipulated to a smaller size before being used as a shape
186 # tensor
187 inferred_value = tf.ones(shape=tensor).shape
188 if inferred_value.dims:
189 inferred_value = inferred_value.as_list()
190 if len(inferred_value) > _MAX_TENSOR_RANK:
191 inferred_value = None
192 else:
193 inferred_value = None
195 return KerasTensor(
196 type_spec, inferred_value=inferred_value, name=name
197 )
198 else:
199 # Fallback to the generic arbitrary-typespec KerasTensor
200 name = getattr(tensor, "name", None)
201 type_spec = tf.type_spec_from_value(tensor)
202 return cls(type_spec, name=name)
204 @classmethod
205 def from_type_spec(cls, type_spec, name=None):
206 return cls(type_spec=type_spec, name=name)
208 def _to_placeholder(self):
209 """Convert this KerasTensor to a placeholder in a graph."""
210 # If there is an inferred value for this tensor, inject the inferred
211 # value
212 if self._inferred_value is not None:
213 # If we suspect this KerasTensor might be representing a shape
214 # tensor, and we were able to extract value information with
215 # TensorFlow's shape handling when making the KerasTensor, we
216 # construct the placeholder by re-injecting the inferred value
217 # information into the graph. We do this injection through the shape
218 # of a placeholder, because that allows us to specify
219 # partially-unspecified shape values.
220 #
221 # See the comment on value extraction inside `from_tensor` for more
222 # info.
223 inferred_value = tf.shape(
224 tf.compat.v1.placeholder(
225 shape=self._inferred_value, dtype=tf.int32
226 )
227 )
228 if self.type_spec.shape.rank == 0:
229 # `tf.shape` always returns a rank-1, we may need to turn it
230 # back to a scalar.
231 inferred_value = inferred_value[0]
232 return inferred_value
234 # Use the generic conversion from typespec to a placeholder.
235 def component_to_placeholder(component):
236 return tf.compat.v1.placeholder(component.dtype, component.shape)
238 return tf.nest.map_structure(
239 component_to_placeholder, self.type_spec, expand_composites=True
240 )
242 def get_shape(self):
243 return self.shape
245 def __len__(self):
246 raise TypeError(
247 "Keras symbolic inputs/outputs do not "
248 "implement `__len__`. You may be "
249 "trying to pass Keras symbolic inputs/outputs "
250 "to a TF API that does not register dispatching, "
251 "preventing Keras from automatically "
252 "converting the API call to a lambda layer "
253 "in the Functional Model. This error will also get raised "
254 "if you try asserting a symbolic input/output directly."
255 )
257 @property
258 def op(self):
259 raise TypeError(
260 "Keras symbolic inputs/outputs do not "
261 "implement `op`. You may be "
262 "trying to pass Keras symbolic inputs/outputs "
263 "to a TF API that does not register dispatching, "
264 "preventing Keras from automatically "
265 "converting the API call to a lambda layer "
266 "in the Functional Model."
267 )
269 def __hash__(self):
270 raise TypeError(
271 f"Tensors are unhashable (this tensor: {self}). "
272 "Instead, use tensor.ref() as the key."
273 )
275 # Note: This enables the KerasTensor's overloaded "right" binary
276 # operators to run when the left operand is an ndarray, because it
277 # accords the Tensor class higher priority than an ndarray, or a
278 # numpy matrix.
279 # In the future explore changing this to using numpy's __numpy_ufunc__
280 # mechanism, which allows more control over how Tensors interact
281 # with ndarrays.
282 __array_priority__ = 100
284 def __array__(self, dtype=None):
285 raise TypeError(
286 f"You are passing {self}, an intermediate Keras symbolic "
287 "input/output, to a TF API that does not allow registering custom "
288 "dispatchers, such as `tf.cond`, `tf.function`, gradient tapes, "
289 "or `tf.map_fn`. Keras Functional model construction only supports "
290 "TF API calls that *do* support dispatching, such as `tf.math.add` "
291 "or `tf.reshape`. "
292 "Other APIs cannot be called directly on symbolic Keras"
293 "inputs/outputs. You can work around "
294 "this limitation by putting the operation in a custom Keras layer "
295 "`call` and calling that layer "
296 "on this symbolic input/output."
297 )
299 @property
300 def is_tensor_like(self):
301 return True
303 def set_shape(self, shape):
304 """Updates the shape of this KerasTensor. Mimics
305 `tf.Tensor.set_shape()`."""
306 if not isinstance(shape, tf.TensorShape):
307 shape = tf.TensorShape(shape)
308 if not self.shape.is_compatible_with(shape):
309 raise ValueError(
310 f"Keras symbolic input/output's shape {self.shape} is not "
311 f"compatible with supplied shape {shape}."
312 )
313 else:
314 shape = self.shape.merge_with(shape)
315 self._type_spec = type_spec_with_shape(self._type_spec, shape)
317 def __str__(self):
318 symbolic_description = ""
319 inferred_value_string = ""
320 name_string = ""
322 if hasattr(self, "_keras_history"):
323 layer = self._keras_history.layer
324 symbolic_description = ", description=\"created by layer '%s'\"" % (
325 layer.name,
326 )
327 if self._inferred_value is not None:
328 inferred_value_string = f", inferred_value={self._inferred_value}"
329 if self.name is not None:
330 name_string = f", name='{self._name}'"
331 return "KerasTensor(type_spec=%s%s%s%s)" % (
332 self.type_spec,
333 inferred_value_string,
334 name_string,
335 symbolic_description,
336 )
338 def __repr__(self):
339 symbolic_description = ""
340 inferred_value_string = ""
341 if isinstance(self.type_spec, tf.TensorSpec):
342 type_spec_string = f"shape={self.shape} dtype={self.dtype.name}"
343 else:
344 type_spec_string = f"type_spec={self.type_spec}"
346 if hasattr(self, "_keras_history"):
347 layer = self._keras_history.layer
348 symbolic_description = f" (created by layer '{layer.name}')"
349 if self._inferred_value is not None:
350 inferred_value_string = f" inferred_value={self._inferred_value}"
351 return "<KerasTensor: %s%s%s>" % (
352 type_spec_string,
353 inferred_value_string,
354 symbolic_description,
355 )
357 @property
358 def dtype(self):
359 """Returns the `dtype` symbolically inferred for this Keras output."""
360 type_spec = self._type_spec
361 if not hasattr(type_spec, "dtype"):
362 raise AttributeError(
363 f"KerasTensor wraps TypeSpec {type(type_spec).__qualname__}, "
364 "which does not have a dtype."
365 )
366 if not isinstance(type_spec.dtype, tf.DType):
367 raise TypeError(
368 "KerasTensor requires that wrapped TypeSpec's dtype is a "
369 f"DType; got TypeSpec {type(type_spec).__qualname__}, whose "
370 "dtype field has unexpected type "
371 f"{type(type_spec.dtype).__qualname__}."
372 )
373 return type_spec.dtype
375 def ref(self):
376 """Returns a hashable reference object to this KerasTensor.
378 The primary use case for this API is to put KerasTensors in a
379 set/dictionary. We can't put tensors in a set/dictionary as
380 `tensor.__hash__()` is not available and tensor equality (`==`) is
381 supposed to produce a tensor representing if the two inputs are equal.
383 See the documentation of `tf.Tensor.ref()` for more info.
384 """
385 return object_identity.Reference(self)
387 @property
388 def node(self):
389 """Find the corresponding `Node` that produce this keras_tensor.
391 During functional model construction, Keras will attach `KerasHistory`
392 to keras tensor to track the connectivity between calls of layers.
393 Return None if there isn't any KerasHistory attached to this tensor.
394 """
395 if hasattr(self, "_keras_history"):
396 layer, node_index, _ = self._keras_history
397 return layer.inbound_nodes[node_index]
398 return None
400 def __iter__(self):
401 shape = None
402 if self.shape.ndims is not None:
403 shape = [dim.value for dim in self.shape.dims]
405 if shape is None:
406 raise TypeError("Cannot iterate over a Tensor with unknown shape.")
407 if not shape:
408 raise TypeError("Cannot iterate over a scalar.")
409 if shape[0] is None:
410 raise TypeError(
411 "Cannot iterate over a Tensor with unknown first dimension."
412 )
413 return _KerasTensorIterator(self, shape[0])
415 @property
416 def name(self):
417 """Returns the (non-unique, optional) name of this symbolic Keras
418 value."""
419 return self._name
421 @classmethod
422 def _overload_all_operators(cls, tensor_class):
423 """Register overloads for all operators."""
424 for operator in tf.Tensor.OVERLOADABLE_OPERATORS:
425 cls._overload_operator(tensor_class, operator)
427 # We include `experimental_ref` for versions of TensorFlow that
428 # still include the deprecated method in Tensors.
429 if hasattr(tensor_class, "experimental_ref"):
430 cls._overload_operator(tensor_class, "experimental_ref")
432 @classmethod
433 def _overload_operator(cls, tensor_class, operator):
434 """Overload operator with the same implementation as the Tensor class.
436 We pull the operator out of the class dynamically to avoid ordering
437 issues.
439 Args:
440 tensor_class: The (Composite)Tensor to get the method from.
441 operator: string. The operator name.
442 """
443 tensor_oper = getattr(tensor_class, operator)
445 # Compatibility with Python 2:
446 # Python 2 unbound methods have type checks for the first arg,
447 # so we need to extract the underlying function
448 tensor_oper = getattr(tensor_oper, "__func__", tensor_oper)
450 setattr(cls, operator, tensor_oper)
453KerasTensor._overload_all_operators(tf.Tensor)
456@keras_export("keras.__internal__.SparseKerasTensor", v1=[])
457class SparseKerasTensor(KerasTensor):
458 """A specialized KerasTensor representation for `tf.sparse.SparseTensor`s.
460 Specifically, it specializes the conversion to a placeholder in order
461 to maintain dense shape information.
462 """
464 def _to_placeholder(self):
465 spec = self.type_spec
467 # nest.map_structure loses dense shape information for sparse tensors.
468 # So, we special-case sparse placeholder creation.
469 # This only preserves shape information for top-level sparse tensors;
470 # not for sparse tensors that are nested inside another composite
471 # tensor.
472 return tf.compat.v1.sparse_placeholder(
473 dtype=spec.dtype, shape=spec.shape
474 )
477@keras_export("keras.__internal__.RaggedKerasTensor", v1=[])
478class RaggedKerasTensor(KerasTensor):
479 """A specialized KerasTensor representation for `tf.RaggedTensor`s.
481 Specifically, it:
483 1. Specializes the conversion to a placeholder in order
484 to maintain shape information for non-ragged dimensions.
485 2. Overloads the KerasTensor's operators with the RaggedTensor versions
486 when they don't match the `tf.Tensor` versions
487 3. Exposes some of the instance method/attribute that are unique to
488 the RaggedTensor API (such as ragged_rank).
489 """
491 def _to_placeholder(self):
492 ragged_spec = self.type_spec
493 if ragged_spec.ragged_rank == 0 or ragged_spec.shape.rank is None:
494 return super()._to_placeholder()
496 flat_shape = ragged_spec.shape[ragged_spec.ragged_rank :]
497 result = tf.compat.v1.placeholder(ragged_spec.dtype, flat_shape)
499 known_num_splits = []
500 prod = 1
501 for axis_size in ragged_spec.shape:
502 if prod is not None:
503 if axis_size is None or (
504 getattr(axis_size, "value", True) is None
505 ):
506 prod = None
507 else:
508 prod = prod * axis_size
509 known_num_splits.append(prod)
511 for axis in range(ragged_spec.ragged_rank, 0, -1):
512 axis_size = ragged_spec.shape[axis]
513 if axis_size is None or (getattr(axis_size, "value", True) is None):
514 num_splits = known_num_splits[axis - 1]
515 if num_splits is not None:
516 num_splits = num_splits + 1
517 splits = tf.compat.v1.placeholder(
518 ragged_spec.row_splits_dtype, [num_splits]
519 )
520 result = tf.RaggedTensor.from_row_splits(
521 result, splits, validate=False
522 )
523 else:
524 rowlen = tf.constant(axis_size, ragged_spec.row_splits_dtype)
525 result = tf.RaggedTensor.from_uniform_row_length(
526 result, rowlen, validate=False
527 )
528 return result
530 @property
531 def ragged_rank(self):
532 return self.type_spec.ragged_rank
535# Overload slicing
536RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__getitem__")
538# Overload math ops
539RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__add__")
540RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__radd__")
541RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__mul__")
542RaggedKerasTensor._overload_operator(tf.RaggedTensor, "__rmul__")
545# TODO(b/161487382):
546# Special-case user-registered symbolic objects (registered by the
547# private `register_symbolic_tensor_type` method) by passing them between
548# scratch graphs directly.
549# This is needed to not break Tensorflow probability
550# while they finish migrating to composite tensors.
551class UserRegisteredSpec(tf.TypeSpec):
552 """TypeSpec to represent user-registered symbolic objects."""
554 def __init__(self, shape, dtype):
555 self.shape = shape
556 self._dtype = dtype
557 self.dtype = dtype
559 def _component_specs(self):
560 raise NotImplementedError
562 def _from_components(self, components):
563 raise NotImplementedError
565 def _serialize(self):
566 raise NotImplementedError
568 def _to_components(self, value):
569 raise NotImplementedError
571 def value_type(self):
572 raise NotImplementedError
575# TODO(b/161487382):
576# Special-case user-registered symbolic objects (registered by the
577# private `register_symbolic_tensor_type` method) by passing them between
578# scratch graphs directly.
579# This is needed to not break Tensorflow probability
580# while they finish migrating to composite tensors.
581class UserRegisteredTypeKerasTensor(KerasTensor):
582 """KerasTensor that represents legacy register_symbolic_tensor_type."""
584 def __init__(self, user_registered_symbolic_object):
585 x = user_registered_symbolic_object
586 self._user_registered_symbolic_object = x
587 type_spec = UserRegisteredSpec(x.shape, x.dtype)
588 name = getattr(x, "name", None)
590 super().__init__(type_spec, name)
592 @classmethod
593 def from_tensor(cls, tensor):
594 return cls(tensor)
596 @classmethod
597 def from_type_spec(cls, type_spec, name=None):
598 raise NotImplementedError(
599 "You cannot instantiate a KerasTensor directly from TypeSpec: %s"
600 % type_spec
601 )
603 def _to_placeholder(self):
604 return self._user_registered_symbolic_object
607class _KerasTensorIterator:
608 """Iterates over the leading dim of a KerasTensor. Performs 0 error
609 checks."""
611 def __init__(self, tensor, dim0):
612 self._tensor = tensor
613 self._index = 0
614 self._limit = dim0
616 def __iter__(self):
617 return self
619 def __next__(self):
620 if self._index == self._limit:
621 raise StopIteration
622 result = self._tensor[self._index]
623 self._index += 1
624 return result
627# Specify the mappings of tensor class to KerasTensor class.
628# This is specifically a list instead of a dict for now because
629# 1. we do a check w/ isinstance because a key lookup based on class
630# would miss subclasses
631# 2. a list allows us to control lookup ordering
632# We include ops.Tensor -> KerasTensor in the first position as a fastpath,
633# *and* include object -> KerasTensor at the end as a catch-all.
634# We can re-visit these choices in the future as needed.
635keras_tensor_classes = [
636 (tf.Tensor, KerasTensor),
637 (tf.SparseTensor, SparseKerasTensor),
638 (tf.RaggedTensor, RaggedKerasTensor),
639 (object, KerasTensor),
640]
643def register_keras_tensor_specialization(cls, keras_tensor_subclass):
644 """Register a specialized KerasTensor subclass for a Tensor type."""
645 # We always leave (object, KerasTensor) at the end as a generic fallback
646 keras_tensor_classes.insert(-1, (cls, keras_tensor_subclass))
649def keras_tensor_to_placeholder(x):
650 """Construct a graph placeholder to represent a KerasTensor when tracing."""
651 if isinstance(x, KerasTensor):
652 return x._to_placeholder()
653 else:
654 return x
657def keras_tensor_from_tensor(tensor):
658 """Convert a traced (composite)tensor to a representative KerasTensor."""
659 # Create a specialized KerasTensor that supports instance methods,
660 # operators, and additional value inference if possible
661 keras_tensor_cls = None
662 for tensor_type, cls in keras_tensor_classes:
663 if isinstance(tensor, tensor_type):
664 keras_tensor_cls = cls
665 break
667 out = keras_tensor_cls.from_tensor(tensor)
669 if getattr(tensor, "_keras_mask", None) is not None:
670 out._keras_mask = keras_tensor_from_tensor(tensor._keras_mask)
671 return out
674def keras_tensor_from_type_spec(type_spec, name=None):
675 """Convert a TypeSpec to a representative KerasTensor."""
676 # Create a specialized KerasTensor that supports instance methods,
677 # operators, and additional value inference if possible
678 keras_tensor_cls = None
679 value_type = type_spec.value_type
680 for tensor_type, cls in keras_tensor_classes:
681 if issubclass(value_type, tensor_type):
682 keras_tensor_cls = cls
683 break
685 return keras_tensor_cls.from_type_spec(type_spec, name=name)
688def type_spec_with_shape(spec, shape):
689 """Returns a copy of TypeSpec `spec` with its shape set to `shape`."""
690 if isinstance(spec, tf.TensorSpec):
692 # TODO(b/203201161) Figure out why mutation is needed here, and remove
693 # it. (TensorSpec objects should be immutable; and we should not be
694 # modifying private fields.)
695 shape = tf.TensorShape(shape)
696 spec._shape = shape
697 return spec
698 elif isinstance(spec, tf.RaggedTensorSpec):
699 return tf.RaggedTensorSpec(
700 shape,
701 spec.dtype,
702 spec.ragged_rank,
703 spec.row_splits_dtype,
704 spec.flat_values_spec,
705 )
706 elif isinstance(spec, tf.SparseTensorSpec):
707 return tf.SparseTensorSpec(shape, spec.dtype)
708 elif hasattr(spec, "with_shape"):
709 # TODO(edloper): Consider adding .with_shape method to TensorSpec,
710 # RaggedTensorSpec, and SparseTensorSpec.
711 return spec.with_shape(shape)
712 else:
713 # TODO(edloper): Consider moving this check to the KerasTensor
714 # constructor.
715 raise ValueError(
716 "Keras requires TypeSpec to have a `with_shape` method "
717 "that returns a copy of `self` with an updated shape."
718 )