Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_util.py: 17%

190 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Internal utilities for `LinearOperator` classes.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_conversion 

22from tensorflow.python.module import module 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import check_ops 

25from tensorflow.python.ops import control_flow_ops 

26from tensorflow.python.ops import linalg_ops 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.ops import variables as variables_module 

29from tensorflow.python.util import nest 

30 

31 

32################################################################################ 

33# To make more friendly for TF2. 

34################################################################################ 

35 

36 

37def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None): 

38 """Converts the given `value` to a `Tensor` if input is nonreference type. 

39 

40 This function converts Python objects of various types to `Tensor` objects 

41 except if the input has nonreference semantics. Reference semantics are 

42 characterized by `is_ref` and is any object which is a 

43 `tf.Variable` or instance of `tf.Module`. This function accepts any input 

44 which `tf.convert_to_tensor` would also. 

45 

46 Note: This function diverges from default Numpy behavior for `float` and 

47 `string` types when `None` is present in a Python list or scalar. Rather 

48 than silently converting `None` values, an error will be thrown. 

49 

50 Args: 

51 value: An object whose type has a registered `Tensor` conversion function. 

52 dtype: Optional element type for the returned tensor. If missing, the 

53 type is inferred from the type of `value`. 

54 dtype_hint: Optional element type for the returned tensor, 

55 used when dtype is None. In some cases, a caller may not have a 

56 dtype in mind when converting to a tensor, so dtype_hint 

57 can be used as a soft preference. If the conversion to 

58 `dtype_hint` is not possible, this argument has no effect. 

59 name: Optional name to use if a new `Tensor` is created. 

60 

61 Returns: 

62 tensor: A `Tensor` based on `value`. 

63 

64 Raises: 

65 TypeError: If no conversion function is registered for `value` to `dtype`. 

66 RuntimeError: If a registered conversion function returns an invalid value. 

67 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 

68 

69 

70 #### Examples: 

71 

72 ```python 

73 

74 x = tf.Variable(0.) 

75 y = convert_nonref_to_tensor(x) 

76 x is y 

77 # ==> True 

78 

79 x = tf.constant(0.) 

80 y = convert_nonref_to_tensor(x) 

81 x is y 

82 # ==> True 

83 

84 x = np.array(0.) 

85 y = convert_nonref_to_tensor(x) 

86 x is y 

87 # ==> False 

88 tf.is_tensor(y) 

89 # ==> True 

90 

91 x = tfp.util.DeferredTensor(13.37, lambda x: x) 

92 y = convert_nonref_to_tensor(x) 

93 x is y 

94 # ==> True 

95 tf.is_tensor(y) 

96 # ==> False 

97 tf.equal(y, 13.37) 

98 # ==> True 

99 ``` 

100 

101 """ 

102 # We explicitly do not use a tf.name_scope to avoid graph clutter. 

103 if value is None: 

104 return None 

105 if is_ref(value): 

106 if dtype is None: 

107 return value 

108 dtype_base = base_dtype(dtype) 

109 value_dtype_base = base_dtype(value.dtype) 

110 if dtype_base != value_dtype_base: 

111 raise TypeError( 

112 f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` " 

113 f"Received: `{dtype_name(value_dtype_base)}`.") 

114 return value 

115 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

116 value, dtype=dtype, dtype_hint=dtype_hint, name=name 

117 ) 

118 

119 

120def base_dtype(dtype): 

121 """Returns a non-reference `dtype` based on this `dtype`.""" 

122 dtype = dtypes.as_dtype(dtype) 

123 if hasattr(dtype, "base_dtype"): 

124 return dtype.base_dtype 

125 return dtype 

126 

127 

128def dtype_name(dtype): 

129 """Returns the string name for this `dtype`.""" 

130 dtype = dtypes.as_dtype(dtype) 

131 if hasattr(dtype, "name"): 

132 return dtype.name 

133 if hasattr(dtype, "__name__"): 

134 return dtype.__name__ 

135 return str(dtype) 

136 

137 

138def check_dtype(arg, dtype): 

139 """Check that arg.dtype == self.dtype.""" 

140 if arg.dtype.base_dtype != dtype: 

141 raise TypeError( 

142 f"Expected argument to have dtype {dtype}. Found: {arg.dtype} in " 

143 f"tensor {arg}.") 

144 

145 

146def is_ref(x): 

147 """Evaluates if the object has reference semantics. 

148 

149 An object is deemed "reference" if it is a `tf.Variable` instance or is 

150 derived from a `tf.Module` with `dtype` and `shape` properties. 

151 

152 Args: 

153 x: Any object. 

154 

155 Returns: 

156 is_ref: Python `bool` indicating input is has nonreference semantics, i.e., 

157 is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties. 

158 """ 

159 return ( 

160 # Note: we check that tf.Variable is a class because we might be using a 

161 # different backend other than TF. 

162 isinstance(x, variables_module.Variable) or 

163 (isinstance(x, module.Module) and hasattr(x, "dtype") and 

164 hasattr(x, "shape"))) 

165 

166 

167def assert_not_ref_type(x, arg_name): 

168 if is_ref(x): 

169 raise TypeError( 

170 f"Argument {arg_name} cannot be reference type. Found: {type(x)}.") 

171 

172 

173################################################################################ 

174# Asserts. 

175################################################################################ 

176 

177 

178def assert_no_entries_with_modulus_zero( 

179 x, message=None, name="assert_no_entries_with_modulus_zero"): 

180 """Returns `Op` that asserts Tensor `x` has no entries with modulus zero. 

181 

182 Args: 

183 x: Numeric `Tensor`, real, integer, or complex. 

184 message: A string message to prepend to failure message. 

185 name: A name to give this `Op`. 

186 

187 Returns: 

188 An `Op` that asserts `x` has no entries with modulus zero. 

189 """ 

190 with ops.name_scope(name, values=[x]): 

191 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x") 

192 dtype = x.dtype.base_dtype 

193 should_be_nonzero = math_ops.abs(x) 

194 zero = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

195 0, dtype=dtype.real_dtype 

196 ) 

197 return check_ops.assert_less(zero, should_be_nonzero, message=message) 

198 

199 

200def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): 

201 """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts. 

202 

203 Args: 

204 x: Numeric `Tensor`, real, integer, or complex. 

205 message: A string message to prepend to failure message. 

206 name: A name to give this `Op`. 

207 

208 Returns: 

209 An `Op` that asserts `x` has no entries with modulus zero. 

210 """ 

211 with ops.name_scope(name, values=[x]): 

212 x = tensor_conversion.convert_to_tensor_v2_with_dispatch(x, name="x") 

213 dtype = x.dtype.base_dtype 

214 

215 if dtype.is_floating: 

216 return control_flow_ops.no_op() 

217 

218 zero = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

219 0, dtype=dtype.real_dtype 

220 ) 

221 return check_ops.assert_equal(zero, math_ops.imag(x), message=message) 

222 

223 

224def assert_compatible_matrix_dimensions(operator, x): 

225 """Assert that an argument to solve/matmul has proper domain dimension. 

226 

227 If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then 

228 `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an 

229 `Assert` that "fires" if this is not the case. Static checks are already 

230 done by the base class `LinearOperator`. 

231 

232 Args: 

233 operator: `LinearOperator`. 

234 x: `Tensor`. 

235 

236 Returns: 

237 `Assert` `Op`. 

238 """ 

239 # Static checks are done in the base class. Only tensor asserts here. 

240 assert_same_dd = check_ops.assert_equal( 

241 array_ops.shape(x)[-2], 

242 operator.domain_dimension_tensor(), 

243 # This error message made to look similar to error raised by static check 

244 # in the base class. 

245 message=("Dimensions are not compatible. " 

246 "shape[-2] of argument to be the same as this operator")) 

247 

248 return assert_same_dd 

249 

250 

251def assert_is_batch_matrix(tensor): 

252 """Static assert that `tensor` has rank `2` or higher.""" 

253 sh = tensor.shape 

254 if sh.ndims is not None and sh.ndims < 2: 

255 raise ValueError( 

256 f"Expected [batch] matrix to have at least two dimensions. Found: " 

257 f"{tensor}.") 

258 

259 

260def shape_tensor(shape, name=None): 

261 """Convert Tensor using default type, unless empty list or tuple.""" 

262 # Works just like random_ops._ShapeTensor. 

263 if isinstance(shape, (tuple, list)) and not shape: 

264 dtype = dtypes.int32 

265 else: 

266 dtype = None 

267 return tensor_conversion.convert_to_tensor_v2_with_dispatch( 

268 shape, dtype=dtype, name=name 

269 ) 

270 

271 

272################################################################################ 

273# Broadcasting versions of common linear algebra functions. 

274# TODO(b/77519145) Do this more efficiently in some special cases. 

275################################################################################ 

276 

277 

278def broadcast_matrix_batch_dims(batch_matrices, name=None): 

279 """Broadcast leading dimensions of zero or more [batch] matrices. 

280 

281 Example broadcasting one batch dim of two simple matrices. 

282 

283 ```python 

284 x = [[1, 2], 

285 [3, 4]] # Shape [2, 2], no batch dims 

286 

287 y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] 

288 

289 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 

290 

291 x_bc 

292 ==> [[[1, 2], 

293 [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. 

294 

295 y_bc 

296 ==> same as y 

297 ``` 

298 

299 Example broadcasting many batch dims 

300 

301 ```python 

302 x = tf.random.normal(shape=(2, 3, 1, 4, 4)) 

303 y = tf.random.normal(shape=(1, 3, 2, 5, 5)) 

304 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 

305 

306 x_bc.shape 

307 ==> (2, 3, 2, 4, 4) 

308 

309 y_bc.shape 

310 ==> (2, 3, 2, 5, 5) 

311 ``` 

312 

313 Args: 

314 batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. 

315 name: A string name to prepend to created ops. 

316 

317 Returns: 

318 bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing 

319 the values from `batch_matrices[i]`, with possibly broadcast batch dims. 

320 

321 Raises: 

322 ValueError: If any input `Tensor` is statically determined to have less 

323 than two dimensions. 

324 """ 

325 with ops.name_scope( 

326 name or "broadcast_matrix_batch_dims", values=batch_matrices): 

327 check_ops.assert_proper_iterable(batch_matrices) 

328 batch_matrices = list(batch_matrices) 

329 

330 for i, mat in enumerate(batch_matrices): 

331 batch_matrices[i] = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

332 mat 

333 ) 

334 assert_is_batch_matrix(batch_matrices[i]) 

335 

336 if len(batch_matrices) < 2: 

337 return batch_matrices 

338 

339 # Try static broadcasting. 

340 # bcast_batch_shape is the broadcast batch shape of ALL matrices. 

341 # E.g. if batch_matrices = [x, y], with 

342 # x.shape = [2, j, k] (batch shape = [2]) 

343 # y.shape = [3, 1, l, m] (batch shape = [3, 1]) 

344 # ==> bcast_batch_shape = [3, 2] 

345 bcast_batch_shape = batch_matrices[0].shape[:-2] 

346 for mat in batch_matrices[1:]: 

347 bcast_batch_shape = array_ops.broadcast_static_shape( 

348 bcast_batch_shape, 

349 mat.shape[:-2]) 

350 if bcast_batch_shape.is_fully_defined(): 

351 for i, mat in enumerate(batch_matrices): 

352 if mat.shape[:-2] != bcast_batch_shape: 

353 bcast_shape = array_ops.concat( 

354 [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0) 

355 batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape) 

356 return batch_matrices 

357 

358 # Since static didn't work, do dynamic, which always copies data. 

359 bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] 

360 for mat in batch_matrices[1:]: 

361 bcast_batch_shape = array_ops.broadcast_dynamic_shape( 

362 bcast_batch_shape, 

363 array_ops.shape(mat)[:-2]) 

364 for i, mat in enumerate(batch_matrices): 

365 batch_matrices[i] = array_ops.broadcast_to( 

366 mat, 

367 array_ops.concat( 

368 [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) 

369 

370 return batch_matrices 

371 

372 

373def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): 

374 """Solve systems of linear equations.""" 

375 with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): 

376 matrix = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

377 matrix, name="matrix" 

378 ) 

379 rhs = tensor_conversion.convert_to_tensor_v2_with_dispatch( 

380 rhs, name="rhs", dtype=matrix.dtype 

381 ) 

382 

383 # If either matrix/rhs has extra dims, we can reshape to get rid of them. 

384 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 

385 matrix, rhs, adjoint_a=adjoint) 

386 

387 # This will broadcast by brute force if we still need to. 

388 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) 

389 

390 solution = linalg_ops.matrix_solve( 

391 matrix, rhs, adjoint=adjoint and still_need_to_transpose) 

392 

393 return reshape_inv(solution) 

394 

395 

396def _reshape_for_efficiency(a, 

397 b, 

398 transpose_a=False, 

399 transpose_b=False, 

400 adjoint_a=False, 

401 adjoint_b=False): 

402 """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" 

403 def identity(x): 

404 return x 

405 

406 # At this point, we have not taken transpose/adjoint of a/b. 

407 still_need_to_transpose = True 

408 

409 if a.shape.ndims is None or b.shape.ndims is None: 

410 return a, b, identity, still_need_to_transpose 

411 

412 # This could be handled in the future, but seems less common. 

413 if a.shape.ndims >= b.shape.ndims: 

414 return a, b, identity, still_need_to_transpose 

415 

416 # From now on, we might modify b, but will not modify a. 

417 

418 # Suppose: 

419 # a.shape = C + [m, n], b.shape = 

420 # b.shape = S + C + [n, r] 

421 b_extra_ndims = b.shape.ndims - a.shape.ndims 

422 

423 # b_extra_sh = S, b_main_sh = C + [n, r] 

424 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 

425 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 

426 

427 # No reason to flip unless the extra dims of b are big enough. Why? 

428 # Assume adjoint/transpose = False. Then... 

429 # By not flipping, we have to replicate a to shape 

430 # b_extra_sh + a.shape, 

431 # which could use extra memory. But in all cases, the final output has shape 

432 # b_extra_sh + a.shape[:-1] + [b.shape[-1]] 

433 # So we only end up creating a larger object if the end dim of b is smaller 

434 # than the end dim of a. This often happens, e.g. if b was a vector that was 

435 # expanded to a matrix (by appending a singleton). 

436 

437 # Since adjoint/transpose may not be False, we must make adjustments here. 

438 # The dim of b that holds the multiple equations. 

439 a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] 

440 b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] 

441 b_extra_sz_ = ( 

442 np.prod(b.shape[:b_extra_ndims].as_list()) 

443 if b.shape[:b_extra_ndims].is_fully_defined() else None) 

444 if (a_domain_sz_ is not None and b_eq_sz_ is not None and 

445 b_extra_sz_ is not None): 

446 if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: 

447 return a, b, identity, still_need_to_transpose 

448 

449 # At this point, we're flipping for sure! 

450 # Any transposes/adjoints will happen here explicitly, rather than in calling 

451 # code. Why? To avoid having to write separate complex code for each case. 

452 if adjoint_a: 

453 a = array_ops.matrix_transpose(a, conjugate=True) 

454 elif transpose_a: 

455 a = array_ops.matrix_transpose(a, conjugate=False) 

456 if adjoint_b: 

457 b = array_ops.matrix_transpose(b, conjugate=True) 

458 elif transpose_a: 

459 b = array_ops.matrix_transpose(b, conjugate=False) 

460 still_need_to_transpose = False 

461 

462 # Recompute shapes, since the transpose/adjoint may have changed them. 

463 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 

464 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 

465 

466 # Permutation to put the extra dims at the end. 

467 perm = ( 

468 np.concatenate( 

469 (np.arange(b_extra_ndims, b.shape.ndims), 

470 np.arange(0, b_extra_ndims)), 0)) 

471 b_extra_on_end = array_ops.transpose(b, perm=perm) 

472 

473 # Now squash this end into one long dim. 

474 b_squashed_end = array_ops.reshape( 

475 b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) 

476 

477 def reshape_inv(y): 

478 # Expand the extra dims hanging off the end, "b_extra_sh". 

479 # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y 

480 # Could have different batch dims than a and b, because of broadcasting. 

481 y_extra_shape = array_ops.concat( 

482 (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) 

483 y_extra_on_end = array_ops.reshape(y, y_extra_shape) 

484 inverse_perm = np.argsort(perm) 

485 return array_ops.transpose(y_extra_on_end, perm=inverse_perm) 

486 

487 return a, b_squashed_end, reshape_inv, still_need_to_transpose 

488 

489 

490################################################################################ 

491# Helpers for hints. 

492################################################################################ 

493 

494 

495def is_adjoint_pair(x, y): 

496 """True iff x and y are adjoints of each other (by id, not entries).""" 

497 if x is y: # Note that if x is y then all of their hints are the same! 

498 if x.is_self_adjoint is False: # pylint:disable=g-bool-id-comparison 

499 return False 

500 if x.is_self_adjoint: 

501 return True 

502 # Use the fact that if x = LinearOperatorAdjoint(y), then x.H is y. 

503 return x.H is y or y.H is x 

504 

505 

506def is_aat_form(operators): 

507 """Returns True if operators is of the form A @ A.H, possibly recursively.""" 

508 operators = list(operators) 

509 if not operators: 

510 raise ValueError("AAT form is undefined for empty operators") 

511 

512 if len(operators) % 2: 

513 return False 

514 

515 # Check for forms like (A1 @ A2) @ (A2.H @ A1.H) 

516 return all( 

517 is_adjoint_pair(operators[i], operators[-1 - i]) 

518 for i in range(len(operators) // 2)) 

519 

520 

521def use_operator_or_provided_hint_unless_contradicting( 

522 operator, hint_attr_name, provided_hint_value, message): 

523 """Get combined hint in the case where operator.hint should equal hint. 

524 

525 Args: 

526 operator: LinearOperator that a meta-operator was initialized with. 

527 hint_attr_name: String name for the attribute. 

528 provided_hint_value: Bool or None. Value passed by user in initialization. 

529 message: Error message to print if hints contradict. 

530 

531 Returns: 

532 True, False, or None. 

533 

534 Raises: 

535 ValueError: If hints contradict. 

536 """ 

537 op_hint = getattr(operator, hint_attr_name) 

538 # pylint: disable=g-bool-id-comparison 

539 if op_hint is False and provided_hint_value: 

540 raise ValueError(message) 

541 if op_hint and provided_hint_value is False: 

542 raise ValueError(message) 

543 if op_hint or provided_hint_value: 

544 return True 

545 if op_hint is False or provided_hint_value is False: 

546 return False 

547 # pylint: enable=g-bool-id-comparison 

548 return None 

549 

550 

551################################################################################ 

552# Utilities for blockwise operators. 

553################################################################################ 

554 

555 

556def arg_is_blockwise(block_dimensions, arg, arg_split_dim): 

557 """Detect if input should be interpreted as a list of blocks.""" 

558 # Tuples and lists of length equal to the number of operators may be 

559 # blockwise. 

560 if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)): 

561 # If the elements of the iterable are not nested, interpret the input as 

562 # blockwise. 

563 if not any(nest.is_nested(x) for x in arg): 

564 return True 

565 else: 

566 arg_dims = [ 

567 tensor_conversion.convert_to_tensor_v2_with_dispatch(x).shape[ 

568 arg_split_dim 

569 ] 

570 for x in arg 

571 ] 

572 self_dims = [dim.value for dim in block_dimensions] 

573 

574 # If none of the operator dimensions are known, interpret the input as 

575 # blockwise if its matching dimensions are unequal. 

576 if all(self_d is None for self_d in self_dims): 

577 

578 # A nested tuple/list with a single outermost element is not blockwise 

579 if len(arg_dims) == 1: 

580 return False 

581 elif any(dim != arg_dims[0] for dim in arg_dims): 

582 return True 

583 else: 

584 raise ValueError( 

585 "Parsing of the input structure is ambiguous. Please input " 

586 "a blockwise iterable of `Tensor`s or a single `Tensor`.") 

587 

588 # If input dimensions equal the respective (known) blockwise operator 

589 # dimensions, then the input is blockwise. 

590 if all(self_d == arg_d or self_d is None 

591 for self_d, arg_d in zip(self_dims, arg_dims)): 

592 return True 

593 

594 # If input dimensions equals are all equal, and are greater than or equal 

595 # to the sum of the known operator dimensions, interpret the input as 

596 # blockwise. 

597 # input is not blockwise. 

598 self_dim = sum(self_d for self_d in self_dims if self_d is not None) 

599 if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim: 

600 return False 

601 

602 # If none of these conditions is met, the input shape is mismatched. 

603 raise ValueError("Input dimension does not match operator dimension.") 

604 else: 

605 return False 

606 

607 

608def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1): 

609 """Split `x` into blocks matching `operators`'s `domain_dimension`. 

610 

611 Specifically, if we have a blockwise lower-triangular matrix, with block 

612 sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg` 

613 on `axis` into `J` tensors, whose shape at `axis` is `M_j`. 

614 

615 Args: 

616 block_dims: Iterable of `TensorShapes`. 

617 block_dims_fn: Callable returning an iterable of `Tensor`s. 

618 arg: `Tensor`. `arg` is split into `J` tensors. 

619 axis: Python `Integer` representing the axis to split `arg` on. 

620 

621 Returns: 

622 A list of `Tensor`s. 

623 """ 

624 block_sizes = [dim.value for dim in block_dims] 

625 if any(d is None for d in block_sizes): 

626 block_sizes = block_dims_fn() 

627 return array_ops.split(arg, block_sizes, axis=axis)