Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/execution_time/lazy_sequence.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

99 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import collections 

21import itertools 

22from collections.abc import Iterator, Sequence 

23from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload 

24 

25import attrs 

26import structlog 

27 

28if TYPE_CHECKING: 

29 from airflow.sdk.definitions.xcom_arg import PlainXComArg 

30 from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance 

31 

32T = TypeVar("T") 

33 

34# This is used to wrap values from the API so the structure is compatible with 

35# ``XCom.deserialize_value``. We don't want to wrap the API values in a nested 

36# {"value": value} dict since it wastes bandwidth. 

37_XComWrapper = collections.namedtuple("_XComWrapper", "value") 

38 

39log = structlog.get_logger(logger_name=__name__) 

40 

41 

42@attrs.define 

43class LazyXComIterator(Iterator[T]): 

44 seq: LazyXComSequence[T] 

45 index: int = 0 

46 dir: Literal[1, -1] = 1 

47 

48 def __next__(self) -> T: 

49 if self.index < 0: 

50 # When iterating backwards, avoid extra HTTP request 

51 raise StopIteration() 

52 try: 

53 val = self.seq[self.index] 

54 except IndexError: 

55 raise StopIteration from None 

56 self.index += self.dir 

57 return val 

58 

59 def __iter__(self) -> Iterator[T]: 

60 return self 

61 

62 

63@attrs.define 

64class LazyXComSequence(Sequence[T]): 

65 _len: int | None = attrs.field(init=False, default=None) 

66 _xcom_arg: PlainXComArg = attrs.field(alias="xcom_arg") 

67 _ti: RuntimeTaskInstance = attrs.field(alias="ti") 

68 

69 def __repr__(self) -> str: 

70 if self._len is not None: 

71 counter = "item" if (length := len(self)) == 1 else "items" 

72 return f"LazyXComSequence([{length} {counter}])" 

73 return "LazyXComSequence(<unevaluated length>)" 

74 

75 def __str__(self) -> str: 

76 return repr(self) 

77 

78 def __eq__(self, other: Any) -> bool: 

79 if not isinstance(other, Sequence): 

80 return NotImplemented 

81 z = itertools.zip_longest(iter(self), iter(other), fillvalue=object()) 

82 return all(x == y for x, y in z) 

83 

84 def __hash__(self): 

85 return hash((*[item for item in iter(self)],)) 

86 

87 def __iter__(self) -> Iterator[T]: 

88 return LazyXComIterator(seq=self) 

89 

90 def __len__(self) -> int: 

91 if self._len is None: 

92 from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse 

93 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

94 

95 task = self._xcom_arg.operator 

96 

97 msg = SUPERVISOR_COMMS.send( 

98 GetXComCount( 

99 key=self._xcom_arg.key, 

100 dag_id=task.dag_id, 

101 run_id=self._ti.run_id, 

102 task_id=task.task_id, 

103 ), 

104 ) 

105 if isinstance(msg, ErrorResponse): 

106 raise RuntimeError(msg) 

107 if not isinstance(msg, XComCountResponse): 

108 raise TypeError(f"Got unexpected response to GetXComCount: {msg!r}") 

109 self._len = msg.len 

110 return self._len 

111 

112 @overload 

113 def __getitem__(self, key: int) -> T: ... 

114 

115 @overload 

116 def __getitem__(self, key: slice) -> Sequence[T]: ... 

117 

118 def __getitem__(self, key: int | slice) -> T | Sequence[T]: 

119 from airflow.sdk.execution_time.comms import ( 

120 ErrorResponse, 

121 GetXComSequenceItem, 

122 GetXComSequenceSlice, 

123 XComSequenceIndexResult, 

124 XComSequenceSliceResult, 

125 ) 

126 from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS 

127 from airflow.sdk.execution_time.xcom import XCom 

128 

129 if isinstance(key, slice): 

130 start, stop, step = _coerce_slice(key) 

131 source = (xcom_arg := self._xcom_arg).operator 

132 msg = SUPERVISOR_COMMS.send( 

133 GetXComSequenceSlice( 

134 key=xcom_arg.key, 

135 dag_id=source.dag_id, 

136 task_id=source.task_id, 

137 run_id=self._ti.run_id, 

138 start=start, 

139 stop=stop, 

140 step=step, 

141 ), 

142 ) 

143 if not isinstance(msg, XComSequenceSliceResult): 

144 raise TypeError(f"Got unexpected response to GetXComSequenceSlice: {msg!r}") 

145 return [XCom.deserialize_value(_XComWrapper(value)) for value in msg.root] 

146 

147 if not isinstance(key, int): 

148 if (index := getattr(key, "__index__", None)) is not None: 

149 key = index() 

150 raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") 

151 

152 source = (xcom_arg := self._xcom_arg).operator 

153 msg = SUPERVISOR_COMMS.send( 

154 GetXComSequenceItem( 

155 key=xcom_arg.key, 

156 dag_id=source.dag_id, 

157 task_id=source.task_id, 

158 run_id=self._ti.run_id, 

159 offset=key, 

160 ), 

161 ) 

162 if isinstance(msg, ErrorResponse): 

163 raise IndexError(key) 

164 if not isinstance(msg, XComSequenceIndexResult): 

165 raise TypeError(f"Got unexpected response to GetXComSequenceItem: {msg!r}") 

166 return XCom.deserialize_value(_XComWrapper(msg.root)) 

167 

168 

169def _coerce_slice_index(value: Any) -> int | None: 

170 """ 

171 Check slice attribute's type and convert it to int. 

172 

173 See CPython documentation on this: 

174 https://docs.python.org/3/reference/datamodel.html#object.__index__ 

175 """ 

176 if value is None or isinstance(value, int): 

177 return value 

178 if (index := getattr(value, "__index__", None)) is not None: 

179 return index() 

180 raise TypeError("slice indices must be integers or None or have an __index__ method") 

181 

182 

183def _coerce_slice(key: slice) -> tuple[int | None, int | None, int | None]: 

184 """ 

185 Check slice content and convert it for SQL. 

186 

187 See CPython documentation on this: 

188 https://docs.python.org/3/reference/datamodel.html#slice-objects 

189 """ 

190 if (step := _coerce_slice_index(key.step)) == 0: 

191 raise ValueError("slice step cannot be zero") 

192 return _coerce_slice_index(key.start), _coerce_slice_index(key.stop), step