Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/torch.py: 23%

61 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using pytorch. 

3""" 

4 

5import numpy as np 

6 

7from ..parser import convert_to_valid_einsum_chars 

8from ..sharing import to_backend_cache_wrap 

9 

10__all__ = [ 

11 "transpose", 

12 "einsum", 

13 "tensordot", 

14 "to_torch", 

15 "build_expression", 

16 "evaluate_constants", 

17] 

18 

19_TORCH_DEVICE = None 

20_TORCH_HAS_TENSORDOT = None 

21 

22_torch_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" 

23 

24 

25def _get_torch_and_device(): 

26 global _TORCH_DEVICE 

27 global _TORCH_HAS_TENSORDOT 

28 

29 if _TORCH_DEVICE is None: 

30 import torch 

31 

32 device = "cuda" if torch.cuda.is_available() else "cpu" 

33 _TORCH_DEVICE = torch, device 

34 _TORCH_HAS_TENSORDOT = hasattr(torch, "tensordot") 

35 

36 return _TORCH_DEVICE 

37 

38 

39def transpose(a, axes): 

40 """Normal torch transpose is only valid for 2D matrices.""" 

41 return a.permute(*axes) 

42 

43 

44def einsum(equation, *operands): 

45 """Variadic version of torch.einsum to match numpy api.""" 

46 # rename symbols to support PyTorch 0.4.1 and earlier, 

47 # which allow only symbols a-z. 

48 equation = convert_to_valid_einsum_chars(equation) 

49 

50 torch, _ = _get_torch_and_device() 

51 return torch.einsum(equation, operands) 

52 

53 

54def tensordot(x, y, axes=2): 

55 """Simple translation of tensordot syntax to einsum.""" 

56 torch, _ = _get_torch_and_device() 

57 

58 if _TORCH_HAS_TENSORDOT: 

59 return torch.tensordot(x, y, dims=axes) 

60 

61 xnd = x.ndimension() 

62 ynd = y.ndimension() 

63 

64 # convert int argument to (list[int], list[int]) 

65 if isinstance(axes, int): 

66 axes = range(xnd - axes, xnd), range(axes) 

67 

68 # convert (int, int) to (list[int], list[int]) 

69 if isinstance(axes[0], int): 

70 axes = (axes[0],), axes[1] 

71 if isinstance(axes[1], int): 

72 axes = axes[0], (axes[1],) 

73 

74 # initialize empty indices 

75 x_ix = [None] * xnd 

76 y_ix = [None] * ynd 

77 out_ix = [] 

78 

79 # fill in repeated indices 

80 available_ix = iter(_torch_symbols_base) 

81 for ax1, ax2 in zip(*axes): 

82 repeat = next(available_ix) 

83 x_ix[ax1] = repeat 

84 y_ix[ax2] = repeat 

85 

86 # fill in the rest, and maintain output order 

87 for i in range(xnd): 

88 if x_ix[i] is None: 

89 leave = next(available_ix) 

90 x_ix[i] = leave 

91 out_ix.append(leave) 

92 for i in range(ynd): 

93 if y_ix[i] is None: 

94 leave = next(available_ix) 

95 y_ix[i] = leave 

96 out_ix.append(leave) 

97 

98 # form full string and contract! 

99 einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix))) 

100 return einsum(einsum_str, x, y) 

101 

102 

103@to_backend_cache_wrap 

104def to_torch(array): 

105 torch, device = _get_torch_and_device() 

106 

107 if isinstance(array, np.ndarray): 

108 return torch.from_numpy(array).to(device) 

109 

110 return array 

111 

112 

113def build_expression(_, expr): # pragma: no cover 

114 """Build a torch function based on ``arrays`` and ``expr``.""" 

115 

116 def torch_contract(*arrays): 

117 torch_arrays = [to_torch(x) for x in arrays] 

118 torch_out = expr._contract(torch_arrays, backend="torch") 

119 

120 if torch_out.device.type == "cpu": 

121 return torch_out.numpy() 

122 

123 return torch_out.cpu().numpy() 

124 

125 return torch_contract 

126 

127 

128def evaluate_constants(const_arrays, expr): 

129 """Convert constant arguments to torch, and perform any possible constant 

130 contractions. 

131 """ 

132 const_arrays = [to_torch(x) for x in const_arrays] 

133 return expr(*const_arrays, backend="torch", evaluate_constants=True)