1import asyncio
2import io
3import logging
4import re
5import weakref
6from copy import copy
7from urllib.parse import urlparse
8
9import aiohttp
10import yarl
11
12from fsspec.asyn import AbstractAsyncStreamedFile, AsyncFileSystem, sync, sync_wrapper
13from fsspec.callbacks import DEFAULT_CALLBACK
14from fsspec.exceptions import FSTimeoutError
15from fsspec.spec import AbstractBufferedFile
16from fsspec.utils import (
17 DEFAULT_BLOCK_SIZE,
18 glob_translate,
19 isfilelike,
20 nullcontext,
21 tokenize,
22)
23
24from ..caching import AllBytes
25
26# https://stackoverflow.com/a/15926317/3821154
27ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P<url>[^"']+)""")
28ex2 = re.compile(r"""(?P<url>http[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""")
29logger = logging.getLogger("fsspec.http")
30
31
32async def get_client(**kwargs):
33 return aiohttp.ClientSession(**kwargs)
34
35
36class HTTPFileSystem(AsyncFileSystem):
37 """
38 Simple File-System for fetching data via HTTP(S)
39
40 ``ls()`` is implemented by loading the parent page and doing a regex
41 match on the result. If simple_link=True, anything of the form
42 "http(s)://server.com/stuff?thing=other"; otherwise only links within
43 HTML href tags will be used.
44
45 URLs are passed unfiltered to aiohttp, so all addresses are accessible. Where URLs are
46 supplied by a user, the calling application may wish to filter to prevent scanning.
47 """
48
49 protocol = ("http", "https")
50 sep = "/"
51
52 def __init__(
53 self,
54 simple_links=True,
55 block_size=None,
56 same_scheme=True,
57 size_policy=None,
58 cache_type="bytes",
59 cache_options=None,
60 asynchronous=False,
61 loop=None,
62 client_kwargs=None,
63 get_client=get_client,
64 encoded=False,
65 **storage_options,
66 ):
67 """
68 NB: if this is called async, you must await set_client
69
70 Parameters
71 ----------
72 block_size: int
73 Blocks to read bytes; if 0, will default to raw requests file-like
74 objects instead of HTTPFile instances
75 simple_links: bool
76 If True, will consider both HTML <a> tags and anything that looks
77 like a URL; if False, will consider only the former.
78 same_scheme: True
79 When doing ls/glob, if this is True, only consider paths that have
80 http/https matching the input URLs.
81 size_policy: this argument is deprecated
82 client_kwargs: dict
83 Passed to aiohttp.ClientSession, see
84 https://docs.aiohttp.org/en/stable/client_reference.html
85 For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
86 get_client: Callable[..., aiohttp.ClientSession]
87 A callable, which takes keyword arguments and constructs
88 an aiohttp.ClientSession. Its state will be managed by
89 the HTTPFileSystem class.
90 storage_options: key-value
91 Any other parameters passed on to requests
92 cache_type, cache_options: defaults used in open()
93 """
94 super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options)
95 self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE
96 self.simple_links = simple_links
97 self.same_schema = same_scheme
98 self.cache_type = cache_type
99 self.cache_options = cache_options
100 self.client_kwargs = client_kwargs or {}
101 self.get_client = get_client
102 self.encoded = encoded
103 self.kwargs = storage_options
104 self._session = None
105
106 # Clean caching-related parameters from `storage_options`
107 # before propagating them as `request_options` through `self.kwargs`.
108 # TODO: Maybe rename `self.kwargs` to `self.request_options` to make
109 # it clearer.
110 request_options = copy(storage_options)
111 self.use_listings_cache = request_options.pop("use_listings_cache", False)
112 request_options.pop("listings_expiry_time", None)
113 request_options.pop("max_paths", None)
114 request_options.pop("skip_instance_cache", None)
115 self.kwargs = request_options
116
117 @property
118 def fsid(self):
119 return "http"
120
121 def encode_url(self, url):
122 return yarl.URL(url, encoded=self.encoded)
123
124 @staticmethod
125 def close_session(loop, session):
126 if loop is not None and loop.is_running():
127 try:
128 sync(loop, session.close, timeout=0.1)
129 return
130 except (TimeoutError, FSTimeoutError, NotImplementedError):
131 pass
132 connector = getattr(session, "_connector", None)
133 if connector is not None:
134 # close after loop is dead
135 connector._close()
136
137 async def set_session(self):
138 if self._session is None:
139 self._session = await self.get_client(loop=self.loop, **self.client_kwargs)
140 if not self.asynchronous:
141 weakref.finalize(self, self.close_session, self.loop, self._session)
142 return self._session
143
144 @classmethod
145 def _strip_protocol(cls, path):
146 """For HTTP, we always want to keep the full URL"""
147 return path
148
149 @classmethod
150 def _parent(cls, path):
151 # override, since _strip_protocol is different for URLs
152 par = super()._parent(path)
153 if len(par) > 7: # "http://..."
154 return par
155 return ""
156
157 async def _ls_real(self, url, detail=True, **kwargs):
158 # ignoring URL-encoded arguments
159 kw = self.kwargs.copy()
160 kw.update(kwargs)
161 logger.debug(url)
162 session = await self.set_session()
163 async with session.get(self.encode_url(url), **self.kwargs) as r:
164 self._raise_not_found_for_status(r, url)
165
166 if "Content-Type" in r.headers:
167 mimetype = r.headers["Content-Type"].partition(";")[0]
168 else:
169 mimetype = None
170
171 if mimetype in ("text/html", None):
172 try:
173 text = await r.text(errors="ignore")
174 if self.simple_links:
175 links = ex2.findall(text) + [u[2] for u in ex.findall(text)]
176 else:
177 links = [u[2] for u in ex.findall(text)]
178 except UnicodeDecodeError:
179 links = [] # binary, not HTML
180 else:
181 links = []
182
183 out = set()
184 parts = urlparse(url)
185 for l in links:
186 if isinstance(l, tuple):
187 l = l[1]
188 if l.startswith("/") and len(l) > 1:
189 # absolute URL on this server
190 l = f"{parts.scheme}://{parts.netloc}{l}"
191 if l.startswith("http"):
192 if self.same_schema and l.startswith(url.rstrip("/") + "/"):
193 out.add(l)
194 elif l.replace("https", "http").startswith(
195 url.replace("https", "http").rstrip("/") + "/"
196 ):
197 # allowed to cross http <-> https
198 out.add(l)
199 else:
200 if l not in ["..", "../"]:
201 # Ignore FTP-like "parent"
202 out.add("/".join([url.rstrip("/"), l.lstrip("/")]))
203 if not out and url.endswith("/"):
204 out = await self._ls_real(url.rstrip("/"), detail=False)
205 if detail:
206 return [
207 {
208 "name": u,
209 "size": None,
210 "type": "directory" if u.endswith("/") else "file",
211 }
212 for u in out
213 ]
214 else:
215 return sorted(out)
216
217 async def _ls(self, url, detail=True, **kwargs):
218 if self.use_listings_cache and url in self.dircache:
219 out = self.dircache[url]
220 else:
221 out = await self._ls_real(url, detail=detail, **kwargs)
222 self.dircache[url] = out
223 return out
224
225 ls = sync_wrapper(_ls)
226
227 def _raise_not_found_for_status(self, response, url):
228 """
229 Raises FileNotFoundError for 404s, otherwise uses raise_for_status.
230 """
231 if response.status == 404:
232 raise FileNotFoundError(url)
233 response.raise_for_status()
234
235 async def _cat_file(self, url, start=None, end=None, **kwargs):
236 kw = self.kwargs.copy()
237 kw.update(kwargs)
238 logger.debug(url)
239
240 if start is not None or end is not None:
241 if start == end:
242 return b""
243 headers = kw.pop("headers", {}).copy()
244
245 headers["Range"] = await self._process_limits(url, start, end)
246 kw["headers"] = headers
247 session = await self.set_session()
248 async with session.get(self.encode_url(url), **kw) as r:
249 out = await r.read()
250 self._raise_not_found_for_status(r, url)
251 return out
252
253 async def _get_file(
254 self, rpath, lpath, chunk_size=5 * 2**20, callback=DEFAULT_CALLBACK, **kwargs
255 ):
256 kw = self.kwargs.copy()
257 kw.update(kwargs)
258 logger.debug(rpath)
259 session = await self.set_session()
260 async with session.get(self.encode_url(rpath), **kw) as r:
261 try:
262 size = int(r.headers["content-length"])
263 except (ValueError, KeyError):
264 size = None
265
266 callback.set_size(size)
267 self._raise_not_found_for_status(r, rpath)
268 if isfilelike(lpath):
269 outfile = lpath
270 else:
271 outfile = open(lpath, "wb") # noqa: ASYNC230
272
273 try:
274 chunk = True
275 while chunk:
276 chunk = await r.content.read(chunk_size)
277 outfile.write(chunk)
278 callback.relative_update(len(chunk))
279 finally:
280 if not isfilelike(lpath):
281 outfile.close()
282
283 async def _put_file(
284 self,
285 lpath,
286 rpath,
287 chunk_size=5 * 2**20,
288 callback=DEFAULT_CALLBACK,
289 method="post",
290 mode="overwrite",
291 **kwargs,
292 ):
293 if mode != "overwrite":
294 raise NotImplementedError("Exclusive write")
295
296 async def gen_chunks():
297 # Support passing arbitrary file-like objects
298 # and use them instead of streams.
299 if isinstance(lpath, io.IOBase):
300 context = nullcontext(lpath)
301 use_seek = False # might not support seeking
302 else:
303 context = open(lpath, "rb") # noqa: ASYNC230
304 use_seek = True
305
306 with context as f:
307 if use_seek:
308 callback.set_size(f.seek(0, 2))
309 f.seek(0)
310 else:
311 callback.set_size(getattr(f, "size", None))
312
313 chunk = f.read(chunk_size)
314 while chunk:
315 yield chunk
316 callback.relative_update(len(chunk))
317 chunk = f.read(chunk_size)
318
319 kw = self.kwargs.copy()
320 kw.update(kwargs)
321 session = await self.set_session()
322
323 method = method.lower()
324 if method not in ("post", "put"):
325 raise ValueError(
326 f"method has to be either 'post' or 'put', not: {method!r}"
327 )
328
329 meth = getattr(session, method)
330 async with meth(self.encode_url(rpath), data=gen_chunks(), **kw) as resp:
331 self._raise_not_found_for_status(resp, rpath)
332
333 async def _exists(self, path, strict=False, **kwargs):
334 kw = self.kwargs.copy()
335 kw.update(kwargs)
336 try:
337 logger.debug(path)
338 session = await self.set_session()
339 r = await session.get(self.encode_url(path), **kw)
340 async with r:
341 if strict:
342 self._raise_not_found_for_status(r, path)
343 return r.status < 400
344 except FileNotFoundError:
345 return False
346 except aiohttp.ClientError:
347 if strict:
348 raise
349 return False
350
351 async def _isfile(self, path, **kwargs):
352 return await self._exists(path, **kwargs)
353
354 def _open(
355 self,
356 path,
357 mode="rb",
358 block_size=None,
359 autocommit=None, # XXX: This differs from the base class.
360 cache_type=None,
361 cache_options=None,
362 size=None,
363 **kwargs,
364 ):
365 """Make a file-like object
366
367 Parameters
368 ----------
369 path: str
370 Full URL with protocol
371 mode: string
372 must be "rb"
373 block_size: int or None
374 Bytes to download in one request; use instance value if None. If
375 zero, will return a streaming Requests file-like instance.
376 kwargs: key-value
377 Any other parameters, passed to requests calls
378 """
379 if mode != "rb":
380 raise NotImplementedError
381 block_size = block_size if block_size is not None else self.block_size
382 kw = self.kwargs.copy()
383 kw["asynchronous"] = self.asynchronous
384 kw.update(kwargs)
385 info = {}
386 size = size or info.update(self.info(path, **kwargs)) or info["size"]
387 session = sync(self.loop, self.set_session)
388 if block_size and size and info.get("partial", True):
389 return HTTPFile(
390 self,
391 path,
392 session=session,
393 block_size=block_size,
394 mode=mode,
395 size=size,
396 cache_type=cache_type or self.cache_type,
397 cache_options=cache_options or self.cache_options,
398 loop=self.loop,
399 **kw,
400 )
401 else:
402 return HTTPStreamFile(
403 self,
404 path,
405 mode=mode,
406 loop=self.loop,
407 session=session,
408 **kw,
409 )
410
411 async def open_async(self, path, mode="rb", size=None, **kwargs):
412 session = await self.set_session()
413 if size is None:
414 try:
415 size = (await self._info(path, **kwargs))["size"]
416 except FileNotFoundError:
417 pass
418 return AsyncStreamFile(
419 self,
420 path,
421 loop=self.loop,
422 session=session,
423 size=size,
424 **kwargs,
425 )
426
427 def ukey(self, url):
428 """Unique identifier; assume HTTP files are static, unchanging"""
429 return tokenize(url, self.kwargs, self.protocol)
430
431 async def _info(self, url, **kwargs):
432 """Get info of URL
433
434 Tries to access location via HEAD, and then GET methods, but does
435 not fetch the data.
436
437 It is possible that the server does not supply any size information, in
438 which case size will be given as None (and certain operations on the
439 corresponding file will not work).
440 """
441 info = {}
442 session = await self.set_session()
443
444 for policy in ["head", "get"]:
445 try:
446 info.update(
447 await _file_info(
448 self.encode_url(url),
449 size_policy=policy,
450 session=session,
451 **self.kwargs,
452 **kwargs,
453 )
454 )
455 if info.get("size") is not None:
456 break
457 except Exception as exc:
458 if policy == "get":
459 # If get failed, then raise a FileNotFoundError
460 raise FileNotFoundError(url) from exc
461 logger.debug("", exc_info=exc)
462
463 return {"name": url, "size": None, **info, "type": "file"}
464
465 async def _glob(self, path, maxdepth=None, **kwargs):
466 """
467 Find files by glob-matching.
468
469 This implementation is idntical to the one in AbstractFileSystem,
470 but "?" is not considered as a character for globbing, because it is
471 so common in URLs, often identifying the "query" part.
472 """
473 if maxdepth is not None and maxdepth < 1:
474 raise ValueError("maxdepth must be at least 1")
475 import re
476
477 ends_with_slash = path.endswith("/") # _strip_protocol strips trailing slash
478 path = self._strip_protocol(path)
479 append_slash_to_dirname = ends_with_slash or path.endswith(("/**", "/*"))
480 idx_star = path.find("*") if path.find("*") >= 0 else len(path)
481 idx_brace = path.find("[") if path.find("[") >= 0 else len(path)
482
483 min_idx = min(idx_star, idx_brace)
484
485 detail = kwargs.pop("detail", False)
486
487 if not has_magic(path):
488 if await self._exists(path, **kwargs):
489 if not detail:
490 return [path]
491 else:
492 return {path: await self._info(path, **kwargs)}
493 else:
494 if not detail:
495 return [] # glob of non-existent returns empty
496 else:
497 return {}
498 elif "/" in path[:min_idx]:
499 min_idx = path[:min_idx].rindex("/")
500 root = path[: min_idx + 1]
501 depth = path[min_idx + 1 :].count("/") + 1
502 else:
503 root = ""
504 depth = path[min_idx + 1 :].count("/") + 1
505
506 if "**" in path:
507 if maxdepth is not None:
508 idx_double_stars = path.find("**")
509 depth_double_stars = path[idx_double_stars:].count("/") + 1
510 depth = depth - depth_double_stars + maxdepth
511 else:
512 depth = None
513
514 allpaths = await self._find(
515 root, maxdepth=depth, withdirs=True, detail=True, **kwargs
516 )
517
518 pattern = glob_translate(path + ("/" if ends_with_slash else ""))
519 pattern = re.compile(pattern)
520
521 out = {
522 (
523 p.rstrip("/")
524 if not append_slash_to_dirname
525 and info["type"] == "directory"
526 and p.endswith("/")
527 else p
528 ): info
529 for p, info in sorted(allpaths.items())
530 if pattern.match(p.rstrip("/"))
531 }
532
533 if detail:
534 return out
535 else:
536 return list(out)
537
538 async def _isdir(self, path):
539 # override, since all URLs are (also) files
540 try:
541 return bool(await self._ls(path))
542 except (FileNotFoundError, ValueError):
543 return False
544
545 async def _pipe_file(self, path, value, mode="overwrite", **kwargs):
546 """
547 Write bytes to a remote file over HTTP.
548
549 Parameters
550 ----------
551 path : str
552 Target URL where the data should be written
553 value : bytes
554 Data to be written
555 mode : str
556 How to write to the file - 'overwrite' or 'append'
557 **kwargs : dict
558 Additional parameters to pass to the HTTP request
559 """
560 url = self._strip_protocol(path)
561 headers = kwargs.pop("headers", {})
562 headers["Content-Length"] = str(len(value))
563
564 session = await self.set_session()
565
566 async with session.put(
567 self.encode_url(url), data=value, headers=headers, **kwargs
568 ) as r:
569 r.raise_for_status()
570
571
572class HTTPFile(AbstractBufferedFile):
573 """
574 A file-like object pointing to a remote HTTP(S) resource
575
576 Supports only reading, with read-ahead of a predetermined block-size.
577
578 In the case that the server does not supply the filesize, only reading of
579 the complete file in one go is supported.
580
581 Parameters
582 ----------
583 url: str
584 Full URL of the remote resource, including the protocol
585 session: aiohttp.ClientSession or None
586 All calls will be made within this session, to avoid restarting
587 connections where the server allows this
588 block_size: int or None
589 The amount of read-ahead to do, in bytes. Default is 5MB, or the value
590 configured for the FileSystem creating this file
591 size: None or int
592 If given, this is the size of the file in bytes, and we don't attempt
593 to call the server to find the value.
594 kwargs: all other key-values are passed to requests calls.
595 """
596
597 def __init__(
598 self,
599 fs,
600 url,
601 session=None,
602 block_size=None,
603 mode="rb",
604 cache_type="bytes",
605 cache_options=None,
606 size=None,
607 loop=None,
608 asynchronous=False,
609 **kwargs,
610 ):
611 if mode != "rb":
612 raise NotImplementedError("File mode not supported")
613 self.asynchronous = asynchronous
614 self.loop = loop
615 self.url = url
616 self.session = session
617 self.details = {"name": url, "size": size, "type": "file"}
618 super().__init__(
619 fs=fs,
620 path=url,
621 mode=mode,
622 block_size=block_size,
623 cache_type=cache_type,
624 cache_options=cache_options,
625 **kwargs,
626 )
627
628 def read(self, length=-1):
629 """Read bytes from file
630
631 Parameters
632 ----------
633 length: int
634 Read up to this many bytes. If negative, read all content to end of
635 file. If the server has not supplied the filesize, attempting to
636 read only part of the data will raise a ValueError.
637 """
638 if (
639 (length < 0 and self.loc == 0) # explicit read all
640 # but not when the size is known and fits into a block anyways
641 and not (self.size is not None and self.size <= self.blocksize)
642 ):
643 self._fetch_all()
644 if self.size is None:
645 if length < 0:
646 self._fetch_all()
647 else:
648 length = min(self.size - self.loc, length)
649 return super().read(length)
650
651 async def async_fetch_all(self):
652 """Read whole file in one shot, without caching
653
654 This is only called when position is still at zero,
655 and read() is called without a byte-count.
656 """
657 logger.debug(f"Fetch all for {self}")
658 if not isinstance(self.cache, AllBytes):
659 r = await self.session.get(self.fs.encode_url(self.url), **self.kwargs)
660 async with r:
661 r.raise_for_status()
662 out = await r.read()
663 self.cache = AllBytes(
664 size=len(out), fetcher=None, blocksize=None, data=out
665 )
666 self.size = len(out)
667
668 _fetch_all = sync_wrapper(async_fetch_all)
669
670 def _parse_content_range(self, headers):
671 """Parse the Content-Range header"""
672 s = headers.get("Content-Range", "")
673 m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s)
674 if not m:
675 return None, None, None
676
677 if m[1] == "*":
678 start = end = None
679 else:
680 start, end = [int(x) for x in m[1].split("-")]
681 total = None if m[2] == "*" else int(m[2])
682 return start, end, total
683
684 async def async_fetch_range(self, start, end):
685 """Download a block of data
686
687 The expectation is that the server returns only the requested bytes,
688 with HTTP code 206. If this is not the case, we first check the headers,
689 and then stream the output - if the data size is bigger than we
690 requested, an exception is raised.
691 """
692 logger.debug(f"Fetch range for {self}: {start}-{end}")
693 kwargs = self.kwargs.copy()
694 headers = kwargs.pop("headers", {}).copy()
695 headers["Range"] = f"bytes={start}-{end - 1}"
696 logger.debug(f"{self.url} : {headers['Range']}")
697 r = await self.session.get(
698 self.fs.encode_url(self.url), headers=headers, **kwargs
699 )
700 async with r:
701 if r.status == 416:
702 # range request outside file
703 return b""
704 r.raise_for_status()
705
706 # If the server has handled the range request, it should reply
707 # with status 206 (partial content). But we'll guess that a suitable
708 # Content-Range header or a Content-Length no more than the
709 # requested range also mean we have got the desired range.
710 response_is_range = (
711 r.status == 206
712 or self._parse_content_range(r.headers)[0] == start
713 or int(r.headers.get("Content-Length", end + 1)) <= end - start
714 )
715
716 if response_is_range:
717 # partial content, as expected
718 out = await r.read()
719 elif start > 0:
720 raise ValueError(
721 "The HTTP server doesn't appear to support range requests. "
722 "Only reading this file from the beginning is supported. "
723 "Open with block_size=0 for a streaming file interface."
724 )
725 else:
726 # Response is not a range, but we want the start of the file,
727 # so we can read the required amount anyway.
728 cl = 0
729 out = []
730 while True:
731 chunk = await r.content.read(2**20)
732 # data size unknown, let's read until we have enough
733 if chunk:
734 out.append(chunk)
735 cl += len(chunk)
736 if cl > end - start:
737 break
738 else:
739 break
740 out = b"".join(out)[: end - start]
741 return out
742
743 _fetch_range = sync_wrapper(async_fetch_range)
744
745
746magic_check = re.compile("([*[])")
747
748
749def has_magic(s):
750 match = magic_check.search(s)
751 return match is not None
752
753
754class HTTPStreamFile(AbstractBufferedFile):
755 def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs):
756 self.asynchronous = kwargs.pop("asynchronous", False)
757 self.url = url
758 self.loop = loop
759 self.session = session
760 if mode != "rb":
761 raise ValueError
762 self.details = {"name": url, "size": None}
763 super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs)
764
765 async def cor():
766 r = await self.session.get(self.fs.encode_url(url), **kwargs).__aenter__()
767 self.fs._raise_not_found_for_status(r, url)
768 return r
769
770 self.r = sync(self.loop, cor)
771 self.loop = fs.loop
772
773 def seek(self, loc, whence=0):
774 if loc == 0 and whence == 1:
775 return
776 if loc == self.loc and whence == 0:
777 return
778 raise ValueError("Cannot seek streaming HTTP file")
779
780 async def _read(self, num=-1):
781 out = await self.r.content.read(num)
782 self.loc += len(out)
783 return out
784
785 read = sync_wrapper(_read)
786
787 async def _close(self):
788 self.r.close()
789
790 def close(self):
791 asyncio.run_coroutine_threadsafe(self._close(), self.loop)
792 super().close()
793
794
795class AsyncStreamFile(AbstractAsyncStreamedFile):
796 def __init__(
797 self, fs, url, mode="rb", loop=None, session=None, size=None, **kwargs
798 ):
799 self.url = url
800 self.session = session
801 self.r = None
802 if mode != "rb":
803 raise ValueError
804 self.details = {"name": url, "size": None}
805 self.kwargs = kwargs
806 super().__init__(fs=fs, path=url, mode=mode, cache_type="none")
807 self.size = size
808
809 async def read(self, num=-1):
810 if self.r is None:
811 r = await self.session.get(
812 self.fs.encode_url(self.url), **self.kwargs
813 ).__aenter__()
814 self.fs._raise_not_found_for_status(r, self.url)
815 self.r = r
816 out = await self.r.content.read(num)
817 self.loc += len(out)
818 return out
819
820 async def close(self):
821 if self.r is not None:
822 self.r.close()
823 self.r = None
824 await super().close()
825
826
827async def get_range(session, url, start, end, file=None, **kwargs):
828 # explicit get a range when we know it must be safe
829 kwargs = kwargs.copy()
830 headers = kwargs.pop("headers", {}).copy()
831 headers["Range"] = f"bytes={start}-{end - 1}"
832 r = await session.get(url, headers=headers, **kwargs)
833 r.raise_for_status()
834 async with r:
835 out = await r.read()
836 if file:
837 with open(file, "r+b") as f: # noqa: ASYNC230
838 f.seek(start)
839 f.write(out)
840 else:
841 return out
842
843
844async def _file_info(url, session, size_policy="head", **kwargs):
845 """Call HEAD on the server to get details about the file (size/checksum etc.)
846
847 Default operation is to explicitly allow redirects and use encoding
848 'identity' (no compression) to get the true size of the target.
849 """
850 logger.debug("Retrieve file size for %s", url)
851 kwargs = kwargs.copy()
852 ar = kwargs.pop("allow_redirects", True)
853 head = kwargs.get("headers", {}).copy()
854 head["Accept-Encoding"] = "identity"
855 kwargs["headers"] = head
856
857 info = {}
858 if size_policy == "head":
859 r = await session.head(url, allow_redirects=ar, **kwargs)
860 elif size_policy == "get":
861 r = await session.get(url, allow_redirects=ar, **kwargs)
862 else:
863 raise TypeError(f'size_policy must be "head" or "get", got {size_policy}')
864 async with r:
865 r.raise_for_status()
866
867 if "Content-Length" in r.headers:
868 # Some servers may choose to ignore Accept-Encoding and return
869 # compressed content, in which case the returned size is unreliable.
870 if "Content-Encoding" not in r.headers or r.headers["Content-Encoding"] in [
871 "identity",
872 "",
873 ]:
874 info["size"] = int(r.headers["Content-Length"])
875 elif "Content-Range" in r.headers:
876 info["size"] = int(r.headers["Content-Range"].split("/")[1])
877
878 if "Content-Type" in r.headers:
879 info["mimetype"] = r.headers["Content-Type"].partition(";")[0]
880
881 if r.headers.get("Accept-Ranges") == "none":
882 # Some servers may explicitly discourage partial content requests, but
883 # the lack of "Accept-Ranges" does not always indicate they would fail
884 info["partial"] = False
885
886 info["url"] = str(r.url)
887
888 for checksum_field in ["ETag", "Content-MD5", "Digest", "Last-Modified"]:
889 if r.headers.get(checksum_field):
890 info[checksum_field] = r.headers[checksum_field]
891
892 return info
893
894
895async def _file_size(url, session=None, *args, **kwargs):
896 if session is None:
897 session = await get_client()
898 info = await _file_info(url, session=session, *args, **kwargs)
899 return info.get("size")
900
901
902file_size = sync_wrapper(_file_size)