Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/structured/structured_array_ops.py: 25%

228 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""StructuredTensor array ops.""" 

16 

17from typing import Sequence 

18 

19from tensorflow.python.framework import constant_op 

20from tensorflow.python.framework import dtypes 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import random_ops 

25from tensorflow.python.ops.ragged import dynamic_ragged_shape 

26from tensorflow.python.ops.ragged import ragged_tensor 

27from tensorflow.python.ops.ragged.row_partition import RowPartition 

28from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 

29from tensorflow.python.util import deprecation 

30from tensorflow.python.util import dispatch 

31 

32 

33@dispatch.dispatch_for_api(array_ops.shape_v2) 

34def shape_v2(input: StructuredTensor, out_type=dtypes.int32, # pylint: disable=redefined-builtin 

35 name=None) -> dynamic_ragged_shape.DynamicRaggedShape: 

36 """Returns a DynamicRaggedShape containing the shape of the input.""" 

37 del name 

38 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access 

39 

40 

41@dispatch.dispatch_for_api(array_ops.shape) 

42def shape_v1(input: StructuredTensor, name=None, # pylint: disable=redefined-builtin 

43 out_type=dtypes.int32) -> dynamic_ragged_shape.DynamicRaggedShape: 

44 """Returns a DynamicRaggedShape containing the shape of the input.""" 

45 del name 

46 return input._ragged_shape.with_dtype(out_type) # pylint: disable=protected-access 

47 

48 

49@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor) 

50@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim') 

51def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin 

52 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 

53 

54 This is an implementation of tf.expand_dims for StructuredTensor. Note 

55 that the `axis` must be less than or equal to rank. 

56 

57 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 

58 >>> tf.expand_dims(st, 0).to_pyval() 

59 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 

60 >>> tf.expand_dims(st, 1).to_pyval() 

61 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 

62 >>> tf.expand_dims(st, 2).to_pyval() 

63 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

64 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 

65 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

66 

67 Args: 

68 input: the original StructuredTensor. 

69 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 

70 name: the name of the op. 

71 dim: deprecated: use axis. 

72 

73 Returns: 

74 a new structured tensor with larger rank. 

75 

76 Raises: 

77 an error if `axis < -(rank + 1)` or `rank < axis`. 

78 """ 

79 axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim) 

80 return _expand_dims_impl(input, axis, name=name) 

81 

82 

83@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor) 

84def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin 

85 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 

86 

87 This is an implementation of tf.expand_dims for StructuredTensor. Note 

88 that the `axis` must be less than or equal to rank. 

89 

90 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 

91 >>> tf.expand_dims(st, 0).to_pyval() 

92 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 

93 >>> tf.expand_dims(st, 1).to_pyval() 

94 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 

95 >>> tf.expand_dims(st, 2).to_pyval() 

96 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

97 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 

98 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

99 

100 Args: 

101 input: the original StructuredTensor. 

102 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 

103 name: the name of the op. 

104 

105 Returns: 

106 a new structured tensor with larger rank. 

107 

108 Raises: 

109 an error if `axis < -(rank + 1)` or `rank < axis`. 

110 """ 

111 return _expand_dims_impl(input, axis, name=name) 

112 

113 

114@dispatch.dispatch_for_types(array_ops.gather, StructuredTensor) 

115def gather(params, 

116 indices, 

117 validate_indices=None, 

118 name=None, 

119 axis=None, 

120 batch_dims=0): 

121 """tf.gather for structured tensors. 

122 

123 Does not support (yet) checks on illegal axis values, et cetera. 

124 

125 Indices must be a ragged or dense tensor. 

126 Args: 

127 params: a structured tensor to be gathered 

128 indices: a ragged tensor or tensor to gather by. 

129 validate_indices: whether to validate the indices 

130 name: the name of the op(s). 

131 axis: the axis in params to gather on. 

132 batch_dims: the number of batch dimensions. 

133 

134 Returns: 

135 the params reorganized according to indices. 

136 """ 

137 if name is None: 

138 name = 'gather' 

139 with ops.name_scope(name): 

140 if axis is None: 

141 axis = batch_dims 

142 axis = array_ops.get_positive_axis(axis, params.shape.rank, 

143 ndims_name='params.shape.rank') 

144 indices = ragged_tensor.convert_to_tensor_or_ragged_tensor( 

145 indices, name='indices') 

146 

147 def leaf_op(p): 

148 return array_ops.gather( 

149 p, 

150 indices, 

151 validate_indices=validate_indices, 

152 axis=axis, 

153 batch_dims=batch_dims, 

154 name=None) 

155 

156 return _extend_op_single(params, leaf_op) 

157 

158 

159@dispatch.dispatch_for_types(array_ops.concat, StructuredTensor) 

160def concat(values, axis, name: str = 'concat'): 

161 """tf.concat for structured tensors. 

162 

163 Does not support (yet) checks on illegal axis values, et cetera. 

164 

165 Args: 

166 values: a sequence of StructuredTensors. 

167 axis: an axis to concatenate upon. 

168 name: the name of the op(s). 

169 

170 Returns: 

171 the params reorganized according to indices. 

172 """ 

173 if name is None: 

174 name = 'concat' 

175 _assert_concat_compatible_structured_tensors(values) 

176 def leaf_op(values): 

177 return array_ops.concat(values, axis) 

178 # TODO(martinz): handle axis when it is a tensor. 

179 axis = array_ops.get_positive_axis(axis, values[0].rank) 

180 with ops.name_scope(name, 'StructuredConcat', values): 

181 return _extend_op(values, leaf_op) 

182 

183 

184@dispatch.dispatch_for_types(random_ops.random_shuffle, StructuredTensor) 

185def random_shuffle(value, seed=None, name=None): 

186 """Shuffle a structured tensor on the zeroth axis. 

187 

188 Args: 

189 value: a structured tensor of rank at least one. 

190 seed: the seed for shuffling. 

191 name: the name for shuffle. 

192 

193 Returns: 

194 The shuffled structured tensor. 

195 """ 

196 with ops.name_scope(name, 'shuffle', [value, seed]): 

197 if value.rank == 0: 

198 raise ValueError('Cannot shuffle a scalar StructuredTensor') 

199 first_dimension = value.nrows() 

200 index = random_ops.random_shuffle(math_ops.range(first_dimension), 

201 seed=seed) 

202 return gather(value, index, axis=0) 

203 

204 

205@dispatch.dispatch_for_types(array_ops.size_v2, StructuredTensor) 

206def size_v2(input, out_type=dtypes.int32, name=None): 

207 # pylint: disable=redefined-builtin 

208 """Returns the size of a tensor.""" 

209 return size(input, name=name, out_type=out_type) 

210 

211 

212# pylint: disable=protected-access 

213@dispatch.dispatch_for_types(array_ops.size, StructuredTensor) 

214def size(input, name=None, out_type=dtypes.int32): 

215 # pylint: disable=redefined-builtin 

216 """Returns the size of a tensor.""" 

217 with ops.name_scope(name, 'size', [input]) as name: 

218 if not input.row_partitions: 

219 if input.nrows() is not None: 

220 return math_ops.cast(input.nrows(), out_type) # vector. 

221 else: 

222 return math_ops.cast(1, out_type) # scalar. 

223 # 2D and up. 

224 nvals = input.row_partitions[-1].nvals() 

225 if nvals is None or out_type is None: 

226 return nvals 

227 return math_ops.cast(nvals, dtype=out_type) 

228 

229 

230# pylint: disable=protected-access 

231@dispatch.dispatch_for_types(array_ops.zeros_like, StructuredTensor) 

232def zeros_like(tensor, dtype=None, name=None, optimize=True): 

233 """Implementation of zeros_like for StructuredTensor for TF v1.""" 

234 del optimize 

235 return zeros_like_v2(tensor, dtype=dtype, name=name) 

236 

237 

238# pylint: disable=protected-access 

239@dispatch.dispatch_for_types(array_ops.zeros_like_v2, StructuredTensor) 

240def zeros_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin 

241 """Replace every object with a zero. 

242 

243 Example: 

244 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}]) 

245 >>> tf.zeros_like(st) 

246 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([0.0, 0.0], dtype=float32)> 

247 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]]) 

248 >>> tf.zeros_like(st, dtype=tf.int32) 

249 <tf.RaggedTensor [[0], [0, 0]]> 

250 

251 Args: 

252 input: a structured tensor. 

253 dtype: the dtype of the resulting zeros. (default is tf.float32) 

254 name: a name for the op. 

255 Returns: 

256 a tensor of zeros of the same shape. 

257 """ 

258 if dtype is None: 

259 dtype = dtypes.float32 

260 with ops.name_scope(name, 'zeros_like', [input]) as name: 

261 if not input.row_partitions: 

262 if input.nrows() is not None: 

263 return array_ops.zeros([input.nrows()], dtype) # vector. 

264 else: 

265 return array_ops.zeros([], dtype) # scalar. 

266 # 2D and up. 

267 last_row_partition = input.row_partitions[-1] 

268 

269 result = ragged_tensor.RaggedTensor._from_nested_row_partitions( 

270 array_ops.zeros(last_row_partition.nvals(), dtype=dtype), 

271 input.row_partitions) 

272 return result 

273 

274 

275# pylint: disable=protected-access 

276@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor) 

277def ones_like(tensor, dtype=None, name=None, optimize=True): 

278 """Implementation of zeros_like for StructuredTensor for TF v1.""" 

279 del optimize 

280 return ones_like_v2(tensor, dtype=dtype, name=name) 

281 

282 

283# pylint: disable=protected-access 

284@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor) 

285def ones_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin 

286 """Replace every object with a zero. 

287 

288 Example: 

289 >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}]) 

290 >>> tf.ones_like(st) 

291 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)> 

292 >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]]) 

293 >>> tf.ones_like(st, dtype=tf.int32) 

294 <tf.RaggedTensor [[1], [1, 1]]> 

295 

296 Args: 

297 input: a structured tensor. 

298 dtype: the dtype of the resulting zeros. (default is tf.float32) 

299 name: a name for the op. 

300 Returns: 

301 a tensor of zeros of the same shape. 

302 """ 

303 if dtype is None: 

304 dtype = dtypes.float32 

305 with ops.name_scope(name, 'ones_like', [input]) as name: 

306 if not input.row_partitions: 

307 if input.nrows() is not None: 

308 return array_ops.ones([input.nrows()], dtype) # vector. 

309 else: 

310 return array_ops.ones([], dtype) # scalar. 

311 # 2D and up. 

312 last_row_partition = input.row_partitions[-1] 

313 

314 result = ragged_tensor.RaggedTensor._from_nested_row_partitions( 

315 array_ops.ones(last_row_partition.nvals(), dtype=dtype), 

316 input.row_partitions) 

317 return result 

318 

319 

320@dispatch.dispatch_for_types(array_ops.rank, StructuredTensor) 

321def rank(input, name=None): 

322 # pylint: disable=redefined-builtin 

323 """Returns the rank of a tensor.""" 

324 with ops.name_scope(name, 'rank', [input]) as name: 

325 return constant_op.constant(input.rank, dtype=dtypes.int32) 

326 

327 

328def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin 

329 """Creates a StructuredTensor with a length 1 axis inserted at index `axis`. 

330 

331 This is an implementation of tf.expand_dims for StructuredTensor. Note 

332 that the `axis` must be less than or equal to rank. 

333 

334 >>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]]) 

335 >>> tf.expand_dims(st, 0).to_pyval() 

336 [[[{'x': 1}, {'x': 2}], [{'x': 3}]]] 

337 >>> tf.expand_dims(st, 1).to_pyval() 

338 [[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]] 

339 >>> tf.expand_dims(st, 2).to_pyval() 

340 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

341 >>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2 

342 [[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]] 

343 

344 Args: 

345 st: the original StructuredTensor. 

346 axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank` 

347 name: the name of the op. 

348 

349 Returns: 

350 a new structured tensor with larger rank. 

351 

352 Raises: 

353 an error if `axis < -(rank + 1)` or `rank < axis`. 

354 """ 

355 axis = array_ops.get_positive_axis( 

356 axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)') 

357 with ops.name_scope(name, 'ExpandDims', [st, axis]): 

358 new_fields = { 

359 k: array_ops.expand_dims(v, axis) for (k, v) in st._fields.items() 

360 } 

361 new_shape = st.shape[:axis] + (1,) + st.shape[axis:] 

362 new_row_partitions = _expand_st_row_partitions(st, axis) 

363 new_nrows = st.nrows() if (axis > 0) else 1 

364 return StructuredTensor.from_fields( 

365 new_fields, 

366 shape=new_shape, 

367 row_partitions=new_row_partitions, 

368 nrows=new_nrows) 

369 

370 

371def _expand_st_row_partitions(st, axis): 

372 """Create the row_partitions for expand_dims.""" 

373 if axis == 0: 

374 if st.shape.rank == 0: 

375 return () 

376 nvals = st.nrows() 

377 new_partition = RowPartition.from_uniform_row_length( 

378 nvals, nvals, nrows=1, validate=False) 

379 return (new_partition,) + st.row_partitions 

380 elif axis == st.rank: 

381 nvals = ( 

382 st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows()) 

383 return st.row_partitions + (RowPartition.from_uniform_row_length( 

384 1, nvals, nrows=nvals, validate=False),) 

385 else: 

386 nvals = ( 

387 st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows()) 

388 return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length( 

389 1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:] 

390 

391 

392# TODO(martinz): consider allowing values to be nested. 

393def _extend_op(values, leaf_op, empty_st_op=None): 

394 """Extend an op from RaggedTensor and Tensor to StructuredTensor. 

395 

396 Visits all children of the structured tensor, and children of children, 

397 applying leaf_op whenever it reaches a leaf, and empty_st_op whenever 

398 it reaches an internal node without children. 

399 

400 Args: 

401 values: a list of structured tensors, ragged tensors, or tensors. All must 

402 have the same type. If they are structured tensors, they must have the 

403 same paths. 

404 leaf_op: an op for handling non-structured tensor. 

405 empty_st_op: op to create a structured tensor without fields. 

406 

407 Returns: 

408 the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor) 

409 

410 Raises: 

411 ValueError: 

412 If values is not a Sequence or is empty. 

413 """ 

414 if not isinstance(values, Sequence): 

415 raise ValueError('Expected a list') 

416 

417 if not values: 

418 raise ValueError('List cannot be empty') 

419 

420 if empty_st_op is None: 

421 empty_st_op = empty_st_op_like_zeros(leaf_op) 

422 # Use the structure of the first StructuredTensor. They are all assumed to 

423 # be the same. 

424 value = values[0] 

425 

426 if isinstance(value, StructuredTensor): 

427 # TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later. 

428 empty_result = empty_st_op(values) 

429 if not value.field_names(): 

430 return empty_result 

431 new_fields = {} 

432 for k in value.field_names(): 

433 new_fields[k] = _extend_op([v.field_value(k) for v in values], leaf_op, 

434 empty_st_op) 

435 return StructuredTensor.from_fields(new_fields, shape=empty_result.shape) 

436 else: 

437 return leaf_op(values) 

438 

439 

440def _extend_op_single(value, leaf_op, empty_st_op=None): 

441 """Extend an op to a value instead of a list of values.""" 

442 

443 def to_list_op(element_op): 

444 if element_op is None: 

445 return None 

446 

447 def list_op(values): 

448 [value] = values 

449 return element_op(value) 

450 

451 return list_op 

452 

453 return _extend_op([value], to_list_op(leaf_op), to_list_op(empty_st_op)) 

454 

455 

456def empty_st_op_like_zeros(leaf_op): 

457 

458 def empty_st_op(values): 

459 as_zeros = [ 

460 zeros_like_v2(value, dtype=dtypes.int32) for value in values 

461 ] 

462 result = leaf_op(as_zeros) 

463 return _structured_tensor_like(result) 

464 

465 return empty_st_op 

466 

467 

468def _structured_tensor_from_dense_tensor(t): 

469 """Create a structured tensor with the shape of a dense tensor.""" 

470 # Note: If a tensor will have rank 0, 

471 # it either has a fully defined shape or has unknown rank. 

472 if t.shape.is_fully_defined(): 

473 return StructuredTensor.from_fields({}, shape=t.shape) 

474 elif t.shape.rank is None: 

475 raise ValueError("Can't build StructuredTensor w/ unknown rank") 

476 elif t.shape.rank == 1: 

477 return StructuredTensor.from_fields({}, shape=t.shape, 

478 nrows=array_ops.shape(t)[0]) 

479 else: 

480 rt = ragged_tensor.RaggedTensor.from_tensor(t) 

481 return _structured_tensor_from_row_partitions(t.shape, 

482 rt._nested_row_partitions) 

483 

484 

485def _structured_tensor_from_row_partitions(shape, row_partitions): 

486 return StructuredTensor.from_fields({}, 

487 shape=shape, 

488 row_partitions=row_partitions) 

489 

490 

491# pylint: disable=protected_access 

492def _all_nested_row_partitions(rt): 

493 """Returns all nested row partitions in rt, including for dense dimensions.""" 

494 if isinstance(rt, ops.Tensor): 

495 if rt.shape.rank <= 1: 

496 return () 

497 else: 

498 rt2 = ragged_tensor.RaggedTensor.from_tensor(rt) 

499 return rt2._nested_row_partitions 

500 else: 

501 tail_partitions = _all_nested_row_partitions(rt.flat_values) 

502 head_partitions = rt._nested_row_partitions # pylint: disable=protected_access 

503 return head_partitions + tail_partitions 

504 

505 

506def _structured_tensor_like(t): 

507 """Create a StructuredTensor with the shape of a (composite) tensor.""" 

508 if isinstance(t, ops.Tensor): 

509 return _structured_tensor_from_dense_tensor(t) 

510 if ragged_tensor.is_ragged(t): 

511 return StructuredTensor.from_fields( 

512 {}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t)) 

513 # here, it is a StructuredTensor 

514 return StructuredTensor.from_fields({}, 

515 shape=t.shape, 

516 row_partitions=t.row_partitions, 

517 nrows=t.nrows()) 

518 

519 

520def _get_all_paths(st): 

521 """Get all the paths from a StructuredTensor.""" 

522 fields = st.field_names() 

523 all_paths = {()} 

524 for k in fields: 

525 v = st.field_value(k) 

526 if isinstance(v, StructuredTensor): 

527 all_paths = all_paths.union([(k,) + p for p in _get_all_paths(v)]) 

528 else: 

529 all_paths.add((k,)) 

530 return all_paths 

531 

532 

533def _get_all_ranks(st): 

534 """Get ranks of all submessages of a StructuredTensor.""" 

535 fields = st.field_names() 

536 all_ranks = {(): st.rank} 

537 for k in fields: 

538 v = st.field_value(k) 

539 if isinstance(v, StructuredTensor): 

540 for (k2, v2) in _get_all_ranks(v).items(): 

541 all_ranks[(k,) + k2] = v2 

542 return all_ranks 

543 

544 

545def _assert_all_paths_match(values): 

546 """Raises an error if the paths are not identical.""" 

547 paths = [_get_all_paths(st) for st in values] 

548 path_diff = set() 

549 for other_paths in paths[1:]: 

550 path_diff = path_diff.union(paths[0].symmetric_difference(other_paths)) 

551 if path_diff: 

552 raise ValueError( 

553 'Some paths are present in some, but not all, structured tensors: %r' % 

554 (path_diff,)) 

555 

556 

557def _assert_all_ranks_match(values): 

558 """Raises an error if the ranks of submessages are not identical.""" 

559 ranks = [_get_all_ranks(st) for st in values] 

560 for other_ranks in ranks[1:]: 

561 if other_ranks != ranks[0]: 

562 # TODO(martinz): If this becomes common, we can provide more detail. 

563 # e.g.: which path is inconsistent. 

564 raise ValueError('Ranks of sub-message do not match') 

565 

566 

567def _assert_concat_compatible_structured_tensors(values): 

568 """Sometimes raises an error if concat doesn't make sense statically on values. 

569 

570 values must be a sequence, and each element in values must be a structured 

571 tensor, and must have the same paths. Additionally, each path that is a 

572 submessage must have the same rank. 

573 

574 These constraints are sufficient for concat on the fields to be the same 

575 as concat on structured tensors. This is meant to capture scenarios like 

576 paths that are not in the first structured tensor, but are in later 

577 structured tensors, which will just be ignored by the recursive algorithm. 

578 

579 If the rank of a submessage was different for two structured tensors, 

580 then that is also a non-sensical merge. 

581 

582 Note that all of these checks are static, as paths and submessage ranks 

583 are known. 

584 

585 Args: 

586 values: a Sequence of StructuredTensors. 

587 

588 Raises: 

589 ValueError: if there is any inconsistency as described above. 

590 """ 

591 if not isinstance(values, Sequence): 

592 raise ValueError('values must be a list of StructuredTensors (not a list)') 

593 if not values: 

594 raise ValueError('values must not be an empty list') 

595 for st in values: 

596 if not isinstance(st, StructuredTensor): 

597 raise ValueError('values must be a list of StructuredTensors') 

598 _assert_all_paths_match(values) 

599 _assert_all_ranks_match(values)