1from __future__ import annotations
2
3__all__ = (
4 "MultiListener",
5 "StapledByteStream",
6 "StapledObjectStream",
7)
8
9from collections.abc import Callable, Mapping, Sequence
10from dataclasses import dataclass
11from typing import Any, Generic, TypeVar
12
13from ..abc import (
14 ByteReceiveStream,
15 ByteSendStream,
16 ByteStream,
17 Listener,
18 ObjectReceiveStream,
19 ObjectSendStream,
20 ObjectStream,
21 TaskGroup,
22)
23
24T_Item = TypeVar("T_Item")
25T_Stream = TypeVar("T_Stream")
26
27
28@dataclass(eq=False)
29class StapledByteStream(ByteStream):
30 """
31 Combines two byte streams into a single, bidirectional byte stream.
32
33 Extra attributes will be provided from both streams, with the receive stream
34 providing the values in case of a conflict.
35
36 :param ByteSendStream send_stream: the sending byte stream
37 :param ByteReceiveStream receive_stream: the receiving byte stream
38 """
39
40 send_stream: ByteSendStream
41 receive_stream: ByteReceiveStream
42
43 async def receive(self, max_bytes: int = 65536) -> bytes:
44 return await self.receive_stream.receive(max_bytes)
45
46 async def send(self, item: bytes) -> None:
47 await self.send_stream.send(item)
48
49 async def send_eof(self) -> None:
50 await self.send_stream.aclose()
51
52 async def aclose(self) -> None:
53 await self.send_stream.aclose()
54 await self.receive_stream.aclose()
55
56 @property
57 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
58 return {
59 **self.send_stream.extra_attributes,
60 **self.receive_stream.extra_attributes,
61 }
62
63
64@dataclass(eq=False)
65class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
66 """
67 Combines two object streams into a single, bidirectional object stream.
68
69 Extra attributes will be provided from both streams, with the receive stream
70 providing the values in case of a conflict.
71
72 :param ObjectSendStream send_stream: the sending object stream
73 :param ObjectReceiveStream receive_stream: the receiving object stream
74 """
75
76 send_stream: ObjectSendStream[T_Item]
77 receive_stream: ObjectReceiveStream[T_Item]
78
79 async def receive(self) -> T_Item:
80 return await self.receive_stream.receive()
81
82 async def send(self, item: T_Item) -> None:
83 await self.send_stream.send(item)
84
85 async def send_eof(self) -> None:
86 await self.send_stream.aclose()
87
88 async def aclose(self) -> None:
89 await self.send_stream.aclose()
90 await self.receive_stream.aclose()
91
92 @property
93 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
94 return {
95 **self.send_stream.extra_attributes,
96 **self.receive_stream.extra_attributes,
97 }
98
99
100@dataclass(eq=False)
101class MultiListener(Generic[T_Stream], Listener[T_Stream]):
102 """
103 Combines multiple listeners into one, serving connections from all of them at once.
104
105 Any MultiListeners in the given collection of listeners will have their listeners
106 moved into this one.
107
108 Extra attributes are provided from each listener, with each successive listener
109 overriding any conflicting attributes from the previous one.
110
111 :param listeners: listeners to serve
112 :type listeners: Sequence[Listener[T_Stream]]
113 """
114
115 listeners: Sequence[Listener[T_Stream]]
116
117 def __post_init__(self) -> None:
118 listeners: list[Listener[T_Stream]] = []
119 for listener in self.listeners:
120 if isinstance(listener, MultiListener):
121 listeners.extend(listener.listeners)
122 del listener.listeners[:] # type: ignore[attr-defined]
123 else:
124 listeners.append(listener)
125
126 self.listeners = listeners
127
128 async def serve(
129 self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
130 ) -> None:
131 from .. import create_task_group
132
133 async with create_task_group() as tg:
134 for listener in self.listeners:
135 tg.start_soon(listener.serve, handler, task_group)
136
137 async def aclose(self) -> None:
138 for listener in self.listeners:
139 await listener.aclose()
140
141 @property
142 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
143 attributes: dict = {}
144 for listener in self.listeners:
145 attributes.update(listener.extra_attributes)
146
147 return attributes