1from __future__ import annotations
2
3from collections.abc import Callable, Mapping, Sequence
4from dataclasses import dataclass
5from typing import Any, Generic, TypeVar
6
7from ..abc import (
8 ByteReceiveStream,
9 ByteSendStream,
10 ByteStream,
11 Listener,
12 ObjectReceiveStream,
13 ObjectSendStream,
14 ObjectStream,
15 TaskGroup,
16)
17
18T_Item = TypeVar("T_Item")
19T_Stream = TypeVar("T_Stream")
20
21
22@dataclass(eq=False)
23class StapledByteStream(ByteStream):
24 """
25 Combines two byte streams into a single, bidirectional byte stream.
26
27 Extra attributes will be provided from both streams, with the receive stream
28 providing the values in case of a conflict.
29
30 :param ByteSendStream send_stream: the sending byte stream
31 :param ByteReceiveStream receive_stream: the receiving byte stream
32 """
33
34 send_stream: ByteSendStream
35 receive_stream: ByteReceiveStream
36
37 async def receive(self, max_bytes: int = 65536) -> bytes:
38 return await self.receive_stream.receive(max_bytes)
39
40 async def send(self, item: bytes) -> None:
41 await self.send_stream.send(item)
42
43 async def send_eof(self) -> None:
44 await self.send_stream.aclose()
45
46 async def aclose(self) -> None:
47 await self.send_stream.aclose()
48 await self.receive_stream.aclose()
49
50 @property
51 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
52 return {
53 **self.send_stream.extra_attributes,
54 **self.receive_stream.extra_attributes,
55 }
56
57
58@dataclass(eq=False)
59class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
60 """
61 Combines two object streams into a single, bidirectional object stream.
62
63 Extra attributes will be provided from both streams, with the receive stream
64 providing the values in case of a conflict.
65
66 :param ObjectSendStream send_stream: the sending object stream
67 :param ObjectReceiveStream receive_stream: the receiving object stream
68 """
69
70 send_stream: ObjectSendStream[T_Item]
71 receive_stream: ObjectReceiveStream[T_Item]
72
73 async def receive(self) -> T_Item:
74 return await self.receive_stream.receive()
75
76 async def send(self, item: T_Item) -> None:
77 await self.send_stream.send(item)
78
79 async def send_eof(self) -> None:
80 await self.send_stream.aclose()
81
82 async def aclose(self) -> None:
83 await self.send_stream.aclose()
84 await self.receive_stream.aclose()
85
86 @property
87 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
88 return {
89 **self.send_stream.extra_attributes,
90 **self.receive_stream.extra_attributes,
91 }
92
93
94@dataclass(eq=False)
95class MultiListener(Generic[T_Stream], Listener[T_Stream]):
96 """
97 Combines multiple listeners into one, serving connections from all of them at once.
98
99 Any MultiListeners in the given collection of listeners will have their listeners
100 moved into this one.
101
102 Extra attributes are provided from each listener, with each successive listener
103 overriding any conflicting attributes from the previous one.
104
105 :param listeners: listeners to serve
106 :type listeners: Sequence[Listener[T_Stream]]
107 """
108
109 listeners: Sequence[Listener[T_Stream]]
110
111 def __post_init__(self) -> None:
112 listeners: list[Listener[T_Stream]] = []
113 for listener in self.listeners:
114 if isinstance(listener, MultiListener):
115 listeners.extend(listener.listeners)
116 del listener.listeners[:] # type: ignore[attr-defined]
117 else:
118 listeners.append(listener)
119
120 self.listeners = listeners
121
122 async def serve(
123 self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
124 ) -> None:
125 from .. import create_task_group
126
127 async with create_task_group() as tg:
128 for listener in self.listeners:
129 tg.start_soon(listener.serve, handler, task_group)
130
131 async def aclose(self) -> None:
132 for listener in self.listeners:
133 await listener.aclose()
134
135 @property
136 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
137 attributes: dict = {}
138 for listener in self.listeners:
139 attributes.update(listener.extra_attributes)
140
141 return attributes