Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/torch.py: 23%
61 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Required functions for optimized contractions of numpy arrays using pytorch.
3"""
5import numpy as np
7from ..parser import convert_to_valid_einsum_chars
8from ..sharing import to_backend_cache_wrap
10__all__ = [
11 "transpose",
12 "einsum",
13 "tensordot",
14 "to_torch",
15 "build_expression",
16 "evaluate_constants",
17]
19_TORCH_DEVICE = None
20_TORCH_HAS_TENSORDOT = None
22_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
25def _get_torch_and_device():
26 global _TORCH_DEVICE
27 global _TORCH_HAS_TENSORDOT
29 if _TORCH_DEVICE is None:
30 import torch
32 device = "cuda" if torch.cuda.is_available() else "cpu"
33 _TORCH_DEVICE = torch, device
34 _TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot")
36 return _TORCH_DEVICE
39def transpose(a, axes):
40 """Normal torch transpose is only valid for 2D matrices."""
41 return a.permute(*axes)
44def einsum(equation, *operands):
45 """Variadic version of torch.einsum to match numpy api."""
46 # rename symbols to support PyTorch 0.4.1 and earlier,
47 # which allow only symbols a-z.
48 equation = convert_to_valid_einsum_chars(equation)
50 torch, _ = _get_torch_and_device()
51 return torch.einsum(equation, operands)
54def tensordot(x, y, axes=2):
55 """Simple translation of tensordot syntax to einsum."""
56 torch, _ = _get_torch_and_device()
58 if _TORCH_HAS_TENSORDOT:
59 return torch.tensordot(x, y, dims=axes)
61 xnd = x.ndimension()
62 ynd = y.ndimension()
64 # convert int argument to (list[int], list[int])
65 if isinstance(axes, int):
66 axes = range(xnd - axes, xnd), range(axes)
68 # convert (int, int) to (list[int], list[int])
69 if isinstance(axes[0], int):
70 axes = (axes[0],), axes[1]
71 if isinstance(axes[1], int):
72 axes = axes[0], (axes[1],)
74 # initialize empty indices
75 x_ix = [None] * xnd
76 y_ix = [None] * ynd
77 out_ix = []
79 # fill in repeated indices
80 available_ix = iter(_torch_symbols_base)
81 for ax1, ax2 in zip(*axes):
82 repeat = next(available_ix)
83 x_ix[ax1] = repeat
84 y_ix[ax2] = repeat
86 # fill in the rest, and maintain output order
87 for i in range(xnd):
88 if x_ix[i] is None:
89 leave = next(available_ix)
90 x_ix[i] = leave
91 out_ix.append(leave)
92 for i in range(ynd):
93 if y_ix[i] is None:
94 leave = next(available_ix)
95 y_ix[i] = leave
96 out_ix.append(leave)
98 # form full string and contract!
99 einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
100 return einsum(einsum_str, x, y)
103@to_backend_cache_wrap
104def to_torch(array):
105 torch, device = _get_torch_and_device()
107 if isinstance(array, np.ndarray):
108 return torch.from_numpy(array).to(device)
110 return array
113def build_expression(_, expr): # pragma: no cover
114 """Build a torch function based on ``arrays`` and ``expr``."""
116 def torch_contract(*arrays):
117 torch_arrays = [to_torch(x) for x in arrays]
118 torch_out = expr._contract(torch_arrays, backend="torch")
120 if torch_out.device.type == "cpu":
121 return torch_out.numpy()
123 return torch_out.cpu().numpy()
125 return torch_contract
128def evaluate_constants(const_arrays, expr):
129 """Convert constant arguments to torch, and perform any possible constant
130 contractions.
131 """
132 const_arrays = [to_torch(x) for x in const_arrays]
133 return expr(*const_arrays, backend="torch", evaluate_constants=True)