1from __future__ import annotations
2
3from collections import OrderedDict, deque
4from dataclasses import dataclass, field
5from types import TracebackType
6from typing import Generic, NamedTuple, TypeVar
7
8from .. import (
9 BrokenResourceError,
10 ClosedResourceError,
11 EndOfStream,
12 WouldBlock,
13)
14from ..abc import Event, ObjectReceiveStream, ObjectSendStream
15from ..lowlevel import checkpoint
16
17T_Item = TypeVar("T_Item")
18T_co = TypeVar("T_co", covariant=True)
19T_contra = TypeVar("T_contra", contravariant=True)
20
21
22class MemoryObjectStreamStatistics(NamedTuple):
23 current_buffer_used: int #: number of items stored in the buffer
24 #: maximum number of items that can be stored on this stream (or :data:`math.inf`)
25 max_buffer_size: float
26 open_send_streams: int #: number of unclosed clones of the send stream
27 open_receive_streams: int #: number of unclosed clones of the receive stream
28 #: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
29 tasks_waiting_send: int
30 #: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
31 tasks_waiting_receive: int
32
33
34@dataclass(eq=False)
35class MemoryObjectStreamState(Generic[T_Item]):
36 max_buffer_size: float = field()
37 buffer: deque[T_Item] = field(init=False, default_factory=deque)
38 open_send_channels: int = field(init=False, default=0)
39 open_receive_channels: int = field(init=False, default=0)
40 waiting_receivers: OrderedDict[Event, list[T_Item]] = field(
41 init=False, default_factory=OrderedDict
42 )
43 waiting_senders: OrderedDict[Event, T_Item] = field(
44 init=False, default_factory=OrderedDict
45 )
46
47 def statistics(self) -> MemoryObjectStreamStatistics:
48 return MemoryObjectStreamStatistics(
49 len(self.buffer),
50 self.max_buffer_size,
51 self.open_send_channels,
52 self.open_receive_channels,
53 len(self.waiting_senders),
54 len(self.waiting_receivers),
55 )
56
57
58@dataclass(eq=False)
59class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
60 _state: MemoryObjectStreamState[T_co]
61 _closed: bool = field(init=False, default=False)
62
63 def __post_init__(self) -> None:
64 self._state.open_receive_channels += 1
65
66 def receive_nowait(self) -> T_co:
67 """
68 Receive the next item if it can be done without waiting.
69
70 :return: the received item
71 :raises ~anyio.ClosedResourceError: if this send stream has been closed
72 :raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
73 closed from the sending end
74 :raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
75 waiting to send
76
77 """
78 if self._closed:
79 raise ClosedResourceError
80
81 if self._state.waiting_senders:
82 # Get the item from the next sender
83 send_event, item = self._state.waiting_senders.popitem(last=False)
84 self._state.buffer.append(item)
85 send_event.set()
86
87 if self._state.buffer:
88 return self._state.buffer.popleft()
89 elif not self._state.open_send_channels:
90 raise EndOfStream
91
92 raise WouldBlock
93
94 async def receive(self) -> T_co:
95 await checkpoint()
96 try:
97 return self.receive_nowait()
98 except WouldBlock:
99 # Add ourselves in the queue
100 receive_event = Event()
101 container: list[T_co] = []
102 self._state.waiting_receivers[receive_event] = container
103
104 try:
105 await receive_event.wait()
106 finally:
107 self._state.waiting_receivers.pop(receive_event, None)
108
109 if container:
110 return container[0]
111 else:
112 raise EndOfStream
113
114 def clone(self) -> MemoryObjectReceiveStream[T_co]:
115 """
116 Create a clone of this receive stream.
117
118 Each clone can be closed separately. Only when all clones have been closed will
119 the receiving end of the memory stream be considered closed by the sending ends.
120
121 :return: the cloned stream
122
123 """
124 if self._closed:
125 raise ClosedResourceError
126
127 return MemoryObjectReceiveStream(_state=self._state)
128
129 def close(self) -> None:
130 """
131 Close the stream.
132
133 This works the exact same way as :meth:`aclose`, but is provided as a special
134 case for the benefit of synchronous callbacks.
135
136 """
137 if not self._closed:
138 self._closed = True
139 self._state.open_receive_channels -= 1
140 if self._state.open_receive_channels == 0:
141 send_events = list(self._state.waiting_senders.keys())
142 for event in send_events:
143 event.set()
144
145 async def aclose(self) -> None:
146 self.close()
147
148 def statistics(self) -> MemoryObjectStreamStatistics:
149 """
150 Return statistics about the current state of this stream.
151
152 .. versionadded:: 3.0
153 """
154 return self._state.statistics()
155
156 def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
157 return self
158
159 def __exit__(
160 self,
161 exc_type: type[BaseException] | None,
162 exc_val: BaseException | None,
163 exc_tb: TracebackType | None,
164 ) -> None:
165 self.close()
166
167
168@dataclass(eq=False)
169class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
170 _state: MemoryObjectStreamState[T_contra]
171 _closed: bool = field(init=False, default=False)
172
173 def __post_init__(self) -> None:
174 self._state.open_send_channels += 1
175
176 def send_nowait(self, item: T_contra) -> None:
177 """
178 Send an item immediately if it can be done without waiting.
179
180 :param item: the item to send
181 :raises ~anyio.ClosedResourceError: if this send stream has been closed
182 :raises ~anyio.BrokenResourceError: if the stream has been closed from the
183 receiving end
184 :raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
185 to receive
186
187 """
188 if self._closed:
189 raise ClosedResourceError
190 if not self._state.open_receive_channels:
191 raise BrokenResourceError
192
193 if self._state.waiting_receivers:
194 receive_event, container = self._state.waiting_receivers.popitem(last=False)
195 container.append(item)
196 receive_event.set()
197 elif len(self._state.buffer) < self._state.max_buffer_size:
198 self._state.buffer.append(item)
199 else:
200 raise WouldBlock
201
202 async def send(self, item: T_contra) -> None:
203 """
204 Send an item to the stream.
205
206 If the buffer is full, this method blocks until there is again room in the
207 buffer or the item can be sent directly to a receiver.
208
209 :param item: the item to send
210 :raises ~anyio.ClosedResourceError: if this send stream has been closed
211 :raises ~anyio.BrokenResourceError: if the stream has been closed from the
212 receiving end
213
214 """
215 await checkpoint()
216 try:
217 self.send_nowait(item)
218 except WouldBlock:
219 # Wait until there's someone on the receiving end
220 send_event = Event()
221 self._state.waiting_senders[send_event] = item
222 try:
223 await send_event.wait()
224 except BaseException:
225 self._state.waiting_senders.pop(send_event, None)
226 raise
227
228 if self._state.waiting_senders.pop(send_event, None):
229 raise BrokenResourceError from None
230
231 def clone(self) -> MemoryObjectSendStream[T_contra]:
232 """
233 Create a clone of this send stream.
234
235 Each clone can be closed separately. Only when all clones have been closed will
236 the sending end of the memory stream be considered closed by the receiving ends.
237
238 :return: the cloned stream
239
240 """
241 if self._closed:
242 raise ClosedResourceError
243
244 return MemoryObjectSendStream(_state=self._state)
245
246 def close(self) -> None:
247 """
248 Close the stream.
249
250 This works the exact same way as :meth:`aclose`, but is provided as a special
251 case for the benefit of synchronous callbacks.
252
253 """
254 if not self._closed:
255 self._closed = True
256 self._state.open_send_channels -= 1
257 if self._state.open_send_channels == 0:
258 receive_events = list(self._state.waiting_receivers.keys())
259 self._state.waiting_receivers.clear()
260 for event in receive_events:
261 event.set()
262
263 async def aclose(self) -> None:
264 self.close()
265
266 def statistics(self) -> MemoryObjectStreamStatistics:
267 """
268 Return statistics about the current state of this stream.
269
270 .. versionadded:: 3.0
271 """
272 return self._state.statistics()
273
274 def __enter__(self) -> MemoryObjectSendStream[T_contra]:
275 return self
276
277 def __exit__(
278 self,
279 exc_type: type[BaseException] | None,
280 exc_val: BaseException | None,
281 exc_tb: TracebackType | None,
282 ) -> None:
283 self.close()