1# Copyright 2021, New York University and the TUF contributors
2# SPDX-License-Identifier: MIT OR Apache-2.0
3
4"""Provides an interface for network IO abstraction."""
5
6# Imports
7import abc
8import logging
9import tempfile
10from collections.abc import Iterator
11from contextlib import contextmanager
12from typing import IO
13
14from tuf.api import exceptions
15
16logger = logging.getLogger(__name__)
17
18
19# Classes
20class FetcherInterface(metaclass=abc.ABCMeta):
21 """Defines an interface for abstract network download.
22
23 By providing a concrete implementation of the abstract interface,
24 users of the framework can plug-in their preferred/customized
25 network stack.
26
27 Implementations of FetcherInterface only need to implement ``_fetch()``.
28 The public API of the class is already implemented.
29 """
30
31 @abc.abstractmethod
32 def _fetch(self, url: str) -> Iterator[bytes]:
33 """Fetch the contents of HTTP/HTTPS ``url`` from a remote server.
34
35 Implementations must raise ``DownloadHTTPError`` if they receive
36 an HTTP error code.
37
38 Implementations may raise any errors but the ones that are not
39 ``DownloadErrors`` will be wrapped in a ``DownloadError`` by
40 ``fetch()``.
41
42 Args:
43 url: URL string that represents a file location.
44
45 Raises:
46 exceptions.DownloadHTTPError: HTTP error code was received.
47
48 Returns:
49 Bytes iterator
50 """
51 raise NotImplementedError # pragma: no cover
52
53 def fetch(self, url: str) -> Iterator[bytes]:
54 """Fetch the contents of HTTP/HTTPS ``url`` from a remote server.
55
56 Args:
57 url: URL string that represents a file location.
58
59 Raises:
60 exceptions.DownloadError: An error occurred during download.
61 exceptions.DownloadHTTPError: An HTTP error code was received.
62
63 Returns:
64 Bytes iterator
65 """
66 # Ensure that fetch() only raises DownloadErrors, regardless of the
67 # fetcher implementation
68 try:
69 return self._fetch(url)
70 except exceptions.DownloadError as e:
71 raise e
72 except Exception as e:
73 raise exceptions.DownloadError(f"Failed to download {url}") from e
74
75 @contextmanager
76 def download_file(self, url: str, max_length: int) -> Iterator[IO]:
77 """Download file from given ``url``.
78
79 It is recommended to use ``download_file()`` within a ``with``
80 block to guarantee that allocated file resources will always
81 be released even if download fails.
82
83 Args:
84 url: URL string that represents the location of the file.
85 max_length: Upper bound of file size in bytes.
86
87 Raises:
88 exceptions.DownloadError: An error occurred during download.
89 exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
90 ``max_length``.
91 exceptions.DownloadHTTPError: An HTTP error code was received.
92
93 Yields:
94 ``TemporaryFile`` object that points to the contents of ``url``.
95 """
96 logger.debug("Downloading: %s", url)
97
98 number_of_bytes_received = 0
99
100 with tempfile.TemporaryFile() as temp_file:
101 chunks = self.fetch(url)
102 for chunk in chunks:
103 number_of_bytes_received += len(chunk)
104 if number_of_bytes_received > max_length:
105 raise exceptions.DownloadLengthMismatchError(
106 f"Downloaded {number_of_bytes_received} bytes exceeding"
107 f" the maximum allowed length of {max_length}"
108 )
109
110 temp_file.write(chunk)
111
112 logger.debug(
113 "Downloaded %d out of %d bytes",
114 number_of_bytes_received,
115 max_length,
116 )
117
118 temp_file.seek(0)
119 yield temp_file
120
121 def download_bytes(self, url: str, max_length: int) -> bytes:
122 """Download bytes from given ``url``.
123
124 Returns the downloaded bytes, otherwise like ``download_file()``.
125
126 Args:
127 url: URL string that represents the location of the file.
128 max_length: Upper bound of data size in bytes.
129
130 Raises:
131 exceptions.DownloadError: An error occurred during download.
132 exceptions.DownloadLengthMismatchError: Downloaded bytes exceed
133 ``max_length``.
134 exceptions.DownloadHTTPError: An HTTP error code was received.
135
136 Returns:
137 Content of the file in bytes.
138 """
139 with self.download_file(url, max_length) as dl_file:
140 return dl_file.read()