Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/routing.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

474 statements  

1from __future__ import annotations 

2 

3import contextlib 

4import functools 

5import inspect 

6import re 

7import traceback 

8import types 

9import typing 

10import warnings 

11from contextlib import asynccontextmanager 

12from enum import Enum 

13 

14from starlette._exception_handler import wrap_app_handling_exceptions 

15from starlette._utils import get_route_path, is_async_callable 

16from starlette.concurrency import run_in_threadpool 

17from starlette.convertors import CONVERTOR_TYPES, Convertor 

18from starlette.datastructures import URL, Headers, URLPath 

19from starlette.exceptions import HTTPException 

20from starlette.middleware import Middleware 

21from starlette.requests import Request 

22from starlette.responses import PlainTextResponse, RedirectResponse, Response 

23from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send 

24from starlette.websockets import WebSocket, WebSocketClose 

25 

26 

27class NoMatchFound(Exception): 

28 """ 

29 Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` 

30 if no matching route exists. 

31 """ 

32 

33 def __init__(self, name: str, path_params: dict[str, typing.Any]) -> None: 

34 params = ", ".join(list(path_params.keys())) 

35 super().__init__(f'No route exists for name "{name}" and params "{params}".') 

36 

37 

38class Match(Enum): 

39 NONE = 0 

40 PARTIAL = 1 

41 FULL = 2 

42 

43 

44def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover 

45 """ 

46 Correctly determines if an object is a coroutine function, 

47 including those wrapped in functools.partial objects. 

48 """ 

49 warnings.warn( 

50 "iscoroutinefunction_or_partial is deprecated, " 

51 "and will be removed in a future release.", 

52 DeprecationWarning, 

53 ) 

54 while isinstance(obj, functools.partial): 

55 obj = obj.func 

56 return inspect.iscoroutinefunction(obj) 

57 

58 

59def request_response( 

60 func: typing.Callable[[Request], typing.Awaitable[Response] | Response], 

61) -> ASGIApp: 

62 """ 

63 Takes a function or coroutine `func(request) -> response`, 

64 and returns an ASGI application. 

65 """ 

66 

67 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

68 request = Request(scope, receive, send) 

69 

70 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

71 if is_async_callable(func): 

72 response = await func(request) 

73 else: 

74 response = await run_in_threadpool(func, request) 

75 await response(scope, receive, send) 

76 

77 await wrap_app_handling_exceptions(app, request)(scope, receive, send) 

78 

79 return app 

80 

81 

82def websocket_session( 

83 func: typing.Callable[[WebSocket], typing.Awaitable[None]], 

84) -> ASGIApp: 

85 """ 

86 Takes a coroutine `func(session)`, and returns an ASGI application. 

87 """ 

88 # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" 

89 

90 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

91 session = WebSocket(scope, receive=receive, send=send) 

92 

93 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

94 await func(session) 

95 

96 await wrap_app_handling_exceptions(app, session)(scope, receive, send) 

97 

98 return app 

99 

100 

101def get_name(endpoint: typing.Callable[..., typing.Any]) -> str: 

102 if inspect.isroutine(endpoint) or inspect.isclass(endpoint): 

103 return endpoint.__name__ 

104 return endpoint.__class__.__name__ 

105 

106 

107def replace_params( 

108 path: str, 

109 param_convertors: dict[str, Convertor[typing.Any]], 

110 path_params: dict[str, str], 

111) -> tuple[str, dict[str, str]]: 

112 for key, value in list(path_params.items()): 

113 if "{" + key + "}" in path: 

114 convertor = param_convertors[key] 

115 value = convertor.to_string(value) 

116 path = path.replace("{" + key + "}", value) 

117 path_params.pop(key) 

118 return path, path_params 

119 

120 

121# Match parameters in URL paths, eg. '{param}', and '{param:int}' 

122PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") 

123 

124 

125def compile_path( 

126 path: str, 

127) -> tuple[typing.Pattern[str], str, dict[str, Convertor[typing.Any]]]: 

128 """ 

129 Given a path string, like: "/{username:str}", 

130 or a host string, like: "{subdomain}.mydomain.org", return a three-tuple 

131 of (regex, format, {param_name:convertor}). 

132 

133 regex: "/(?P<username>[^/]+)" 

134 format: "/{username}" 

135 convertors: {"username": StringConvertor()} 

136 """ 

137 is_host = not path.startswith("/") 

138 

139 path_regex = "^" 

140 path_format = "" 

141 duplicated_params = set() 

142 

143 idx = 0 

144 param_convertors = {} 

145 for match in PARAM_REGEX.finditer(path): 

146 param_name, convertor_type = match.groups("str") 

147 convertor_type = convertor_type.lstrip(":") 

148 assert ( 

149 convertor_type in CONVERTOR_TYPES 

150 ), f"Unknown path convertor '{convertor_type}'" 

151 convertor = CONVERTOR_TYPES[convertor_type] 

152 

153 path_regex += re.escape(path[idx : match.start()]) 

154 path_regex += f"(?P<{param_name}>{convertor.regex})" 

155 

156 path_format += path[idx : match.start()] 

157 path_format += "{%s}" % param_name 

158 

159 if param_name in param_convertors: 

160 duplicated_params.add(param_name) 

161 

162 param_convertors[param_name] = convertor 

163 

164 idx = match.end() 

165 

166 if duplicated_params: 

167 names = ", ".join(sorted(duplicated_params)) 

168 ending = "s" if len(duplicated_params) > 1 else "" 

169 raise ValueError(f"Duplicated param name{ending} {names} at path {path}") 

170 

171 if is_host: 

172 # Align with `Host.matches()` behavior, which ignores port. 

173 hostname = path[idx:].split(":")[0] 

174 path_regex += re.escape(hostname) + "$" 

175 else: 

176 path_regex += re.escape(path[idx:]) + "$" 

177 

178 path_format += path[idx:] 

179 

180 return re.compile(path_regex), path_format, param_convertors 

181 

182 

183class BaseRoute: 

184 def matches(self, scope: Scope) -> tuple[Match, Scope]: 

185 raise NotImplementedError() # pragma: no cover 

186 

187 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

188 raise NotImplementedError() # pragma: no cover 

189 

190 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

191 raise NotImplementedError() # pragma: no cover 

192 

193 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

194 """ 

195 A route may be used in isolation as a stand-alone ASGI app. 

196 This is a somewhat contrived case, as they'll almost always be used 

197 within a Router, but could be useful for some tooling and minimal apps. 

198 """ 

199 match, child_scope = self.matches(scope) 

200 if match == Match.NONE: 

201 if scope["type"] == "http": 

202 response = PlainTextResponse("Not Found", status_code=404) 

203 await response(scope, receive, send) 

204 elif scope["type"] == "websocket": 

205 websocket_close = WebSocketClose() 

206 await websocket_close(scope, receive, send) 

207 return 

208 

209 scope.update(child_scope) 

210 await self.handle(scope, receive, send) 

211 

212 

213class Route(BaseRoute): 

214 def __init__( 

215 self, 

216 path: str, 

217 endpoint: typing.Callable[..., typing.Any], 

218 *, 

219 methods: list[str] | None = None, 

220 name: str | None = None, 

221 include_in_schema: bool = True, 

222 middleware: typing.Sequence[Middleware] | None = None, 

223 ) -> None: 

224 assert path.startswith("/"), "Routed paths must start with '/'" 

225 self.path = path 

226 self.endpoint = endpoint 

227 self.name = get_name(endpoint) if name is None else name 

228 self.include_in_schema = include_in_schema 

229 

230 endpoint_handler = endpoint 

231 while isinstance(endpoint_handler, functools.partial): 

232 endpoint_handler = endpoint_handler.func 

233 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): 

234 # Endpoint is function or method. Treat it as `func(request) -> response`. 

235 self.app = request_response(endpoint) 

236 if methods is None: 

237 methods = ["GET"] 

238 else: 

239 # Endpoint is a class. Treat it as ASGI. 

240 self.app = endpoint 

241 

242 if middleware is not None: 

243 for cls, args, kwargs in reversed(middleware): 

244 self.app = cls(app=self.app, *args, **kwargs) 

245 

246 if methods is None: 

247 self.methods = None 

248 else: 

249 self.methods = {method.upper() for method in methods} 

250 if "GET" in self.methods: 

251 self.methods.add("HEAD") 

252 

253 self.path_regex, self.path_format, self.param_convertors = compile_path(path) 

254 

255 def matches(self, scope: Scope) -> tuple[Match, Scope]: 

256 path_params: dict[str, typing.Any] 

257 if scope["type"] == "http": 

258 route_path = get_route_path(scope) 

259 match = self.path_regex.match(route_path) 

260 if match: 

261 matched_params = match.groupdict() 

262 for key, value in matched_params.items(): 

263 matched_params[key] = self.param_convertors[key].convert(value) 

264 path_params = dict(scope.get("path_params", {})) 

265 path_params.update(matched_params) 

266 child_scope = {"endpoint": self.endpoint, "path_params": path_params} 

267 if self.methods and scope["method"] not in self.methods: 

268 return Match.PARTIAL, child_scope 

269 else: 

270 return Match.FULL, child_scope 

271 return Match.NONE, {} 

272 

273 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

274 seen_params = set(path_params.keys()) 

275 expected_params = set(self.param_convertors.keys()) 

276 

277 if name != self.name or seen_params != expected_params: 

278 raise NoMatchFound(name, path_params) 

279 

280 path, remaining_params = replace_params( 

281 self.path_format, self.param_convertors, path_params 

282 ) 

283 assert not remaining_params 

284 return URLPath(path=path, protocol="http") 

285 

286 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

287 if self.methods and scope["method"] not in self.methods: 

288 headers = {"Allow": ", ".join(self.methods)} 

289 if "app" in scope: 

290 raise HTTPException(status_code=405, headers=headers) 

291 else: 

292 response = PlainTextResponse( 

293 "Method Not Allowed", status_code=405, headers=headers 

294 ) 

295 await response(scope, receive, send) 

296 else: 

297 await self.app(scope, receive, send) 

298 

299 def __eq__(self, other: typing.Any) -> bool: 

300 return ( 

301 isinstance(other, Route) 

302 and self.path == other.path 

303 and self.endpoint == other.endpoint 

304 and self.methods == other.methods 

305 ) 

306 

307 def __repr__(self) -> str: 

308 class_name = self.__class__.__name__ 

309 methods = sorted(self.methods or []) 

310 path, name = self.path, self.name 

311 return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})" 

312 

313 

314class WebSocketRoute(BaseRoute): 

315 def __init__( 

316 self, 

317 path: str, 

318 endpoint: typing.Callable[..., typing.Any], 

319 *, 

320 name: str | None = None, 

321 middleware: typing.Sequence[Middleware] | None = None, 

322 ) -> None: 

323 assert path.startswith("/"), "Routed paths must start with '/'" 

324 self.path = path 

325 self.endpoint = endpoint 

326 self.name = get_name(endpoint) if name is None else name 

327 

328 endpoint_handler = endpoint 

329 while isinstance(endpoint_handler, functools.partial): 

330 endpoint_handler = endpoint_handler.func 

331 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): 

332 # Endpoint is function or method. Treat it as `func(websocket)`. 

333 self.app = websocket_session(endpoint) 

334 else: 

335 # Endpoint is a class. Treat it as ASGI. 

336 self.app = endpoint 

337 

338 if middleware is not None: 

339 for cls, args, kwargs in reversed(middleware): 

340 self.app = cls(app=self.app, *args, **kwargs) 

341 

342 self.path_regex, self.path_format, self.param_convertors = compile_path(path) 

343 

344 def matches(self, scope: Scope) -> tuple[Match, Scope]: 

345 path_params: dict[str, typing.Any] 

346 if scope["type"] == "websocket": 

347 route_path = get_route_path(scope) 

348 match = self.path_regex.match(route_path) 

349 if match: 

350 matched_params = match.groupdict() 

351 for key, value in matched_params.items(): 

352 matched_params[key] = self.param_convertors[key].convert(value) 

353 path_params = dict(scope.get("path_params", {})) 

354 path_params.update(matched_params) 

355 child_scope = {"endpoint": self.endpoint, "path_params": path_params} 

356 return Match.FULL, child_scope 

357 return Match.NONE, {} 

358 

359 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

360 seen_params = set(path_params.keys()) 

361 expected_params = set(self.param_convertors.keys()) 

362 

363 if name != self.name or seen_params != expected_params: 

364 raise NoMatchFound(name, path_params) 

365 

366 path, remaining_params = replace_params( 

367 self.path_format, self.param_convertors, path_params 

368 ) 

369 assert not remaining_params 

370 return URLPath(path=path, protocol="websocket") 

371 

372 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

373 await self.app(scope, receive, send) 

374 

375 def __eq__(self, other: typing.Any) -> bool: 

376 return ( 

377 isinstance(other, WebSocketRoute) 

378 and self.path == other.path 

379 and self.endpoint == other.endpoint 

380 ) 

381 

382 def __repr__(self) -> str: 

383 return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" 

384 

385 

386class Mount(BaseRoute): 

387 def __init__( 

388 self, 

389 path: str, 

390 app: ASGIApp | None = None, 

391 routes: typing.Sequence[BaseRoute] | None = None, 

392 name: str | None = None, 

393 *, 

394 middleware: typing.Sequence[Middleware] | None = None, 

395 ) -> None: 

396 assert path == "" or path.startswith("/"), "Routed paths must start with '/'" 

397 assert ( 

398 app is not None or routes is not None 

399 ), "Either 'app=...', or 'routes=' must be specified" 

400 self.path = path.rstrip("/") 

401 if app is not None: 

402 self._base_app: ASGIApp = app 

403 else: 

404 self._base_app = Router(routes=routes) 

405 self.app = self._base_app 

406 if middleware is not None: 

407 for cls, args, kwargs in reversed(middleware): 

408 self.app = cls(app=self.app, *args, **kwargs) 

409 self.name = name 

410 self.path_regex, self.path_format, self.param_convertors = compile_path( 

411 self.path + "/{path:path}" 

412 ) 

413 

414 @property 

415 def routes(self) -> list[BaseRoute]: 

416 return getattr(self._base_app, "routes", []) 

417 

418 def matches(self, scope: Scope) -> tuple[Match, Scope]: 

419 path_params: dict[str, typing.Any] 

420 if scope["type"] in ("http", "websocket"): 

421 root_path = scope.get("root_path", "") 

422 route_path = get_route_path(scope) 

423 match = self.path_regex.match(route_path) 

424 if match: 

425 matched_params = match.groupdict() 

426 for key, value in matched_params.items(): 

427 matched_params[key] = self.param_convertors[key].convert(value) 

428 remaining_path = "/" + matched_params.pop("path") 

429 matched_path = route_path[: -len(remaining_path)] 

430 path_params = dict(scope.get("path_params", {})) 

431 path_params.update(matched_params) 

432 child_scope = { 

433 "path_params": path_params, 

434 # app_root_path will only be set at the top level scope, 

435 # initialized with the (optional) value of a root_path 

436 # set above/before Starlette. And even though any 

437 # mount will have its own child scope with its own respective 

438 # root_path, the app_root_path will always be available in all 

439 # the child scopes with the same top level value because it's 

440 # set only once here with a default, any other child scope will 

441 # just inherit that app_root_path default value stored in the 

442 # scope. All this is needed to support Request.url_for(), as it 

443 # uses the app_root_path to build the URL path. 

444 "app_root_path": scope.get("app_root_path", root_path), 

445 "root_path": root_path + matched_path, 

446 "endpoint": self.app, 

447 } 

448 return Match.FULL, child_scope 

449 return Match.NONE, {} 

450 

451 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

452 if self.name is not None and name == self.name and "path" in path_params: 

453 # 'name' matches "<mount_name>". 

454 path_params["path"] = path_params["path"].lstrip("/") 

455 path, remaining_params = replace_params( 

456 self.path_format, self.param_convertors, path_params 

457 ) 

458 if not remaining_params: 

459 return URLPath(path=path) 

460 elif self.name is None or name.startswith(self.name + ":"): 

461 if self.name is None: 

462 # No mount name. 

463 remaining_name = name 

464 else: 

465 # 'name' matches "<mount_name>:<child_name>". 

466 remaining_name = name[len(self.name) + 1 :] 

467 path_kwarg = path_params.get("path") 

468 path_params["path"] = "" 

469 path_prefix, remaining_params = replace_params( 

470 self.path_format, self.param_convertors, path_params 

471 ) 

472 if path_kwarg is not None: 

473 remaining_params["path"] = path_kwarg 

474 for route in self.routes or []: 

475 try: 

476 url = route.url_path_for(remaining_name, **remaining_params) 

477 return URLPath( 

478 path=path_prefix.rstrip("/") + str(url), protocol=url.protocol 

479 ) 

480 except NoMatchFound: 

481 pass 

482 raise NoMatchFound(name, path_params) 

483 

484 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

485 await self.app(scope, receive, send) 

486 

487 def __eq__(self, other: typing.Any) -> bool: 

488 return ( 

489 isinstance(other, Mount) 

490 and self.path == other.path 

491 and self.app == other.app 

492 ) 

493 

494 def __repr__(self) -> str: 

495 class_name = self.__class__.__name__ 

496 name = self.name or "" 

497 return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})" 

498 

499 

500class Host(BaseRoute): 

501 def __init__(self, host: str, app: ASGIApp, name: str | None = None) -> None: 

502 assert not host.startswith("/"), "Host must not start with '/'" 

503 self.host = host 

504 self.app = app 

505 self.name = name 

506 self.host_regex, self.host_format, self.param_convertors = compile_path(host) 

507 

508 @property 

509 def routes(self) -> list[BaseRoute]: 

510 return getattr(self.app, "routes", []) 

511 

512 def matches(self, scope: Scope) -> tuple[Match, Scope]: 

513 if scope["type"] in ("http", "websocket"): 

514 headers = Headers(scope=scope) 

515 host = headers.get("host", "").split(":")[0] 

516 match = self.host_regex.match(host) 

517 if match: 

518 matched_params = match.groupdict() 

519 for key, value in matched_params.items(): 

520 matched_params[key] = self.param_convertors[key].convert(value) 

521 path_params = dict(scope.get("path_params", {})) 

522 path_params.update(matched_params) 

523 child_scope = {"path_params": path_params, "endpoint": self.app} 

524 return Match.FULL, child_scope 

525 return Match.NONE, {} 

526 

527 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

528 if self.name is not None and name == self.name and "path" in path_params: 

529 # 'name' matches "<mount_name>". 

530 path = path_params.pop("path") 

531 host, remaining_params = replace_params( 

532 self.host_format, self.param_convertors, path_params 

533 ) 

534 if not remaining_params: 

535 return URLPath(path=path, host=host) 

536 elif self.name is None or name.startswith(self.name + ":"): 

537 if self.name is None: 

538 # No mount name. 

539 remaining_name = name 

540 else: 

541 # 'name' matches "<mount_name>:<child_name>". 

542 remaining_name = name[len(self.name) + 1 :] 

543 host, remaining_params = replace_params( 

544 self.host_format, self.param_convertors, path_params 

545 ) 

546 for route in self.routes or []: 

547 try: 

548 url = route.url_path_for(remaining_name, **remaining_params) 

549 return URLPath(path=str(url), protocol=url.protocol, host=host) 

550 except NoMatchFound: 

551 pass 

552 raise NoMatchFound(name, path_params) 

553 

554 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

555 await self.app(scope, receive, send) 

556 

557 def __eq__(self, other: typing.Any) -> bool: 

558 return ( 

559 isinstance(other, Host) 

560 and self.host == other.host 

561 and self.app == other.app 

562 ) 

563 

564 def __repr__(self) -> str: 

565 class_name = self.__class__.__name__ 

566 name = self.name or "" 

567 return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" 

568 

569 

570_T = typing.TypeVar("_T") 

571 

572 

573class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): 

574 def __init__(self, cm: typing.ContextManager[_T]): 

575 self._cm = cm 

576 

577 async def __aenter__(self) -> _T: 

578 return self._cm.__enter__() 

579 

580 async def __aexit__( 

581 self, 

582 exc_type: type[BaseException] | None, 

583 exc_value: BaseException | None, 

584 traceback: types.TracebackType | None, 

585 ) -> bool | None: 

586 return self._cm.__exit__(exc_type, exc_value, traceback) 

587 

588 

589def _wrap_gen_lifespan_context( 

590 lifespan_context: typing.Callable[ 

591 [typing.Any], typing.Generator[typing.Any, typing.Any, typing.Any] 

592 ], 

593) -> typing.Callable[[typing.Any], typing.AsyncContextManager[typing.Any]]: 

594 cmgr = contextlib.contextmanager(lifespan_context) 

595 

596 @functools.wraps(cmgr) 

597 def wrapper(app: typing.Any) -> _AsyncLiftContextManager[typing.Any]: 

598 return _AsyncLiftContextManager(cmgr(app)) 

599 

600 return wrapper 

601 

602 

603class _DefaultLifespan: 

604 def __init__(self, router: Router): 

605 self._router = router 

606 

607 async def __aenter__(self) -> None: 

608 await self._router.startup() 

609 

610 async def __aexit__(self, *exc_info: object) -> None: 

611 await self._router.shutdown() 

612 

613 def __call__(self: _T, app: object) -> _T: 

614 return self 

615 

616 

617class Router: 

618 def __init__( 

619 self, 

620 routes: typing.Sequence[BaseRoute] | None = None, 

621 redirect_slashes: bool = True, 

622 default: ASGIApp | None = None, 

623 on_startup: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, 

624 on_shutdown: typing.Sequence[typing.Callable[[], typing.Any]] | None = None, 

625 # the generic to Lifespan[AppType] is the type of the top level application 

626 # which the router cannot know statically, so we use typing.Any 

627 lifespan: Lifespan[typing.Any] | None = None, 

628 *, 

629 middleware: typing.Sequence[Middleware] | None = None, 

630 ) -> None: 

631 self.routes = [] if routes is None else list(routes) 

632 self.redirect_slashes = redirect_slashes 

633 self.default = self.not_found if default is None else default 

634 self.on_startup = [] if on_startup is None else list(on_startup) 

635 self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) 

636 

637 if on_startup or on_shutdown: 

638 warnings.warn( 

639 "The on_startup and on_shutdown parameters are deprecated, and they " 

640 "will be removed on version 1.0. Use the lifespan parameter instead. " 

641 "See more about it on https://www.starlette.io/lifespan/.", 

642 DeprecationWarning, 

643 ) 

644 if lifespan: 

645 warnings.warn( 

646 "The `lifespan` parameter cannot be used with `on_startup` or " 

647 "`on_shutdown`. Both `on_startup` and `on_shutdown` will be " 

648 "ignored." 

649 ) 

650 

651 if lifespan is None: 

652 self.lifespan_context: Lifespan[typing.Any] = _DefaultLifespan(self) 

653 

654 elif inspect.isasyncgenfunction(lifespan): 

655 warnings.warn( 

656 "async generator function lifespans are deprecated, " 

657 "use an @contextlib.asynccontextmanager function instead", 

658 DeprecationWarning, 

659 ) 

660 self.lifespan_context = asynccontextmanager( 

661 lifespan, 

662 ) 

663 elif inspect.isgeneratorfunction(lifespan): 

664 warnings.warn( 

665 "generator function lifespans are deprecated, " 

666 "use an @contextlib.asynccontextmanager function instead", 

667 DeprecationWarning, 

668 ) 

669 self.lifespan_context = _wrap_gen_lifespan_context( 

670 lifespan, 

671 ) 

672 else: 

673 self.lifespan_context = lifespan 

674 

675 self.middleware_stack = self.app 

676 if middleware: 

677 for cls, args, kwargs in reversed(middleware): 

678 self.middleware_stack = cls(self.middleware_stack, *args, **kwargs) 

679 

680 async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: 

681 if scope["type"] == "websocket": 

682 websocket_close = WebSocketClose() 

683 await websocket_close(scope, receive, send) 

684 return 

685 

686 # If we're running inside a starlette application then raise an 

687 # exception, so that the configurable exception handler can deal with 

688 # returning the response. For plain ASGI apps, just return the response. 

689 if "app" in scope: 

690 raise HTTPException(status_code=404) 

691 else: 

692 response = PlainTextResponse("Not Found", status_code=404) 

693 await response(scope, receive, send) 

694 

695 def url_path_for(self, name: str, /, **path_params: typing.Any) -> URLPath: 

696 for route in self.routes: 

697 try: 

698 return route.url_path_for(name, **path_params) 

699 except NoMatchFound: 

700 pass 

701 raise NoMatchFound(name, path_params) 

702 

703 async def startup(self) -> None: 

704 """ 

705 Run any `.on_startup` event handlers. 

706 """ 

707 for handler in self.on_startup: 

708 if is_async_callable(handler): 

709 await handler() 

710 else: 

711 handler() 

712 

713 async def shutdown(self) -> None: 

714 """ 

715 Run any `.on_shutdown` event handlers. 

716 """ 

717 for handler in self.on_shutdown: 

718 if is_async_callable(handler): 

719 await handler() 

720 else: 

721 handler() 

722 

723 async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: 

724 """ 

725 Handle ASGI lifespan messages, which allows us to manage application 

726 startup and shutdown events. 

727 """ 

728 started = False 

729 app: typing.Any = scope.get("app") 

730 await receive() 

731 try: 

732 async with self.lifespan_context(app) as maybe_state: 

733 if maybe_state is not None: 

734 if "state" not in scope: 

735 raise RuntimeError( 

736 'The server does not support "state" in the lifespan scope.' 

737 ) 

738 scope["state"].update(maybe_state) 

739 await send({"type": "lifespan.startup.complete"}) 

740 started = True 

741 await receive() 

742 except BaseException: 

743 exc_text = traceback.format_exc() 

744 if started: 

745 await send({"type": "lifespan.shutdown.failed", "message": exc_text}) 

746 else: 

747 await send({"type": "lifespan.startup.failed", "message": exc_text}) 

748 raise 

749 else: 

750 await send({"type": "lifespan.shutdown.complete"}) 

751 

752 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

753 """ 

754 The main entry point to the Router class. 

755 """ 

756 await self.middleware_stack(scope, receive, send) 

757 

758 async def app(self, scope: Scope, receive: Receive, send: Send) -> None: 

759 assert scope["type"] in ("http", "websocket", "lifespan") 

760 

761 if "router" not in scope: 

762 scope["router"] = self 

763 

764 if scope["type"] == "lifespan": 

765 await self.lifespan(scope, receive, send) 

766 return 

767 

768 partial = None 

769 

770 for route in self.routes: 

771 # Determine if any route matches the incoming scope, 

772 # and hand over to the matching route if found. 

773 match, child_scope = route.matches(scope) 

774 if match == Match.FULL: 

775 scope.update(child_scope) 

776 await route.handle(scope, receive, send) 

777 return 

778 elif match == Match.PARTIAL and partial is None: 

779 partial = route 

780 partial_scope = child_scope 

781 

782 if partial is not None: 

783 #  Handle partial matches. These are cases where an endpoint is 

784 # able to handle the request, but is not a preferred option. 

785 # We use this in particular to deal with "405 Method Not Allowed". 

786 scope.update(partial_scope) 

787 await partial.handle(scope, receive, send) 

788 return 

789 

790 route_path = get_route_path(scope) 

791 if scope["type"] == "http" and self.redirect_slashes and route_path != "/": 

792 redirect_scope = dict(scope) 

793 if route_path.endswith("/"): 

794 redirect_scope["path"] = redirect_scope["path"].rstrip("/") 

795 else: 

796 redirect_scope["path"] = redirect_scope["path"] + "/" 

797 

798 for route in self.routes: 

799 match, child_scope = route.matches(redirect_scope) 

800 if match != Match.NONE: 

801 redirect_url = URL(scope=redirect_scope) 

802 response = RedirectResponse(url=str(redirect_url)) 

803 await response(scope, receive, send) 

804 return 

805 

806 await self.default(scope, receive, send) 

807 

808 def __eq__(self, other: typing.Any) -> bool: 

809 return isinstance(other, Router) and self.routes == other.routes 

810 

811 def mount( 

812 self, path: str, app: ASGIApp, name: str | None = None 

813 ) -> None: # pragma: nocover 

814 route = Mount(path, app=app, name=name) 

815 self.routes.append(route) 

816 

817 def host( 

818 self, host: str, app: ASGIApp, name: str | None = None 

819 ) -> None: # pragma: no cover 

820 route = Host(host, app=app, name=name) 

821 self.routes.append(route) 

822 

823 def add_route( 

824 self, 

825 path: str, 

826 endpoint: typing.Callable[[Request], typing.Awaitable[Response] | Response], 

827 methods: list[str] | None = None, 

828 name: str | None = None, 

829 include_in_schema: bool = True, 

830 ) -> None: # pragma: nocover 

831 route = Route( 

832 path, 

833 endpoint=endpoint, 

834 methods=methods, 

835 name=name, 

836 include_in_schema=include_in_schema, 

837 ) 

838 self.routes.append(route) 

839 

840 def add_websocket_route( 

841 self, 

842 path: str, 

843 endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], 

844 name: str | None = None, 

845 ) -> None: # pragma: no cover 

846 route = WebSocketRoute(path, endpoint=endpoint, name=name) 

847 self.routes.append(route) 

848 

849 def route( 

850 self, 

851 path: str, 

852 methods: list[str] | None = None, 

853 name: str | None = None, 

854 include_in_schema: bool = True, 

855 ) -> typing.Callable: # type: ignore[type-arg] 

856 """ 

857 We no longer document this decorator style API, and its usage is discouraged. 

858 Instead you should use the following approach: 

859 

860 >>> routes = [Route(path, endpoint=...), ...] 

861 >>> app = Starlette(routes=routes) 

862 """ 

863 warnings.warn( 

864 "The `route` decorator is deprecated, and will be removed in version 1.0.0." 

865 "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501 

866 DeprecationWarning, 

867 ) 

868 

869 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 

870 self.add_route( 

871 path, 

872 func, 

873 methods=methods, 

874 name=name, 

875 include_in_schema=include_in_schema, 

876 ) 

877 return func 

878 

879 return decorator 

880 

881 def websocket_route(self, path: str, name: str | None = None) -> typing.Callable: # type: ignore[type-arg] 

882 """ 

883 We no longer document this decorator style API, and its usage is discouraged. 

884 Instead you should use the following approach: 

885 

886 >>> routes = [WebSocketRoute(path, endpoint=...), ...] 

887 >>> app = Starlette(routes=routes) 

888 """ 

889 warnings.warn( 

890 "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501 

891 "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501 

892 DeprecationWarning, 

893 ) 

894 

895 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 

896 self.add_websocket_route(path, func, name=name) 

897 return func 

898 

899 return decorator 

900 

901 def add_event_handler( 

902 self, event_type: str, func: typing.Callable[[], typing.Any] 

903 ) -> None: # pragma: no cover 

904 assert event_type in ("startup", "shutdown") 

905 

906 if event_type == "startup": 

907 self.on_startup.append(func) 

908 else: 

909 self.on_shutdown.append(func) 

910 

911 def on_event(self, event_type: str) -> typing.Callable: # type: ignore[type-arg] 

912 warnings.warn( 

913 "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 

914 "Refer to https://www.starlette.io/lifespan/ for recommended approach.", 

915 DeprecationWarning, 

916 ) 

917 

918 def decorator(func: typing.Callable) -> typing.Callable: # type: ignore[type-arg] # noqa: E501 

919 self.add_event_handler(event_type, func) 

920 return func 

921 

922 return decorator