Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/torch.py: 23%
61 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Required functions for optimized contractions of numpy arrays using pytorch.
3"""
5import numpy as np
7from ..parser import convert_to_valid_einsum_chars
8from ..sharing import to_backend_cache_wrap
10__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"]
12_TORCH_DEVICE = None
13_TORCH_HAS_TENSORDOT = None
15_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
18def _get_torch_and_device():
19 global _TORCH_DEVICE
20 global _TORCH_HAS_TENSORDOT
22 if _TORCH_DEVICE is None:
23 import torch
24 device = 'cuda' if torch.cuda.is_available() else 'cpu'
25 _TORCH_DEVICE = torch, device
26 _TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot')
28 return _TORCH_DEVICE
31def transpose(a, axes):
32 """Normal torch transpose is only valid for 2D matrices.
33 """
34 return a.permute(*axes)
37def einsum(equation, *operands):
38 """Variadic version of torch.einsum to match numpy api.
39 """
40 # rename symbols to support PyTorch 0.4.1 and earlier,
41 # which allow only symbols a-z.
42 equation = convert_to_valid_einsum_chars(equation)
44 torch, _ = _get_torch_and_device()
45 return torch.einsum(equation, operands)
48def tensordot(x, y, axes=2):
49 """Simple translation of tensordot syntax to einsum.
50 """
51 torch, _ = _get_torch_and_device()
53 if _TORCH_HAS_TENSORDOT:
54 return torch.tensordot(x, y, dims=axes)
56 xnd = x.ndimension()
57 ynd = y.ndimension()
59 # convert int argument to (list[int], list[int])
60 if isinstance(axes, int):
61 axes = range(xnd - axes, xnd), range(axes)
63 # convert (int, int) to (list[int], list[int])
64 if isinstance(axes[0], int):
65 axes = (axes[0], ), axes[1]
66 if isinstance(axes[1], int):
67 axes = axes[0], (axes[1], )
69 # initialize empty indices
70 x_ix = [None] * xnd
71 y_ix = [None] * ynd
72 out_ix = []
74 # fill in repeated indices
75 available_ix = iter(_torch_symbols_base)
76 for ax1, ax2 in zip(*axes):
77 repeat = next(available_ix)
78 x_ix[ax1] = repeat
79 y_ix[ax2] = repeat
81 # fill in the rest, and maintain output order
82 for i in range(xnd):
83 if x_ix[i] is None:
84 leave = next(available_ix)
85 x_ix[i] = leave
86 out_ix.append(leave)
87 for i in range(ynd):
88 if y_ix[i] is None:
89 leave = next(available_ix)
90 y_ix[i] = leave
91 out_ix.append(leave)
93 # form full string and contract!
94 einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix)))
95 return einsum(einsum_str, x, y)
98@to_backend_cache_wrap
99def to_torch(array):
100 torch, device = _get_torch_and_device()
102 if isinstance(array, np.ndarray):
103 return torch.from_numpy(array).to(device)
105 return array
108def build_expression(_, expr): # pragma: no cover
109 """Build a torch function based on ``arrays`` and ``expr``.
110 """
111 def torch_contract(*arrays):
112 torch_arrays = [to_torch(x) for x in arrays]
113 torch_out = expr._contract(torch_arrays, backend='torch')
115 if torch_out.device.type == 'cpu':
116 return torch_out.numpy()
118 return torch_out.cpu().numpy()
120 return torch_contract
123def evaluate_constants(const_arrays, expr):
124 """Convert constant arguments to torch, and perform any possible constant
125 contractions.
126 """
127 const_arrays = [to_torch(x) for x in const_arrays]
128 return expr(*const_arrays, backend='torch', evaluate_constants=True)