Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/linear_operator_algebra.py: 55%
109 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Registration mechanisms for various n-ary operations on LinearOperators."""
18import itertools
20from tensorflow.python.framework import ops
21from tensorflow.python.util import tf_inspect
24_ADJOINTS = {}
25_CHOLESKY_DECOMPS = {}
26_MATMUL = {}
27_SOLVE = {}
28_INVERSES = {}
31def _registered_function(type_list, registry):
32 """Given a list of classes, finds the most specific function registered."""
33 enumerated_hierarchies = [enumerate(tf_inspect.getmro(t)) for t in type_list]
34 # Get all possible combinations of hierarchies.
35 cls_combinations = list(itertools.product(*enumerated_hierarchies))
37 def hierarchy_distance(cls_combination):
38 candidate_distance = sum(c[0] for c in cls_combination)
39 if tuple(c[1] for c in cls_combination) in registry:
40 return candidate_distance
41 return 10000
43 registered_combination = min(cls_combinations, key=hierarchy_distance)
44 return registry.get(tuple(r[1] for r in registered_combination), None)
47def _registered_adjoint(type_a):
48 """Get the Adjoint function registered for class a."""
49 return _registered_function([type_a], _ADJOINTS)
52def _registered_cholesky(type_a):
53 """Get the Cholesky function registered for class a."""
54 return _registered_function([type_a], _CHOLESKY_DECOMPS)
57def _registered_matmul(type_a, type_b):
58 """Get the Matmul function registered for classes a and b."""
59 return _registered_function([type_a, type_b], _MATMUL)
62def _registered_solve(type_a, type_b):
63 """Get the Solve function registered for classes a and b."""
64 return _registered_function([type_a, type_b], _SOLVE)
67def _registered_inverse(type_a):
68 """Get the Cholesky function registered for class a."""
69 return _registered_function([type_a], _INVERSES)
72def adjoint(lin_op_a, name=None):
73 """Get the adjoint associated to lin_op_a.
75 Args:
76 lin_op_a: The LinearOperator to take the adjoint of.
77 name: Name to use for this operation.
79 Returns:
80 A LinearOperator that represents the adjoint of `lin_op_a`.
82 Raises:
83 NotImplementedError: If no Adjoint method is defined for the LinearOperator
84 type of `lin_op_a`.
85 """
86 adjoint_fn = _registered_adjoint(type(lin_op_a))
87 if adjoint_fn is None:
88 raise ValueError("No adjoint registered for {}".format(
89 type(lin_op_a)))
91 with ops.name_scope(name, "Adjoint"):
92 return adjoint_fn(lin_op_a)
95def cholesky(lin_op_a, name=None):
96 """Get the Cholesky factor associated to lin_op_a.
98 Args:
99 lin_op_a: The LinearOperator to decompose.
100 name: Name to use for this operation.
102 Returns:
103 A LinearOperator that represents the lower Cholesky factor of `lin_op_a`.
105 Raises:
106 NotImplementedError: If no Cholesky method is defined for the LinearOperator
107 type of `lin_op_a`.
108 """
109 cholesky_fn = _registered_cholesky(type(lin_op_a))
110 if cholesky_fn is None:
111 raise ValueError("No cholesky decomposition registered for {}".format(
112 type(lin_op_a)))
114 with ops.name_scope(name, "Cholesky"):
115 return cholesky_fn(lin_op_a)
118def matmul(lin_op_a, lin_op_b, name=None):
119 """Compute lin_op_a.matmul(lin_op_b).
121 Args:
122 lin_op_a: The LinearOperator on the left.
123 lin_op_b: The LinearOperator on the right.
124 name: Name to use for this operation.
126 Returns:
127 A LinearOperator that represents the matmul between `lin_op_a` and
128 `lin_op_b`.
130 Raises:
131 NotImplementedError: If no matmul method is defined between types of
132 `lin_op_a` and `lin_op_b`.
133 """
134 matmul_fn = _registered_matmul(type(lin_op_a), type(lin_op_b))
135 if matmul_fn is None:
136 raise ValueError("No matmul registered for {}.matmul({})".format(
137 type(lin_op_a), type(lin_op_b)))
139 with ops.name_scope(name, "Matmul"):
140 return matmul_fn(lin_op_a, lin_op_b)
143def solve(lin_op_a, lin_op_b, name=None):
144 """Compute lin_op_a.solve(lin_op_b).
146 Args:
147 lin_op_a: The LinearOperator on the left.
148 lin_op_b: The LinearOperator on the right.
149 name: Name to use for this operation.
151 Returns:
152 A LinearOperator that represents the solve between `lin_op_a` and
153 `lin_op_b`.
155 Raises:
156 NotImplementedError: If no solve method is defined between types of
157 `lin_op_a` and `lin_op_b`.
158 """
159 solve_fn = _registered_solve(type(lin_op_a), type(lin_op_b))
160 if solve_fn is None:
161 raise ValueError("No solve registered for {}.solve({})".format(
162 type(lin_op_a), type(lin_op_b)))
164 with ops.name_scope(name, "Solve"):
165 return solve_fn(lin_op_a, lin_op_b)
168def inverse(lin_op_a, name=None):
169 """Get the Inverse associated to lin_op_a.
171 Args:
172 lin_op_a: The LinearOperator to decompose.
173 name: Name to use for this operation.
175 Returns:
176 A LinearOperator that represents the inverse of `lin_op_a`.
178 Raises:
179 NotImplementedError: If no Inverse method is defined for the LinearOperator
180 type of `lin_op_a`.
181 """
182 inverse_fn = _registered_inverse(type(lin_op_a))
183 if inverse_fn is None:
184 raise ValueError("No inverse registered for {}".format(
185 type(lin_op_a)))
187 with ops.name_scope(name, "Inverse"):
188 return inverse_fn(lin_op_a)
191class RegisterAdjoint:
192 """Decorator to register an Adjoint implementation function.
194 Usage:
196 @linear_operator_algebra.RegisterAdjoint(lin_op.LinearOperatorIdentity)
197 def _adjoint_identity(lin_op_a):
198 # Return the identity matrix.
199 """
201 def __init__(self, lin_op_cls_a):
202 """Initialize the LinearOperator registrar.
204 Args:
205 lin_op_cls_a: the class of the LinearOperator to decompose.
206 """
207 self._key = (lin_op_cls_a,)
209 def __call__(self, adjoint_fn):
210 """Perform the Adjoint registration.
212 Args:
213 adjoint_fn: The function to use for the Adjoint.
215 Returns:
216 adjoint_fn
218 Raises:
219 TypeError: if adjoint_fn is not a callable.
220 ValueError: if a Adjoint function has already been registered for
221 the given argument classes.
222 """
223 if not callable(adjoint_fn):
224 raise TypeError(
225 "adjoint_fn must be callable, received: {}".format(adjoint_fn))
226 if self._key in _ADJOINTS:
227 raise ValueError("Adjoint({}) has already been registered to: {}".format(
228 self._key[0].__name__, _ADJOINTS[self._key]))
229 _ADJOINTS[self._key] = adjoint_fn
230 return adjoint_fn
233class RegisterCholesky:
234 """Decorator to register a Cholesky implementation function.
236 Usage:
238 @linear_operator_algebra.RegisterCholesky(lin_op.LinearOperatorIdentity)
239 def _cholesky_identity(lin_op_a):
240 # Return the identity matrix.
241 """
243 def __init__(self, lin_op_cls_a):
244 """Initialize the LinearOperator registrar.
246 Args:
247 lin_op_cls_a: the class of the LinearOperator to decompose.
248 """
249 self._key = (lin_op_cls_a,)
251 def __call__(self, cholesky_fn):
252 """Perform the Cholesky registration.
254 Args:
255 cholesky_fn: The function to use for the Cholesky.
257 Returns:
258 cholesky_fn
260 Raises:
261 TypeError: if cholesky_fn is not a callable.
262 ValueError: if a Cholesky function has already been registered for
263 the given argument classes.
264 """
265 if not callable(cholesky_fn):
266 raise TypeError(
267 "cholesky_fn must be callable, received: {}".format(cholesky_fn))
268 if self._key in _CHOLESKY_DECOMPS:
269 raise ValueError("Cholesky({}) has already been registered to: {}".format(
270 self._key[0].__name__, _CHOLESKY_DECOMPS[self._key]))
271 _CHOLESKY_DECOMPS[self._key] = cholesky_fn
272 return cholesky_fn
275class RegisterMatmul:
276 """Decorator to register a Matmul implementation function.
278 Usage:
280 @linear_operator_algebra.RegisterMatmul(
281 lin_op.LinearOperatorIdentity,
282 lin_op.LinearOperatorIdentity)
283 def _matmul_identity(a, b):
284 # Return the identity matrix.
285 """
287 def __init__(self, lin_op_cls_a, lin_op_cls_b):
288 """Initialize the LinearOperator registrar.
290 Args:
291 lin_op_cls_a: the class of the LinearOperator to multiply.
292 lin_op_cls_b: the class of the second LinearOperator to multiply.
293 """
294 self._key = (lin_op_cls_a, lin_op_cls_b)
296 def __call__(self, matmul_fn):
297 """Perform the Matmul registration.
299 Args:
300 matmul_fn: The function to use for the Matmul.
302 Returns:
303 matmul_fn
305 Raises:
306 TypeError: if matmul_fn is not a callable.
307 ValueError: if a Matmul function has already been registered for
308 the given argument classes.
309 """
310 if not callable(matmul_fn):
311 raise TypeError(
312 "matmul_fn must be callable, received: {}".format(matmul_fn))
313 if self._key in _MATMUL:
314 raise ValueError("Matmul({}, {}) has already been registered.".format(
315 self._key[0].__name__,
316 self._key[1].__name__))
317 _MATMUL[self._key] = matmul_fn
318 return matmul_fn
321class RegisterSolve:
322 """Decorator to register a Solve implementation function.
324 Usage:
326 @linear_operator_algebra.RegisterSolve(
327 lin_op.LinearOperatorIdentity,
328 lin_op.LinearOperatorIdentity)
329 def _solve_identity(a, b):
330 # Return the identity matrix.
331 """
333 def __init__(self, lin_op_cls_a, lin_op_cls_b):
334 """Initialize the LinearOperator registrar.
336 Args:
337 lin_op_cls_a: the class of the LinearOperator that is computing solve.
338 lin_op_cls_b: the class of the second LinearOperator to solve.
339 """
340 self._key = (lin_op_cls_a, lin_op_cls_b)
342 def __call__(self, solve_fn):
343 """Perform the Solve registration.
345 Args:
346 solve_fn: The function to use for the Solve.
348 Returns:
349 solve_fn
351 Raises:
352 TypeError: if solve_fn is not a callable.
353 ValueError: if a Solve function has already been registered for
354 the given argument classes.
355 """
356 if not callable(solve_fn):
357 raise TypeError(
358 "solve_fn must be callable, received: {}".format(solve_fn))
359 if self._key in _SOLVE:
360 raise ValueError("Solve({}, {}) has already been registered.".format(
361 self._key[0].__name__,
362 self._key[1].__name__))
363 _SOLVE[self._key] = solve_fn
364 return solve_fn
367class RegisterInverse:
368 """Decorator to register an Inverse implementation function.
370 Usage:
372 @linear_operator_algebra.RegisterInverse(lin_op.LinearOperatorIdentity)
373 def _inverse_identity(lin_op_a):
374 # Return the identity matrix.
375 """
377 def __init__(self, lin_op_cls_a):
378 """Initialize the LinearOperator registrar.
380 Args:
381 lin_op_cls_a: the class of the LinearOperator to decompose.
382 """
383 self._key = (lin_op_cls_a,)
385 def __call__(self, inverse_fn):
386 """Perform the Inverse registration.
388 Args:
389 inverse_fn: The function to use for the Inverse.
391 Returns:
392 inverse_fn
394 Raises:
395 TypeError: if inverse_fn is not a callable.
396 ValueError: if a Inverse function has already been registered for
397 the given argument classes.
398 """
399 if not callable(inverse_fn):
400 raise TypeError(
401 "inverse_fn must be callable, received: {}".format(inverse_fn))
402 if self._key in _INVERSES:
403 raise ValueError("Inverse({}) has already been registered to: {}".format(
404 self._key[0].__name__, _INVERSES[self._key]))
405 _INVERSES[self._key] = inverse_fn
406 return inverse_fn