1import ssl
2import typing
3from typing import Optional
4
5from .._exceptions import ReadError
6from .base import (
7 SOCKET_OPTION,
8 AsyncNetworkBackend,
9 AsyncNetworkStream,
10 NetworkBackend,
11 NetworkStream,
12)
13
14
15class MockSSLObject:
16 def __init__(self, http2: bool):
17 self._http2 = http2
18
19 def selected_alpn_protocol(self) -> str:
20 return "h2" if self._http2 else "http/1.1"
21
22
23class MockStream(NetworkStream):
24 def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
25 self._buffer = buffer
26 self._http2 = http2
27 self._closed = False
28
29 def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
30 if self._closed:
31 raise ReadError("Connection closed")
32 if not self._buffer:
33 return b""
34 return self._buffer.pop(0)
35
36 def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
37 pass
38
39 def close(self) -> None:
40 self._closed = True
41
42 def start_tls(
43 self,
44 ssl_context: ssl.SSLContext,
45 server_hostname: Optional[str] = None,
46 timeout: Optional[float] = None,
47 ) -> NetworkStream:
48 return self
49
50 def get_extra_info(self, info: str) -> typing.Any:
51 return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
52
53 def __repr__(self) -> str:
54 return "<httpcore.MockStream>"
55
56
57class MockBackend(NetworkBackend):
58 def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
59 self._buffer = buffer
60 self._http2 = http2
61
62 def connect_tcp(
63 self,
64 host: str,
65 port: int,
66 timeout: Optional[float] = None,
67 local_address: Optional[str] = None,
68 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
69 ) -> NetworkStream:
70 return MockStream(list(self._buffer), http2=self._http2)
71
72 def connect_unix_socket(
73 self,
74 path: str,
75 timeout: Optional[float] = None,
76 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
77 ) -> NetworkStream:
78 return MockStream(list(self._buffer), http2=self._http2)
79
80 def sleep(self, seconds: float) -> None:
81 pass
82
83
84class AsyncMockStream(AsyncNetworkStream):
85 def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
86 self._buffer = buffer
87 self._http2 = http2
88 self._closed = False
89
90 async def read(self, max_bytes: int, timeout: Optional[float] = None) -> bytes:
91 if self._closed:
92 raise ReadError("Connection closed")
93 if not self._buffer:
94 return b""
95 return self._buffer.pop(0)
96
97 async def write(self, buffer: bytes, timeout: Optional[float] = None) -> None:
98 pass
99
100 async def aclose(self) -> None:
101 self._closed = True
102
103 async def start_tls(
104 self,
105 ssl_context: ssl.SSLContext,
106 server_hostname: Optional[str] = None,
107 timeout: Optional[float] = None,
108 ) -> AsyncNetworkStream:
109 return self
110
111 def get_extra_info(self, info: str) -> typing.Any:
112 return MockSSLObject(http2=self._http2) if info == "ssl_object" else None
113
114 def __repr__(self) -> str:
115 return "<httpcore.AsyncMockStream>"
116
117
118class AsyncMockBackend(AsyncNetworkBackend):
119 def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None:
120 self._buffer = buffer
121 self._http2 = http2
122
123 async def connect_tcp(
124 self,
125 host: str,
126 port: int,
127 timeout: Optional[float] = None,
128 local_address: Optional[str] = None,
129 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
130 ) -> AsyncNetworkStream:
131 return AsyncMockStream(list(self._buffer), http2=self._http2)
132
133 async def connect_unix_socket(
134 self,
135 path: str,
136 timeout: Optional[float] = None,
137 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
138 ) -> AsyncNetworkStream:
139 return AsyncMockStream(list(self._buffer), http2=self._http2)
140
141 async def sleep(self, seconds: float) -> None:
142 pass