Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/backends/torch.py: 23%

61 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Required functions for optimized contractions of numpy arrays using pytorch. 

3""" 

4 

5import numpy as np 

6 

7from ..parser import convert_to_valid_einsum_chars 

8from ..sharing import to_backend_cache_wrap 

9 

10__all__ = ["transpose", "einsum", "tensordot", "to_torch", "build_expression", "evaluate_constants"] 

11 

12_TORCH_DEVICE = None 

13_TORCH_HAS_TENSORDOT = None 

14 

15_torch_symbols_base = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 

16 

17 

18def _get_torch_and_device(): 

19 global _TORCH_DEVICE 

20 global _TORCH_HAS_TENSORDOT 

21 

22 if _TORCH_DEVICE is None: 

23 import torch 

24 device = 'cuda' if torch.cuda.is_available() else 'cpu' 

25 _TORCH_DEVICE = torch, device 

26 _TORCH_HAS_TENSORDOT = hasattr(torch, 'tensordot') 

27 

28 return _TORCH_DEVICE 

29 

30 

31def transpose(a, axes): 

32 """Normal torch transpose is only valid for 2D matrices. 

33 """ 

34 return a.permute(*axes) 

35 

36 

37def einsum(equation, *operands): 

38 """Variadic version of torch.einsum to match numpy api. 

39 """ 

40 # rename symbols to support PyTorch 0.4.1 and earlier, 

41 # which allow only symbols a-z. 

42 equation = convert_to_valid_einsum_chars(equation) 

43 

44 torch, _ = _get_torch_and_device() 

45 return torch.einsum(equation, operands) 

46 

47 

48def tensordot(x, y, axes=2): 

49 """Simple translation of tensordot syntax to einsum. 

50 """ 

51 torch, _ = _get_torch_and_device() 

52 

53 if _TORCH_HAS_TENSORDOT: 

54 return torch.tensordot(x, y, dims=axes) 

55 

56 xnd = x.ndimension() 

57 ynd = y.ndimension() 

58 

59 # convert int argument to (list[int], list[int]) 

60 if isinstance(axes, int): 

61 axes = range(xnd - axes, xnd), range(axes) 

62 

63 # convert (int, int) to (list[int], list[int]) 

64 if isinstance(axes[0], int): 

65 axes = (axes[0], ), axes[1] 

66 if isinstance(axes[1], int): 

67 axes = axes[0], (axes[1], ) 

68 

69 # initialize empty indices 

70 x_ix = [None] * xnd 

71 y_ix = [None] * ynd 

72 out_ix = [] 

73 

74 # fill in repeated indices 

75 available_ix = iter(_torch_symbols_base) 

76 for ax1, ax2 in zip(*axes): 

77 repeat = next(available_ix) 

78 x_ix[ax1] = repeat 

79 y_ix[ax2] = repeat 

80 

81 # fill in the rest, and maintain output order 

82 for i in range(xnd): 

83 if x_ix[i] is None: 

84 leave = next(available_ix) 

85 x_ix[i] = leave 

86 out_ix.append(leave) 

87 for i in range(ynd): 

88 if y_ix[i] is None: 

89 leave = next(available_ix) 

90 y_ix[i] = leave 

91 out_ix.append(leave) 

92 

93 # form full string and contract! 

94 einsum_str = "{},{}->{}".format(*map("".join, (x_ix, y_ix, out_ix))) 

95 return einsum(einsum_str, x, y) 

96 

97 

98@to_backend_cache_wrap 

99def to_torch(array): 

100 torch, device = _get_torch_and_device() 

101 

102 if isinstance(array, np.ndarray): 

103 return torch.from_numpy(array).to(device) 

104 

105 return array 

106 

107 

108def build_expression(_, expr): # pragma: no cover 

109 """Build a torch function based on ``arrays`` and ``expr``. 

110 """ 

111 def torch_contract(*arrays): 

112 torch_arrays = [to_torch(x) for x in arrays] 

113 torch_out = expr._contract(torch_arrays, backend='torch') 

114 

115 if torch_out.device.type == 'cpu': 

116 return torch_out.numpy() 

117 

118 return torch_out.cpu().numpy() 

119 

120 return torch_contract 

121 

122 

123def evaluate_constants(const_arrays, expr): 

124 """Convert constant arguments to torch, and perform any possible constant 

125 contractions. 

126 """ 

127 const_arrays = [to_torch(x) for x in const_arrays] 

128 return expr(*const_arrays, backend='torch', evaluate_constants=True)