1"""
2This module defines a Validator interface with base functionality that can be subclassed
3for custom validators provided to the RequestValidationMiddleware.
4"""
5import copy
6import json
7import typing as t
8
9from starlette.datastructures import Headers, MutableHeaders
10from starlette.types import Receive, Scope, Send
11
12from connexion.exceptions import BadRequestProblem
13
14
15class AbstractRequestBodyValidator:
16 """
17 Validator interface with base functionality that can be subclassed for custom validators.
18
19 .. note: Validators load the whole body into memory, which can be a problem for large payloads.
20 """
21
22 MUTABLE_VALIDATION = False
23 """
24 Whether mutations to the body during validation should be transmitted via the receive channel.
25 Note that this does not apply to the substitution of a missing body with the default body, which always
26 updates the receive channel.
27 """
28 MAX_MESSAGE_LENGTH = 256000
29 """Maximum message length that will be sent via the receive channel for mutated bodies."""
30
31 def __init__(
32 self,
33 *,
34 schema: dict,
35 required: bool = False,
36 nullable: bool = False,
37 encoding: str,
38 strict_validation: bool,
39 **kwargs,
40 ):
41 """
42 :param schema: Schema of operation to validate
43 :param required: Whether RequestBody is required
44 :param nullable: Whether RequestBody is nullable
45 :param encoding: Encoding of body (passed via Content-Type header)
46 :param kwargs: Additional arguments for subclasses
47 :param strict_validation: Whether to allow parameters not defined in the spec
48 """
49 self._schema = schema
50 self._nullable = nullable
51 self._required = required
52 self._encoding = encoding
53 self._strict_validation = strict_validation
54
55 async def _parse(
56 self, stream: t.AsyncGenerator[bytes, None], scope: Scope
57 ) -> t.Any:
58 """Parse the incoming stream."""
59
60 def _validate(self, body: t.Any) -> t.Optional[dict]:
61 """
62 Validate the parsed body.
63
64 :raises: :class:`connexion.exceptions.BadRequestProblem`
65 """
66
67 def _insert_body(
68 self, receive: Receive, *, body: t.Any, scope: Scope
69 ) -> t.Tuple[Receive, Scope]:
70 """
71 Insert messages transmitting the body at the start of the `receive` channel.
72
73 This method updates the provided `scope` in place with the right `Content-Length` header.
74 """
75 if body is None:
76 return receive, scope
77
78 bytes_body = json.dumps(body).encode(self._encoding)
79
80 # Update the content-length header
81 new_scope = scope.copy()
82 new_scope["headers"] = copy.deepcopy(scope["headers"])
83 headers = MutableHeaders(scope=new_scope)
84 headers["content-length"] = str(len(bytes_body))
85
86 # Wrap in new receive channel
87 messages = (
88 {
89 "type": "http.request",
90 "body": bytes_body[i : i + self.MAX_MESSAGE_LENGTH],
91 "more_body": i + self.MAX_MESSAGE_LENGTH < len(bytes_body),
92 }
93 for i in range(0, len(bytes_body), self.MAX_MESSAGE_LENGTH)
94 )
95
96 receive = self._insert_messages(receive, messages=messages)
97
98 return receive, new_scope
99
100 @staticmethod
101 def _insert_messages(
102 receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
103 ) -> Receive:
104 """Insert messages at the start of the `receive` channel."""
105 # Ensure that messages is an iterator so each message is replayed once.
106 message_iterator = iter(messages)
107
108 async def receive_() -> t.MutableMapping[str, t.Any]:
109 try:
110 return next(message_iterator)
111 except StopIteration:
112 return await receive()
113
114 return receive_
115
116 async def wrap_receive(
117 self, receive: Receive, *, scope: Scope
118 ) -> t.Tuple[Receive, Scope]:
119 """
120 Wrap the provided `receive` channel with request body validation.
121
122 This method updates the provided `scope` in place with the right `Content-Length` header.
123 """
124 # Handle missing bodies
125 headers = Headers(scope=scope)
126 if not int(headers.get("content-length", 0)):
127 body = self._schema.get("default")
128 if body is None and self._required:
129 raise BadRequestProblem("RequestBody is required")
130 # The default body is encoded as a `receive` channel to mimic an incoming body
131 receive, scope = self._insert_body(receive, body=body, scope=scope)
132
133 # The receive channel is converted to a stream for convenient access
134 messages = []
135
136 async def stream() -> t.AsyncGenerator[bytes, None]:
137 more_body = True
138 while more_body:
139 message = await receive()
140 messages.append(message)
141 more_body = message.get("more_body", False)
142 yield message.get("body", b"")
143 yield b""
144
145 # The body is parsed and validated
146 body = await self._parse(stream(), scope=scope)
147 if not (body is None and self._nullable):
148 self._validate(body)
149
150 # If MUTABLE_VALIDATION is enabled, include any changes made during validation in the messages to send
151 if self.MUTABLE_VALIDATION:
152 # Include changes made during validation
153 receive, scope = self._insert_body(receive, body=body, scope=scope)
154 else:
155 # Serialize original messages
156 receive = self._insert_messages(receive, messages=messages)
157
158 return receive, scope
159
160
161class AbstractResponseBodyValidator:
162 """
163 Validator interface with base functionality that can be subclassed for custom validators.
164
165 .. note: Validators load the whole body into memory, which can be a problem for large payloads.
166 """
167
168 def __init__(
169 self,
170 scope: Scope,
171 *,
172 schema: dict,
173 nullable: bool = False,
174 encoding: str,
175 ) -> None:
176 self._scope = scope
177 self._schema = schema
178 self._nullable = nullable
179 self._encoding = encoding
180
181 def _parse(self, stream: t.Generator[bytes, None, None]) -> t.Any:
182 """Parse the incoming stream."""
183
184 def _validate(self, body: t.Any) -> t.Optional[dict]:
185 """
186 Validate the body.
187
188 :raises: :class:`connexion.exceptions.NonConformingResponse`
189 """
190
191 def wrap_send(self, send: Send) -> Send:
192 """Wrap the provided send channel with response body validation"""
193
194 messages = []
195
196 async def send_(message: t.MutableMapping[str, t.Any]) -> None:
197 messages.append(message)
198
199 if message["type"] == "http.response.start" or message.get(
200 "more_body", False
201 ):
202 return
203
204 stream = (message.get("body", b"") for message in messages)
205 body = self._parse(stream)
206
207 if not (body is None and self._nullable):
208 self._validate(body)
209
210 while messages:
211 await send(messages.pop(0))
212
213 return send_