Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/row_partition.py: 23%

494 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A class used to partition a sequence into contiguous subsequences ("rows"). 

16""" 

17 

18 

19# TODO(edloper): Make into a ExtensionType (if possible) 

20 

21 

22import numpy as np 

23 

24from tensorflow.core.protobuf import struct_pb2 

25from tensorflow.python.framework import composite_tensor 

26from tensorflow.python.framework import constant_op 

27from tensorflow.python.framework import dtypes 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import tensor_conversion 

30from tensorflow.python.framework import tensor_shape 

31from tensorflow.python.framework import tensor_spec 

32from tensorflow.python.framework import tensor_util 

33from tensorflow.python.framework import type_spec 

34from tensorflow.python.framework import type_spec_registry 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import check_ops 

37from tensorflow.python.ops import control_flow_ops 

38from tensorflow.python.ops import gen_ragged_math_ops 

39from tensorflow.python.ops import math_ops 

40from tensorflow.python.ops.ragged import segment_id_ops 

41from tensorflow.python.saved_model import nested_structure_coder 

42from tensorflow.python.util.tf_export import tf_export 

43 

44# =============================================================================== 

45# RowPartition 

46# =============================================================================== 

47# TODO(edloper): Consider removing row_starts and row_limits factory methods 

48# and accessors from RowPartition. In particular, these two encodings are 

49# "second-class citizens": we never cache them, and if you do construct a 

50# RowPartition from them then it may be more expensive than you might expect 

51# (because we append a value to the beginning/end to transform them into 

52# splits). If we do remove them from RowPartition, then we would still keep 

53# the from_row_starts and from_row_limits factory methods in RaggedTensor. 

54 

55 

56@tf_export("experimental.RowPartition") 

57class RowPartition(composite_tensor.CompositeTensor): 

58 """Partitioning of a sequence of values into contiguous subsequences ("rows"). 

59 

60 A `RowPartition` describes how a sequence with `nvals` items should be 

61 divided into `nrows` contiguous subsequences ("rows"). For example, a 

62 `RowPartition` could be used to partition the vector `[1, 2, 3, 4, 5]` into 

63 subsequences `[[1, 2], [3], [], [4, 5]]`. Note that `RowPartition` stores 

64 information about how values are partitioned, but does not include the 

65 partitioned values themselves. `tf.RaggedTensor` is used to pair a `values` 

66 tensor with one or more `RowPartition`s, providing a complete encoding for a 

67 ragged tensor (i.e. a tensor with variable-length dimensions). 

68 

69 `RowPartition`s may be defined using several different schemes: 

70 

71 * `row_lengths`: an integer vector with shape `[nrows]`, which specifies 

72 the length of each row. 

73 

74 * `row_splits`: an integer vector with shape `[nrows+1]`, specifying the 

75 "split points" between each row. 

76 

77 * `row_starts`: an integer vector with shape `[nrows]`, which specifies 

78 the start offset for each row. Equivalent to `row_splits[:-1]`. 

79 

80 * `row_limits`: an integer vector with shape `[nrows]`, which specifies 

81 the stop offset for each row. Equivalent to `row_splits[1:]`. 

82 

83 * `value_rowids` is an integer vector with shape `[nvals]`, corresponding 

84 one-to-one with sequence values, which specifies the row that each value 

85 belongs to. If the partition has empty trailing rows, then `nrows` 

86 must also be specified. 

87 

88 * `uniform_row_length` is an integer scalar, specifying the length of every 

89 row. This scheme may only be used if all rows have the same length. 

90 

91 For example, the following `RowPartition`s all represent the partitioning of 

92 8 values into 5 sublists as follows: `[[*, *, *, *], [], [*, *, *], [*], []]`. 

93 

94 >>> p1 = RowPartition.from_row_lengths([4, 0, 3, 1, 0]) 

95 >>> p2 = RowPartition.from_row_splits([0, 4, 4, 7, 8, 8]) 

96 >>> p3 = RowPartition.from_row_starts([0, 4, 4, 7, 8], nvals=8) 

97 >>> p4 = RowPartition.from_row_limits([4, 4, 7, 8, 8]) 

98 >>> p5 = RowPartition.from_value_rowids([0, 0, 0, 0, 2, 2, 2, 3], nrows=5) 

99 

100 For more information about each scheme, see the documentation for the 

101 its factory method. For additional examples, see the documentation on 

102 `tf.RaggedTensor`. 

103 

104 ### Precomputed Encodings 

105 

106 `RowPartition` always stores at least one encoding of the partitioning, but 

107 it can be configured to cache additional encodings as well. This can 

108 avoid unnecessary recomputation in eager mode. (In graph mode, optimizations 

109 such as common subexpression elimination will typically prevent these 

110 unnecessary recomputations.) To check which encodings are precomputed, use 

111 `RowPartition.has_precomputed_<encoding>`. To cache an additional 

112 encoding, use `RowPartition.with_precomputed_<encoding>`. 

113 """ 

114 

115 # ============================================================================= 

116 # Constructor (private) 

117 # ============================================================================= 

118 def __init__(self, 

119 row_splits, 

120 row_lengths=None, 

121 value_rowids=None, 

122 nrows=None, 

123 uniform_row_length=None, 

124 nvals=None, 

125 internal=False): 

126 """Creates a `RowPartition` from the specified encoding tensor(s). 

127 

128 This constructor is private -- please use one of the following ops to 

129 build `RowPartition`s: 

130 

131 * `RowPartition.from_row_lengths` 

132 * `RowPartition.from_value_rowids` 

133 * `RowPartition.from_row_splits` 

134 * `RowPartition.from_row_starts` 

135 * `RowPartition.from_row_limits` 

136 * `RowPartition.from_uniform_row_length` 

137 

138 If row_splits is has a constant value, then all other arguments should 

139 have a constant value. 

140 

141 Args: 

142 row_splits: A 1-D integer tensor with shape `[nrows+1]`. 

143 row_lengths: A 1-D integer tensor with shape `[nrows]` 

144 value_rowids: A 1-D integer tensor with shape `[nvals]`. 

145 nrows: A 1-D integer scalar tensor. 

146 uniform_row_length: A scalar tensor. 

147 nvals: A scalar tensor. 

148 internal: Private key value, required to ensure that this private 

149 constructor is *only* called from the factory methods. 

150 

151 Raises: 

152 TypeError: If a row partitioning tensor has an inappropriate dtype. 

153 TypeError: If exactly one row partitioning argument was not specified. 

154 ValueError: If a row partitioning tensor has an inappropriate shape. 

155 ValueError: If multiple partitioning arguments are specified. 

156 ValueError: If nrows is specified but value_rowids is not None. 

157 """ 

158 if internal is not _row_partition_factory_key: 

159 raise ValueError("RowPartition constructor is private; please use one " 

160 "of the factory methods instead (e.g., " 

161 "RowPartition.from_row_lengths())") 

162 

163 # Validate the arguments. 

164 if not isinstance(row_splits, ops.Tensor): 

165 raise TypeError("Row-partitioning argument must be a Tensor, got %r" % 

166 row_splits) 

167 if row_splits.dtype not in (dtypes.int32, dtypes.int64): 

168 raise ValueError("Row-partitioning argument must be int32 or int64") 

169 

170 # Validate shapes & dtypes. 

171 row_splits.shape.assert_has_rank(1) 

172 row_splits.set_shape([None]) 

173 self._row_splits = row_splits 

174 

175 # Store any cached tensors. These are used to avoid unnecessary 

176 # round-trip conversions when a RowPartition is constructed from 

177 # lengths or rowids, and we later want those lengths/rowids back. 

178 for tensor in [row_lengths, value_rowids, nrows, uniform_row_length, nvals]: 

179 if tensor is not None: 

180 if not isinstance(tensor, ops.Tensor): 

181 raise TypeError("Cached value must be a Tensor or None.") 

182 elif tensor.dtype != row_splits.dtype: 

183 raise ValueError(f"Inconsistent dtype for encoding tensors: " 

184 f"{tensor} vs {row_splits}") 

185 self._row_lengths = row_lengths 

186 self._value_rowids = value_rowids 

187 self._nrows = nrows 

188 self._uniform_row_length = uniform_row_length 

189 self._nvals = nvals 

190 

191 # ============================================================================= 

192 # Factory Methods 

193 # ============================================================================= 

194 

195 @classmethod 

196 def from_value_rowids(cls, 

197 value_rowids, 

198 nrows=None, 

199 validate=True, 

200 dtype=None, 

201 dtype_hint=None): 

202 """Creates a `RowPartition` with rows partitioned by `value_rowids`. 

203 

204 This `RowPartition` divides a sequence `values` into rows by specifying 

205 which row each value should be added to: 

206 

207 ```python 

208 partitioned_rows = [[] for _ in nrows] 

209 for (value, rowid) in zip(values, value_rowids): 

210 partitioned_rows[rowid].append(value) 

211 ``` 

212 

213 Args: 

214 value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds 

215 one-to-one with `values`, and specifies each value's row index. Must be 

216 nonnegative, and must be sorted in ascending order. 

217 nrows: An integer scalar specifying the number of rows. This should be 

218 specified if the `RowPartition` may containing empty training rows. Must 

219 be greater than `value_rowids[-1]` (or greater than or equal to zero if 

220 `value_rowids` is empty). Defaults to `value_rowids[-1] + 1` (or zero if 

221 `value_rowids` is empty). 

222 validate: If true, then use assertions to check that the arguments form a 

223 valid `RowPartition`. 

224 dtype: Optional dtype for the RowPartition. If missing, the type 

225 is inferred from the type of `value_rowids`, dtype_hint, or tf.int64. 

226 dtype_hint: Optional dtype for the RowPartition, used when dtype 

227 is None. In some cases, a caller may not have a dtype in mind when 

228 converting to a tensor, so dtype_hint can be used as a soft preference. 

229 If the conversion to `dtype_hint` is not possible, this argument has no 

230 effect. 

231 

232 Returns: 

233 A `RowPartition`. 

234 

235 Raises: 

236 ValueError: If `nrows` is incompatible with `value_rowids`. 

237 

238 #### Example: 

239 

240 >>> print(RowPartition.from_value_rowids( 

241 ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], 

242 ... nrows=4)) 

243 tf.RowPartition(row_splits=[0 4 4 7 8]) 

244 """ 

245 # Local import bincount_ops to avoid import-cycle since bincount_ops 

246 # imports ragged_tensor. 

247 from tensorflow.python.ops import bincount_ops # pylint: disable=g-import-not-at-top 

248 if not isinstance(validate, bool): 

249 raise TypeError("validate must have type bool") 

250 with ops.name_scope(None, "RowPartitionFromValueRowIds", 

251 [value_rowids, nrows]): 

252 value_rowids = cls._convert_row_partition( 

253 value_rowids, "value_rowids", dtype_hint=dtype_hint, dtype=dtype) 

254 if nrows is None: 

255 const_rowids = tensor_util.constant_value(value_rowids) 

256 if const_rowids is None: 

257 nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 

258 const_nrows = None 

259 else: 

260 const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 

261 nrows = ops.convert_to_tensor( 

262 const_nrows, value_rowids.dtype, name="nrows") 

263 else: 

264 nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows") 

265 const_nrows = tensor_util.constant_value(nrows) 

266 if const_nrows is not None: 

267 if const_nrows < 0: 

268 raise ValueError("Expected nrows >= 0; got %d" % const_nrows) 

269 const_rowids = tensor_util.constant_value(value_rowids) 

270 if const_rowids is not None and const_rowids.size > 0: 

271 if not const_nrows >= const_rowids[-1] + 1: 

272 raise ValueError( 

273 "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " 

274 "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) 

275 

276 value_rowids.shape.assert_has_rank(1) 

277 nrows.shape.assert_has_rank(0) 

278 

279 if validate: 

280 msg = ("Arguments to from_value_rowids do not form a valid " 

281 "RowPartition") 

282 checks = [ 

283 check_ops.assert_rank(value_rowids, 1, message=msg), 

284 check_ops.assert_rank(nrows, 0, message=msg), 

285 check_ops.assert_non_negative(value_rowids[:1], message=msg), 

286 _assert_monotonic_increasing(value_rowids, message=msg), 

287 check_ops.assert_less(value_rowids[-1:], nrows, message=msg), 

288 ] 

289 value_rowids = control_flow_ops.with_dependencies(checks, value_rowids) 

290 

291 # Convert value_rowids & nrows to row_splits. 

292 # Note: we don't use segment_ids_to_row_splits() here because we want 

293 # to save the intermediate value `row_lengths`, so we can cache it. 

294 # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the 

295 # cast. 

296 value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) 

297 nrows_int32 = math_ops.cast(nrows, dtypes.int32) 

298 row_lengths = bincount_ops.bincount( 

299 value_rowids_int32, 

300 minlength=nrows_int32, 

301 maxlength=nrows_int32, 

302 dtype=value_rowids.dtype) 

303 row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) 

304 if const_nrows is not None: 

305 row_lengths.set_shape([const_nrows]) 

306 row_splits.set_shape([const_nrows + 1]) 

307 

308 return cls( 

309 row_splits=row_splits, 

310 row_lengths=row_lengths, 

311 value_rowids=value_rowids, 

312 nrows=nrows, 

313 internal=_row_partition_factory_key) 

314 

315 @classmethod 

316 def from_row_splits(cls, 

317 row_splits, 

318 validate=True, 

319 dtype=None, 

320 dtype_hint=None): 

321 """Creates a `RowPartition` with rows partitioned by `row_splits`. 

322 

323 This `RowPartition` divides a sequence `values` into rows by indicating 

324 where each row begins and ends: 

325 

326 ```python 

327 partitioned_rows = [] 

328 for i in range(len(row_splits) - 1): 

329 row_start = row_splits[i] 

330 row_end = row_splits[i + 1] 

331 partitioned_rows.append(values[row_start:row_end]) 

332 ``` 

333 

334 Args: 

335 row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be 

336 empty, and must be sorted in ascending order. `row_splits[0]` must be 

337 zero. 

338 validate: If true, then use assertions to check that the arguments form a 

339 valid `RowPartition`. 

340 dtype: Optional dtype for the RowPartition. If missing, the type 

341 is inferred from the type of `row_splits`, dtype_hint, or tf.int64. 

342 dtype_hint: Optional dtype for the RowPartition, used when dtype 

343 is None. In some cases, a caller may not have a dtype in mind when 

344 converting to a tensor, so dtype_hint can be used as a soft preference. 

345 If the conversion to `dtype_hint` is not possible, this argument has no 

346 effect. 

347 

348 Returns: 

349 A `RowPartition`. 

350 

351 Raises: 

352 ValueError: If `row_splits` is an empty list. 

353 """ 

354 if not isinstance(validate, bool): 

355 raise TypeError("validate must have type bool") 

356 if isinstance(row_splits, (list, tuple)) and not row_splits: 

357 raise ValueError("row_splits tensor may not be empty.") 

358 if isinstance(row_splits, tensor_spec.TensorSpec): 

359 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 

360 

361 with ops.name_scope(None, "RowPartitionFromRowSplits", [row_splits]): 

362 row_splits = cls._convert_row_partition( 

363 row_splits, "row_splits", dtype_hint=dtype_hint, dtype=dtype) 

364 row_splits.shape.assert_has_rank(1) 

365 

366 if validate: 

367 msg = "Arguments to from_row_splits do not form a valid RaggedTensor:" 

368 checks = [ 

369 check_ops.assert_rank(row_splits, 1, message=(msg + "rank")), 

370 _assert_zero(row_splits[0], message=(msg + "zero")), 

371 _assert_monotonic_increasing( 

372 row_splits, message=(msg + "monotonic")), 

373 ] 

374 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 

375 

376 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 

377 

378 @classmethod 

379 def from_row_lengths(cls, 

380 row_lengths, 

381 validate=True, 

382 dtype=None, 

383 dtype_hint=None): 

384 """Creates a `RowPartition` with rows partitioned by `row_lengths`. 

385 

386 This `RowPartition` divides a sequence `values` into rows by indicating 

387 the length of each row: 

388 

389 ```python 

390 partitioned_rows = [[values.pop(0) for _ in range(length)] 

391 for length in row_lengths] 

392 ``` 

393 

394 Args: 

395 row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be 

396 nonnegative. 

397 validate: If true, then use assertions to check that the arguments form a 

398 valid `RowPartition`. 

399 

400 dtype: Optional dtype for the RowPartition. If missing, the type 

401 is inferred from the type of `row_lengths`, dtype_hint, or tf.int64. 

402 dtype_hint: Optional dtype for the RowPartition, used when dtype 

403 is None. In some cases, a caller may not have a dtype in mind when 

404 converting to a tensor, so dtype_hint can be used as a soft preference. 

405 If the conversion to `dtype_hint` is not possible, this argument has no 

406 effect. 

407 

408 Returns: 

409 A `RowPartition`. 

410 """ 

411 if not isinstance(validate, bool): 

412 raise TypeError("validate must have type bool") 

413 with ops.name_scope(None, "RowPartitionFromRowLengths", [row_lengths]): 

414 row_lengths = cls._convert_row_partition( 

415 row_lengths, "row_lengths", dtype_hint=dtype_hint, dtype=dtype) 

416 row_lengths.shape.assert_has_rank(1) 

417 

418 if validate: 

419 msg = "Arguments to from_row_lengths do not form a valid RowPartition" 

420 checks = [ 

421 check_ops.assert_rank(row_lengths, 1, message=msg), 

422 check_ops.assert_non_negative(row_lengths, message=msg), 

423 ] 

424 row_lengths = control_flow_ops.with_dependencies(checks, row_lengths) 

425 

426 row_limits = math_ops.cumsum(row_lengths) 

427 row_splits = array_ops.concat([[0], row_limits], axis=0) 

428 return cls( 

429 row_splits=row_splits, 

430 row_lengths=row_lengths, 

431 internal=_row_partition_factory_key) 

432 

433 @classmethod 

434 def from_row_starts(cls, 

435 row_starts, 

436 nvals, 

437 validate=True, 

438 dtype=None, 

439 dtype_hint=None): 

440 """Creates a `RowPartition` with rows partitioned by `row_starts`. 

441 

442 Equivalent to: `from_row_splits(concat([row_starts, nvals], axis=0))`. 

443 

444 Args: 

445 row_starts: A 1-D integer tensor with shape `[nrows]`. Must be 

446 nonnegative and sorted in ascending order. If `nrows>0`, then 

447 `row_starts[0]` must be zero. 

448 nvals: A scalar tensor indicating the number of values. 

449 validate: If true, then use assertions to check that the arguments form a 

450 valid `RowPartition`. 

451 dtype: Optional dtype for the RowPartition. If missing, the type 

452 is inferred from the type of `row_starts`, dtype_hint, or tf.int64. 

453 dtype_hint: Optional dtype for the RowPartition, used when dtype 

454 is None. In some cases, a caller may not have a dtype in mind when 

455 converting to a tensor, so dtype_hint can be used as a soft preference. 

456 If the conversion to `dtype_hint` is not possible, this argument has no 

457 effect. 

458 

459 Returns: 

460 A `RowPartition`. 

461 """ 

462 if not isinstance(validate, bool): 

463 raise TypeError("validate must have type bool") 

464 with ops.name_scope(None, "RowPartitionFromRowStarts", [row_starts]): 

465 row_starts = cls._convert_row_partition( 

466 row_starts, "row_starts", dtype_hint=dtype_hint, dtype=dtype) 

467 row_starts.shape.assert_has_rank(1) 

468 # TODO(martinz): nvals and row_starts could be inconsistent at call time, 

469 # even though they eventually end up the same type. 

470 nvals = math_ops.cast(nvals, row_starts.dtype) 

471 if validate: 

472 msg = "Arguments to from_row_starts do not form a valid RaggedTensor" 

473 checks = [ 

474 check_ops.assert_rank(row_starts, 1, message=msg), 

475 _assert_zero(row_starts[:1], message=msg), 

476 _assert_monotonic_increasing(row_starts, message=msg), 

477 check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg), 

478 ] 

479 row_starts = control_flow_ops.with_dependencies(checks, row_starts) 

480 

481 row_splits = array_ops.concat([row_starts, [nvals]], axis=0) 

482 return cls(row_splits=row_splits, nvals=nvals, 

483 internal=_row_partition_factory_key) 

484 

485 @classmethod 

486 def from_row_limits(cls, 

487 row_limits, 

488 validate=True, 

489 dtype=None, 

490 dtype_hint=None): 

491 """Creates a `RowPartition` with rows partitioned by `row_limits`. 

492 

493 Equivalent to: `from_row_splits(values, concat([0, row_limits], axis=0))`. 

494 

495 Args: 

496 row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in 

497 ascending order. 

498 validate: If true, then use assertions to check that the arguments form a 

499 valid `RowPartition`. 

500 dtype: Optional dtype for the RowPartition. If missing, the type 

501 is inferred from the type of `row_limits`, dtype_hint, or tf.int64. 

502 dtype_hint: Optional dtype for the RowPartition, used when dtype 

503 is None. In some cases, a caller may not have a dtype in mind when 

504 converting to a tensor, so dtype_hint can be used as a soft preference. 

505 If the conversion to `dtype_hint` is not possible, this argument has no 

506 effect. 

507 

508 Returns: 

509 A `RowPartition`. 

510 """ 

511 if not isinstance(validate, bool): 

512 raise TypeError("validate must have type bool") 

513 with ops.name_scope(None, "RowPartitionFromRowLimits", [row_limits]): 

514 row_limits = cls._convert_row_partition( 

515 row_limits, "row_limits", dtype_hint=dtype_hint, dtype=dtype) 

516 row_limits.shape.assert_has_rank(1) 

517 

518 if validate: 

519 msg = "Arguments to from_row_limits do not form a valid RaggedTensor" 

520 checks = [ 

521 check_ops.assert_rank(row_limits, 1, message=msg), 

522 check_ops.assert_non_negative(row_limits[:1], message=msg), 

523 _assert_monotonic_increasing(row_limits, message=msg), 

524 ] 

525 row_limits = control_flow_ops.with_dependencies(checks, row_limits) 

526 

527 zero = array_ops.zeros([1], row_limits.dtype) 

528 row_splits = array_ops.concat([zero, row_limits], axis=0) 

529 return cls(row_splits=row_splits, internal=_row_partition_factory_key) 

530 

531 @classmethod 

532 def from_uniform_row_length(cls, 

533 uniform_row_length, 

534 nvals=None, 

535 nrows=None, 

536 validate=True, 

537 dtype=None, 

538 dtype_hint=None): 

539 """Creates a `RowPartition` with rows partitioned by `uniform_row_length`. 

540 

541 This `RowPartition` divides a sequence `values` into rows that all have 

542 the same length: 

543 

544 ```python 

545 partitioned_rows = [[values.pop(0) for _ in range(uniform_row_length)] 

546 for _ in range(nrows)] 

547 ``` 

548 

549 Note that either or both of nvals and nrows must be specified. 

550 

551 Args: 

552 uniform_row_length: A scalar integer tensor. Must be nonnegative. The 

553 size of the outer axis of `values` must be evenly divisible by 

554 `uniform_row_length`. 

555 nvals: a non-negative scalar integer tensor for the number of values. 

556 Must be specified if nrows is not specified. If not specified, 

557 defaults to uniform_row_length*nrows 

558 nrows: The number of rows in the constructed RowPartition. If not 

559 specified, then it defaults to `nvals/uniform_row_length` (or `0` if 

560 `uniform_row_length==0`). `nrows` only needs to be specified if 

561 `uniform_row_length` might be zero. `uniform_row_length*nrows` must be 

562 `nvals`. 

563 validate: If true, then use assertions to check that the arguments form a 

564 valid `RowPartition`. 

565 dtype: Optional dtype for the RowPartition. If missing, the type 

566 is inferred from the type of `uniform_row_length`, dtype_hint, 

567 or tf.int64. 

568 dtype_hint: Optional dtype for the RowPartition, used when dtype 

569 is None. In some cases, a caller may not have a dtype in mind when 

570 converting to a tensor, so dtype_hint can be used as a soft preference. 

571 If the conversion to `dtype_hint` is not possible, this argument has no 

572 effect. 

573 

574 Returns: 

575 A `RowPartition`. 

576 """ 

577 if not isinstance(validate, bool): 

578 raise TypeError("validate must have type bool") 

579 if nrows is None and nvals is None: 

580 raise ValueError("Either (or both) of nvals and nrows must be specified") 

581 with ops.name_scope(None, "RowPartitionFromUniformRowLength", 

582 [uniform_row_length, nrows]): 

583 [uniform_row_length, nvals, nrows 

584 ] = _convert_all_to_tensors([(uniform_row_length, "uniform_row_length"), 

585 (nvals, "nvals"), (nrows, "nrows")], 

586 dtype=dtype, 

587 dtype_hint=dtype_hint) 

588 

589 uniform_row_length.shape.assert_has_rank(0) 

590 

591 # Find nrows. 

592 const_row_length = tensor_util.constant_value(uniform_row_length) 

593 if nrows is None: 

594 if const_row_length is None: 

595 # Avoid division by zero if uniform_row_length==0 (and nvals==0). 

596 rowlen_or_1 = math_ops.maximum( 

597 uniform_row_length, 

598 constant_op.constant(1, uniform_row_length.dtype)) 

599 nrows = nvals // rowlen_or_1 

600 elif const_row_length == 0: 

601 nrows = constant_op.constant(0, dtype=uniform_row_length.dtype) 

602 else: 

603 nrows = nvals // const_row_length 

604 const_nrows = None if nrows is None else tensor_util.constant_value(nrows) 

605 const_nvals = None if nvals is None else tensor_util.constant_value(nvals) 

606 const_uniform_row_length = tensor_util.constant_value(uniform_row_length) 

607 

608 checks = [] 

609 

610 if const_nvals is None and const_nrows is not None and const_uniform_row_length is not None: 

611 const_nvals = const_nrows * const_uniform_row_length 

612 if nvals is not None and validate: 

613 checks.append(check_ops.assert_equal(nvals, const_nvals)) 

614 nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) 

615 

616 if nvals is None: 

617 nvals = nrows * uniform_row_length 

618 

619 # Find row_splits. 

620 if const_nrows is not None and const_row_length is not None: 

621 row_splits = [v * const_row_length for v in range(const_nrows + 1)] 

622 row_splits = constant_op.constant(row_splits, uniform_row_length.dtype) 

623 else: 

624 row_splits = math_ops.range( 

625 nrows + 1, dtype=uniform_row_length.dtype) * uniform_row_length 

626 

627 if validate: 

628 

629 if (const_nrows is None or const_row_length is None or 

630 const_nvals is None): 

631 checks.append( 

632 check_ops.assert_equal( 

633 nrows * uniform_row_length, nvals, 

634 ("uniform_row_length", uniform_row_length, "times nrows", 

635 nrows, "must equal nvals", nvals))) 

636 else: 

637 if const_nrows * const_row_length != const_nvals: 

638 raise ValueError( 

639 "uniform_row_length=%d times nrows=%d must equal nvals=%d" % 

640 (const_row_length, const_nrows, const_nvals)) 

641 

642 if uniform_row_length.shape.rank is None: 

643 checks.append( 

644 check_ops.assert_rank( 

645 uniform_row_length, 

646 0, 

647 message="uniform_row_length must be a scalar.")) 

648 

649 const_row_length = tensor_util.constant_value(uniform_row_length) 

650 if const_row_length is None: 

651 checks.append( 

652 check_ops.assert_greater_equal( 

653 uniform_row_length, 

654 constant_op.constant(0, uniform_row_length.dtype), 

655 message="uniform_row_length must be >= 0.")) 

656 else: 

657 if const_row_length < 0: 

658 raise ValueError("uniform_row_length must be >= 0.") 

659 

660 row_splits = control_flow_ops.with_dependencies(checks, row_splits) 

661 

662 return cls( 

663 row_splits=row_splits, 

664 uniform_row_length=uniform_row_length, 

665 nrows=nrows, 

666 nvals=nvals, 

667 internal=_row_partition_factory_key) 

668 

669 @classmethod 

670 def _convert_row_partition(cls, partition, name, dtype=None, dtype_hint=None): 

671 """Converts `partition` to Tensors. 

672 

673 Args: 

674 partition: A row-partitioning tensor for the `RowPartition` being 

675 constructed. I.e., one of: row_splits, row_lengths, row_starts, 

676 row_limits, value_rowids, uniform_row_length. 

677 name: The name of the row-partitioning tensor. 

678 dtype: Optional dtype for the RowPartition. If missing, the type 

679 is inferred from the type of `uniform_row_length`, dtype_hint, 

680 or tf.int64. 

681 dtype_hint: Optional dtype for the RowPartition, used when dtype 

682 is None. In some cases, a caller may not have a dtype in mind when 

683 converting to a tensor, so dtype_hint can be used as a soft preference. 

684 If the conversion to `dtype_hint` is not possible, this argument has no 

685 effect. 

686 

687 Returns: 

688 A tensor equivalent to partition. 

689 

690 Raises: 

691 ValueError: if dtype is not int32 or int64. 

692 """ 

693 if dtype_hint is None: 

694 dtype_hint = dtypes.int64 

695 if (isinstance(partition, np.ndarray) and 

696 partition.dtype == np.int32 and dtype is None): 

697 partition = ops.convert_to_tensor(partition, name=name) 

698 else: 

699 partition = tensor_conversion.convert_to_tensor_v2( 

700 partition, dtype_hint=dtype_hint, dtype=dtype, name=name 

701 ) 

702 if partition.dtype not in (dtypes.int32, dtypes.int64): 

703 raise ValueError("%s must have dtype int32 or int64" % name) 

704 

705 return partition 

706 

707 def _with_dependencies(self, dependencies): 

708 """Returns a new RowPartition equal to self with control dependencies. 

709 

710 Specifically, self._row_splits is gated by the given control dependencies. 

711 Used to add sanity checks to the constructors. 

712 

713 Args: 

714 dependencies: a list of tensors to use as dependencies. 

715 

716 Returns: 

717 A new RowPartition object. 

718 """ 

719 new_row_splits = control_flow_ops.with_dependencies(dependencies, 

720 self._row_splits) 

721 return RowPartition( 

722 row_splits=new_row_splits, 

723 row_lengths=self._row_lengths, 

724 value_rowids=self._value_rowids, 

725 nrows=self._nrows, 

726 uniform_row_length=self._uniform_row_length, 

727 internal=_row_partition_factory_key) 

728 

729 # ============================================================================= 

730 # Accessors 

731 # ============================================================================= 

732 

733 @property 

734 def dtype(self): 

735 """The `DType` used to encode the row partition (either int32 or int64).""" 

736 return self._row_splits.dtype 

737 

738 def row_splits(self): 

739 """Returns the row-split indices for this row partition. 

740 

741 `row_splits` specifies where the values for each row begin and end. 

742 In particular, the values for row `i` are stored in the slice 

743 `values[row_splits[i]:row_splits[i+1]]`. 

744 

745 Returns: 

746 A 1-D integer `Tensor` with shape `[self.nrows+1]`. 

747 The returned tensor is non-empty, and is sorted in ascending order. 

748 `self.row_splits()[0] == 0`. 

749 `self.row_splits()[-1] == self.nvals()`. 

750 """ 

751 return self._row_splits 

752 

753 def value_rowids(self): 

754 """Returns the row indices for this row partition. 

755 

756 `value_rowids` specifies the row index fo reach value. In particular, 

757 `value_rowids[i]` is the row index for `values[i]`. 

758 

759 Returns: 

760 A 1-D integer `Tensor` with shape `[self.nvals()]`. 

761 The returned tensor is nonnegative, and is sorted in ascending order. 

762 """ 

763 if self._value_rowids is not None: 

764 return self._value_rowids 

765 return segment_id_ops.row_splits_to_segment_ids(self._row_splits) 

766 

767 def nvals(self): 

768 """Returns the number of values partitioned by this `RowPartition`. 

769 

770 If the sequence partitioned by this `RowPartition` is a tensor, then 

771 `nvals` is the size of that tensor's outermost dimension -- i.e., 

772 `nvals == values.shape[0]`. 

773 

774 Returns: 

775 scalar integer Tensor 

776 """ 

777 # TODO(martinz): Uncomment these lines. 

778 # if self._nvals is not None: 

779 # return self._nvals 

780 return self._row_splits[-1] 

781 

782 def nrows(self): 

783 """Returns the number of rows created by this `RowPartition`. 

784 

785 Returns: 

786 scalar integer Tensor 

787 """ 

788 if self._nrows is not None: 

789 return self._nrows 

790 nsplits = tensor_shape.dimension_at_index(self._row_splits.shape, 0) 

791 if nsplits.value is None: 

792 return array_ops.shape(self._row_splits, out_type=self.dtype)[0] - 1 

793 else: 

794 return constant_op.constant(nsplits.value - 1, dtype=self.dtype) 

795 

796 def uniform_row_length(self): 

797 """Returns the length of each row in this partition, if rows are uniform. 

798 

799 If all rows in this `RowPartition` have the same length, then this returns 

800 that length as a scalar integer `Tensor`. Otherwise, it returns `None`. 

801 

802 Returns: 

803 scalar Tensor with `type=self.dtype`, or `None`. 

804 """ 

805 return self._uniform_row_length 

806 

807 def row_starts(self): 

808 """Returns the start indices for rows in this row partition. 

809 

810 These indices specify where the values for each row begin. 

811 `partition.row_starts()` is equal to `partition.row_splits()[:-1]`. 

812 

813 Returns: 

814 A 1-D integer Tensor with shape `[self.nrows()]`. 

815 The returned tensor is nonnegative, and is sorted in ascending order. 

816 `self.row_starts()[0] == 0`. 

817 `self.row_starts()[-1] <= self.nvals()`. 

818 """ 

819 return self._row_splits[:-1] 

820 

821 def row_limits(self): 

822 """Returns the limit indices for rows in this row partition. 

823 

824 These indices specify where the values for each row end. 

825 `partition.row_limits()` is equal to `partition.row_splits()[:-1]`. 

826 

827 Returns: 

828 A 1-D integer Tensor with shape `[self.nrows]`. 

829 The returned tensor is nonnegative, and is sorted in ascending order. 

830 `self.row_limits()[-1] == self.nvals()`. 

831 """ 

832 return self._row_splits[1:] 

833 

834 def row_lengths(self): 

835 """Returns the lengths of rows in this `RowPartition`. 

836 

837 Returns: 

838 A 1-D integer Tensor with shape `[self.nrows]`. 

839 The returned tensor is nonnegative. 

840 `tf.reduce_sum(self.row_lengths) == self.nvals()`. 

841 """ 

842 if self._row_lengths is not None: 

843 return self._row_lengths 

844 splits = self._row_splits 

845 return splits[1:] - splits[:-1] 

846 

847 @property 

848 def static_nrows(self): 

849 """The number of rows in this partition, if statically known. 

850 

851 ```python 

852 self.row_lengths().shape == [self.static_nrows] 

853 self.row_starts().shape == [self.static_nrows] 

854 self.row_limits().shape == [self.static_nrows] 

855 self.row_splits().shape == [self.static_nrows + 1] 

856 ``` 

857 

858 Returns: 

859 The number of rows in this partition as an `int` (if statically known); 

860 or `None` (otherwise). 

861 """ 

862 if self._row_splits is not None: 

863 nrows_plus_one = tensor_shape.dimension_value(self._row_splits.shape[0]) 

864 if nrows_plus_one is not None: 

865 return nrows_plus_one - 1 

866 if self._row_lengths is not None: 

867 nrows = tensor_shape.dimension_value(self._row_lengths.shape[0]) 

868 if nrows is not None: 

869 return nrows 

870 if self._nrows is not None: 

871 return tensor_util.constant_value(self._nrows) 

872 return None 

873 

874 @property 

875 def static_nvals(self): 

876 """The number of values in this partition, if statically known. 

877 

878 ```python 

879 self.value_rowids().shape == [self.static_vals] 

880 ``` 

881 

882 Returns: 

883 The number of values in this partition as an `int` (if statically known); 

884 or `None` (otherwise). 

885 """ 

886 if self._nvals is not None: 

887 nvals = tensor_util.constant_value(self._nvals) 

888 if nvals is not None: 

889 return nvals 

890 if self._value_rowids is not None: 

891 nvals = tensor_shape.dimension_at_index(self._value_rowids.shape, 0) 

892 if nvals.value is not None: 

893 return nvals.value 

894 return None 

895 

896 @property 

897 def static_uniform_row_length(self): 

898 """The number of values in each row of this partition, if statically known. 

899 

900 Returns: 

901 The number of values in each row of this partition as an `int` (if 

902 statically known); or `None` (otherwise). 

903 """ 

904 if self._uniform_row_length is not None: 

905 return tensor_util.constant_value(self._uniform_row_length) 

906 return None 

907 

908 def offsets_in_rows(self): 

909 """Return the offset of each value. 

910 

911 RowPartition takes an array x and converts it into sublists. 

912 offsets[i] is the index of x[i] in its sublist. 

913 Given a shape, such as: 

914 [*,*,*],[*,*],[],[*,*] 

915 This returns: 

916 0,1,2,0,1,0,1 

917 

918 Returns: 

919 an offset for every value. 

920 """ 

921 return gen_ragged_math_ops.ragged_range( 

922 starts=constant_op.constant(0, self.dtype), 

923 limits=self.row_lengths(), 

924 deltas=constant_op.constant(1, self.dtype)).rt_dense_values 

925 

926 def is_uniform(self): 

927 """Returns true if the partition is known to be uniform statically. 

928 

929 This is based upon the existence of self._uniform_row_length. For example: 

930 RowPartition.from_row_lengths([3,3,3]).is_uniform()==false 

931 RowPartition.from_uniform_row_length(5, nvals=20).is_uniform()==true 

932 RowPartition.from_row_lengths([2,0,2]).is_uniform()==false 

933 

934 Returns: 

935 Whether a RowPartition is known to be uniform statically. 

936 """ 

937 return self._uniform_row_length is not None 

938 

939 def _static_check(self): 

940 """Checks if the object is internally consistent. 

941 

942 Raises: 

943 ValueError if inconsistent. 

944 """ 

945 my_dtype = self.dtype 

946 if self._uniform_row_length is not None: 

947 if self._uniform_row_length.dtype != my_dtype: 

948 raise ValueError("_uniform_row_length.dtype=" + 

949 str(self._uniform_row_length.dtype) + ", not " + 

950 str(my_dtype)) 

951 

952 if self._row_lengths is not None and self._row_lengths.dtype != my_dtype: 

953 raise ValueError("_row_lengths.dtype=" + str(self._row_lengths.dtype) + 

954 ", not " + str(my_dtype)) 

955 

956 if self._value_rowids is not None and self._value_rowids.dtype != my_dtype: 

957 raise ValueError("_value_rowids.dtype=" + str(self._value_rowids.dtype) + 

958 ", not " + str(my_dtype)) 

959 

960 if self._nrows is not None and self._nrows.dtype != my_dtype: 

961 raise ValueError("_nrows.dtype=" + str(self._nrows.dtype) + ", not " + 

962 str(my_dtype)) 

963 

964 # ============================================================================= 

965 # Transformation 

966 # ============================================================================= 

967 

968 def with_dtype(self, dtype): 

969 """Returns a copy of this RowPartition with the given encoding dtype. 

970 

971 Args: 

972 dtype: The dtype for encoding tensors, such as `row_splits` and `nrows`. 

973 One of `tf.int32` or `tf.int64`. 

974 

975 Returns: 

976 A copy of this RowPartition, with the encoding tensors cast to the given 

977 type. 

978 """ 

979 dtype = dtypes.as_dtype(dtype) 

980 if dtype not in (dtypes.int32, dtypes.int64): 

981 raise ValueError("dtype must be int32 or int64") 

982 if self.dtype == dtype: 

983 return self 

984 

985 return RowPartition( 

986 row_splits=_cast_if_not_none(self._row_splits, dtype), 

987 row_lengths=_cast_if_not_none(self._row_lengths, dtype), 

988 value_rowids=_cast_if_not_none(self._value_rowids, dtype), 

989 nrows=_cast_if_not_none(self._nrows, dtype), 

990 uniform_row_length=_cast_if_not_none(self._uniform_row_length, dtype), 

991 internal=_row_partition_factory_key) 

992 

993 # ============================================================================= 

994 # String Encoding 

995 # ============================================================================= 

996 

997 def __repr__(self): 

998 if self._uniform_row_length is not None: 

999 return (f"tf.RowPartition(nrows={self._nrows}, " 

1000 f"uniform_row_length={self._uniform_row_length})") 

1001 else: 

1002 return f"tf.RowPartition(row_splits={self._row_splits})" 

1003 

1004 # ============================================================================= 

1005 # Precomputed Encodings 

1006 # ============================================================================= 

1007 

1008 def _has_precomputed_row_splits(self): 

1009 """Returns true if `row_splits` has already been computed. 

1010 

1011 If true, then `self.row_splits()` will return its value without calling 

1012 any TensorFlow ops. 

1013 """ 

1014 return self._row_splits is not None 

1015 

1016 def _has_precomputed_row_lengths(self): 

1017 """Returns true if `row_lengths` has already been computed. 

1018 

1019 If true, then `self.row_lengths()` will return its value without calling 

1020 any TensorFlow ops. 

1021 """ 

1022 return self._row_lengths is not None 

1023 

1024 def _has_precomputed_value_rowids(self): 

1025 """Returns true if `value_rowids` has already been computed. 

1026 

1027 If true, then `self.value_rowids()` will return its value without calling 

1028 any TensorFlow ops. 

1029 """ 

1030 return self._value_rowids is not None 

1031 

1032 def _has_precomputed_nrows(self): 

1033 """Returns true if `nrows` has already been computed. 

1034 

1035 If true, then `self.nrows()` will return its value without calling 

1036 any TensorFlow ops. 

1037 """ 

1038 return self._nrows is not None 

1039 

1040 def _has_precomputed_nvals(self): 

1041 """Returns true if `nvals` has already been computed. 

1042 

1043 If true, then `self.nvals()` will return its value without calling 

1044 any TensorFlow ops. 

1045 """ 

1046 return self._nvals is not None 

1047 

1048 def _with_precomputed_row_splits(self): 

1049 """Returns a copy of `self` with `row_splits` precomputed.""" 

1050 return RowPartition( 

1051 row_splits=self.row_splits(), 

1052 row_lengths=self._row_lengths, 

1053 value_rowids=self._value_rowids, 

1054 nrows=self._nrows, 

1055 uniform_row_length=self._uniform_row_length, 

1056 nvals=self._nvals, 

1057 internal=_row_partition_factory_key) 

1058 

1059 def _with_precomputed_row_lengths(self): 

1060 """Returns a copy of `self` with `row_lengths` precomputed.""" 

1061 return RowPartition( 

1062 row_splits=self._row_splits, 

1063 row_lengths=self.row_lengths(), 

1064 value_rowids=self._value_rowids, 

1065 nrows=self._nrows, 

1066 nvals=self._nvals, 

1067 uniform_row_length=self._uniform_row_length, 

1068 internal=_row_partition_factory_key) 

1069 

1070 def _with_precomputed_value_rowids(self): 

1071 """Returns a copy of `self` with `value_rowids` precomputed.""" 

1072 return RowPartition( 

1073 row_splits=self._row_splits, 

1074 row_lengths=self._row_lengths, 

1075 value_rowids=self.value_rowids(), 

1076 nrows=self._nrows, 

1077 nvals=self._nvals, 

1078 uniform_row_length=self._uniform_row_length, 

1079 internal=_row_partition_factory_key) 

1080 

1081 def _with_precomputed_nrows(self): 

1082 """Returns a copy of `self` with `nrows` precomputed.""" 

1083 return RowPartition( 

1084 row_splits=self._row_splits, 

1085 row_lengths=self._row_lengths, 

1086 value_rowids=self._value_rowids, 

1087 nrows=self.nrows(), 

1088 nvals=self._nvals, 

1089 uniform_row_length=self._uniform_row_length, 

1090 internal=_row_partition_factory_key) 

1091 

1092 def _with_precomputed_nvals(self): 

1093 """Returns a copy of `self` with `row_splits` precomputed.""" 

1094 return RowPartition( 

1095 row_splits=self.row_splits(), 

1096 row_lengths=self._row_lengths, 

1097 value_rowids=self._value_rowids, 

1098 nrows=self._nrows, 

1099 nvals=self.nvals(), 

1100 uniform_row_length=self._uniform_row_length, 

1101 internal=_row_partition_factory_key) 

1102 

1103 def _merge_with_spec(self, b): 

1104 """Merge with a TypeSpec to create a new RowPartition.""" 

1105 a_spec = self._type_spec 

1106 if not a_spec.is_compatible_with(b): 

1107 # TODO(martinz): Should a dynamic check be used here? 

1108 raise ValueError("RowPartition and RowPartitionSpec are not compatible") 

1109 nrows = constant_op.constant( 

1110 b.nrows, self.dtype) if b.nrows is not None else self._nrows 

1111 nvals = constant_op.constant( 

1112 b.nvals, self.dtype) if b.nvals is not None else self._nvals 

1113 uniform_row_length = constant_op.constant( 

1114 b.uniform_row_length, self.dtype 

1115 ) if b.uniform_row_length is not None else self._uniform_row_length 

1116 return RowPartition( 

1117 row_splits=self._row_splits, 

1118 row_lengths=self._row_lengths, 

1119 value_rowids=self._value_rowids, 

1120 nvals=nvals, 

1121 uniform_row_length=uniform_row_length, 

1122 nrows=nrows, 

1123 internal=_row_partition_factory_key) 

1124 

1125 def _merge_precomputed_encodings(self, other, validate=True): 

1126 """Returns a RowPartition that merges encodings from `self` and `other`. 

1127 

1128 Requires that `self` and `other` describe the same partition. 

1129 

1130 Args: 

1131 other: A `RowPartition` that encodes the same partition as `self`. 

1132 validate: If true, then add runtime checks to verify that `self` and 

1133 `other` encode the same row partition. 

1134 

1135 Returns: 

1136 A `RowPartition`. 

1137 """ 

1138 # pylint: disable=protected-access 

1139 if (self is other or # Fast path if row partitions are equal. 

1140 (self._row_splits is other._row_splits and 

1141 self._row_lengths is other._row_lengths and 

1142 self._value_rowids is other._value_rowids and 

1143 self._nrows is other._nrows and 

1144 self._nvals is other._nvals and 

1145 self._uniform_row_length is other._uniform_row_length)): 

1146 return self 

1147 

1148 # Merge the component tensors. We only need to validate one encoding. 

1149 # We merge less-expensive encodings first (to avoid expensive validation). 

1150 nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows", 

1151 validate) 

1152 nvals, _ = _merge_tensors(self._nvals, other._nvals, "nvals", validate) 

1153 uniform_row_length, uniform_row_length_validated = _merge_tensors( 

1154 self._uniform_row_length, other._uniform_row_length, 

1155 "uniform_row_length", validate) 

1156 if uniform_row_length_validated and nrows_validated: 

1157 validate = False # Validation complete. 

1158 row_splits, row_splits_validated = _merge_tensors(self._row_splits, 

1159 other._row_splits, 

1160 "row_splits", validate) 

1161 if row_splits_validated: 

1162 validate = False # Validation complete. 

1163 row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths, 

1164 other._row_lengths, 

1165 "row_lengths", validate) 

1166 if row_lengths_validated: 

1167 validate = False # Validation complete. 

1168 value_rowids, value_rowids_validated = _merge_tensors( 

1169 self._value_rowids, other._value_rowids, "value_rowids", validate) 

1170 if value_rowids_validated and nrows_validated: 

1171 validate = False # Validation complete. 

1172 # TODO(edloper): If we make the row_splits encoding optional, then there 

1173 # will be cases where we need to do validation at this point -- e.g. if 

1174 # self has only row_splits and other has only value_rowids. But for 

1175 # now, we are guaranteed to have done validation by this point. 

1176 

1177 # Avoid creating new RowPartition objects if we don't need to. 

1178 if (row_splits is self._row_splits and row_lengths is self._row_lengths and 

1179 value_rowids is self._value_rowids and nrows is self._nrows and 

1180 uniform_row_length is self._uniform_row_length): 

1181 return self 

1182 if (row_splits is other._row_splits and 

1183 row_lengths is other._row_lengths and 

1184 value_rowids is other._value_rowids and nrows is other._nrows and 

1185 uniform_row_length is other._uniform_row_length): 

1186 return other 

1187 

1188 return RowPartition( 

1189 row_splits=row_splits, 

1190 row_lengths=row_lengths, 

1191 value_rowids=value_rowids, 

1192 nrows=nrows, 

1193 uniform_row_length=uniform_row_length, 

1194 nvals=nvals, 

1195 internal=_row_partition_factory_key) 

1196 

1197 # ============================================================================= 

1198 # Composite Tensor 

1199 # ============================================================================= 

1200 

1201 @property 

1202 def _type_spec(self): 

1203 return RowPartitionSpec.from_value(self) 

1204 

1205 

1206# =============================================================================== 

1207# RowPartitionSpec 

1208# =============================================================================== 

1209# TODO(edloper): Consider refactoring RowPartitionSpec to allow any combination 

1210# of precomputed row-partition encodings (rather than always using row_splits). 

1211 

1212 

1213@type_spec_registry.register("tf.RowPartitionSpec") 

1214class RowPartitionSpec(type_spec.TypeSpec): 

1215 """Type specification for a `tf.RowPartition`.""" 

1216 

1217 __slots__ = ["_nrows", "_nvals", "_uniform_row_length", "_dtype"] 

1218 

1219 value_type = property(lambda self: RowPartition) 

1220 

1221 def __init__(self, 

1222 nrows=None, 

1223 nvals=None, 

1224 uniform_row_length=None, 

1225 dtype=dtypes.int64): 

1226 """Constructs a new RowPartitionSpec. 

1227 

1228 Args: 

1229 nrows: The number of rows in the RowPartition, or `None` if unspecified. 

1230 nvals: The number of values partitioned by the RowPartition, or `None` if 

1231 unspecified. 

1232 uniform_row_length: The number of values in each row for this 

1233 RowPartition, or `None` if rows are ragged or row length is unspecified. 

1234 dtype: The data type used to encode the partition. One of `tf.int64` or 

1235 `tf.int32`. 

1236 """ 

1237 # Wrap dimension sizes in 1D TensorShapes so the default implementations 

1238 # of TypeSpec methods such as `is_compatile_with` will work. 

1239 nrows = tensor_shape.TensorShape([nrows]) 

1240 nvals = tensor_shape.TensorShape([nvals]) 

1241 if not isinstance(uniform_row_length, tensor_shape.TensorShape): 

1242 uniform_row_length = tensor_shape.TensorShape([uniform_row_length]) 

1243 else: 

1244 uniform_row_length = uniform_row_length.with_rank(1) 

1245 

1246 self._nrows = nrows 

1247 self._nvals = nvals 

1248 self._uniform_row_length = uniform_row_length 

1249 self._dtype = dtypes.as_dtype(dtype) 

1250 if self._dtype not in (dtypes.int32, dtypes.int64): 

1251 raise ValueError("dtype must be tf.int32 or tf.int64") 

1252 

1253 # Check dimension consistency, & infer dimensions when possible. 

1254 nrows = tensor_shape.dimension_value(nrows[0]) 

1255 nvals = tensor_shape.dimension_value(nvals[0]) 

1256 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 

1257 if nrows == 0: # no rows -> no values. 

1258 if nvals is None: 

1259 self._nvals = tensor_shape.TensorShape([0]) 

1260 elif nvals != 0: 

1261 raise ValueError("nvals=%s is not compatible with nrows=%s" % 

1262 (nvals, nrows)) 

1263 if ncols == 0: # there are no values in each row -> no values. 

1264 if nvals is None: 

1265 self._nvals = tensor_shape.TensorShape([0]) 

1266 elif nvals != 0: 

1267 raise ValueError("nvals=%s is not compatible with uniform_row_length" 

1268 "=%s" % (nvals, uniform_row_length)) 

1269 if ncols is not None and nvals is not None: 

1270 if ncols != 0 and nvals % ncols != 0: 

1271 raise ValueError("nvals=%s is not compatible with uniform_row_length" 

1272 "=%s (doesn't divide evenly)" % (nvals, ncols)) 

1273 if nrows is not None and nvals != ncols * nrows: 

1274 raise ValueError("nvals=%s is not compatible with nrows=%s and " 

1275 "uniform_row_length=%s" % (nvals, nrows, ncols)) 

1276 if nrows is None and ncols != 0: 

1277 self._nrows = tensor_shape.TensorShape([nvals // ncols]) 

1278 if ncols is not None and nrows is not None and nvals is None: 

1279 self._nvals = tensor_shape.TensorShape([ncols * nrows]) 

1280 

1281 def is_compatible_with(self, other): 

1282 if not super(RowPartitionSpec, self).is_compatible_with(other): 

1283 return False 

1284 nrows = self._nrows.merge_with(other.nrows) 

1285 nvals = self._nvals.merge_with(other.nvals) 

1286 ncols = self._uniform_row_length.merge_with(other.uniform_row_length) 

1287 return self._dimensions_compatible(nrows, nvals, ncols) 

1288 

1289 def _serialize(self): 

1290 return (self._nrows, self._nvals, self._uniform_row_length, self._dtype) 

1291 

1292 @classmethod 

1293 def _deserialize(cls, serialization): 

1294 # Remove TensorShape wrappers from serialization. 

1295 (nrows, nvals, uniform_row_length, dtype) = serialization 

1296 nrows = tensor_shape.dimension_value(nrows[0]) 

1297 nvals = tensor_shape.dimension_value(nvals[0]) 

1298 return cls(nrows, nvals, uniform_row_length, dtype) 

1299 

1300 @property 

1301 def nrows(self): 

1302 return tensor_shape.dimension_value(self._nrows[0]) 

1303 

1304 @property 

1305 def nvals(self): 

1306 return tensor_shape.dimension_value(self._nvals[0]) 

1307 

1308 @property 

1309 def uniform_row_length(self): 

1310 return tensor_shape.dimension_value(self._uniform_row_length[0]) 

1311 

1312 @property 

1313 def dtype(self): 

1314 return self._dtype 

1315 

1316 @property 

1317 def _component_specs(self): 

1318 row_splits_shape = tensor_shape.TensorShape( 

1319 [tensor_shape.dimension_at_index(self._nrows, 0) + 1]) 

1320 return tensor_spec.TensorSpec(row_splits_shape, self._dtype) 

1321 

1322 def _to_components(self, value): 

1323 return value.row_splits() 

1324 

1325 def _from_components(self, tensor): 

1326 return RowPartition.from_row_splits(tensor, validate=False) 

1327 

1328 @classmethod 

1329 def from_value(cls, value): 

1330 if not isinstance(value, RowPartition): 

1331 raise TypeError("Expected `value` to be a `RowPartition`") 

1332 return cls(value.static_nrows, value.static_nvals, 

1333 value.static_uniform_row_length, value.dtype) 

1334 

1335 def __repr__(self): 

1336 return ("RowPartitionSpec(nrows=%s, nvals=%s, uniform_row_length=%s, " 

1337 "dtype=%r)" % (self.nrows, self.nvals, self.uniform_row_length, 

1338 self.dtype)) 

1339 

1340 @staticmethod 

1341 def _dimensions_compatible(nrows, nvals, uniform_row_length): 

1342 """Returns true if the given dimensions are compatible.""" 

1343 nrows = tensor_shape.dimension_value(nrows[0]) 

1344 nvals = tensor_shape.dimension_value(nvals[0]) 

1345 ncols = tensor_shape.dimension_value(uniform_row_length[0]) 

1346 if nrows == 0 and nvals not in (0, None): 

1347 return False # can't have values if we have no rows. 

1348 if ncols == 0 and nvals not in (0, None): 

1349 return False # can't have values if we have no values in each row. 

1350 if ncols is not None and nvals is not None: 

1351 if ncols != 0 and nvals % ncols != 0: 

1352 return False # rows aren't uniform. 

1353 if nrows is not None and nvals != ncols * nrows: 

1354 return False # inconsistent number of values. 

1355 return True 

1356 

1357 def _merge_with(self, other): 

1358 """Merge two RowPartitionSpecs.""" 

1359 nrows = self._nrows.merge_with(other.nrows) 

1360 nvals = self._nvals.merge_with(other.nvals) 

1361 ncols = self._uniform_row_length.merge_with(other.uniform_row_length) 

1362 

1363 if not RowPartitionSpec._dimensions_compatible(nrows, nvals, ncols): 

1364 raise ValueError("Merging incompatible RowPartitionSpecs") 

1365 

1366 # NOTE: if the dtypes are unequal, behavior is unspecified. 

1367 if self.dtype != other.dtype: 

1368 raise ValueError("Merging RowPartitionSpecs with incompatible dtypes") 

1369 

1370 return RowPartitionSpec(nrows=nrows[0], 

1371 nvals=nvals[0], 

1372 uniform_row_length=ncols[0], 

1373 dtype=self.dtype) 

1374 

1375 def with_dtype(self, dtype): 

1376 nrows = tensor_shape.dimension_value(self._nrows[0]) 

1377 nvals = tensor_shape.dimension_value(self._nvals[0]) 

1378 return RowPartitionSpec(nrows, nvals, self._uniform_row_length, dtype) 

1379 

1380 def __deepcopy__(self, memo): 

1381 del memo 

1382 dtype = self.dtype 

1383 nrows = tensor_shape.dimension_value(self._nrows[0]) 

1384 nvals = tensor_shape.dimension_value(self._nvals[0]) 

1385 uniform_row_length = (None if self._uniform_row_length is None else 

1386 tensor_shape.dimension_value( 

1387 self._uniform_row_length[0])) 

1388 return RowPartitionSpec(nrows, nvals, uniform_row_length, dtype) 

1389 

1390 

1391nested_structure_coder.register_codec( 

1392 nested_structure_coder.BuiltInTypeSpecCodec( 

1393 RowPartitionSpec, struct_pb2.TypeSpecProto.ROW_PARTITION_SPEC 

1394 ) 

1395) 

1396 

1397 

1398# =============================================================================== 

1399# Helper Functions 

1400# =============================================================================== 

1401 

1402 

1403def _assert_monotonic_increasing(tensor, message=None): 

1404 return check_ops.assert_non_negative( 

1405 tensor[1:] - tensor[:-1], message=message) 

1406 

1407 

1408def _assert_zero(tensor, message=None): 

1409 return check_ops.assert_equal( 

1410 tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) 

1411 

1412 

1413def _cast_if_not_none(tensor, dtype): 

1414 return None if tensor is None else math_ops.cast(tensor, dtype) 

1415 

1416 

1417def _merge_tensors(t1, t2, name, validate): 

1418 """Merge two optional Tensors with equal values into a single Tensor. 

1419 

1420 Args: 

1421 t1: tf.Tensor or None 

1422 t2: tf.Tensor or None 

1423 name: A name for the tensors (for error messages) 

1424 validate: If true, then check that `t1` is compatible with `t2` (if both are 

1425 non-None). 

1426 

1427 Returns: 

1428 A pair `(merged_value, validated)`: 

1429 * `merged_value` is `t1` if it is not None; or `t2` otherwise. 

1430 * `validated` is true if we validated that t1 and t2 are equal (either 

1431 by adding a check, or because t1 is t2). 

1432 """ 

1433 if t1 is None: 

1434 return t2, False 

1435 elif t2 is None: 

1436 return t1, False 

1437 elif t1 is t2: 

1438 return t1, True 

1439 else: 

1440 err_msg = ("RowPartition._merge_precomputed_encodings: partitions " 

1441 "have incompatible %s" % name) 

1442 if not t1.shape.is_compatible_with(t2.shape): 

1443 raise ValueError(err_msg) 

1444 if validate: 

1445 checks = [check_ops.assert_equal(t1, t2, message=err_msg)] 

1446 return control_flow_ops.with_dependencies(checks, t1), True 

1447 else: 

1448 return t1, False 

1449 

1450_row_partition_factory_key = object() # unique private object 

1451 

1452 

1453def _get_dtype_or_none(value): 

1454 if isinstance(value, ops.Tensor): 

1455 return value.dtype 

1456 return None 

1457 

1458 

1459def _get_target_dtype(values, dtype=None, dtype_hint=None): 

1460 """Gets the target dtype of a family of values.""" 

1461 if dtype is not None: 

1462 return dtype 

1463 

1464 for value in values: 

1465 if isinstance(value, ops.Tensor): 

1466 return value.dtype 

1467 

1468 for value in values: 

1469 if isinstance(value, np.ndarray): 

1470 return dtypes.as_dtype(value.dtype) 

1471 

1472 if dtype_hint is not None: 

1473 return dtype_hint 

1474 

1475 return dtypes.int64 

1476 

1477 

1478def _convert_all_to_tensors(values, dtype=None, dtype_hint=None): 

1479 """Convert a list of objects to tensors of the same dtype.""" 

1480 target_dtype = _get_target_dtype([x for (x, _) in values], dtype, dtype_hint) 

1481 

1482 # If dtype is None, we use convert behavior. 

1483 # If dtype is not None, we use cast behavior. 

1484 convert_behavior = dtype is None 

1485 

1486 if convert_behavior: 

1487 return [ 

1488 None if x is None else ops.convert_to_tensor( 

1489 x, dtype=target_dtype, name=name) for (x, name) in values 

1490 ] 

1491 else: 

1492 return [ 

1493 None if x is None else math_ops.cast(x, dtype=target_dtype, name=name) 

1494 for (x, name) in values 

1495 ]