Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/testclient.py: 23%
324 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import contextlib
2import inspect
3import io
4import json
5import math
6import queue
7import sys
8import typing
9import warnings
10from concurrent.futures import Future
11from types import GeneratorType
12from urllib.parse import unquote, urljoin
14import anyio
15import anyio.from_thread
16import httpx
17from anyio.streams.stapled import StapledObjectStream
19from starlette._utils import is_async_callable
20from starlette.types import ASGIApp, Message, Receive, Scope, Send
21from starlette.websockets import WebSocketDisconnect
23if sys.version_info >= (3, 8): # pragma: no cover
24 from typing import TypedDict
25else: # pragma: no cover
26 from typing_extensions import TypedDict
28_PortalFactoryType = typing.Callable[
29 [], typing.ContextManager[anyio.abc.BlockingPortal]
30]
32ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
33ASGI2App = typing.Callable[[Scope], ASGIInstance]
34ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
37_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str]]]
40def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
41 if inspect.isclass(app):
42 return hasattr(app, "__await__")
43 return is_async_callable(app)
46class _WrapASGI2:
47 """
48 Provide an ASGI3 interface onto an ASGI2 app.
49 """
51 def __init__(self, app: ASGI2App) -> None:
52 self.app = app
54 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
55 instance = self.app(scope)
56 await instance(receive, send)
59class _AsyncBackend(TypedDict):
60 backend: str
61 backend_options: typing.Dict[str, typing.Any]
64class _Upgrade(Exception):
65 def __init__(self, session: "WebSocketTestSession") -> None:
66 self.session = session
69class WebSocketTestSession:
70 def __init__(
71 self,
72 app: ASGI3App,
73 scope: Scope,
74 portal_factory: _PortalFactoryType,
75 ) -> None:
76 self.app = app
77 self.scope = scope
78 self.accepted_subprotocol = None
79 self.portal_factory = portal_factory
80 self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
81 self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()
82 self.extra_headers = None
84 def __enter__(self) -> "WebSocketTestSession":
85 self.exit_stack = contextlib.ExitStack()
86 self.portal = self.exit_stack.enter_context(self.portal_factory())
88 try:
89 _: "Future[None]" = self.portal.start_task_soon(self._run)
90 self.send({"type": "websocket.connect"})
91 message = self.receive()
92 self._raise_on_close(message)
93 except Exception:
94 self.exit_stack.close()
95 raise
96 self.accepted_subprotocol = message.get("subprotocol", None)
97 self.extra_headers = message.get("headers", None)
98 return self
100 def __exit__(self, *args: typing.Any) -> None:
101 try:
102 self.close(1000)
103 finally:
104 self.exit_stack.close()
105 while not self._send_queue.empty():
106 message = self._send_queue.get()
107 if isinstance(message, BaseException):
108 raise message
110 async def _run(self) -> None:
111 """
112 The sub-thread in which the websocket session runs.
113 """
114 scope = self.scope
115 receive = self._asgi_receive
116 send = self._asgi_send
117 try:
118 await self.app(scope, receive, send)
119 except BaseException as exc:
120 self._send_queue.put(exc)
121 raise
123 async def _asgi_receive(self) -> Message:
124 while self._receive_queue.empty():
125 await anyio.sleep(0)
126 return self._receive_queue.get()
128 async def _asgi_send(self, message: Message) -> None:
129 self._send_queue.put(message)
131 def _raise_on_close(self, message: Message) -> None:
132 if message["type"] == "websocket.close":
133 raise WebSocketDisconnect(
134 message.get("code", 1000), message.get("reason", "")
135 )
137 def send(self, message: Message) -> None:
138 self._receive_queue.put(message)
140 def send_text(self, data: str) -> None:
141 self.send({"type": "websocket.receive", "text": data})
143 def send_bytes(self, data: bytes) -> None:
144 self.send({"type": "websocket.receive", "bytes": data})
146 def send_json(self, data: typing.Any, mode: str = "text") -> None:
147 assert mode in ["text", "binary"]
148 text = json.dumps(data)
149 if mode == "text":
150 self.send({"type": "websocket.receive", "text": text})
151 else:
152 self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
154 def close(self, code: int = 1000) -> None:
155 self.send({"type": "websocket.disconnect", "code": code})
157 def receive(self) -> Message:
158 message = self._send_queue.get()
159 if isinstance(message, BaseException):
160 raise message
161 return message
163 def receive_text(self) -> str:
164 message = self.receive()
165 self._raise_on_close(message)
166 return message["text"]
168 def receive_bytes(self) -> bytes:
169 message = self.receive()
170 self._raise_on_close(message)
171 return message["bytes"]
173 def receive_json(self, mode: str = "text") -> typing.Any:
174 assert mode in ["text", "binary"]
175 message = self.receive()
176 self._raise_on_close(message)
177 if mode == "text":
178 text = message["text"]
179 else:
180 text = message["bytes"].decode("utf-8")
181 return json.loads(text)
184class _TestClientTransport(httpx.BaseTransport):
185 def __init__(
186 self,
187 app: ASGI3App,
188 portal_factory: _PortalFactoryType,
189 raise_server_exceptions: bool = True,
190 root_path: str = "",
191 ) -> None:
192 self.app = app
193 self.raise_server_exceptions = raise_server_exceptions
194 self.root_path = root_path
195 self.portal_factory = portal_factory
197 def handle_request(self, request: httpx.Request) -> httpx.Response:
198 scheme = request.url.scheme
199 netloc = request.url.netloc.decode(encoding="ascii")
200 path = request.url.path
201 raw_path = request.url.raw_path
202 query = request.url.query.decode(encoding="ascii")
204 default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
206 if ":" in netloc:
207 host, port_string = netloc.split(":", 1)
208 port = int(port_string)
209 else:
210 host = netloc
211 port = default_port
213 # Include the 'host' header.
214 if "host" in request.headers:
215 headers: typing.List[typing.Tuple[bytes, bytes]] = []
216 elif port == default_port: # pragma: no cover
217 headers = [(b"host", host.encode())]
218 else: # pragma: no cover
219 headers = [(b"host", (f"{host}:{port}").encode())]
221 # Include other request headers.
222 headers += [
223 (key.lower().encode(), value.encode())
224 for key, value in request.headers.items()
225 ]
227 scope: typing.Dict[str, typing.Any]
229 if scheme in {"ws", "wss"}:
230 subprotocol = request.headers.get("sec-websocket-protocol", None)
231 if subprotocol is None:
232 subprotocols: typing.Sequence[str] = []
233 else:
234 subprotocols = [value.strip() for value in subprotocol.split(",")]
235 scope = {
236 "type": "websocket",
237 "path": unquote(path),
238 "raw_path": raw_path,
239 "root_path": self.root_path,
240 "scheme": scheme,
241 "query_string": query.encode(),
242 "headers": headers,
243 "client": ["testclient", 50000],
244 "server": [host, port],
245 "subprotocols": subprotocols,
246 }
247 session = WebSocketTestSession(self.app, scope, self.portal_factory)
248 raise _Upgrade(session)
250 scope = {
251 "type": "http",
252 "http_version": "1.1",
253 "method": request.method,
254 "path": unquote(path),
255 "raw_path": raw_path,
256 "root_path": self.root_path,
257 "scheme": scheme,
258 "query_string": query.encode(),
259 "headers": headers,
260 "client": ["testclient", 50000],
261 "server": [host, port],
262 "extensions": {"http.response.debug": {}},
263 }
265 request_complete = False
266 response_started = False
267 response_complete: anyio.Event
268 raw_kwargs: typing.Dict[str, typing.Any] = {"stream": io.BytesIO()}
269 template = None
270 context = None
272 async def receive() -> Message:
273 nonlocal request_complete
275 if request_complete:
276 if not response_complete.is_set():
277 await response_complete.wait()
278 return {"type": "http.disconnect"}
280 body = request.read()
281 if isinstance(body, str):
282 body_bytes: bytes = body.encode("utf-8") # pragma: no cover
283 elif body is None:
284 body_bytes = b"" # pragma: no cover
285 elif isinstance(body, GeneratorType):
286 try: # pragma: no cover
287 chunk = body.send(None)
288 if isinstance(chunk, str):
289 chunk = chunk.encode("utf-8")
290 return {"type": "http.request", "body": chunk, "more_body": True}
291 except StopIteration: # pragma: no cover
292 request_complete = True
293 return {"type": "http.request", "body": b""}
294 else:
295 body_bytes = body
297 request_complete = True
298 return {"type": "http.request", "body": body_bytes}
300 async def send(message: Message) -> None:
301 nonlocal raw_kwargs, response_started, template, context
303 if message["type"] == "http.response.start":
304 assert (
305 not response_started
306 ), 'Received multiple "http.response.start" messages.'
307 raw_kwargs["status_code"] = message["status"]
308 raw_kwargs["headers"] = [
309 (key.decode(), value.decode())
310 for key, value in message.get("headers", [])
311 ]
312 response_started = True
313 elif message["type"] == "http.response.body":
314 assert (
315 response_started
316 ), 'Received "http.response.body" without "http.response.start".'
317 assert (
318 not response_complete.is_set()
319 ), 'Received "http.response.body" after response completed.'
320 body = message.get("body", b"")
321 more_body = message.get("more_body", False)
322 if request.method != "HEAD":
323 raw_kwargs["stream"].write(body)
324 if not more_body:
325 raw_kwargs["stream"].seek(0)
326 response_complete.set()
327 elif message["type"] == "http.response.debug":
328 template = message["info"]["template"]
329 context = message["info"]["context"]
331 try:
332 with self.portal_factory() as portal:
333 response_complete = portal.call(anyio.Event)
334 portal.call(self.app, scope, receive, send)
335 except BaseException as exc:
336 if self.raise_server_exceptions:
337 raise exc
339 if self.raise_server_exceptions:
340 assert response_started, "TestClient did not receive any response."
341 elif not response_started:
342 raw_kwargs = {
343 "status_code": 500,
344 "headers": [],
345 "stream": io.BytesIO(),
346 }
348 raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
350 response = httpx.Response(**raw_kwargs, request=request)
351 if template is not None:
352 response.template = template # type: ignore[attr-defined]
353 response.context = context # type: ignore[attr-defined]
354 return response
357class TestClient(httpx.Client):
358 __test__ = False
359 task: "Future[None]"
360 portal: typing.Optional[anyio.abc.BlockingPortal] = None
362 def __init__(
363 self,
364 app: ASGIApp,
365 base_url: str = "http://testserver",
366 raise_server_exceptions: bool = True,
367 root_path: str = "",
368 backend: str = "asyncio",
369 backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
370 cookies: httpx._client.CookieTypes = None,
371 headers: typing.Dict[str, str] = None,
372 ) -> None:
373 self.async_backend = _AsyncBackend(
374 backend=backend, backend_options=backend_options or {}
375 )
376 if _is_asgi3(app):
377 app = typing.cast(ASGI3App, app)
378 asgi_app = app
379 else:
380 app = typing.cast(ASGI2App, app) # type: ignore[assignment]
381 asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
382 self.app = asgi_app
383 transport = _TestClientTransport(
384 self.app,
385 portal_factory=self._portal_factory,
386 raise_server_exceptions=raise_server_exceptions,
387 root_path=root_path,
388 )
389 if headers is None:
390 headers = {}
391 headers.setdefault("user-agent", "testclient")
392 super().__init__(
393 app=self.app,
394 base_url=base_url,
395 headers=headers,
396 transport=transport,
397 follow_redirects=True,
398 cookies=cookies,
399 )
401 @contextlib.contextmanager
402 def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
403 if self.portal is not None:
404 yield self.portal
405 else:
406 with anyio.from_thread.start_blocking_portal(
407 **self.async_backend
408 ) as portal:
409 yield portal
411 def _choose_redirect_arg(
412 self,
413 follow_redirects: typing.Optional[bool],
414 allow_redirects: typing.Optional[bool],
415 ) -> typing.Union[bool, httpx._client.UseClientDefault]:
416 redirect: typing.Union[
417 bool, httpx._client.UseClientDefault
418 ] = httpx._client.USE_CLIENT_DEFAULT
419 if allow_redirects is not None:
420 message = (
421 "The `allow_redirects` argument is deprecated. "
422 "Use `follow_redirects` instead."
423 )
424 warnings.warn(message, DeprecationWarning)
425 redirect = allow_redirects
426 if follow_redirects is not None:
427 redirect = follow_redirects
428 elif allow_redirects is not None and follow_redirects is not None:
429 raise RuntimeError( # pragma: no cover
430 "Cannot use both `allow_redirects` and `follow_redirects`."
431 )
432 return redirect
434 def request( # type: ignore[override]
435 self,
436 method: str,
437 url: httpx._types.URLTypes,
438 *,
439 content: typing.Optional[httpx._types.RequestContent] = None,
440 data: typing.Optional[_RequestData] = None,
441 files: typing.Optional[httpx._types.RequestFiles] = None,
442 json: typing.Any = None,
443 params: typing.Optional[httpx._types.QueryParamTypes] = None,
444 headers: typing.Optional[httpx._types.HeaderTypes] = None,
445 cookies: typing.Optional[httpx._types.CookieTypes] = None,
446 auth: typing.Union[
447 httpx._types.AuthTypes, httpx._client.UseClientDefault
448 ] = httpx._client.USE_CLIENT_DEFAULT,
449 follow_redirects: typing.Optional[bool] = None,
450 allow_redirects: typing.Optional[bool] = None,
451 timeout: typing.Union[
452 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
453 ] = httpx._client.USE_CLIENT_DEFAULT,
454 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
455 ) -> httpx.Response:
456 url = self.base_url.join(url)
457 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
458 return super().request(
459 method,
460 url,
461 content=content,
462 data=data, # type: ignore[arg-type]
463 files=files,
464 json=json,
465 params=params,
466 headers=headers,
467 cookies=cookies,
468 auth=auth,
469 follow_redirects=redirect,
470 timeout=timeout,
471 extensions=extensions,
472 )
474 def get( # type: ignore[override]
475 self,
476 url: httpx._types.URLTypes,
477 *,
478 params: typing.Optional[httpx._types.QueryParamTypes] = None,
479 headers: typing.Optional[httpx._types.HeaderTypes] = None,
480 cookies: typing.Optional[httpx._types.CookieTypes] = None,
481 auth: typing.Union[
482 httpx._types.AuthTypes, httpx._client.UseClientDefault
483 ] = httpx._client.USE_CLIENT_DEFAULT,
484 follow_redirects: typing.Optional[bool] = None,
485 allow_redirects: typing.Optional[bool] = None,
486 timeout: typing.Union[
487 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
488 ] = httpx._client.USE_CLIENT_DEFAULT,
489 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
490 ) -> httpx.Response:
491 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
492 return super().get(
493 url,
494 params=params,
495 headers=headers,
496 cookies=cookies,
497 auth=auth,
498 follow_redirects=redirect,
499 timeout=timeout,
500 extensions=extensions,
501 )
503 def options( # type: ignore[override]
504 self,
505 url: httpx._types.URLTypes,
506 *,
507 params: typing.Optional[httpx._types.QueryParamTypes] = None,
508 headers: typing.Optional[httpx._types.HeaderTypes] = None,
509 cookies: typing.Optional[httpx._types.CookieTypes] = None,
510 auth: typing.Union[
511 httpx._types.AuthTypes, httpx._client.UseClientDefault
512 ] = httpx._client.USE_CLIENT_DEFAULT,
513 follow_redirects: typing.Optional[bool] = None,
514 allow_redirects: typing.Optional[bool] = None,
515 timeout: typing.Union[
516 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
517 ] = httpx._client.USE_CLIENT_DEFAULT,
518 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
519 ) -> httpx.Response:
520 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
521 return super().options(
522 url,
523 params=params,
524 headers=headers,
525 cookies=cookies,
526 auth=auth,
527 follow_redirects=redirect,
528 timeout=timeout,
529 extensions=extensions,
530 )
532 def head( # type: ignore[override]
533 self,
534 url: httpx._types.URLTypes,
535 *,
536 params: typing.Optional[httpx._types.QueryParamTypes] = None,
537 headers: typing.Optional[httpx._types.HeaderTypes] = None,
538 cookies: typing.Optional[httpx._types.CookieTypes] = None,
539 auth: typing.Union[
540 httpx._types.AuthTypes, httpx._client.UseClientDefault
541 ] = httpx._client.USE_CLIENT_DEFAULT,
542 follow_redirects: typing.Optional[bool] = None,
543 allow_redirects: typing.Optional[bool] = None,
544 timeout: typing.Union[
545 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
546 ] = httpx._client.USE_CLIENT_DEFAULT,
547 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
548 ) -> httpx.Response:
549 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
550 return super().head(
551 url,
552 params=params,
553 headers=headers,
554 cookies=cookies,
555 auth=auth,
556 follow_redirects=redirect,
557 timeout=timeout,
558 extensions=extensions,
559 )
561 def post( # type: ignore[override]
562 self,
563 url: httpx._types.URLTypes,
564 *,
565 content: typing.Optional[httpx._types.RequestContent] = None,
566 data: typing.Optional[_RequestData] = None,
567 files: typing.Optional[httpx._types.RequestFiles] = None,
568 json: typing.Any = None,
569 params: typing.Optional[httpx._types.QueryParamTypes] = None,
570 headers: typing.Optional[httpx._types.HeaderTypes] = None,
571 cookies: typing.Optional[httpx._types.CookieTypes] = None,
572 auth: typing.Union[
573 httpx._types.AuthTypes, httpx._client.UseClientDefault
574 ] = httpx._client.USE_CLIENT_DEFAULT,
575 follow_redirects: typing.Optional[bool] = None,
576 allow_redirects: typing.Optional[bool] = None,
577 timeout: typing.Union[
578 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
579 ] = httpx._client.USE_CLIENT_DEFAULT,
580 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
581 ) -> httpx.Response:
582 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
583 return super().post(
584 url,
585 content=content,
586 data=data, # type: ignore[arg-type]
587 files=files,
588 json=json,
589 params=params,
590 headers=headers,
591 cookies=cookies,
592 auth=auth,
593 follow_redirects=redirect,
594 timeout=timeout,
595 extensions=extensions,
596 )
598 def put( # type: ignore[override]
599 self,
600 url: httpx._types.URLTypes,
601 *,
602 content: typing.Optional[httpx._types.RequestContent] = None,
603 data: typing.Optional[_RequestData] = None,
604 files: typing.Optional[httpx._types.RequestFiles] = None,
605 json: typing.Any = None,
606 params: typing.Optional[httpx._types.QueryParamTypes] = None,
607 headers: typing.Optional[httpx._types.HeaderTypes] = None,
608 cookies: typing.Optional[httpx._types.CookieTypes] = None,
609 auth: typing.Union[
610 httpx._types.AuthTypes, httpx._client.UseClientDefault
611 ] = httpx._client.USE_CLIENT_DEFAULT,
612 follow_redirects: typing.Optional[bool] = None,
613 allow_redirects: typing.Optional[bool] = None,
614 timeout: typing.Union[
615 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
616 ] = httpx._client.USE_CLIENT_DEFAULT,
617 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
618 ) -> httpx.Response:
619 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
620 return super().put(
621 url,
622 content=content,
623 data=data, # type: ignore[arg-type]
624 files=files,
625 json=json,
626 params=params,
627 headers=headers,
628 cookies=cookies,
629 auth=auth,
630 follow_redirects=redirect,
631 timeout=timeout,
632 extensions=extensions,
633 )
635 def patch( # type: ignore[override]
636 self,
637 url: httpx._types.URLTypes,
638 *,
639 content: typing.Optional[httpx._types.RequestContent] = None,
640 data: typing.Optional[_RequestData] = None,
641 files: typing.Optional[httpx._types.RequestFiles] = None,
642 json: typing.Any = None,
643 params: typing.Optional[httpx._types.QueryParamTypes] = None,
644 headers: typing.Optional[httpx._types.HeaderTypes] = None,
645 cookies: typing.Optional[httpx._types.CookieTypes] = None,
646 auth: typing.Union[
647 httpx._types.AuthTypes, httpx._client.UseClientDefault
648 ] = httpx._client.USE_CLIENT_DEFAULT,
649 follow_redirects: typing.Optional[bool] = None,
650 allow_redirects: typing.Optional[bool] = None,
651 timeout: typing.Union[
652 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
653 ] = httpx._client.USE_CLIENT_DEFAULT,
654 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
655 ) -> httpx.Response:
656 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
657 return super().patch(
658 url,
659 content=content,
660 data=data, # type: ignore[arg-type]
661 files=files,
662 json=json,
663 params=params,
664 headers=headers,
665 cookies=cookies,
666 auth=auth,
667 follow_redirects=redirect,
668 timeout=timeout,
669 extensions=extensions,
670 )
672 def delete( # type: ignore[override]
673 self,
674 url: httpx._types.URLTypes,
675 *,
676 params: typing.Optional[httpx._types.QueryParamTypes] = None,
677 headers: typing.Optional[httpx._types.HeaderTypes] = None,
678 cookies: typing.Optional[httpx._types.CookieTypes] = None,
679 auth: typing.Union[
680 httpx._types.AuthTypes, httpx._client.UseClientDefault
681 ] = httpx._client.USE_CLIENT_DEFAULT,
682 follow_redirects: typing.Optional[bool] = None,
683 allow_redirects: typing.Optional[bool] = None,
684 timeout: typing.Union[
685 httpx._client.TimeoutTypes, httpx._client.UseClientDefault
686 ] = httpx._client.USE_CLIENT_DEFAULT,
687 extensions: typing.Optional[typing.Dict[str, typing.Any]] = None,
688 ) -> httpx.Response:
689 redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
690 return super().delete(
691 url,
692 params=params,
693 headers=headers,
694 cookies=cookies,
695 auth=auth,
696 follow_redirects=redirect,
697 timeout=timeout,
698 extensions=extensions,
699 )
701 def websocket_connect(
702 self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
703 ) -> typing.Any:
704 url = urljoin("ws://testserver", url)
705 headers = kwargs.get("headers", {})
706 headers.setdefault("connection", "upgrade")
707 headers.setdefault("sec-websocket-key", "testserver==")
708 headers.setdefault("sec-websocket-version", "13")
709 if subprotocols is not None:
710 headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
711 kwargs["headers"] = headers
712 try:
713 super().request("GET", url, **kwargs)
714 except _Upgrade as exc:
715 session = exc.session
716 else:
717 raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
719 return session
721 def __enter__(self) -> "TestClient":
722 with contextlib.ExitStack() as stack:
723 self.portal = portal = stack.enter_context(
724 anyio.from_thread.start_blocking_portal(**self.async_backend)
725 )
727 @stack.callback
728 def reset_portal() -> None:
729 self.portal = None
731 self.stream_send = StapledObjectStream(
732 *anyio.create_memory_object_stream(math.inf)
733 )
734 self.stream_receive = StapledObjectStream(
735 *anyio.create_memory_object_stream(math.inf)
736 )
737 self.task = portal.start_task_soon(self.lifespan)
738 portal.call(self.wait_startup)
740 @stack.callback
741 def wait_shutdown() -> None:
742 portal.call(self.wait_shutdown)
744 self.exit_stack = stack.pop_all()
746 return self
748 def __exit__(self, *args: typing.Any) -> None:
749 self.exit_stack.close()
751 async def lifespan(self) -> None:
752 scope = {"type": "lifespan"}
753 try:
754 await self.app(scope, self.stream_receive.receive, self.stream_send.send)
755 finally:
756 await self.stream_send.send(None)
758 async def wait_startup(self) -> None:
759 await self.stream_receive.send({"type": "lifespan.startup"})
761 async def receive() -> typing.Any:
762 message = await self.stream_send.receive()
763 if message is None:
764 self.task.result()
765 return message
767 message = await receive()
768 assert message["type"] in (
769 "lifespan.startup.complete",
770 "lifespan.startup.failed",
771 )
772 if message["type"] == "lifespan.startup.failed":
773 await receive()
775 async def wait_shutdown(self) -> None:
776 async def receive() -> typing.Any:
777 message = await self.stream_send.receive()
778 if message is None:
779 self.task.result()
780 return message
782 async with self.stream_send:
783 await self.stream_receive.send({"type": "lifespan.shutdown"})
784 message = await receive()
785 assert message["type"] in (
786 "lifespan.shutdown.complete",
787 "lifespan.shutdown.failed",
788 )
789 if message["type"] == "lifespan.shutdown.failed":
790 await receive()