Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor.py: 43%

301 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tensor and TensorSpec classes.""" 

16 

17from typing import Type 

18 

19import numpy as np 

20 

21from tensorflow.core.framework import attr_value_pb2 

22from tensorflow.core.function import trace_type 

23from tensorflow.core.protobuf import struct_pb2 

24from tensorflow.python.framework import common_shapes 

25from tensorflow.python.framework import constant_op 

26from tensorflow.python.framework import dtypes 

27from tensorflow.python.framework import op_callbacks 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import tensor_shape 

30from tensorflow.python.framework import tensor_util 

31from tensorflow.python.framework import type_spec 

32from tensorflow.python.framework import type_spec_registry 

33from tensorflow.python.ops import gen_array_ops 

34from tensorflow.python.ops import handle_data_util 

35from tensorflow.python.platform import tf_logging as logging 

36from tensorflow.python.saved_model import nested_structure_coder 

37from tensorflow.python.types import core as core_tf_types 

38from tensorflow.python.types import internal 

39from tensorflow.python.util import _pywrap_utils 

40from tensorflow.python.util import compat 

41from tensorflow.python.util.tf_export import tf_export 

42 

43 

44# TODO(b/249802365): Sanitize all TensorSpec names. 

45def sanitize_spec_name(name: str) -> str: 

46 """Sanitizes Spec names. Matches Graph Node and Python naming conventions. 

47 

48 Without sanitization, names that are not legal Python parameter names can be 

49 set which makes it challenging to represent callables supporting the named 

50 calling capability. 

51 

52 Args: 

53 name: The name to sanitize. 

54 

55 Returns: 

56 A string that meets Python parameter conventions. 

57 """ 

58 if not name: 

59 return "unknown" 

60 

61 # Lower case and replace non-alphanumeric chars with '_' 

62 swapped = "".join([c if c.isalnum() else "_" for c in name.lower()]) 

63 

64 if swapped[0].isalpha(): 

65 return swapped 

66 else: 

67 return "tensor_" + swapped 

68 

69 

70def get_op_name(tensor_name): 

71 """Extract the Op name from a Tensor name. 

72 

73 The Op name is everything before a colon, if present, 

74 not including any ^ prefix denoting a control dependency. 

75 

76 Args: 

77 tensor_name: the full name of a Tensor in the graph. 

78 Returns: 

79 The name of the Op of which the given Tensor is an output. 

80 Raises: 

81 ValueError: if tensor_name is None or empty. 

82 """ 

83 if not tensor_name: 

84 raise ValueError( 

85 f"Tensor name cannot be empty or None. Received: {tensor_name}.") 

86 

87 # Control dependency inputs start with ^. 

88 if tensor_name.startswith("^"): 

89 tensor_name = tensor_name[1:] 

90 if ":" in tensor_name: 

91 op_name, _ = tensor_name.split(":") 

92 return op_name 

93 return tensor_name 

94 

95 

96class DenseSpec(type_spec.TypeSpec): 

97 """Describes a dense object with shape, dtype, and name.""" 

98 

99 __slots__ = ["_shape", "_dtype", "_name"] 

100 

101 _component_specs = property(lambda self: self) 

102 

103 def __init__(self, shape, dtype=dtypes.float32, name=None): 

104 """Creates a TensorSpec. 

105 

106 Args: 

107 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 

108 dtype: Value convertible to `tf.DType`. The type of the tensor values. 

109 name: Optional name for the Tensor. 

110 

111 Raises: 

112 TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is 

113 not convertible to a `tf.DType`. 

114 """ 

115 self._shape = tensor_shape.TensorShape(shape) 

116 self._dtype = dtypes.as_dtype(dtype) 

117 self._name = name 

118 

119 @property 

120 def shape(self): 

121 """Returns the `TensorShape` that represents the shape of the tensor.""" 

122 return self._shape 

123 

124 @property 

125 def dtype(self): 

126 """Returns the `dtype` of elements in the tensor.""" 

127 return self._dtype 

128 

129 @property 

130 def name(self): 

131 """Returns the (optionally provided) name of the described tensor.""" 

132 return self._name 

133 

134 def is_compatible_with(self, spec_or_value): 

135 return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and 

136 self._dtype.is_compatible_with(spec_or_value.dtype) and 

137 self._shape.is_compatible_with(spec_or_value.shape)) 

138 

139 def __repr__(self): 

140 return "{}(shape={}, dtype={}, name={})".format( 

141 type(self).__name__, self.shape, repr(self.dtype), repr(self.name)) 

142 

143 def __hash__(self): 

144 return hash((self._shape, self.dtype)) 

145 

146 def __eq__(self, other): 

147 # pylint: disable=protected-access 

148 return (type(self) is type(other) and self._shape == other._shape and 

149 self._dtype == other._dtype and self._name == other._name) 

150 

151 def __ne__(self, other): 

152 return not self == other 

153 

154 def _serialize(self): 

155 return (self._shape, self._dtype, self._name) 

156 

157 def _to_legacy_output_types(self): 

158 return self._dtype 

159 

160 def _to_legacy_output_shapes(self): 

161 return self._shape 

162 

163 def _to_legacy_output_classes(self): 

164 return self.value_type 

165 

166 

167@tf_export("TensorSpec") 

168@type_spec_registry.register("tf.TensorSpec") 

169class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec, 

170 trace_type.Serializable, internal.TensorSpec): 

171 """Describes the type of a tf.Tensor. 

172 

173 >>> t = tf.constant([[1,2,3],[4,5,6]]) 

174 >>> tf.TensorSpec.from_tensor(t) 

175 TensorSpec(shape=(2, 3), dtype=tf.int32, name=None) 

176 

177 Contains metadata for describing the the nature of `tf.Tensor` objects 

178 accepted or returned by some TensorFlow APIs. 

179 

180 For example, it can be used to constrain the type of inputs accepted by 

181 a tf.function: 

182 

183 >>> @tf.function(input_signature=[tf.TensorSpec([1, None])]) 

184 ... def constrained_foo(t): 

185 ... print("tracing...") 

186 ... return t 

187 

188 Now the `tf.function` is able to assume that `t` is always of the type 

189 `tf.TensorSpec([1, None])` which will avoid retracing as well as enforce the 

190 type restriction on inputs. 

191 

192 As a result, the following call with tensor of type `tf.TensorSpec([1, 2])` 

193 triggers a trace and succeeds: 

194 >>> constrained_foo(tf.constant([[1., 2]])).numpy() 

195 tracing... 

196 array([[1., 2.]], dtype=float32) 

197 

198 The following subsequent call with tensor of type `tf.TensorSpec([1, 4])` 

199 does not trigger a trace and succeeds: 

200 >>> constrained_foo(tf.constant([[1., 2, 3, 4]])).numpy() 

201 array([[1., 2., 3., 4.], dtype=float32) 

202 

203 But the following call with tensor of type `tf.TensorSpec([2, 2])` fails: 

204 >>> constrained_foo(tf.constant([[1., 2], [3, 4]])).numpy() 

205 Traceback (most recent call last): 

206 ... 

207 TypeError: Binding inputs to tf.function `constrained_foo` failed ... 

208 

209 """ 

210 

211 __slots__ = [] 

212 

213 @classmethod 

214 def experimental_type_proto(cls) -> Type[struct_pb2.TensorSpecProto]: 

215 """Returns the type of proto associated with TensorSpec serialization.""" 

216 return struct_pb2.TensorSpecProto 

217 

218 @classmethod 

219 def experimental_from_proto( 

220 cls, proto: struct_pb2.TensorSpecProto) -> "TensorSpec": 

221 """Returns a TensorSpec instance based on the serialized proto.""" 

222 return TensorSpec( 

223 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape), 

224 dtype=proto.dtype, 

225 name=proto.name if proto.name else None) 

226 

227 def experimental_as_proto(self) -> struct_pb2.TensorSpecProto: 

228 """Returns a proto representation of the TensorSpec instance.""" 

229 return struct_pb2.TensorSpecProto( 

230 shape=self.shape.experimental_as_proto(), 

231 dtype=self.dtype.experimental_as_proto().datatype, 

232 name=self.name) 

233 

234 def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation,arguments-renamed 

235 """Returns True if spec_or_tensor is compatible with this TensorSpec. 

236 

237 Two tensors are considered compatible if they have the same dtype 

238 and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). 

239 

240 Args: 

241 spec_or_tensor: A tf.TensorSpec or a tf.Tensor 

242 

243 Returns: 

244 True if spec_or_tensor is compatible with self. 

245 """ 

246 return super(TensorSpec, self).is_compatible_with(spec_or_tensor) 

247 

248 def is_subtype_of(self, other): 

249 if not isinstance(other, TensorSpec): 

250 return False 

251 

252 return ( 

253 (not self.name or self.name == other.name) 

254 and self.shape.is_subtype_of(other.shape) 

255 and self.dtype.is_subtype_of(other.dtype) 

256 ) 

257 

258 def placeholder_value(self, placeholder_context): 

259 """Generates a graph_placholder with the given TensorSpec information.""" 

260 if placeholder_context.unnest_only: 

261 return self 

262 

263 name = self.name or placeholder_context.naming_scope 

264 context_graph = placeholder_context.context_graph 

265 if placeholder_context.with_none_control_dependencies: 

266 # Note: setting ops.control_dependencies(None) ensures we always put 

267 # capturing placeholders outside of any control flow context. 

268 with context_graph.control_dependencies(None): 

269 placeholder = self._graph_placeholder(context_graph, name=name) 

270 else: 

271 placeholder = self._graph_placeholder(context_graph, name=name) 

272 

273 if name is not None: 

274 # Record the requested/user-specified name in case it's different than 

275 # the uniquified name, for validation when exporting signatures. 

276 placeholder.op._set_attr( # pylint: disable=protected-access 

277 "_user_specified_name", 

278 attr_value_pb2.AttrValue(s=compat.as_bytes(name))) 

279 

280 handle_data = self.dtype._handle_data # pylint: disable=protected-access 

281 if ( 

282 handle_data is not None 

283 and handle_data.is_set 

284 and handle_data.shape_and_type 

285 ): 

286 handle_data_util.set_handle_data(placeholder, handle_data) 

287 

288 # Record the composite device as an attribute to the placeholder. 

289 # This attribute would be propagated into the arg_attr of the FunctionDef. 

290 # Currently, a packed eager tensor is always placed on a CompositeDevice. 

291 if placeholder_context.composite_device_name is not None: 

292 placeholder.op._set_attr( # pylint: disable=protected-access 

293 "_composite_device", 

294 attr_value_pb2.AttrValue(s=compat.as_bytes( 

295 placeholder_context.composite_device_name))) 

296 

297 return placeholder 

298 

299 def _graph_placeholder(self, graph, name=None): 

300 """Graph-only version of tf.compat.v1.placeholder(), for internal use only.""" 

301 dtype = self.dtype.base_dtype 

302 shape = self.shape 

303 dtype_value = attr_value_pb2.AttrValue(type=dtype.as_datatype_enum) 

304 if isinstance(shape, (list, tuple)): 

305 shape = tensor_shape.TensorShape(shape) 

306 shape = attr_value_pb2.AttrValue(shape=shape.as_proto()) 

307 attrs = {"dtype": dtype_value, "shape": shape} 

308 try: 

309 op = graph._create_op_internal( # pylint: disable=protected-access 

310 "Placeholder", [], [dtype], input_types=[], 

311 attrs=attrs, name=name) 

312 except ValueError as e: 

313 # TODO(b/262413656) Sometimes parameter names are not valid op names, in 

314 # which case an unnamed placeholder is created instead. Update this logic 

315 # to sanitize the name instead of falling back on unnamed placeholders. 

316 logging.warning(e) 

317 op = graph._create_op_internal( # pylint: disable=protected-access 

318 "Placeholder", [], [dtype], input_types=[], attrs=attrs) 

319 (result,) = op.outputs 

320 if op_callbacks.should_invoke_op_callbacks(): 

321 # TODO(b/147670703): Once the special-op creation code paths 

322 # are unified. Remove this `if` block. 

323 callback_outputs = op_callbacks.invoke_op_callbacks( 

324 "Placeholder", tuple(), attrs, tuple(op.outputs), 

325 op_name=name, graph=graph) 

326 if callback_outputs is not None: 

327 (result,) = callback_outputs 

328 return result 

329 

330 def _to_tensors(self, value): 

331 assert isinstance(value, ops.Tensor) 

332 return [value] 

333 

334 def _flatten(self): 

335 return [self] 

336 

337 def _cast(self, value, casting_context): 

338 """Cast value to a tensor that is a subtype of this TensorSpec.""" 

339 # This method is mainly used to cast Python primitives to tensor. 

340 # Currently, cast tensor to tensor with different types are not supported. 

341 # For example, casting int32 to float32 would raise a ValueError. 

342 if casting_context.allow_specs and isinstance(value, TensorSpec): 

343 assert value.is_subtype_of(self), f"Can not cast {value!r} to {self!r}" 

344 return self 

345 

346 value = ops.convert_to_tensor(value, self.dtype) 

347 value_spec = TensorSpec(value.shape, value.dtype, self.name) 

348 

349 if not value_spec.is_subtype_of(self): 

350 if self.is_subtype_of(value_spec): 

351 gen_array_ops.ensure_shape(value, self.shape) 

352 else: 

353 raise AssertionError(f"Can not cast {value_spec!r} to {self!r}") 

354 

355 return value 

356 

357 @classmethod 

358 def from_spec(cls, spec, name=None): 

359 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 

360 

361 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName") 

362 >>> tf.TensorSpec.from_spec(spec, "NewName") 

363 TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName') 

364 

365 Args: 

366 spec: The `TypeSpec` used to create the new `TensorSpec`. 

367 name: The name for the new `TensorSpec`. Defaults to `spec.name`. 

368 """ 

369 return cls(spec.shape, spec.dtype, name or spec.name) 

370 

371 @classmethod 

372 def from_tensor(cls, tensor, name=None): 

373 """Returns a `TensorSpec` that describes `tensor`. 

374 

375 >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3])) 

376 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 

377 

378 Args: 

379 tensor: The `tf.Tensor` that should be described. 

380 name: A name for the `TensorSpec`. Defaults to `tensor.op.name`. 

381 

382 Returns: 

383 A `TensorSpec` that describes `tensor`. 

384 """ 

385 if isinstance(tensor, ops.EagerTensor): 

386 return TensorSpec(tensor.shape, tensor.dtype, name) 

387 elif isinstance(tensor, ops.Tensor): 

388 # TODO(b/249802365): Return a sanitized version of op name or no name. 

389 return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) 

390 else: 

391 raise ValueError( 

392 f"`tensor` should be a tf.Tensor, but got type {type(tensor)}.") 

393 

394 @property 

395 def value_type(self): 

396 """The Python type for values that are compatible with this TypeSpec.""" 

397 return ops.Tensor 

398 

399 def _to_components(self, value): 

400 assert isinstance(value, core_tf_types.Tensor) 

401 return value 

402 

403 def _from_components(self, components): 

404 return components 

405 

406 def _from_compatible_tensor_list(self, tensor_list): 

407 # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()` 

408 # op here and return that, instead of mutating the input's shape using 

409 # `Tensor.set_shape()`. However, that would add extra ops, which could 

410 # impact performance. When this bug is resolved, we should be able to add 

411 # the `ensure_shape()` ops and optimize them away using contextual shape 

412 # information. 

413 assert len(tensor_list) == 1 

414 tensor_list[0].set_shape(self._shape) 

415 return tensor_list[0] 

416 

417 def _to_batchable_tensor_list(self, value, batched=False): 

418 if batched and self._shape.merge_with(value.shape).ndims == 0: 

419 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 

420 return self._to_components(value) 

421 

422 def _batch(self, batch_size): 

423 return TensorSpec( 

424 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 

425 self._dtype) 

426 

427 def _unbatch(self): 

428 if self._shape.ndims == 0: 

429 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 

430 return TensorSpec(self._shape[1:], self._dtype) 

431 

432 @property 

433 def _flat_tensor_specs(self): 

434 return [self] 

435 

436 def _to_tensor_list(self, value): 

437 return [self._to_components(value)] 

438 

439 def _to_batched_tensor_list(self, value): 

440 return self._to_tensor_list(value) 

441 

442 # TODO(b/206014848): Helper function to support logic that does not consider 

443 # Tensor name. Will be removed once load-bearing usages of Tensor name are 

444 # fixed. 

445 def _without_tensor_names(self) -> "TensorSpec": 

446 """Returns a version of `TensorSpec` with the name removed.""" 

447 if self.name is None: 

448 return self 

449 else: 

450 return TensorSpec(self.shape, self.dtype) 

451 

452trace_type.register_serializable(TensorSpec) 

453trace_type.register_tensor_type(TensorSpec) 

454 

455 

456class _TensorCodec: 

457 """Codec for Tensor.""" 

458 

459 def can_encode(self, pyobj): 

460 return isinstance(pyobj, ops.Tensor) 

461 

462 def do_encode(self, tensor_value, encode_fn): 

463 """Returns an encoded `TensorProto` for the given `tf.Tensor`.""" 

464 del encode_fn 

465 encoded_tensor = struct_pb2.StructuredValue() 

466 if isinstance(tensor_value, ops.EagerTensor): 

467 encoded_tensor.tensor_value.CopyFrom( 

468 tensor_util.make_tensor_proto(tensor_value.numpy()) 

469 ) 

470 else: 

471 if tensor_value.op.type == "Const": 

472 encoded_tensor.tensor_value.CopyFrom(tensor_value.op.get_attr("value")) 

473 else: 

474 raise nested_structure_coder.NotEncodableError( 

475 f"No encoder for object {str(tensor_value)} of type" 

476 f" {type(tensor_value)}." 

477 ) 

478 return encoded_tensor 

479 

480 def can_decode(self, value): 

481 return value.HasField("tensor_value") 

482 

483 def do_decode(self, value, decode_fn): 

484 """Returns the `tf.Tensor` encoded by the proto `value`.""" 

485 del decode_fn 

486 tensor_proto = value.tensor_value 

487 tensor = constant_op.constant(tensor_util.MakeNdarray(tensor_proto)) 

488 return tensor 

489 

490 

491nested_structure_coder.register_codec(_TensorCodec()) 

492 

493 

494class _TensorSpecCodec: 

495 """Codec for `TensorSpec`.""" 

496 

497 def can_encode(self, pyobj): 

498 # BoundedTensorSpec has its own decoder. 

499 return (isinstance(pyobj, TensorSpec) and 

500 not isinstance(pyobj, BoundedTensorSpec)) 

501 

502 def do_encode(self, tensor_spec_value, encode_fn): 

503 encoded_tensor_spec = struct_pb2.StructuredValue() 

504 encoded_tensor_spec.tensor_spec_value.CopyFrom( 

505 struct_pb2.TensorSpecProto( 

506 shape=encode_fn(tensor_spec_value.shape).tensor_shape_value, 

507 dtype=encode_fn(tensor_spec_value.dtype).tensor_dtype_value, 

508 name=tensor_spec_value.name)) 

509 return encoded_tensor_spec 

510 

511 def can_decode(self, value): 

512 return value.HasField("tensor_spec_value") 

513 

514 def do_decode(self, value, decode_fn): 

515 name = value.tensor_spec_value.name 

516 return TensorSpec( 

517 shape=decode_fn( 

518 struct_pb2.StructuredValue( 

519 tensor_shape_value=value.tensor_spec_value.shape)), 

520 dtype=decode_fn( 

521 struct_pb2.StructuredValue( 

522 tensor_dtype_value=value.tensor_spec_value.dtype)), 

523 name=(name if name else None)) 

524 

525 

526nested_structure_coder.register_codec(_TensorSpecCodec()) 

527 

528 

529# TODO(b/133606651): Should is_compatible_with should check min/max bounds? 

530@type_spec_registry.register("tf.BoundedTensorSpec") 

531class BoundedTensorSpec(TensorSpec, trace_type.Serializable): 

532 """A `TensorSpec` that specifies minimum and maximum values. 

533 

534 Example usage: 

535 ```python 

536 spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) 

537 tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) 

538 tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) 

539 ``` 

540 

541 Bounds are meant to be inclusive. This is especially important for 

542 integer types. The following spec will be satisfied by tensors 

543 with values in the set {0, 1, 2}: 

544 ```python 

545 spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) 

546 ``` 

547 """ 

548 

549 __slots__ = ("_minimum", "_maximum") 

550 

551 def __init__(self, shape, dtype, minimum, maximum, name=None): 

552 """Initializes a new `BoundedTensorSpec`. 

553 

554 Args: 

555 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 

556 dtype: Value convertible to `tf.DType`. The type of the tensor values. 

557 minimum: Number or sequence specifying the minimum element bounds 

558 (inclusive). Must be broadcastable to `shape`. 

559 maximum: Number or sequence specifying the maximum element bounds 

560 (inclusive). Must be broadcastable to `shape`. 

561 name: Optional string containing a semantic name for the corresponding 

562 array. Defaults to `None`. 

563 

564 Raises: 

565 ValueError: If `minimum` or `maximum` are not provided or not 

566 broadcastable to `shape`. 

567 TypeError: If the shape is not an iterable or if the `dtype` is an invalid 

568 numpy dtype. 

569 """ 

570 super(BoundedTensorSpec, self).__init__(shape, dtype, name) 

571 

572 if minimum is None: 

573 raise ValueError("`minimum` can not be None.") 

574 if maximum is None: 

575 raise ValueError("`maximum` can not be None.") 

576 

577 try: 

578 minimum_shape = np.shape(minimum) 

579 common_shapes.broadcast_shape( 

580 tensor_shape.TensorShape(minimum_shape), self.shape) 

581 except ValueError as exception: 

582 raise ValueError( 

583 f"`minimum` {minimum} is not compatible with shape {self.shape}." 

584 ) from exception 

585 

586 try: 

587 maximum_shape = np.shape(maximum) 

588 common_shapes.broadcast_shape( 

589 tensor_shape.TensorShape(maximum_shape), self.shape) 

590 except ValueError as exception: 

591 raise ValueError( 

592 f"`maximum` {maximum} is not compatible with shape {self.shape}." 

593 ) from exception 

594 

595 self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype) 

596 self._minimum.setflags(write=False) 

597 

598 self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype) 

599 self._maximum.setflags(write=False) 

600 

601 @classmethod 

602 def experimental_type_proto(cls) -> Type[struct_pb2.BoundedTensorSpecProto]: 

603 """Returns the type of proto associated with BoundedTensorSpec serialization.""" 

604 return struct_pb2.BoundedTensorSpecProto 

605 

606 @classmethod 

607 def experimental_from_proto( 

608 cls, proto: struct_pb2.BoundedTensorSpecProto) -> "BoundedTensorSpec": 

609 """Returns a BoundedTensorSpec instance based on the serialized proto.""" 

610 return BoundedTensorSpec( 

611 shape=tensor_shape.TensorShape.experimental_from_proto(proto.shape), 

612 dtype=proto.dtype, 

613 minimum=tensor_util.MakeNdarray(proto.minimum), 

614 maximum=tensor_util.MakeNdarray(proto.maximum), 

615 name=proto.name if proto.name else None) 

616 

617 def experimental_as_proto(self) -> struct_pb2.BoundedTensorSpecProto: 

618 """Returns a proto representation of the BoundedTensorSpec instance.""" 

619 return struct_pb2.BoundedTensorSpecProto( 

620 shape=self.shape.experimental_as_proto(), 

621 dtype=self.dtype.experimental_as_proto().datatype, 

622 minimum=tensor_util.make_tensor_proto(self._minimum), 

623 maximum=tensor_util.make_tensor_proto(self._maximum), 

624 name=self.name) 

625 

626 @classmethod 

627 def from_spec(cls, spec): 

628 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 

629 

630 If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to 

631 `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to 

632 `spec.dtype.min` and `spec.dtype.max`. 

633 

634 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x") 

635 >>> BoundedTensorSpec.from_spec(spec) 

636 BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x', 

637 minimum=array(-2147483648, dtype=int32), 

638 maximum=array(2147483647, dtype=int32)) 

639 

640 Args: 

641 spec: The `TypeSpec` used to create the new `BoundedTensorSpec`. 

642 """ 

643 dtype = dtypes.as_dtype(spec.dtype) 

644 minimum = getattr(spec, "minimum", dtype.min) 

645 maximum = getattr(spec, "maximum", dtype.max) 

646 return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) 

647 

648 @property 

649 def minimum(self): 

650 """Returns a NumPy array specifying the minimum bounds (inclusive).""" 

651 return self._minimum 

652 

653 @property 

654 def maximum(self): 

655 """Returns a NumPy array specifying the maximum bounds (inclusive).""" 

656 return self._maximum 

657 

658 def _cast(self, value, casting_context): 

659 if casting_context.allow_specs and isinstance(value, BoundedTensorSpec): 

660 assert value.is_subtype_of(self), f"Can not cast {value!r} to {self!r}" 

661 return self 

662 

663 actual_spec = TensorSpec(shape=self.shape, dtype=self.dtype, name=self.name) 

664 return actual_spec._cast(value, casting_context) # pylint: disable=protected-access 

665 

666 def __repr__(self): 

667 s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" 

668 return s.format(self.shape, repr(self.dtype), repr(self.name), 

669 repr(self.minimum), repr(self.maximum)) 

670 

671 def __eq__(self, other): 

672 tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) 

673 return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and 

674 np.allclose(self.maximum, other.maximum)) 

675 

676 def __hash__(self): 

677 return hash((self._shape, self.dtype)) 

678 

679 def __reduce__(self): 

680 return BoundedTensorSpec, (self._shape, self._dtype, self._minimum, 

681 self._maximum, self._name) 

682 

683 def _serialize(self): 

684 return (self._shape, self._dtype, self._minimum, self._maximum, self._name) 

685 

686 

687class _BoundedTensorSpecCodec: 

688 """Codec for `BoundedTensorSpec`.""" 

689 

690 def can_encode(self, pyobj): 

691 return isinstance(pyobj, BoundedTensorSpec) 

692 

693 def do_encode(self, bounded_tensor_spec_value, encode_fn): 

694 """Returns an encoded proto for the given `tf.BoundedTensorSpec`.""" 

695 encoded_bounded_tensor_spec = struct_pb2.StructuredValue() 

696 encoded_bounded_tensor_spec.bounded_tensor_spec_value.CopyFrom( 

697 struct_pb2.BoundedTensorSpecProto( 

698 shape=encode_fn(bounded_tensor_spec_value.shape).tensor_shape_value, 

699 dtype=encode_fn(bounded_tensor_spec_value.dtype).tensor_dtype_value, 

700 name=bounded_tensor_spec_value.name, 

701 minimum=tensor_util.make_tensor_proto( 

702 bounded_tensor_spec_value.minimum), 

703 maximum=tensor_util.make_tensor_proto( 

704 bounded_tensor_spec_value.maximum))) 

705 return encoded_bounded_tensor_spec 

706 

707 def can_decode(self, value): 

708 return value.HasField("bounded_tensor_spec_value") 

709 

710 def do_decode(self, value, decode_fn): 

711 btsv = value.bounded_tensor_spec_value 

712 name = btsv.name 

713 return BoundedTensorSpec( 

714 shape=decode_fn( 

715 struct_pb2.StructuredValue(tensor_shape_value=btsv.shape)), 

716 dtype=decode_fn( 

717 struct_pb2.StructuredValue(tensor_dtype_value=btsv.dtype)), 

718 minimum=tensor_util.MakeNdarray(btsv.minimum), 

719 maximum=tensor_util.MakeNdarray(btsv.maximum), 

720 name=(name if name else None)) 

721 

722 

723nested_structure_coder.register_codec(_BoundedTensorSpecCodec()) 

724 

725trace_type.register_serializable(BoundedTensorSpec) 

726_pywrap_utils.RegisterType("TensorSpec", TensorSpec) 

727 

728# Note: we do not include Tensor names when constructing TypeSpecs. 

729type_spec.register_type_spec_from_value_converter( 

730 ops.Tensor, lambda tensor: TensorSpec(tensor.shape, tensor.dtype)) 

731 

732type_spec.register_type_spec_from_value_converter( 

733 np.ndarray, lambda array: TensorSpec(array.shape, array.dtype))