1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import annotations
19
20import logging
21import traceback
22from collections.abc import AsyncIterator
23from typing import Any
24
25from airflow._shared.module_loading import import_string, qualname
26from airflow.models.callback import CallbackState
27from airflow.triggers.base import BaseTrigger, TriggerEvent
28
29log = logging.getLogger(__name__)
30
31PAYLOAD_STATUS_KEY = "state"
32PAYLOAD_BODY_KEY = "body"
33
34
35class CallbackTrigger(BaseTrigger):
36 """Trigger that executes a callback function asynchronously."""
37
38 def __init__(self, callback_path: str, callback_kwargs: dict[str, Any] | None = None):
39 super().__init__()
40 self.callback_path = callback_path
41 self.callback_kwargs = callback_kwargs or {}
42
43 def serialize(self) -> tuple[str, dict[str, Any]]:
44 return (
45 qualname(self),
46 {attr: getattr(self, attr) for attr in ("callback_path", "callback_kwargs")},
47 )
48
49 async def run(self) -> AsyncIterator[TriggerEvent]:
50 try:
51 yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.RUNNING})
52 callback = import_string(self.callback_path)
53
54 # TODO: get full context and run template rendering. Right now, a simple context in included in `callback_kwargs`
55 result = await callback(**self.callback_kwargs)
56 yield TriggerEvent({PAYLOAD_STATUS_KEY: CallbackState.SUCCESS, PAYLOAD_BODY_KEY: result})
57
58 except Exception as e:
59 if isinstance(e, ImportError):
60 message = "Failed to import the callable on the triggerer"
61 elif isinstance(e, TypeError) and "await" in str(e):
62 message = "Failed to run the callable because it's not awaitable"
63 else:
64 message = "An error occurred during execution of the callable"
65
66 log.exception("%s: %s; kwargs: %s\n%s", message, self.callback_path, self.callback_kwargs, e)
67 yield TriggerEvent(
68 {
69 PAYLOAD_STATUS_KEY: CallbackState.FAILED,
70 PAYLOAD_BODY_KEY: f"{message}: {traceback.format_exception(e)}",
71 }
72 )