1import asyncio
2import io
3import logging
4import re
5import weakref
6from copy import copy
7from urllib.parse import urlparse
8
9import aiohttp
10import yarl
11
12from fsspec.asyn import AbstractAsyncStreamedFile, AsyncFileSystem, sync, sync_wrapper
13from fsspec.callbacks import DEFAULT_CALLBACK
14from fsspec.exceptions import FSTimeoutError
15from fsspec.spec import AbstractBufferedFile
16from fsspec.utils import (
17 DEFAULT_BLOCK_SIZE,
18 glob_translate,
19 isfilelike,
20 nullcontext,
21 tokenize,
22)
23
24from ..caching import AllBytes
25
26# https://stackoverflow.com/a/15926317/3821154
27ex = re.compile(r"""<(a|A)\s+(?:[^>]*?\s+)?(href|HREF)=["'](?P<url>[^"']+)""")
28ex2 = re.compile(r"""(?P<url>http[s]?://[-a-zA-Z0-9@:%_+.~#?&/=]+)""")
29logger = logging.getLogger("fsspec.http")
30
31
32async def get_client(**kwargs):
33 return aiohttp.ClientSession(**kwargs)
34
35
36class HTTPFileSystem(AsyncFileSystem):
37 """
38 Simple File-System for fetching data via HTTP(S)
39
40 ``ls()`` is implemented by loading the parent page and doing a regex
41 match on the result. If simple_link=True, anything of the form
42 "http(s)://server.com/stuff?thing=other"; otherwise only links within
43 HTML href tags will be used.
44 """
45
46 sep = "/"
47
48 def __init__(
49 self,
50 simple_links=True,
51 block_size=None,
52 same_scheme=True,
53 size_policy=None,
54 cache_type="bytes",
55 cache_options=None,
56 asynchronous=False,
57 loop=None,
58 client_kwargs=None,
59 get_client=get_client,
60 encoded=False,
61 **storage_options,
62 ):
63 """
64 NB: if this is called async, you must await set_client
65
66 Parameters
67 ----------
68 block_size: int
69 Blocks to read bytes; if 0, will default to raw requests file-like
70 objects instead of HTTPFile instances
71 simple_links: bool
72 If True, will consider both HTML <a> tags and anything that looks
73 like a URL; if False, will consider only the former.
74 same_scheme: True
75 When doing ls/glob, if this is True, only consider paths that have
76 http/https matching the input URLs.
77 size_policy: this argument is deprecated
78 client_kwargs: dict
79 Passed to aiohttp.ClientSession, see
80 https://docs.aiohttp.org/en/stable/client_reference.html
81 For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
82 get_client: Callable[..., aiohttp.ClientSession]
83 A callable which takes keyword arguments and constructs
84 an aiohttp.ClientSession. It's state will be managed by
85 the HTTPFileSystem class.
86 storage_options: key-value
87 Any other parameters passed on to requests
88 cache_type, cache_options: defaults used in open
89 """
90 super().__init__(self, asynchronous=asynchronous, loop=loop, **storage_options)
91 self.block_size = block_size if block_size is not None else DEFAULT_BLOCK_SIZE
92 self.simple_links = simple_links
93 self.same_schema = same_scheme
94 self.cache_type = cache_type
95 self.cache_options = cache_options
96 self.client_kwargs = client_kwargs or {}
97 self.get_client = get_client
98 self.encoded = encoded
99 self.kwargs = storage_options
100 self._session = None
101
102 # Clean caching-related parameters from `storage_options`
103 # before propagating them as `request_options` through `self.kwargs`.
104 # TODO: Maybe rename `self.kwargs` to `self.request_options` to make
105 # it clearer.
106 request_options = copy(storage_options)
107 self.use_listings_cache = request_options.pop("use_listings_cache", False)
108 request_options.pop("listings_expiry_time", None)
109 request_options.pop("max_paths", None)
110 request_options.pop("skip_instance_cache", None)
111 self.kwargs = request_options
112
113 @property
114 def fsid(self):
115 return "http"
116
117 def encode_url(self, url):
118 return yarl.URL(url, encoded=self.encoded)
119
120 @staticmethod
121 def close_session(loop, session):
122 if loop is not None and loop.is_running():
123 try:
124 sync(loop, session.close, timeout=0.1)
125 return
126 except (TimeoutError, FSTimeoutError, NotImplementedError):
127 pass
128 connector = getattr(session, "_connector", None)
129 if connector is not None:
130 # close after loop is dead
131 connector._close()
132
133 async def set_session(self):
134 if self._session is None:
135 self._session = await self.get_client(loop=self.loop, **self.client_kwargs)
136 if not self.asynchronous:
137 weakref.finalize(self, self.close_session, self.loop, self._session)
138 return self._session
139
140 @classmethod
141 def _strip_protocol(cls, path):
142 """For HTTP, we always want to keep the full URL"""
143 return path
144
145 @classmethod
146 def _parent(cls, path):
147 # override, since _strip_protocol is different for URLs
148 par = super()._parent(path)
149 if len(par) > 7: # "http://..."
150 return par
151 return ""
152
153 async def _ls_real(self, url, detail=True, **kwargs):
154 # ignoring URL-encoded arguments
155 kw = self.kwargs.copy()
156 kw.update(kwargs)
157 logger.debug(url)
158 session = await self.set_session()
159 async with session.get(self.encode_url(url), **self.kwargs) as r:
160 self._raise_not_found_for_status(r, url)
161 try:
162 text = await r.text()
163 if self.simple_links:
164 links = ex2.findall(text) + [u[2] for u in ex.findall(text)]
165 else:
166 links = [u[2] for u in ex.findall(text)]
167 except UnicodeDecodeError:
168 links = [] # binary, not HTML
169 out = set()
170 parts = urlparse(url)
171 for l in links:
172 if isinstance(l, tuple):
173 l = l[1]
174 if l.startswith("/") and len(l) > 1:
175 # absolute URL on this server
176 l = f"{parts.scheme}://{parts.netloc}{l}"
177 if l.startswith("http"):
178 if self.same_schema and l.startswith(url.rstrip("/") + "/"):
179 out.add(l)
180 elif l.replace("https", "http").startswith(
181 url.replace("https", "http").rstrip("/") + "/"
182 ):
183 # allowed to cross http <-> https
184 out.add(l)
185 else:
186 if l not in ["..", "../"]:
187 # Ignore FTP-like "parent"
188 out.add("/".join([url.rstrip("/"), l.lstrip("/")]))
189 if not out and url.endswith("/"):
190 out = await self._ls_real(url.rstrip("/"), detail=False)
191 if detail:
192 return [
193 {
194 "name": u,
195 "size": None,
196 "type": "directory" if u.endswith("/") else "file",
197 }
198 for u in out
199 ]
200 else:
201 return sorted(out)
202
203 async def _ls(self, url, detail=True, **kwargs):
204 if self.use_listings_cache and url in self.dircache:
205 out = self.dircache[url]
206 else:
207 out = await self._ls_real(url, detail=detail, **kwargs)
208 self.dircache[url] = out
209 return out
210
211 ls = sync_wrapper(_ls)
212
213 def _raise_not_found_for_status(self, response, url):
214 """
215 Raises FileNotFoundError for 404s, otherwise uses raise_for_status.
216 """
217 if response.status == 404:
218 raise FileNotFoundError(url)
219 response.raise_for_status()
220
221 async def _cat_file(self, url, start=None, end=None, **kwargs):
222 kw = self.kwargs.copy()
223 kw.update(kwargs)
224 logger.debug(url)
225
226 if start is not None or end is not None:
227 if start == end:
228 return b""
229 headers = kw.pop("headers", {}).copy()
230
231 headers["Range"] = await self._process_limits(url, start, end)
232 kw["headers"] = headers
233 session = await self.set_session()
234 async with session.get(self.encode_url(url), **kw) as r:
235 out = await r.read()
236 self._raise_not_found_for_status(r, url)
237 return out
238
239 async def _get_file(
240 self, rpath, lpath, chunk_size=5 * 2**20, callback=DEFAULT_CALLBACK, **kwargs
241 ):
242 kw = self.kwargs.copy()
243 kw.update(kwargs)
244 logger.debug(rpath)
245 session = await self.set_session()
246 async with session.get(self.encode_url(rpath), **kw) as r:
247 try:
248 size = int(r.headers["content-length"])
249 except (ValueError, KeyError):
250 size = None
251
252 callback.set_size(size)
253 self._raise_not_found_for_status(r, rpath)
254 if isfilelike(lpath):
255 outfile = lpath
256 else:
257 outfile = open(lpath, "wb") # noqa: ASYNC230
258
259 try:
260 chunk = True
261 while chunk:
262 chunk = await r.content.read(chunk_size)
263 outfile.write(chunk)
264 callback.relative_update(len(chunk))
265 finally:
266 if not isfilelike(lpath):
267 outfile.close()
268
269 async def _put_file(
270 self,
271 lpath,
272 rpath,
273 chunk_size=5 * 2**20,
274 callback=DEFAULT_CALLBACK,
275 method="post",
276 mode="overwrite",
277 **kwargs,
278 ):
279 if mode != "overwrite":
280 raise NotImplementedError("Exclusive write")
281
282 async def gen_chunks():
283 # Support passing arbitrary file-like objects
284 # and use them instead of streams.
285 if isinstance(lpath, io.IOBase):
286 context = nullcontext(lpath)
287 use_seek = False # might not support seeking
288 else:
289 context = open(lpath, "rb") # noqa: ASYNC230
290 use_seek = True
291
292 with context as f:
293 if use_seek:
294 callback.set_size(f.seek(0, 2))
295 f.seek(0)
296 else:
297 callback.set_size(getattr(f, "size", None))
298
299 chunk = f.read(chunk_size)
300 while chunk:
301 yield chunk
302 callback.relative_update(len(chunk))
303 chunk = f.read(chunk_size)
304
305 kw = self.kwargs.copy()
306 kw.update(kwargs)
307 session = await self.set_session()
308
309 method = method.lower()
310 if method not in ("post", "put"):
311 raise ValueError(
312 f"method has to be either 'post' or 'put', not: {method!r}"
313 )
314
315 meth = getattr(session, method)
316 async with meth(self.encode_url(rpath), data=gen_chunks(), **kw) as resp:
317 self._raise_not_found_for_status(resp, rpath)
318
319 async def _exists(self, path, **kwargs):
320 kw = self.kwargs.copy()
321 kw.update(kwargs)
322 try:
323 logger.debug(path)
324 session = await self.set_session()
325 r = await session.get(self.encode_url(path), **kw)
326 async with r:
327 return r.status < 400
328 except aiohttp.ClientError:
329 return False
330
331 async def _isfile(self, path, **kwargs):
332 return await self._exists(path, **kwargs)
333
334 def _open(
335 self,
336 path,
337 mode="rb",
338 block_size=None,
339 autocommit=None, # XXX: This differs from the base class.
340 cache_type=None,
341 cache_options=None,
342 size=None,
343 **kwargs,
344 ):
345 """Make a file-like object
346
347 Parameters
348 ----------
349 path: str
350 Full URL with protocol
351 mode: string
352 must be "rb"
353 block_size: int or None
354 Bytes to download in one request; use instance value if None. If
355 zero, will return a streaming Requests file-like instance.
356 kwargs: key-value
357 Any other parameters, passed to requests calls
358 """
359 if mode != "rb":
360 raise NotImplementedError
361 block_size = block_size if block_size is not None else self.block_size
362 kw = self.kwargs.copy()
363 kw["asynchronous"] = self.asynchronous
364 kw.update(kwargs)
365 info = {}
366 size = size or info.update(self.info(path, **kwargs)) or info["size"]
367 session = sync(self.loop, self.set_session)
368 if block_size and size and info.get("partial", True):
369 return HTTPFile(
370 self,
371 path,
372 session=session,
373 block_size=block_size,
374 mode=mode,
375 size=size,
376 cache_type=cache_type or self.cache_type,
377 cache_options=cache_options or self.cache_options,
378 loop=self.loop,
379 **kw,
380 )
381 else:
382 return HTTPStreamFile(
383 self,
384 path,
385 mode=mode,
386 loop=self.loop,
387 session=session,
388 **kw,
389 )
390
391 async def open_async(self, path, mode="rb", size=None, **kwargs):
392 session = await self.set_session()
393 if size is None:
394 try:
395 size = (await self._info(path, **kwargs))["size"]
396 except FileNotFoundError:
397 pass
398 return AsyncStreamFile(
399 self,
400 path,
401 loop=self.loop,
402 session=session,
403 size=size,
404 **kwargs,
405 )
406
407 def ukey(self, url):
408 """Unique identifier; assume HTTP files are static, unchanging"""
409 return tokenize(url, self.kwargs, self.protocol)
410
411 async def _info(self, url, **kwargs):
412 """Get info of URL
413
414 Tries to access location via HEAD, and then GET methods, but does
415 not fetch the data.
416
417 It is possible that the server does not supply any size information, in
418 which case size will be given as None (and certain operations on the
419 corresponding file will not work).
420 """
421 info = {}
422 session = await self.set_session()
423
424 for policy in ["head", "get"]:
425 try:
426 info.update(
427 await _file_info(
428 self.encode_url(url),
429 size_policy=policy,
430 session=session,
431 **self.kwargs,
432 **kwargs,
433 )
434 )
435 if info.get("size") is not None:
436 break
437 except Exception as exc:
438 if policy == "get":
439 # If get failed, then raise a FileNotFoundError
440 raise FileNotFoundError(url) from exc
441 logger.debug("", exc_info=exc)
442
443 return {"name": url, "size": None, **info, "type": "file"}
444
445 async def _glob(self, path, maxdepth=None, **kwargs):
446 """
447 Find files by glob-matching.
448
449 This implementation is idntical to the one in AbstractFileSystem,
450 but "?" is not considered as a character for globbing, because it is
451 so common in URLs, often identifying the "query" part.
452 """
453 if maxdepth is not None and maxdepth < 1:
454 raise ValueError("maxdepth must be at least 1")
455 import re
456
457 ends_with_slash = path.endswith("/") # _strip_protocol strips trailing slash
458 path = self._strip_protocol(path)
459 append_slash_to_dirname = ends_with_slash or path.endswith(("/**", "/*"))
460 idx_star = path.find("*") if path.find("*") >= 0 else len(path)
461 idx_brace = path.find("[") if path.find("[") >= 0 else len(path)
462
463 min_idx = min(idx_star, idx_brace)
464
465 detail = kwargs.pop("detail", False)
466
467 if not has_magic(path):
468 if await self._exists(path, **kwargs):
469 if not detail:
470 return [path]
471 else:
472 return {path: await self._info(path, **kwargs)}
473 else:
474 if not detail:
475 return [] # glob of non-existent returns empty
476 else:
477 return {}
478 elif "/" in path[:min_idx]:
479 min_idx = path[:min_idx].rindex("/")
480 root = path[: min_idx + 1]
481 depth = path[min_idx + 1 :].count("/") + 1
482 else:
483 root = ""
484 depth = path[min_idx + 1 :].count("/") + 1
485
486 if "**" in path:
487 if maxdepth is not None:
488 idx_double_stars = path.find("**")
489 depth_double_stars = path[idx_double_stars:].count("/") + 1
490 depth = depth - depth_double_stars + maxdepth
491 else:
492 depth = None
493
494 allpaths = await self._find(
495 root, maxdepth=depth, withdirs=True, detail=True, **kwargs
496 )
497
498 pattern = glob_translate(path + ("/" if ends_with_slash else ""))
499 pattern = re.compile(pattern)
500
501 out = {
502 (
503 p.rstrip("/")
504 if not append_slash_to_dirname
505 and info["type"] == "directory"
506 and p.endswith("/")
507 else p
508 ): info
509 for p, info in sorted(allpaths.items())
510 if pattern.match(p.rstrip("/"))
511 }
512
513 if detail:
514 return out
515 else:
516 return list(out)
517
518 async def _isdir(self, path):
519 # override, since all URLs are (also) files
520 try:
521 return bool(await self._ls(path))
522 except (FileNotFoundError, ValueError):
523 return False
524
525 async def _pipe_file(self, path, value, mode="overwrite", **kwargs):
526 """
527 Write bytes to a remote file over HTTP.
528
529 Parameters
530 ----------
531 path : str
532 Target URL where the data should be written
533 value : bytes
534 Data to be written
535 mode : str
536 How to write to the file - 'overwrite' or 'append'
537 **kwargs : dict
538 Additional parameters to pass to the HTTP request
539 """
540 url = self._strip_protocol(path)
541 headers = kwargs.pop("headers", {})
542 headers["Content-Length"] = str(len(value))
543
544 session = await self.set_session()
545
546 async with session.put(url, data=value, headers=headers, **kwargs) as r:
547 r.raise_for_status()
548
549
550class HTTPFile(AbstractBufferedFile):
551 """
552 A file-like object pointing to a remote HTTP(S) resource
553
554 Supports only reading, with read-ahead of a predetermined block-size.
555
556 In the case that the server does not supply the filesize, only reading of
557 the complete file in one go is supported.
558
559 Parameters
560 ----------
561 url: str
562 Full URL of the remote resource, including the protocol
563 session: aiohttp.ClientSession or None
564 All calls will be made within this session, to avoid restarting
565 connections where the server allows this
566 block_size: int or None
567 The amount of read-ahead to do, in bytes. Default is 5MB, or the value
568 configured for the FileSystem creating this file
569 size: None or int
570 If given, this is the size of the file in bytes, and we don't attempt
571 to call the server to find the value.
572 kwargs: all other key-values are passed to requests calls.
573 """
574
575 def __init__(
576 self,
577 fs,
578 url,
579 session=None,
580 block_size=None,
581 mode="rb",
582 cache_type="bytes",
583 cache_options=None,
584 size=None,
585 loop=None,
586 asynchronous=False,
587 **kwargs,
588 ):
589 if mode != "rb":
590 raise NotImplementedError("File mode not supported")
591 self.asynchronous = asynchronous
592 self.loop = loop
593 self.url = url
594 self.session = session
595 self.details = {"name": url, "size": size, "type": "file"}
596 super().__init__(
597 fs=fs,
598 path=url,
599 mode=mode,
600 block_size=block_size,
601 cache_type=cache_type,
602 cache_options=cache_options,
603 **kwargs,
604 )
605
606 def read(self, length=-1):
607 """Read bytes from file
608
609 Parameters
610 ----------
611 length: int
612 Read up to this many bytes. If negative, read all content to end of
613 file. If the server has not supplied the filesize, attempting to
614 read only part of the data will raise a ValueError.
615 """
616 if (
617 (length < 0 and self.loc == 0) # explicit read all
618 # but not when the size is known and fits into a block anyways
619 and not (self.size is not None and self.size <= self.blocksize)
620 ):
621 self._fetch_all()
622 if self.size is None:
623 if length < 0:
624 self._fetch_all()
625 else:
626 length = min(self.size - self.loc, length)
627 return super().read(length)
628
629 async def async_fetch_all(self):
630 """Read whole file in one shot, without caching
631
632 This is only called when position is still at zero,
633 and read() is called without a byte-count.
634 """
635 logger.debug(f"Fetch all for {self}")
636 if not isinstance(self.cache, AllBytes):
637 r = await self.session.get(self.fs.encode_url(self.url), **self.kwargs)
638 async with r:
639 r.raise_for_status()
640 out = await r.read()
641 self.cache = AllBytes(
642 size=len(out), fetcher=None, blocksize=None, data=out
643 )
644 self.size = len(out)
645
646 _fetch_all = sync_wrapper(async_fetch_all)
647
648 def _parse_content_range(self, headers):
649 """Parse the Content-Range header"""
650 s = headers.get("Content-Range", "")
651 m = re.match(r"bytes (\d+-\d+|\*)/(\d+|\*)", s)
652 if not m:
653 return None, None, None
654
655 if m[1] == "*":
656 start = end = None
657 else:
658 start, end = [int(x) for x in m[1].split("-")]
659 total = None if m[2] == "*" else int(m[2])
660 return start, end, total
661
662 async def async_fetch_range(self, start, end):
663 """Download a block of data
664
665 The expectation is that the server returns only the requested bytes,
666 with HTTP code 206. If this is not the case, we first check the headers,
667 and then stream the output - if the data size is bigger than we
668 requested, an exception is raised.
669 """
670 logger.debug(f"Fetch range for {self}: {start}-{end}")
671 kwargs = self.kwargs.copy()
672 headers = kwargs.pop("headers", {}).copy()
673 headers["Range"] = f"bytes={start}-{end - 1}"
674 logger.debug(f"{self.url} : {headers['Range']}")
675 r = await self.session.get(
676 self.fs.encode_url(self.url), headers=headers, **kwargs
677 )
678 async with r:
679 if r.status == 416:
680 # range request outside file
681 return b""
682 r.raise_for_status()
683
684 # If the server has handled the range request, it should reply
685 # with status 206 (partial content). But we'll guess that a suitable
686 # Content-Range header or a Content-Length no more than the
687 # requested range also mean we have got the desired range.
688 response_is_range = (
689 r.status == 206
690 or self._parse_content_range(r.headers)[0] == start
691 or int(r.headers.get("Content-Length", end + 1)) <= end - start
692 )
693
694 if response_is_range:
695 # partial content, as expected
696 out = await r.read()
697 elif start > 0:
698 raise ValueError(
699 "The HTTP server doesn't appear to support range requests. "
700 "Only reading this file from the beginning is supported. "
701 "Open with block_size=0 for a streaming file interface."
702 )
703 else:
704 # Response is not a range, but we want the start of the file,
705 # so we can read the required amount anyway.
706 cl = 0
707 out = []
708 while True:
709 chunk = await r.content.read(2**20)
710 # data size unknown, let's read until we have enough
711 if chunk:
712 out.append(chunk)
713 cl += len(chunk)
714 if cl > end - start:
715 break
716 else:
717 break
718 out = b"".join(out)[: end - start]
719 return out
720
721 _fetch_range = sync_wrapper(async_fetch_range)
722
723
724magic_check = re.compile("([*[])")
725
726
727def has_magic(s):
728 match = magic_check.search(s)
729 return match is not None
730
731
732class HTTPStreamFile(AbstractBufferedFile):
733 def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs):
734 self.asynchronous = kwargs.pop("asynchronous", False)
735 self.url = url
736 self.loop = loop
737 self.session = session
738 if mode != "rb":
739 raise ValueError
740 self.details = {"name": url, "size": None}
741 super().__init__(fs=fs, path=url, mode=mode, cache_type="none", **kwargs)
742
743 async def cor():
744 r = await self.session.get(self.fs.encode_url(url), **kwargs).__aenter__()
745 self.fs._raise_not_found_for_status(r, url)
746 return r
747
748 self.r = sync(self.loop, cor)
749 self.loop = fs.loop
750
751 def seek(self, loc, whence=0):
752 if loc == 0 and whence == 1:
753 return
754 if loc == self.loc and whence == 0:
755 return
756 raise ValueError("Cannot seek streaming HTTP file")
757
758 async def _read(self, num=-1):
759 out = await self.r.content.read(num)
760 self.loc += len(out)
761 return out
762
763 read = sync_wrapper(_read)
764
765 async def _close(self):
766 self.r.close()
767
768 def close(self):
769 asyncio.run_coroutine_threadsafe(self._close(), self.loop)
770 super().close()
771
772
773class AsyncStreamFile(AbstractAsyncStreamedFile):
774 def __init__(
775 self, fs, url, mode="rb", loop=None, session=None, size=None, **kwargs
776 ):
777 self.url = url
778 self.session = session
779 self.r = None
780 if mode != "rb":
781 raise ValueError
782 self.details = {"name": url, "size": None}
783 self.kwargs = kwargs
784 super().__init__(fs=fs, path=url, mode=mode, cache_type="none")
785 self.size = size
786
787 async def read(self, num=-1):
788 if self.r is None:
789 r = await self.session.get(
790 self.fs.encode_url(self.url), **self.kwargs
791 ).__aenter__()
792 self.fs._raise_not_found_for_status(r, self.url)
793 self.r = r
794 out = await self.r.content.read(num)
795 self.loc += len(out)
796 return out
797
798 async def close(self):
799 if self.r is not None:
800 self.r.close()
801 self.r = None
802 await super().close()
803
804
805async def get_range(session, url, start, end, file=None, **kwargs):
806 # explicit get a range when we know it must be safe
807 kwargs = kwargs.copy()
808 headers = kwargs.pop("headers", {}).copy()
809 headers["Range"] = f"bytes={start}-{end - 1}"
810 r = await session.get(url, headers=headers, **kwargs)
811 r.raise_for_status()
812 async with r:
813 out = await r.read()
814 if file:
815 with open(file, "r+b") as f: # noqa: ASYNC230
816 f.seek(start)
817 f.write(out)
818 else:
819 return out
820
821
822async def _file_info(url, session, size_policy="head", **kwargs):
823 """Call HEAD on the server to get details about the file (size/checksum etc.)
824
825 Default operation is to explicitly allow redirects and use encoding
826 'identity' (no compression) to get the true size of the target.
827 """
828 logger.debug("Retrieve file size for %s", url)
829 kwargs = kwargs.copy()
830 ar = kwargs.pop("allow_redirects", True)
831 head = kwargs.get("headers", {}).copy()
832 head["Accept-Encoding"] = "identity"
833 kwargs["headers"] = head
834
835 info = {}
836 if size_policy == "head":
837 r = await session.head(url, allow_redirects=ar, **kwargs)
838 elif size_policy == "get":
839 r = await session.get(url, allow_redirects=ar, **kwargs)
840 else:
841 raise TypeError(f'size_policy must be "head" or "get", got {size_policy}')
842 async with r:
843 r.raise_for_status()
844
845 if "Content-Length" in r.headers:
846 # Some servers may choose to ignore Accept-Encoding and return
847 # compressed content, in which case the returned size is unreliable.
848 if "Content-Encoding" not in r.headers or r.headers["Content-Encoding"] in [
849 "identity",
850 "",
851 ]:
852 info["size"] = int(r.headers["Content-Length"])
853 elif "Content-Range" in r.headers:
854 info["size"] = int(r.headers["Content-Range"].split("/")[1])
855
856 if "Content-Type" in r.headers:
857 info["mimetype"] = r.headers["Content-Type"].partition(";")[0]
858
859 if r.headers.get("Accept-Ranges") == "none":
860 # Some servers may explicitly discourage partial content requests, but
861 # the lack of "Accept-Ranges" does not always indicate they would fail
862 info["partial"] = False
863
864 info["url"] = str(r.url)
865
866 for checksum_field in ["ETag", "Content-MD5", "Digest"]:
867 if r.headers.get(checksum_field):
868 info[checksum_field] = r.headers[checksum_field]
869
870 return info
871
872
873async def _file_size(url, session=None, *args, **kwargs):
874 if session is None:
875 session = await get_client()
876 info = await _file_info(url, session=session, *args, **kwargs)
877 return info.get("size")
878
879
880file_size = sync_wrapper(_file_size)