1import asyncio
2import io
3import logging
4import re
5import weakref
6from copy import copy
7from urllib.parse import urlparse
8
9import aiohttp
10import yarl
11
12from fsspec.asyn import AbstractAsyncStreamedFile, AsyncFileSystem, sync, sync_wrapper
13from fsspec.callbacks import DEFAULT_CALLBACK
14from fsspec.exceptions import FSTimeoutError
15from fsspec.spec import AbstractBufferedFile
16from fsspec.utils import (
17 DEFAULT_BLOCK_SIZE,
18 glob_translate,
19 isfilelike,
20 nullcontext,
21 tokenize,
22)
23
24from ..caching import AllBytes
25
26# https://stackoverflow.com/a/15926317/3821154
27ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P<url>[^"']+)""")
28ex2 = re.compile(r"""(?P<url>http[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""")
29logger = logging.getLogger("fsspec.http")
30
31
32async def get_client(**kwargs):
33 return aiohttp.ClientSession(**kwargs)
34
35
36class HTTPFileSystem(AsyncFileSystem):
37 """
38 Simple File-System for fetching data via HTTP(S)
39
40 ``ls()`` is implemented by loading the parent page and doing a regex
41 match on the result. If simple_link=True, anything of the form
42 "http(s)://server.com/stuff?thing=other"; otherwise only links within
43 HTML href tags will be used.
44 """
45
46 sep = "/"
47
48 def __init__(
49 self,
50 simple_links=True,
51 block_size=None,
52 same_scheme=True,
53 size_policy=None,
54 cache_type="bytes",
55 cache_options=None,
56 asynchronous=False,
57 loop=None,
58 client_kwargs=None,
59 get_client=get_client,
60 encoded=False,
61 **storage_options,
62 ):
63 """
64 NB: if this is called async, you must await set_client
65
66 Parameters
67 ----------
68 block_size: int
69 Blocks to read bytes; if 0, will default to raw requests file-like
70 objects instead of HTTPFile instances
71 simple_links: bool
72 If True, will consider both HTML <a> tags and anything that looks
73 like a URL; if False, will consider only the former.
74 same_scheme: True
75 When doing ls/glob, if this is True, only consider paths that have
76 http/https matching the input URLs.
77 size_policy: this argument is deprecated
78 client_kwargs: dict
79 Passed to aiohttp.ClientSession, see
80 https://docs.aiohttp.org/en/stable/client_reference.html
81 For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
82 get_client: Callable[..., aiohttp.ClientSession]
83 A callable, which takes keyword arguments and constructs
84 an aiohttp.ClientSession. Its state will be managed by
85 the HTTPFileSystem class.
86 storage_options: key-value
87 Any other parameters passed on to requests
88 cache_type, cache_options: defaults used in open()
89 """
90 super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options)
91 self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE
92 self.simple_links = simple_links
93 self.same_schema = same_scheme
94 self.cache_type = cache_type
95 self.cache_options = cache_options
96 self.client_kwargs = client_kwargs or {}
97 self.get_client = get_client
98 self.encoded = encoded
99 self.kwargs = storage_options
100 self._session = None
101
102 # Clean caching-related parameters from `storage_options`
103 # before propagating them as `request_options` through `self.kwargs`.
104 # TODO: Maybe rename `self.kwargs` to `self.request_options` to make
105 # it clearer.
106 request_options = copy(storage_options)
107 self.use_listings_cache = request_options.pop("use_listings_cache", False)
108 request_options.pop("listings_expiry_time", None)
109 request_options.pop("max_paths", None)
110 request_options.pop("skip_instance_cache", None)
111 self.kwargs = request_options
112
113 @property
114 def fsid(self):
115 return "http"
116
117 def encode_url(self, url):
118 return yarl.URL(url, encoded=self.encoded)
119
120 @staticmethod
121 def close_session(loop, session):
122 if loop is not None and loop.is_running():
123 try:
124 sync(loop, session.close, timeout=0.1)
125 return
126 except (TimeoutError, FSTimeoutError, NotImplementedError):
127 pass
128 connector = getattr(session, "_connector", None)
129 if connector is not None:
130 # close after loop is dead
131 connector._close()
132
133 async def set_session(self):
134 if self._session is None:
135 self._session = await self.get_client(loop=self.loop, **self.client_kwargs)
136 if not self.asynchronous:
137 weakref.finalize(self, self.close_session, self.loop, self._session)
138 return self._session
139
140 @classmethod
141 def _strip_protocol(cls, path):
142 """For HTTP, we always want to keep the full URL"""
143 return path
144
145 @classmethod
146 def _parent(cls, path):
147 # override, since _strip_protocol is different for URLs
148 par = super()._parent(path)
149 if len(par) > 7: # "http://..."
150 return par
151 return ""
152
153 async def _ls_real(self, url, detail=True, **kwargs):
154 # ignoring URL-encoded arguments
155 kw = self.kwargs.copy()
156 kw.update(kwargs)
157 logger.debug(url)
158 session = await self.set_session()
159 async with session.get(self.encode_url(url), **self.kwargs) as r:
160 self._raise_not_found_for_status(r, url)
161
162 if "Content-Type" in r.headers:
163 mimetype = r.headers["Content-Type"].partition(";")[0]
164 else:
165 mimetype = None
166
167 if mimetype in ("text/html", None):
168 try:
169 text = await r.text(errors="ignore")
170 if self.simple_links:
171 links = ex2.findall(text) + [u[2] for u in ex.findall(text)]
172 else:
173 links = [u[2] for u in ex.findall(text)]
174 except UnicodeDecodeError:
175 links = [] # binary, not HTML
176 else:
177 links = []
178
179 out = set()
180 parts = urlparse(url)
181 for l in links:
182 if isinstance(l, tuple):
183 l = l[1]
184 if l.startswith("/") and len(l) > 1:
185 # absolute URL on this server
186 l = f"{parts.scheme}://{parts.netloc}{l}"
187 if l.startswith("http"):
188 if self.same_schema and l.startswith(url.rstrip("/") + "/"):
189 out.add(l)
190 elif l.replace("https", "http").startswith(
191 url.replace("https", "http").rstrip("/") + "/"
192 ):
193 # allowed to cross http <-> https
194 out.add(l)
195 else:
196 if l not in ["..", "../"]:
197 # Ignore FTP-like "parent"
198 out.add("/".join([url.rstrip("/"), l.lstrip("/")]))
199 if not out and url.endswith("/"):
200 out = await self._ls_real(url.rstrip("/"), detail=False)
201 if detail:
202 return [
203 {
204 "name": u,
205 "size": None,
206 "type": "directory" if u.endswith("/") else "file",
207 }
208 for u in out
209 ]
210 else:
211 return sorted(out)
212
213 async def _ls(self, url, detail=True, **kwargs):
214 if self.use_listings_cache and url in self.dircache:
215 out = self.dircache[url]
216 else:
217 out = await self._ls_real(url, detail=detail, **kwargs)
218 self.dircache[url] = out
219 return out
220
221 ls = sync_wrapper(_ls)
222
223 def _raise_not_found_for_status(self, response, url):
224 """
225 Raises FileNotFoundError for 404s, otherwise uses raise_for_status.
226 """
227 if response.status == 404:
228 raise FileNotFoundError(url)
229 response.raise_for_status()
230
231 async def _cat_file(self, url, start=None, end=None, **kwargs):
232 kw = self.kwargs.copy()
233 kw.update(kwargs)
234 logger.debug(url)
235
236 if start is not None or end is not None:
237 if start == end:
238 return b""
239 headers = kw.pop("headers", {}).copy()
240
241 headers["Range"] = await self._process_limits(url, start, end)
242 kw["headers"] = headers
243 session = await self.set_session()
244 async with session.get(self.encode_url(url), **kw) as r:
245 out = await r.read()
246 self._raise_not_found_for_status(r, url)
247 return out
248
249 async def _get_file(
250 self, rpath, lpath, chunk_size=5 * 2**20, callback=DEFAULT_CALLBACK, **kwargs
251 ):
252 kw = self.kwargs.copy()
253 kw.update(kwargs)
254 logger.debug(rpath)
255 session = await self.set_session()
256 async with session.get(self.encode_url(rpath), **kw) as r:
257 try:
258 size = int(r.headers["content-length"])
259 except (ValueError, KeyError):
260 size = None
261
262 callback.set_size(size)
263 self._raise_not_found_for_status(r, rpath)
264 if isfilelike(lpath):
265 outfile = lpath
266 else:
267 outfile = open(lpath, "wb") # noqa: ASYNC230
268
269 try:
270 chunk = True
271 while chunk:
272 chunk = await r.content.read(chunk_size)
273 outfile.write(chunk)
274 callback.relative_update(len(chunk))
275 finally:
276 if not isfilelike(lpath):
277 outfile.close()
278
279 async def _put_file(
280 self,
281 lpath,
282 rpath,
283 chunk_size=5 * 2**20,
284 callback=DEFAULT_CALLBACK,
285 method="post",
286 mode="overwrite",
287 **kwargs,
288 ):
289 if mode != "overwrite":
290 raise NotImplementedError("Exclusive write")
291
292 async def gen_chunks():
293 # Support passing arbitrary file-like objects
294 # and use them instead of streams.
295 if isinstance(lpath, io.IOBase):
296 context = nullcontext(lpath)
297 use_seek = False # might not support seeking
298 else:
299 context = open(lpath, "rb") # noqa: ASYNC230
300 use_seek = True
301
302 with context as f:
303 if use_seek:
304 callback.set_size(f.seek(0, 2))
305 f.seek(0)
306 else:
307 callback.set_size(getattr(f, "size", None))
308
309 chunk = f.read(chunk_size)
310 while chunk:
311 yield chunk
312 callback.relative_update(len(chunk))
313 chunk = f.read(chunk_size)
314
315 kw = self.kwargs.copy()
316 kw.update(kwargs)
317 session = await self.set_session()
318
319 method = method.lower()
320 if method not in ("post", "put"):
321 raise ValueError(
322 f"method has to be either 'post' or 'put', not: {method!r}"
323 )
324
325 meth = getattr(session, method)
326 async with meth(self.encode_url(rpath), data=gen_chunks(), **kw) as resp:
327 self._raise_not_found_for_status(resp, rpath)
328
329 async def _exists(self, path, **kwargs):
330 kw = self.kwargs.copy()
331 kw.update(kwargs)
332 try:
333 logger.debug(path)
334 session = await self.set_session()
335 r = await session.get(self.encode_url(path), **kw)
336 async with r:
337 return r.status < 400
338 except aiohttp.ClientError:
339 return False
340
341 async def _isfile(self, path, **kwargs):
342 return await self._exists(path, **kwargs)
343
344 def _open(
345 self,
346 path,
347 mode="rb",
348 block_size=None,
349 autocommit=None, # XXX: This differs from the base class.
350 cache_type=None,
351 cache_options=None,
352 size=None,
353 **kwargs,
354 ):
355 """Make a file-like object
356
357 Parameters
358 ----------
359 path: str
360 Full URL with protocol
361 mode: string
362 must be "rb"
363 block_size: int or None
364 Bytes to download in one request; use instance value if None. If
365 zero, will return a streaming Requests file-like instance.
366 kwargs: key-value
367 Any other parameters, passed to requests calls
368 """
369 if mode != "rb":
370 raise NotImplementedError
371 block_size = block_size if block_size is not None else self.block_size
372 kw = self.kwargs.copy()
373 kw["asynchronous"] = self.asynchronous
374 kw.update(kwargs)
375 info = {}
376 size = size or info.update(self.info(path, **kwargs)) or info["size"]
377 session = sync(self.loop, self.set_session)
378 if block_size and size and info.get("partial", True):
379 return HTTPFile(
380 self,
381 path,
382 session=session,
383 block_size=block_size,
384 mode=mode,
385 size=size,
386 cache_type=cache_type or self.cache_type,
387 cache_options=cache_options or self.cache_options,
388 loop=self.loop,
389 **kw,
390 )
391 else:
392 return HTTPStreamFile(
393 self,
394 path,
395 mode=mode,
396 loop=self.loop,
397 session=session,
398 **kw,
399 )
400
401 async def open_async(self, path, mode="rb", size=None, **kwargs):
402 session = await self.set_session()
403 if size is None:
404 try:
405 size = (await self._info(path, **kwargs))["size"]
406 except FileNotFoundError:
407 pass
408 return AsyncStreamFile(
409 self,
410 path,
411 loop=self.loop,
412 session=session,
413 size=size,
414 **kwargs,
415 )
416
417 def ukey(self, url):
418 """Unique identifier; assume HTTP files are static, unchanging"""
419 return tokenize(url, self.kwargs, self.protocol)
420
421 async def _info(self, url, **kwargs):
422 """Get info of URL
423
424 Tries to access location via HEAD, and then GET methods, but does
425 not fetch the data.
426
427 It is possible that the server does not supply any size information, in
428 which case size will be given as None (and certain operations on the
429 corresponding file will not work).
430 """
431 info = {}
432 session = await self.set_session()
433
434 for policy in ["head", "get"]:
435 try:
436 info.update(
437 await _file_info(
438 self.encode_url(url),
439 size_policy=policy,
440 session=session,
441 **self.kwargs,
442 **kwargs,
443 )
444 )
445 if info.get("size") is not None:
446 break
447 except Exception as exc:
448 if policy == "get":
449 # If get failed, then raise a FileNotFoundError
450 raise FileNotFoundError(url) from exc
451 logger.debug("", exc_info=exc)
452
453 return {"name": url, "size": None, **info, "type": "file"}
454
455 async def _glob(self, path, maxdepth=None, **kwargs):
456 """
457 Find files by glob-matching.
458
459 This implementation is idntical to the one in AbstractFileSystem,
460 but "?" is not considered as a character for globbing, because it is
461 so common in URLs, often identifying the "query" part.
462 """
463 if maxdepth is not None and maxdepth < 1:
464 raise ValueError("maxdepth must be at least 1")
465 import re
466
467 ends_with_slash = path.endswith("/") # _strip_protocol strips trailing slash
468 path = self._strip_protocol(path)
469 append_slash_to_dirname = ends_with_slash or path.endswith(("/**", "/*"))
470 idx_star = path.find("*") if path.find("*") >= 0 else len(path)
471 idx_brace = path.find("[") if path.find("[") >= 0 else len(path)
472
473 min_idx = min(idx_star, idx_brace)
474
475 detail = kwargs.pop("detail", False)
476
477 if not has_magic(path):
478 if await self._exists(path, **kwargs):
479 if not detail:
480 return [path]
481 else:
482 return {path: await self._info(path, **kwargs)}
483 else:
484 if not detail:
485 return [] # glob of non-existent returns empty
486 else:
487 return {}
488 elif "/" in path[:min_idx]:
489 min_idx = path[:min_idx].rindex("/")
490 root = path[: min_idx + 1]
491 depth = path[min_idx + 1 :].count("/") + 1
492 else:
493 root = ""
494 depth = path[min_idx + 1 :].count("/") + 1
495
496 if "**" in path:
497 if maxdepth is not None:
498 idx_double_stars = path.find("**")
499 depth_double_stars = path[idx_double_stars:].count("/") + 1
500 depth = depth - depth_double_stars + maxdepth
501 else:
502 depth = None
503
504 allpaths = await self._find(
505 root, maxdepth=depth, withdirs=True, detail=True, **kwargs
506 )
507
508 pattern = glob_translate(path + ("/" if ends_with_slash else ""))
509 pattern = re.compile(pattern)
510
511 out = {
512 (
513 p.rstrip("/")
514 if not append_slash_to_dirname
515 and info["type"] == "directory"
516 and p.endswith("/")
517 else p
518 ): info
519 for p, info in sorted(allpaths.items())
520 if pattern.match(p.rstrip("/"))
521 }
522
523 if detail:
524 return out
525 else:
526 return list(out)
527
528 async def _isdir(self, path):
529 # override, since all URLs are (also) files
530 try:
531 return bool(await self._ls(path))
532 except (FileNotFoundError, ValueError):
533 return False
534
535 async def _pipe_file(self, path, value, mode="overwrite", **kwargs):
536 """
537 Write bytes to a remote file over HTTP.
538
539 Parameters
540 ----------
541 path : str
542 Target URL where the data should be written
543 value : bytes
544 Data to be written
545 mode : str
546 How to write to the file - 'overwrite' or 'append'
547 **kwargs : dict
548 Additional parameters to pass to the HTTP request
549 """
550 url = self._strip_protocol(path)
551 headers = kwargs.pop("headers", {})
552 headers["Content-Length"] = str(len(value))
553
554 session = await self.set_session()
555
556 async with session.put(url, data=value, headers=headers, **kwargs) as r:
557 r.raise_for_status()
558
559
560class HTTPFile(AbstractBufferedFile):
561 """
562 A file-like object pointing to a remote HTTP(S) resource
563
564 Supports only reading, with read-ahead of a predetermined block-size.
565
566 In the case that the server does not supply the filesize, only reading of
567 the complete file in one go is supported.
568
569 Parameters
570 ----------
571 url: str
572 Full URL of the remote resource, including the protocol
573 session: aiohttp.ClientSession or None
574 All calls will be made within this session, to avoid restarting
575 connections where the server allows this
576 block_size: int or None
577 The amount of read-ahead to do, in bytes. Default is 5MB, or the value
578 configured for the FileSystem creating this file
579 size: None or int
580 If given, this is the size of the file in bytes, and we don't attempt
581 to call the server to find the value.
582 kwargs: all other key-values are passed to requests calls.
583 """
584
585 def __init__(
586 self,
587 fs,
588 url,
589 session=None,
590 block_size=None,
591 mode="rb",
592 cache_type="bytes",
593 cache_options=None,
594 size=None,
595 loop=None,
596 asynchronous=False,
597 **kwargs,
598 ):
599 if mode != "rb":
600 raise NotImplementedError("File mode not supported")
601 self.asynchronous = asynchronous
602 self.loop = loop
603 self.url = url
604 self.session = session
605 self.details = {"name": url, "size": size, "type": "file"}
606 super().__init__(
607 fs=fs,
608 path=url,
609 mode=mode,
610 block_size=block_size,
611 cache_type=cache_type,
612 cache_options=cache_options,
613 **kwargs,
614 )
615
616 def read(self, length=-1):
617 """Read bytes from file
618
619 Parameters
620 ----------
621 length: int
622 Read up to this many bytes. If negative, read all content to end of
623 file. If the server has not supplied the filesize, attempting to
624 read only part of the data will raise a ValueError.
625 """
626 if (
627 (length < 0 and self.loc == 0) # explicit read all
628 # but not when the size is known and fits into a block anyways
629 and not (self.size is not None and self.size <= self.blocksize)
630 ):
631 self._fetch_all()
632 if self.size is None:
633 if length < 0:
634 self._fetch_all()
635 else:
636 length = min(self.size - self.loc, length)
637 return super().read(length)
638
639 async def async_fetch_all(self):
640 """Read whole file in one shot, without caching
641
642 This is only called when position is still at zero,
643 and read() is called without a byte-count.
644 """
645 logger.debug(f"Fetch all for {self}")
646 if not isinstance(self.cache, AllBytes):
647 r = await self.session.get(self.fs.encode_url(self.url), **self.kwargs)
648 async with r:
649 r.raise_for_status()
650 out = await r.read()
651 self.cache = AllBytes(
652 size=len(out), fetcher=None, blocksize=None, data=out
653 )
654 self.size = len(out)
655
656 _fetch_all = sync_wrapper(async_fetch_all)
657
658 def _parse_content_range(self, headers):
659 """Parse the Content-Range header"""
660 s = headers.get("Content-Range", "")
661 m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s)
662 if not m:
663 return None, None, None
664
665 if m[1] == "*":
666 start = end = None
667 else:
668 start, end = [int(x) for x in m[1].split("-")]
669 total = None if m[2] == "*" else int(m[2])
670 return start, end, total
671
672 async def async_fetch_range(self, start, end):
673 """Download a block of data
674
675 The expectation is that the server returns only the requested bytes,
676 with HTTP code 206. If this is not the case, we first check the headers,
677 and then stream the output - if the data size is bigger than we
678 requested, an exception is raised.
679 """
680 logger.debug(f"Fetch range for {self}: {start}-{end}")
681 kwargs = self.kwargs.copy()
682 headers = kwargs.pop("headers", {}).copy()
683 headers["Range"] = f"bytes={start}-{end - 1}"
684 logger.debug(f"{self.url} : {headers['Range']}")
685 r = await self.session.get(
686 self.fs.encode_url(self.url), headers=headers, **kwargs
687 )
688 async with r:
689 if r.status == 416:
690 # range request outside file
691 return b""
692 r.raise_for_status()
693
694 # If the server has handled the range request, it should reply
695 # with status 206 (partial content). But we'll guess that a suitable
696 # Content-Range header or a Content-Length no more than the
697 # requested range also mean we have got the desired range.
698 response_is_range = (
699 r.status == 206
700 or self._parse_content_range(r.headers)[0] == start
701 or int(r.headers.get("Content-Length", end + 1)) <= end - start
702 )
703
704 if response_is_range:
705 # partial content, as expected
706 out = await r.read()
707 elif start > 0:
708 raise ValueError(
709 "The HTTP server doesn't appear to support range requests. "
710 "Only reading this file from the beginning is supported. "
711 "Open with block_size=0 for a streaming file interface."
712 )
713 else:
714 # Response is not a range, but we want the start of the file,
715 # so we can read the required amount anyway.
716 cl = 0
717 out = []
718 while True:
719 chunk = await r.content.read(2**20)
720 # data size unknown, let's read until we have enough
721 if chunk:
722 out.append(chunk)
723 cl += len(chunk)
724 if cl > end - start:
725 break
726 else:
727 break
728 out = b"".join(out)[: end - start]
729 return out
730
731 _fetch_range = sync_wrapper(async_fetch_range)
732
733
734magic_check = re.compile("([*[])")
735
736
737def has_magic(s):
738 match = magic_check.search(s)
739 return match is not None
740
741
742class HTTPStreamFile(AbstractBufferedFile):
743 def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs):
744 self.asynchronous = kwargs.pop("asynchronous", False)
745 self.url = url
746 self.loop = loop
747 self.session = session
748 if mode != "rb":
749 raise ValueError
750 self.details = {"name": url, "size": None}
751 super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs)
752
753 async def cor():
754 r = await self.session.get(self.fs.encode_url(url), **kwargs).__aenter__()
755 self.fs._raise_not_found_for_status(r, url)
756 return r
757
758 self.r = sync(self.loop, cor)
759 self.loop = fs.loop
760
761 def seek(self, loc, whence=0):
762 if loc == 0 and whence == 1:
763 return
764 if loc == self.loc and whence == 0:
765 return
766 raise ValueError("Cannot seek streaming HTTP file")
767
768 async def _read(self, num=-1):
769 out = await self.r.content.read(num)
770 self.loc += len(out)
771 return out
772
773 read = sync_wrapper(_read)
774
775 async def _close(self):
776 self.r.close()
777
778 def close(self):
779 asyncio.run_coroutine_threadsafe(self._close(), self.loop)
780 super().close()
781
782
783class AsyncStreamFile(AbstractAsyncStreamedFile):
784 def __init__(
785 self, fs, url, mode="rb", loop=None, session=None, size=None, **kwargs
786 ):
787 self.url = url
788 self.session = session
789 self.r = None
790 if mode != "rb":
791 raise ValueError
792 self.details = {"name": url, "size": None}
793 self.kwargs = kwargs
794 super().__init__(fs=fs, path=url, mode=mode, cache_type="none")
795 self.size = size
796
797 async def read(self, num=-1):
798 if self.r is None:
799 r = await self.session.get(
800 self.fs.encode_url(self.url), **self.kwargs
801 ).__aenter__()
802 self.fs._raise_not_found_for_status(r, self.url)
803 self.r = r
804 out = await self.r.content.read(num)
805 self.loc += len(out)
806 return out
807
808 async def close(self):
809 if self.r is not None:
810 self.r.close()
811 self.r = None
812 await super().close()
813
814
815async def get_range(session, url, start, end, file=None, **kwargs):
816 # explicit get a range when we know it must be safe
817 kwargs = kwargs.copy()
818 headers = kwargs.pop("headers", {}).copy()
819 headers["Range"] = f"bytes={start}-{end - 1}"
820 r = await session.get(url, headers=headers, **kwargs)
821 r.raise_for_status()
822 async with r:
823 out = await r.read()
824 if file:
825 with open(file, "r+b") as f: # noqa: ASYNC230
826 f.seek(start)
827 f.write(out)
828 else:
829 return out
830
831
832async def _file_info(url, session, size_policy="head", **kwargs):
833 """Call HEAD on the server to get details about the file (size/checksum etc.)
834
835 Default operation is to explicitly allow redirects and use encoding
836 'identity' (no compression) to get the true size of the target.
837 """
838 logger.debug("Retrieve file size for %s", url)
839 kwargs = kwargs.copy()
840 ar = kwargs.pop("allow_redirects", True)
841 head = kwargs.get("headers", {}).copy()
842 head["Accept-Encoding"] = "identity"
843 kwargs["headers"] = head
844
845 info = {}
846 if size_policy == "head":
847 r = await session.head(url, allow_redirects=ar, **kwargs)
848 elif size_policy == "get":
849 r = await session.get(url, allow_redirects=ar, **kwargs)
850 else:
851 raise TypeError(f'size_policy must be "head" or "get", got {size_policy}')
852 async with r:
853 r.raise_for_status()
854
855 if "Content-Length" in r.headers:
856 # Some servers may choose to ignore Accept-Encoding and return
857 # compressed content, in which case the returned size is unreliable.
858 if "Content-Encoding" not in r.headers or r.headers["Content-Encoding"] in [
859 "identity",
860 "",
861 ]:
862 info["size"] = int(r.headers["Content-Length"])
863 elif "Content-Range" in r.headers:
864 info["size"] = int(r.headers["Content-Range"].split("/")[1])
865
866 if "Content-Type" in r.headers:
867 info["mimetype"] = r.headers["Content-Type"].partition(";")[0]
868
869 if r.headers.get("Accept-Ranges") == "none":
870 # Some servers may explicitly discourage partial content requests, but
871 # the lack of "Accept-Ranges" does not always indicate they would fail
872 info["partial"] = False
873
874 info["url"] = str(r.url)
875
876 for checksum_field in ["ETag", "Content-MD5", "Digest"]:
877 if r.headers.get(checksum_field):
878 info[checksum_field] = r.headers[checksum_field]
879
880 return info
881
882
883async def _file_size(url, session=None, *args, **kwargs):
884 if session is None:
885 session = await get_client()
886 info = await _file_info(url, session=session, *args, **kwargs)
887 return info.get("size")
888
889
890file_size = sync_wrapper(_file_size)