Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_algebra.py: 55%

109 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Registration mechanisms for various n-ary operations on LinearOperators.""" 

17 

18import itertools 

19 

20from tensorflow.python.framework import ops 

21from tensorflow.python.util import tf_inspect 

22 

23 

24_ADJOINTS = {} 

25_CHOLESKY_DECOMPS = {} 

26_MATMUL = {} 

27_SOLVE = {} 

28_INVERSES = {} 

29 

30 

31def _registered_function(type_list, registry): 

32 """Given a list of classes, finds the most specific function registered.""" 

33 enumerated_hierarchies = [enumerate(tf_inspect.getmro(t)) for t in type_list] 

34 # Get all possible combinations of hierarchies. 

35 cls_combinations = list(itertools.product(*enumerated_hierarchies)) 

36 

37 def hierarchy_distance(cls_combination): 

38 candidate_distance = sum(c[0] for c in cls_combination) 

39 if tuple(c[1] for c in cls_combination) in registry: 

40 return candidate_distance 

41 return 10000 

42 

43 registered_combination = min(cls_combinations, key=hierarchy_distance) 

44 return registry.get(tuple(r[1] for r in registered_combination), None) 

45 

46 

47def _registered_adjoint(type_a): 

48 """Get the Adjoint function registered for class a.""" 

49 return _registered_function([type_a], _ADJOINTS) 

50 

51 

52def _registered_cholesky(type_a): 

53 """Get the Cholesky function registered for class a.""" 

54 return _registered_function([type_a], _CHOLESKY_DECOMPS) 

55 

56 

57def _registered_matmul(type_a, type_b): 

58 """Get the Matmul function registered for classes a and b.""" 

59 return _registered_function([type_a, type_b], _MATMUL) 

60 

61 

62def _registered_solve(type_a, type_b): 

63 """Get the Solve function registered for classes a and b.""" 

64 return _registered_function([type_a, type_b], _SOLVE) 

65 

66 

67def _registered_inverse(type_a): 

68 """Get the Cholesky function registered for class a.""" 

69 return _registered_function([type_a], _INVERSES) 

70 

71 

72def adjoint(lin_op_a, name=None): 

73 """Get the adjoint associated to lin_op_a. 

74 

75 Args: 

76 lin_op_a: The LinearOperator to take the adjoint of. 

77 name: Name to use for this operation. 

78 

79 Returns: 

80 A LinearOperator that represents the adjoint of `lin_op_a`. 

81 

82 Raises: 

83 NotImplementedError: If no Adjoint method is defined for the LinearOperator 

84 type of `lin_op_a`. 

85 """ 

86 adjoint_fn = _registered_adjoint(type(lin_op_a)) 

87 if adjoint_fn is None: 

88 raise ValueError("No adjoint registered for {}".format( 

89 type(lin_op_a))) 

90 

91 with ops.name_scope(name, "Adjoint"): 

92 return adjoint_fn(lin_op_a) 

93 

94 

95def cholesky(lin_op_a, name=None): 

96 """Get the Cholesky factor associated to lin_op_a. 

97 

98 Args: 

99 lin_op_a: The LinearOperator to decompose. 

100 name: Name to use for this operation. 

101 

102 Returns: 

103 A LinearOperator that represents the lower Cholesky factor of `lin_op_a`. 

104 

105 Raises: 

106 NotImplementedError: If no Cholesky method is defined for the LinearOperator 

107 type of `lin_op_a`. 

108 """ 

109 cholesky_fn = _registered_cholesky(type(lin_op_a)) 

110 if cholesky_fn is None: 

111 raise ValueError("No cholesky decomposition registered for {}".format( 

112 type(lin_op_a))) 

113 

114 with ops.name_scope(name, "Cholesky"): 

115 return cholesky_fn(lin_op_a) 

116 

117 

118def matmul(lin_op_a, lin_op_b, name=None): 

119 """Compute lin_op_a.matmul(lin_op_b). 

120 

121 Args: 

122 lin_op_a: The LinearOperator on the left. 

123 lin_op_b: The LinearOperator on the right. 

124 name: Name to use for this operation. 

125 

126 Returns: 

127 A LinearOperator that represents the matmul between `lin_op_a` and 

128 `lin_op_b`. 

129 

130 Raises: 

131 NotImplementedError: If no matmul method is defined between types of 

132 `lin_op_a` and `lin_op_b`. 

133 """ 

134 matmul_fn = _registered_matmul(type(lin_op_a), type(lin_op_b)) 

135 if matmul_fn is None: 

136 raise ValueError("No matmul registered for {}.matmul({})".format( 

137 type(lin_op_a), type(lin_op_b))) 

138 

139 with ops.name_scope(name, "Matmul"): 

140 return matmul_fn(lin_op_a, lin_op_b) 

141 

142 

143def solve(lin_op_a, lin_op_b, name=None): 

144 """Compute lin_op_a.solve(lin_op_b). 

145 

146 Args: 

147 lin_op_a: The LinearOperator on the left. 

148 lin_op_b: The LinearOperator on the right. 

149 name: Name to use for this operation. 

150 

151 Returns: 

152 A LinearOperator that represents the solve between `lin_op_a` and 

153 `lin_op_b`. 

154 

155 Raises: 

156 NotImplementedError: If no solve method is defined between types of 

157 `lin_op_a` and `lin_op_b`. 

158 """ 

159 solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b)) 

160 if solve_fn is None: 

161 raise ValueError("No solve registered for {}.solve({})".format( 

162 type(lin_op_a), type(lin_op_b))) 

163 

164 with ops.name_scope(name, "Solve"): 

165 return solve_fn(lin_op_a, lin_op_b) 

166 

167 

168def inverse(lin_op_a, name=None): 

169 """Get the Inverse associated to lin_op_a. 

170 

171 Args: 

172 lin_op_a: The LinearOperator to decompose. 

173 name: Name to use for this operation. 

174 

175 Returns: 

176 A LinearOperator that represents the inverse of `lin_op_a`. 

177 

178 Raises: 

179 NotImplementedError: If no Inverse method is defined for the LinearOperator 

180 type of `lin_op_a`. 

181 """ 

182 inverse_fn = _registered_inverse(type(lin_op_a)) 

183 if inverse_fn is None: 

184 raise ValueError("No inverse registered for {}".format( 

185 type(lin_op_a))) 

186 

187 with ops.name_scope(name, "Inverse"): 

188 return inverse_fn(lin_op_a) 

189 

190 

191class RegisterAdjoint: 

192 """Decorator to register an Adjoint implementation function. 

193 

194 Usage: 

195 

196 @linear_operator_algebra.RegisterAdjoint(lin_op.LinearOperatorIdentity) 

197 def _adjoint_identity(lin_op_a): 

198 # Return the identity matrix. 

199 """ 

200 

201 def __init__(self, lin_op_cls_a): 

202 """Initialize the LinearOperator registrar. 

203 

204 Args: 

205 lin_op_cls_a: the class of the LinearOperator to decompose. 

206 """ 

207 self._key = (lin_op_cls_a,) 

208 

209 def __call__(self, adjoint_fn): 

210 """Perform the Adjoint registration. 

211 

212 Args: 

213 adjoint_fn: The function to use for the Adjoint. 

214 

215 Returns: 

216 adjoint_fn 

217 

218 Raises: 

219 TypeError: if adjoint_fn is not a callable. 

220 ValueError: if a Adjoint function has already been registered for 

221 the given argument classes. 

222 """ 

223 if not callable(adjoint_fn): 

224 raise TypeError( 

225 "adjoint_fn must be callable, received: {}".format(adjoint_fn)) 

226 if self._key in _ADJOINTS: 

227 raise ValueError("Adjoint({}) has already been registered to: {}".format( 

228 self._key[0].__name__, _ADJOINTS[self._key])) 

229 _ADJOINTS[self._key] = adjoint_fn 

230 return adjoint_fn 

231 

232 

233class RegisterCholesky: 

234 """Decorator to register a Cholesky implementation function. 

235 

236 Usage: 

237 

238 @linear_operator_algebra.RegisterCholesky(lin_op.LinearOperatorIdentity) 

239 def _cholesky_identity(lin_op_a): 

240 # Return the identity matrix. 

241 """ 

242 

243 def __init__(self, lin_op_cls_a): 

244 """Initialize the LinearOperator registrar. 

245 

246 Args: 

247 lin_op_cls_a: the class of the LinearOperator to decompose. 

248 """ 

249 self._key = (lin_op_cls_a,) 

250 

251 def __call__(self, cholesky_fn): 

252 """Perform the Cholesky registration. 

253 

254 Args: 

255 cholesky_fn: The function to use for the Cholesky. 

256 

257 Returns: 

258 cholesky_fn 

259 

260 Raises: 

261 TypeError: if cholesky_fn is not a callable. 

262 ValueError: if a Cholesky function has already been registered for 

263 the given argument classes. 

264 """ 

265 if not callable(cholesky_fn): 

266 raise TypeError( 

267 "cholesky_fn must be callable, received: {}".format(cholesky_fn)) 

268 if self._key in _CHOLESKY_DECOMPS: 

269 raise ValueError("Cholesky({}) has already been registered to: {}".format( 

270 self._key[0].__name__, _CHOLESKY_DECOMPS[self._key])) 

271 _CHOLESKY_DECOMPS[self._key] = cholesky_fn 

272 return cholesky_fn 

273 

274 

275class RegisterMatmul: 

276 """Decorator to register a Matmul implementation function. 

277 

278 Usage: 

279 

280 @linear_operator_algebra.RegisterMatmul( 

281 lin_op.LinearOperatorIdentity, 

282 lin_op.LinearOperatorIdentity) 

283 def _matmul_identity(a, b): 

284 # Return the identity matrix. 

285 """ 

286 

287 def __init__(self, lin_op_cls_a, lin_op_cls_b): 

288 """Initialize the LinearOperator registrar. 

289 

290 Args: 

291 lin_op_cls_a: the class of the LinearOperator to multiply. 

292 lin_op_cls_b: the class of the second LinearOperator to multiply. 

293 """ 

294 self._key = (lin_op_cls_a, lin_op_cls_b) 

295 

296 def __call__(self, matmul_fn): 

297 """Perform the Matmul registration. 

298 

299 Args: 

300 matmul_fn: The function to use for the Matmul. 

301 

302 Returns: 

303 matmul_fn 

304 

305 Raises: 

306 TypeError: if matmul_fn is not a callable. 

307 ValueError: if a Matmul function has already been registered for 

308 the given argument classes. 

309 """ 

310 if not callable(matmul_fn): 

311 raise TypeError( 

312 "matmul_fn must be callable, received: {}".format(matmul_fn)) 

313 if self._key in _MATMUL: 

314 raise ValueError("Matmul({}, {}) has already been registered.".format( 

315 self._key[0].__name__, 

316 self._key[1].__name__)) 

317 _MATMUL[self._key] = matmul_fn 

318 return matmul_fn 

319 

320 

321class RegisterSolve: 

322 """Decorator to register a Solve implementation function. 

323 

324 Usage: 

325 

326 @linear_operator_algebra.RegisterSolve( 

327 lin_op.LinearOperatorIdentity, 

328 lin_op.LinearOperatorIdentity) 

329 def _solve_identity(a, b): 

330 # Return the identity matrix. 

331 """ 

332 

333 def __init__(self, lin_op_cls_a, lin_op_cls_b): 

334 """Initialize the LinearOperator registrar. 

335 

336 Args: 

337 lin_op_cls_a: the class of the LinearOperator that is computing solve. 

338 lin_op_cls_b: the class of the second LinearOperator to solve. 

339 """ 

340 self._key = (lin_op_cls_a, lin_op_cls_b) 

341 

342 def __call__(self, solve_fn): 

343 """Perform the Solve registration. 

344 

345 Args: 

346 solve_fn: The function to use for the Solve. 

347 

348 Returns: 

349 solve_fn 

350 

351 Raises: 

352 TypeError: if solve_fn is not a callable. 

353 ValueError: if a Solve function has already been registered for 

354 the given argument classes. 

355 """ 

356 if not callable(solve_fn): 

357 raise TypeError( 

358 "solve_fn must be callable, received: {}".format(solve_fn)) 

359 if self._key in _SOLVE: 

360 raise ValueError("Solve({}, {}) has already been registered.".format( 

361 self._key[0].__name__, 

362 self._key[1].__name__)) 

363 _SOLVE[self._key] = solve_fn 

364 return solve_fn 

365 

366 

367class RegisterInverse: 

368 """Decorator to register an Inverse implementation function. 

369 

370 Usage: 

371 

372 @linear_operator_algebra.RegisterInverse(lin_op.LinearOperatorIdentity) 

373 def _inverse_identity(lin_op_a): 

374 # Return the identity matrix. 

375 """ 

376 

377 def __init__(self, lin_op_cls_a): 

378 """Initialize the LinearOperator registrar. 

379 

380 Args: 

381 lin_op_cls_a: the class of the LinearOperator to decompose. 

382 """ 

383 self._key = (lin_op_cls_a,) 

384 

385 def __call__(self, inverse_fn): 

386 """Perform the Inverse registration. 

387 

388 Args: 

389 inverse_fn: The function to use for the Inverse. 

390 

391 Returns: 

392 inverse_fn 

393 

394 Raises: 

395 TypeError: if inverse_fn is not a callable. 

396 ValueError: if a Inverse function has already been registered for 

397 the given argument classes. 

398 """ 

399 if not callable(inverse_fn): 

400 raise TypeError( 

401 "inverse_fn must be callable, received: {}".format(inverse_fn)) 

402 if self._key in _INVERSES: 

403 raise ValueError("Inverse({}) has already been registered to: {}".format( 

404 self._key[0].__name__, _INVERSES[self._key])) 

405 _INVERSES[self._key] = inverse_fn 

406 return inverse_fn