1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import json
21import logging
22import operator
23import os
24import urllib.parse
25import warnings
26from collections.abc import Callable
27from typing import TYPE_CHECKING, Any, ClassVar, overload
28
29import attrs
30
31from airflow.sdk.providers_manager_runtime import ProvidersManagerTaskRuntime
32
33if TYPE_CHECKING:
34 from collections.abc import Collection
35 from urllib.parse import SplitResult
36
37 from pydantic.types import JsonValue
38 from typing_extensions import Self
39
40 from airflow.sdk.api.datamodels._generated import AssetProfile
41 from airflow.sdk.io.path import ObjectStoragePath
42 from airflow.triggers.base import BaseEventTrigger
43
44 AttrsInstance = attrs.AttrsInstance
45else:
46 AttrsInstance = object
47
48
49__all__ = [
50 "Asset",
51 "Dataset",
52 "Model",
53 "AssetAlias",
54 "AssetAll",
55 "AssetAny",
56 "AssetNameRef",
57 "AssetRef",
58 "AssetUriRef",
59 "AssetWatcher",
60]
61
62from airflow.sdk.configuration import conf
63
64log = logging.getLogger(__name__)
65
66
67SQL_ALCHEMY_CONN = conf.get("database", "SQL_ALCHEMY_CONN", fallback="NOT AVAILABLE")
68
69
70@attrs.define(frozen=True)
71class AssetUniqueKey(AttrsInstance):
72 """
73 Columns to identify an unique asset.
74
75 :meta private:
76 """
77
78 name: str
79 uri: str
80
81 @classmethod
82 def from_asset(cls, asset: Asset) -> Self:
83 return cls(name=asset.name, uri=asset.uri)
84
85 def to_asset(self) -> Asset:
86 return Asset(name=self.name, uri=self.uri)
87
88 @staticmethod
89 def from_profile(profile: AssetProfile) -> AssetUniqueKey:
90 if profile.name and profile.uri:
91 return AssetUniqueKey(name=profile.name, uri=profile.uri)
92
93 if name := profile.name:
94 return AssetUniqueKey(name=name, uri=name)
95 if uri := profile.uri:
96 return AssetUniqueKey(name=uri, uri=uri)
97
98 raise ValueError("name and uri cannot both be empty")
99
100
101@attrs.define(frozen=True)
102class AssetAliasUniqueKey:
103 """
104 Columns to identify an unique asset alias.
105
106 :meta private:
107 """
108
109 name: str
110
111 @classmethod
112 def from_asset_alias(cls, asset_alias: AssetAlias) -> Self:
113 return cls(name=asset_alias.name)
114
115 def to_asset_alias(self) -> AssetAlias:
116 return AssetAlias(name=self.name)
117
118
119BaseAssetUniqueKey = AssetUniqueKey | AssetAliasUniqueKey
120
121
122def normalize_noop(parts: SplitResult) -> SplitResult:
123 """
124 Place-hold a :class:`~urllib.parse.SplitResult`` normalizer.
125
126 :meta private:
127 """
128 return parts
129
130
131def _get_uri_normalizer(scheme: str) -> Callable[[SplitResult], SplitResult] | None:
132 if scheme == "file":
133 return normalize_noop
134
135 return ProvidersManagerTaskRuntime().asset_uri_handlers.get(scheme)
136
137
138def _get_normalized_scheme(uri: str) -> str:
139 parsed = urllib.parse.urlsplit(uri)
140 return parsed.scheme.lower()
141
142
143def _sanitize_uri(inp: str | ObjectStoragePath) -> str:
144 """
145 Sanitize an asset URI.
146
147 This checks for URI validity, and normalizes the URI if needed. A fully
148 normalized URI is returned.
149 """
150 uri = str(inp)
151 parsed = urllib.parse.urlsplit(uri)
152 if not parsed.scheme and not parsed.netloc: # Does not look like a URI.
153 return uri
154 if not (normalized_scheme := _get_normalized_scheme(uri)):
155 return uri
156 if normalized_scheme.startswith("x-"):
157 return uri
158 if normalized_scheme == "airflow":
159 raise ValueError("Asset scheme 'airflow' is reserved")
160 if parsed.password:
161 # TODO: Collect this into a DagWarning.
162 warnings.warn(
163 "An Asset URI should not contain a password. User info has been automatically dropped.",
164 UserWarning,
165 stacklevel=3,
166 )
167 _, _, normalized_netloc = parsed.netloc.rpartition("@")
168 else:
169 normalized_netloc = parsed.netloc
170 if parsed.query:
171 normalized_query = urllib.parse.urlencode(sorted(urllib.parse.parse_qsl(parsed.query)))
172 else:
173 normalized_query = ""
174 parsed = parsed._replace(
175 scheme=normalized_scheme,
176 netloc=normalized_netloc,
177 path=parsed.path.rstrip("/") or "/", # Remove all trailing slashes.
178 query=normalized_query,
179 fragment="", # Ignore any fragments.
180 )
181 if (normalizer := _get_uri_normalizer(normalized_scheme)) is not None:
182 parsed = normalizer(parsed)
183 return urllib.parse.urlunsplit(parsed)
184
185
186def _validate_identifier(instance, attribute, value):
187 if not isinstance(value, str):
188 raise ValueError(f"{type(instance).__name__} {attribute.name} must be a string")
189 if len(value) > 1500:
190 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot exceed 1500 characters")
191 if value.isspace():
192 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be just whitespace")
193 # We use latin1_general_cs to store the name (and group, asset values etc.) on MySQL.
194 # relaxing this check for non mysql backend
195 if SQL_ALCHEMY_CONN.startswith("mysql") and not value.isascii():
196 raise ValueError(f"{type(instance).__name__} {attribute.name} must only consist of ASCII characters")
197 return value
198
199
200def _validate_non_empty_identifier(instance, attribute, value):
201 if not _validate_identifier(instance, attribute, value):
202 raise ValueError(f"{type(instance).__name__} {attribute.name} cannot be empty")
203 return value
204
205
206def _validate_asset_name(instance, attribute, value):
207 _validate_non_empty_identifier(instance, attribute, value)
208 if value == "self" or value == "context":
209 raise ValueError(f"prohibited name for asset: {value}")
210 return value
211
212
213def _set_extra_default(extra: dict[str, JsonValue] | None) -> dict:
214 """
215 Automatically convert None to an empty dict.
216
217 This allows the caller site to continue doing ``Asset(uri, extra=None)``,
218 but still allow the ``extra`` attribute to always be a dict.
219 """
220 if extra is None:
221 return {}
222 return extra
223
224
225class BaseAsset:
226 """
227 Protocol for all asset triggers to use in ``DAG(schedule=...)``.
228
229 :meta private:
230 """
231
232 def __or__(self, other: BaseAsset) -> BaseAsset:
233 if not isinstance(other, BaseAsset):
234 return NotImplemented
235 return AssetAny(self, other)
236
237 def __and__(self, other: BaseAsset) -> BaseAsset:
238 if not isinstance(other, BaseAsset):
239 return NotImplemented
240 return AssetAll(self, other)
241
242
243def _validate_asset_watcher_trigger(instance, attribute, value):
244 from airflow.triggers.base import BaseEventTrigger
245
246 if not isinstance(value, BaseEventTrigger):
247 raise ValueError("Asset watcher trigger must inherit BaseEventTrigger")
248 return value
249
250
251@attrs.define
252class AssetWatcher:
253 """A representation of an asset watcher. The name uniquely identifies the watch."""
254
255 name: str
256 trigger: BaseEventTrigger = attrs.field(validator=_validate_asset_watcher_trigger)
257
258
259@attrs.define(init=False, unsafe_hash=False)
260class Asset(os.PathLike, BaseAsset):
261 """A representation of data asset dependencies between workflows."""
262
263 name: str = attrs.field(
264 validator=[_validate_asset_name],
265 )
266 uri: str = attrs.field(
267 validator=[_validate_non_empty_identifier],
268 converter=_sanitize_uri,
269 )
270 group: str = attrs.field(
271 default=attrs.Factory(operator.attrgetter("asset_type"), takes_self=True),
272 validator=[_validate_identifier],
273 )
274 extra: dict[str, JsonValue] = attrs.field(
275 factory=dict,
276 converter=_set_extra_default,
277 )
278 watchers: list[AssetWatcher] = attrs.field(
279 factory=list,
280 )
281
282 asset_type: ClassVar[str] = "asset"
283 __version__: ClassVar[int] = 1
284
285 @overload
286 def __init__(
287 self,
288 name: str,
289 uri: str | ObjectStoragePath,
290 *,
291 group: str = ...,
292 extra: dict[str, JsonValue] | None = None,
293 watchers: list[AssetWatcher] = ...,
294 ) -> None:
295 """Canonical; both name and uri are provided."""
296
297 @overload
298 def __init__(
299 self,
300 name: str,
301 *,
302 group: str = ...,
303 extra: dict[str, JsonValue] | None = None,
304 watchers: list[AssetWatcher] = ...,
305 ) -> None:
306 """It's possible to only provide the name, either by keyword or as the only positional argument."""
307
308 @overload
309 def __init__(
310 self,
311 *,
312 uri: str | ObjectStoragePath,
313 group: str = ...,
314 extra: dict[str, JsonValue] | None = None,
315 watchers: list[AssetWatcher] = ...,
316 ) -> None:
317 """It's possible to only provide the URI as a keyword argument."""
318
319 def __init__(
320 self,
321 name: str | None = None,
322 uri: str | ObjectStoragePath | None = None,
323 *,
324 group: str | None = None,
325 extra: dict[str, JsonValue] | None = None,
326 watchers: list[AssetWatcher] | None = None,
327 ) -> None:
328 if name is None and uri is None:
329 raise TypeError("Asset() requires either 'name' or 'uri'")
330 if name is None:
331 name = str(uri)
332 elif uri is None:
333 uri = name
334
335 if TYPE_CHECKING:
336 assert name is not None
337 assert uri is not None
338
339 # attrs default (and factory) does not kick in if any value is given to
340 # the argument. We need to exclude defaults from the custom ___init___.
341 kwargs: dict[str, Any] = {}
342 if group is not None:
343 kwargs["group"] = group
344 if extra is not None:
345 kwargs["extra"] = extra
346 if watchers is not None:
347 kwargs["watchers"] = watchers
348
349 self.__attrs_init__(name=name, uri=uri, **kwargs)
350
351 @overload
352 @staticmethod
353 def ref(*, name: str) -> AssetNameRef: ...
354
355 @overload
356 @staticmethod
357 def ref(*, uri: str) -> AssetUriRef: ...
358
359 @staticmethod
360 def ref(*, name: str = "", uri: str = "") -> AssetRef:
361 if name and uri:
362 raise TypeError("Asset reference must be made to either name or URI, not both")
363 if name:
364 return AssetNameRef(name)
365 if uri:
366 return AssetUriRef(uri)
367 raise TypeError("Asset reference expects keyword argument 'name' or 'uri'")
368
369 def __fspath__(self) -> str:
370 return self.uri
371
372 def __eq__(self, other: Any) -> bool:
373 # The Asset class can be subclassed, and we don't want fields added by a
374 # subclass to break equality. This explicitly filters out only fields
375 # defined by the Asset class for comparison.
376 if not isinstance(other, Asset):
377 return NotImplemented
378 f = attrs.filters.include(*attrs.fields_dict(Asset))
379 return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)
380
381 def __hash__(self):
382 f = attrs.filters.include(*attrs.fields_dict(Asset))
383 return hash(json.dumps(attrs.asdict(self, filter=f), sort_keys=True))
384
385 @property
386 def normalized_uri(self) -> str | None:
387 """
388 Returns the normalized and AIP-60 compliant URI whenever possible.
389
390 If we can't retrieve the scheme from URI or no normalizer is provided or if parsing fails,
391 it returns None.
392
393 If a normalizer for the scheme exists and parsing is successful we return the normalizer result.
394 """
395 if not (normalized_scheme := _get_normalized_scheme(self.uri)):
396 return None
397
398 if (normalizer := _get_uri_normalizer(normalized_scheme)) is None:
399 return None
400 parsed = urllib.parse.urlsplit(self.uri)
401 try:
402 normalized_uri = normalizer(parsed)
403 return urllib.parse.urlunsplit(normalized_uri)
404 except ValueError:
405 return None
406
407
408class AssetRef(BaseAsset, AttrsInstance):
409 """
410 Reference to an asset.
411
412 This class is not intended to be instantiated directly. Call ``Asset.ref``
413 instead to create one of the subclasses.
414
415 :meta private:
416 """
417
418
419@attrs.define(hash=True)
420class AssetNameRef(AssetRef):
421 """Name reference to an asset."""
422
423 name: str
424
425
426@attrs.define(hash=True)
427class AssetUriRef(AssetRef):
428 """URI reference to an asset."""
429
430 uri: str
431
432
433class Dataset(Asset):
434 """A representation of dataset dependencies between workflows."""
435
436 asset_type: ClassVar[str] = "dataset"
437
438
439class Model(Asset):
440 """A representation of model dependencies between workflows."""
441
442 asset_type: ClassVar[str] = "model"
443
444
445@attrs.define(hash=True)
446class AssetAlias(BaseAsset):
447 """
448 A representation of an asset alias.
449
450 An asset alias can be used to create assets at task execution time.
451 """
452
453 name: str = attrs.field(validator=_validate_non_empty_identifier)
454 group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier)
455
456
457class AssetBooleanCondition(BaseAsset):
458 """
459 Base class for asset boolean logic.
460
461 :meta private:
462 """
463
464 objects: Collection[BaseAsset]
465
466 def __init__(self, *objects: BaseAsset) -> None:
467 if not all(isinstance(o, BaseAsset) for o in objects):
468 raise TypeError("expect asset expressions in condition")
469 self.objects = objects
470
471 def __eq__(self, other: object) -> bool:
472 if not isinstance(other, type(self)):
473 return NotImplemented
474 return self.objects == other.objects
475
476 def __hash__(self) -> int:
477 return hash(tuple(self.objects))
478
479
480class AssetAny(AssetBooleanCondition):
481 """Use to combine assets schedule references in an "or" relationship."""
482
483 def __or__(self, other: BaseAsset) -> BaseAsset:
484 if not isinstance(other, BaseAsset):
485 return NotImplemented
486 # Optimization: X | (Y | Z) is equivalent to X | Y | Z.
487 return AssetAny(*self.objects, other)
488
489 def __repr__(self) -> str:
490 return f"AssetAny({', '.join(map(str, self.objects))})"
491
492
493class AssetAll(AssetBooleanCondition):
494 """Use to combine assets schedule references in an "and" relationship."""
495
496 agg_func = all # type: ignore[assignment]
497
498 def __and__(self, other: BaseAsset) -> BaseAsset:
499 if not isinstance(other, BaseAsset):
500 return NotImplemented
501 # Optimization: X & (Y & Z) is equivalent to X & Y & Z.
502 return AssetAll(*self.objects, other)
503
504 def __repr__(self) -> str:
505 return f"AssetAll({', '.join(map(str, self.objects))})"
506
507
508@attrs.define
509class AssetAliasEvent(attrs.AttrsInstance):
510 """Representation of asset event to be triggered by an asset alias."""
511
512 source_alias_name: str
513 dest_asset_key: AssetUniqueKey
514 dest_asset_extra: dict[str, JsonValue]
515 extra: dict[str, JsonValue]