Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/airflow/sdk/io/path.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
18from __future__ import annotations
20import contextlib
21import os
22import shutil
23from collections.abc import Mapping
24from typing import TYPE_CHECKING, Any, ClassVar
25from urllib.parse import urlsplit
27from fsspec.utils import stringify_path
28from upath.implementations.cloud import CloudPath
29from upath.registry import get_upath_class
31from airflow.sdk.io.stat import stat_result
32from airflow.sdk.io.store import attach
34if TYPE_CHECKING:
35 from fsspec import AbstractFileSystem
38class _TrackingFileWrapper:
39 """Wrapper that tracks file operations to intercept lineage."""
41 def __init__(self, path: ObjectStoragePath, obj):
42 super().__init__()
43 self._path = path
44 self._obj = obj
46 def __getattr__(self, name):
47 from airflow.lineage.hook import get_hook_lineage_collector
49 if not callable(attr := getattr(self._obj, name)):
50 return attr
52 # If the attribute is a method, wrap it in another method to intercept the call
53 def wrapper(*args, **kwargs):
54 if name == "read":
55 get_hook_lineage_collector().add_input_asset(context=self._path, uri=str(self._path))
56 elif name == "write":
57 get_hook_lineage_collector().add_output_asset(context=self._path, uri=str(self._path))
58 result = attr(*args, **kwargs)
59 return result
61 return wrapper
63 # We need to explicitly implement `__iter__`
64 # because otherwise python iteration logic could use `__getitem__`
65 def __iter__(self):
66 return iter(self._obj)
68 def __getitem__(self, key):
69 # Intercept item access
70 return self._obj[key]
72 def __enter__(self):
73 self._obj.__enter__()
74 return self
76 def __exit__(self, exc_type, exc_val, exc_tb):
77 self._obj.__exit__(exc_type, exc_val, exc_tb)
80class ObjectStoragePath(CloudPath):
81 """A path-like object for object storage."""
83 __version__: ClassVar[int] = 1
85 _protocol_dispatch = False
87 sep: ClassVar[str] = "/"
88 root_marker: ClassVar[str] = "/"
90 __slots__ = ("_hash_cached",)
92 @classmethod
93 def _transform_init_args(
94 cls,
95 args: tuple[str | os.PathLike, ...],
96 protocol: str,
97 storage_options: dict[str, Any],
98 ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]:
99 """Extract conn_id from the URL and set it as a storage option."""
100 if args:
101 arg0 = args[0]
102 parsed_url = urlsplit(stringify_path(arg0))
103 userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
104 if have_info:
105 storage_options.setdefault("conn_id", userinfo or None)
106 parsed_url = parsed_url._replace(netloc=hostinfo)
107 args = (parsed_url.geturl(),) + args[1:]
108 protocol = protocol or parsed_url.scheme
109 return args, protocol, storage_options
111 @classmethod
112 def _fs_factory(
113 cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any]
114 ) -> AbstractFileSystem:
115 return attach(protocol or "file", storage_options.get("conn_id")).fs
117 def __hash__(self) -> int:
118 self._hash_cached: int
119 try:
120 return self._hash_cached
121 except AttributeError:
122 self._hash_cached = hash(str(self))
123 return self._hash_cached
125 def __eq__(self, other: Any) -> bool:
126 return self.samestore(other) and str(self) == str(other)
128 def samestore(self, other: Any) -> bool:
129 return (
130 isinstance(other, ObjectStoragePath)
131 and self.protocol == other.protocol
132 and self.storage_options.get("conn_id") == other.storage_options.get("conn_id")
133 )
135 @property
136 def container(self) -> str:
137 return self.bucket
139 @property
140 def bucket(self) -> str:
141 if self._url:
142 return self._url.netloc
143 return ""
145 @property
146 def key(self) -> str:
147 if self._url:
148 # per convention, we strip the leading slashes to ensure a relative key is returned
149 # we keep the trailing slash to allow for directory-like semantics
150 return self._url.path.lstrip(self.sep)
151 return ""
153 @property
154 def namespace(self) -> str:
155 return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol
157 def open(self, mode="r", **kwargs):
158 """Open the file pointed to by this path."""
159 kwargs.setdefault("block_size", kwargs.pop("buffering", None))
160 return _TrackingFileWrapper(self, self.fs.open(self.path, mode=mode, **kwargs))
162 def stat(self) -> stat_result: # type: ignore[override]
163 """Call ``stat`` and return the result."""
164 return stat_result(
165 self.fs.stat(self.path),
166 protocol=self.protocol,
167 conn_id=self.storage_options.get("conn_id"),
168 )
170 def samefile(self, other_path: Any) -> bool:
171 """Return whether other_path is the same or not as this file."""
172 if not isinstance(other_path, ObjectStoragePath):
173 return False
175 st = self.stat()
176 other_st = other_path.stat()
178 return (
179 st["protocol"] == other_st["protocol"]
180 and st["conn_id"] == other_st["conn_id"]
181 and st["ino"] == other_st["ino"]
182 )
184 def _scandir(self):
185 # Emulate os.scandir(), which returns an object that can be used as a
186 # context manager.
187 return contextlib.nullcontext(self.iterdir())
189 def replace(self, target) -> ObjectStoragePath:
190 """
191 Rename this path to the target path, overwriting if that path exists.
193 The target path may be absolute or relative. Relative paths are
194 interpreted relative to the current working directory, *not* the
195 directory of the Path object.
197 Returns the new Path instance pointing to the target path.
198 """
199 return self.rename(target)
201 @classmethod
202 def cwd(cls):
203 if cls is ObjectStoragePath:
204 return get_upath_class("").cwd()
205 raise NotImplementedError
207 @classmethod
208 def home(cls):
209 if cls is ObjectStoragePath:
210 return get_upath_class("").home()
211 raise NotImplementedError
213 # EXTENDED OPERATIONS
215 def ukey(self) -> str:
216 """Hash of file properties, to tell if it has changed."""
217 return self.fs.ukey(self.path)
219 def checksum(self) -> int:
220 """Return the checksum of the file at this path."""
221 # we directly access the fs here to avoid changing the abstract interface
222 return self.fs.checksum(self.path)
224 def read_block(self, offset: int, length: int, delimiter=None):
225 r"""
226 Read a block of bytes.
228 Starting at ``offset`` of the file, read ``length`` bytes. If
229 ``delimiter`` is set then we ensure that the read starts and stops at
230 delimiter boundaries that follow the locations ``offset`` and ``offset
231 + length``. If ``offset`` is zero then we start at zero. The
232 bytestring returned WILL include the end delimiter string.
234 If offset+length is beyond the eof, reads to eof.
236 :param offset: int
237 Byte offset to start read
238 :param length: int
239 Number of bytes to read. If None, read to the end.
240 :param delimiter: bytes (optional)
241 Ensure reading starts and stops at delimiter bytestring
243 Examples
244 --------
245 .. code-block:: pycon
247 # Read the first 13 bytes (no delimiter)
248 >>> read_block(0, 13)
249 b'Alice, 100\nBo'
251 # Read first 13 bytes, but force newline boundaries
252 >>> read_block(0, 13, delimiter=b"\n")
253 b'Alice, 100\nBob, 200\n'
255 # Read until EOF, but only stop at newline
256 >>> read_block(0, None, delimiter=b"\n")
257 b'Alice, 100\nBob, 200\nCharlie, 300'
259 See Also
260 --------
261 :func:`fsspec.utils.read_block`
262 """
263 return self.fs.read_block(self.path, offset=offset, length=length, delimiter=delimiter)
265 def sign(self, expiration: int = 100, **kwargs):
266 """
267 Create a signed URL representing the given path.
269 Some implementations allow temporary URLs to be generated, as a
270 way of delegating credentials.
272 :param path: str
273 The path on the filesystem
274 :param expiration: int
275 Number of seconds to enable the URL for (if supported)
277 :returns URL: str
278 The signed URL
280 :raises NotImplementedError: if the method is not implemented for a store
281 """
282 return self.fs.sign(self.path, expiration=expiration, **kwargs)
284 def size(self) -> int:
285 """Size in bytes of the file at this path."""
286 return self.fs.size(self.path)
288 def _cp_file(self, dst: ObjectStoragePath, **kwargs):
289 """Copy a single file from this path to another location by streaming the data."""
290 # create the directory or bucket if required
291 if dst.key.endswith(self.sep) or not dst.key:
292 dst.mkdir(exist_ok=True, parents=True)
293 dst = dst / self.key
294 elif dst.is_dir():
295 dst = dst / self.key
297 # streaming copy
298 with self.open("rb") as f1, dst.open("wb") as f2:
299 # make use of system dependent buffer size
300 shutil.copyfileobj(f1, f2, **kwargs)
302 def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None:
303 """
304 Copy file(s) from this path to another location.
306 For remote to remote copies, the key used for the destination will be the same as the source.
307 So that s3://src_bucket/foo/bar will be copied to gcs://dst_bucket/foo/bar and not
308 gcs://dst_bucket/bar.
310 :param dst: Destination path
311 :param recursive: If True, copy directories recursively.
313 kwargs: Additional keyword arguments to be passed to the underlying implementation.
314 """
315 from airflow.lineage.hook import get_hook_lineage_collector
317 if isinstance(dst, str):
318 dst = ObjectStoragePath(dst)
320 if self.samestore(dst) or self.protocol == "file" or dst.protocol == "file":
321 # only emit this in "optimized" variants - else lineage will be captured by file writes/reads
322 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self))
323 get_hook_lineage_collector().add_output_asset(context=dst, uri=str(dst))
325 # same -> same
326 if self.samestore(dst):
327 self.fs.copy(self.path, dst.path, recursive=recursive, **kwargs)
328 return
330 # use optimized path for local -> remote or remote -> local
331 if self.protocol == "file":
332 dst.fs.put(self.path, dst.path, recursive=recursive, **kwargs)
333 return
335 if dst.protocol == "file":
336 self.fs.get(self.path, dst.path, recursive=recursive, **kwargs)
337 return
339 if not self.exists():
340 raise FileNotFoundError(f"{self} does not exist")
342 # remote dir -> remote dir
343 if self.is_dir():
344 if dst.is_file():
345 raise ValueError("Cannot copy directory to a file.")
347 dst.mkdir(exist_ok=True, parents=True)
349 out = self.fs.expand_path(self.path, recursive=True, **kwargs)
351 for path in out:
352 # this check prevents one extra call to is_dir() as
353 # glob returns self as well
354 if path == self.path:
355 continue
357 src_obj = ObjectStoragePath(
358 path,
359 protocol=self.protocol,
360 conn_id=self.storage_options.get("conn_id"),
361 )
363 # skip directories, empty directories will not be created
364 if src_obj.is_dir():
365 continue
367 src_obj._cp_file(dst)
368 return
370 # remote file -> remote dir
371 self._cp_file(dst, **kwargs)
373 def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) -> None:
374 """
375 Move file(s) from this path to another location.
377 :param path: Destination path
378 :param recursive: bool
379 If True, move directories recursively.
381 kwargs: Additional keyword arguments to be passed to the underlying implementation.
382 """
383 from airflow.lineage.hook import get_hook_lineage_collector
385 if isinstance(path, str):
386 path = ObjectStoragePath(path)
388 if self.samestore(path):
389 get_hook_lineage_collector().add_input_asset(context=self, uri=str(self))
390 get_hook_lineage_collector().add_output_asset(context=path, uri=str(path))
391 return self.fs.move(self.path, path.path, recursive=recursive, **kwargs)
393 # non-local copy
394 self.copy(path, recursive=recursive, **kwargs)
395 self.unlink()
397 def serialize(self) -> dict[str, Any]:
398 _kwargs = {**self.storage_options}
399 conn_id = _kwargs.pop("conn_id", None)
401 return {
402 "path": str(self),
403 "conn_id": conn_id,
404 "kwargs": _kwargs,
405 }
407 @classmethod
408 def deserialize(cls, data: dict, version: int) -> ObjectStoragePath:
409 if version > cls.__version__:
410 raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.")
412 _kwargs = data.pop("kwargs")
413 path = data.pop("path")
414 conn_id = data.pop("conn_id", None)
416 return ObjectStoragePath(path, conn_id=conn_id, **_kwargs)
418 def __str__(self):
419 conn_id = self.storage_options.get("conn_id")
420 if self._protocol and conn_id:
421 return f"{self._protocol}://{conn_id}@{self.path}"
422 return super().__str__()