Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/structured/structured_tensor.py: 19%

648 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Structured Tensors.""" 

16 

17import re 

18from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union 

19 

20import numpy as np 

21 

22from tensorflow.python.framework import constant_op 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import extension_type 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import tensor_shape 

27from tensorflow.python.framework import tensor_spec 

28from tensorflow.python.framework import type_spec 

29from tensorflow.python.ops import array_ops 

30from tensorflow.python.ops import check_ops 

31from tensorflow.python.ops import control_flow_ops 

32from tensorflow.python.ops import math_ops 

33from tensorflow.python.ops.ragged import dynamic_ragged_shape 

34from tensorflow.python.ops.ragged import ragged_factory_ops 

35from tensorflow.python.ops.ragged import ragged_tensor 

36from tensorflow.python.ops.ragged.row_partition import RowPartition 

37from tensorflow.python.util import compat 

38from tensorflow.python.util import nest 

39from tensorflow.python.util.tf_export import tf_export 

40 

41# Each field may contain one of the following types of Tensors. 

42_FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor', 

43 extension_type.ExtensionType] 

44# Function that takes a FieldValue as input and returns the transformed 

45# FieldValue. 

46_FieldFn = Callable[[_FieldValue], _FieldValue] 

47 

48 

49@tf_export('experimental.StructuredTensor') 

50class StructuredTensor(extension_type.BatchableExtensionType): 

51 """A multidimensional collection of structures with the same schema. 

52 

53 A **`StructuredTensor`** is a multi-dimensional collection of ***structures*** 

54 with the same ***schema***, where: 

55 

56 * A ***schema*** is a collection of fields, each of which has a name and type. 

57 * A ***structure*** maps each field in the schema to a tensor value (which 

58 could be a nested StructuredTensor). 

59 

60 As an important special case, a 1D `StructuredTensor` encodes a 2D table, 

61 where columns are heterogeneous `Tensor`s, and rows are the aligned elements 

62 in each of those `Tensor`s. 

63 

64 Internally, StructuredTensors use a "field-major" encoding: for each leaf 

65 field, there is a single tensor that stores the value of that field for all 

66 structures in the `StructuredTensor`. 

67 

68 ### Examples 

69 

70 >>> # A scalar StructuredTensor describing a single person. 

71 >>> s1 = tf.experimental.StructuredTensor.from_pyval( 

72 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}) 

73 >>> s1.shape 

74 TensorShape([]) 

75 >>> s1["age"] 

76 <tf.Tensor: shape=(), dtype=int32, numpy=82> 

77 

78 >>> # A vector StructuredTensor describing three people. 

79 >>> s2 = tf.experimental.StructuredTensor.from_pyval([ 

80 ... {"age": 12, "nicknames": ["Josaphine"]}, 

81 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}, 

82 ... {"age": 42, "nicknames": ["Elmo"]}]) 

83 >>> s2.shape 

84 TensorShape([3]) 

85 >>> s2[0]["age"] 

86 <tf.Tensor: shape=(), dtype=int32, numpy=12> 

87 

88 

89 ### Field Paths 

90 

91 A *field path* is a tuple of field names, specifying the path to a nested 

92 field. 

93 """ 

94 _fields: Mapping[str, _FieldValue] 

95 _ragged_shape: dynamic_ragged_shape.DynamicRaggedShape 

96 

97 __name__ = 'tf.StructuredTensor' 

98 #============================================================================= 

99 # Common Types 

100 #============================================================================= 

101 # pylint: disable=invalid-name 

102 # Field names work as key, and they can be a sequence to refer to the 

103 # sub-levels (embedded) StructuredTensor's. 

104 FieldName = Union[str, Sequence[str]] 

105 

106 # pylint: enable=invalid-name 

107 

108 #============================================================================= 

109 # Constructor & Factory Methods 

110 #============================================================================= 

111 def __init__(self, fields: Mapping[str, _FieldValue], 

112 ragged_shape: dynamic_ragged_shape.DynamicRaggedShape): 

113 self._fields = fields 

114 self._ragged_shape = ragged_shape 

115 

116 @classmethod 

117 def _old_init(cls, fields, shape, nrows, row_partitions, internal=False): 

118 """Private constructor -- use factory methods to create StructuredTensors. 

119 

120 This constructor builds a `StructuredTensor` from the given attributes, 

121 performing minimal validation. 

122 

123 Args: 

124 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 

125 `StructuredTensor`. (This dict is not copied, so the caller must ensure 

126 that it does not get mutated via leaked references.) 

127 shape: `tf.TensorShape` with statically known rank. 

128 nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`. 

129 row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`. 

130 internal: ignored argument. 

131 

132 Returns: 

133 a StructuredTensor. 

134 """ 

135 assert isinstance(fields, dict), fields 

136 assert isinstance(shape, tensor_shape.TensorShape), shape 

137 assert nrows is None or isinstance(nrows, ops.Tensor), nrows 

138 assert row_partitions is None or isinstance(row_partitions, 

139 tuple), row_partitions 

140 return StructuredTensor( 

141 fields=fields, 

142 ragged_shape=_dynamic_ragged_shape_init(fields, shape, nrows, 

143 row_partitions)) 

144 

145 @classmethod 

146 def from_shape( 

147 cls, ragged_shape: dynamic_ragged_shape.DynamicRaggedShape 

148 ) -> 'StructuredTensor': 

149 """Creates a `StructuredTensor` with no fields and ragged_shape. 

150 

151 Args: 

152 ragged_shape: the shape of the structured tensor. 

153 

154 Returns: 

155 a StructuredTensor with no fields and ragged_shape. 

156 """ 

157 return StructuredTensor(fields={}, ragged_shape=ragged_shape) 

158 

159 @classmethod 

160 def from_fields(cls, 

161 fields, 

162 shape=(), 

163 nrows=None, 

164 row_partitions=None, 

165 validate=False): 

166 """Creates a `StructuredTensor` from a dictionary of fields. 

167 

168 Args: 

169 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 

170 `StructuredTensor`, providing the values for individual fields in each 

171 structure. If `shape.rank > 0`, then every tensor in `fields` must have 

172 the same shape in the first `shape.rank` dimensions; and that shape must 

173 be compatible with `shape`; and `result[i1...iN][key] = 

174 fields[key][i1...iN]` (where `N==shape.rank`). 

175 shape: A `TensorShape`: static information about the shape of the 

176 `StructuredTensor`. Must have a known `rank`. Defaults to scalar shape 

177 (i.e. `rank=0`). 

178 nrows: scalar integer tensor containing the number of rows in this 

179 `StructuredTensor`. Should only be specified if `shape.rank > 0`. 

180 Default value is inferred from the `fields` values. If `fields` is 

181 empty, then this must be specified. 

182 row_partitions: A list of `RowPartition`s describing the (possibly ragged) 

183 shape of this `StructuredTensor`. Should only be specified if 

184 `shape.rank > 1`. Default value is inferred from the `fields` values. 

185 If `fields` is empty, then this must be specified. 

186 validate: If true, then add runtime validation ops that check that the 

187 field values all have compatible shapes in the outer `shape.rank` 

188 dimensions. 

189 

190 Returns: 

191 A `StructuredTensor`. 

192 

193 Examples: 

194 

195 >>> tf.experimental.StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]}) 

196 <StructuredTensor( 

197 fields={ 

198 "x": tf.Tensor(1, shape=(), dtype=int32), 

199 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 

200 shape=())> 

201 

202 >>> tf.experimental.StructuredTensor.from_fields( 

203 ... {'foo': [1, 2], 'bar': [3, 4]}, shape=[2]) 

204 <StructuredTensor( 

205 fields={ 

206 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 

207 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 

208 shape=(2,))> 

209 """ 

210 shape = tensor_shape.as_shape(shape) 

211 rank = shape.rank 

212 if rank is None: 

213 raise ValueError("StructuredTensor's shape must have known rank.") 

214 if not isinstance(fields, dict): 

215 raise TypeError('fields must be a dictionary, got %s' % 

216 type(fields).__name__) 

217 if rank < 2 and row_partitions: 

218 raise ValueError('row_partitions must be None or [] if shape.rank<2') 

219 if rank == 0 and nrows is not None: 

220 raise ValueError('nrows must be None if shape.rank==0') 

221 if row_partitions is not None: 

222 row_partitions = tuple(row_partitions) 

223 if len(row_partitions) != max(0, rank - 1): 

224 raise ValueError('len(row_partitions) must be shape.rank-1') 

225 elif rank < 2: 

226 row_partitions = () 

227 

228 fields = dict(fields) # Make a private copy. 

229 with ops.name_scope(None, 'StructuredTensor', fields.values()): 

230 # TODO(martinz): Make this have better errors. 

231 shape = _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions) 

232 

233 # TODO(martinz): This may not need to be done if all fields are dense. 

234 if shape.rank > 1: 

235 shape = shape._with_num_row_partitions(shape.rank - 1) 

236 

237 # Validate keys and convert field values to tensors. 

238 for key, value in fields.items(): 

239 if not isinstance(key, str): 

240 raise TypeError(f'Unexpected type for key in `fields`: {key}') 

241 if not _FIELD_NAME_RE.match(key): 

242 raise ValueError('Field name %r is not currently allowed.' % key) 

243 fields[key] = _convert_to_structured_field_value(value) 

244 

245 fields = dict([(k, _replace_row_partitions(v, row_partitions)) 

246 for (k, v) in fields.items()]) 

247 return cls(fields=fields, ragged_shape=shape) 

248 

249 @classmethod 

250 def from_fields_and_rank( 

251 cls, 

252 fields: Mapping[str, _FieldValue], 

253 rank: int, 

254 validate: bool = False, 

255 dtype: Optional[dtypes.DType] = None) -> 'StructuredTensor': 

256 """Creates a `StructuredTensor` from a nonempty dictionary of fields. 

257 

258 Note that if the shape dtype is not specified, the shape dtype will be 

259 inferred from any fields that have a shape dtype. If fields differ, then 

260 int64 will be preferred to int32, because coercing from int32 to int64 is 

261 safer than coercing from int64 to int32. 

262 

263 If there are no ragged fields, then it will be int64 by default, but this 

264 will be changed to int32 in the future. 

265 

266 Args: 

267 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 

268 `StructuredTensor`, providing the values for individual fields in each 

269 structure. If `rank > 0`, then every tensor in `fields` must have the 

270 same shape in the first `rank` dimensions. Cannot be empty. 

271 rank: The rank of the resulting structured tensor. 

272 validate: If true, then add runtime validation ops that check that the 

273 field values all have compatible shapes in the outer `rank` dimensions. 

274 dtype: If specified, then forces dtype of the shape to be this. 

275 

276 Returns: 

277 A `StructuredTensor`. 

278 Examples: 

279 >>> tf.experimental.StructuredTensor.from_fields_and_rank( 

280 ... {'x': 1, 'y': [1, 2, 3]}, 0) 

281 <StructuredTensor( 

282 fields={ 

283 "x": tf.Tensor(1, shape=(), dtype=int32), 

284 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 

285 shape=())> 

286 >>> StructuredTensor.from_fields_and_rank({'foo': [1, 2], 'bar': [3, 4]}, 

287 ... 1) 

288 <StructuredTensor( 

289 fields={ 

290 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 

291 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 

292 shape=(2,))> 

293 """ 

294 if not fields: 

295 raise ValueError('Must provide at least one field') 

296 if not isinstance(rank, int): 

297 raise ValueError('rank must be an integer') 

298 if rank < 0: 

299 raise ValueError('rank must be nonnegative') 

300 fields = { 

301 k: _convert_to_structured_field_value(v) for (k, v) in fields.items() 

302 } 

303 if dtype is None: 

304 dtype = _find_shape_dtype(fields, None, None) 

305 fields = _fields_with_dtype(fields, dtype) 

306 

307 shape = _shape_from_fields(fields, rank, dtype) 

308 if rank > 1: 

309 shape = shape._with_num_row_partitions(rank - 1) 

310 new_rp = shape._row_partitions # pylint: disable=protected-access 

311 fields = { 

312 k: _replace_row_partitions(v, new_rp) for (k, v) in fields.items() 

313 } 

314 return StructuredTensor(fields=fields, ragged_shape=shape) 

315 

316 def with_updates(self, 

317 updates: Dict[FieldName, Union[_FieldValue, _FieldFn, None]], 

318 validate: bool = False) -> 'StructuredTensor': 

319 """Creates a new `StructuredTensor` with the updated fields. 

320 

321 If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being 

322 updated and `v` the new value, then: 

323 

324 ``` 

325 result[k] = v # If (k, v) is in updates and v is a FieldValue 

326 result[k] = f(self[k]) # If (k, f) is in updates and f is a FieldFn 

327 result[k] = self[k] # If k is in self.field_names but not in updates 

328 ``` 

329 

330 If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each 

331 FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is, 

332 prefixed with the same shape as the `StructuredTensor`. Then the resulting 

333 `StructuredTensor` will have: 

334 

335 ``` 

336 result[i1...iN][k] = v[i1...iN] # (k, v) in updates 

337 result[i1...iN][k] = f(self.field_value(k))[i1...iN] # (k, f) in updates 

338 result[i1...iN][k] = self[i1...iN][k] # k not in updates 

339 ``` 

340 

341 Note that `result.shape` is always equal to `self.shape` (but the shapes 

342 of nested StructuredTensors may be changed if they are updated with new 

343 values). 

344 

345 Args: 

346 updates: A dictionary mapping `FieldName` to either a `FieldValue` to be 

347 used to update, or a `FieldFn` that will transform the value for the 

348 given `FieldName`. `FieldName` can be a string for a direct field, or a 

349 sequence of strings to refer to a nested sub-field. `FieldFn` is a 

350 function that takes a `FieldValue` as input and should return a 

351 `FieldValue`. All other fields are copied over to the new 

352 `StructuredTensor`. New `FieldName` can be given (to add new fields), 

353 but only to existing `StructuredTensor`, it won't automatically create 

354 new nested structures -- but one can create a whole `StructureTensor` 

355 sub-structure and set that into an existing structure. If the new value 

356 is set to `None`, it is removed. 

357 validate: If true, then add runtime validation ops that check that the 

358 field values all have compatible shapes in the outer `shape.rank` 

359 dimensions. 

360 

361 Returns: 

362 A `StructuredTensor`. 

363 

364 Raises: 

365 `ValueError`: If the any of the `FieldName` keys points to non-existent 

366 sub-structures, if parent and child nodes are updated, if shapes 

367 change, if a delete update is given for a non-existent field, or if a 

368 `FieldFn` transforming function is given for a `FieldName` that doesn't 

369 yet exist. 

370 

371 Examples: 

372 

373 >>> shoes_us = tf.experimental.StructuredTensor.from_pyval([ 

374 ... {"age": 12, "nicknames": ["Josaphine"], 

375 ... "shoes": {"sizes": [8.0, 7.5, 7.5]}}, 

376 ... {"age": 82, "nicknames": ["Bob", "Bobby"], 

377 ... "shoes": {"sizes": [11.0, 11.5, 12.0]}}, 

378 ... {"age": 42, "nicknames": ["Elmo"], 

379 ... "shoes": {"sizes": [9.0, 9.5, 10.0]}}]) 

380 >>> def us_to_europe(t): 

381 ... return tf.round(t * 2.54 + 17.0) # Rough approximation. 

382 >>> shoe_sizes_key = ("shoes", "sizes") 

383 >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe}) 

384 >>> shoes_eu.field_value(shoe_sizes_key) 

385 <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0], 

386 [40.0, 41.0, 42.0]]> 

387 """ 

388 updates_items = [(_normalize_field_name_to_tuple(name), value) 

389 for name, value in updates.items()] 

390 

391 # Sort by keys and check for updates of both parent and child nodes. 

392 updates_items = sorted(updates_items) 

393 for i in range(1, len(updates_items)): 

394 # Parent of a node would precede node in the sorted order. 

395 name = updates_items[i][0] # item[0] is the name, item[1] is the value. 

396 prev_name = updates_items[i - 1][0] 

397 if name[:len(prev_name)] == prev_name: 

398 raise ValueError( 

399 '`StructuredTensor.with_updates` does not allow both parent and ' 

400 'child nodes to be updated: parent={}, child={}. If needed you can ' 

401 'update child nodes in the parent update value.'.format( 

402 prev_name, name)) 

403 return self._with_updates_impl((), updates_items, validate) 

404 

405 def _with_updates_impl(self, error_prefix: Tuple[str, ...], 

406 updates: List[Tuple[FieldName, Union[_FieldValue, 

407 _FieldFn]]], 

408 validate: bool) -> 'StructuredTensor': 

409 """Recursive part of `with_updates` implementation.""" 

410 # Get current fields. 

411 new_fields = dict(self._fields) 

412 

413 # Convert field name to string with full path for error messages. 

414 def name_fullpath(name: Sequence[str]) -> str: 

415 return str(error_prefix + (name,)) 

416 

417 # Apply value if a function or the value itself. 

418 def apply_value(name: str, value: Union[_FieldValue, 

419 _FieldFn]) -> _FieldValue: 

420 if callable(value): 

421 # `value` is actually a transforming function. 

422 if name not in new_fields: 

423 raise ValueError( 

424 '`StructuredTensor.with_updates` cannot update the field {} ' 

425 'because a transforming function was given, but that field ' 

426 'does not already exist.'.format(name_fullpath(name))) 

427 value = value(new_fields[name]) 

428 return value 

429 

430 # Merge updates. 

431 for name, value in updates: 

432 if not name or not name[0]: 

433 raise ValueError( 

434 '`StructuredTensor.with_updates` does not allow empty names ' 

435 '{}.'.format(name_fullpath(name))) 

436 

437 if len(name) == 1: 

438 name = name[0] 

439 if value is None: 

440 if name not in new_fields: 

441 raise ValueError( 

442 '`StructuredTensor.with_updates` cannot delete field ' 

443 '{} because it is not present.'.format(name_fullpath(name))) 

444 new_fields.pop(name) 

445 else: 

446 new_fields[name] = apply_value(name, value) 

447 else: 

448 # Recursive 

449 prefix = name[0] 

450 suffix = name[1:] 

451 if prefix not in new_fields: 

452 raise ValueError( 

453 '`StructuredTensor.with_updates` cannot create new sub-field ' 

454 '{} if parent field {} is not set.'.format( 

455 error_prefix + tuple(name), name_fullpath(prefix))) 

456 current_value = new_fields[prefix] 

457 if not isinstance(current_value, StructuredTensor): 

458 raise ValueError( 

459 '`StructuredTensor.with_updates` cannot create new sub-field ' 

460 '{} if parent structure {} is not a `StructuredTensor` that ' 

461 'can contain sub-structures -- it is a `{}`.'.format( 

462 error_prefix + tuple(name), name_fullpath(prefix), 

463 type(current_value))) 

464 one_update = [(suffix, value)] 

465 

466 # Accessing protected member in recursion. 

467 # FutureWork: optimize by aggregating the recursions, instead of 

468 # calling one at a time. 

469 # pylint: disable=protected-access 

470 value = current_value._with_updates_impl(error_prefix + (prefix,), 

471 one_update, validate) 

472 # pylint: enable=protected-access 

473 new_fields[prefix] = value 

474 

475 # TODO(edloper): When validate=True, only validate the modified fields. 

476 try: 

477 return StructuredTensor.from_fields( 

478 new_fields, 

479 shape=self.shape, 

480 row_partitions=self.row_partitions, 

481 nrows=self.nrows(), 

482 validate=validate) 

483 

484 except ValueError as e: 

485 msg = '`StructuredTensor.with_updates` failed' 

486 if error_prefix: 

487 msg = '{} for field {}'.format(msg, error_prefix) 

488 raise ValueError(msg) from e 

489 

490 def _promote_helper(self, source_path, new_parent_path): 

491 """Creates a promoted field without adding it to the structure. 

492 

493 Args: 

494 source_path: the source path in the structured tensor. 

495 new_parent_path: the new parent path. Must be a prefix of source_path. 

496 

497 Returns: 

498 a composite tensor of source_path promoted. 

499 Raises: 

500 ValueError: if the shape of the field is unknown and the right strategy 

501 cannot be determined. 

502 """ 

503 current_field = self.field_value(source_path) 

504 new_parent_rank = self.field_value(new_parent_path).rank 

505 parent_rank = self.field_value(source_path[:-1]).rank 

506 if new_parent_rank == parent_rank: 

507 return current_field 

508 current_field_rank = current_field.shape.rank 

509 if current_field_rank is None: 

510 raise ValueError('Cannot determine if dimensions should be merged.') 

511 inner_dim = min(parent_rank, current_field_rank - 1) 

512 if inner_dim <= new_parent_rank: 

513 return current_field 

514 return _merge_dims_generic(current_field, new_parent_rank, inner_dim) 

515 

516 def promote(self, source_path, new_name): 

517 """Promotes a field, merging dimensions between grandparent and parent. 

518 

519 >>> d = [ 

520 ... {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]}, 

521 ... {'docs': [{'tokens':[7]}]}] 

522 >>> st = tf.experimental.StructuredTensor.from_pyval(d) 

523 >>> st2 =st.promote(('docs','tokens'), 'docs_tokens') 

524 >>> st2[0]['docs_tokens'] 

525 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)> 

526 >>> st2[1]['docs_tokens'] 

527 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)> 

528 

529 Args: 

530 source_path: the path of the field or substructure to promote; must have 

531 length at least 2. 

532 new_name: the name of the new field (must be a string). 

533 

534 Returns: 

535 a modified structured tensor with the new field as a child of the 

536 grandparent of the source_path. 

537 

538 Raises: 

539 ValueError: if source_path is not a list or a tuple or has a length 

540 less than two, or new_name is not a string, or the rank 

541 of source_path is unknown and it is needed. 

542 """ 

543 if not isinstance(new_name, str): 

544 raise ValueError('new_name is not a string') 

545 if not isinstance(source_path, (list, tuple)): 

546 raise ValueError('source_path must be a list or tuple') 

547 

548 if len(source_path) < 2: 

549 raise ValueError('source_path must have length at least two') 

550 

551 grandparent_path = source_path[:-2] 

552 new_field = self._promote_helper(source_path, grandparent_path) 

553 new_path = grandparent_path + (new_name,) 

554 return self.with_updates({new_path: new_field}) 

555 

556 #============================================================================= 

557 # Properties 

558 #============================================================================= 

559 

560 @property 

561 def rank(self): 

562 """The rank of this StructuredTensor. Guaranteed not to be `None`.""" 

563 return self._ragged_shape.rank 

564 

565 @property 

566 def shape(self): 

567 """The static shape of this StructuredTensor. 

568 

569 The returned `TensorShape` is guaranteed to have a known rank, but the 

570 individual dimension sizes may be unknown. 

571 

572 Returns: 

573 `tf.TensorShape` 

574 """ 

575 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access 

576 

577 # TODO(martinz): for backwards compatibility 

578 @property 

579 def _row_partitions(self): 

580 """Deprecated form of row_partitions.""" 

581 return self.row_partitions 

582 

583 # TODO(edloper): Make this a func instead of a property? Or make nrows 

584 # a property instead of a func? Seems like these should be consistent. 

585 @property 

586 def row_partitions(self): 

587 """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`. 

588 

589 When `self.rank <= 1`, this tuple will be empty. 

590 

591 When `self.rank > 1`, these `RowPartitions` define the shape of the 

592 `StructuredTensor` by describing how a flat (1D) list of structures can be 

593 repeatedly partitioned to form a higher-dimensional object. In particular, 

594 the flat list is first partitioned into sublists using `row_partitions[-1]`, 

595 and then those sublists are further partitioned using `row_partitions[-2]`, 

596 etc. The following examples show the row partitions used to describe 

597 several different `StructuredTensor`, each of which contains 8 copies of 

598 the same structure (`x`): 

599 

600 >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']} # shape = [] (scalar) 

601 

602 >>> s1 = [[x, x, x, x], [x, x, x, x]] # shape = [2, 4] 

603 >>> tf.experimental.StructuredTensor.from_pyval(s1).row_partitions 

604 (tf.RowPartition(row_splits=[0 4 8]),) 

605 

606 >>> s2 = [[x, x], [x, x], [x, x], [x, x]] # shape = [4, 2] 

607 >>> tf.experimental.StructuredTensor.from_pyval(s2).row_partitions 

608 (tf.RowPartition(row_splits=[0 2 4 6 8]),) 

609 

610 >>> s3 = [[x, x, x], [], [x, x, x, x], [x]] # shape = [2, None] 

611 >>> tf.experimental.StructuredTensor.from_pyval(s3).row_partitions 

612 (tf.RowPartition(row_splits=[0 3 3 7 8]),) 

613 

614 >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]] # shape = [2, 2, 2] 

615 >>> tf.experimental.StructuredTensor.from_pyval(s4).row_partitions 

616 (tf.RowPartition(row_splits=[0 2 4]), 

617 tf.RowPartition(row_splits=[0 2 4 6 8])) 

618 

619 

620 >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]] # shape = [3, None, None] 

621 >>> tf.experimental.StructuredTensor.from_pyval(s5).row_partitions 

622 (tf.RowPartition(row_splits=[0 2 3 5]), 

623 tf.RowPartition(row_splits=[0 2 3 5 7 8])) 

624 

625 Note that shapes for nested fields (such as `x['b']` in the above example) 

626 are not considered part of the shape of a `StructuredTensor`, and are not 

627 included in `row_partitions`. 

628 

629 If this `StructuredTensor` has a ragged shape (i.e., if any of the 

630 `row_partitions` is not uniform in size), then all fields will be encoded 

631 as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s 

632 used to define their outermost `self.rank` dimensions. 

633 

634 Returns: 

635 A `tuple` of `RowPartition` objects with length `self.rank - 1` 

636 (or `0` if `self.rank < 2`) 

637 

638 """ 

639 if self.rank < 2: 

640 return () 

641 return self._ragged_shape._as_row_partitions() # pylint:disable=protected-access 

642 

643 def nrows(self): 

644 """The number of rows in this StructuredTensor (if rank>0). 

645 

646 This means the length of the outer-most dimension of the StructuredTensor. 

647 

648 Notice that if `self.rank > 1`, then this equals the number of rows 

649 of the first row partition. That is, 

650 `self.nrows() == self.row_partitions[0].nrows()`. 

651 

652 Otherwise `self.nrows()` will be the first dimension of the field values. 

653 

654 Returns: 

655 A scalar integer `Tensor` (or `None` if `self.rank == 0`). 

656 """ 

657 if self.rank == 0: 

658 return None 

659 return self._ragged_shape[0] 

660 

661 def with_shape_dtype(self, dtype: dtypes.DType) -> 'StructuredTensor': 

662 if dtype == self._ragged_shape.dtype: 

663 return self 

664 return StructuredTensor( 

665 fields=_fields_with_dtype(self._fields, dtype), 

666 ragged_shape=self._ragged_shape.with_dtype(dtype)) 

667 

668 def _is_eager(self): 

669 """True if all fields are composed of eager tensors.""" 

670 tensors = nest.flatten(self, expand_composites=True) 

671 return all(isinstance(t, ops.EagerTensor) for t in tensors) 

672 

673 #============================================================================= 

674 # Encoding 

675 #============================================================================= 

676 

677 def field_names(self): 

678 """Returns the string field names for this `StructuredTensor`.""" 

679 return tuple(self._fields.keys()) 

680 

681 def field_value(self, field_name): 

682 """Returns the tensor value for the specified field or path. 

683 

684 If `field_name` is a `string`, then it names a field directly owned by this 

685 `StructuredTensor`. If this `StructuredTensor` has shape `[D1...DN]`, then 

686 the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice 

687 `result[d1...dN]` contains the field value for the structure at 

688 `self[d1...dN]`. 

689 

690 If `field_name` is a `tuple` of `string`, then it specifies a path to a 

691 field owned by nested `StructuredTensor`. In particular, 

692 `struct.field_value((f1, f2, ..., fN))` is equivalent to 

693 `struct.field_value(f1).field_value(f2)....field_value(fN)` 

694 

695 Args: 

696 field_name: `string` or `tuple` of `string`: The field whose values should 

697 be returned. 

698 

699 Returns: 

700 `Tensor`, `StructuredTensor`, or `RaggedTensor`. 

701 

702 Raises: 

703 KeyError: If the given field_name is not found. 

704 """ 

705 if isinstance(field_name, (list, tuple)): 

706 value = self 

707 for f in field_name: 

708 if not isinstance(value, StructuredTensor): 

709 raise KeyError('Field path {} not found in {}'.format( 

710 field_name, self)) 

711 value = value.field_value(f) 

712 return value 

713 return self._fields[field_name] 

714 

715 #============================================================================= 

716 # Operators 

717 #============================================================================= 

718 

719 # TODO(edloper): Add support for ellipsis and/or newaxis? 

720 def __getitem__(self, key): 

721 """Returns the specified piece of this StructuredTensor. 

722 

723 * If `struct_tensor` is scalar (i.e., a single structure), then 

724 `struct_tensor[f]` returns the value of field `f` (where `f` must be a 

725 string). 

726 

727 * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional 

728 tensor of structures), `struct_tensor[i]` selects an element or slice of 

729 the tensor using standard Python semantics (e.g., negative values index 

730 from the end). `i` may have any of the following types: 

731 

732 * `int` constant 

733 * `string` constant 

734 * scalar integer `Tensor` 

735 * `slice` containing integer constants and/or scalar integer 

736 `Tensor`s 

737 

738 #### Multidimensional indexing 

739 

740 `StructuredTensor` supports multidimensional indexing. I.e., `key` may be a 

741 `tuple` of values, indexing or slicing multiple dimensions at once. For 

742 example, if `people` is a vector of structures, each of which has a vector- 

743 valued `names` field, then `people[3, 'names', 0]` is equivalent to 

744 `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly 

745 ragged) matrix of names, with shape `[num_people, num_names_per_person]`. 

746 

747 Args: 

748 key: Indicates which piece of the StructuredTensor to return. 

749 

750 Returns: 

751 A `Tensor`, `StructuredTensor`, or `RaggedTensor`. 

752 """ 

753 if isinstance(key, list): 

754 key = tuple(key) 

755 elif not isinstance(key, tuple): 

756 key = (key,) 

757 if not key: 

758 return self 

759 

760 if self.rank == 0: 

761 return self._scalar_getitem(key) 

762 else: 

763 return self._tensor_getitem(key) 

764 

765 def _scalar_getitem(self, key): 

766 if (isinstance(key[0], slice) and key[0].start is None and 

767 key[0].stop is None and key[0].step is None): 

768 fields = dict((field_name, field_value.__getitem__(key[1:])) 

769 for (field_name, field_value) in self._fields.items()) 

770 return StructuredTensor.from_fields(fields, self.shape) 

771 

772 elif not isinstance(key[0], compat.bytes_or_text_types): 

773 raise ValueError('Key for indexing a StructuredTensor must be a ' 

774 "string or a full slice (':')") 

775 

776 return self._fields[key[0]].__getitem__(key[1:]) 

777 

778 def _tensor_getitem(self, key): 

779 rank = self.rank 

780 if len(key) <= rank: 

781 new_fields = dict((field_name, field_value.__getitem__(key)) 

782 for (field_name, field_value) in self._fields.items()) 

783 result_shape = self.shape.as_list() 

784 for d, k in enumerate(key): 

785 if isinstance(k, slice): 

786 if not (k.start is None and k.stop is None and k.step is None): 

787 # TODO(edloper): Better static shape analysis here. 

788 result_shape[d] = None 

789 elif isinstance(k, (int, ops.Tensor)): 

790 result_shape[d] = -1 # mark for deletion 

791 elif k is None: 

792 raise ValueError('Slicing not supported for tf.newaxis') 

793 else: 

794 # Ellipsis, tf.newaxis: 

795 raise ValueError('Slicing not supported for %r' % k) 

796 result_shape = [d for d in result_shape if d != -1] 

797 return StructuredTensor.from_fields(new_fields, result_shape) 

798 

799 else: 

800 if not isinstance(key[rank], compat.bytes_or_text_types): 

801 # TODO(edloper): Also support full slice here? 

802 raise ValueError('Key for indexing a StructuredTensor must be a string') 

803 return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:]) 

804 

805 def __repr__(self): 

806 fields = sorted(self._fields.items()) 

807 fields = ((k, str(v).replace('\n', '\n ')) for k, v in fields) 

808 fields = ('"{}": {}'.format(k, v) for k, v in fields) 

809 dict_repr = ',\n '.join(fields) 

810 return ('<StructuredTensor(\n' 

811 ' fields={\n' 

812 ' %s},\n' 

813 ' shape=%s)>' % (dict_repr, self.shape)) 

814 

815 #============================================================================= 

816 # Conversion 

817 #============================================================================= 

818 

819 def to_pyval(self): 

820 """Returns this StructuredTensor as a nested Python dict or list of dicts. 

821 

822 Converts this `StructuredTensor` to a nested python value: 

823 

824 * `StructTensors` with `rank=0` are converted into a dictionary, with an 

825 entry for each field. Field names are used as keys and field values are 

826 converted to python values. In particular: 

827 

828 * Scalar Tensor fields are converted to simple values (such as 

829 `int` or `float` or `string`) 

830 * Non-scalar Tensor fields and RaggedTensor fields are converted to 

831 nested lists of simple values. 

832 * StructuredTensor fields are converted recursively using `to_pyval`. 

833 

834 * `StructTensors` with `rank>0` are converted to nested python `list`s, 

835 containing one dictionary for each structure (where each structure's 

836 dictionary is defined as described above). 

837 

838 Requires that all fields are Eager tensors. 

839 

840 >>> tf.experimental.StructuredTensor.from_fields( 

841 ... {'a': [1, 2, 3]}, [3]).to_pyval() 

842 [{'a': 1}, {'a': 2}, {'a': 3}] 

843 

844 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 

845 

846 Returns: 

847 A nested Python dict or list of dicts. 

848 """ 

849 if not self._is_eager(): 

850 raise ValueError( 

851 'StructuredTensor.to_pyval() is only supported in eager mode.') 

852 

853 # Convert each field value to a nested list. 

854 result = {} 

855 for (key, value) in self._fields.items(): 

856 if isinstance(value, ops.EagerTensor): 

857 value = value.numpy() 

858 if isinstance(value, np.ndarray): 

859 value = value.tolist() 

860 elif isinstance(value, ragged_tensor.RaggedTensor): 

861 value = value.to_list() 

862 elif isinstance(value, StructuredTensor): 

863 value = value.to_pyval() 

864 # TODO(edloper): Throw an exception if value is an unexpected type. 

865 result[key] = value 

866 

867 # If rank>0, then re-group each value from dict-of-list to list-of-dict. 

868 if len(self.shape) > 0: # pylint: disable=g-explicit-length-test 

869 if not result: # special-case for StructuredTensors w/ no fields. 

870 return _empty_dict_pylist_from_row_partitions(self.row_partitions, 

871 self.nrows()) 

872 return _pyval_field_major_to_node_major( 

873 list(result.keys()), list(result.values()), self.rank) 

874 else: 

875 return result 

876 

877 @classmethod 

878 def from_pyval(cls, pyval, typespec=None): 

879 """Constructs a StructuredTensor from a nested Python structure. 

880 

881 >>> tf.experimental.StructuredTensor.from_pyval( 

882 ... {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]}) 

883 <StructuredTensor( 

884 fields={ 

885 "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32), 

886 "b": <tf.RaggedTensor [[4, 5], [6, 7]]>}, 

887 shape=())> 

888 

889 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 

890 

891 Args: 

892 pyval: The nested Python structure that should be used to create the new 

893 `StructuredTensor`. 

894 typespec: A `StructuredTensor.Spec` specifying the expected type for each 

895 field. If not specified, then all nested dictionaries are turned into 

896 StructuredTensors, and all nested lists are turned into Tensors (if 

897 rank<2) or RaggedTensors (if rank>=2). 

898 

899 Returns: 

900 A `StructuredTensor`. 

901 """ 

902 return cls._from_pyval(pyval, typespec, ()) 

903 

904 @classmethod 

905 def _from_pyval(cls, pyval, typespec, path_so_far): 

906 """Helper function for from_pyval. 

907 

908 

909 Args: 

910 pyval: The nested Python structure that should be used to create the new 

911 `StructuredTensor`. 

912 typespec: A `StructuredTensor.Spec` specifying the expected type for each 

913 field. If not specified, then all nested dictionaries are turned into 

914 StructuredTensors, and all nested lists are turned into Tensors (if 

915 rank<2) or RaggedTensors (if rank>=2). 

916 path_so_far: the path of fields that led here (for error messages). 

917 

918 Returns: 

919 A `StructuredTensor`. 

920 """ 

921 if isinstance(pyval, dict): 

922 return cls._from_pydict(pyval, typespec, path_so_far) 

923 elif isinstance(pyval, (list, tuple)): 

924 keys = set() 

925 rank = _pyval_find_struct_keys_and_depth(pyval, keys) 

926 if rank is not None: 

927 return cls._from_pylist_of_dict(pyval, keys, rank, typespec, 

928 path_so_far) 

929 else: 

930 return cls._from_pylist_of_value(pyval, typespec, path_so_far) 

931 else: 

932 return cls._from_pyscalar(pyval, typespec, path_so_far) 

933 

934 @classmethod 

935 def _from_pydict(cls, pyval, typespec, path_so_far): 

936 """Converts python dictionary `pyval` to a StructuredTensor with rank=0.""" 

937 if typespec is None: 

938 fields = dict((k, cls._from_pyval(v, None, path_so_far + (k,))) 

939 for (k, v) in pyval.items()) 

940 else: 

941 spec_shape = typespec._shape # pylint: disable=protected-access 

942 field_specs = typespec._field_specs # pylint: disable=protected-access 

943 if not (isinstance(typespec, StructuredTensor.Spec) and 

944 spec_shape.rank == 0 and set(pyval) == set(field_specs)): 

945 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

946 (path_so_far, pyval, typespec)) 

947 fields = dict((k, cls._from_pyval(v, field_specs[k], path_so_far + (k,))) 

948 for (k, v) in pyval.items()) 

949 return StructuredTensor.from_fields(fields=fields, shape=(), validate=False) 

950 

951 @classmethod 

952 def _from_pylist_of_dict(cls, pyval, keys, rank, typespec, path_so_far): 

953 """Converts python list `pyval` to a StructuredTensor with rank>1.""" 

954 fields = dict((key, []) for key in keys) 

955 for child in pyval: 

956 _pyval_update_fields(child, fields, 1) 

957 if typespec is None: 

958 shape = tensor_shape.TensorShape([None] * rank) 

959 for (key, target) in fields.items(): 

960 fields[key] = cls._from_pyval(target, None, path_so_far + (key,)) 

961 else: 

962 field_specs = typespec._fields # pylint: disable=protected-access 

963 if ((not isinstance(typespec, StructuredTensor.Spec)) or # pylint: disable=superfluous-parens 

964 (set(fields) - set(field_specs))): 

965 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

966 (path_so_far, pyval, typespec)) 

967 shape = typespec._shape 

968 if shape.rank < rank: 

969 raise ValueError('Value at %r does not match typespec (rank mismatch): ' 

970 '%r vs %r' % (path_so_far, pyval, typespec)) 

971 for (key, spec) in field_specs.items(): 

972 fields[key] = cls._from_pyval( 

973 fields.get(key, []), spec, path_so_far + (key,)) 

974 try: 

975 if not fields and typespec is None: 

976 # TODO(b/183245576): handle cases where the typespec is known 

977 # but the dictionary is empty. 

978 return StructuredTensor._from_pylist_of_empty_dict(pyval, rank) 

979 return StructuredTensor.from_fields( 

980 fields=fields, shape=shape, validate=False) 

981 except Exception as exc: 

982 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 

983 

984 @classmethod 

985 def _from_pylist_of_empty_dict(cls, pyval, rank): 

986 """Converts a pylist of empty dictionaries to StructuredTensors.""" 

987 if rank == 0: 

988 return StructuredTensor.from_fields(fields={}, shape=(), validate=False) 

989 elif rank == 1: 

990 nrows = len(pyval) 

991 shape = (nrows,) 

992 return StructuredTensor.from_fields(fields={}, shape=shape, nrows=nrows) 

993 elif rank > 1: 

994 ragged_zeros = ragged_factory_ops.constant(_dicts_to_zeros(pyval)) 

995 nrows = len(pyval) 

996 shape = tensor_shape.TensorShape([len(pyval)] + ([None] * (rank - 1))) 

997 return StructuredTensor.from_fields( 

998 fields={}, 

999 shape=shape, 

1000 row_partitions=ragged_zeros._nested_row_partitions, # pylint:disable=protected-access 

1001 nrows=nrows) 

1002 

1003 @classmethod 

1004 def _from_pylist_of_value(cls, pyval, typespec, path_so_far): 

1005 """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1.""" 

1006 if typespec is None: 

1007 try: 

1008 return ragged_factory_ops.constant(pyval) 

1009 except Exception as exc: 

1010 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 

1011 elif isinstance(typespec, tensor_spec.TensorSpec): 

1012 try: 

1013 result = constant_op.constant(pyval, typespec.dtype) 

1014 except Exception as exc: 

1015 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 

1016 if not typespec.shape.is_compatible_with(result.shape): 

1017 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

1018 (path_so_far, typespec, pyval)) 

1019 return result 

1020 elif isinstance(typespec, ragged_tensor.RaggedTensorSpec): 

1021 # pylint: disable=protected-access 

1022 try: 

1023 return ragged_factory_ops.constant( 

1024 pyval, 

1025 dtype=typespec._dtype, 

1026 ragged_rank=typespec._ragged_rank, 

1027 row_splits_dtype=typespec._row_splits_dtype, 

1028 inner_shape=typespec._shape[typespec._ragged_rank + 1:]) 

1029 except Exception as exc: 

1030 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 

1031 elif isinstance(typespec, StructuredTensor.Spec): 

1032 empty_rank = _pyval_empty_list_depth(pyval) 

1033 if empty_rank is None: 

1034 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

1035 (path_so_far, typespec, pyval)) 

1036 else: 

1037 return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec, 

1038 path_so_far) 

1039 else: 

1040 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

1041 (path_so_far, typespec, pyval)) 

1042 

1043 @classmethod 

1044 def _from_pyscalar(cls, pyval, typespec, path_so_far): 

1045 """Converts python scalar value `pyval` to a Tensor.""" 

1046 if typespec is None: 

1047 try: 

1048 return constant_op.constant(pyval) 

1049 except Exception as exc: 

1050 raise ValueError('Error parsing path %r' % (path_so_far,)) from exc 

1051 else: 

1052 if not (isinstance(typespec, tensor_spec.TensorSpec) and 

1053 typespec.shape.rank == 0): 

1054 raise ValueError('Value at %r does not match typespec: %r vs %r' % 

1055 (path_so_far, typespec, pyval)) 

1056 # TODO(edloper): Check that typespec.shape matches. 

1057 return constant_op.constant(pyval, typespec.dtype) 

1058 

1059 #============================================================================= 

1060 # Transforms 

1061 #============================================================================= 

1062 

1063 # TODO(edloper): Add a 'validate' option here? 

1064 # TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor 

1065 # have a partition_outer_dimension method? 

1066 def partition_outer_dimension(self, row_partition): 

1067 """Partitions the outer dimension of this StructuredTensor. 

1068 

1069 Returns a new `StructuredTensor` with the same values as `self`, where 

1070 the outer dimension is partitioned into two (possibly ragged) dimensions. 

1071 Requires that this StructuredTensor have an outer dimension (i.e., 

1072 `self.shape.rank > 0`). 

1073 

1074 >>> st = tf.experimental.StructuredTensor.from_pyval( 

1075 ... [{'foo': 12}, {'foo': 33}, {'foo': 99}]) 

1076 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 

1077 >>> st.partition_outer_dimension(partition) 

1078 <StructuredTensor( 

1079 fields={ 

1080 "foo": <tf.RaggedTensor [[12, 33], [], [99]]>}, 

1081 shape=(3, None))> 

1082 

1083 Args: 

1084 row_partition: A `RowPartition`. 

1085 

1086 Returns: 

1087 A `StructuredTensor` with rank `values.rank + 1`. 

1088 """ 

1089 if not isinstance(row_partition, RowPartition): 

1090 raise TypeError('row_partition must be a RowPartition.') 

1091 if self.shape.rank == 0: 

1092 raise ValueError('Shape %s must have rank at least 1' % self.shape) 

1093 return _partition_outer_dimension(self, row_partition) 

1094 

1095 def merge_dims(self, outer_axis, inner_axis): 

1096 """Merges outer_axis...inner_axis into a single dimension. 

1097 

1098 Returns a copy of this RaggedTensor with the specified range of dimensions 

1099 flattened into a single dimension, with elements in row-major order. 

1100 

1101 >>> st = tf.experimental.StructuredTensor.from_pyval( 

1102 ... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]) 

1103 >>> st.merge_dims(0, 1) 

1104 <StructuredTensor( 

1105 fields={ 

1106 "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)}, 

1107 shape=(3,))> 

1108 

1109 Args: 

1110 outer_axis: `int`: The first dimension in the range of dimensions to 

1111 merge. May be negative (to index from the last dimension). 

1112 inner_axis: `int`: The last dimension in the range of dimensions to merge. 

1113 May be negative (to index from the last dimension). 

1114 

1115 Returns: 

1116 A copy of this tensor, with the specified dimensions merged into a 

1117 single dimension. The shape of the returned tensor will be 

1118 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 

1119 is the total number of slices in the merged dimensions. 

1120 """ 

1121 outer_axis = array_ops.get_positive_axis( 

1122 outer_axis, 

1123 self.shape.rank, 

1124 axis_name='outer_axis', 

1125 ndims_name='rank(self)') 

1126 inner_axis = array_ops.get_positive_axis( 

1127 inner_axis, 

1128 self.shape.rank, 

1129 axis_name='inner_axis', 

1130 ndims_name='rank(self)') 

1131 if not outer_axis <= inner_axis: 

1132 raise ValueError('Expected outer_axis (%d) to be less than or equal to ' 

1133 'inner_axis (%d)' % (outer_axis, inner_axis)) 

1134 return _merge_dims(self, outer_axis, inner_axis) 

1135 

1136 class Spec: 

1137 """A spec for StructuredTensor.""" 

1138 

1139 def __validate__(self): 

1140 assert self._ragged_shape is not None 

1141 

1142 @classmethod 

1143 def _from_fields_and_rank(cls, fields, rank): 

1144 """Creates a spec of a StructuredTensor with fields and rank.""" 

1145 shape = None 

1146 for (k, v) in fields.items(): 

1147 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v) 

1148 if field_shape_untruncated is None: 

1149 raise ValueError(f'Cannot convert spec of {k}.') 

1150 untruncated_rank = field_shape_untruncated.rank 

1151 if (untruncated_rank is not None and untruncated_rank < rank): 

1152 raise ValueError(f'Rank of field {k} is {untruncated_rank}, ' 

1153 f'but must be at least {rank}.') 

1154 field_shape = field_shape_untruncated._truncate(rank) # pylint: disable=protected-access 

1155 if shape is None: 

1156 shape = field_shape 

1157 else: 

1158 shape = shape._merge_with(field_shape) 

1159 return StructuredTensor.Spec(_ragged_shape=shape, _fields=fields) 

1160 

1161 @classmethod 

1162 def _from_shape( 

1163 cls, shape: dynamic_ragged_shape.DynamicRaggedShape 

1164 ) -> 'StructuredTensor.Spec': 

1165 """Creates the spec of an empty StructuredTensor.""" 

1166 return StructuredTensor.Spec(_ragged_shape=shape, _fields={}) 

1167 

1168 # For backwards compatibility 

1169 @property 

1170 def _shape(self) -> tensor_shape.TensorShape: 

1171 return self._ragged_shape._to_tensor_shape() # pylint: disable=protected-access 

1172 

1173 # For backwards compatibility 

1174 @property 

1175 def _field_specs(self) -> Dict[str, type_spec.TypeSpec]: 

1176 return self._fields 

1177 

1178 # For backwards compatibility 

1179 @property 

1180 def shape(self) -> tensor_shape.TensorShape: 

1181 return self._shape 

1182 

1183 # For backwards compatibility 

1184 @property 

1185 def rank(self): 

1186 return self._ragged_shape.rank 

1187 

1188 

1189# Regular expression used to determine whether a string is a valid field name. 

1190# Note: we plan to relax (or possibly eliminate) this in the future; you 

1191# should not rely on the fact that some field names are currently disallowed. 

1192_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$') 

1193 

1194#============================================================================= 

1195# Helper functions 

1196#============================================================================= 

1197# TODO(edloper): Move some of these helpers to row_partition.py? 

1198 

1199 

1200def _convert_to_structured_field_value(value): 

1201 """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" 

1202 if isinstance(value, 

1203 (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): 

1204 return value 

1205 elif ragged_tensor.is_ragged(value): 

1206 return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 

1207 elif isinstance(value, extension_type.ExtensionType): 

1208 return value 

1209 else: 

1210 try: 

1211 return ops.convert_to_tensor(value) 

1212 except (ValueError, TypeError) as e: 

1213 raise TypeError('Unexpected type for value in `fields`: %r' % 

1214 value) from e 

1215 

1216 

1217def _find_shape_dtype( 

1218 fields: Mapping[str, _FieldValue], nrows: Optional[ops.Tensor], 

1219 row_partitions: Optional[Sequence[RowPartition]]) -> dtypes.DType: 

1220 """Return a consistent dtype for fields, nrows, & row_partitions. 

1221 

1222 In the future, the default will switch from int64 to int32, but for now, 

1223 we stick with int64. 

1224 

1225 Args: 

1226 fields: the fields of the StructuredTensor. 

1227 nrows: the nrows of the StructuredTensor 

1228 row_partitions: the row_partitions of the StructuredTensor. 

1229 

1230 Returns: 

1231 If anything requires int64, then return int64. 

1232 If int32 is explicitly specified, return int32. Otherwise, return int64. 

1233 """ 

1234 field_dtypes = [_field_shape_dtype(v) for v in fields.values()] 

1235 nrows_dtypes = [nrows.dtype] if isinstance(nrows, ops.Tensor) else [] 

1236 rp_dtypes = [] if row_partitions is None else [ 

1237 rp.dtype for rp in row_partitions 

1238 ] 

1239 

1240 all_dtypes = field_dtypes + nrows_dtypes + rp_dtypes 

1241 

1242 if dtypes.int64 in all_dtypes: 

1243 return dtypes.int64 

1244 if dtypes.int32 in all_dtypes: 

1245 return dtypes.int32 

1246 

1247 # TODO(martinz): Eventually, shift this to tf.int32. 

1248 return dtypes.int64 

1249 

1250 

1251def _merge_nrows(nrows, static_nrows, value, dtype, validate): 

1252 """Merges `nrows` with `nrows(value)`. 

1253 

1254 Checks that `value` has the expected number of rows (`nrows`), and returns 

1255 `nrows`. If `validate` is true, then add validation ops that check that 

1256 the `nrows` values match. 

1257 

1258 Args: 

1259 nrows: scalar integer Tensor. 

1260 static_nrows: tf.Dimension: static value of nrows, if known. 

1261 value: Tensor or RaggedTensor or StructuredTensor 

1262 dtype: dtype for `nrows`. 

1263 validate: bool -- whether to add validation ops. 

1264 

1265 Returns: 

1266 A tuple `(nrows, static_nrows)`. 

1267 """ 

1268 static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) 

1269 if isinstance(value, ops.Tensor): 

1270 value_nrows = array_ops.shape(value, out_type=dtype)[0] 

1271 else: 

1272 value_nrows = value.nrows() 

1273 if nrows is None: 

1274 nrows = value_nrows 

1275 elif (static_value_nrows.value is not None and 

1276 static_nrows.value is not None): 

1277 if not static_value_nrows.is_compatible_with(static_nrows): 

1278 raise ValueError('fields have incompatible nrows') 

1279 nrows = value_nrows # No need to add an assertion op. 

1280 elif validate: 

1281 nrows = control_flow_ops.with_dependencies([ 

1282 check_ops.assert_equal( 

1283 nrows, value_nrows, message='fields have incompatible nrows') 

1284 ], nrows) 

1285 return nrows, static_nrows._merge_with(static_value_nrows) # pylint: disable=protected-access 

1286 

1287 

1288def _merge_row_partitions(row_partitions, value, rank, dtype, validate): 

1289 """Merges `row_partitions` with `row_partitions(value)`.""" 

1290 if isinstance(value, ops.Tensor): 

1291 value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) 

1292 

1293 elif isinstance(value, ragged_tensor.RaggedTensor): 

1294 value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype) 

1295 

1296 else: 

1297 assert isinstance(value, StructuredTensor), type(value) 

1298 value_row_partitions = value.row_partitions[:rank - 1] 

1299 

1300 assert len(value_row_partitions) == rank - 1 

1301 if row_partitions is None: 

1302 return tuple(value_row_partitions) 

1303 else: 

1304 return tuple([ 

1305 p1._merge_precomputed_encodings(p2, validate) # pylint: disable=protected-access 

1306 for (p1, p2) in zip(row_partitions, value_row_partitions) 

1307 ]) 

1308 

1309 

1310def _row_partitions_for_tensor(value, rank, dtype): 

1311 """Returns the row partitions for a tf.Tensor.""" 

1312 shape = array_ops.shape(value, out_type=dtype) 

1313 return _row_partitions_for_uniform_shape(shape, rank) 

1314 

1315 

1316def _row_partitions_for_ragged_tensor(value, rank, dtype): 

1317 """Returns the row partitions for a tf.RaggedTensor.""" 

1318 assert rank > 1 

1319 value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access 

1320 if len(value_row_partitions) < (rank - 1): 

1321 value_row_partitions += _row_partitions_for_tensor( 

1322 value.flat_values, rank - len(value_row_partitions), dtype) 

1323 assert len(value_row_partitions) == rank - 1 

1324 return value_row_partitions 

1325 

1326 

1327def _row_partitions_for_uniform_shape(shape, rank): 

1328 """Returns row partitions for the given shape Tensor. 

1329 

1330 Args: 

1331 shape: A vector describing a uniform shape. 

1332 rank: The number of dimensions to generate row partitions for 

1333 

1334 Returns: 

1335 A list of (rank-1) `RowPartition`s with uniform row length. 

1336 """ 

1337 shape_cumprod = math_ops.cumprod(shape[:rank]) 

1338 # pylint: disable=g-complex-comprehension 

1339 return tuple([ 

1340 RowPartition.from_uniform_row_length( 

1341 uniform_row_length=shape[i + 1], 

1342 nvals=shape_cumprod[i + 1], 

1343 nrows=shape_cumprod[i]) for i in range(rank - 1) 

1344 ]) 

1345 

1346 

1347def _pyval_field_major_to_node_major(keys, values, depth): 

1348 """Regroup each field (k, v) from dict-of-list to list-of-dict. 

1349 

1350 Given a "field-major" encoding of the StructuredTensor (which maps each key to 

1351 a single nested list containing the values for all structs), return a 

1352 corresponding "node-major" encoding, consisting of a nested list of dicts. 

1353 

1354 Args: 

1355 keys: The field names (list of string). Must not be empty. 

1356 values: The field values (list of python values). Must have the same length 

1357 as `keys`. 

1358 depth: The list depth at which dictionaries should be created. 

1359 

1360 Returns: 

1361 A nested list of dict, with depth `depth`. 

1362 """ 

1363 assert keys 

1364 if depth == 0: 

1365 return dict(zip(keys, values)) 

1366 nvals = len(values[0]) 

1367 assert all(nvals == len(values[i]) for i in range(1, len(values))) 

1368 return [ 

1369 _pyval_field_major_to_node_major(keys, value_slice, depth - 1) 

1370 for value_slice in zip(*values) 

1371 ] 

1372 

1373 

1374def _empty_dict_pylist_from_row_partitions(row_partitions, nrows): 

1375 """Returns a python list of empty dicts from the given row partitions. 

1376 

1377 Args: 

1378 row_partitions: The row-partitions describing the ragged shape of the 

1379 result. 

1380 nrows: The number of rows in the outermost row-partition. (Or if 

1381 `len(row_partitions)==0`, then the number of empty dicts to return.) 

1382 

1383 Returns: 

1384 A nested python list whose leaves (if any) are empty python dicts. 

1385 """ 

1386 if not row_partitions: 

1387 return [{} for _ in range(nrows)] 

1388 else: 

1389 values = _empty_dict_pylist_from_row_partitions( 

1390 row_partitions[1:], row_partitions[0].row_splits()[-1]) 

1391 splits = row_partitions[0].row_splits() 

1392 return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 

1393 

1394 

1395def _pyval_find_struct_keys_and_depth(pyval, keys): 

1396 """Finds the keys & depth of nested dictionaries in `pyval`. 

1397 

1398 Args: 

1399 pyval: A nested structure of lists, tuples, and dictionaries. 

1400 keys: (output parameter) A set, which will be updated with any keys that are 

1401 found in the nested dictionaries. 

1402 

1403 Returns: 

1404 The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does 

1405 not contain any dictionaries. 

1406 Raises: 

1407 ValueError: If dictionaries have inconsistent depth. 

1408 """ 

1409 if isinstance(pyval, dict): 

1410 keys.update(pyval.keys()) 

1411 return 0 

1412 elif isinstance(pyval, (list, tuple)): 

1413 depth = None 

1414 for child in pyval: 

1415 child_depth = _pyval_find_struct_keys_and_depth(child, keys) 

1416 if child_depth is not None: 

1417 if depth is None: 

1418 depth = child_depth + 1 

1419 elif depth != child_depth + 1: 

1420 raise ValueError('Inconsistent depth of dictionaries') 

1421 return depth 

1422 else: 

1423 return None 

1424 

1425 

1426def _pyval_update_fields(pyval, fields, depth): 

1427 """Append the field values from `pyval` to `fields`. 

1428 

1429 Args: 

1430 pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s) 

1431 should be appended to `fields`. 

1432 fields: A dictionary mapping string keys to field values. Field values 

1433 extracted from `pyval` are appended to this dictionary's values. 

1434 depth: The depth at which `pyval` should be appended to the field values. 

1435 """ 

1436 if not isinstance(pyval, (dict, list, tuple)): 

1437 raise ValueError('Expected dict or nested list/tuple of dict') 

1438 

1439 for (key, target) in fields.items(): 

1440 for _ in range(1, depth): 

1441 target = target[-1] 

1442 target.append(pyval[key] if isinstance(pyval, dict) else []) 

1443 

1444 if isinstance(pyval, (list, tuple)): 

1445 for child in pyval: 

1446 _pyval_update_fields(child, fields, depth + 1) 

1447 

1448 

1449def _pyval_empty_list_depth(pyval): 

1450 """Find the max depth for nested empty lists. 

1451 

1452 Args: 

1453 pyval: A nested python list. 

1454 

1455 Returns: 

1456 The maximum depth of empty lists in `pyval`, or None if `pyval` contains 

1457 anything other than nested empty lists. 

1458 """ 

1459 if isinstance(pyval, list): 

1460 if not pyval: 

1461 return 1 

1462 depths = [_pyval_empty_list_depth(v) for v in pyval] 

1463 if any(depth is None for depth in depths): 

1464 return None 

1465 else: 

1466 return max(depths) + 1 

1467 else: 

1468 return None 

1469 

1470 

1471def _replace_row_partitions(value, new_partitions): 

1472 """Updates `value` to use `new_partitions` as its (outer) row partitions. 

1473 

1474 This is used to ensure that all fields in a `StructuredTensor` use identical 

1475 `RowPartition` objects for the shared dimensions. In particular, 

1476 `StructuredTensor.from_fields` first merges all of the row partitions from 

1477 any fields, and then replaces the outer row partitions of all fields with 

1478 the merged row partitions (using this function). 

1479 

1480 Args: 

1481 value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`. 

1482 new_partitions: A list of row-partitions that should be used by `value`. 

1483 Must be equivalent to `value`'s current row partitions. 

1484 

1485 Returns: 

1486 A value that is equivalent to `value`, where outer row partitions have been 

1487 replaced by `new_partitions`. 

1488 """ 

1489 if isinstance(value, ops.Tensor) or not new_partitions: 

1490 return value 

1491 

1492 elif isinstance(value, ragged_tensor.RaggedTensor): 

1493 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 

1494 values=_replace_row_partitions(value.values, new_partitions[1:]), 

1495 row_partition=new_partitions[0]) 

1496 

1497 else: 

1498 assert isinstance(value, StructuredTensor) 

1499 new_fields = dict((k, _replace_row_partitions(v, new_partitions)) 

1500 for (k, v) in value._fields.items()) 

1501 return StructuredTensor._old_init( # pylint: disable=protected-access 

1502 fields=new_fields, 

1503 shape=value.shape, 

1504 nrows=value.nrows(), 

1505 row_partitions=tuple(new_partitions) + 

1506 tuple(value.row_partitions[len(new_partitions):])) 

1507 

1508 

1509def _partition_outer_dimension(value, row_partition): 

1510 """Partitions the outer dimension of `value` using `row_partitions`. 

1511 

1512 Examples: 

1513 

1514 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 

1515 >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition) 

1516 <tf.RaggedTensor [[1, 2], [], [3]]> 

1517 

1518 >>> struct_value = tf.experimental.StructuredTensor.from_pyval( 

1519 ... [{'x': 1}, {'x': 2}, {'x': 3}]) 

1520 >>> _partition_outer_dimension(struct_value, partition) 

1521 <StructuredTensor( 

1522 fields={ 

1523 "x": <tf.RaggedTensor [[1, 2], [], [3]]>}, 

1524 shape=(3, None))> 

1525 

1526 Args: 

1527 value: Tensor, RaggedTensor, or StructuredTensor 

1528 row_partition: RowPartition 

1529 

1530 Returns: 

1531 A value with the same type as `value`, where 

1532 `result.rank = value.rank + 1`. 

1533 """ 

1534 is_ragged = row_partition.uniform_row_length() is None 

1535 if isinstance(value, ops.Tensor) and not is_ragged: 

1536 new_shape = array_ops.concat( 

1537 [[row_partition.nrows(), 

1538 row_partition.uniform_row_length()], 

1539 array_ops.shape(value, out_type=row_partition.dtype)[1:]], 

1540 axis=0) 

1541 return array_ops.reshape(value, new_shape) 

1542 elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 

1543 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 

1544 value, row_partition) 

1545 else: 

1546 assert isinstance(value, StructuredTensor) 

1547 nrows = row_partition.static_nrows 

1548 ncols = row_partition.static_uniform_row_length 

1549 shape = tensor_shape.TensorShape([nrows, 

1550 ncols]).concatenate(value.shape[1:]) 

1551 fields = dict((k, _partition_outer_dimension(v, row_partition)) 

1552 for (k, v) in value._fields.items()) 

1553 return StructuredTensor._old_init( # pylint: disable=protected-access 

1554 fields, shape, row_partition.nrows(), 

1555 (row_partition,) + value.row_partitions) 

1556 

1557 

1558def _merge_dims(value, outer_axis, inner_axis): 

1559 """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" 

1560 assert outer_axis < inner_axis 

1561 if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 

1562 return ragged_tensor.merge_dims(value, outer_axis, inner_axis) 

1563 else: 

1564 assert isinstance(value, StructuredTensor) 

1565 fields = dict((k, _merge_dims(v, outer_axis, inner_axis)) 

1566 for (k, v) in value._fields.items()) 

1567 ragged_shape = value._ragged_shape._merge_dims( # pylint: disable=protected-access 

1568 outer_axis, inner_axis) 

1569 return StructuredTensor(fields, ragged_shape) 

1570 

1571 

1572_structured_tensor_factory_key = object() # unique private object 

1573 

1574 

1575def _dynamic_ragged_shape_spec_from_spec( 

1576 spec: Union[dynamic_ragged_shape.DynamicRaggedShape.Spec, 

1577 ragged_tensor.RaggedTensorSpec, StructuredTensor.Spec, 

1578 tensor_spec.TensorSpec] 

1579) -> dynamic_ragged_shape.DynamicRaggedShape.Spec: 

1580 if isinstance(spec, StructuredTensor.Spec): 

1581 return spec._ragged_shape # pylint: disable=protected-access 

1582 else: 

1583 return dynamic_ragged_shape.DynamicRaggedShape.Spec._from_spec(spec) # pylint: disable=protected-access 

1584 

1585 

1586def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]: 

1587 """FieldName can be given also as string, this normalizes it to a tuple.""" 

1588 if isinstance(name, str): 

1589 return (name,) 

1590 if isinstance(name, list): 

1591 return tuple(name) 

1592 assert isinstance(name, tuple) 

1593 return name 

1594 

1595 

1596def _dicts_to_zeros(pyval): 

1597 """Replaces dictionaries zeros in a pylist.""" 

1598 if isinstance(pyval, dict): 

1599 return 0 

1600 return [_dicts_to_zeros(x) for x in pyval] 

1601 

1602 

1603def _merge_dims_generic(source, outer, inner): 

1604 """Merges outer_axis...inner_axis into a single dimension. 

1605 

1606 If outer == inner, this is a NOOP. If inner < outer, then this fials. 

1607 If inner >= source.shape.rank, then the behavior is undefined. 

1608 

1609 Args: 

1610 source: a tensor, ragged tensor, or structured tensor. 

1611 outer: a python int, indicating the first dimension to compress (must be 

1612 nonnegative). 

1613 inner: a python int, indicating the first dimension to keep (of the tail) 

1614 (must be nonnegative). 

1615 

1616 Returns: 

1617 source with outer_axis...inner_axis merged into a single dimension. 

1618 

1619 """ 

1620 if isinstance(source, StructuredTensor): 

1621 return source.merge_dims(outer, inner) 

1622 else: 

1623 return ragged_tensor.merge_dims(source, outer, inner) 

1624 

1625 

1626def _dynamic_ragged_shape_from_tensor( 

1627 field, dtype=None) -> dynamic_ragged_shape.DynamicRaggedShape: 

1628 """Extension of DynamicRaggedShape.from_tensor to support StructuredTensor.""" 

1629 if isinstance(field, StructuredTensor): 

1630 return field._ragged_shape # pylint: disable=protected-access 

1631 shape = array_ops.shape_v2(field, out_type=dtype) 

1632 

1633 if isinstance(shape, ops.Tensor): 

1634 return dynamic_ragged_shape.DynamicRaggedShape( 

1635 row_partitions=[], inner_shape=shape) 

1636 elif isinstance(shape, dynamic_ragged_shape.DynamicRaggedShape): 

1637 return shape 

1638 # TODO(martinz): add a test for the following line. 

1639 raise TypeError(f'Expected shape tf.shape({field}) to return a Tensor or a ' 

1640 f'DynamicRaggedShape. Instead, got: {shape}.') 

1641 

1642 

1643def _merge_with_optional( 

1644 a: Optional[dynamic_ragged_shape.DynamicRaggedShape], 

1645 b: Optional[dynamic_ragged_shape.DynamicRaggedShape] 

1646) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: 

1647 if a is None: 

1648 return b 

1649 if b is None: 

1650 return a 

1651 return a._merge_with(b) # pylint: disable=protected-access 

1652 

1653 

1654def _shape_from_fields( 

1655 fields, rank: int, 

1656 dtype: dtypes.DType) -> Optional[dynamic_ragged_shape.DynamicRaggedShape]: 

1657 """Given fields, rank, and dtype, create a shape.""" 

1658 

1659 field_shape = None 

1660 for (k, field) in fields.items(): 

1661 try: 

1662 next_field_shape_raw = _dynamic_ragged_shape_from_tensor( 

1663 field, dtype=dtype) 

1664 next_field_shape = next_field_shape_raw[:rank] 

1665 field_shape = _merge_with_optional(field_shape, next_field_shape) 

1666 except Exception as err: 

1667 raise ValueError(f'Error in shape of {k}') from err 

1668 

1669 return field_shape 

1670 

1671 

1672def _field_shape_dtype(field: _FieldValue) -> Optional[dtypes.DType]: 

1673 if isinstance(field, ragged_tensor.RaggedTensor): 

1674 return field._row_partition.dtype # pylint: disable=protected-access 

1675 if isinstance(field, StructuredTensor): 

1676 return field._ragged_shape.dtype # pylint: disable=protected-access 

1677 return None 

1678 

1679 

1680def _field_with_shape_dtype(field: _FieldValue, 

1681 dtype: dtypes.DType) -> _FieldValue: 

1682 if isinstance(field, ragged_tensor.RaggedTensor): 

1683 return field.with_row_splits_dtype(dtype) 

1684 if isinstance(field, StructuredTensor): 

1685 return field.with_shape_dtype(dtype) 

1686 

1687 return field 

1688 

1689 

1690def _fields_with_dtype(fields: Mapping[str, _FieldValue], 

1691 dtype: dtypes.DType) -> Mapping[str, _FieldValue]: 

1692 return {k: _field_with_shape_dtype(v, dtype) for (k, v) in fields.items()} 

1693 

1694 

1695# pylint:disable=protected-access 

1696def _dynamic_ragged_shape_init(fields, shape, nrows, row_partitions): 

1697 """Produce a DynamicRaggedShape for StructuredTensor.""" 

1698 assert isinstance(fields, dict), fields 

1699 assert isinstance(shape, tensor_shape.TensorShape), shape 

1700 assert nrows is None or isinstance(nrows, ops.Tensor) or isinstance( 

1701 nrows, int), nrows 

1702 assert row_partitions is None or isinstance(row_partitions, 

1703 tuple), row_partitions 

1704 rank = shape.rank 

1705 

1706 if rank is None: 

1707 raise TypeError("StructuredTensor's shape must have known rank.") 

1708 

1709 # TODO(martinz): figure out whether to validate. 

1710 dtype = _find_shape_dtype(fields, nrows, row_partitions) 

1711 

1712 fields = _fields_with_dtype(fields, dtype) 

1713 

1714 result = None 

1715 if shape.is_fully_defined(): 

1716 result = dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 

1717 shape.as_list(), dtype=dtype) 

1718 

1719 if rank == 0: 

1720 return dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 

1721 array_ops.zeros((0,), dtype=dtype)) 

1722 

1723 result = _merge_with_optional(result, _shape_from_fields(fields, rank, dtype)) 

1724 if rank == 1: 

1725 alt_value = tensor_shape.dimension_value(shape[0]) 

1726 if alt_value is not None: 

1727 nrows = alt_value 

1728 if nrows is not None: 

1729 result = _merge_with_optional( 

1730 result, 

1731 dynamic_ragged_shape.DynamicRaggedShape._from_inner_shape( 

1732 [nrows], dtype=dtype)) 

1733 if result is None: 

1734 raise ValueError('Must specify `nrows`, a fully specified `shape`,' + 

1735 ' or have `fields` if `rank=1`') 

1736 

1737 return result 

1738 

1739 if row_partitions: 

1740 result = _merge_with_optional( 

1741 result, 

1742 dynamic_ragged_shape.DynamicRaggedShape.from_row_partitions( 

1743 row_partitions, dtype=dtype)) 

1744 

1745 if result is None: 

1746 raise ValueError('Must specify row_partitions, a fully specified shape, ' + 

1747 'or have fields if rank > 1') 

1748 return result 

1749 

1750 

1751# TODO(martinz): Drop this method or rename. 

1752def StructuredTensorSpec(shape, field_specs): # pylint:disable=invalid-name 

1753 """A placeholder for the old StructuredTensorSpec.""" 

1754 if not isinstance(field_specs, dict): 

1755 raise TypeError('field_specs must be a dictionary.') 

1756 for k in field_specs.keys(): 

1757 if not isinstance(k, str): 

1758 raise TypeError('field_specs must be a dictionary with string keys.') 

1759 for v in field_specs.values(): 

1760 if not isinstance(v, type_spec.TypeSpec): 

1761 raise TypeError('field_specs must be a dictionary with TypeSpec values.') 

1762 

1763 shape = dynamic_ragged_shape.DynamicRaggedShape.Spec._from_tensor_shape( 

1764 tensor_shape.as_shape(shape), 0, dtypes.int32) 

1765 rank = shape.rank 

1766 if rank is None: 

1767 raise TypeError("StructuredTensor's shape must have known rank.") 

1768 for (k, v) in field_specs.items(): 

1769 field_shape_untruncated = _dynamic_ragged_shape_spec_from_spec(v) 

1770 if field_shape_untruncated is None: 

1771 raise ValueError(f'Cannot convert spec of {k}.') 

1772 untruncated_rank = field_shape_untruncated.rank 

1773 if (untruncated_rank is not None and untruncated_rank < rank): 

1774 raise ValueError(f'Rank of field {k} is {untruncated_rank},' 

1775 f' but must be at least {rank}.') 

1776 field_shape = field_shape_untruncated._truncate(rank) 

1777 shape = shape._merge_with(field_shape) 

1778 return StructuredTensor.Spec(_ragged_shape=shape, _fields=field_specs)