Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py: 35%

425 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes for tensor shape inference.""" 

16import functools 

17import operator 

18from typing import Optional, Sequence, Type 

19 

20from tensorflow.core.framework import tensor_shape_pb2 

21from tensorflow.core.function import trace_type 

22from tensorflow.core.protobuf import struct_pb2 

23from tensorflow.python import tf2 

24from tensorflow.python.eager import monitoring 

25from tensorflow.python.platform import tf_logging as logging 

26from tensorflow.python.saved_model import nested_structure_coder 

27from tensorflow.python.types import trace 

28from tensorflow.python.util.tf_export import tf_export 

29from tensorflow.tools.docs import doc_controls 

30 

31_TENSORSHAPE_V2_OVERRIDE = None 

32 

33_api_usage_gauge = monitoring.BoolGauge( 

34 "/tensorflow/api/v2_tensorshape", 

35 "Whether tensor_shape.enable_v2_tensorshape() is called.") 

36 

37 

38@tf_export(v1=["enable_v2_tensorshape"]) 

39def enable_v2_tensorshape(): 

40 """In TensorFlow 2.0, iterating over a TensorShape instance returns values. 

41 

42 This enables the new behavior. 

43 

44 Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but 

45 it V2 it returns either an integer, or None. 

46 

47 Examples: 

48 

49 ``` 

50 ####################### 

51 # If you had this in V1: 

52 value = tensor_shape[i].value 

53 

54 # Do this in V2 instead: 

55 value = tensor_shape[i] 

56 

57 ####################### 

58 # If you had this in V1: 

59 for dim in tensor_shape: 

60 value = dim.value 

61 print(value) 

62 

63 # Do this in V2 instead: 

64 for value in tensor_shape: 

65 print(value) 

66 

67 ####################### 

68 # If you had this in V1: 

69 dim = tensor_shape[i] 

70 dim.assert_is_compatible_with(other_shape) # or using any other shape method 

71 

72 # Do this in V2 instead: 

73 if tensor_shape.rank is None: 

74 dim = Dimension(None) 

75 else: 

76 dim = tensor_shape.dims[i] 

77 dim.assert_is_compatible_with(other_shape) # or using any other shape method 

78 

79 # The V2 suggestion above is more explicit, which will save you from 

80 # the following trap (present in V1): 

81 # you might do in-place modifications to `dim` and expect them to be reflected 

82 # in `tensor_shape[i]`, but they would not be. 

83 ``` 

84 """ 

85 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 

86 _TENSORSHAPE_V2_OVERRIDE = True 

87 logging.vlog(1, "Enabling v2 tensorshape") 

88 _api_usage_gauge.get_cell().set(True) 

89 

90 

91@tf_export(v1=["disable_v2_tensorshape"]) 

92def disable_v2_tensorshape(): 

93 """Disables the V2 TensorShape behavior and reverts to V1 behavior. 

94 

95 See docstring for `enable_v2_tensorshape` for details about the new behavior. 

96 """ 

97 global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name 

98 _TENSORSHAPE_V2_OVERRIDE = False 

99 logging.vlog(1, "Disabling v2 tensorshape") 

100 _api_usage_gauge.get_cell().set(False) 

101 

102 

103@tf_export( 

104 "compat.dimension_value", v1=["dimension_value", "compat.dimension_value"]) 

105def dimension_value(dimension): 

106 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 

107 

108 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 

109 coexist with the new behavior. This utility is a bridge between the two. 

110 

111 When accessing the value of a TensorShape dimension, 

112 use this utility, like this: 

113 

114 ``` 

115 # If you had this in your V1 code: 

116 value = tensor_shape[i].value 

117 

118 # Use `dimension_value` as direct replacement compatible with both V1 & V2: 

119 value = dimension_value(tensor_shape[i]) 

120 

121 # This would be the V2 equivalent: 

122 value = tensor_shape[i] # Warning: this will return the dim value in V2! 

123 ``` 

124 

125 Args: 

126 dimension: Either a `Dimension` instance, an integer, or None. 

127 

128 Returns: 

129 A plain value, i.e. an integer or None. 

130 """ 

131 if isinstance(dimension, Dimension): 

132 return dimension.value 

133 return dimension 

134 

135 

136@tf_export( 

137 "compat.dimension_at_index", 

138 v1=["dimension_at_index", "compat.dimension_at_index"]) 

139def dimension_at_index(shape, index): 

140 """Compatibility utility required to allow for both V1 and V2 behavior in TF. 

141 

142 Until the release of TF 2.0, we need the legacy behavior of `TensorShape` to 

143 coexist with the new behavior. This utility is a bridge between the two. 

144 

145 If you want to retrieve the Dimension instance corresponding to a certain 

146 index in a TensorShape instance, use this utility, like this: 

147 

148 ``` 

149 # If you had this in your V1 code: 

150 dim = tensor_shape[i] 

151 

152 # Use `dimension_at_index` as direct replacement compatible with both V1 & V2: 

153 dim = dimension_at_index(tensor_shape, i) 

154 

155 # Another possibility would be this, but WARNING: it only works if the 

156 # tensor_shape instance has a defined rank. 

157 dim = tensor_shape.dims[i] # `dims` may be None if the rank is undefined! 

158 

159 # In native V2 code, we recommend instead being more explicit: 

160 if tensor_shape.rank is None: 

161 dim = Dimension(None) 

162 else: 

163 dim = tensor_shape.dims[i] 

164 

165 # Being more explicit will save you from the following trap (present in V1): 

166 # you might do in-place modifications to `dim` and expect them to be reflected 

167 # in `tensor_shape[i]`, but they would not be (as the Dimension object was 

168 # instantiated on the fly. 

169 ``` 

170 

171 Args: 

172 shape: A TensorShape instance. 

173 index: An integer index. 

174 

175 Returns: 

176 A dimension object. 

177 """ 

178 assert isinstance(shape, TensorShape) 

179 if shape.rank is None: 

180 return Dimension(None) 

181 else: 

182 return shape.dims[index] 

183 

184 

185@tf_export(v1=["Dimension"]) 

186class Dimension(object): 

187 """Represents the value of one dimension in a TensorShape. 

188 

189 @compatibility(TF2) 

190 In TF2, members of a `TensorShape` object are integers. The `Dimension` class 

191 is not part of TF2's data model. 

192 

193 Please refer to the [TensorShape section of the migration guide] 

194 (https://www.tensorflow.org/guide/migrate/index#tensorshape) on common code 

195 patterns adapting Dimension objects to a TF2 syntax. 

196 @end_compatibility 

197 """ 

198 

199 __slots__ = ["_value"] 

200 

201 def __init__(self, value): 

202 """Creates a new Dimension with the given value.""" 

203 if isinstance(value, int): # Most common case. 

204 if value < 0: 

205 raise ValueError("Dimension %d must be >= 0" % value) 

206 self._value = value 

207 elif value is None: 

208 self._value = None 

209 elif isinstance(value, Dimension): 

210 self._value = value._value 

211 else: 

212 try: 

213 # int(...) compensates for the int/long dichotomy on Python 2.X. 

214 # TODO(b/143206389): Remove once we fully migrate to 3.X. 

215 self._value = int(value.__index__()) 

216 except AttributeError: 

217 raise TypeError( 

218 "Dimension value must be integer or None or have " 

219 "an __index__ method, got value '{0!r}' with type '{1!r}'".format( 

220 value, type(value))) from None 

221 if self._value < 0: 

222 raise ValueError("Dimension %d must be >= 0" % self._value) 

223 

224 def __repr__(self): 

225 return "Dimension(%s)" % repr(self._value) 

226 

227 def __str__(self): 

228 value = self._value 

229 return "?" if value is None else str(value) 

230 

231 def __eq__(self, other): 

232 """Returns true if `other` has the same known value as this Dimension.""" 

233 try: 

234 other = as_dimension(other) 

235 except (TypeError, ValueError): 

236 return NotImplemented 

237 if self._value is None or other.value is None: 

238 return None 

239 return self._value == other.value 

240 

241 def __ne__(self, other): 

242 """Returns true if `other` has a different known value from `self`.""" 

243 try: 

244 other = as_dimension(other) 

245 except (TypeError, ValueError): 

246 return NotImplemented 

247 if self._value is None or other.value is None: 

248 return None 

249 return self._value != other.value 

250 

251 def __bool__(self): 

252 """Equivalent to `bool(self.value)`.""" 

253 return bool(self._value) 

254 

255 def __int__(self): 

256 return self._value 

257 

258 # This is needed for Windows. 

259 # See https://github.com/tensorflow/tensorflow/pull/9780 

260 def __long__(self): 

261 return self._value 

262 

263 def __index__(self): 

264 # Allow use in Python 3 range 

265 return self._value 

266 

267 @property 

268 def value(self): 

269 """The value of this dimension, or None if it is unknown.""" 

270 return self._value 

271 

272 # TODO(b/225058047): Reconsider semantics. 

273 def is_compatible_with(self, other): 

274 """Returns true if `other` is compatible with this Dimension. 

275 

276 Two known Dimensions are compatible if they have the same value. 

277 An unknown Dimension is compatible with all other Dimensions. 

278 

279 Args: 

280 other: Another Dimension. 

281 

282 Returns: 

283 True if this Dimension and `other` are compatible. 

284 """ 

285 other = as_dimension(other) 

286 return (self._value is None or other.value is None or 

287 self._value == other.value) 

288 

289 def assert_is_compatible_with(self, other): 

290 """Raises an exception if `other` is not compatible with this Dimension. 

291 

292 Args: 

293 other: Another Dimension. 

294 

295 Raises: 

296 ValueError: If `self` and `other` are not compatible (see 

297 is_compatible_with). 

298 """ 

299 if not self.is_compatible_with(other): 

300 raise ValueError("Dimensions %s and %s are not compatible" % 

301 (self, other)) 

302 

303 def merge_with(self, other): 

304 """Returns a Dimension that combines the information in `self` and `other`. 

305 

306 Dimensions are combined as follows: 

307 

308 ```python 

309 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(n)) == 

310 tf.compat.v1.Dimension(n) 

311 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(None)) == 

312 tf.compat.v1.Dimension(n) 

313 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(n)) == 

314 tf.compat.v1.Dimension(n) 

315 # equivalent to tf.compat.v1.Dimension(None) 

316 tf.compat.v1.Dimension(None).merge_with(tf.compat.v1.Dimension(None)) 

317 

318 # raises ValueError for n != m 

319 tf.compat.v1.Dimension(n) .merge_with(tf.compat.v1.Dimension(m)) 

320 ``` 

321 

322 Args: 

323 other: Another Dimension. 

324 

325 Returns: 

326 A Dimension containing the combined information of `self` and 

327 `other`. 

328 

329 Raises: 

330 ValueError: If `self` and `other` are not compatible (see 

331 is_compatible_with). 

332 """ 

333 other = as_dimension(other) 

334 self.assert_is_compatible_with(other) 

335 if self._value is None: 

336 return Dimension(other.value) 

337 else: 

338 return Dimension(self._value) 

339 

340 def __add__(self, other): 

341 """Returns the sum of `self` and `other`. 

342 

343 Dimensions are summed as follows: 

344 

345 ```python 

346 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(n) == 

347 tf.compat.v1.Dimension(m + n) 

348 tf.compat.v1.Dimension(m) + tf.compat.v1.Dimension(None) # equiv. to 

349 tf.compat.v1.Dimension(None) 

350 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(n) # equiv. to 

351 tf.compat.v1.Dimension(None) 

352 tf.compat.v1.Dimension(None) + tf.compat.v1.Dimension(None) # equiv. to 

353 tf.compat.v1.Dimension(None) 

354 ``` 

355 

356 Args: 

357 other: Another Dimension, or a value accepted by `as_dimension`. 

358 

359 Returns: 

360 A Dimension whose value is the sum of `self` and `other`. 

361 """ 

362 try: 

363 other = as_dimension(other) 

364 except (TypeError, ValueError): 

365 return NotImplemented 

366 if self._value is None or other.value is None: 

367 return Dimension(None) 

368 else: 

369 return Dimension(self._value + other.value) 

370 

371 def __radd__(self, other): 

372 """Returns the sum of `other` and `self`. 

373 

374 Args: 

375 other: Another Dimension, or a value accepted by `as_dimension`. 

376 

377 Returns: 

378 A Dimension whose value is the sum of `self` and `other`. 

379 """ 

380 return self + other 

381 

382 def __sub__(self, other): 

383 """Returns the subtraction of `other` from `self`. 

384 

385 Dimensions are subtracted as follows: 

386 

387 ```python 

388 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(n) == 

389 tf.compat.v1.Dimension(m - n) 

390 tf.compat.v1.Dimension(m) - tf.compat.v1.Dimension(None) # equiv. to 

391 tf.compat.v1.Dimension(None) 

392 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(n) # equiv. to 

393 tf.compat.v1.Dimension(None) 

394 tf.compat.v1.Dimension(None) - tf.compat.v1.Dimension(None) # equiv. to 

395 tf.compat.v1.Dimension(None) 

396 ``` 

397 

398 Args: 

399 other: Another Dimension, or a value accepted by `as_dimension`. 

400 

401 Returns: 

402 A Dimension whose value is the subtraction of `other` from `self`. 

403 """ 

404 try: 

405 other = as_dimension(other) 

406 except (TypeError, ValueError): 

407 return NotImplemented 

408 if self._value is None or other.value is None: 

409 return Dimension(None) 

410 else: 

411 return Dimension(self._value - other.value) 

412 

413 def __rsub__(self, other): 

414 """Returns the subtraction of `self` from `other`. 

415 

416 Args: 

417 other: Another Dimension, or a value accepted by `as_dimension`. 

418 

419 Returns: 

420 A Dimension whose value is the subtraction of `self` from `other`. 

421 """ 

422 other = as_dimension(other) 

423 if self._value is None or other.value is None: 

424 return Dimension(None) 

425 else: 

426 return Dimension(other.value - self._value) 

427 

428 def __mul__(self, other): 

429 """Returns the product of `self` and `other`. 

430 

431 Dimensions are summed as follows: 

432 

433 ```python 

434 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(n) == 

435 tf.compat.v1.Dimension(m * n) 

436 tf.compat.v1.Dimension(m) * tf.compat.v1.Dimension(None) # equiv. to 

437 tf.compat.v1.Dimension(None) 

438 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(n) # equiv. to 

439 tf.compat.v1.Dimension(None) 

440 tf.compat.v1.Dimension(None) * tf.compat.v1.Dimension(None) # equiv. to 

441 tf.compat.v1.Dimension(None) 

442 ``` 

443 

444 Args: 

445 other: Another Dimension, or a value accepted by `as_dimension`. 

446 

447 Returns: 

448 A Dimension whose value is the product of `self` and `other`. 

449 """ 

450 try: 

451 other = as_dimension(other) 

452 except (TypeError, ValueError): 

453 return NotImplemented 

454 

455 if self._value is None or other.value is None: 

456 return Dimension(None) 

457 else: 

458 return Dimension(self._value * other.value) 

459 

460 def __rmul__(self, other): 

461 """Returns the product of `self` and `other`. 

462 

463 Args: 

464 other: Another Dimension, or a value accepted by `as_dimension`. 

465 

466 Returns: 

467 A Dimension whose value is the product of `self` and `other`. 

468 """ 

469 return self * other 

470 

471 def __floordiv__(self, other): 

472 """Returns the quotient of `self` and `other` rounded down. 

473 

474 Dimensions are divided as follows: 

475 

476 ```python 

477 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(n) == 

478 tf.compat.v1.Dimension(m // n) 

479 tf.compat.v1.Dimension(m) // tf.compat.v1.Dimension(None) # equiv. to 

480 tf.compat.v1.Dimension(None) 

481 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(n) # equiv. to 

482 tf.compat.v1.Dimension(None) 

483 tf.compat.v1.Dimension(None) // tf.compat.v1.Dimension(None) # equiv. to 

484 tf.compat.v1.Dimension(None) 

485 ``` 

486 

487 Args: 

488 other: Another Dimension, or a value accepted by `as_dimension`. 

489 

490 Returns: 

491 A `Dimension` whose value is the integer quotient of `self` and `other`. 

492 """ 

493 try: 

494 other = as_dimension(other) 

495 except (TypeError, ValueError): 

496 return NotImplemented 

497 if self._value is None or other.value is None: 

498 return Dimension(None) 

499 else: 

500 return Dimension(self._value // other.value) 

501 

502 def __rfloordiv__(self, other): 

503 """Returns the quotient of `other` and `self` rounded down. 

504 

505 Args: 

506 other: Another Dimension, or a value accepted by `as_dimension`. 

507 

508 Returns: 

509 A `Dimension` whose value is the integer quotient of `self` and `other`. 

510 """ 

511 other = as_dimension(other) 

512 if self._value is None or other.value is None: 

513 return Dimension(None) 

514 else: 

515 return Dimension(other.value // self._value) 

516 

517 def __div__(self, other): 

518 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 

519 

520 This function exists only for backwards compatibility purposes; new code 

521 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 

522 communicates clearly that the result rounds down, and is forward compatible 

523 to Python 3. 

524 

525 Args: 

526 other: Another `Dimension`. 

527 

528 Returns: 

529 A `Dimension` whose value is the integer quotient of `self` and `other`. 

530 """ 

531 return self // other 

532 

533 def __rdiv__(self, other): 

534 """Use `__floordiv__` via `x // y` instead. 

535 

536 This function exists only to have a better error message. Instead of: 

537 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 

538 this function will explicitly call for usage of `//` instead. 

539 

540 Args: 

541 other: Another `Dimension`. 

542 

543 Raises: 

544 TypeError. 

545 """ 

546 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 

547 "please use // instead".format(type(other).__name__)) 

548 

549 def __truediv__(self, other): 

550 """Use `__floordiv__` via `x // y` instead. 

551 

552 This function exists only to have a better error message. Instead of: 

553 `TypeError: unsupported operand type(s) for /: 'Dimension' and 'int'`, 

554 this function will explicitly call for usage of `//` instead. 

555 

556 Args: 

557 other: Another `Dimension`. 

558 

559 Raises: 

560 TypeError. 

561 """ 

562 raise TypeError("unsupported operand type(s) for /: 'Dimension' and '{}', " 

563 "please use // instead".format(type(other).__name__)) 

564 

565 def __rtruediv__(self, other): 

566 """Use `__floordiv__` via `x // y` instead. 

567 

568 This function exists only to have a better error message. Instead of: 

569 `TypeError: unsupported operand type(s) for /: 'int' and 'Dimension'`, 

570 this function will explicitly call for usage of `//` instead. 

571 

572 Args: 

573 other: Another `Dimension`. 

574 

575 Raises: 

576 TypeError. 

577 """ 

578 raise TypeError("unsupported operand type(s) for /: '{}' and 'Dimension', " 

579 "please use // instead".format(type(other).__name__)) 

580 

581 def __mod__(self, other): 

582 """Returns `self` modulo `other`. 

583 

584 Dimension modulo are computed as follows: 

585 

586 ```python 

587 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(n) == 

588 tf.compat.v1.Dimension(m % n) 

589 tf.compat.v1.Dimension(m) % tf.compat.v1.Dimension(None) # equiv. to 

590 tf.compat.v1.Dimension(None) 

591 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(n) # equiv. to 

592 tf.compat.v1.Dimension(None) 

593 tf.compat.v1.Dimension(None) % tf.compat.v1.Dimension(None) # equiv. to 

594 tf.compat.v1.Dimension(None) 

595 ``` 

596 

597 Args: 

598 other: Another Dimension, or a value accepted by `as_dimension`. 

599 

600 Returns: 

601 A Dimension whose value is `self` modulo `other`. 

602 """ 

603 other = as_dimension(other) 

604 if self._value is None or other.value is None: 

605 return Dimension(None) 

606 else: 

607 return Dimension(self._value % other.value) 

608 

609 def __rmod__(self, other): 

610 """Returns `other` modulo `self`. 

611 

612 Args: 

613 other: Another Dimension, or a value accepted by `as_dimension`. 

614 

615 Returns: 

616 A Dimension whose value is `other` modulo `self`. 

617 """ 

618 other = as_dimension(other) 

619 return other % self 

620 

621 def __lt__(self, other): 

622 """Returns True if `self` is known to be less than `other`. 

623 

624 Dimensions are compared as follows: 

625 

626 ```python 

627 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(n)) == (m < n) 

628 (tf.compat.v1.Dimension(m) < tf.compat.v1.Dimension(None)) == None 

629 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(n)) == None 

630 (tf.compat.v1.Dimension(None) < tf.compat.v1.Dimension(None)) == None 

631 ``` 

632 

633 Args: 

634 other: Another Dimension. 

635 

636 Returns: 

637 The value of `self.value < other.value` if both are known, otherwise 

638 None. 

639 """ 

640 other = as_dimension(other) 

641 if self._value is None or other.value is None: 

642 return None 

643 else: 

644 return self._value < other.value 

645 

646 def __le__(self, other): 

647 """Returns True if `self` is known to be less than or equal to `other`. 

648 

649 Dimensions are compared as follows: 

650 

651 ```python 

652 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(n)) == (m <= n) 

653 (tf.compat.v1.Dimension(m) <= tf.compat.v1.Dimension(None)) == None 

654 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(n)) == None 

655 (tf.compat.v1.Dimension(None) <= tf.compat.v1.Dimension(None)) == None 

656 ``` 

657 

658 Args: 

659 other: Another Dimension. 

660 

661 Returns: 

662 The value of `self.value <= other.value` if both are known, otherwise 

663 None. 

664 """ 

665 other = as_dimension(other) 

666 if self._value is None or other.value is None: 

667 return None 

668 else: 

669 return self._value <= other.value 

670 

671 def __gt__(self, other): 

672 """Returns True if `self` is known to be greater than `other`. 

673 

674 Dimensions are compared as follows: 

675 

676 ```python 

677 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(n)) == (m > n) 

678 (tf.compat.v1.Dimension(m) > tf.compat.v1.Dimension(None)) == None 

679 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(n)) == None 

680 (tf.compat.v1.Dimension(None) > tf.compat.v1.Dimension(None)) == None 

681 ``` 

682 

683 Args: 

684 other: Another Dimension. 

685 

686 Returns: 

687 The value of `self.value > other.value` if both are known, otherwise 

688 None. 

689 """ 

690 other = as_dimension(other) 

691 if self._value is None or other.value is None: 

692 return None 

693 else: 

694 return self._value > other.value 

695 

696 def __ge__(self, other): 

697 """Returns True if `self` is known to be greater than or equal to `other`. 

698 

699 Dimensions are compared as follows: 

700 

701 ```python 

702 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(n)) == (m >= n) 

703 (tf.compat.v1.Dimension(m) >= tf.compat.v1.Dimension(None)) == None 

704 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(n)) == None 

705 (tf.compat.v1.Dimension(None) >= tf.compat.v1.Dimension(None)) == None 

706 ``` 

707 

708 Args: 

709 other: Another Dimension. 

710 

711 Returns: 

712 The value of `self.value >= other.value` if both are known, otherwise 

713 None. 

714 """ 

715 other = as_dimension(other) 

716 if self._value is None or other.value is None: 

717 return None 

718 else: 

719 return self._value >= other.value 

720 

721 def __reduce__(self): 

722 return Dimension, (self._value,) 

723 

724 

725def as_dimension(value): 

726 """Converts the given value to a Dimension. 

727 

728 A Dimension input will be returned unmodified. 

729 An input of `None` will be converted to an unknown Dimension. 

730 An integer input will be converted to a Dimension with that value. 

731 

732 Args: 

733 value: The value to be converted. 

734 

735 Returns: 

736 A Dimension corresponding to the given value. 

737 """ 

738 if isinstance(value, Dimension): 

739 return value 

740 else: 

741 return Dimension(value) 

742 

743 

744@tf_export("TensorShape") 

745class TensorShape(trace.TraceType, trace_type.Serializable): 

746 """Represents the shape of a `Tensor`. 

747 

748 >>> t = tf.constant([[1,2,3],[4,5,6]]) 

749 >>> t.shape 

750 TensorShape([2, 3]) 

751 

752 `TensorShape` is the *static* shape representation of a Tensor. 

753 During eager execution a Tensor always has a fully specified shape but 

754 when tracing a `tf.function` it may be one of the following: 

755 

756 * *Fully-known shape:* has a known number of dimensions and a known size 

757 for each dimension. e.g. `TensorShape([16, 256])` 

758 * *Partially-known shape:* has a known number of dimensions, and an unknown 

759 size for one or more dimension. e.g. `TensorShape([None, 256])` 

760 * *Unknown shape:* has an unknown number of dimensions, and an unknown 

761 size in all dimensions. e.g. `TensorShape(None)` 

762 

763 During function tracing `t.shape` will return a `TensorShape` object 

764 representing the shape of Tensor as it is known during tracing. 

765 This static representation will be partially defined in cases where the 

766 exact shape depends on the values within the tensors. To get the 

767 *dynamic* representation, please use `tf.shape(t)` 

768 which will return Tensor representing the fully defined shape of `t`. 

769 This way, you can express logic that manipulates the shapes of tensors by 

770 building other tensors that depend on the dynamic shape of `t`. 

771 

772 Note: `tf.RaggedTensor.shape` also returns a `tf.TensorShape`, 

773 the lengths of any ragged dimensions are unknown (`None`). 

774 

775 For example, this function prints the `TensorShape' (`t.shape`), when you 

776 trace the function, and returns a tensor `tf.shape(t)` for given input `t`: 

777 

778 >>> @tf.function 

779 ... def get_dynamic_shape(t): 

780 ... print("tracing...") 

781 ... print(f"static shape is {t.shape}") 

782 ... return tf.shape(t) 

783 

784 Just calling the function traces it with a fully-specified static shape: 

785 

786 >>> result = get_dynamic_shape(tf.constant([[1, 1, 1], [0, 0, 0]])) 

787 tracing... 

788 static shape is (2, 3) 

789 >>> result.numpy() 

790 array([2, 3], dtype=int32) 

791 

792 But `tf.function` can also trace the function with a partially specified 

793 (or even unspecified) shape: 

794 

795 >>> cf1 = get_dynamic_shape.get_concrete_function(tf.TensorSpec( 

796 ... shape=[None, 2])) 

797 tracing... 

798 static shape is (None, 2) 

799 >>> cf1(tf.constant([[1., 0],[1, 0],[1, 0]])).numpy() 

800 array([3, 2], dtype=int32) 

801 

802 >>> cf2 = get_dynamic_shape.get_concrete_function(tf.TensorSpec(shape=None)) 

803 tracing... 

804 static shape is <unknown> 

805 >>> cf2(tf.constant([[[[[1., 0]]]]])).numpy() 

806 array([1, 1, 1, 1, 2], dtype=int32) 

807 

808 If a tensor is produced by an operation of type `"Foo"`, its shape 

809 may be inferred if there is a registered shape function for 

810 `"Foo"`. See [Shape 

811 functions](https://www.tensorflow.org/guide/create_op#shape_functions_in_c) 

812 for details of shape functions and how to register them. Alternatively, 

813 you may set the shape explicitly using `tf.Tensor.ensure_shape`. 

814 """ 

815 __slots__ = ["_dims"] 

816 

817 def __init__(self, dims): 

818 """Creates a new TensorShape with the given dimensions. 

819 

820 Args: 

821 dims: A list of Dimensions, or None if the shape is unspecified. 

822 

823 Raises: 

824 TypeError: If dims cannot be converted to a list of dimensions. 

825 """ 

826 if isinstance(dims, (tuple, list)): # Most common case. 

827 self._dims = tuple(as_dimension(d).value for d in dims) 

828 elif dims is None: 

829 self._dims = None 

830 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 

831 if dims.unknown_rank: 

832 self._dims = None 

833 else: 

834 self._dims = tuple( 

835 # Protos store variable-size dimensions as -1 

836 dim.size if dim.size != -1 else None 

837 for dim in dims.dim 

838 ) 

839 elif isinstance(dims, TensorShape): 

840 self._dims = dims._dims 

841 else: 

842 try: 

843 dims_iter = iter(dims) 

844 except TypeError: 

845 # Treat as a singleton dimension 

846 self._dims = (as_dimension(dims).value,) 

847 else: 

848 self._dims = [] 

849 for d in dims_iter: 

850 try: 

851 self._dims.append(as_dimension(d).value) 

852 except TypeError as e: 

853 raise TypeError( 

854 "Failed to convert '{0!r}' to a shape: '{1!r}'" 

855 "could not be converted to a dimension. A shape should " 

856 "either be single dimension (e.g. 10), or an iterable of " 

857 "dimensions (e.g. [1, 10, None]).".format(dims, d)) from e 

858 self._dims = tuple(self._dims) 

859 

860 @property 

861 def _v2_behavior(self): 

862 if _TENSORSHAPE_V2_OVERRIDE is None: 

863 return tf2.enabled() 

864 return _TENSORSHAPE_V2_OVERRIDE 

865 

866 def __repr__(self): 

867 if self._v2_behavior: 

868 if self._dims is not None: 

869 return f"TensorShape({list(self._dims)})" 

870 else: 

871 return "TensorShape(None)" 

872 else: 

873 return f"TensorShape({self.dims})" 

874 

875 def __str__(self): 

876 if self.rank is None: 

877 return "<unknown>" 

878 elif self.rank == 1: 

879 if self._v2_behavior: 

880 return "(%s,)" % self._dims[0] 

881 else: 

882 return "(%s,)" % self.dims[0] 

883 else: 

884 if self._v2_behavior: 

885 return "(%s)" % ", ".join(str(d) for d in self._dims) 

886 else: 

887 return "(%s)" % ", ".join(str(d) for d in self.dims) 

888 

889 @property 

890 def rank(self): 

891 """Returns the rank of this shape, or None if it is unspecified.""" 

892 if self._dims is not None: 

893 return len(self._dims) 

894 return None 

895 

896 @property 

897 def dims(self): 

898 """Deprecated. Returns list of dimensions for this shape. 

899 

900 Suggest `TensorShape.as_list` instead. 

901 

902 Returns: 

903 A list containing `tf.compat.v1.Dimension`s, or None if the shape is 

904 unspecified. 

905 """ 

906 if self._dims is None: 

907 return None 

908 return [as_dimension(d) for d in self._dims] 

909 

910 @property 

911 def ndims(self): 

912 """Deprecated accessor for `rank`.""" 

913 return self.rank 

914 

915 def __len__(self): 

916 """Returns the rank of this shape, or raises ValueError if unspecified.""" 

917 if self._dims is None: 

918 raise ValueError("Cannot take the length of shape with unknown rank.") 

919 return len(self._dims) 

920 

921 def __bool__(self): 

922 """Returns True if this shape contains non-zero information.""" 

923 return self._dims is not None 

924 

925 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 

926 __nonzero__ = __bool__ 

927 

928 def __iter__(self): 

929 """Returns `self.dims` if the rank is known, otherwise raises ValueError.""" 

930 if self._dims is None: 

931 raise ValueError("Cannot iterate over a shape with unknown rank.") 

932 else: 

933 if self._v2_behavior: 

934 return iter(d for d in self._dims) 

935 else: 

936 return iter(d for d in self.dims) 

937 

938 def __getitem__(self, key): 

939 """Returns the value of a dimension or a shape, depending on the key. 

940 

941 Args: 

942 key: If `key` is an integer, returns the dimension at that index; 

943 otherwise if `key` is a slice, returns a TensorShape whose dimensions 

944 are those selected by the slice from `self`. 

945 

946 Returns: 

947 An integer if `key` is an integer, or a `TensorShape` if `key` is a 

948 slice. 

949 

950 Raises: 

951 ValueError: If `key` is a slice and `self` is completely unknown and 

952 the step is set. 

953 """ 

954 if self._dims is not None: 

955 if isinstance(key, slice): 

956 return TensorShape(self._dims[key]) 

957 else: 

958 if self._v2_behavior: 

959 return self._dims[key] 

960 else: 

961 return self.dims[key] 

962 else: 

963 if isinstance(key, slice): 

964 start = key.start if key.start is not None else 0 

965 stop = key.stop 

966 

967 if key.step is not None: 

968 # TODO(mrry): Handle these maybe. 

969 raise ValueError("Steps are not yet handled") 

970 if stop is None: 

971 # NOTE(mrry): This implies that TensorShape(None) is compatible with 

972 # TensorShape(None)[1:], which is obviously not true. It would be 

973 # possible to track the number of dimensions symbolically, 

974 # and perhaps we should do that. 

975 return unknown_shape() 

976 elif start < 0 or stop < 0: 

977 # TODO(mrry): Handle this better, as it will be useful for handling 

978 # suffixes of otherwise unknown shapes. 

979 return unknown_shape() 

980 else: 

981 return unknown_shape(rank=stop - start) 

982 else: 

983 if self._v2_behavior: 

984 return None 

985 else: 

986 return Dimension(None) 

987 

988 def num_elements(self): 

989 """Returns the total number of elements, or none for incomplete shapes.""" 

990 if self.is_fully_defined(): 

991 return functools.reduce(operator.mul, self.as_list(), 1) 

992 else: 

993 return None 

994 

995 def merge_with(self, other): 

996 """Returns a `TensorShape` combining the information in `self` and `other`. 

997 

998 The dimensions in `self` and `other` are merged element-wise, 

999 according to the rules below: 

1000 

1001 ```python 

1002 Dimension(n).merge_with(Dimension(None)) == Dimension(n) 

1003 Dimension(None).merge_with(Dimension(n)) == Dimension(n) 

1004 Dimension(None).merge_with(Dimension(None)) == Dimension(None) 

1005 # raises ValueError for n != m 

1006 Dimension(n).merge_with(Dimension(m)) 

1007 ``` 

1008 >> ts = tf.TensorShape([1,2]) 

1009 >> ot1 = tf.TensorShape([1,2]) 

1010 >> ts.merge_with(ot).as_list() 

1011 [1,2] 

1012 

1013 >> ot2 = tf.TensorShape([1,None]) 

1014 >> ts.merge_with(ot2).as_list() 

1015 [1,2] 

1016 

1017 >> ot3 = tf.TensorShape([None, None]) 

1018 >> ot3.merge_with(ot2).as_list() 

1019 [1, None] 

1020 

1021 Args: 

1022 other: Another `TensorShape`. 

1023 

1024 Returns: 

1025 A `TensorShape` containing the combined information of `self` and 

1026 `other`. 

1027 

1028 Raises: 

1029 ValueError: If `self` and `other` are not compatible. 

1030 """ 

1031 other = as_shape(other) 

1032 if self.dims is None: 

1033 return other 

1034 if other.dims is None: 

1035 return self 

1036 else: 

1037 try: 

1038 self.assert_same_rank(other) 

1039 new_dims = [ 

1040 dim.merge_with(other_dim) 

1041 for dim, other_dim in zip(self.dims, other.dims) 

1042 ] 

1043 return TensorShape(new_dims) 

1044 except ValueError: 

1045 raise ValueError("Shapes %s and %s are not compatible" % (self, other)) 

1046 

1047 def __add__(self, other): 

1048 return self.concatenate(other) 

1049 

1050 def __radd__(self, other): 

1051 if not isinstance(other, TensorShape): 

1052 other = TensorShape(other) 

1053 return other.concatenate(self) 

1054 

1055 def concatenate(self, other): 

1056 """Returns the concatenation of the dimension in `self` and `other`. 

1057 

1058 *N.B.* If either `self` or `other` is completely unknown, 

1059 concatenation will discard information about the other shape. In 

1060 future, we might support concatenation that preserves this 

1061 information for use with slicing. 

1062 

1063 Args: 

1064 other: Another `TensorShape`. 

1065 

1066 Returns: 

1067 A `TensorShape` whose dimensions are the concatenation of the 

1068 dimensions in `self` and `other`. 

1069 """ 

1070 # TODO(mrry): Handle the case where we concatenate a known shape with a 

1071 # completely unknown shape, so that we can use the partial information. 

1072 other = as_shape(other) 

1073 if self.dims is None or other.dims is None: 

1074 return unknown_shape() 

1075 else: 

1076 return TensorShape(self.dims + other.dims) 

1077 

1078 def assert_same_rank(self, other): 

1079 """Raises an exception if `self` and `other` do not have compatible ranks. 

1080 

1081 Args: 

1082 other: Another `TensorShape`. 

1083 

1084 Raises: 

1085 ValueError: If `self` and `other` do not represent shapes with the 

1086 same rank. 

1087 """ 

1088 other = as_shape(other) 

1089 if self.rank is not None and other.rank is not None: 

1090 if self.rank != other.rank: 

1091 raise ValueError("Shapes %s and %s must have the same rank" % 

1092 (self, other)) 

1093 

1094 def assert_has_rank(self, rank): 

1095 """Raises an exception if `self` is not compatible with the given `rank`. 

1096 

1097 Args: 

1098 rank: An integer. 

1099 

1100 Raises: 

1101 ValueError: If `self` does not represent a shape with the given `rank`. 

1102 """ 

1103 if self.rank not in (None, rank): 

1104 raise ValueError("Shape %s must have rank %d" % (self, rank)) 

1105 

1106 def with_rank(self, rank): 

1107 """Returns a shape based on `self` with the given rank. 

1108 

1109 This method promotes a completely unknown shape to one with a 

1110 known rank. 

1111 

1112 Args: 

1113 rank: An integer. 

1114 

1115 Returns: 

1116 A shape that is at least as specific as `self` with the given rank. 

1117 

1118 Raises: 

1119 ValueError: If `self` does not represent a shape with the given `rank`. 

1120 """ 

1121 try: 

1122 return self.merge_with(unknown_shape(rank=rank)) 

1123 except ValueError: 

1124 raise ValueError("Shape %s must have rank %d" % (self, rank)) 

1125 

1126 def with_rank_at_least(self, rank): 

1127 """Returns a shape based on `self` with at least the given rank. 

1128 

1129 Args: 

1130 rank: An integer. 

1131 

1132 Returns: 

1133 A shape that is at least as specific as `self` with at least the given 

1134 rank. 

1135 

1136 Raises: 

1137 ValueError: If `self` does not represent a shape with at least the given 

1138 `rank`. 

1139 """ 

1140 if self.rank is not None and self.rank < rank: 

1141 raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 

1142 else: 

1143 return self 

1144 

1145 def with_rank_at_most(self, rank): 

1146 """Returns a shape based on `self` with at most the given rank. 

1147 

1148 Args: 

1149 rank: An integer. 

1150 

1151 Returns: 

1152 A shape that is at least as specific as `self` with at most the given 

1153 rank. 

1154 

1155 Raises: 

1156 ValueError: If `self` does not represent a shape with at most the given 

1157 `rank`. 

1158 """ 

1159 if self.rank is not None and self.rank > rank: 

1160 raise ValueError("Shape %s must have rank at most %d" % (self, rank)) 

1161 else: 

1162 return self 

1163 

1164 def is_subtype_of(self, other: trace.TraceType) -> bool: 

1165 """Returns True iff `self` is subtype of `other`. 

1166 

1167 Shape A is a subtype of shape B if shape B can successfully represent it: 

1168 

1169 * A `TensorShape` of any rank is a subtype of `TensorShape(None)`. 

1170 

1171 * TensorShapes of equal ranks are covariant, i.e. 

1172 `TensorShape([A1, A2, ..])` is a subtype of 

1173 `TensorShape([B1, B2, ..])` iff An is a subtype of Bn. 

1174 

1175 An is subtype of Bn iff An == Bn or Bn is None. 

1176 

1177 * TensorShapes of different defined ranks have no subtyping relation. 

1178 

1179 The subtyping relation is reflexive and transitive, but not symmetric. 

1180 

1181 Some examples: 

1182 * `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and 

1183 `TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but 

1184 `TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of 

1185 each other. 

1186 

1187 * All two-dimensional shapes are subtypes of `TensorShape([None, None])`, 

1188 such as `TensorShape([32, 784])`. There is no subtype relationship with, 

1189 for example, `TensorShape([None])` or `TensorShape([None, None, None])`. 

1190 

1191 * `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])` 

1192 and `TensorShape(None)`. It is not a subtype of, for example, 

1193 `TensorShape([32])`, `TensorShape([32, None, 1])`, 

1194 `TensorShape([64, None])` or `TensorShape([None, 32])`. 

1195 

1196 * `TensorShape([32, 784])` is a subtype of itself, and also 

1197 `TensorShape([32, None])`, `TensorShape([None, 784])`, 

1198 `TensorShape([None, None])` and `TensorShape(None)`. 

1199 It has no subtype relation with, for example, `TensorShape([32, 1, 784])` 

1200 or `TensorShape([None])`. 

1201 

1202 Args: 

1203 other: Another `TensorShape`. 

1204 

1205 Returns: 

1206 True iff `self` is subtype of `other`. 

1207 

1208 """ 

1209 if not isinstance(other, TensorShape): 

1210 return False 

1211 

1212 # All Tensors are subtypes of a Tensor with no shape. 

1213 if other.rank is None: 

1214 return True 

1215 

1216 # Tensor with a defined shape can only be subtype of another with a defined 

1217 # shape if they have the same number of dimensions. 

1218 if self.rank != other.rank: 

1219 return False 

1220 

1221 # A Tensor is a subtype if each corresponding dimension is a subtype. 

1222 return all(o is None or s == o for s, o in zip(self._dims, other._dims)) # pylint: disable=protected-access 

1223 

1224 def most_specific_common_supertype( 

1225 self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]: 

1226 """Returns the most specific supertype `TensorShape` of self and others. 

1227 

1228 * `TensorShape([None, 1])` is the most specific `TensorShape` supertyping 

1229 both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that 

1230 `TensorShape(None)` is also a supertype but it is not "most specific". 

1231 

1232 * `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping 

1233 both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are 

1234 other less specific TensorShapes that supertype above mentioned 

1235 TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`. 

1236 

1237 * `TensorShape([None, None])` is the most specific `TensorShape` 

1238 supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`. 

1239 As always, `TensorShape(None)` is also a supertype but not the most 

1240 specific one. 

1241 

1242 * `TensorShape(None`) is the only `TensorShape` supertyping both 

1243 `TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two 

1244 shapes that have different ranks will only have `TensorShape(None)` 

1245 as a common supertype. 

1246 

1247 * `TensorShape(None)` is the only `TensorShape` supertyping both 

1248 `TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common 

1249 supertype of any shape with `TensorShape(None)` is `TensorShape(None)`. 

1250 

1251 Args: 

1252 others: Sequence of `TensorShape`. 

1253 

1254 Returns: 

1255 A `TensorShape` which is the most specific supertype shape of `self` 

1256 and `others`. None if it does not exist. 

1257 """ 

1258 if any(not isinstance(other, TensorShape) for other in others): 

1259 return None 

1260 

1261 # A Rankless TensorShape is already a global supertype so we return another 

1262 # instance of it. 

1263 if self.rank is None: 

1264 return unknown_shape() 

1265 

1266 # A Rankless TensorShape is the most specific supertype for shapes whose 

1267 # ranks do not match. 

1268 if any(other.dims is None or self.rank != other.rank for other in others): 

1269 return unknown_shape() 

1270 

1271 # Retain the integer dimension if it is the same across all others, else 

1272 # use an undefined dimension. 

1273 dims = [ 

1274 dim if all(dim == other._dims[i] 

1275 for other in others) else None 

1276 for i, dim in enumerate(self._dims) 

1277 ] 

1278 return TensorShape(dims) 

1279 

1280 @doc_controls.do_not_doc_inheritable 

1281 def placeholder_value(self, placeholder_context): 

1282 raise NotImplementedError("A graph placeholder is not currently supported" 

1283 "for an object of type: TensorShape.") 

1284 

1285 @classmethod 

1286 def experimental_type_proto(cls) -> Type[tensor_shape_pb2.TensorShapeProto]: 

1287 """Returns the type of proto associated with TensorShape serialization.""" 

1288 return tensor_shape_pb2.TensorShapeProto 

1289 

1290 @classmethod 

1291 def experimental_from_proto( 

1292 cls, proto: tensor_shape_pb2.TensorShapeProto) -> "TensorShape": 

1293 """Returns a TensorShape instance based on the serialized proto.""" 

1294 return TensorShape(proto) 

1295 

1296 def experimental_as_proto(self) -> tensor_shape_pb2.TensorShapeProto: 

1297 """Returns a proto representation of the TensorShape instance.""" 

1298 return self.as_proto() 

1299 

1300 # TODO(b/216206374): Consider deprecation at TraceType release. 

1301 def is_compatible_with(self, other): 

1302 """Returns True iff `self` is compatible with `other`. 

1303 

1304 Two possibly-partially-defined shapes are compatible if there 

1305 exists a fully-defined shape that both shapes can represent. Thus, 

1306 compatibility allows the shape inference code to reason about 

1307 partially-defined shapes. For example: 

1308 

1309 * TensorShape(None) is compatible with all shapes. 

1310 

1311 * TensorShape([None, None]) is compatible with all two-dimensional 

1312 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 

1313 not compatible with, for example, TensorShape([None]) or 

1314 TensorShape([None, None, None]). 

1315 

1316 * TensorShape([32, None]) is compatible with all two-dimensional shapes 

1317 with size 32 in the 0th dimension, and also TensorShape([None, None]) 

1318 and TensorShape(None). It is not compatible with, for example, 

1319 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 

1320 

1321 * TensorShape([32, 784]) is compatible with itself, and also 

1322 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 

1323 None]) and TensorShape(None). It is not compatible with, for example, 

1324 TensorShape([32, 1, 784]) or TensorShape([None]). 

1325 

1326 The compatibility relation is reflexive and symmetric, but not 

1327 transitive. For example, TensorShape([32, 784]) is compatible with 

1328 TensorShape(None), and TensorShape(None) is compatible with 

1329 TensorShape([4, 4]), but TensorShape([32, 784]) is not compatible with 

1330 TensorShape([4, 4]). 

1331 

1332 Args: 

1333 other: Another TensorShape. 

1334 

1335 Returns: 

1336 True iff `self` is compatible with `other`. 

1337 

1338 """ 

1339 other = as_shape(other) 

1340 if self.dims is not None and other.dims is not None: 

1341 if self.rank != other.rank: 

1342 return False 

1343 for x_dim, y_dim in zip(self.dims, other.dims): 

1344 if not x_dim.is_compatible_with(y_dim): 

1345 return False 

1346 return True 

1347 

1348 def assert_is_compatible_with(self, other): 

1349 """Raises exception if `self` and `other` do not represent the same shape. 

1350 

1351 This method can be used to assert that there exists a shape that both 

1352 `self` and `other` represent. 

1353 

1354 Args: 

1355 other: Another TensorShape. 

1356 

1357 Raises: 

1358 ValueError: If `self` and `other` do not represent the same shape. 

1359 """ 

1360 if not self.is_compatible_with(other): 

1361 raise ValueError("Shapes %s and %s are incompatible" % (self, other)) 

1362 

1363 def most_specific_compatible_shape(self, other): 

1364 """Returns the most specific TensorShape compatible with `self` and `other`. 

1365 

1366 * TensorShape([None, 1]) is the most specific TensorShape compatible with 

1367 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 

1368 TensorShape(None) is also compatible with above mentioned TensorShapes. 

1369 

1370 * TensorShape([1, 2, 3]) is the most specific TensorShape compatible with 

1371 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 

1372 less specific TensorShapes compatible with above mentioned TensorShapes, 

1373 e.g. TensorShape([1, 2, None]), TensorShape(None). 

1374 

1375 Args: 

1376 other: Another `TensorShape`. 

1377 

1378 Returns: 

1379 A `TensorShape` which is the most specific compatible shape of `self` 

1380 and `other`. 

1381 """ 

1382 

1383 other = as_shape(other) 

1384 if self.dims is None or other.dims is None or self.rank != other.rank: 

1385 return unknown_shape() 

1386 

1387 dims = [ 

1388 d1 if d1 is not None and d2 is not None and d1 == d2 else None 

1389 for d1, d2 in zip(self.dims, other.dims) 

1390 ] 

1391 return TensorShape(dims) 

1392 

1393 def is_fully_defined(self): 

1394 """Returns True iff `self` is fully defined in every dimension.""" 

1395 return (self._dims is not None and 

1396 all(dim is not None for dim in self._dims)) 

1397 

1398 def assert_is_fully_defined(self): 

1399 """Raises an exception if `self` is not fully defined in every dimension. 

1400 

1401 Raises: 

1402 ValueError: If `self` does not have a known value for every dimension. 

1403 """ 

1404 if not self.is_fully_defined(): 

1405 raise ValueError("Shape %s is not fully defined" % self) 

1406 

1407 def as_list(self): 

1408 """Returns a list of integers or `None` for each dimension. 

1409 

1410 Returns: 

1411 A list of integers or `None` for each dimension. 

1412 

1413 Raises: 

1414 ValueError: If `self` is an unknown shape with an unknown rank. 

1415 """ 

1416 if self._dims is None: 

1417 raise ValueError("as_list() is not defined on an unknown TensorShape.") 

1418 return list(self._dims) 

1419 

1420 def as_proto(self): 

1421 """Returns this shape as a `TensorShapeProto`.""" 

1422 if self._dims is None: 

1423 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 

1424 else: 

1425 return tensor_shape_pb2.TensorShapeProto(dim=[ 

1426 tensor_shape_pb2.TensorShapeProto.Dim( 

1427 size=-1 if d is None else d) for d in self._dims 

1428 ]) 

1429 

1430 def __eq__(self, other): 

1431 """Returns True if `self` is equivalent to `other`. 

1432 

1433 It first tries to convert `other` to `TensorShape`. `TypeError` is thrown 

1434 when the conversion fails. Otherwise, it compares each element in the 

1435 TensorShape dimensions. 

1436 

1437 * Two *Fully known* shapes, return True iff each element is equal. 

1438 >>> t_a = tf.TensorShape([1,2]) 

1439 >>> a = [1, 2] 

1440 >>> t_b = tf.TensorShape([1,2]) 

1441 >>> t_c = tf.TensorShape([1,2,3]) 

1442 >>> t_a.__eq__(a) 

1443 True 

1444 >>> t_a.__eq__(t_b) 

1445 True 

1446 >>> t_a.__eq__(t_c) 

1447 False 

1448 

1449 * Two *Partially-known* shapes, return True iff each element is equal. 

1450 >>> p_a = tf.TensorShape([1,None]) 

1451 >>> p_b = tf.TensorShape([1,None]) 

1452 >>> p_c = tf.TensorShape([2,None]) 

1453 >>> p_a.__eq__(p_b) 

1454 True 

1455 >>> t_a.__eq__(p_a) 

1456 False 

1457 >>> p_a.__eq__(p_c) 

1458 False 

1459 

1460 * Two *Unknown shape*, return True. 

1461 >>> unk_a = tf.TensorShape(None) 

1462 >>> unk_b = tf.TensorShape(None) 

1463 >>> unk_a.__eq__(unk_b) 

1464 True 

1465 >>> unk_a.__eq__(t_a) 

1466 False 

1467 

1468 Args: 

1469 other: A `TensorShape` or type that can be converted to `TensorShape`. 

1470 

1471 Returns: 

1472 True if the dimensions are all equal. 

1473 

1474 Raises: 

1475 TypeError if `other` can not be converted to `TensorShape`. 

1476 """ 

1477 

1478 try: 

1479 other = as_shape(other) 

1480 except TypeError: 

1481 return NotImplemented 

1482 

1483 return self._dims == other._dims 

1484 

1485 def __hash__(self): 

1486 return hash(self._dims) 

1487 

1488 def __reduce__(self): 

1489 return TensorShape, (self.dims,) 

1490 

1491 def __concat__(self, other): 

1492 return self.concatenate(other) 

1493 

1494trace_type.register_serializable(TensorShape) 

1495 

1496 

1497class _TensorShapeCodec: 

1498 """Codec for `TensorShape`.""" 

1499 

1500 def can_encode(self, pyobj): 

1501 return isinstance(pyobj, TensorShape) 

1502 

1503 def do_encode(self, tensor_shape_value, encode_fn): 

1504 del encode_fn 

1505 encoded_tensor_shape = struct_pb2.StructuredValue() 

1506 encoded_tensor_shape.tensor_shape_value.CopyFrom( 

1507 tensor_shape_value.as_proto()) 

1508 return encoded_tensor_shape 

1509 

1510 def can_decode(self, value): 

1511 return value.HasField("tensor_shape_value") 

1512 

1513 def do_decode(self, value, decode_fn): 

1514 del decode_fn 

1515 return TensorShape(value.tensor_shape_value) 

1516 

1517 

1518nested_structure_coder.register_codec(_TensorShapeCodec()) 

1519 

1520 

1521def as_shape(shape): 

1522 """Converts the given object to a TensorShape.""" 

1523 if isinstance(shape, TensorShape): 

1524 return shape 

1525 else: 

1526 return TensorShape(shape) 

1527 

1528 

1529def unknown_shape(rank=None, **kwargs): 

1530 """Returns an unknown TensorShape, optionally with a known rank. 

1531 

1532 Args: 

1533 rank: (Optional) If specified, the number of dimensions in the shape. 

1534 **kwargs: For backwards compatibility. 

1535 

1536 Returns: 

1537 An unknown TensorShape. 

1538 

1539 Raises: 

1540 TypeError: In case of invalid arguments. 

1541 """ 

1542 if rank is None and "ndims" in kwargs: 

1543 rank = kwargs.pop("ndims") 

1544 if kwargs: 

1545 raise TypeError("Unknown argument: %s" % kwargs) 

1546 if rank is None: 

1547 return TensorShape(None) 

1548 else: 

1549 return TensorShape([Dimension(None)] * rank)