Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/upath/_chain.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import sys
4import warnings
5from collections import defaultdict
6from collections import deque
7from collections.abc import Iterator
8from collections.abc import MutableMapping
9from collections.abc import Sequence
10from collections.abc import Set
11from itertools import zip_longest
12from typing import TYPE_CHECKING
13from typing import Any
14from typing import NamedTuple
16from upath._flavour import WrappedFileSystemFlavour
17from upath._protocol import get_upath_protocol
18from upath.registry import available_implementations
19from upath.types import UNSET_DEFAULT
21if TYPE_CHECKING:
22 if sys.version_info >= (3, 11):
23 from typing import Never
24 from typing import Self
25 else:
26 from typing_extensions import Never
27 from typing_extensions import Self
29__all__ = [
30 "ChainSegment",
31 "Chain",
32 "FSSpecChainParser",
33 "DEFAULT_CHAIN_PARSER",
34]
37class ChainSegment(NamedTuple):
38 path: str | None # support for path passthrough (i.e. simplecache)
39 protocol: str
40 storage_options: dict[str, Any]
43class Chain:
44 """holds current chain segments"""
46 __slots__ = (
47 "_segments",
48 "_index",
49 )
51 def __init__(
52 self,
53 *segments: ChainSegment,
54 index: int = 0,
55 ) -> None:
56 if not (0 <= index < len(segments)):
57 raise ValueError("index must be between 0 and len(segments)")
58 self._segments = segments
59 self._index = index
61 def __repr__(self) -> str:
62 args = ", ".join(map(repr, self._segments))
63 if self._index != 0:
64 args += f", index={self._index}"
65 return f"{type(self).__name__}({args})"
67 @property
68 def current(self) -> ChainSegment:
69 return self._segments[self._index]
71 @property
72 def _path_index(self) -> int:
73 for idx, segment in enumerate(self._segments[self._index :], start=self._index):
74 if segment.path is not None:
75 return idx
76 raise IndexError("No target path found")
78 @property
79 def active_path(self) -> str:
80 path = self._segments[self._path_index].path
81 if path is None:
82 raise RuntimeError
83 return path
85 @property
86 def active_path_protocol(self) -> str:
87 return self._segments[self._path_index].protocol
89 def replace(
90 self,
91 *,
92 path: str | None = None,
93 protocol: str | None = None,
94 storage_options: dict[str, Any] | None = None,
95 ) -> Self:
96 """replace the current chain segment keeping remaining chain segments"""
97 segments = self.to_list()
98 index = self._index
100 replacements: MutableMapping[int, dict[str, Any]] = defaultdict(dict)
101 if protocol is not None:
102 replacements[index]["protocol"] = protocol
103 if storage_options is not None:
104 replacements[index]["storage_options"] = storage_options
105 if path is not None:
106 replacements[self._path_index]["path"] = path
108 for idx, items in replacements.items():
109 segments[idx] = segments[idx]._replace(**items)
111 return type(self)(*segments, index=index)
113 def to_list(self) -> list[ChainSegment]:
114 """return a list of chain segments unnesting target_* segments"""
115 queue = deque(self._segments)
116 segments = []
117 while queue:
118 segment = queue.popleft()
119 if (
120 "target_protocol" in segment.storage_options
121 and "fo" in segment.storage_options
122 ):
123 storage_options = segment.storage_options.copy()
124 target_options = storage_options.pop("target_options", {})
125 target_protocol = storage_options.pop("target_protocol")
126 fo = storage_options.pop("fo")
127 queue.appendleft(ChainSegment(fo, target_protocol, target_options))
128 segments.append(
129 ChainSegment(segment.path, segment.protocol, storage_options)
130 )
131 elif not segments or segment != segments[-1]:
132 segments.append(segment)
133 return segments
135 @classmethod
136 def from_list(cls, segments: list[ChainSegment], index: int = 0) -> Self:
137 return cls(*segments, index=index)
139 def nest(self) -> ChainSegment:
140 """return a nested target_* structure"""
141 # see: fsspec.core.url_to_fs
142 inkwargs: dict[str, Any] = {}
143 # Reverse iterate the chain, creating a nested target_* structure
144 chain = self._segments
145 _prev = chain[-1].path
146 for i, ch in enumerate(reversed(chain)):
147 urls, protocol, kw = ch
148 if urls is None:
149 urls = _prev
150 _prev = urls
151 if i == len(chain) - 1:
152 inkwargs = {**kw, **inkwargs}
153 continue
154 inkwargs["target_options"] = {**kw, **inkwargs}
155 inkwargs["target_protocol"] = protocol
156 inkwargs["fo"] = urls # codespell:ignore fo
157 urlpath, protocol, _ = chain[0]
158 return ChainSegment(urlpath, protocol, inkwargs)
161def _iter_fileobject_protocol_options(
162 fileobject: str | None,
163 protocol: str,
164 storage_options: dict[str, Any],
165 /,
166) -> Iterator[tuple[str | None, str, dict[str, Any]]]:
167 """yields fileobject, protocol and remaining storage options"""
168 so = storage_options.copy()
169 while "target_protocol" in so:
170 t_protocol = so.pop("target_protocol", "")
171 t_fileobject = so.pop("fo", None) # codespell:ignore fo
172 t_so = so.pop("target_options", {})
173 yield fileobject, protocol, so
174 fileobject, protocol, so = t_fileobject, t_protocol, t_so
175 yield fileobject, protocol, so
178class FSSpecChainParser:
179 """parse an fsspec chained urlpath"""
181 def __init__(self) -> None:
182 self.link: str = "::"
183 self.known_protocols: Set[str] = set()
185 def unchain(
186 self,
187 path: str,
188 _deprecated_storage_options: Never = UNSET_DEFAULT,
189 /,
190 *,
191 protocol: str | None = None,
192 storage_options: dict[str, Any] | None = None,
193 ) -> list[ChainSegment]:
194 """implements same behavior as fsspec.core._un_chain
196 two differences:
197 1. it sets the urlpath to None for upstream filesystems that passthrough
198 2. it checks against the known protocols for exact matches
200 """
201 if _deprecated_storage_options is not UNSET_DEFAULT:
202 warnings.warn(
203 "passing storage_options as positional argument is deprecated, "
204 "pass as keyword argument instead",
205 DeprecationWarning,
206 stacklevel=2,
207 )
208 if storage_options is not None:
209 raise ValueError(
210 "cannot pass storage_options both positionally and as keyword"
211 )
212 storage_options = _deprecated_storage_options
213 protocol = protocol or storage_options.get("protocol")
214 if storage_options is None:
215 storage_options = {}
217 segments: list[ChainSegment] = []
218 path_bit: str | None
219 next_path_overwrite: str | None = None
220 for proto0, bit in zip_longest([protocol], path.split(self.link)):
221 # get protocol and path_bit
222 if (
223 "://" in bit # uri-like, fast-path (redundant)
224 or "/" in bit # path-like, fast-path
225 ):
226 proto = get_upath_protocol(bit, protocol=proto0)
227 flavour = WrappedFileSystemFlavour.from_protocol(proto)
228 path_bit = flavour.strip_protocol(bit)
229 extra_so = flavour.get_kwargs_from_url(bit)
230 elif bit in self.known_protocols and (
231 proto0 is None or bit == proto0
232 ): # exact match a fsspec protocol
233 proto = bit
234 path_bit = None
235 extra_so = {}
236 elif bit in (m := set(available_implementations(fallback=True))) and (
237 proto0 is None or bit == proto0
238 ):
239 self.known_protocols = m
240 proto = bit
241 path_bit = None
242 extra_so = {}
243 else:
244 proto = get_upath_protocol(bit, protocol=proto0)
245 flavour = WrappedFileSystemFlavour.from_protocol(proto)
246 path_bit = flavour.strip_protocol(bit)
247 extra_so = flavour.get_kwargs_from_url(bit)
248 if proto in {"blockcache", "filecache", "simplecache"}:
249 if path_bit is not None:
250 next_path_overwrite = path_bit or "/"
251 path_bit = None
252 elif next_path_overwrite is not None:
253 path_bit = next_path_overwrite
254 next_path_overwrite = None
255 segments.append(ChainSegment(path_bit, proto, extra_so))
257 root_so = segments[0].storage_options
258 for segment, proto_fo_so in zip_longest(
259 segments,
260 _iter_fileobject_protocol_options(
261 path_bit if segments else None,
262 protocol or "",
263 storage_options,
264 ),
265 ):
266 t_fo, t_proto, t_so = proto_fo_so or (None, "", {})
267 if segment is None:
268 if next_path_overwrite is not None:
269 t_fo = next_path_overwrite
270 next_path_overwrite = None
271 segments.append(ChainSegment(t_fo, t_proto, t_so))
272 else:
273 proto = segment.protocol
274 # check if protocol is consistent with storage options
275 if t_proto and t_proto != proto:
276 raise ValueError(
277 f"protocol {proto!r} collides with target_protocol {t_proto!r}"
278 )
279 # update the storage_options
280 segment.storage_options.update(root_so.pop(proto, {}))
281 segment.storage_options.update(t_so)
283 return segments
285 def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]:
286 """returns a chained urlpath from the segments"""
287 urlpaths = []
288 kwargs = {}
289 for segment in segments:
290 if segment.protocol and segment.path is not None:
291 # FIXME: currently unstrip_protocol is only implemented by
292 # AbstractFileSystem, LocalFileSystem, and OSSFileSystem
293 # so to make this work we just implement it ourselves here.
294 # To do this properly we would need to instantiate the
295 # filesystem with its storage options and call
296 # fs.unstrip_protocol(segment.path)
297 if segment.path.startswith(f"{segment.protocol}:/"):
298 urlpath = segment.path
299 else:
300 urlpath = f"{segment.protocol}://{segment.path}"
301 elif segment.protocol:
302 urlpath = segment.protocol
303 elif segment.path is not None:
304 urlpath = segment.path
305 else:
306 warnings.warn(
307 f"skipping invalid segment {segment}",
308 RuntimeWarning,
309 stacklevel=2,
310 )
311 continue
312 urlpaths.append(urlpath)
313 # TODO: ensure roundtrip with unchain behavior
314 if segment.storage_options:
315 kwargs[segment.protocol] = segment.storage_options
316 return self.link.join(urlpaths), kwargs
319DEFAULT_CHAIN_PARSER = FSSpecChainParser()
322if __name__ == "__main__":
323 from pprint import pp
325 from fsspec.core import _un_chain
327 chained_path = "simplecache::zip://haha.csv::gcs://bucket/file.zip"
328 chained_kw = {"zip": {"allowZip64": False}}
329 print(chained_path, chained_kw)
330 out0 = _un_chain(chained_path, chained_kw)
331 out1 = FSSpecChainParser().unchain(chained_path, storage_options=chained_kw)
333 pp(out0)
334 pp(out1)
336 rechained_path, rechained_kw = FSSpecChainParser().chain(out1)
337 print(rechained_path, rechained_kw)
339 # UPath should store segments and access the path to operate on
340 # through segments.current.path
341 segments0 = Chain.from_list(segments=out1, index=1)
342 assert segments0.current.protocol == "zip"
344 # try to switch out zip path
345 segments1 = segments0.replace(path="/newfile.csv")
346 new_path, new_kw = FSSpecChainParser().chain(segments1.to_list())
347 print(new_path, new_kw)