Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorboard/compat/tensorflow_stub/tensor_shape.py: 26%

317 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helper classes for tensor shape inference.""" 

16 

17from . import compat, dtypes 

18from tensorboard.compat.proto import tensor_shape_pb2 

19 

20 

21# @tf_export("Dimension") 

22class Dimension: 

23 """Represents the value of one dimension in a TensorShape.""" 

24 

25 def __init__(self, value): 

26 """Creates a new Dimension with the given value.""" 

27 if value is None: 

28 self._value = None 

29 elif isinstance(value, dtypes.DType): 

30 raise TypeError("Cannot convert %s to Dimension" % value) 

31 else: 

32 self._value = int(value) 

33 if ( 

34 not isinstance(value, compat.bytes_or_text_types) 

35 and self._value != value 

36 ): 

37 raise ValueError("Ambiguous dimension: %s" % value) 

38 if self._value < 0: 

39 raise ValueError("Dimension %d must be >= 0" % self._value) 

40 

41 def __repr__(self): 

42 return "Dimension(%s)" % repr(self._value) 

43 

44 def __str__(self): 

45 value = self._value 

46 return "?" if value is None else str(value) 

47 

48 def __eq__(self, other): 

49 """Returns true if `other` has the same known value as this 

50 Dimension.""" 

51 try: 

52 other = as_dimension(other) 

53 except (TypeError, ValueError): 

54 return NotImplemented 

55 if self._value is None or other.value is None: 

56 return None 

57 return self._value == other.value 

58 

59 def __ne__(self, other): 

60 """Returns true if `other` has a different known value from `self`.""" 

61 try: 

62 other = as_dimension(other) 

63 except (TypeError, ValueError): 

64 return NotImplemented 

65 if self._value is None or other.value is None: 

66 return None 

67 return self._value != other.value 

68 

69 def __int__(self): 

70 return self._value 

71 

72 # This is needed for Windows. 

73 # See https://github.com/tensorflow/tensorflow/pull/9780 

74 def __long__(self): 

75 return self._value 

76 

77 def __index__(self): 

78 # Allow use in Python 3 range 

79 return self._value 

80 

81 @property 

82 def value(self): 

83 """The value of this dimension, or None if it is unknown.""" 

84 return self._value 

85 

86 def is_convertible_with(self, other): 

87 """Returns true if `other` is convertible with this Dimension. 

88 

89 Two known Dimensions are convertible if they have the same value. 

90 An unknown Dimension is convertible with all other Dimensions. 

91 

92 Args: 

93 other: Another Dimension. 

94 

95 Returns: 

96 True if this Dimension and `other` are convertible. 

97 """ 

98 other = as_dimension(other) 

99 return ( 

100 self._value is None 

101 or other.value is None 

102 or self._value == other.value 

103 ) 

104 

105 def assert_is_convertible_with(self, other): 

106 """Raises an exception if `other` is not convertible with this 

107 Dimension. 

108 

109 Args: 

110 other: Another Dimension. 

111 

112 Raises: 

113 ValueError: If `self` and `other` are not convertible (see 

114 is_convertible_with). 

115 """ 

116 if not self.is_convertible_with(other): 

117 raise ValueError( 

118 "Dimensions %s and %s are not convertible" % (self, other) 

119 ) 

120 

121 def merge_with(self, other): 

122 """Returns a Dimension that combines the information in `self` and 

123 `other`. 

124 

125 Dimensions are combined as follows: 

126 

127 ```python 

128 tf.Dimension(n) .merge_with(tf.Dimension(n)) == tf.Dimension(n) 

129 tf.Dimension(n) .merge_with(tf.Dimension(None)) == tf.Dimension(n) 

130 tf.Dimension(None).merge_with(tf.Dimension(n)) == tf.Dimension(n) 

131 tf.Dimension(None).merge_with(tf.Dimension(None)) == tf.Dimension(None) 

132 tf.Dimension(n) .merge_with(tf.Dimension(m)) # raises ValueError for n != m 

133 ``` 

134 

135 Args: 

136 other: Another Dimension. 

137 

138 Returns: 

139 A Dimension containing the combined information of `self` and 

140 `other`. 

141 

142 Raises: 

143 ValueError: If `self` and `other` are not convertible (see 

144 is_convertible_with). 

145 """ 

146 other = as_dimension(other) 

147 self.assert_is_convertible_with(other) 

148 if self._value is None: 

149 return Dimension(other.value) 

150 else: 

151 return Dimension(self._value) 

152 

153 def __add__(self, other): 

154 """Returns the sum of `self` and `other`. 

155 

156 Dimensions are summed as follows: 

157 

158 ```python 

159 tf.Dimension(m) + tf.Dimension(n) == tf.Dimension(m + n) 

160 tf.Dimension(m) + tf.Dimension(None) == tf.Dimension(None) 

161 tf.Dimension(None) + tf.Dimension(n) == tf.Dimension(None) 

162 tf.Dimension(None) + tf.Dimension(None) == tf.Dimension(None) 

163 ``` 

164 

165 Args: 

166 other: Another Dimension, or a value accepted by `as_dimension`. 

167 

168 Returns: 

169 A Dimension whose value is the sum of `self` and `other`. 

170 """ 

171 other = as_dimension(other) 

172 if self._value is None or other.value is None: 

173 return Dimension(None) 

174 else: 

175 return Dimension(self._value + other.value) 

176 

177 def __radd__(self, other): 

178 """Returns the sum of `other` and `self`. 

179 

180 Args: 

181 other: Another Dimension, or a value accepted by `as_dimension`. 

182 

183 Returns: 

184 A Dimension whose value is the sum of `self` and `other`. 

185 """ 

186 return self + other 

187 

188 def __sub__(self, other): 

189 """Returns the subtraction of `other` from `self`. 

190 

191 Dimensions are subtracted as follows: 

192 

193 ```python 

194 tf.Dimension(m) - tf.Dimension(n) == tf.Dimension(m - n) 

195 tf.Dimension(m) - tf.Dimension(None) == tf.Dimension(None) 

196 tf.Dimension(None) - tf.Dimension(n) == tf.Dimension(None) 

197 tf.Dimension(None) - tf.Dimension(None) == tf.Dimension(None) 

198 ``` 

199 

200 Args: 

201 other: Another Dimension, or a value accepted by `as_dimension`. 

202 

203 Returns: 

204 A Dimension whose value is the subtraction of `other` from `self`. 

205 """ 

206 other = as_dimension(other) 

207 if self._value is None or other.value is None: 

208 return Dimension(None) 

209 else: 

210 return Dimension(self._value - other.value) 

211 

212 def __rsub__(self, other): 

213 """Returns the subtraction of `self` from `other`. 

214 

215 Args: 

216 other: Another Dimension, or a value accepted by `as_dimension`. 

217 

218 Returns: 

219 A Dimension whose value is the subtraction of `self` from `other`. 

220 """ 

221 other = as_dimension(other) 

222 if self._value is None or other.value is None: 

223 return Dimension(None) 

224 else: 

225 return Dimension(other.value - self._value) 

226 

227 def __mul__(self, other): 

228 """Returns the product of `self` and `other`. 

229 

230 Dimensions are summed as follows: 

231 

232 ```python 

233 tf.Dimension(m) * tf.Dimension(n) == tf.Dimension(m * n) 

234 tf.Dimension(m) * tf.Dimension(None) == tf.Dimension(None) 

235 tf.Dimension(None) * tf.Dimension(n) == tf.Dimension(None) 

236 tf.Dimension(None) * tf.Dimension(None) == tf.Dimension(None) 

237 ``` 

238 

239 Args: 

240 other: Another Dimension, or a value accepted by `as_dimension`. 

241 

242 Returns: 

243 A Dimension whose value is the product of `self` and `other`. 

244 """ 

245 try: 

246 other = as_dimension(other) 

247 except (TypeError, ValueError): 

248 return NotImplemented 

249 

250 if self._value is None or other.value is None: 

251 return Dimension(None) 

252 else: 

253 return Dimension(self._value * other.value) 

254 

255 def __rmul__(self, other): 

256 """Returns the product of `self` and `other`. 

257 

258 Args: 

259 other: Another Dimension, or a value accepted by `as_dimension`. 

260 

261 Returns: 

262 A Dimension whose value is the product of `self` and `other`. 

263 """ 

264 return self * other 

265 

266 def __floordiv__(self, other): 

267 """Returns the quotient of `self` and `other` rounded down. 

268 

269 Dimensions are divided as follows: 

270 

271 ```python 

272 tf.Dimension(m) // tf.Dimension(n) == tf.Dimension(m // n) 

273 tf.Dimension(m) // tf.Dimension(None) == tf.Dimension(None) 

274 tf.Dimension(None) // tf.Dimension(n) == tf.Dimension(None) 

275 tf.Dimension(None) // tf.Dimension(None) == tf.Dimension(None) 

276 ``` 

277 

278 Args: 

279 other: Another Dimension, or a value accepted by `as_dimension`. 

280 

281 Returns: 

282 A `Dimension` whose value is the integer quotient of `self` and `other`. 

283 """ 

284 try: 

285 other = as_dimension(other) 

286 except (TypeError, ValueError): 

287 return NotImplemented 

288 if self._value is None or other.value is None: 

289 return Dimension(None) 

290 else: 

291 return Dimension(self._value // other.value) 

292 

293 def __rfloordiv__(self, other): 

294 """Returns the quotient of `other` and `self` rounded down. 

295 

296 Args: 

297 other: Another Dimension, or a value accepted by `as_dimension`. 

298 

299 Returns: 

300 A `Dimension` whose value is the integer quotient of `self` and `other`. 

301 """ 

302 other = as_dimension(other) 

303 if self._value is None or other.value is None: 

304 return Dimension(None) 

305 else: 

306 return Dimension(other.value // self._value) 

307 

308 def __div__(self, other): 

309 """DEPRECATED: Use `__floordiv__` via `x // y` instead. 

310 

311 This function exists only for backwards convertibility purposes; new code 

312 should use `__floordiv__` via the syntax `x // y`. Using `x // y` 

313 communicates clearly that the result rounds down, and is forward convertible 

314 to Python 3. 

315 

316 Args: 

317 other: Another `Dimension`. 

318 

319 Returns: 

320 A `Dimension` whose value is the integer quotient of `self` and `other`. 

321 """ 

322 return self // other 

323 

324 def __mod__(self, other): 

325 """Returns `self` modulo `other`. 

326 

327 Dimension moduli are computed as follows: 

328 

329 ```python 

330 tf.Dimension(m) % tf.Dimension(n) == tf.Dimension(m % n) 

331 tf.Dimension(m) % tf.Dimension(None) == tf.Dimension(None) 

332 tf.Dimension(None) % tf.Dimension(n) == tf.Dimension(None) 

333 tf.Dimension(None) % tf.Dimension(None) == tf.Dimension(None) 

334 ``` 

335 

336 Args: 

337 other: Another Dimension, or a value accepted by `as_dimension`. 

338 

339 Returns: 

340 A Dimension whose value is `self` modulo `other`. 

341 """ 

342 try: 

343 other = as_dimension(other) 

344 except (TypeError, ValueError): 

345 return NotImplemented 

346 if self._value is None or other.value is None: 

347 return Dimension(None) 

348 else: 

349 return Dimension(self._value % other.value) 

350 

351 def __rmod__(self, other): 

352 """Returns `other` modulo `self`. 

353 

354 Args: 

355 other: Another Dimension, or a value accepted by `as_dimension`. 

356 

357 Returns: 

358 A Dimension whose value is `other` modulo `self`. 

359 """ 

360 try: 

361 other = as_dimension(other) 

362 except (TypeError, ValueError): 

363 return NotImplemented 

364 return other % self 

365 

366 def __lt__(self, other): 

367 """Returns True if `self` is known to be less than `other`. 

368 

369 Dimensions are compared as follows: 

370 

371 ```python 

372 (tf.Dimension(m) < tf.Dimension(n)) == (m < n) 

373 (tf.Dimension(m) < tf.Dimension(None)) == None 

374 (tf.Dimension(None) < tf.Dimension(n)) == None 

375 (tf.Dimension(None) < tf.Dimension(None)) == None 

376 ``` 

377 

378 Args: 

379 other: Another Dimension. 

380 

381 Returns: 

382 The value of `self.value < other.value` if both are known, otherwise 

383 None. 

384 """ 

385 other = as_dimension(other) 

386 if self._value is None or other.value is None: 

387 return None 

388 else: 

389 return self._value < other.value 

390 

391 def __le__(self, other): 

392 """Returns True if `self` is known to be less than or equal to `other`. 

393 

394 Dimensions are compared as follows: 

395 

396 ```python 

397 (tf.Dimension(m) <= tf.Dimension(n)) == (m <= n) 

398 (tf.Dimension(m) <= tf.Dimension(None)) == None 

399 (tf.Dimension(None) <= tf.Dimension(n)) == None 

400 (tf.Dimension(None) <= tf.Dimension(None)) == None 

401 ``` 

402 

403 Args: 

404 other: Another Dimension. 

405 

406 Returns: 

407 The value of `self.value <= other.value` if both are known, otherwise 

408 None. 

409 """ 

410 other = as_dimension(other) 

411 if self._value is None or other.value is None: 

412 return None 

413 else: 

414 return self._value <= other.value 

415 

416 def __gt__(self, other): 

417 """Returns True if `self` is known to be greater than `other`. 

418 

419 Dimensions are compared as follows: 

420 

421 ```python 

422 (tf.Dimension(m) > tf.Dimension(n)) == (m > n) 

423 (tf.Dimension(m) > tf.Dimension(None)) == None 

424 (tf.Dimension(None) > tf.Dimension(n)) == None 

425 (tf.Dimension(None) > tf.Dimension(None)) == None 

426 ``` 

427 

428 Args: 

429 other: Another Dimension. 

430 

431 Returns: 

432 The value of `self.value > other.value` if both are known, otherwise 

433 None. 

434 """ 

435 other = as_dimension(other) 

436 if self._value is None or other.value is None: 

437 return None 

438 else: 

439 return self._value > other.value 

440 

441 def __ge__(self, other): 

442 """Returns True if `self` is known to be greater than or equal to 

443 `other`. 

444 

445 Dimensions are compared as follows: 

446 

447 ```python 

448 (tf.Dimension(m) >= tf.Dimension(n)) == (m >= n) 

449 (tf.Dimension(m) >= tf.Dimension(None)) == None 

450 (tf.Dimension(None) >= tf.Dimension(n)) == None 

451 (tf.Dimension(None) >= tf.Dimension(None)) == None 

452 ``` 

453 

454 Args: 

455 other: Another Dimension. 

456 

457 Returns: 

458 The value of `self.value >= other.value` if both are known, otherwise 

459 None. 

460 """ 

461 other = as_dimension(other) 

462 if self._value is None or other.value is None: 

463 return None 

464 else: 

465 return self._value >= other.value 

466 

467 def __reduce__(self): 

468 return Dimension, (self._value,) 

469 

470 

471def as_dimension(value): 

472 """Converts the given value to a Dimension. 

473 

474 A Dimension input will be returned unmodified. 

475 An input of `None` will be converted to an unknown Dimension. 

476 An integer input will be converted to a Dimension with that value. 

477 

478 Args: 

479 value: The value to be converted. 

480 

481 Returns: 

482 A Dimension corresponding to the given value. 

483 """ 

484 if isinstance(value, Dimension): 

485 return value 

486 else: 

487 return Dimension(value) 

488 

489 

490# @tf_export("TensorShape") 

491class TensorShape: 

492 """Represents the shape of a `Tensor`. 

493 

494 A `TensorShape` represents a possibly-partial shape specification for a 

495 `Tensor`. It may be one of the following: 

496 

497 * *Fully-known shape:* has a known number of dimensions and a known size 

498 for each dimension. e.g. `TensorShape([16, 256])` 

499 * *Partially-known shape:* has a known number of dimensions, and an unknown 

500 size for one or more dimension. e.g. `TensorShape([None, 256])` 

501 * *Unknown shape:* has an unknown number of dimensions, and an unknown 

502 size in all dimensions. e.g. `TensorShape(None)` 

503 

504 If a tensor is produced by an operation of type `"Foo"`, its shape 

505 may be inferred if there is a registered shape function for 

506 `"Foo"`. See @{$adding_an_op#shape-functions-in-c$`Shape functions in C++`} 

507 for details of shape functions and how to register them. Alternatively, 

508 the shape may be set explicitly using @{tf.Tensor.set_shape}. 

509 """ 

510 

511 def __init__(self, dims): 

512 """Creates a new TensorShape with the given dimensions. 

513 

514 Args: 

515 dims: A list of Dimensions, or None if the shape is unspecified. 

516 DEPRECATED: A single integer is treated as a singleton list. 

517 

518 Raises: 

519 TypeError: If dims cannot be converted to a list of dimensions. 

520 """ 

521 # TODO(irving): Eliminate the single integer special case. 

522 if dims is None: 

523 self._dims = None 

524 elif isinstance(dims, compat.bytes_or_text_types): 

525 raise TypeError( 

526 "A string has ambiguous TensorShape, please wrap in a " 

527 "list or convert to an int: %s" % dims 

528 ) 

529 elif isinstance(dims, tensor_shape_pb2.TensorShapeProto): 

530 if dims.unknown_rank: 

531 self._dims = None 

532 else: 

533 self._dims = [ 

534 # Protos store variable-size dimensions as -1 

535 as_dimension(dim.size if dim.size != -1 else None) 

536 for dim in dims.dim 

537 ] 

538 elif isinstance(dims, TensorShape): 

539 self._dims = dims.dims 

540 else: 

541 try: 

542 dims_iter = iter(dims) 

543 except TypeError: 

544 # Treat as a singleton dimension 

545 self._dims = [as_dimension(dims)] 

546 else: 

547 # Got a list of dimensions 

548 self._dims = [as_dimension(d) for d in dims_iter] 

549 self._ndims = None 

550 

551 def __repr__(self): 

552 return "TensorShape(%r)" % self._dims 

553 

554 def __str__(self): 

555 if self.ndims is None: 

556 return "<unknown>" 

557 elif self.ndims == 1: 

558 return "(%s,)" % self._dims[0] 

559 else: 

560 return "(%s)" % ", ".join(str(d) for d in self._dims) 

561 

562 @property 

563 def dims(self): 

564 """Returns a list of Dimensions, or None if the shape is 

565 unspecified.""" 

566 return self._dims 

567 

568 @dims.setter 

569 def dims(self, dims): 

570 self._dims = dims 

571 self._ndims = None 

572 

573 @property 

574 def ndims(self): 

575 """Returns the rank of this shape, or None if it is unspecified.""" 

576 if self._dims is None: 

577 return None 

578 else: 

579 if self._ndims is None: 

580 self._ndims = len(self._dims) 

581 return self._ndims 

582 

583 def __len__(self): 

584 """Returns the rank of this shape, or raises ValueError if 

585 unspecified.""" 

586 if self._dims is None: 

587 raise ValueError( 

588 "Cannot take the length of Shape with unknown rank." 

589 ) 

590 return self.ndims 

591 

592 def __bool__(self): 

593 """Returns True if this shape contains non-zero information.""" 

594 return self._dims is not None 

595 

596 # Python 3 wants __bool__, Python 2.7 wants __nonzero__ 

597 __nonzero__ = __bool__ 

598 

599 def __iter__(self): 

600 """Returns `self.dims` if the rank is known, otherwise raises 

601 ValueError.""" 

602 if self._dims is None: 

603 raise ValueError("Cannot iterate over a shape with unknown rank.") 

604 else: 

605 return iter(self._dims) 

606 

607 def __getitem__(self, key): 

608 """Returns the value of a dimension or a shape, depending on the key. 

609 

610 Args: 

611 key: If `key` is an integer, returns the dimension at that index; 

612 otherwise if `key` is a slice, returns a TensorShape whose 

613 dimensions are those selected by the slice from `self`. 

614 

615 Returns: 

616 A dimension if `key` is an integer, or a `TensorShape` if `key` is a 

617 slice. 

618 

619 Raises: 

620 ValueError: If `key` is a slice, and any of its elements are negative, or 

621 if `self` is completely unknown and the step is set. 

622 """ 

623 if self._dims is not None: 

624 if isinstance(key, slice): 

625 return TensorShape(self._dims[key]) 

626 else: 

627 return self._dims[key] 

628 else: 

629 if isinstance(key, slice): 

630 start = key.start if key.start is not None else 0 

631 stop = key.stop 

632 

633 if key.step is not None: 

634 # TODO(mrry): Handle these maybe. 

635 raise ValueError("Steps are not yet handled") 

636 if stop is None: 

637 # NOTE(mrry): This implies that TensorShape(None) is convertible with 

638 # TensorShape(None)[1:], which is obviously not true. It would be 

639 # possible to track the number of dimensions symbolically, 

640 # and perhaps we should do that. 

641 return unknown_shape() 

642 elif start < 0 or stop < 0: 

643 # TODO(mrry): Handle this better, as it will be useful for handling 

644 # suffixes of otherwise unknown shapes. 

645 return unknown_shape() 

646 else: 

647 return unknown_shape(ndims=stop - start) 

648 else: 

649 return Dimension(None) 

650 

651 def num_elements(self): 

652 """Returns the total number of elements, or none for incomplete 

653 shapes.""" 

654 if self.is_fully_defined(): 

655 size = 1 

656 for dim in self._dims: 

657 size *= dim.value 

658 return size 

659 else: 

660 return None 

661 

662 def merge_with(self, other): 

663 """Returns a `TensorShape` combining the information in `self` and 

664 `other`. 

665 

666 The dimensions in `self` and `other` are merged elementwise, 

667 according to the rules defined for `Dimension.merge_with()`. 

668 

669 Args: 

670 other: Another `TensorShape`. 

671 

672 Returns: 

673 A `TensorShape` containing the combined information of `self` and 

674 `other`. 

675 

676 Raises: 

677 ValueError: If `self` and `other` are not convertible. 

678 """ 

679 other = as_shape(other) 

680 if self._dims is None: 

681 return other 

682 else: 

683 try: 

684 self.assert_same_rank(other) 

685 new_dims = [] 

686 for i, dim in enumerate(self._dims): 

687 new_dims.append(dim.merge_with(other[i])) 

688 return TensorShape(new_dims) 

689 except ValueError: 

690 raise ValueError( 

691 "Shapes %s and %s are not convertible" % (self, other) 

692 ) 

693 

694 def concatenate(self, other): 

695 """Returns the concatenation of the dimension in `self` and `other`. 

696 

697 *N.B.* If either `self` or `other` is completely unknown, 

698 concatenation will discard information about the other shape. In 

699 future, we might support concatenation that preserves this 

700 information for use with slicing. 

701 

702 Args: 

703 other: Another `TensorShape`. 

704 

705 Returns: 

706 A `TensorShape` whose dimensions are the concatenation of the 

707 dimensions in `self` and `other`. 

708 """ 

709 # TODO(mrry): Handle the case where we concatenate a known shape with a 

710 # completely unknown shape, so that we can use the partial information. 

711 other = as_shape(other) 

712 if self._dims is None or other.dims is None: 

713 return unknown_shape() 

714 else: 

715 return TensorShape(self._dims + other.dims) 

716 

717 def assert_same_rank(self, other): 

718 """Raises an exception if `self` and `other` do not have convertible 

719 ranks. 

720 

721 Args: 

722 other: Another `TensorShape`. 

723 

724 Raises: 

725 ValueError: If `self` and `other` do not represent shapes with the 

726 same rank. 

727 """ 

728 other = as_shape(other) 

729 if self.ndims is not None and other.ndims is not None: 

730 if self.ndims != other.ndims: 

731 raise ValueError( 

732 "Shapes %s and %s must have the same rank" % (self, other) 

733 ) 

734 

735 def assert_has_rank(self, rank): 

736 """Raises an exception if `self` is not convertible with the given 

737 `rank`. 

738 

739 Args: 

740 rank: An integer. 

741 

742 Raises: 

743 ValueError: If `self` does not represent a shape with the given `rank`. 

744 """ 

745 if self.ndims not in (None, rank): 

746 raise ValueError("Shape %s must have rank %d" % (self, rank)) 

747 

748 def with_rank(self, rank): 

749 """Returns a shape based on `self` with the given rank. 

750 

751 This method promotes a completely unknown shape to one with a 

752 known rank. 

753 

754 Args: 

755 rank: An integer. 

756 

757 Returns: 

758 A shape that is at least as specific as `self` with the given rank. 

759 

760 Raises: 

761 ValueError: If `self` does not represent a shape with the given `rank`. 

762 """ 

763 try: 

764 return self.merge_with(unknown_shape(ndims=rank)) 

765 except ValueError: 

766 raise ValueError("Shape %s must have rank %d" % (self, rank)) 

767 

768 def with_rank_at_least(self, rank): 

769 """Returns a shape based on `self` with at least the given rank. 

770 

771 Args: 

772 rank: An integer. 

773 

774 Returns: 

775 A shape that is at least as specific as `self` with at least the given 

776 rank. 

777 

778 Raises: 

779 ValueError: If `self` does not represent a shape with at least the given 

780 `rank`. 

781 """ 

782 if self.ndims is not None and self.ndims < rank: 

783 raise ValueError( 

784 "Shape %s must have rank at least %d" % (self, rank) 

785 ) 

786 else: 

787 return self 

788 

789 def with_rank_at_most(self, rank): 

790 """Returns a shape based on `self` with at most the given rank. 

791 

792 Args: 

793 rank: An integer. 

794 

795 Returns: 

796 A shape that is at least as specific as `self` with at most the given 

797 rank. 

798 

799 Raises: 

800 ValueError: If `self` does not represent a shape with at most the given 

801 `rank`. 

802 """ 

803 if self.ndims is not None and self.ndims > rank: 

804 raise ValueError( 

805 "Shape %s must have rank at most %d" % (self, rank) 

806 ) 

807 else: 

808 return self 

809 

810 def is_convertible_with(self, other): 

811 """Returns True iff `self` is convertible with `other`. 

812 

813 Two possibly-partially-defined shapes are convertible if there 

814 exists a fully-defined shape that both shapes can represent. Thus, 

815 convertibility allows the shape inference code to reason about 

816 partially-defined shapes. For example: 

817 

818 * TensorShape(None) is convertible with all shapes. 

819 

820 * TensorShape([None, None]) is convertible with all two-dimensional 

821 shapes, such as TensorShape([32, 784]), and also TensorShape(None). It is 

822 not convertible with, for example, TensorShape([None]) or 

823 TensorShape([None, None, None]). 

824 

825 * TensorShape([32, None]) is convertible with all two-dimensional shapes 

826 with size 32 in the 0th dimension, and also TensorShape([None, None]) 

827 and TensorShape(None). It is not convertible with, for example, 

828 TensorShape([32]), TensorShape([32, None, 1]) or TensorShape([64, None]). 

829 

830 * TensorShape([32, 784]) is convertible with itself, and also 

831 TensorShape([32, None]), TensorShape([None, 784]), TensorShape([None, 

832 None]) and TensorShape(None). It is not convertible with, for example, 

833 TensorShape([32, 1, 784]) or TensorShape([None]). 

834 

835 The convertibility relation is reflexive and symmetric, but not 

836 transitive. For example, TensorShape([32, 784]) is convertible with 

837 TensorShape(None), and TensorShape(None) is convertible with 

838 TensorShape([4, 4]), but TensorShape([32, 784]) is not convertible with 

839 TensorShape([4, 4]). 

840 

841 Args: 

842 other: Another TensorShape. 

843 

844 Returns: 

845 True iff `self` is convertible with `other`. 

846 """ 

847 other = as_shape(other) 

848 if self._dims is not None and other.dims is not None: 

849 if self.ndims != other.ndims: 

850 return False 

851 for x_dim, y_dim in zip(self._dims, other.dims): 

852 if not x_dim.is_convertible_with(y_dim): 

853 return False 

854 return True 

855 

856 def assert_is_convertible_with(self, other): 

857 """Raises exception if `self` and `other` do not represent the same 

858 shape. 

859 

860 This method can be used to assert that there exists a shape that both 

861 `self` and `other` represent. 

862 

863 Args: 

864 other: Another TensorShape. 

865 

866 Raises: 

867 ValueError: If `self` and `other` do not represent the same shape. 

868 """ 

869 if not self.is_convertible_with(other): 

870 raise ValueError( 

871 "Shapes %s and %s are inconvertible" % (self, other) 

872 ) 

873 

874 def most_specific_convertible_shape(self, other): 

875 """Returns the most specific TensorShape convertible with `self` and 

876 `other`. 

877 

878 * TensorShape([None, 1]) is the most specific TensorShape convertible with 

879 both TensorShape([2, 1]) and TensorShape([5, 1]). Note that 

880 TensorShape(None) is also convertible with above mentioned TensorShapes. 

881 

882 * TensorShape([1, 2, 3]) is the most specific TensorShape convertible with 

883 both TensorShape([1, 2, 3]) and TensorShape([1, 2, 3]). There are more 

884 less specific TensorShapes convertible with above mentioned TensorShapes, 

885 e.g. TensorShape([1, 2, None]), TensorShape(None). 

886 

887 Args: 

888 other: Another `TensorShape`. 

889 

890 Returns: 

891 A `TensorShape` which is the most specific convertible shape of `self` 

892 and `other`. 

893 """ 

894 

895 other = as_shape(other) 

896 if ( 

897 self._dims is None 

898 or other.dims is None 

899 or self.ndims != other.ndims 

900 ): 

901 return unknown_shape() 

902 

903 dims = [(Dimension(None))] * self.ndims 

904 for i, (d1, d2) in enumerate(zip(self._dims, other.dims)): 

905 if d1 is not None and d2 is not None and d1 == d2: 

906 dims[i] = d1 

907 return TensorShape(dims) 

908 

909 def is_fully_defined(self): 

910 """Returns True iff `self` is fully defined in every dimension.""" 

911 return self._dims is not None and all( 

912 dim.value is not None for dim in self._dims 

913 ) 

914 

915 def assert_is_fully_defined(self): 

916 """Raises an exception if `self` is not fully defined in every 

917 dimension. 

918 

919 Raises: 

920 ValueError: If `self` does not have a known value for every dimension. 

921 """ 

922 if not self.is_fully_defined(): 

923 raise ValueError("Shape %s is not fully defined" % self) 

924 

925 def as_list(self): 

926 """Returns a list of integers or `None` for each dimension. 

927 

928 Returns: 

929 A list of integers or `None` for each dimension. 

930 

931 Raises: 

932 ValueError: If `self` is an unknown shape with an unknown rank. 

933 """ 

934 if self._dims is None: 

935 raise ValueError( 

936 "as_list() is not defined on an unknown TensorShape." 

937 ) 

938 return [dim.value for dim in self._dims] 

939 

940 def as_proto(self): 

941 """Returns this shape as a `TensorShapeProto`.""" 

942 if self._dims is None: 

943 return tensor_shape_pb2.TensorShapeProto(unknown_rank=True) 

944 else: 

945 return tensor_shape_pb2.TensorShapeProto( 

946 dim=[ 

947 tensor_shape_pb2.TensorShapeProto.Dim( 

948 size=-1 if d.value is None else d.value 

949 ) 

950 for d in self._dims 

951 ] 

952 ) 

953 

954 def __eq__(self, other): 

955 """Returns True if `self` is equivalent to `other`.""" 

956 try: 

957 other = as_shape(other) 

958 except TypeError: 

959 return NotImplemented 

960 return self._dims == other.dims 

961 

962 def __ne__(self, other): 

963 """Returns True if `self` is known to be different from `other`.""" 

964 try: 

965 other = as_shape(other) 

966 except TypeError: 

967 return NotImplemented 

968 if self.ndims is None or other.ndims is None: 

969 raise ValueError( 

970 "The inequality of unknown TensorShapes is undefined." 

971 ) 

972 if self.ndims != other.ndims: 

973 return True 

974 return self._dims != other.dims 

975 

976 def __reduce__(self): 

977 return TensorShape, (self._dims,) 

978 

979 

980def as_shape(shape): 

981 """Converts the given object to a TensorShape.""" 

982 if isinstance(shape, TensorShape): 

983 return shape 

984 else: 

985 return TensorShape(shape) 

986 

987 

988def unknown_shape(ndims=None): 

989 """Returns an unknown TensorShape, optionally with a known rank. 

990 

991 Args: 

992 ndims: (Optional) If specified, the number of dimensions in the shape. 

993 

994 Returns: 

995 An unknown TensorShape. 

996 """ 

997 if ndims is None: 

998 return TensorShape(None) 

999 else: 

1000 return TensorShape([Dimension(None)] * ndims) 

1001 

1002 

1003_SCALAR_SHAPE = TensorShape([]) 

1004 

1005 

1006def scalar(): 

1007 """Returns a shape representing a scalar.""" 

1008 return _SCALAR_SHAPE 

1009 

1010 

1011def vector(length): 

1012 """Returns a shape representing a vector. 

1013 

1014 Args: 

1015 length: The length of the vector, which may be None if unknown. 

1016 

1017 Returns: 

1018 A TensorShape representing a vector of the given length. 

1019 """ 

1020 return TensorShape([length]) 

1021 

1022 

1023def matrix(rows, cols): 

1024 """Returns a shape representing a matrix. 

1025 

1026 Args: 

1027 rows: The number of rows in the matrix, which may be None if unknown. 

1028 cols: The number of columns in the matrix, which may be None if unknown. 

1029 

1030 Returns: 

1031 A TensorShape representing a matrix of the given size. 

1032 """ 

1033 return TensorShape([rows, cols])