1from __future__ import annotations
2
3import ssl
4import typing
5
6from .._exceptions import ReadError
7from .base import (
8 SOCKET_OPTION,
9 AsyncNetworkBackend,
10 AsyncNetworkStream,
11 NetworkBackend,
12 NetworkStream,
13)
14
15
16class MockSSLObject:
17 def __init__(self, http2: bool):
18 self._http2 = http2
19
20 def selected_alpn_protocol(self) -> str:
21 return "h2" if self._http2 else "http/1.1"
22
23
24class MockStream(NetworkStream):
25 def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
26 self._buffer = buffer
27 self._http2 = http2
28 self._closed = False
29
30 def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
31 if self._closed:
32 raise ReadError("Connection closed")
33 if not self._buffer:
34 return b""
35 return self._buffer.pop(0)
36
37 def write(self, buffer: bytes, timeout: float | None = None) -> None:
38 pass
39
40 def close(self) -> None:
41 self._closed = True
42
43 def start_tls(
44 self,
45 ssl_context: ssl.SSLContext,
46 server_hostname: str | None = None,
47 timeout: float | None = None,
48 ) -> NetworkStream:
49 return self
50
51 def get_extra_info(self, info: str) -> typing.Any:
52 return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
53
54 def __repr__(self) -> str:
55 return "<httpcore.MockStream>"
56
57
58class MockBackend(NetworkBackend):
59 def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
60 self._buffer = buffer
61 self._http2 = http2
62
63 def connect_tcp(
64 self,
65 host: str,
66 port: int,
67 timeout: float | None = None,
68 local_address: str | None = None,
69 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
70 ) -> NetworkStream:
71 return MockStream(list(self._buffer), http2=self._http2)
72
73 def connect_unix_socket(
74 self,
75 path: str,
76 timeout: float | None = None,
77 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
78 ) -> NetworkStream:
79 return MockStream(list(self._buffer), http2=self._http2)
80
81 def sleep(self, seconds: float) -> None:
82 pass
83
84
85class AsyncMockStream(AsyncNetworkStream):
86 def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
87 self._buffer = buffer
88 self._http2 = http2
89 self._closed = False
90
91 async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
92 if self._closed:
93 raise ReadError("Connection closed")
94 if not self._buffer:
95 return b""
96 return self._buffer.pop(0)
97
98 async def write(self, buffer: bytes, timeout: float | None = None) -> None:
99 pass
100
101 async def aclose(self) -> None:
102 self._closed = True
103
104 async def start_tls(
105 self,
106 ssl_context: ssl.SSLContext,
107 server_hostname: str | None = None,
108 timeout: float | None = None,
109 ) -> AsyncNetworkStream:
110 return self
111
112 def get_extra_info(self, info: str) -> typing.Any:
113 return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
114
115 def __repr__(self) -> str:
116 return "<httpcore.AsyncMockStream>"
117
118
119class AsyncMockBackend(AsyncNetworkBackend):
120 def __init__(self, buffer: list[bytes], http2: bool = False) -> None:
121 self._buffer = buffer
122 self._http2 = http2
123
124 async def connect_tcp(
125 self,
126 host: str,
127 port: int,
128 timeout: float | None = None,
129 local_address: str | None = None,
130 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
131 ) -> AsyncNetworkStream:
132 return AsyncMockStream(list(self._buffer), http2=self._http2)
133
134 async def connect_unix_socket(
135 self,
136 path: str,
137 timeout: float | None = None,
138 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
139 ) -> AsyncNetworkStream:
140 return AsyncMockStream(list(self._buffer), http2=self._http2)
141
142 async def sleep(self, seconds: float) -> None:
143 pass