Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/upath/_chain.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

207 statements  

1from __future__ import annotations 

2 

3import sys 

4import warnings 

5from collections import defaultdict 

6from collections import deque 

7from collections.abc import Iterator 

8from collections.abc import MutableMapping 

9from collections.abc import Sequence 

10from collections.abc import Set 

11from itertools import zip_longest 

12from typing import TYPE_CHECKING 

13from typing import Any 

14from typing import NamedTuple 

15 

16from upath._flavour import WrappedFileSystemFlavour 

17from upath._protocol import get_upath_protocol 

18from upath.registry import available_implementations 

19from upath.types import UNSET_DEFAULT 

20 

21if TYPE_CHECKING: 

22 if sys.version_info >= (3, 11): 

23 from typing import Never 

24 from typing import Self 

25 else: 

26 from typing_extensions import Never 

27 from typing_extensions import Self 

28 

29__all__ = [ 

30 "ChainSegment", 

31 "Chain", 

32 "FSSpecChainParser", 

33 "DEFAULT_CHAIN_PARSER", 

34] 

35 

36 

37class ChainSegment(NamedTuple): 

38 path: str | None # support for path passthrough (i.e. simplecache) 

39 protocol: str 

40 storage_options: dict[str, Any] 

41 

42 

43class Chain: 

44 """holds current chain segments""" 

45 

46 __slots__ = ( 

47 "_segments", 

48 "_index", 

49 ) 

50 

51 def __init__( 

52 self, 

53 *segments: ChainSegment, 

54 index: int = 0, 

55 ) -> None: 

56 if not (0 <= index < len(segments)): 

57 raise ValueError("index must be between 0 and len(segments)") 

58 self._segments = segments 

59 self._index = index 

60 

61 def __repr__(self) -> str: 

62 args = ", ".join(map(repr, self._segments)) 

63 if self._index != 0: 

64 args += f", index={self._index}" 

65 return f"{type(self).__name__}({args})" 

66 

67 @property 

68 def current(self) -> ChainSegment: 

69 return self._segments[self._index] 

70 

71 @property 

72 def _path_index(self) -> int: 

73 for idx, segment in enumerate(self._segments[self._index :], start=self._index): 

74 if segment.path is not None: 

75 return idx 

76 raise IndexError("No target path found") 

77 

78 @property 

79 def active_path(self) -> str: 

80 path = self._segments[self._path_index].path 

81 if path is None: 

82 raise RuntimeError 

83 return path 

84 

85 @property 

86 def active_path_protocol(self) -> str: 

87 return self._segments[self._path_index].protocol 

88 

89 def replace( 

90 self, 

91 *, 

92 path: str | None = None, 

93 protocol: str | None = None, 

94 storage_options: dict[str, Any] | None = None, 

95 ) -> Self: 

96 """replace the current chain segment keeping remaining chain segments""" 

97 segments = self.to_list() 

98 index = self._index 

99 

100 replacements: MutableMapping[int, dict[str, Any]] = defaultdict(dict) 

101 if protocol is not None: 

102 replacements[index]["protocol"] = protocol 

103 if storage_options is not None: 

104 replacements[index]["storage_options"] = storage_options 

105 if path is not None: 

106 replacements[self._path_index]["path"] = path 

107 

108 for idx, items in replacements.items(): 

109 segments[idx] = segments[idx]._replace(**items) 

110 

111 return type(self)(*segments, index=index) 

112 

113 def to_list(self) -> list[ChainSegment]: 

114 """return a list of chain segments unnesting target_* segments""" 

115 queue = deque(self._segments) 

116 segments = [] 

117 while queue: 

118 segment = queue.popleft() 

119 if ( 

120 "target_protocol" in segment.storage_options 

121 and "fo" in segment.storage_options 

122 ): 

123 storage_options = segment.storage_options.copy() 

124 target_options = storage_options.pop("target_options", {}) 

125 target_protocol = storage_options.pop("target_protocol") 

126 fo = storage_options.pop("fo") 

127 queue.appendleft(ChainSegment(fo, target_protocol, target_options)) 

128 segments.append( 

129 ChainSegment(segment.path, segment.protocol, storage_options) 

130 ) 

131 elif not segments or segment != segments[-1]: 

132 segments.append(segment) 

133 return segments 

134 

135 @classmethod 

136 def from_list(cls, segments: list[ChainSegment], index: int = 0) -> Self: 

137 return cls(*segments, index=index) 

138 

139 def nest(self) -> ChainSegment: 

140 """return a nested target_* structure""" 

141 # see: fsspec.core.url_to_fs 

142 inkwargs: dict[str, Any] = {} 

143 # Reverse iterate the chain, creating a nested target_* structure 

144 chain = self._segments 

145 _prev = chain[-1].path 

146 for i, ch in enumerate(reversed(chain)): 

147 urls, protocol, kw = ch 

148 if urls is None: 

149 urls = _prev 

150 _prev = urls 

151 if i == len(chain) - 1: 

152 inkwargs = {**kw, **inkwargs} 

153 continue 

154 inkwargs["target_options"] = {**kw, **inkwargs} 

155 inkwargs["target_protocol"] = protocol 

156 inkwargs["fo"] = urls # codespell:ignore fo 

157 urlpath, protocol, _ = chain[0] 

158 return ChainSegment(urlpath, protocol, inkwargs) 

159 

160 

161def _iter_fileobject_protocol_options( 

162 fileobject: str | None, 

163 protocol: str, 

164 storage_options: dict[str, Any], 

165 /, 

166) -> Iterator[tuple[str | None, str, dict[str, Any]]]: 

167 """yields fileobject, protocol and remaining storage options""" 

168 so = storage_options.copy() 

169 while "target_protocol" in so: 

170 t_protocol = so.pop("target_protocol", "") 

171 t_fileobject = so.pop("fo", None) # codespell:ignore fo 

172 t_so = so.pop("target_options", {}) 

173 yield fileobject, protocol, so 

174 fileobject, protocol, so = t_fileobject, t_protocol, t_so 

175 yield fileobject, protocol, so 

176 

177 

178class FSSpecChainParser: 

179 """parse an fsspec chained urlpath""" 

180 

181 def __init__(self) -> None: 

182 self.link: str = "::" 

183 self.known_protocols: Set[str] = set() 

184 

185 def unchain( 

186 self, 

187 path: str, 

188 _deprecated_storage_options: Never = UNSET_DEFAULT, 

189 /, 

190 *, 

191 protocol: str | None = None, 

192 storage_options: dict[str, Any] | None = None, 

193 ) -> list[ChainSegment]: 

194 """implements same behavior as fsspec.core._un_chain 

195 

196 two differences: 

197 1. it sets the urlpath to None for upstream filesystems that passthrough 

198 2. it checks against the known protocols for exact matches 

199 

200 """ 

201 if _deprecated_storage_options is not UNSET_DEFAULT: 

202 warnings.warn( 

203 "passing storage_options as positional argument is deprecated, " 

204 "pass as keyword argument instead", 

205 DeprecationWarning, 

206 stacklevel=2, 

207 ) 

208 if storage_options is not None: 

209 raise ValueError( 

210 "cannot pass storage_options both positionally and as keyword" 

211 ) 

212 storage_options = _deprecated_storage_options 

213 protocol = protocol or storage_options.get("protocol") 

214 if storage_options is None: 

215 storage_options = {} 

216 

217 segments: list[ChainSegment] = [] 

218 path_bit: str | None 

219 next_path_overwrite: str | None = None 

220 for proto0, bit in zip_longest([protocol], path.split(self.link)): 

221 # get protocol and path_bit 

222 if ( 

223 "://" in bit # uri-like, fast-path (redundant) 

224 or "/" in bit # path-like, fast-path 

225 ): 

226 proto = get_upath_protocol(bit, protocol=proto0) 

227 flavour = WrappedFileSystemFlavour.from_protocol(proto) 

228 path_bit = flavour.strip_protocol(bit) 

229 extra_so = flavour.get_kwargs_from_url(bit) 

230 elif bit in self.known_protocols and ( 

231 proto0 is None or bit == proto0 

232 ): # exact match a fsspec protocol 

233 proto = bit 

234 path_bit = None 

235 extra_so = {} 

236 elif bit in (m := set(available_implementations(fallback=True))) and ( 

237 proto0 is None or bit == proto0 

238 ): 

239 self.known_protocols = m 

240 proto = bit 

241 path_bit = None 

242 extra_so = {} 

243 else: 

244 proto = get_upath_protocol(bit, protocol=proto0) 

245 flavour = WrappedFileSystemFlavour.from_protocol(proto) 

246 path_bit = flavour.strip_protocol(bit) 

247 extra_so = flavour.get_kwargs_from_url(bit) 

248 if proto in {"blockcache", "filecache", "simplecache"}: 

249 if path_bit is not None: 

250 next_path_overwrite = path_bit or "/" 

251 path_bit = None 

252 elif next_path_overwrite is not None: 

253 path_bit = next_path_overwrite 

254 next_path_overwrite = None 

255 segments.append(ChainSegment(path_bit, proto, extra_so)) 

256 

257 root_so = segments[0].storage_options 

258 for segment, proto_fo_so in zip_longest( 

259 segments, 

260 _iter_fileobject_protocol_options( 

261 path_bit if segments else None, 

262 protocol or "", 

263 storage_options, 

264 ), 

265 ): 

266 t_fo, t_proto, t_so = proto_fo_so or (None, "", {}) 

267 if segment is None: 

268 if next_path_overwrite is not None: 

269 t_fo = next_path_overwrite 

270 next_path_overwrite = None 

271 segments.append(ChainSegment(t_fo, t_proto, t_so)) 

272 else: 

273 proto = segment.protocol 

274 # check if protocol is consistent with storage options 

275 if t_proto and t_proto != proto: 

276 raise ValueError( 

277 f"protocol {proto!r} collides with target_protocol {t_proto!r}" 

278 ) 

279 # update the storage_options 

280 segment.storage_options.update(root_so.pop(proto, {})) 

281 segment.storage_options.update(t_so) 

282 

283 return segments 

284 

285 def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]: 

286 """returns a chained urlpath from the segments""" 

287 urlpaths = [] 

288 kwargs = {} 

289 for segment in segments: 

290 if segment.protocol and segment.path is not None: 

291 # FIXME: currently unstrip_protocol is only implemented by 

292 # AbstractFileSystem, LocalFileSystem, and OSSFileSystem 

293 # so to make this work we just implement it ourselves here. 

294 # To do this properly we would need to instantiate the 

295 # filesystem with its storage options and call 

296 # fs.unstrip_protocol(segment.path) 

297 if segment.path.startswith(f"{segment.protocol}:/"): 

298 urlpath = segment.path 

299 else: 

300 urlpath = f"{segment.protocol}://{segment.path}" 

301 elif segment.protocol: 

302 urlpath = segment.protocol 

303 elif segment.path is not None: 

304 urlpath = segment.path 

305 else: 

306 warnings.warn( 

307 f"skipping invalid segment {segment}", 

308 RuntimeWarning, 

309 stacklevel=2, 

310 ) 

311 continue 

312 urlpaths.append(urlpath) 

313 # TODO: ensure roundtrip with unchain behavior 

314 if segment.storage_options: 

315 kwargs[segment.protocol] = segment.storage_options 

316 return self.link.join(urlpaths), kwargs 

317 

318 

319DEFAULT_CHAIN_PARSER = FSSpecChainParser() 

320 

321 

322if __name__ == "__main__": 

323 from pprint import pp 

324 

325 from fsspec.core import _un_chain 

326 

327 chained_path = "simplecache::zip://haha.csv::gcs://bucket/file.zip" 

328 chained_kw = {"zip": {"allowZip64": False}} 

329 print(chained_path, chained_kw) 

330 out0 = _un_chain(chained_path, chained_kw) 

331 out1 = FSSpecChainParser().unchain(chained_path, storage_options=chained_kw) 

332 

333 pp(out0) 

334 pp(out1) 

335 

336 rechained_path, rechained_kw = FSSpecChainParser().chain(out1) 

337 print(rechained_path, rechained_kw) 

338 

339 # UPath should store segments and access the path to operate on 

340 # through segments.current.path 

341 segments0 = Chain.from_list(segments=out1, index=1) 

342 assert segments0.current.protocol == "zip" 

343 

344 # try to switch out zip path 

345 segments1 = segments0.replace(path="/newfile.csv") 

346 new_path, new_kw = FSSpecChainParser().chain(segments1.to_list()) 

347 print(new_path, new_kw)