Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/io/path.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

202 statements  

1# Licensed to the Apache Software Foundation (ASF) under one 

2# or more contributor license agreements. See the NOTICE file 

3# distributed with this work for additional information 

4# regarding copyright ownership. The ASF licenses this file 

5# to you under the Apache License, Version 2.0 (the 

6# "License"); you may not use this file except in compliance 

7# with the License. You may obtain a copy of the License at 

8# 

9# http://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, 

12# software distributed under the License is distributed on an 

13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

14# KIND, either express or implied. See the License for the 

15# specific language governing permissions and limitations 

16# under the License. 

17 

18from __future__ import annotations 

19 

20import shutil 

21from typing import TYPE_CHECKING, Any, ClassVar 

22from urllib.parse import urlsplit 

23 

24from fsspec.utils import stringify_path 

25from upath import UPath 

26from upath.extensions import ProxyUPath 

27 

28from airflow.sdk.io.stat import stat_result 

29from airflow.sdk.io.store import attach 

30 

31if TYPE_CHECKING: 

32 from fsspec import AbstractFileSystem 

33 from typing_extensions import Self 

34 from upath.types import JoinablePathLike 

35 

36 

37class _TrackingFileWrapper: 

38 """Wrapper that tracks file operations to intercept lineage.""" 

39 

40 def __init__(self, path: ObjectStoragePath, obj): 

41 super().__init__() 

42 self._path = path 

43 self._obj = obj 

44 

45 def __getattr__(self, name): 

46 from airflow.sdk.lineage import get_hook_lineage_collector 

47 

48 if not callable(attr := getattr(self._obj, name)): 

49 return attr 

50 

51 # If the attribute is a method, wrap it in another method to intercept the call 

52 def wrapper(*args, **kwargs): 

53 if name == "read": 

54 get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path)) 

55 elif name == "write": 

56 get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path)) 

57 result = attr(*args, **kwargs) 

58 return result 

59 

60 return wrapper 

61 

62 # We need to explicitly implement `__iter__` 

63 # because otherwise python iteration logic could use `__getitem__` 

64 def __iter__(self): 

65 return iter(self._obj) 

66 

67 def __getitem__(self, key): 

68 # Intercept item access 

69 return self._obj[key] 

70 

71 def __enter__(self): 

72 self._obj.__enter__() 

73 return self 

74 

75 def __exit__(self, exc_type, exc_val, exc_tb): 

76 self._obj.__exit__(exc_type, exc_val, exc_tb) 

77 

78 

79class ObjectStoragePath(ProxyUPath): 

80 """A path-like object for object storage.""" 

81 

82 __version__: ClassVar[int] = 1 

83 

84 sep: ClassVar[str] = "/" 

85 root_marker: ClassVar[str] = "/" 

86 

87 __slots__ = ("_hash_cached",) 

88 

89 def __init__( 

90 self, 

91 *args: JoinablePathLike, 

92 protocol: str | None = None, 

93 conn_id: str | None = None, 

94 **storage_options: Any, 

95 ) -> None: 

96 # ensure conn_id is always set in storage_options 

97 storage_options.setdefault("conn_id", None) 

98 # parse conn_id from args if provided 

99 if args: 

100 arg0 = args[0] 

101 if isinstance(arg0, type(self)): 

102 storage_options["conn_id"] = arg0.storage_options.get("conn_id") 

103 else: 

104 parsed_url = urlsplit(stringify_path(arg0)) 

105 userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") 

106 if have_info: 

107 conn_id = storage_options["conn_id"] = userinfo or None 

108 parsed_url = parsed_url._replace(netloc=hostinfo) 

109 args = (parsed_url.geturl(),) + args[1:] 

110 protocol = protocol or parsed_url.scheme 

111 # override conn_id if explicitly provided 

112 if conn_id is not None: 

113 storage_options["conn_id"] = conn_id 

114 super().__init__(*args, protocol=protocol, **storage_options) 

115 

116 @property 

117 def fs(self) -> AbstractFileSystem: 

118 """Return the filesystem for this path, using airflow's attach mechanism.""" 

119 conn_id = self.storage_options.get("conn_id") 

120 return attach(self.protocol or "file", conn_id).fs 

121 

122 def __hash__(self) -> int: 

123 self._hash_cached: int 

124 try: 

125 return self._hash_cached 

126 except AttributeError: 

127 self._hash_cached = hash(str(self)) 

128 return self._hash_cached 

129 

130 def __eq__(self, other: Any) -> bool: 

131 return self.samestore(other) and str(self) == str(other) 

132 

133 def samestore(self, other: Any) -> bool: 

134 return ( 

135 isinstance(other, ObjectStoragePath) 

136 and self.protocol == other.protocol 

137 and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") 

138 ) 

139 

140 @property 

141 def container(self) -> str: 

142 return self.bucket 

143 

144 @property 

145 def bucket(self) -> str: 

146 if self._url: 

147 return self._url.netloc 

148 return "" 

149 

150 @property 

151 def key(self) -> str: 

152 if self._url: 

153 # per convention, we strip the leading slashes to ensure a relative key is returned 

154 # we keep the trailing slash to allow for directory-like semantics 

155 return self._url.path.lstrip(self.sep) 

156 return "" 

157 

158 @property 

159 def namespace(self) -> str: 

160 return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol 

161 

162 def open(self, mode="r", **kwargs): 

163 """Open the file pointed to by this path.""" 

164 kwargs.setdefault("block_size", kwargs.pop("buffering", None)) 

165 return _TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs)) 

166 

167 def stat(self) -> stat_result: # type: ignore[override] 

168 """Call ``stat`` and return the result.""" 

169 return stat_result( 

170 self.fs.stat(self.path), 

171 protocol=self.protocol, 

172 conn_id=self.storage_options.get("conn_id"), 

173 ) 

174 

175 def samefile(self, other_path: Any) -> bool: 

176 """Return whether other_path is the same or not as this file.""" 

177 if not isinstance(other_path, ObjectStoragePath): 

178 return False 

179 

180 st = self.stat() 

181 other_st = other_path.stat() 

182 

183 return ( 

184 st["protocol"] == other_st["protocol"] 

185 and st["conn_id"] == other_st["conn_id"] 

186 and st["ino"] == other_st["ino"] 

187 ) 

188 

189 def replace(self, target) -> Self: 

190 """ 

191 Rename this path to the target path, overwriting if that path exists. 

192 

193 The target path may be absolute or relative. Relative paths are 

194 interpreted relative to the current working directory, *not* the 

195 directory of the Path object. 

196 

197 Returns the new Path instance pointing to the target path. 

198 """ 

199 return self.rename(target) 

200 

201 @classmethod 

202 def cwd(cls) -> Self: 

203 return cls._from_upath(UPath.cwd()) 

204 

205 @classmethod 

206 def home(cls) -> Self: 

207 return cls._from_upath(UPath.home()) 

208 

209 # EXTENDED OPERATIONS 

210 

211 def ukey(self) -> str: 

212 """Hash of file properties, to tell if it has changed.""" 

213 return self.fs.ukey(self.path) 

214 

215 def checksum(self) -> int: 

216 """Return the checksum of the file at this path.""" 

217 # we directly access the fs here to avoid changing the abstract interface 

218 return self.fs.checksum(self.path) 

219 

220 def read_block(self, offset: int, length: int, delimiter=None): 

221 r""" 

222 Read a block of bytes. 

223 

224 Starting at ``offset`` of the file, read ``length`` bytes. If 

225 ``delimiter`` is set then we ensure that the read starts and stops at 

226 delimiter boundaries that follow the locations ``offset`` and ``offset 

227 + length``. If ``offset`` is zero then we start at zero. The 

228 bytestring returned WILL include the end delimiter string. 

229 

230 If offset+length is beyond the eof, reads to eof. 

231 

232 :param offset: int 

233 Byte offset to start read 

234 :param length: int 

235 Number of bytes to read. If None, read to the end. 

236 :param delimiter: bytes (optional) 

237 Ensure reading starts and stops at delimiter bytestring 

238 

239 Examples 

240 -------- 

241 .. code-block:: pycon 

242 

243 # Read the first 13 bytes (no delimiter) 

244 >>> read_block(0, 13) 

245 b'Alice, 100\nBo' 

246 

247 # Read first 13 bytes, but force newline boundaries 

248 >>> read_block(0, 13, delimiter=b"\n") 

249 b'Alice, 100\nBob, 200\n' 

250 

251 # Read until EOF, but only stop at newline 

252 >>> read_block(0, None, delimiter=b"\n") 

253 b'Alice, 100\nBob, 200\nCharlie, 300' 

254 

255 See Also 

256 -------- 

257 :func:`fsspec.utils.read_block` 

258 """ 

259 return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter) 

260 

261 def sign(self, expiration: int = 100, **kwargs): 

262 """ 

263 Create a signed URL representing the given path. 

264 

265 Some implementations allow temporary URLs to be generated, as a 

266 way of delegating credentials. 

267 

268 :param path: str 

269 The path on the filesystem 

270 :param expiration: int 

271 Number of seconds to enable the URL for (if supported) 

272 

273 :returns URL: str 

274 The signed URL 

275 

276 :raises NotImplementedError: if the method is not implemented for a store 

277 """ 

278 return self.fs.sign(self.path, expiration=expiration, **kwargs) 

279 

280 def size(self) -> int: 

281 """Size in bytes of the file at this path.""" 

282 return self.fs.size(self.path) 

283 

284 def _cp_file(self, dst: ObjectStoragePath, **kwargs): 

285 """Copy a single file from this path to another location by streaming the data.""" 

286 # create the directory or bucket if required 

287 if dst.key.endswith(self.sep) or not dst.key: 

288 dst.mkdir(exist_ok=True, parents=True) 

289 dst = dst / self.key 

290 elif dst.is_dir(): 

291 dst = dst / self.key 

292 

293 # streaming copy 

294 with self.open("rb") as f1, dst.open("wb") as f2: 

295 # make use of system dependent buffer size 

296 shutil.copyfileobj(f1, f2, **kwargs) 

297 

298 def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override] 

299 """ 

300 Copy file(s) from this path to another location. 

301 

302 For remote to remote copies, the key used for the destination will be the same as the source. 

303 So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not 

304 gcs://dst_bucket/bar. 

305 

306 :param dst: Destination path 

307 :param recursive: If True, copy directories recursively. 

308 

309 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

310 """ 

311 from airflow.sdk.lineage import get_hook_lineage_collector 

312 

313 if isinstance(dst, str): 

314 dst = ObjectStoragePath(dst) 

315 

316 if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file": 

317 # only emit this in "optimized" variants - else lineage will be captured by file writes/reads 

318 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) 

319 get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst)) 

320 

321 # same -> same 

322 if self.samestore(dst): 

323 self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs) 

324 return 

325 

326 # use optimized path for local -> remote or remote -> local 

327 if self.protocol == "file": 

328 dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs) 

329 return 

330 

331 if dst.protocol == "file": 

332 self.fs.get(self.path, dst.path, recursive=recursive, **kwargs) 

333 return 

334 

335 if not self.exists(): 

336 raise FileNotFoundError(f"{self} does not exist") 

337 

338 # remote dir -> remote dir 

339 if self.is_dir(): 

340 if dst.is_file(): 

341 raise ValueError("Cannot copy directory to a file.") 

342 

343 dst.mkdir(exist_ok=True, parents=True) 

344 

345 out = self.fs.expand_path(self.path, recursive=True, **kwargs) 

346 

347 for path in out: 

348 # this check prevents one extra call to is_dir() as 

349 # glob returns self as well 

350 if path == self.path: 

351 continue 

352 

353 src_obj = ObjectStoragePath( 

354 path, 

355 protocol=self.protocol, 

356 conn_id=self.storage_options.get("conn_id"), 

357 ) 

358 

359 # skip directories, empty directories will not be created 

360 if src_obj.is_dir(): 

361 continue 

362 

363 src_obj._cp_file(dst) 

364 return 

365 

366 # remote file -> remote dir 

367 self._cp_file(dst, **kwargs) 

368 

369 def copy_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override] 

370 """ 

371 Copy file(s) from this path into another directory. 

372 

373 :param target_dir: Destination directory 

374 :param recursive: If True, copy directories recursively. 

375 

376 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

377 """ 

378 if isinstance(target_dir, str): 

379 target_dir = ObjectStoragePath(target_dir) 

380 if not target_dir.is_dir(): 

381 raise NotADirectoryError(f"Destination {target_dir} is not a directory.") 

382 dst_path = target_dir / self.name 

383 self.copy(dst_path, recursive=recursive, **kwargs) 

384 

385 def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override] 

386 """ 

387 Move file(s) from this path to another location. 

388 

389 :param path: Destination path 

390 :param recursive: bool 

391 If True, move directories recursively. 

392 

393 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

394 """ 

395 from airflow.sdk.lineage import get_hook_lineage_collector 

396 

397 if isinstance(path, str): 

398 path = ObjectStoragePath(path) 

399 

400 if self.samestore(path): 

401 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self)) 

402 get_hook_lineage_collector().add_output_asset(context=path, uri=str(path)) 

403 return self.fs.move(self.path, path.path, recursive=recursive, **kwargs) 

404 

405 # non-local copy 

406 self.copy(path, recursive=recursive, **kwargs) 

407 self.unlink() 

408 

409 def move_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override] 

410 """ 

411 Move file(s) from this path into another directory. 

412 

413 :param target_dir: Destination directory 

414 :param recursive: bool 

415 If True, move directories recursively. 

416 

417 kwargs: Additional keyword arguments to be passed to the underlying implementation. 

418 """ 

419 if isinstance(target_dir, str): 

420 target_dir = ObjectStoragePath(target_dir) 

421 if not target_dir.is_dir(): 

422 raise NotADirectoryError(f"Destination {target_dir} is not a directory.") 

423 dst_path = target_dir / self.name 

424 self.move(dst_path, recursive=recursive, **kwargs) 

425 

426 def serialize(self) -> dict[str, Any]: 

427 _kwargs = {**self.storage_options} 

428 conn_id = _kwargs.pop("conn_id", None) 

429 

430 return { 

431 "path": str(self), 

432 "conn_id": conn_id, 

433 "kwargs": _kwargs, 

434 } 

435 

436 @classmethod 

437 def deserialize(cls, data: dict, version: int) -> ObjectStoragePath: 

438 if version > cls.__version__: 

439 raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.") 

440 

441 _kwargs = data.pop("kwargs") 

442 path = data.pop("path") 

443 conn_id = data.pop("conn_id", None) 

444 

445 return ObjectStoragePath(path, conn_id=conn_id, **_kwargs) 

446 

447 def __str__(self): 

448 conn_id = self.storage_options.get("conn_id") 

449 if self.protocol and conn_id: 

450 return f"{self.protocol}://{conn_id}@{self.path}" 

451 return super().__str__()