Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/io/path.py: 29%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
18from __future__ import annotations
20import shutil
21from typing import TYPE_CHECKING, Any, ClassVar
22from urllib.parse import urlsplit
24from fsspec.utils import stringify_path
25from upath import UPath
26from upath.extensions import ProxyUPath
28from airflow.sdk.io.stat import stat_result
29from airflow.sdk.io.store import attach
31if TYPE_CHECKING:
32 from fsspec import AbstractFileSystem
33 from typing_extensions import Self
34 from upath.types import JoinablePathLike
37class _TrackingFileWrapper:
38 """Wrapper that tracks file operations to intercept lineage."""
40 def __init__(self, path: ObjectStoragePath, obj):
41 super().__init__()
42 self._path = path
43 self._obj = obj
45 def __getattr__(self, name):
46 from airflow.sdk.lineage import get_hook_lineage_collector
48 if not callable(attr := getattr(self._obj, name)):
49 return attr
51 # If the attribute is a method, wrap it in another method to intercept the call
52 def wrapper(*args, **kwargs):
53 if name == "read":
54 get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path))
55 elif name == "write":
56 get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path))
57 result = attr(*args, **kwargs)
58 return result
60 return wrapper
62 # We need to explicitly implement `__iter__`
63 # because otherwise python iteration logic could use `__getitem__`
64 def __iter__(self):
65 return iter(self._obj)
67 def __getitem__(self, key):
68 # Intercept item access
69 return self._obj[key]
71 def __enter__(self):
72 self._obj.__enter__()
73 return self
75 def __exit__(self, exc_type, exc_val, exc_tb):
76 self._obj.__exit__(exc_type, exc_val, exc_tb)
79class ObjectStoragePath(ProxyUPath):
80 """A path-like object for object storage."""
82 __version__: ClassVar[int] = 1
84 sep: ClassVar[str] = "/"
85 root_marker: ClassVar[str] = "/"
87 __slots__ = ("_hash_cached",)
89 def __init__(
90 self,
91 *args: JoinablePathLike,
92 protocol: str | None = None,
93 conn_id: str | None = None,
94 **storage_options: Any,
95 ) -> None:
96 # ensure conn_id is always set in storage_options
97 storage_options.setdefault("conn_id", None)
98 # parse conn_id from args if provided
99 if args:
100 arg0 = args[0]
101 if isinstance(arg0, type(self)):
102 storage_options["conn_id"] = arg0.storage_options.get("conn_id")
103 else:
104 parsed_url = urlsplit(stringify_path(arg0))
105 userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
106 if have_info:
107 conn_id = storage_options["conn_id"] = userinfo or None
108 parsed_url = parsed_url._replace(netloc=hostinfo)
109 args = (parsed_url.geturl(),) + args[1:]
110 protocol = protocol or parsed_url.scheme
111 # override conn_id if explicitly provided
112 if conn_id is not None:
113 storage_options["conn_id"] = conn_id
114 super().__init__(*args, protocol=protocol, **storage_options)
116 @property
117 def fs(self) -> AbstractFileSystem:
118 """Return the filesystem for this path, using airflow's attach mechanism."""
119 conn_id = self.storage_options.get("conn_id")
120 return attach(self.protocol or "file", conn_id).fs
122 def __hash__(self) -> int:
123 self._hash_cached: int
124 try:
125 return self._hash_cached
126 except AttributeError:
127 self._hash_cached = hash(str(self))
128 return self._hash_cached
130 def __eq__(self, other: Any) -> bool:
131 return self.samestore(other) and str(self) == str(other)
133 def samestore(self, other: Any) -> bool:
134 return (
135 isinstance(other, ObjectStoragePath)
136 and self.protocol == other.protocol
137 and self.storage_options.get("conn_id") == other.storage_options.get("conn_id")
138 )
140 @property
141 def container(self) -> str:
142 return self.bucket
144 @property
145 def bucket(self) -> str:
146 if self._url:
147 return self._url.netloc
148 return ""
150 @property
151 def key(self) -> str:
152 if self._url:
153 # per convention, we strip the leading slashes to ensure a relative key is returned
154 # we keep the trailing slash to allow for directory-like semantics
155 return self._url.path.lstrip(self.sep)
156 return ""
158 @property
159 def namespace(self) -> str:
160 return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol
162 def open(self, mode="r", **kwargs):
163 """Open the file pointed to by this path."""
164 kwargs.setdefault("block_size", kwargs.pop("buffering", None))
165 return _TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs))
167 def stat(self) -> stat_result: # type: ignore[override]
168 """Call ``stat`` and return the result."""
169 return stat_result(
170 self.fs.stat(self.path),
171 protocol=self.protocol,
172 conn_id=self.storage_options.get("conn_id"),
173 )
175 def samefile(self, other_path: Any) -> bool:
176 """Return whether other_path is the same or not as this file."""
177 if not isinstance(other_path, ObjectStoragePath):
178 return False
180 st = self.stat()
181 other_st = other_path.stat()
183 return (
184 st["protocol"] == other_st["protocol"]
185 and st["conn_id"] == other_st["conn_id"]
186 and st["ino"] == other_st["ino"]
187 )
189 def replace(self, target) -> Self:
190 """
191 Rename this path to the target path, overwriting if that path exists.
193 The target path may be absolute or relative. Relative paths are
194 interpreted relative to the current working directory, *not* the
195 directory of the Path object.
197 Returns the new Path instance pointing to the target path.
198 """
199 return self.rename(target)
201 @classmethod
202 def cwd(cls) -> Self:
203 return cls._from_upath(UPath.cwd())
205 @classmethod
206 def home(cls) -> Self:
207 return cls._from_upath(UPath.home())
209 # EXTENDED OPERATIONS
211 def ukey(self) -> str:
212 """Hash of file properties, to tell if it has changed."""
213 return self.fs.ukey(self.path)
215 def checksum(self) -> int:
216 """Return the checksum of the file at this path."""
217 # we directly access the fs here to avoid changing the abstract interface
218 return self.fs.checksum(self.path)
220 def read_block(self, offset: int, length: int, delimiter=None):
221 r"""
222 Read a block of bytes.
224 Starting at ``offset`` of the file, read ``length`` bytes. If
225 ``delimiter`` is set then we ensure that the read starts and stops at
226 delimiter boundaries that follow the locations ``offset`` and ``offset
227 + length``. If ``offset`` is zero then we start at zero. The
228 bytestring returned WILL include the end delimiter string.
230 If offset+length is beyond the eof, reads to eof.
232 :param offset: int
233 Byte offset to start read
234 :param length: int
235 Number of bytes to read. If None, read to the end.
236 :param delimiter: bytes (optional)
237 Ensure reading starts and stops at delimiter bytestring
239 Examples
240 --------
241 .. code-block:: pycon
243 # Read the first 13 bytes (no delimiter)
244 >>> read_block(0, 13)
245 b'Alice, 100\nBo'
247 # Read first 13 bytes, but force newline boundaries
248 >>> read_block(0, 13, delimiter=b"\n")
249 b'Alice, 100\nBob, 200\n'
251 # Read until EOF, but only stop at newline
252 >>> read_block(0, None, delimiter=b"\n")
253 b'Alice, 100\nBob, 200\nCharlie, 300'
255 See Also
256 --------
257 :func:`fsspec.utils.read_block`
258 """
259 return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter)
261 def sign(self, expiration: int = 100, **kwargs):
262 """
263 Create a signed URL representing the given path.
265 Some implementations allow temporary URLs to be generated, as a
266 way of delegating credentials.
268 :param path: str
269 The path on the filesystem
270 :param expiration: int
271 Number of seconds to enable the URL for (if supported)
273 :returns URL: str
274 The signed URL
276 :raises NotImplementedError: if the method is not implemented for a store
277 """
278 return self.fs.sign(self.path, expiration=expiration, **kwargs)
280 def size(self) -> int:
281 """Size in bytes of the file at this path."""
282 return self.fs.size(self.path)
284 def _cp_file(self, dst: ObjectStoragePath, **kwargs):
285 """Copy a single file from this path to another location by streaming the data."""
286 # create the directory or bucket if required
287 if dst.key.endswith(self.sep) or not dst.key:
288 dst.mkdir(exist_ok=True, parents=True)
289 dst = dst / self.key
290 elif dst.is_dir():
291 dst = dst / self.key
293 # streaming copy
294 with self.open("rb") as f1, dst.open("wb") as f2:
295 # make use of system dependent buffer size
296 shutil.copyfileobj(f1, f2, **kwargs)
298 def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
299 """
300 Copy file(s) from this path to another location.
302 For remote to remote copies, the key used for the destination will be the same as the source.
303 So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not
304 gcs://dst_bucket/bar.
306 :param dst: Destination path
307 :param recursive: If True, copy directories recursively.
309 kwargs: Additional keyword arguments to be passed to the underlying implementation.
310 """
311 from airflow.sdk.lineage import get_hook_lineage_collector
313 if isinstance(dst, str):
314 dst = ObjectStoragePath(dst)
316 if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file":
317 # only emit this in "optimized" variants - else lineage will be captured by file writes/reads
318 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self))
319 get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst))
321 # same -> same
322 if self.samestore(dst):
323 self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
324 return
326 # use optimized path for local -> remote or remote -> local
327 if self.protocol == "file":
328 dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs)
329 return
331 if dst.protocol == "file":
332 self.fs.get(self.path, dst.path, recursive=recursive, **kwargs)
333 return
335 if not self.exists():
336 raise FileNotFoundError(f"{self} does not exist")
338 # remote dir -> remote dir
339 if self.is_dir():
340 if dst.is_file():
341 raise ValueError("Cannot copy directory to a file.")
343 dst.mkdir(exist_ok=True, parents=True)
345 out = self.fs.expand_path(self.path, recursive=True, **kwargs)
347 for path in out:
348 # this check prevents one extra call to is_dir() as
349 # glob returns self as well
350 if path == self.path:
351 continue
353 src_obj = ObjectStoragePath(
354 path,
355 protocol=self.protocol,
356 conn_id=self.storage_options.get("conn_id"),
357 )
359 # skip directories, empty directories will not be created
360 if src_obj.is_dir():
361 continue
363 src_obj._cp_file(dst)
364 return
366 # remote file -> remote dir
367 self._cp_file(dst, **kwargs)
369 def copy_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
370 """
371 Copy file(s) from this path into another directory.
373 :param target_dir: Destination directory
374 :param recursive: If True, copy directories recursively.
376 kwargs: Additional keyword arguments to be passed to the underlying implementation.
377 """
378 if isinstance(target_dir, str):
379 target_dir = ObjectStoragePath(target_dir)
380 if not target_dir.is_dir():
381 raise NotADirectoryError(f"Destination {target_dir} is not a directory.")
382 dst_path = target_dir / self.name
383 self.copy(dst_path, recursive=recursive, **kwargs)
385 def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
386 """
387 Move file(s) from this path to another location.
389 :param path: Destination path
390 :param recursive: bool
391 If True, move directories recursively.
393 kwargs: Additional keyword arguments to be passed to the underlying implementation.
394 """
395 from airflow.sdk.lineage import get_hook_lineage_collector
397 if isinstance(path, str):
398 path = ObjectStoragePath(path)
400 if self.samestore(path):
401 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self))
402 get_hook_lineage_collector().add_output_asset(context=path, uri=str(path))
403 return self.fs.move(self.path, path.path, recursive=recursive, **kwargs)
405 # non-local copy
406 self.copy(path, recursive=recursive, **kwargs)
407 self.unlink()
409 def move_into(self, target_dir: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None: # type: ignore[override]
410 """
411 Move file(s) from this path into another directory.
413 :param target_dir: Destination directory
414 :param recursive: bool
415 If True, move directories recursively.
417 kwargs: Additional keyword arguments to be passed to the underlying implementation.
418 """
419 if isinstance(target_dir, str):
420 target_dir = ObjectStoragePath(target_dir)
421 if not target_dir.is_dir():
422 raise NotADirectoryError(f"Destination {target_dir} is not a directory.")
423 dst_path = target_dir / self.name
424 self.move(dst_path, recursive=recursive, **kwargs)
426 def serialize(self) -> dict[str, Any]:
427 _kwargs = {**self.storage_options}
428 conn_id = _kwargs.pop("conn_id", None)
430 return {
431 "path": str(self),
432 "conn_id": conn_id,
433 "kwargs": _kwargs,
434 }
436 @classmethod
437 def deserialize(cls, data: dict, version: int) -> ObjectStoragePath:
438 if version > cls.__version__:
439 raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.")
441 _kwargs = data.pop("kwargs")
442 path = data.pop("path")
443 conn_id = data.pop("conn_id", None)
445 return ObjectStoragePath(path, conn_id=conn_id, **_kwargs)
447 def __str__(self):
448 conn_id = self.storage_options.get("conn_id")
449 if self.protocol and conn_id:
450 return f"{self.protocol}://{conn_id}@{self.path}"
451 return super().__str__()