1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import logging
21import operator
22import os
23import urllib.parse
24import warnings
25from collections.abc import Callable
26from typing import TYPE_CHECKING, Any, ClassVar, overload
27
28import attrs
29
30if TYPE_CHECKING:
31 from collections.abc import Collection
32 from urllib.parse import SplitResult
33
34 from pydantic.types import JsonValue
35 from typing_extensions import Self
36
37 from airflow.sdk.api.datamodels._generated import AssetProfile
38 from airflow.sdk.io.path import ObjectStoragePath
39 from airflow.triggers.base import BaseEventTrigger
40
41 AttrsInstance = attrs.AttrsInstance
42else:
43 AttrsInstance = object
44
45
46__all__ = [
47 "Asset",
48 "Dataset",
49 "Model",
50 "AssetAlias",
51 "AssetAll",
52 "AssetAny",
53 "AssetNameRef",
54 "AssetRef",
55 "AssetUriRef",
56 "AssetWatcher",
57]
58
59from airflow.sdk.configuration import conf
60
61log = logging.getLogger(__name__)
62
63
64SQL_ALCHEMY_CONN = conf.get("database", "SQL_ALCHEMY_CONN", fallback="NOT AVAILABLE")
65
66
67@attrs.define(frozen=True)
68class AssetUniqueKey(AttrsInstance):
69 """
70 Columns to identify an unique asset.
71
72 :meta private:
73 """
74
75 name: str
76 uri: str
77
78 @classmethod
79 def from_asset(cls, asset: Asset) -> Self:
80 return cls(name=asset.name, uri=asset.uri)
81
82 def to_asset(self) -> Asset:
83 return Asset(name=self.name, uri=self.uri)
84
85 @staticmethod
86 def from_profile(profile: AssetProfile) -> AssetUniqueKey:
87 if profile.name and profile.uri:
88 return AssetUniqueKey(name=profile.name, uri=profile.uri)
89
90 if name := profile.name:
91 return AssetUniqueKey(name=name, uri=name)
92 if uri := profile.uri:
93 return AssetUniqueKey(name=uri, uri=uri)
94
95 raise ValueError("name and uri cannot both be empty")
96
97
98@attrs.define(frozen=True)
99class AssetAliasUniqueKey:
100 """
101 Columns to identify an unique asset alias.
102
103 :meta private:
104 """
105
106 name: str
107
108 @classmethod
109 def from_asset_alias(cls, asset_alias: AssetAlias) -> Self:
110 return cls(name=asset_alias.name)
111
112 def to_asset_alias(self) -> AssetAlias:
113 return AssetAlias(name=self.name)
114
115
116BaseAssetUniqueKey = AssetUniqueKey | AssetAliasUniqueKey
117
118
119def normalize_noop(parts: SplitResult) -> SplitResult:
120 """
121 Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.
122
123 :meta private:
124 """
125 return parts
126
127
128def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
129 if scheme == "file":
130 return normalize_noop
131 from airflow.providers_manager import ProvidersManager
132
133 return ProvidersManager().asset_uri_handlers.get(scheme)
134
135
136def _get_normalized_scheme(uri: str) -> str:
137 parsed = urllib.parse.urlsplit(uri)
138 return parsed.scheme.lower()
139
140
141def _sanitize_uri(inp: str | ObjectStoragePath) -> str:
142 """
143 Sanitize an asset URI.
144
145 This checks for URI validity, and normalizes the URI if needed. A fully
146 normalized URI is returned.
147 """
148 uri = str(inp)
149 parsed = urllib.parse.urlsplit(uri)
150 if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
151 return uri
152 if not (normalized_scheme := _get_normalized_scheme(uri)):
153 return uri
154 if normalized_scheme.startswith("x-"):
155 return uri
156 if normalized_scheme == "airflow":
157 raise ValueError("Asset scheme 'airflow' is reserved")
158 if parsed.password:
159 # TODO: Collect this into a DagWarning.
160 warnings.warn(
161 "An Asset URI should not contain a password. User info has been automatically dropped.",
162 UserWarning,
163 stacklevel=3,
164 )
165 _, _, normalized_netloc = parsed.netloc.rpartition("@")
166 else:
167 normalized_netloc = parsed.netloc
168 if parsed.query:
169 normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query)))
170 else:
171 normalized_query = ""
172 parsed = parsed._replace(
173 scheme=normalized_scheme,
174 netloc=normalized_netloc,
175 path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes.
176 query=normalized_query,
177 fragment="", # Ignore any fragments.
178 )
179 if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None:
180 parsed = normalizer(parsed)
181 return urllib.parse.urlunsplit(parsed)
182
183
184def _validate_identifier(instance, attribute, value):
185 if not isinstance(value, str):
186 raise ValueError(f"{type(instance).__name__} {attribute.name} must be a string")
187 if len(value) > 1500:
188 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot exceed 1500 characters")
189 if value.isspace():
190 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be just whitespace")
191 # We use latin1_general_cs to store the name (and group, asset values etc.) on MySQL.
192 # relaxing this check for non mysql backend
193 if SQL_ALCHEMY_CONN.startswith("mysql") and not value.isascii():
194 raise ValueError(f"{type(instance).__name__} {attribute.name} must only consist of ASCII characters")
195 return value
196
197
198def _validate_non_empty_identifier(instance, attribute, value):
199 if not _validate_identifier(instance, attribute, value):
200 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be empty")
201 return value
202
203
204def _validate_asset_name(instance, attribute, value):
205 _validate_non_empty_identifier(instance, attribute, value)
206 if value == "self" or value == "context":
207 raise ValueError(f"prohibited name for asset: {value}")
208 return value
209
210
211def _set_extra_default(extra: dict[str, JsonValue] | None) -> dict:
212 """
213 Automatically convert None to an empty dict.
214
215 This allows the caller site to continue doing ``Asset(uri, extra=None)``,
216 but still allow the ``extra`` attribute to always be a dict.
217 """
218 if extra is None:
219 return {}
220 return extra
221
222
223class BaseAsset:
224 """
225 Protocol for all asset triggers to use in ``DAG(schedule=...)``.
226
227 :meta private:
228 """
229
230 def __or__(self, other: BaseAsset) -> BaseAsset:
231 if not isinstance(other, BaseAsset):
232 return NotImplemented
233 return AssetAny(self, other)
234
235 def __and__(self, other: BaseAsset) -> BaseAsset:
236 if not isinstance(other, BaseAsset):
237 return NotImplemented
238 return AssetAll(self, other)
239
240
241def _validate_asset_watcher_trigger(instance, attribute, value):
242 from airflow.triggers.base import BaseEventTrigger
243
244 if not isinstance(value, BaseEventTrigger):
245 raise ValueError("Asset watcher trigger must inherit BaseEventTrigger")
246 return value
247
248
249@attrs.define
250class AssetWatcher:
251 """A representation of an asset watcher. The name uniquely identifies the watch."""
252
253 name: str
254 trigger: BaseEventTrigger = attrs.field(validator=_validate_asset_watcher_trigger)
255
256
257@attrs.define(init=False, unsafe_hash=False)
258class Asset(os.PathLike, BaseAsset):
259 """A representation of data asset dependencies between workflows."""
260
261 name: str = attrs.field(
262 validator=[_validate_asset_name],
263 )
264 uri: str = attrs.field(
265 validator=[_validate_non_empty_identifier],
266 converter=_sanitize_uri,
267 )
268 group: str = attrs.field(
269 default=attrs.Factory(operator.attrgetter("asset_type"), takes_self=True),
270 validator=[_validate_identifier],
271 )
272 extra: dict[str, JsonValue] = attrs.field(
273 factory=dict,
274 converter=_set_extra_default,
275 )
276 watchers: list[AssetWatcher] = attrs.field(
277 factory=list,
278 )
279
280 asset_type: ClassVar[str] = "asset"
281 __version__: ClassVar[int] = 1
282
283 @overload
284 def __init__(
285 self,
286 name: str,
287 uri: str | ObjectStoragePath,
288 *,
289 group: str = ...,
290 extra: dict[str, JsonValue] | None = None,
291 watchers: list[AssetWatcher] = ...,
292 ) -> None:
293 """Canonical; both name and uri are provided."""
294
295 @overload
296 def __init__(
297 self,
298 name: str,
299 *,
300 group: str = ...,
301 extra: dict[str, JsonValue] | None = None,
302 watchers: list[AssetWatcher] = ...,
303 ) -> None:
304 """It's possible to only provide the name, either by keyword or as the only positional argument."""
305
306 @overload
307 def __init__(
308 self,
309 *,
310 uri: str | ObjectStoragePath,
311 group: str = ...,
312 extra: dict[str, JsonValue] | None = None,
313 watchers: list[AssetWatcher] = ...,
314 ) -> None:
315 """It's possible to only provide the URI as a keyword argument."""
316
317 def __init__(
318 self,
319 name: str | None = None,
320 uri: str | ObjectStoragePath | None = None,
321 *,
322 group: str | None = None,
323 extra: dict[str, JsonValue] | None = None,
324 watchers: list[AssetWatcher] | None = None,
325 ) -> None:
326 if name is None and uri is None:
327 raise TypeError("Asset() requires either 'name' or 'uri'")
328 if name is None:
329 name = str(uri)
330 elif uri is None:
331 uri = name
332
333 if TYPE_CHECKING:
334 assert name is not None
335 assert uri is not None
336
337 # attrs default (and factory) does not kick in if any value is given to
338 # the argument. We need to exclude defaults from the custom ___init___.
339 kwargs: dict[str, Any] = {}
340 if group is not None:
341 kwargs["group"] = group
342 if extra is not None:
343 kwargs["extra"] = extra
344 if watchers is not None:
345 kwargs["watchers"] = watchers
346
347 self.__attrs_init__(name=name, uri=uri, **kwargs)
348
349 @overload
350 @staticmethod
351 def ref(*, name: str) -> AssetNameRef: ...
352
353 @overload
354 @staticmethod
355 def ref(*, uri: str) -> AssetUriRef: ...
356
357 @staticmethod
358 def ref(*, name: str = "", uri: str = "") -> AssetRef:
359 if name and uri:
360 raise TypeError("Asset reference must be made to either name or URI, not both")
361 if name:
362 return AssetNameRef(name)
363 if uri:
364 return AssetUriRef(uri)
365 raise TypeError("Asset reference expects keyword argument 'name' or 'uri'")
366
367 def __fspath__(self) -> str:
368 return self.uri
369
370 def __eq__(self, other: Any) -> bool:
371 # The Asset class can be subclassed, and we don't want fields added by a
372 # subclass to break equality. This explicitly filters out only fields
373 # defined by the Asset class for comparison.
374 if not isinstance(other, Asset):
375 return NotImplemented
376 f = attrs.filters.include(*attrs.fields_dict(Asset))
377 return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)
378
379 def __hash__(self):
380 f = attrs.filters.include(*attrs.fields_dict(Asset))
381 return hash(attrs.asdict(self, filter=f))
382
383 @property
384 def normalized_uri(self) -> str | None:
385 """
386 Returns the normalized and AIP-60 compliant URI whenever possible.
387
388 If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails,
389 it returns None.
390
391 If a normalizer for the scheme exists and parsing is successful we return the normalizer result.
392 """
393 if not (normalized_scheme := _get_normalized_scheme(self.uri)):
394 return None
395
396 if (normalizer := _get_uri_normalizer(normalized_scheme)) is None:
397 return None
398 parsed = urllib.parse.urlsplit(self.uri)
399 try:
400 normalized_uri = normalizer(parsed)
401 return urllib.parse.urlunsplit(normalized_uri)
402 except ValueError:
403 return None
404
405
406class AssetRef(BaseAsset, AttrsInstance):
407 """
408 Reference to an asset.
409
410 This class is not intended to be instantiated directly. Call ``Asset.ref``
411 instead to create one of the subclasses.
412
413 :meta private:
414 """
415
416
417@attrs.define(hash=True)
418class AssetNameRef(AssetRef):
419 """Name reference to an asset."""
420
421 name: str
422
423
424@attrs.define(hash=True)
425class AssetUriRef(AssetRef):
426 """URI reference to an asset."""
427
428 uri: str
429
430
431class Dataset(Asset):
432 """A representation of dataset dependencies between workflows."""
433
434 asset_type: ClassVar[str] = "dataset"
435
436
437class Model(Asset):
438 """A representation of model dependencies between workflows."""
439
440 asset_type: ClassVar[str] = "model"
441
442
443@attrs.define(hash=True)
444class AssetAlias(BaseAsset):
445 """
446 A representation of an asset alias.
447
448 An asset alias can be used to create assets at task execution time.
449 """
450
451 name: str = attrs.field(validator=_validate_non_empty_identifier)
452 group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier)
453
454
455class AssetBooleanCondition(BaseAsset):
456 """
457 Base class for asset boolean logic.
458
459 :meta private:
460 """
461
462 objects: Collection[BaseAsset]
463
464 def __init__(self, *objects: BaseAsset) -> None:
465 if not all(isinstance(o, BaseAsset) for o in objects):
466 raise TypeError("expect asset expressions in condition")
467 self.objects = objects
468
469 def __eq__(self, other: object) -> bool:
470 if not isinstance(other, type(self)):
471 return NotImplemented
472 return self.objects == other.objects
473
474 def __hash__(self) -> int:
475 return hash(tuple(self.objects))
476
477
478class AssetAny(AssetBooleanCondition):
479 """Use to combine assets schedule references in an "or" relationship."""
480
481 def __or__(self, other: BaseAsset) -> BaseAsset:
482 if not isinstance(other, BaseAsset):
483 return NotImplemented
484 # Optimization: X | (Y | Z) is equivalent to X | Y | Z.
485 return AssetAny(*self.objects, other)
486
487 def __repr__(self) -> str:
488 return f"AssetAny({', '.join(map(str, self.objects))})"
489
490
491class AssetAll(AssetBooleanCondition):
492 """Use to combine assets schedule references in an "and" relationship."""
493
494 agg_func = all # type: ignore[assignment]
495
496 def __and__(self, other: BaseAsset) -> BaseAsset:
497 if not isinstance(other, BaseAsset):
498 return NotImplemented
499 # Optimization: X & (Y & Z) is equivalent to X & Y & Z.
500 return AssetAll(*self.objects, other)
501
502 def __repr__(self) -> str:
503 return f"AssetAll({', '.join(map(str, self.objects))})"
504
505
506@attrs.define
507class AssetAliasEvent(attrs.AttrsInstance):
508 """Representation of asset event to be triggered by an asset alias."""
509
510 source_alias_name: str
511 dest_asset_key: AssetUniqueKey
512 dest_asset_extra: dict[str, JsonValue]
513 extra: dict[str, JsonValue]