1"""
2This module defines a Validator interface with base functionality that can be subclassed
3for custom validators provided to the RequestValidationMiddleware.
4"""
5import copy
6import json
7import typing as t
8
9from starlette.datastructures import Headers, MutableHeaders
10from starlette.types import Receive, Scope, Send
11
12from connexion.exceptions import BadRequestProblem
13
14
15class AbstractRequestBodyValidator:
16 """
17 Validator interface with base functionality that can be subclassed for custom validators.
18
19 .. note: Validators load the whole body into memory, which can be a problem for large payloads.
20 """
21
22 MUTABLE_VALIDATION = False
23 """
24 Whether mutations to the body during validation should be transmitted via the receive channel.
25 Note that this does not apply to the substitution of a missing body with the default body, which always
26 updates the receive channel.
27 """
28 MAX_MESSAGE_LENGTH = 256000
29 """Maximum message length that will be sent via the receive channel for mutated bodies."""
30
31 def __init__(
32 self,
33 *,
34 schema: dict,
35 required: bool = False,
36 nullable: bool = False,
37 encoding: str,
38 strict_validation: bool,
39 **kwargs,
40 ):
41 """
42 :param schema: Schema of operation to validate
43 :param required: Whether RequestBody is required
44 :param nullable: Whether RequestBody is nullable
45 :param encoding: Encoding of body (passed via Content-Type header)
46 :param kwargs: Additional arguments for subclasses
47 :param strict_validation: Whether to allow parameters not defined in the spec
48 """
49 self._schema = schema
50 self._nullable = nullable
51 self._required = required
52 self._encoding = encoding
53 self._strict_validation = strict_validation
54
55 async def _parse(
56 self, stream: t.AsyncGenerator[bytes, None], scope: Scope
57 ) -> t.Any:
58 """Parse the incoming stream."""
59
60 def _validate(self, body: t.Any) -> t.Optional[dict]:
61 """
62 Validate the parsed body.
63
64 :raises: :class:`connexion.exceptions.BadRequestProblem`
65 """
66
67 def _insert_body(self, receive: Receive, *, body: t.Any, scope: Scope) -> Receive:
68 """
69 Insert messages transmitting the body at the start of the `receive` channel.
70
71 This method updates the provided `scope` in place with the right `Content-Length` header.
72 """
73 if body is None:
74 return receive
75
76 bytes_body = json.dumps(body).encode(self._encoding)
77
78 # Update the content-length header
79 new_scope = scope.copy()
80 new_scope["headers"] = copy.deepcopy(scope["headers"])
81 headers = MutableHeaders(scope=new_scope)
82 headers["content-length"] = str(len(bytes_body))
83
84 # Wrap in new receive channel
85 messages = (
86 {
87 "type": "http.request",
88 "body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
89 "more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
90 }
91 for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
92 )
93
94 receive = self._insert_messages(receive, messages=messages)
95
96 return receive
97
98 @staticmethod
99 def _insert_messages(
100 receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
101 ) -> Receive:
102 """Insert messages at the start of the `receive` channel."""
103 # Ensure that messages is an iterator so each message is replayed once.
104 message_iterator = iter(messages)
105
106 async def receive_() -> t.MutableMapping[str, t.Any]:
107 try:
108 return next(message_iterator)
109 except StopIteration:
110 return await receive()
111
112 return receive_
113
114 async def wrap_receive(self, receive: Receive, *, scope: Scope) -> Receive:
115 """
116 Wrap the provided `receive` channel with request body validation.
117
118 This method updates the provided `scope` in place with the right `Content-Length` header.
119 """
120 # Handle missing bodies
121 headers = Headers(scope=scope)
122 if not int(headers.get("content-length", 0)):
123 body = self._schema.get("default")
124 if body is None and self._required:
125 raise BadRequestProblem("RequestBody is required")
126 # The default body is encoded as a `receive` channel to mimic an incoming body
127 receive = self._insert_body(receive, body=body, scope=scope)
128
129 # The receive channel is converted to a stream for convenient access
130 messages = []
131
132 async def stream() -> t.AsyncGenerator[bytes, None]:
133 more_body = True
134 while more_body:
135 message = await receive()
136 messages.append(message)
137 more_body = message.get("more_body", False)
138 yield message.get("body", b"")
139 yield b""
140
141 # The body is parsed and validated
142 body = await self._parse(stream(), scope=scope)
143 if not (body is None and self._nullable):
144 self._validate(body)
145
146 # If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
147 if self.MUTABLE_VALIDATION:
148 # Include changes made during validation
149 receive = self._insert_body(receive, body=body, scope=scope)
150 else:
151 # Serialize original messages
152 receive = self._insert_messages(receive, messages=messages)
153
154 return receive
155
156
157class AbstractResponseBodyValidator:
158 """
159 Validator interface with base functionality that can be subclassed for custom validators.
160
161 .. note: Validators load the whole body into memory, which can be a problem for large payloads.
162 """
163
164 def __init__(
165 self,
166 scope: Scope,
167 *,
168 schema: dict,
169 nullable: bool = False,
170 encoding: str,
171 ) -> None:
172 self._scope = scope
173 self._schema = schema
174 self._nullable = nullable
175 self._encoding = encoding
176
177 def _parse(self, stream: t.Generator[bytes, None, None]) -> t.Any:
178 """Parse the incoming stream."""
179
180 def _validate(self, body: t.Any) -> t.Optional[dict]:
181 """
182 Validate the body.
183
184 :raises: :class:`connexion.exceptions.NonConformingResponse`
185 """
186
187 def wrap_send(self, send: Send) -> Send:
188 """Wrap the provided send channel with response body validation"""
189
190 messages = []
191
192 async def send_(message: t.MutableMapping[str, t.Any]) -> None:
193 messages.append(message)
194
195 if message["type"] == "http.response.start" or message.get(
196 "more_body", False
197 ):
198 return
199
200 stream = (message.get("body", b"") for message in messages)
201 body = self._parse(stream)
202
203 if not (body is None and self._nullable):
204 self._validate(body)
205
206 while messages:
207 await send(messages.pop(0))
208
209 return send_