Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/routing.py: 21%

446 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import contextlib 

2import functools 

3import inspect 

4import re 

5import traceback 

6import types 

7import typing 

8import warnings 

9from contextlib import asynccontextmanager 

10from enum import Enum 

11 

12from starlette._utils import is_async_callable 

13from starlette.concurrency import run_in_threadpool 

14from starlette.convertors import CONVERTOR_TYPES, Convertor 

15from starlette.datastructures import URL, Headers, URLPath 

16from starlette.exceptions import HTTPException 

17from starlette.middleware import Middleware 

18from starlette.requests import Request 

19from starlette.responses import PlainTextResponse, RedirectResponse 

20from starlette.types import ASGIApp, Receive, Scope, Send 

21from starlette.websockets import WebSocket, WebSocketClose 

22 

23 

24class NoMatchFound(Exception): 

25 """ 

26 Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)` 

27 if no matching route exists. 

28 """ 

29 

30 def __init__(self, name: str, path_params: typing.Dict[str, typing.Any]) -> None: 

31 params = ", ".join(list(path_params.keys())) 

32 super().__init__(f'No route exists for name "{name}" and params "{params}".') 

33 

34 

35class Match(Enum): 

36 NONE = 0 

37 PARTIAL = 1 

38 FULL = 2 

39 

40 

41def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover 

42 """ 

43 Correctly determines if an object is a coroutine function, 

44 including those wrapped in functools.partial objects. 

45 """ 

46 warnings.warn( 

47 "iscoroutinefunction_or_partial is deprecated, " 

48 "and will be removed in a future release.", 

49 DeprecationWarning, 

50 ) 

51 while isinstance(obj, functools.partial): 

52 obj = obj.func 

53 return inspect.iscoroutinefunction(obj) 

54 

55 

56def request_response(func: typing.Callable) -> ASGIApp: 

57 """ 

58 Takes a function or coroutine `func(request) -> response`, 

59 and returns an ASGI application. 

60 """ 

61 is_coroutine = is_async_callable(func) 

62 

63 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

64 request = Request(scope, receive=receive, send=send) 

65 if is_coroutine: 

66 response = await func(request) 

67 else: 

68 response = await run_in_threadpool(func, request) 

69 await response(scope, receive, send) 

70 

71 return app 

72 

73 

74def websocket_session(func: typing.Callable) -> ASGIApp: 

75 """ 

76 Takes a coroutine `func(session)`, and returns an ASGI application. 

77 """ 

78 # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" 

79 

80 async def app(scope: Scope, receive: Receive, send: Send) -> None: 

81 session = WebSocket(scope, receive=receive, send=send) 

82 await func(session) 

83 

84 return app 

85 

86 

87def get_name(endpoint: typing.Callable) -> str: 

88 if inspect.isroutine(endpoint) or inspect.isclass(endpoint): 

89 return endpoint.__name__ 

90 return endpoint.__class__.__name__ 

91 

92 

93def replace_params( 

94 path: str, 

95 param_convertors: typing.Dict[str, Convertor], 

96 path_params: typing.Dict[str, str], 

97) -> typing.Tuple[str, dict]: 

98 for key, value in list(path_params.items()): 

99 if "{" + key + "}" in path: 

100 convertor = param_convertors[key] 

101 value = convertor.to_string(value) 

102 path = path.replace("{" + key + "}", value) 

103 path_params.pop(key) 

104 return path, path_params 

105 

106 

107# Match parameters in URL paths, eg. '{param}', and '{param:int}' 

108PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") 

109 

110 

111def compile_path( 

112 path: str, 

113) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: 

114 """ 

115 Given a path string, like: "/{username:str}", 

116 or a host string, like: "{subdomain}.mydomain.org", return a three-tuple 

117 of (regex, format, {param_name:convertor}). 

118 

119 regex: "/(?P<username>[^/]+)" 

120 format: "/{username}" 

121 convertors: {"username": StringConvertor()} 

122 """ 

123 is_host = not path.startswith("/") 

124 

125 path_regex = "^" 

126 path_format = "" 

127 duplicated_params = set() 

128 

129 idx = 0 

130 param_convertors = {} 

131 for match in PARAM_REGEX.finditer(path): 

132 param_name, convertor_type = match.groups("str") 

133 convertor_type = convertor_type.lstrip(":") 

134 assert ( 

135 convertor_type in CONVERTOR_TYPES 

136 ), f"Unknown path convertor '{convertor_type}'" 

137 convertor = CONVERTOR_TYPES[convertor_type] 

138 

139 path_regex += re.escape(path[idx : match.start()]) 

140 path_regex += f"(?P<{param_name}>{convertor.regex})" 

141 

142 path_format += path[idx : match.start()] 

143 path_format += "{%s}" % param_name 

144 

145 if param_name in param_convertors: 

146 duplicated_params.add(param_name) 

147 

148 param_convertors[param_name] = convertor 

149 

150 idx = match.end() 

151 

152 if duplicated_params: 

153 names = ", ".join(sorted(duplicated_params)) 

154 ending = "s" if len(duplicated_params) > 1 else "" 

155 raise ValueError(f"Duplicated param name{ending} {names} at path {path}") 

156 

157 if is_host: 

158 # Align with `Host.matches()` behavior, which ignores port. 

159 hostname = path[idx:].split(":")[0] 

160 path_regex += re.escape(hostname) + "$" 

161 else: 

162 path_regex += re.escape(path[idx:]) + "$" 

163 

164 path_format += path[idx:] 

165 

166 return re.compile(path_regex), path_format, param_convertors 

167 

168 

169class BaseRoute: 

170 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: 

171 raise NotImplementedError() # pragma: no cover 

172 

173 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

174 raise NotImplementedError() # pragma: no cover 

175 

176 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

177 raise NotImplementedError() # pragma: no cover 

178 

179 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

180 """ 

181 A route may be used in isolation as a stand-alone ASGI app. 

182 This is a somewhat contrived case, as they'll almost always be used 

183 within a Router, but could be useful for some tooling and minimal apps. 

184 """ 

185 match, child_scope = self.matches(scope) 

186 if match == Match.NONE: 

187 if scope["type"] == "http": 

188 response = PlainTextResponse("Not Found", status_code=404) 

189 await response(scope, receive, send) 

190 elif scope["type"] == "websocket": 

191 websocket_close = WebSocketClose() 

192 await websocket_close(scope, receive, send) 

193 return 

194 

195 scope.update(child_scope) 

196 await self.handle(scope, receive, send) 

197 

198 

199class Route(BaseRoute): 

200 def __init__( 

201 self, 

202 path: str, 

203 endpoint: typing.Callable, 

204 *, 

205 methods: typing.Optional[typing.List[str]] = None, 

206 name: typing.Optional[str] = None, 

207 include_in_schema: bool = True, 

208 ) -> None: 

209 assert path.startswith("/"), "Routed paths must start with '/'" 

210 self.path = path 

211 self.endpoint = endpoint 

212 self.name = get_name(endpoint) if name is None else name 

213 self.include_in_schema = include_in_schema 

214 

215 endpoint_handler = endpoint 

216 while isinstance(endpoint_handler, functools.partial): 

217 endpoint_handler = endpoint_handler.func 

218 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): 

219 # Endpoint is function or method. Treat it as `func(request) -> response`. 

220 self.app = request_response(endpoint) 

221 if methods is None: 

222 methods = ["GET"] 

223 else: 

224 # Endpoint is a class. Treat it as ASGI. 

225 self.app = endpoint 

226 

227 if methods is None: 

228 self.methods = None 

229 else: 

230 self.methods = {method.upper() for method in methods} 

231 if "GET" in self.methods: 

232 self.methods.add("HEAD") 

233 

234 self.path_regex, self.path_format, self.param_convertors = compile_path(path) 

235 

236 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: 

237 if scope["type"] == "http": 

238 match = self.path_regex.match(scope["path"]) 

239 if match: 

240 matched_params = match.groupdict() 

241 for key, value in matched_params.items(): 

242 matched_params[key] = self.param_convertors[key].convert(value) 

243 path_params = dict(scope.get("path_params", {})) 

244 path_params.update(matched_params) 

245 child_scope = {"endpoint": self.endpoint, "path_params": path_params} 

246 if self.methods and scope["method"] not in self.methods: 

247 return Match.PARTIAL, child_scope 

248 else: 

249 return Match.FULL, child_scope 

250 return Match.NONE, {} 

251 

252 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

253 seen_params = set(path_params.keys()) 

254 expected_params = set(self.param_convertors.keys()) 

255 

256 if name != self.name or seen_params != expected_params: 

257 raise NoMatchFound(name, path_params) 

258 

259 path, remaining_params = replace_params( 

260 self.path_format, self.param_convertors, path_params 

261 ) 

262 assert not remaining_params 

263 return URLPath(path=path, protocol="http") 

264 

265 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

266 if self.methods and scope["method"] not in self.methods: 

267 headers = {"Allow": ", ".join(self.methods)} 

268 if "app" in scope: 

269 raise HTTPException(status_code=405, headers=headers) 

270 else: 

271 response = PlainTextResponse( 

272 "Method Not Allowed", status_code=405, headers=headers 

273 ) 

274 await response(scope, receive, send) 

275 else: 

276 await self.app(scope, receive, send) 

277 

278 def __eq__(self, other: typing.Any) -> bool: 

279 return ( 

280 isinstance(other, Route) 

281 and self.path == other.path 

282 and self.endpoint == other.endpoint 

283 and self.methods == other.methods 

284 ) 

285 

286 def __repr__(self) -> str: 

287 class_name = self.__class__.__name__ 

288 methods = sorted(self.methods or []) 

289 path, name = self.path, self.name 

290 return f"{class_name}(path={path!r}, name={name!r}, methods={methods!r})" 

291 

292 

293class WebSocketRoute(BaseRoute): 

294 def __init__( 

295 self, path: str, endpoint: typing.Callable, *, name: typing.Optional[str] = None 

296 ) -> None: 

297 assert path.startswith("/"), "Routed paths must start with '/'" 

298 self.path = path 

299 self.endpoint = endpoint 

300 self.name = get_name(endpoint) if name is None else name 

301 

302 endpoint_handler = endpoint 

303 while isinstance(endpoint_handler, functools.partial): 

304 endpoint_handler = endpoint_handler.func 

305 if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): 

306 # Endpoint is function or method. Treat it as `func(websocket)`. 

307 self.app = websocket_session(endpoint) 

308 else: 

309 # Endpoint is a class. Treat it as ASGI. 

310 self.app = endpoint 

311 

312 self.path_regex, self.path_format, self.param_convertors = compile_path(path) 

313 

314 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: 

315 if scope["type"] == "websocket": 

316 match = self.path_regex.match(scope["path"]) 

317 if match: 

318 matched_params = match.groupdict() 

319 for key, value in matched_params.items(): 

320 matched_params[key] = self.param_convertors[key].convert(value) 

321 path_params = dict(scope.get("path_params", {})) 

322 path_params.update(matched_params) 

323 child_scope = {"endpoint": self.endpoint, "path_params": path_params} 

324 return Match.FULL, child_scope 

325 return Match.NONE, {} 

326 

327 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

328 seen_params = set(path_params.keys()) 

329 expected_params = set(self.param_convertors.keys()) 

330 

331 if name != self.name or seen_params != expected_params: 

332 raise NoMatchFound(name, path_params) 

333 

334 path, remaining_params = replace_params( 

335 self.path_format, self.param_convertors, path_params 

336 ) 

337 assert not remaining_params 

338 return URLPath(path=path, protocol="websocket") 

339 

340 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

341 await self.app(scope, receive, send) 

342 

343 def __eq__(self, other: typing.Any) -> bool: 

344 return ( 

345 isinstance(other, WebSocketRoute) 

346 and self.path == other.path 

347 and self.endpoint == other.endpoint 

348 ) 

349 

350 def __repr__(self) -> str: 

351 return f"{self.__class__.__name__}(path={self.path!r}, name={self.name!r})" 

352 

353 

354class Mount(BaseRoute): 

355 def __init__( 

356 self, 

357 path: str, 

358 app: typing.Optional[ASGIApp] = None, 

359 routes: typing.Optional[typing.Sequence[BaseRoute]] = None, 

360 name: typing.Optional[str] = None, 

361 *, 

362 middleware: typing.Optional[typing.Sequence[Middleware]] = None, 

363 ) -> None: 

364 assert path == "" or path.startswith("/"), "Routed paths must start with '/'" 

365 assert ( 

366 app is not None or routes is not None 

367 ), "Either 'app=...', or 'routes=' must be specified" 

368 self.path = path.rstrip("/") 

369 if app is not None: 

370 self._base_app: ASGIApp = app 

371 else: 

372 self._base_app = Router(routes=routes) 

373 self.app = self._base_app 

374 if middleware is not None: 

375 for cls, options in reversed(middleware): 

376 self.app = cls(app=self.app, **options) 

377 self.name = name 

378 self.path_regex, self.path_format, self.param_convertors = compile_path( 

379 self.path + "/{path:path}" 

380 ) 

381 

382 @property 

383 def routes(self) -> typing.List[BaseRoute]: 

384 return getattr(self._base_app, "routes", []) 

385 

386 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: 

387 if scope["type"] in ("http", "websocket"): 

388 path = scope["path"] 

389 match = self.path_regex.match(path) 

390 if match: 

391 matched_params = match.groupdict() 

392 for key, value in matched_params.items(): 

393 matched_params[key] = self.param_convertors[key].convert(value) 

394 remaining_path = "/" + matched_params.pop("path") 

395 matched_path = path[: -len(remaining_path)] 

396 path_params = dict(scope.get("path_params", {})) 

397 path_params.update(matched_params) 

398 root_path = scope.get("root_path", "") 

399 child_scope = { 

400 "path_params": path_params, 

401 "app_root_path": scope.get("app_root_path", root_path), 

402 "root_path": root_path + matched_path, 

403 "path": remaining_path, 

404 "endpoint": self.app, 

405 } 

406 return Match.FULL, child_scope 

407 return Match.NONE, {} 

408 

409 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

410 if self.name is not None and name == self.name and "path" in path_params: 

411 # 'name' matches "<mount_name>". 

412 path_params["path"] = path_params["path"].lstrip("/") 

413 path, remaining_params = replace_params( 

414 self.path_format, self.param_convertors, path_params 

415 ) 

416 if not remaining_params: 

417 return URLPath(path=path) 

418 elif self.name is None or name.startswith(self.name + ":"): 

419 if self.name is None: 

420 # No mount name. 

421 remaining_name = name 

422 else: 

423 # 'name' matches "<mount_name>:<child_name>". 

424 remaining_name = name[len(self.name) + 1 :] 

425 path_kwarg = path_params.get("path") 

426 path_params["path"] = "" 

427 path_prefix, remaining_params = replace_params( 

428 self.path_format, self.param_convertors, path_params 

429 ) 

430 if path_kwarg is not None: 

431 remaining_params["path"] = path_kwarg 

432 for route in self.routes or []: 

433 try: 

434 url = route.url_path_for(remaining_name, **remaining_params) 

435 return URLPath( 

436 path=path_prefix.rstrip("/") + str(url), protocol=url.protocol 

437 ) 

438 except NoMatchFound: 

439 pass 

440 raise NoMatchFound(name, path_params) 

441 

442 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

443 await self.app(scope, receive, send) 

444 

445 def __eq__(self, other: typing.Any) -> bool: 

446 return ( 

447 isinstance(other, Mount) 

448 and self.path == other.path 

449 and self.app == other.app 

450 ) 

451 

452 def __repr__(self) -> str: 

453 class_name = self.__class__.__name__ 

454 name = self.name or "" 

455 return f"{class_name}(path={self.path!r}, name={name!r}, app={self.app!r})" 

456 

457 

458class Host(BaseRoute): 

459 def __init__( 

460 self, host: str, app: ASGIApp, name: typing.Optional[str] = None 

461 ) -> None: 

462 assert not host.startswith("/"), "Host must not start with '/'" 

463 self.host = host 

464 self.app = app 

465 self.name = name 

466 self.host_regex, self.host_format, self.param_convertors = compile_path(host) 

467 

468 @property 

469 def routes(self) -> typing.List[BaseRoute]: 

470 return getattr(self.app, "routes", []) 

471 

472 def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: 

473 if scope["type"] in ("http", "websocket"): 

474 headers = Headers(scope=scope) 

475 host = headers.get("host", "").split(":")[0] 

476 match = self.host_regex.match(host) 

477 if match: 

478 matched_params = match.groupdict() 

479 for key, value in matched_params.items(): 

480 matched_params[key] = self.param_convertors[key].convert(value) 

481 path_params = dict(scope.get("path_params", {})) 

482 path_params.update(matched_params) 

483 child_scope = {"path_params": path_params, "endpoint": self.app} 

484 return Match.FULL, child_scope 

485 return Match.NONE, {} 

486 

487 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

488 if self.name is not None and name == self.name and "path" in path_params: 

489 # 'name' matches "<mount_name>". 

490 path = path_params.pop("path") 

491 host, remaining_params = replace_params( 

492 self.host_format, self.param_convertors, path_params 

493 ) 

494 if not remaining_params: 

495 return URLPath(path=path, host=host) 

496 elif self.name is None or name.startswith(self.name + ":"): 

497 if self.name is None: 

498 # No mount name. 

499 remaining_name = name 

500 else: 

501 # 'name' matches "<mount_name>:<child_name>". 

502 remaining_name = name[len(self.name) + 1 :] 

503 host, remaining_params = replace_params( 

504 self.host_format, self.param_convertors, path_params 

505 ) 

506 for route in self.routes or []: 

507 try: 

508 url = route.url_path_for(remaining_name, **remaining_params) 

509 return URLPath(path=str(url), protocol=url.protocol, host=host) 

510 except NoMatchFound: 

511 pass 

512 raise NoMatchFound(name, path_params) 

513 

514 async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: 

515 await self.app(scope, receive, send) 

516 

517 def __eq__(self, other: typing.Any) -> bool: 

518 return ( 

519 isinstance(other, Host) 

520 and self.host == other.host 

521 and self.app == other.app 

522 ) 

523 

524 def __repr__(self) -> str: 

525 class_name = self.__class__.__name__ 

526 name = self.name or "" 

527 return f"{class_name}(host={self.host!r}, name={name!r}, app={self.app!r})" 

528 

529 

530_T = typing.TypeVar("_T") 

531 

532 

533class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): 

534 def __init__(self, cm: typing.ContextManager[_T]): 

535 self._cm = cm 

536 

537 async def __aenter__(self) -> _T: 

538 return self._cm.__enter__() 

539 

540 async def __aexit__( 

541 self, 

542 exc_type: typing.Optional[typing.Type[BaseException]], 

543 exc_value: typing.Optional[BaseException], 

544 traceback: typing.Optional[types.TracebackType], 

545 ) -> typing.Optional[bool]: 

546 return self._cm.__exit__(exc_type, exc_value, traceback) 

547 

548 

549def _wrap_gen_lifespan_context( 

550 lifespan_context: typing.Callable[[typing.Any], typing.Generator] 

551) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: 

552 cmgr = contextlib.contextmanager(lifespan_context) 

553 

554 @functools.wraps(cmgr) 

555 def wrapper(app: typing.Any) -> _AsyncLiftContextManager: 

556 return _AsyncLiftContextManager(cmgr(app)) 

557 

558 return wrapper 

559 

560 

561class _DefaultLifespan: 

562 def __init__(self, router: "Router"): 

563 self._router = router 

564 

565 async def __aenter__(self) -> None: 

566 await self._router.startup() 

567 

568 async def __aexit__(self, *exc_info: object) -> None: 

569 await self._router.shutdown() 

570 

571 def __call__(self: _T, app: object) -> _T: 

572 return self 

573 

574 

575class Router: 

576 def __init__( 

577 self, 

578 routes: typing.Optional[typing.Sequence[BaseRoute]] = None, 

579 redirect_slashes: bool = True, 

580 default: typing.Optional[ASGIApp] = None, 

581 on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, 

582 on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, 

583 lifespan: typing.Optional[ 

584 typing.Callable[[typing.Any], typing.AsyncContextManager] 

585 ] = None, 

586 ) -> None: 

587 self.routes = [] if routes is None else list(routes) 

588 self.redirect_slashes = redirect_slashes 

589 self.default = self.not_found if default is None else default 

590 self.on_startup = [] if on_startup is None else list(on_startup) 

591 self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) 

592 

593 if lifespan is None: 

594 self.lifespan_context: typing.Callable[ 

595 [typing.Any], typing.AsyncContextManager 

596 ] = _DefaultLifespan(self) 

597 

598 elif inspect.isasyncgenfunction(lifespan): 

599 warnings.warn( 

600 "async generator function lifespans are deprecated, " 

601 "use an @contextlib.asynccontextmanager function instead", 

602 DeprecationWarning, 

603 ) 

604 self.lifespan_context = asynccontextmanager( 

605 lifespan, # type: ignore[arg-type] 

606 ) 

607 elif inspect.isgeneratorfunction(lifespan): 

608 warnings.warn( 

609 "generator function lifespans are deprecated, " 

610 "use an @contextlib.asynccontextmanager function instead", 

611 DeprecationWarning, 

612 ) 

613 self.lifespan_context = _wrap_gen_lifespan_context( 

614 lifespan, # type: ignore[arg-type] 

615 ) 

616 else: 

617 self.lifespan_context = lifespan 

618 

619 async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: 

620 if scope["type"] == "websocket": 

621 websocket_close = WebSocketClose() 

622 await websocket_close(scope, receive, send) 

623 return 

624 

625 # If we're running inside a starlette application then raise an 

626 # exception, so that the configurable exception handler can deal with 

627 # returning the response. For plain ASGI apps, just return the response. 

628 if "app" in scope: 

629 raise HTTPException(status_code=404) 

630 else: 

631 response = PlainTextResponse("Not Found", status_code=404) 

632 await response(scope, receive, send) 

633 

634 def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: 

635 for route in self.routes: 

636 try: 

637 return route.url_path_for(name, **path_params) 

638 except NoMatchFound: 

639 pass 

640 raise NoMatchFound(name, path_params) 

641 

642 async def startup(self) -> None: 

643 """ 

644 Run any `.on_startup` event handlers. 

645 """ 

646 for handler in self.on_startup: 

647 if is_async_callable(handler): 

648 await handler() 

649 else: 

650 handler() 

651 

652 async def shutdown(self) -> None: 

653 """ 

654 Run any `.on_shutdown` event handlers. 

655 """ 

656 for handler in self.on_shutdown: 

657 if is_async_callable(handler): 

658 await handler() 

659 else: 

660 handler() 

661 

662 async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: 

663 """ 

664 Handle ASGI lifespan messages, which allows us to manage application 

665 startup and shutdown events. 

666 """ 

667 started = False 

668 app = scope.get("app") 

669 await receive() 

670 try: 

671 async with self.lifespan_context(app): 

672 await send({"type": "lifespan.startup.complete"}) 

673 started = True 

674 await receive() 

675 except BaseException: 

676 exc_text = traceback.format_exc() 

677 if started: 

678 await send({"type": "lifespan.shutdown.failed", "message": exc_text}) 

679 else: 

680 await send({"type": "lifespan.startup.failed", "message": exc_text}) 

681 raise 

682 else: 

683 await send({"type": "lifespan.shutdown.complete"}) 

684 

685 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

686 """ 

687 The main entry point to the Router class. 

688 """ 

689 assert scope["type"] in ("http", "websocket", "lifespan") 

690 

691 if "router" not in scope: 

692 scope["router"] = self 

693 

694 if scope["type"] == "lifespan": 

695 await self.lifespan(scope, receive, send) 

696 return 

697 

698 partial = None 

699 

700 for route in self.routes: 

701 # Determine if any route matches the incoming scope, 

702 # and hand over to the matching route if found. 

703 match, child_scope = route.matches(scope) 

704 if match == Match.FULL: 

705 scope.update(child_scope) 

706 await route.handle(scope, receive, send) 

707 return 

708 elif match == Match.PARTIAL and partial is None: 

709 partial = route 

710 partial_scope = child_scope 

711 

712 if partial is not None: 

713 #  Handle partial matches. These are cases where an endpoint is 

714 # able to handle the request, but is not a preferred option. 

715 # We use this in particular to deal with "405 Method Not Allowed". 

716 scope.update(partial_scope) 

717 await partial.handle(scope, receive, send) 

718 return 

719 

720 if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/": 

721 redirect_scope = dict(scope) 

722 if scope["path"].endswith("/"): 

723 redirect_scope["path"] = redirect_scope["path"].rstrip("/") 

724 else: 

725 redirect_scope["path"] = redirect_scope["path"] + "/" 

726 

727 for route in self.routes: 

728 match, child_scope = route.matches(redirect_scope) 

729 if match != Match.NONE: 

730 redirect_url = URL(scope=redirect_scope) 

731 response = RedirectResponse(url=str(redirect_url)) 

732 await response(scope, receive, send) 

733 return 

734 

735 await self.default(scope, receive, send) 

736 

737 def __eq__(self, other: typing.Any) -> bool: 

738 return isinstance(other, Router) and self.routes == other.routes 

739 

740 def mount( 

741 self, path: str, app: ASGIApp, name: typing.Optional[str] = None 

742 ) -> None: # pragma: nocover 

743 route = Mount(path, app=app, name=name) 

744 self.routes.append(route) 

745 

746 def host( 

747 self, host: str, app: ASGIApp, name: typing.Optional[str] = None 

748 ) -> None: # pragma: no cover 

749 route = Host(host, app=app, name=name) 

750 self.routes.append(route) 

751 

752 def add_route( 

753 self, 

754 path: str, 

755 endpoint: typing.Callable, 

756 methods: typing.Optional[typing.List[str]] = None, 

757 name: typing.Optional[str] = None, 

758 include_in_schema: bool = True, 

759 ) -> None: # pragma: nocover 

760 route = Route( 

761 path, 

762 endpoint=endpoint, 

763 methods=methods, 

764 name=name, 

765 include_in_schema=include_in_schema, 

766 ) 

767 self.routes.append(route) 

768 

769 def add_websocket_route( 

770 self, path: str, endpoint: typing.Callable, name: typing.Optional[str] = None 

771 ) -> None: # pragma: no cover 

772 route = WebSocketRoute(path, endpoint=endpoint, name=name) 

773 self.routes.append(route) 

774 

775 def route( 

776 self, 

777 path: str, 

778 methods: typing.Optional[typing.List[str]] = None, 

779 name: typing.Optional[str] = None, 

780 include_in_schema: bool = True, 

781 ) -> typing.Callable: 

782 """ 

783 We no longer document this decorator style API, and its usage is discouraged. 

784 Instead you should use the following approach: 

785 

786 >>> routes = [Route(path, endpoint=...), ...] 

787 >>> app = Starlette(routes=routes) 

788 """ 

789 warnings.warn( 

790 "The `route` decorator is deprecated, and will be removed in version 1.0.0." 

791 "Refer to https://www.starlette.io/routing/#http-routing for the recommended approach.", # noqa: E501 

792 DeprecationWarning, 

793 ) 

794 

795 def decorator(func: typing.Callable) -> typing.Callable: 

796 self.add_route( 

797 path, 

798 func, 

799 methods=methods, 

800 name=name, 

801 include_in_schema=include_in_schema, 

802 ) 

803 return func 

804 

805 return decorator 

806 

807 def websocket_route( 

808 self, path: str, name: typing.Optional[str] = None 

809 ) -> typing.Callable: 

810 """ 

811 We no longer document this decorator style API, and its usage is discouraged. 

812 Instead you should use the following approach: 

813 

814 >>> routes = [WebSocketRoute(path, endpoint=...), ...] 

815 >>> app = Starlette(routes=routes) 

816 """ 

817 warnings.warn( 

818 "The `websocket_route` decorator is deprecated, and will be removed in version 1.0.0. Refer to " # noqa: E501 

819 "https://www.starlette.io/routing/#websocket-routing for the recommended approach.", # noqa: E501 

820 DeprecationWarning, 

821 ) 

822 

823 def decorator(func: typing.Callable) -> typing.Callable: 

824 self.add_websocket_route(path, func, name=name) 

825 return func 

826 

827 return decorator 

828 

829 def add_event_handler( 

830 self, event_type: str, func: typing.Callable 

831 ) -> None: # pragma: no cover 

832 assert event_type in ("startup", "shutdown") 

833 

834 if event_type == "startup": 

835 self.on_startup.append(func) 

836 else: 

837 self.on_shutdown.append(func) 

838 

839 def on_event(self, event_type: str) -> typing.Callable: 

840 warnings.warn( 

841 "The `on_event` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 

842 "Refer to https://www.starlette.io/events/#registering-events for recommended approach.", # noqa: E501 

843 DeprecationWarning, 

844 ) 

845 

846 def decorator(func: typing.Callable) -> typing.Callable: 

847 self.add_event_handler(event_type, func) 

848 return func 

849 

850 return decorator