Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/io/path.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

191 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import contextlib 

21import os 

22import shutil 

23from collections.abc import Mapping 

24from typing import TYPE_CHECKING, Any, ClassVar 

25from urllib.parse import urlsplit 

26 

27from fsspec.utils import stringify_path 

28from upath.implementations.cloud import CloudPath 

29from upath.registry import get_upath_class 

30 

31from airflow.sdk.io.stat import stat_result 

32from airflow.sdk.io.store import attach 

33 

34if TYPE_CHECKING: 

35 from fsspec import AbstractFileSystem 

36 

37 

38class _TrackingFileWrapper: 

39 """Wrapper that tracks file operations to intercept lineage.""" 

40 

41 def __init__(self, path: ObjectStoragePath, obj): 

42 super().__init__() 

43 self._path = path 

44 self._obj = obj 

45 

46 def __getattr__(self, name): 

47 from airflow.lineage.hook import get_hook_lineage_collector 

48 

49 if not callable(attr := getattr(self._obj, name)): 

50 return attr 

51 

52 # If the attribute is a method, wrap it in another method to intercept the call 

53 def wrapper(*args, **kwargs): 

54 if name == "read": 

55 get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path)) 

56 elif name == "write": 

57 get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path)) 

58 result = attr(*args, **kwargs) 

59 return result 

60 

61 return wrapper 

62 

63 # We need to explicitly implement `__iter__` 

64 # because otherwise python iteration logic could use `__getitem__` 

65 def __iter__(self): 

66 return iter(self._obj) 

67 

68 def __getitem__(self, key): 

69 # Intercept item access 

70 return self._obj[key] 

71 

72 def __enter__(self): 

73 self._obj.__enter__() 

74 return self 

75 

76 def __exit__(self, exc_type, exc_val, exc_tb): 

77 self._obj.__exit__(exc_type, exc_val, exc_tb) 

78 

79 

80class ObjectStoragePath(CloudPath): 

81 """A path-like object for object storage.""" 

82 

83 __version__: ClassVar[int] = 1 

84 

85 _protocol_dispatch = False 

86 

87 sep: ClassVar[str] = "/" 

88 root_marker: ClassVar[str] = "/" 

89 

90 __slots__ = ("_hash_cached",) 

91 

92 @classmethod 

93 def _transform_init_args( 

94 cls, 

95 args: tuple[str | os.PathLike, ...], 

96 protocol: str, 

97 storage_options: dict[str, Any], 

98 ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]: 

99 """Extract conn_id from the URL and set it as a storage option.""" 

100 if args: 

101 arg0 = args[0] 

102 parsed_url = urlsplit(stringify_path(arg0)) 

103 userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") 

104 if have_info: 

105 storage_options.setdefault("conn_id", userinfo or None) 

106 parsed_url = parsed_url._replace(netloc=hostinfo) 

107 args = (parsed_url.geturl(),) + args[1:] 

108 protocol = protocol or parsed_url.scheme 

109 return args, protocol, storage_options 

110 

111 @classmethod 

112 def _fs_factory( 

113 cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any] 

114 ) -> AbstractFileSystem: 

115 return attach(protocol or "file", storage_options.get("conn_id")).fs 

116 

117 def __hash__(self) -> int: 

118 self._hash_cached: int 

119 try: 

120 return self._hash_cached 

121 except AttributeError: 

122 self._hash_cached = hash(str(self)) 

123 return self._hash_cached 

124 

125 def __eq__(self, other: Any) -> bool: 

126 return self.samestore(other) and str(self) == str(other) 

127 

128 def samestore(self, other: Any) -> bool: 

129 return ( 

130 isinstance(other, ObjectStoragePath) 

131 and self.protocol == other.protocol 

132 and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") 

133 ) 

134 

135 @property 

136 def container(self) -> str: 

137 return self.bucket 

138 

139 @property 

140 def bucket(self) -> str: 

141 if self._url: 

142 return self._url.netloc 

143 return "" 

144 

145 @property 

146 def key(self) -> str: 

147 if self._url: 

148 # per convention, we strip the leading slashes to ensure a relative key is returned 

149 # we keep the trailing slash to allow for directory-like semantics 

150 return self._url.path.lstrip(self.sep) 

151 return "" 

152 

153 @property 

154 def namespace(self) -> str: 

155 return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol 

156 

157 def open(self, mode="r", **kwargs): 

158 """Open the file pointed to by this path.""" 

159 kwargs.setdefault("block_size", kwargs.pop("buffering", None)) 

160 return _TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) 

161 

162 def stat(self) -> stat_result: # type: ignore[override] 

163 """Call ``stat`` and return the result.""" 

164 return stat_result( 

165 self.fs.stat(self.path), 

166 protocol=self.protocol, 

167 conn_id=self.storage_options.get("conn_id"), 

168 ) 

169 

170 def samefile(self, other_path: Any) -> bool: 

171 """Return whether other_path is the same or not as this file.""" 

172 if not isinstance(other_path, ObjectStoragePath): 

173 return False 

174 

175 st = self.stat() 

176 other_st = other_path.stat() 

177 

178 return ( 

179 st["protocol"] == other_st["protocol"] 

180 and st["conn_id"] == other_st["conn_id"] 

181 and st["ino"] == other_st["ino"] 

182 ) 

183 

184 def _scandir(self): 

185 # Emulate os.scandir(), which returns an object that can be used as a 

186 # context manager. 

187 return contextlib.nullcontext(self.iterdir()) 

188 

189 def replace(self, target) -> ObjectStoragePath: 

190 """ 

191 Rename this path to the target path, overwriting if that path exists. 

192 

193 The target path may be absolute or relative. Relative paths are 

194 interpreted relative to the current working directory, *not* the 

195 directory of the Path object. 

196 

197 Returns the new Path instance pointing to the target path. 

198 """ 

199 return self.rename(target) 

200 

201 @classmethod 

202 def cwd(cls): 

203 if cls is ObjectStoragePath: 

204 return get_upath_class("").cwd() 

205 raise NotImplementedError 

206 

207 @classmethod 

208 def home(cls): 

209 if cls is ObjectStoragePath: 

210 return get_upath_class("").home() 

211 raise NotImplementedError 

212 

213 # EXTENDED OPERATIONS 

214 

215 def ukey(self) -> str: 

216 """Hash of file properties, to tell if it has changed.""" 

217 return self.fs.ukey(self.path) 

218 

219 def checksum(self) -> int: 

220 """Return the checksum of the file at this path.""" 

221 # we directly access the fs here to avoid changing the abstract interface 

222 return self.fs.checksum(self.path) 

223 

224 def read_block(self, offset: int, length: int, delimiter=None): 

225 r""" 

226 Read a block of bytes. 

227 

228 Starting at ``offset`` of the file, read ``length`` bytes. If 

229 ``delimiter`` is set then we ensure that the read starts and stops at 

230 delimiter boundaries that follow the locations ``offset`` and ``offset 

231 + length``. If ``offset`` is zero then we start at zero. The 

232 bytestring returned WILL include the end delimiter string. 

233 

234 If offset+length is beyond the eof, reads to eof. 

235 

236 :param offset: int 

237 Byte offset to start read 

238 :param length: int 

239 Number of bytes to read. If None, read to the end. 

240 :param delimiter: bytes (optional) 

241 Ensure reading starts and stops at delimiter bytestring 

242 

243 Examples 

244 -------- 

245 .. code-block:: pycon 

246 

247 # Read the first 13 bytes (no delimiter) 

248 >>> read_block(0, 13) 

249 b'Alice, 100\nBo' 

250 

251 # Read first 13 bytes, but force newline boundaries 

252 >>> read_block(0, 13, delimiter=b"\n") 

253 b'Alice, 100\nBob, 200\n' 

254 

255 # Read until EOF, but only stop at newline 

256 >>> read_block(0, None, delimiter=b"\n") 

257 b'Alice, 100\nBob, 200\nCharlie, 300' 

258 

259 See Also 

260 -------- 

261 :func:`fsspec.utils.read_block` 

262 """ 

263 return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter) 

264 

265 def sign(self, expiration: int = 100, **kwargs): 

266 """ 

267 Create a signed URL representing the given path. 

268 

269 Some implementations allow temporary URLs to be generated, as a 

270 way of delegating credentials. 

271 

272 :param path: str 

273 The path on the filesystem 

274 :param expiration: int 

275 Number of seconds to enable the URL for (if supported) 

276 

277 :returns URL: str 

278 The signed URL 

279 

280 :raises NotImplementedError: if the method is not implemented for a store 

281 """ 

282 return self.fs.sign(self.path, expiration=expiration, **kwargs) 

283 

284 def size(self) -> int: 

285 """Size in bytes of the file at this path.""" 

286 return self.fs.size(self.path) 

287 

288 def _cp_file(self, dst: ObjectStoragePath, **kwargs): 

289 """Copy a single file from this path to another location by streaming the data.""" 

290 # create the directory or bucket if required 

291 if dst.key.endswith(self.sep) or not dst.key: 

292 dst.mkdir(exist_ok=True, parents=True) 

293 dst = dst / self.key 

294 elif dst.is_dir(): 

295 dst = dst / self.key 

296 

297 # streaming copy 

298 with self.open("rb") as f1, dst.open("wb") as f2: 

299 # make use of system dependent buffer size 

300 shutil.copyfileobj(f1, f2, **kwargs) 

301 

302 def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: 

303 """ 

304 Copy file(s) from this path to another location. 

305 

306 For remote to remote copies, the key used for the destination will be the same as the source. 

307 So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not 

308 gcs://dst_bucket/bar. 

309 

310 :param dst: Destination path 

311 :param recursive: If True, copy directories recursively. 

312 

313 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

314 """ 

315 from airflow.lineage.hook import get_hook_lineage_collector 

316 

317 if isinstance(dst, str): 

318 dst = ObjectStoragePath(dst) 

319 

320 if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": 

321 # only emit this in "optimized" variants - else lineage will be captured by file writes/reads 

322 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) 

323 get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst)) 

324 

325 # same -> same 

326 if self.samestore(dst): 

327 self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) 

328 return 

329 

330 # use optimized path for local -> remote or remote -> local 

331 if self.protocol == "file": 

332 dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs) 

333 return 

334 

335 if dst.protocol == "file": 

336 self.fs.get(self.path, dst.path, recursive=recursive, **kwargs) 

337 return 

338 

339 if not self.exists(): 

340 raise FileNotFoundError(f"{self} does not exist") 

341 

342 # remote dir -> remote dir 

343 if self.is_dir(): 

344 if dst.is_file(): 

345 raise ValueError("Cannot copy directory to a file.") 

346 

347 dst.mkdir(exist_ok=True, parents=True) 

348 

349 out = self.fs.expand_path(self.path, recursive=True, **kwargs) 

350 

351 for path in out: 

352 # this check prevents one extra call to is_dir() as 

353 # glob returns self as well 

354 if path == self.path: 

355 continue 

356 

357 src_obj = ObjectStoragePath( 

358 path, 

359 protocol=self.protocol, 

360 conn_id=self.storage_options.get("conn_id"), 

361 ) 

362 

363 # skip directories, empty directories will not be created 

364 if src_obj.is_dir(): 

365 continue 

366 

367 src_obj._cp_file(dst) 

368 return 

369 

370 # remote file -> remote dir 

371 self._cp_file(dst, **kwargs) 

372 

373 def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: 

374 """ 

375 Move file(s) from this path to another location. 

376 

377 :param path: Destination path 

378 :param recursive: bool 

379 If True, move directories recursively. 

380 

381 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

382 """ 

383 from airflow.lineage.hook import get_hook_lineage_collector 

384 

385 if isinstance(path, str): 

386 path = ObjectStoragePath(path) 

387 

388 if self.samestore(path): 

389 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) 

390 get_hook_lineage_collector().add_output_asset(context=path, uri=str(path)) 

391 return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) 

392 

393 # non-local copy 

394 self.copy(path, recursive=recursive, **kwargs) 

395 self.unlink() 

396 

397 def serialize(self) -> dict[str, Any]: 

398 _kwargs = {**self.storage_options} 

399 conn_id = _kwargs.pop("conn_id", None) 

400 

401 return { 

402 "path": str(self), 

403 "conn_id": conn_id, 

404 "kwargs": _kwargs, 

405 } 

406 

407 @classmethod 

408 def deserialize(cls, data: dict, version: int) -> ObjectStoragePath: 

409 if version > cls.__version__: 

410 raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.") 

411 

412 _kwargs = data.pop("kwargs") 

413 path = data.pop("path") 

414 conn_id = data.pop("conn_id", None) 

415 

416 return ObjectStoragePath(path, conn_id=conn_id, **_kwargs) 

417 

418 def __str__(self): 

419 conn_id = self.storage_options.get("conn_id") 

420 if self._protocol and conn_id: 

421 return f"{self._protocol}://{conn_id}@{self.path}" 

422 return super().__str__()