1from __future__ import annotations
2
3import warnings
4from collections import OrderedDict, deque
5from dataclasses import dataclass, field
6from types import TracebackType
7from typing import Generic, NamedTuple, TypeVar
8
9from .. import (
10 BrokenResourceError,
11 ClosedResourceError,
12 EndOfStream,
13 WouldBlock,
14)
15from .._core._testing import TaskInfo, get_current_task
16from ..abc import Event, ObjectReceiveStream, ObjectSendStream
17from ..lowlevel import checkpoint
18
19T_Item = TypeVar("T_Item")
20T_co = TypeVar("T_co", covariant=True)
21T_contra = TypeVar("T_contra", contravariant=True)
22
23
24class MemoryObjectStreamStatistics(NamedTuple):
25 current_buffer_used: int #: number of items stored in the buffer
26 #: maximum number of items that can be stored on this stream (or :data:`math.inf`)
27 max_buffer_size: float
28 open_send_streams: int #: number of unclosed clones of the send stream
29 open_receive_streams: int #: number of unclosed clones of the receive stream
30 #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
31 tasks_waiting_send: int
32 #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
33 tasks_waiting_receive: int
34
35
36@dataclass(eq=False)
37class MemoryObjectItemReceiver(Generic[T_Item]):
38 task_info: TaskInfo = field(init=False, default_factory=get_current_task)
39 item: T_Item = field(init=False)
40
41
42@dataclass(eq=False)
43class MemoryObjectStreamState(Generic[T_Item]):
44 max_buffer_size: float = field()
45 buffer: deque[T_Item] = field(init=False, default_factory=deque)
46 open_send_channels: int = field(init=False, default=0)
47 open_receive_channels: int = field(init=False, default=0)
48 waiting_receivers: OrderedDict[Event, MemoryObjectItemReceiver[T_Item]] = field(
49 init=False, default_factory=OrderedDict
50 )
51 waiting_senders: OrderedDict[Event, T_Item] = field(
52 init=False, default_factory=OrderedDict
53 )
54
55 def statistics(self) -> MemoryObjectStreamStatistics:
56 return MemoryObjectStreamStatistics(
57 len(self.buffer),
58 self.max_buffer_size,
59 self.open_send_channels,
60 self.open_receive_channels,
61 len(self.waiting_senders),
62 len(self.waiting_receivers),
63 )
64
65
66@dataclass(eq=False)
67class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
68 _state: MemoryObjectStreamState[T_co]
69 _closed: bool = field(init=False, default=False)
70
71 def __post_init__(self) -> None:
72 self._state.open_receive_channels += 1
73
74 def receive_nowait(self) -> T_co:
75 """
76 Receive the next item if it can be done without waiting.
77
78 :return: the received item
79 :raises ~anyio.ClosedResourceError: if this send stream has been closed
80 :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
81 closed from the sending end
82 :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
83 waiting to send
84
85 """
86 if self._closed:
87 raise ClosedResourceError
88
89 if self._state.waiting_senders:
90 # Get the item from the next sender
91 send_event, item = self._state.waiting_senders.popitem(last=False)
92 self._state.buffer.append(item)
93 send_event.set()
94
95 if self._state.buffer:
96 return self._state.buffer.popleft()
97 elif not self._state.open_send_channels:
98 raise EndOfStream
99
100 raise WouldBlock
101
102 async def receive(self) -> T_co:
103 await checkpoint()
104 try:
105 return self.receive_nowait()
106 except WouldBlock:
107 # Add ourselves in the queue
108 receive_event = Event()
109 receiver = MemoryObjectItemReceiver[T_co]()
110 self._state.waiting_receivers[receive_event] = receiver
111
112 try:
113 await receive_event.wait()
114 finally:
115 self._state.waiting_receivers.pop(receive_event, None)
116
117 try:
118 return receiver.item
119 except AttributeError:
120 raise EndOfStream
121
122 def clone(self) -> MemoryObjectReceiveStream[T_co]:
123 """
124 Create a clone of this receive stream.
125
126 Each clone can be closed separately. Only when all clones have been closed will
127 the receiving end of the memory stream be considered closed by the sending ends.
128
129 :return: the cloned stream
130
131 """
132 if self._closed:
133 raise ClosedResourceError
134
135 return MemoryObjectReceiveStream(_state=self._state)
136
137 def close(self) -> None:
138 """
139 Close the stream.
140
141 This works the exact same way as :meth:`aclose`, but is provided as a special
142 case for the benefit of synchronous callbacks.
143
144 """
145 if not self._closed:
146 self._closed = True
147 self._state.open_receive_channels -= 1
148 if self._state.open_receive_channels == 0:
149 send_events = list(self._state.waiting_senders.keys())
150 for event in send_events:
151 event.set()
152
153 async def aclose(self) -> None:
154 self.close()
155
156 def statistics(self) -> MemoryObjectStreamStatistics:
157 """
158 Return statistics about the current state of this stream.
159
160 .. versionadded:: 3.0
161 """
162 return self._state.statistics()
163
164 def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
165 return self
166
167 def __exit__(
168 self,
169 exc_type: type[BaseException] | None,
170 exc_val: BaseException | None,
171 exc_tb: TracebackType | None,
172 ) -> None:
173 self.close()
174
175 def __del__(self) -> None:
176 if not self._closed:
177 warnings.warn(
178 f"Unclosed <{self.__class__.__name__}>",
179 ResourceWarning,
180 source=self,
181 )
182
183
184@dataclass(eq=False)
185class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
186 _state: MemoryObjectStreamState[T_contra]
187 _closed: bool = field(init=False, default=False)
188
189 def __post_init__(self) -> None:
190 self._state.open_send_channels += 1
191
192 def send_nowait(self, item: T_contra) -> None:
193 """
194 Send an item immediately if it can be done without waiting.
195
196 :param item: the item to send
197 :raises ~anyio.ClosedResourceError: if this send stream has been closed
198 :raises ~anyio.BrokenResourceError: if the stream has been closed from the
199 receiving end
200 :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
201 to receive
202
203 """
204 if self._closed:
205 raise ClosedResourceError
206 if not self._state.open_receive_channels:
207 raise BrokenResourceError
208
209 while self._state.waiting_receivers:
210 receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
211 if not receiver.task_info.has_pending_cancellation():
212 receiver.item = item
213 receive_event.set()
214 return
215
216 if len(self._state.buffer) < self._state.max_buffer_size:
217 self._state.buffer.append(item)
218 else:
219 raise WouldBlock
220
221 async def send(self, item: T_contra) -> None:
222 """
223 Send an item to the stream.
224
225 If the buffer is full, this method blocks until there is again room in the
226 buffer or the item can be sent directly to a receiver.
227
228 :param item: the item to send
229 :raises ~anyio.ClosedResourceError: if this send stream has been closed
230 :raises ~anyio.BrokenResourceError: if the stream has been closed from the
231 receiving end
232
233 """
234 await checkpoint()
235 try:
236 self.send_nowait(item)
237 except WouldBlock:
238 # Wait until there's someone on the receiving end
239 send_event = Event()
240 self._state.waiting_senders[send_event] = item
241 try:
242 await send_event.wait()
243 except BaseException:
244 self._state.waiting_senders.pop(send_event, None)
245 raise
246
247 if send_event in self._state.waiting_senders:
248 del self._state.waiting_senders[send_event]
249 raise BrokenResourceError from None
250
251 def clone(self) -> MemoryObjectSendStream[T_contra]:
252 """
253 Create a clone of this send stream.
254
255 Each clone can be closed separately. Only when all clones have been closed will
256 the sending end of the memory stream be considered closed by the receiving ends.
257
258 :return: the cloned stream
259
260 """
261 if self._closed:
262 raise ClosedResourceError
263
264 return MemoryObjectSendStream(_state=self._state)
265
266 def close(self) -> None:
267 """
268 Close the stream.
269
270 This works the exact same way as :meth:`aclose`, but is provided as a special
271 case for the benefit of synchronous callbacks.
272
273 """
274 if not self._closed:
275 self._closed = True
276 self._state.open_send_channels -= 1
277 if self._state.open_send_channels == 0:
278 receive_events = list(self._state.waiting_receivers.keys())
279 self._state.waiting_receivers.clear()
280 for event in receive_events:
281 event.set()
282
283 async def aclose(self) -> None:
284 self.close()
285
286 def statistics(self) -> MemoryObjectStreamStatistics:
287 """
288 Return statistics about the current state of this stream.
289
290 .. versionadded:: 3.0
291 """
292 return self._state.statistics()
293
294 def __enter__(self) -> MemoryObjectSendStream[T_contra]:
295 return self
296
297 def __exit__(
298 self,
299 exc_type: type[BaseException] | None,
300 exc_val: BaseException | None,
301 exc_tb: TracebackType | None,
302 ) -> None:
303 self.close()
304
305 def __del__(self) -> None:
306 if not self._closed:
307 warnings.warn(
308 f"Unclosed <{self.__class__.__name__}>",
309 ResourceWarning,
310 source=self,
311 )