Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor.py: 43%
301 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensor and TensorSpec classes."""
17from typing import Type
19import numpy as np
21from tensorflow.core.framework import attr_value_pb2
22from tensorflow.core.function import trace_type
23from tensorflow.core.protobuf import struct_pb2
24from tensorflow.python.framework import common_shapes
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import op_callbacks
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.framework import type_spec
32from tensorflow.python.framework import type_spec_registry
33from tensorflow.python.ops import gen_array_ops
34from tensorflow.python.ops import handle_data_util
35from tensorflow.python.platform import tf_logging as logging
36from tensorflow.python.saved_model import nested_structure_coder
37from tensorflow.python.types import core as core_tf_types
38from tensorflow.python.types import internal
39from tensorflow.python.util import _pywrap_utils
40from tensorflow.python.util import compat
41from tensorflow.python.util.tf_export import tf_export
44# TODO(b/249802365): Sanitize all TensorSpec names.
45def sanitize_spec_name(name: str) -> str:
46 """Sanitizes Spec names. Matches Graph Node and Python naming conventions.
48 Without sanitization, names that are not legal Python parameter names can be
49 set which makes it challenging to represent callables supporting the named
50 calling capability.
52 Args:
53 name: The name to sanitize.
55 Returns:
56 A string that meets Python parameter conventions.
57 """
58 if not name:
59 return "unknown"
61 # Lower case and replace non-alphanumeric chars with '_'
62 swapped = "".join([c if c.isalnum() else "_" for c in name.lower()])
64 if swapped[0].isalpha():
65 return swapped
66 else:
67 return "tensor_" + swapped
70def get_op_name(tensor_name):
71 """Extract the Op name from a Tensor name.
73 The Op name is everything before a colon, if present,
74 not including any ^ prefix denoting a control dependency.
76 Args:
77 tensor_name: the full name of a Tensor in the graph.
78 Returns:
79 The name of the Op of which the given Tensor is an output.
80 Raises:
81 ValueError: if tensor_name is None or empty.
82 """
83 if not tensor_name:
84 raise ValueError(
85 f"Tensor name cannot be empty or None. Received: {tensor_name}.")
87 # Control dependency inputs start with ^.
88 if tensor_name.startswith("^"):
89 tensor_name = tensor_name[1:]
90 if ":" in tensor_name:
91 op_name, _ = tensor_name.split(":")
92 return op_name
93 return tensor_name
96class DenseSpec(type_spec.TypeSpec):
97 """Describes a dense object with shape, dtype, and name."""
99 __slots__ = ["_shape", "_dtype", "_name"]
101 _component_specs = property(lambda self: self)
103 def __init__(self, shape, dtype=dtypes.float32, name=None):
104 """Creates a TensorSpec.
106 Args:
107 shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
108 dtype: Value convertible to `tf.DType`. The type of the tensor values.
109 name: Optional name for the Tensor.
111 Raises:
112 TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
113 not convertible to a `tf.DType`.
114 """
115 self._shape = tensor_shape.TensorShape(shape)
116 self._dtype = dtypes.as_dtype(dtype)
117 self._name = name
119 @property
120 def shape(self):
121 """Returns the `TensorShape` that represents the shape of the tensor."""
122 return self._shape
124 @property
125 def dtype(self):
126 """Returns the `dtype` of elements in the tensor."""
127 return self._dtype
129 @property
130 def name(self):
131 """Returns the (optionally provided) name of the described tensor."""
132 return self._name
134 def is_compatible_with(self, spec_or_value):
135 return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and
136 self._dtype.is_compatible_with(spec_or_value.dtype) and
137 self._shape.is_compatible_with(spec_or_value.shape))
139 def __repr__(self):
140 return "{}(shape={}, dtype={}, name={})".format(
141 type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
143 def __hash__(self):
144 return hash((self._shape, self.dtype))
146 def __eq__(self, other):
147 # pylint: disable=protected-access
148 return (type(self) is type(other) and self._shape == other._shape and
149 self._dtype == other._dtype and self._name == other._name)
151 def __ne__(self, other):
152 return not self == other
154 def _serialize(self):
155 return (self._shape, self._dtype, self._name)
157 def _to_legacy_output_types(self):
158 return self._dtype
160 def _to_legacy_output_shapes(self):
161 return self._shape
163 def _to_legacy_output_classes(self):
164 return self.value_type
167@tf_export("TensorSpec")
168@type_spec_registry.register("tf.TensorSpec")
169class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec,
170 trace_type.Serializable, internal.TensorSpec):
171 """Describes the type of a tf.Tensor.
173 >>> t = tf.constant([[1,2,3],[4,5,6]])
174 >>> tf.TensorSpec.from_tensor(t)
175 TensorSpec(shape=(2, 3), dtype=tf.int32, name=None)
177 Contains metadata for describing the the nature of `tf.Tensor` objects
178 accepted or returned by some TensorFlow APIs.
180 For example, it can be used to constrain the type of inputs accepted by
181 a tf.function:
183 >>> @tf.function(input_signature=[tf.TensorSpec([1, None])])
184 ... def constrained_foo(t):
185 ... print("tracing...")
186 ... return t
188 Now the `tf.function` is able to assume that `t` is always of the type
189 `tf.TensorSpec([1, None])` which will avoid retracing as well as enforce the
190 type restriction on inputs.
192 As a result, the following call with tensor of type `tf.TensorSpec([1, 2])`
193 triggers a trace and succeeds:
194 >>> constrained_foo(tf.constant([[1., 2]])).numpy()
195 tracing...
196 array([[1., 2.]], dtype=float32)
198 The following subsequent call with tensor of type `tf.TensorSpec([1, 4])`
199 does not trigger a trace and succeeds:
200 >>> constrained_foo(tf.constant([[1., 2, 3, 4]])).numpy()
201 array([[1., 2., 3., 4.], dtype=float32)
203 But the following call with tensor of type `tf.TensorSpec([2, 2])` fails:
204 >>> constrained_foo(tf.constant([[1., 2], [3, 4]])).numpy()
205 Traceback (most recent call last):
206 ...
207 TypeError: Binding inputs to tf.function `constrained_foo` failed ...
209 """
211 __slots__ = []
213 @classmethod
214 def experimental_type_proto(cls) -> Type[struct_pb2.TensorSpecProto]:
215 """Returns the type of proto associated with TensorSpec serialization."""
216 return struct_pb2.TensorSpecProto
218 @classmethod
219 def experimental_from_proto(
220 cls, proto: struct_pb2.TensorSpecProto) -> "TensorSpec":
221 """Returns a TensorSpec instance based on the serialized proto."""
222 return TensorSpec(
223 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
224 dtype=proto.dtype,
225 name=proto.name if proto.name else None)
227 def experimental_as_proto(self) -> struct_pb2.TensorSpecProto:
228 """Returns a proto representation of the TensorSpec instance."""
229 return struct_pb2.TensorSpecProto(
230 shape=self.shape.experimental_as_proto(),
231 dtype=self.dtype.experimental_as_proto().datatype,
232 name=self.name)
234 def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation,arguments-renamed
235 """Returns True if spec_or_tensor is compatible with this TensorSpec.
237 Two tensors are considered compatible if they have the same dtype
238 and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
240 Args:
241 spec_or_tensor: A tf.TensorSpec or a tf.Tensor
243 Returns:
244 True if spec_or_tensor is compatible with self.
245 """
246 return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
248 def is_subtype_of(self, other):
249 if not isinstance(other, TensorSpec):
250 return False
252 return (
253 (not self.name or self.name == other.name)
254 and self.shape.is_subtype_of(other.shape)
255 and self.dtype.is_subtype_of(other.dtype)
256 )
258 def placeholder_value(self, placeholder_context):
259 """Generates a graph_placholder with the given TensorSpec information."""
260 if placeholder_context.unnest_only:
261 return self
263 name = self.name or placeholder_context.naming_scope
264 context_graph = placeholder_context.context_graph
265 if placeholder_context.with_none_control_dependencies:
266 # Note: setting ops.control_dependencies(None) ensures we always put
267 # capturing placeholders outside of any control flow context.
268 with context_graph.control_dependencies(None):
269 placeholder = self._graph_placeholder(context_graph, name=name)
270 else:
271 placeholder = self._graph_placeholder(context_graph, name=name)
273 if name is not None:
274 # Record the requested/user-specified name in case it's different than
275 # the uniquified name, for validation when exporting signatures.
276 placeholder.op._set_attr( # pylint: disable=protected-access
277 "_user_specified_name",
278 attr_value_pb2.AttrValue(s=compat.as_bytes(name)))
280 handle_data = self.dtype._handle_data # pylint: disable=protected-access
281 if (
282 handle_data is not None
283 and handle_data.is_set
284 and handle_data.shape_and_type
285 ):
286 handle_data_util.set_handle_data(placeholder, handle_data)
288 # Record the composite device as an attribute to the placeholder.
289 # This attribute would be propagated into the arg_attr of the FunctionDef.
290 # Currently, a packed eager tensor is always placed on a CompositeDevice.
291 if placeholder_context.composite_device_name is not None:
292 placeholder.op._set_attr( # pylint: disable=protected-access
293 "_composite_device",
294 attr_value_pb2.AttrValue(s=compat.as_bytes(
295 placeholder_context.composite_device_name)))
297 return placeholder
299 def _graph_placeholder(self, graph, name=None):
300 """Graph-only version of tf.compat.v1.placeholder(), for internal use only."""
301 dtype = self.dtype.base_dtype
302 shape = self.shape
303 dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum)
304 if isinstance(shape, (list, tuple)):
305 shape = tensor_shape.TensorShape(shape)
306 shape = attr_value_pb2.AttrValue(shape=shape.as_proto())
307 attrs = {"dtype": dtype_value, "shape": shape}
308 try:
309 op = graph._create_op_internal( # pylint: disable=protected-access
310 "Placeholder", [], [dtype], input_types=[],
311 attrs=attrs, name=name)
312 except ValueError as e:
313 # TODO(b/262413656) Sometimes parameter names are not valid op names, in
314 # which case an unnamed placeholder is created instead. Update this logic
315 # to sanitize the name instead of falling back on unnamed placeholders.
316 logging.warning(e)
317 op = graph._create_op_internal( # pylint: disable=protected-access
318 "Placeholder", [], [dtype], input_types=[], attrs=attrs)
319 (result,) = op.outputs
320 if op_callbacks.should_invoke_op_callbacks():
321 # TODO(b/147670703): Once the special-op creation code paths
322 # are unified. Remove this `if` block.
323 callback_outputs = op_callbacks.invoke_op_callbacks(
324 "Placeholder", tuple(), attrs, tuple(op.outputs),
325 op_name=name, graph=graph)
326 if callback_outputs is not None:
327 (result,) = callback_outputs
328 return result
330 def _to_tensors(self, value):
331 assert isinstance(value, ops.Tensor)
332 return [value]
334 def _flatten(self):
335 return [self]
337 def _cast(self, value, casting_context):
338 """Cast value to a tensor that is a subtype of this TensorSpec."""
339 # This method is mainly used to cast Python primitives to tensor.
340 # Currently, cast tensor to tensor with different types are not supported.
341 # For example, casting int32 to float32 would raise a ValueError.
342 if casting_context.allow_specs and isinstance(value, TensorSpec):
343 assert value.is_subtype_of(self), f"Can not cast {value!r} to {self!r}"
344 return self
346 value = ops.convert_to_tensor(value, self.dtype)
347 value_spec = TensorSpec(value.shape, value.dtype, self.name)
349 if not value_spec.is_subtype_of(self):
350 if self.is_subtype_of(value_spec):
351 gen_array_ops.ensure_shape(value, self.shape)
352 else:
353 raise AssertionError(f"Can not cast {value_spec!r} to {self!r}")
355 return value
357 @classmethod
358 def from_spec(cls, spec, name=None):
359 """Returns a `TensorSpec` with the same shape and dtype as `spec`.
361 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
362 >>> tf.TensorSpec.from_spec(spec, "NewName")
363 TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')
365 Args:
366 spec: The `TypeSpec` used to create the new `TensorSpec`.
367 name: The name for the new `TensorSpec`. Defaults to `spec.name`.
368 """
369 return cls(spec.shape, spec.dtype, name or spec.name)
371 @classmethod
372 def from_tensor(cls, tensor, name=None):
373 """Returns a `TensorSpec` that describes `tensor`.
375 >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
376 TensorSpec(shape=(3,), dtype=tf.int32, name=None)
378 Args:
379 tensor: The `tf.Tensor` that should be described.
380 name: A name for the `TensorSpec`. Defaults to `tensor.op.name`.
382 Returns:
383 A `TensorSpec` that describes `tensor`.
384 """
385 if isinstance(tensor, ops.EagerTensor):
386 return TensorSpec(tensor.shape, tensor.dtype, name)
387 elif isinstance(tensor, ops.Tensor):
388 # TODO(b/249802365): Return a sanitized version of op name or no name.
389 return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
390 else:
391 raise ValueError(
392 f"`tensor` should be a tf.Tensor, but got type {type(tensor)}.")
394 @property
395 def value_type(self):
396 """The Python type for values that are compatible with this TypeSpec."""
397 return ops.Tensor
399 def _to_components(self, value):
400 assert isinstance(value, core_tf_types.Tensor)
401 return value
403 def _from_components(self, components):
404 return components
406 def _from_compatible_tensor_list(self, tensor_list):
407 # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
408 # op here and return that, instead of mutating the input's shape using
409 # `Tensor.set_shape()`. However, that would add extra ops, which could
410 # impact performance. When this bug is resolved, we should be able to add
411 # the `ensure_shape()` ops and optimize them away using contextual shape
412 # information.
413 assert len(tensor_list) == 1
414 tensor_list[0].set_shape(self._shape)
415 return tensor_list[0]
417 def _to_batchable_tensor_list(self, value, batched=False):
418 if batched and self._shape.merge_with(value.shape).ndims == 0:
419 raise ValueError("Unbatching a tensor is only supported for rank >= 1")
420 return self._to_components(value)
422 def _batch(self, batch_size):
423 return TensorSpec(
424 tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
425 self._dtype)
427 def _unbatch(self):
428 if self._shape.ndims == 0:
429 raise ValueError("Unbatching a tensor is only supported for rank >= 1")
430 return TensorSpec(self._shape[1:], self._dtype)
432 @property
433 def _flat_tensor_specs(self):
434 return [self]
436 def _to_tensor_list(self, value):
437 return [self._to_components(value)]
439 def _to_batched_tensor_list(self, value):
440 return self._to_tensor_list(value)
442 # TODO(b/206014848): Helper function to support logic that does not consider
443 # Tensor name. Will be removed once load-bearing usages of Tensor name are
444 # fixed.
445 def _without_tensor_names(self) -> "TensorSpec":
446 """Returns a version of `TensorSpec` with the name removed."""
447 if self.name is None:
448 return self
449 else:
450 return TensorSpec(self.shape, self.dtype)
452trace_type.register_serializable(TensorSpec)
453trace_type.register_tensor_type(TensorSpec)
456class _TensorCodec:
457 """Codec for Tensor."""
459 def can_encode(self, pyobj):
460 return isinstance(pyobj, ops.Tensor)
462 def do_encode(self, tensor_value, encode_fn):
463 """Returns an encoded `TensorProto` for the given `tf.Tensor`."""
464 del encode_fn
465 encoded_tensor = struct_pb2.StructuredValue()
466 if isinstance(tensor_value, ops.EagerTensor):
467 encoded_tensor.tensor_value.CopyFrom(
468 tensor_util.make_tensor_proto(tensor_value.numpy())
469 )
470 else:
471 if tensor_value.op.type == "Const":
472 encoded_tensor.tensor_value.CopyFrom(tensor_value.op.get_attr("value"))
473 else:
474 raise nested_structure_coder.NotEncodableError(
475 f"No encoder for object {str(tensor_value)} of type"
476 f" {type(tensor_value)}."
477 )
478 return encoded_tensor
480 def can_decode(self, value):
481 return value.HasField("tensor_value")
483 def do_decode(self, value, decode_fn):
484 """Returns the `tf.Tensor` encoded by the proto `value`."""
485 del decode_fn
486 tensor_proto = value.tensor_value
487 tensor = constant_op.constant(tensor_util.MakeNdarray(tensor_proto))
488 return tensor
491nested_structure_coder.register_codec(_TensorCodec())
494class _TensorSpecCodec:
495 """Codec for `TensorSpec`."""
497 def can_encode(self, pyobj):
498 # BoundedTensorSpec has its own decoder.
499 return (isinstance(pyobj, TensorSpec) and
500 not isinstance(pyobj, BoundedTensorSpec))
502 def do_encode(self, tensor_spec_value, encode_fn):
503 encoded_tensor_spec = struct_pb2.StructuredValue()
504 encoded_tensor_spec.tensor_spec_value.CopyFrom(
505 struct_pb2.TensorSpecProto(
506 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value,
507 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value,
508 name=tensor_spec_value.name))
509 return encoded_tensor_spec
511 def can_decode(self, value):
512 return value.HasField("tensor_spec_value")
514 def do_decode(self, value, decode_fn):
515 name = value.tensor_spec_value.name
516 return TensorSpec(
517 shape=decode_fn(
518 struct_pb2.StructuredValue(
519 tensor_shape_value=value.tensor_spec_value.shape)),
520 dtype=decode_fn(
521 struct_pb2.StructuredValue(
522 tensor_dtype_value=value.tensor_spec_value.dtype)),
523 name=(name if name else None))
526nested_structure_coder.register_codec(_TensorSpecCodec())
529# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
530@type_spec_registry.register("tf.BoundedTensorSpec")
531class BoundedTensorSpec(TensorSpec, trace_type.Serializable):
532 """A `TensorSpec` that specifies minimum and maximum values.
534 Example usage:
535 ```python
536 spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
537 tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
538 tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
539 ```
541 Bounds are meant to be inclusive. This is especially important for
542 integer types. The following spec will be satisfied by tensors
543 with values in the set {0, 1, 2}:
544 ```python
545 spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
546 ```
547 """
549 __slots__ = ("_minimum", "_maximum")
551 def __init__(self, shape, dtype, minimum, maximum, name=None):
552 """Initializes a new `BoundedTensorSpec`.
554 Args:
555 shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
556 dtype: Value convertible to `tf.DType`. The type of the tensor values.
557 minimum: Number or sequence specifying the minimum element bounds
558 (inclusive). Must be broadcastable to `shape`.
559 maximum: Number or sequence specifying the maximum element bounds
560 (inclusive). Must be broadcastable to `shape`.
561 name: Optional string containing a semantic name for the corresponding
562 array. Defaults to `None`.
564 Raises:
565 ValueError: If `minimum` or `maximum` are not provided or not
566 broadcastable to `shape`.
567 TypeError: If the shape is not an iterable or if the `dtype` is an invalid
568 numpy dtype.
569 """
570 super(BoundedTensorSpec, self).__init__(shape, dtype, name)
572 if minimum is None:
573 raise ValueError("`minimum` can not be None.")
574 if maximum is None:
575 raise ValueError("`maximum` can not be None.")
577 try:
578 minimum_shape = np.shape(minimum)
579 common_shapes.broadcast_shape(
580 tensor_shape.TensorShape(minimum_shape), self.shape)
581 except ValueError as exception:
582 raise ValueError(
583 f"`minimum` {minimum} is not compatible with shape {self.shape}."
584 ) from exception
586 try:
587 maximum_shape = np.shape(maximum)
588 common_shapes.broadcast_shape(
589 tensor_shape.TensorShape(maximum_shape), self.shape)
590 except ValueError as exception:
591 raise ValueError(
592 f"`maximum` {maximum} is not compatible with shape {self.shape}."
593 ) from exception
595 self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype)
596 self._minimum.setflags(write=False)
598 self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype)
599 self._maximum.setflags(write=False)
601 @classmethod
602 def experimental_type_proto(cls) -> Type[struct_pb2.BoundedTensorSpecProto]:
603 """Returns the type of proto associated with BoundedTensorSpec serialization."""
604 return struct_pb2.BoundedTensorSpecProto
606 @classmethod
607 def experimental_from_proto(
608 cls, proto: struct_pb2.BoundedTensorSpecProto) -> "BoundedTensorSpec":
609 """Returns a BoundedTensorSpec instance based on the serialized proto."""
610 return BoundedTensorSpec(
611 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape),
612 dtype=proto.dtype,
613 minimum=tensor_util.MakeNdarray(proto.minimum),
614 maximum=tensor_util.MakeNdarray(proto.maximum),
615 name=proto.name if proto.name else None)
617 def experimental_as_proto(self) -> struct_pb2.BoundedTensorSpecProto:
618 """Returns a proto representation of the BoundedTensorSpec instance."""
619 return struct_pb2.BoundedTensorSpecProto(
620 shape=self.shape.experimental_as_proto(),
621 dtype=self.dtype.experimental_as_proto().datatype,
622 minimum=tensor_util.make_tensor_proto(self._minimum),
623 maximum=tensor_util.make_tensor_proto(self._maximum),
624 name=self.name)
626 @classmethod
627 def from_spec(cls, spec):
628 """Returns a `TensorSpec` with the same shape and dtype as `spec`.
630 If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
631 `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
632 `spec.dtype.min` and `spec.dtype.max`.
634 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
635 >>> BoundedTensorSpec.from_spec(spec)
636 BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
637 minimum=array(-2147483648, dtype=int32),
638 maximum=array(2147483647, dtype=int32))
640 Args:
641 spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
642 """
643 dtype = dtypes.as_dtype(spec.dtype)
644 minimum = getattr(spec, "minimum", dtype.min)
645 maximum = getattr(spec, "maximum", dtype.max)
646 return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
648 @property
649 def minimum(self):
650 """Returns a NumPy array specifying the minimum bounds (inclusive)."""
651 return self._minimum
653 @property
654 def maximum(self):
655 """Returns a NumPy array specifying the maximum bounds (inclusive)."""
656 return self._maximum
658 def _cast(self, value, casting_context):
659 if casting_context.allow_specs and isinstance(value, BoundedTensorSpec):
660 assert value.is_subtype_of(self), f"Can not cast {value!r} to {self!r}"
661 return self
663 actual_spec = TensorSpec(shape=self.shape, dtype=self.dtype, name=self.name)
664 return actual_spec._cast(value, casting_context) # pylint: disable=protected-access
666 def __repr__(self):
667 s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
668 return s.format(self.shape, repr(self.dtype), repr(self.name),
669 repr(self.minimum), repr(self.maximum))
671 def __eq__(self, other):
672 tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
673 return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
674 np.allclose(self.maximum, other.maximum))
676 def __hash__(self):
677 return hash((self._shape, self.dtype))
679 def __reduce__(self):
680 return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
681 self._maximum, self._name)
683 def _serialize(self):
684 return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
687class _BoundedTensorSpecCodec:
688 """Codec for `BoundedTensorSpec`."""
690 def can_encode(self, pyobj):
691 return isinstance(pyobj, BoundedTensorSpec)
693 def do_encode(self, bounded_tensor_spec_value, encode_fn):
694 """Returns an encoded proto for the given `tf.BoundedTensorSpec`."""
695 encoded_bounded_tensor_spec = struct_pb2.StructuredValue()
696 encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom(
697 struct_pb2.BoundedTensorSpecProto(
698 shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value,
699 dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value,
700 name=bounded_tensor_spec_value.name,
701 minimum=tensor_util.make_tensor_proto(
702 bounded_tensor_spec_value.minimum),
703 maximum=tensor_util.make_tensor_proto(
704 bounded_tensor_spec_value.maximum)))
705 return encoded_bounded_tensor_spec
707 def can_decode(self, value):
708 return value.HasField("bounded_tensor_spec_value")
710 def do_decode(self, value, decode_fn):
711 btsv = value.bounded_tensor_spec_value
712 name = btsv.name
713 return BoundedTensorSpec(
714 shape=decode_fn(
715 struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)),
716 dtype=decode_fn(
717 struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)),
718 minimum=tensor_util.MakeNdarray(btsv.minimum),
719 maximum=tensor_util.MakeNdarray(btsv.maximum),
720 name=(name if name else None))
723nested_structure_coder.register_codec(_BoundedTensorSpecCodec())
725trace_type.register_serializable(BoundedTensorSpec)
726_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
728# Note: we do not include Tensor names when constructing TypeSpecs.
729type_spec.register_type_spec_from_value_converter(
730 ops.Tensor, lambda tensor: TensorSpec(tensor.shape, tensor.dtype))
732type_spec.register_type_spec_from_value_converter(
733 np.ndarray, lambda array: TensorSpec(array.shape, array.dtype))