Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_addition.py: 33%

151 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Add one or more `LinearOperators` efficiently.""" 

16 

17import abc 

18 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_shape 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import check_ops 

23from tensorflow.python.ops.linalg import linear_operator 

24from tensorflow.python.ops.linalg import linear_operator_diag 

25from tensorflow.python.ops.linalg import linear_operator_full_matrix 

26from tensorflow.python.ops.linalg import linear_operator_identity 

27from tensorflow.python.ops.linalg import linear_operator_lower_triangular 

28 

29__all__ = [] 

30 

31 

32def add_operators(operators, 

33 operator_name=None, 

34 addition_tiers=None, 

35 name=None): 

36 """Efficiently add one or more linear operators. 

37 

38 Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of 

39 operators `[B1, B2,...]` such that 

40 

41 ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` 

42 

43 The operators `Bk` result by adding some of the `Ak`, as allowed by 

44 `addition_tiers`. 

45 

46 Example of efficient adding of diagonal operators. 

47 

48 ```python 

49 A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") 

50 A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") 

51 

52 # Use two tiers, the first contains an Adder that returns Diag. Since both 

53 # A1 and A2 are Diag, they can use this Adder. The second tier will not be 

54 # used. 

55 addition_tiers = [ 

56 [_AddAndReturnDiag()], 

57 [_AddAndReturnMatrix()]] 

58 B_list = add_operators([A1, A2], addition_tiers=addition_tiers) 

59 

60 len(B_list) 

61 ==> 1 

62 

63 B_list[0].__class__.__name__ 

64 ==> 'LinearOperatorDiag' 

65 

66 B_list[0].to_dense() 

67 ==> [[3., 0.], 

68 [0., 3.]] 

69 

70 B_list[0].name 

71 ==> 'Add/A1__A2/' 

72 ``` 

73 

74 Args: 

75 operators: Iterable of `LinearOperator` objects with same `dtype`, domain 

76 and range dimensions, and broadcastable batch shapes. 

77 operator_name: String name for returned `LinearOperator`. Defaults to 

78 concatenation of "Add/A__B/" that indicates the order of addition steps. 

79 addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` 

80 is a list of `Adder` objects. This function attempts to do all additions 

81 in tier `i` before trying tier `i + 1`. 

82 name: A name for this `Op`. Defaults to `add_operators`. 

83 

84 Returns: 

85 Subclass of `LinearOperator`. Class and order of addition may change as new 

86 (and better) addition strategies emerge. 

87 

88 Raises: 

89 ValueError: If `operators` argument is empty. 

90 ValueError: If shapes are incompatible. 

91 """ 

92 # Default setting 

93 if addition_tiers is None: 

94 addition_tiers = _DEFAULT_ADDITION_TIERS 

95 

96 # Argument checking. 

97 check_ops.assert_proper_iterable(operators) 

98 operators = list(reversed(operators)) 

99 if len(operators) < 1: 

100 raise ValueError( 

101 f"Argument `operators` must contain at least one operator. " 

102 f"Received: {operators}.") 

103 if not all( 

104 isinstance(op, linear_operator.LinearOperator) for op in operators): 

105 raise TypeError( 

106 f"Argument `operators` must contain only LinearOperator instances. " 

107 f"Received: {operators}.") 

108 _static_check_for_same_dimensions(operators) 

109 _static_check_for_broadcastable_batch_shape(operators) 

110 

111 with ops.name_scope(name or "add_operators"): 

112 

113 # Additions done in one of the tiers. Try tier 0, 1,... 

114 ops_to_try_at_next_tier = list(operators) 

115 for tier in addition_tiers: 

116 ops_to_try_at_this_tier = ops_to_try_at_next_tier 

117 ops_to_try_at_next_tier = [] 

118 while ops_to_try_at_this_tier: 

119 op1 = ops_to_try_at_this_tier.pop() 

120 op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) 

121 if op2 is not None: 

122 # Will try to add the result of this again at this same tier. 

123 new_operator = adder.add(op1, op2, operator_name) 

124 ops_to_try_at_this_tier.append(new_operator) 

125 else: 

126 ops_to_try_at_next_tier.append(op1) 

127 

128 return ops_to_try_at_next_tier 

129 

130 

131def _pop_a_match_at_tier(op1, operator_list, tier): 

132 # Search from the back of list to the front in order to create nice default 

133 # order of operations. 

134 for i in range(1, len(operator_list) + 1): 

135 op2 = operator_list[-i] 

136 for adder in tier: 

137 if adder.can_add(op1, op2): 

138 return operator_list.pop(-i), adder 

139 return None, None 

140 

141 

142def _infer_hints_allowing_override(op1, op2, hints): 

143 """Infer hints from op1 and op2. hints argument is an override. 

144 

145 Args: 

146 op1: LinearOperator 

147 op2: LinearOperator 

148 hints: _Hints object holding "is_X" boolean hints to use for returned 

149 operator. 

150 If some hint is None, try to set using op1 and op2. If the 

151 hint is provided, ignore op1 and op2 hints. This allows an override 

152 of previous hints, but does not allow forbidden hints (e.g. you still 

153 cannot say a real diagonal operator is not self-adjoint. 

154 

155 Returns: 

156 _Hints object. 

157 """ 

158 hints = hints or _Hints() 

159 # If A, B are self-adjoint, then so is A + B. 

160 if hints.is_self_adjoint is None: 

161 is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint 

162 else: 

163 is_self_adjoint = hints.is_self_adjoint 

164 

165 # If A, B are positive definite, then so is A + B. 

166 if hints.is_positive_definite is None: 

167 is_positive_definite = op1.is_positive_definite and op2.is_positive_definite 

168 else: 

169 is_positive_definite = hints.is_positive_definite 

170 

171 # A positive definite operator is always non-singular. 

172 if is_positive_definite and hints.is_positive_definite is None: 

173 is_non_singular = True 

174 else: 

175 is_non_singular = hints.is_non_singular 

176 

177 return _Hints( 

178 is_non_singular=is_non_singular, 

179 is_self_adjoint=is_self_adjoint, 

180 is_positive_definite=is_positive_definite) 

181 

182 

183def _static_check_for_same_dimensions(operators): 

184 """ValueError if operators determined to have different dimensions.""" 

185 if len(operators) < 2: 

186 return 

187 

188 domain_dimensions = [ 

189 (op.name, tensor_shape.dimension_value(op.domain_dimension)) 

190 for op in operators 

191 if tensor_shape.dimension_value(op.domain_dimension) is not None] 

192 if len(set(value for name, value in domain_dimensions)) > 1: 

193 raise ValueError(f"All `operators` must have the same `domain_dimension`. " 

194 f"Received: {domain_dimensions}.") 

195 

196 range_dimensions = [ 

197 (op.name, tensor_shape.dimension_value(op.range_dimension)) 

198 for op in operators 

199 if tensor_shape.dimension_value(op.range_dimension) is not None] 

200 if len(set(value for name, value in range_dimensions)) > 1: 

201 raise ValueError(f"All operators must have the same `range_dimension`. " 

202 f"Received: {range_dimensions}.") 

203 

204 

205def _static_check_for_broadcastable_batch_shape(operators): 

206 """ValueError if operators determined to have non-broadcastable shapes.""" 

207 if len(operators) < 2: 

208 return 

209 

210 # This will fail if they cannot be broadcast together. 

211 batch_shape = operators[0].batch_shape 

212 for op in operators[1:]: 

213 batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) 

214 

215 

216class _Hints: 

217 """Holds 'is_X' flags that every LinearOperator is initialized with.""" 

218 

219 def __init__(self, 

220 is_non_singular=None, 

221 is_positive_definite=None, 

222 is_self_adjoint=None): 

223 self.is_non_singular = is_non_singular 

224 self.is_positive_definite = is_positive_definite 

225 self.is_self_adjoint = is_self_adjoint 

226 

227 

228################################################################################ 

229# Classes to add two linear operators. 

230################################################################################ 

231 

232 

233class _Adder(metaclass=abc.ABCMeta): 

234 """Abstract base class to add two operators. 

235 

236 Each `Adder` acts independently, adding everything it can, paying no attention 

237 as to whether another `Adder` could have done the addition more efficiently. 

238 """ 

239 

240 @property 

241 def name(self): 

242 return self.__class__.__name__ 

243 

244 @abc.abstractmethod 

245 def can_add(self, op1, op2): 

246 """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" 

247 pass 

248 

249 @abc.abstractmethod 

250 def _add(self, op1, op2, operator_name, hints): 

251 # Derived classes can assume op1 and op2 have been validated, e.g. they have 

252 # the same dtype, and their domain/range dimensions match. 

253 pass 

254 

255 def add(self, op1, op2, operator_name, hints=None): 

256 """Return new `LinearOperator` acting like `op1 + op2`. 

257 

258 Args: 

259 op1: `LinearOperator` 

260 op2: `LinearOperator`, with `shape` and `dtype` such that adding to 

261 `op1` is allowed. 

262 operator_name: `String` name to give to returned `LinearOperator` 

263 hints: `_Hints` object. Returned `LinearOperator` will be created with 

264 these hints. 

265 

266 Returns: 

267 `LinearOperator` 

268 """ 

269 updated_hints = _infer_hints_allowing_override(op1, op2, hints) 

270 

271 if operator_name is None: 

272 operator_name = "Add/" + op1.name + "__" + op2.name + "/" 

273 

274 scope_name = self.name 

275 if scope_name.startswith("_"): 

276 scope_name = scope_name[1:] 

277 with ops.name_scope(scope_name): 

278 return self._add(op1, op2, operator_name, updated_hints) 

279 

280 

281class _AddAndReturnScaledIdentity(_Adder): 

282 """Handles additions resulting in an Identity family member. 

283 

284 The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family 

285 is closed under addition. This `Adder` respects that, and returns an Identity 

286 """ 

287 

288 def can_add(self, op1, op2): 

289 types = {_type(op1), _type(op2)} 

290 return not types.difference(_IDENTITY_FAMILY) 

291 

292 def _add(self, op1, op2, operator_name, hints): 

293 # Will build a LinearOperatorScaledIdentity. 

294 

295 if _type(op1) == _SCALED_IDENTITY: 

296 multiplier_1 = op1.multiplier 

297 else: 

298 multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) 

299 

300 if _type(op2) == _SCALED_IDENTITY: 

301 multiplier_2 = op2.multiplier 

302 else: 

303 multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) 

304 

305 return linear_operator_identity.LinearOperatorScaledIdentity( 

306 num_rows=op1.range_dimension_tensor(), 

307 multiplier=multiplier_1 + multiplier_2, 

308 is_non_singular=hints.is_non_singular, 

309 is_self_adjoint=hints.is_self_adjoint, 

310 is_positive_definite=hints.is_positive_definite, 

311 name=operator_name) 

312 

313 

314class _AddAndReturnDiag(_Adder): 

315 """Handles additions resulting in a Diag operator.""" 

316 

317 def can_add(self, op1, op2): 

318 types = {_type(op1), _type(op2)} 

319 return not types.difference(_DIAG_LIKE) 

320 

321 def _add(self, op1, op2, operator_name, hints): 

322 return linear_operator_diag.LinearOperatorDiag( 

323 diag=op1.diag_part() + op2.diag_part(), 

324 is_non_singular=hints.is_non_singular, 

325 is_self_adjoint=hints.is_self_adjoint, 

326 is_positive_definite=hints.is_positive_definite, 

327 name=operator_name) 

328 

329 

330class _AddAndReturnTriL(_Adder): 

331 """Handles additions resulting in a TriL operator.""" 

332 

333 def can_add(self, op1, op2): 

334 types = {_type(op1), _type(op2)} 

335 return not types.difference(_DIAG_LIKE.union({_TRIL})) 

336 

337 def _add(self, op1, op2, operator_name, hints): 

338 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 

339 op_add_to_tensor, op_other = op1, op2 

340 else: 

341 op_add_to_tensor, op_other = op2, op1 

342 

343 return linear_operator_lower_triangular.LinearOperatorLowerTriangular( 

344 tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 

345 is_non_singular=hints.is_non_singular, 

346 is_self_adjoint=hints.is_self_adjoint, 

347 is_positive_definite=hints.is_positive_definite, 

348 name=operator_name) 

349 

350 

351class _AddAndReturnMatrix(_Adder): 

352 """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" 

353 

354 def can_add(self, op1, op2): # pylint: disable=unused-argument 

355 return isinstance(op1, linear_operator.LinearOperator) and isinstance( 

356 op2, linear_operator.LinearOperator) 

357 

358 def _add(self, op1, op2, operator_name, hints): 

359 if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: 

360 op_add_to_tensor, op_other = op1, op2 

361 else: 

362 op_add_to_tensor, op_other = op2, op1 

363 return linear_operator_full_matrix.LinearOperatorFullMatrix( 

364 matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), 

365 is_non_singular=hints.is_non_singular, 

366 is_self_adjoint=hints.is_self_adjoint, 

367 is_positive_definite=hints.is_positive_definite, 

368 name=operator_name) 

369 

370 

371################################################################################ 

372# Constants designating types of LinearOperators 

373################################################################################ 

374 

375# Type name constants for LinearOperator classes. 

376_IDENTITY = "identity" 

377_SCALED_IDENTITY = "scaled_identity" 

378_DIAG = "diag" 

379_TRIL = "tril" 

380_MATRIX = "matrix" 

381 

382# Groups of operators. 

383_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} 

384_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} 

385# operators with an efficient .add_to_tensor() method. 

386_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE 

387 

388# Supported LinearOperator classes. 

389SUPPORTED_OPERATORS = [ 

390 linear_operator_diag.LinearOperatorDiag, 

391 linear_operator_lower_triangular.LinearOperatorLowerTriangular, 

392 linear_operator_full_matrix.LinearOperatorFullMatrix, 

393 linear_operator_identity.LinearOperatorIdentity, 

394 linear_operator_identity.LinearOperatorScaledIdentity 

395] 

396 

397 

398def _type(operator): 

399 """Returns the type name constant (e.g. _TRIL) for operator.""" 

400 if isinstance(operator, linear_operator_diag.LinearOperatorDiag): 

401 return _DIAG 

402 if isinstance(operator, 

403 linear_operator_lower_triangular.LinearOperatorLowerTriangular): 

404 return _TRIL 

405 if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): 

406 return _MATRIX 

407 if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): 

408 return _IDENTITY 

409 if isinstance(operator, 

410 linear_operator_identity.LinearOperatorScaledIdentity): 

411 return _SCALED_IDENTITY 

412 raise TypeError(f"Expected operator to be one of [LinearOperatorDiag, " 

413 f"LinearOperatorLowerTriangular, LinearOperatorFullMatrix, " 

414 f"LinearOperatorIdentity, LinearOperatorScaledIdentity]. " 

415 f"Received: {operator}") 

416 

417 

418################################################################################ 

419# Addition tiers: 

420# We attempt to use Adders in tier K before K+1. 

421# 

422# Organize tiers to 

423# (i) reduce O(..) complexity of forming final operator, and 

424# (ii) produce the "most efficient" final operator. 

425# Dev notes: 

426# * Results of addition at tier K will be added at tier K or higher. 

427# * Tiers may change, and we warn the user that it may change. 

428################################################################################ 

429 

430# Note that the final tier, _AddAndReturnMatrix, will convert everything to a 

431# dense matrix. So it is sometimes very inefficient. 

432_DEFAULT_ADDITION_TIERS = [ 

433 [_AddAndReturnScaledIdentity()], 

434 [_AddAndReturnDiag()], 

435 [_AddAndReturnTriL()], 

436 [_AddAndReturnMatrix()], 

437]