Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/datasets/manager.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

73 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20from typing import TYPE_CHECKING 

21 

22from sqlalchemy import exc, select 

23from sqlalchemy.orm import joinedload 

24 

25from airflow.configuration import conf 

26from airflow.datasets import Dataset 

27from airflow.listeners.listener import get_listener_manager 

28from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel 

29from airflow.stats import Stats 

30from airflow.utils.log.logging_mixin import LoggingMixin 

31 

32if TYPE_CHECKING: 

33 from sqlalchemy.orm.session import Session 

34 

35 from airflow.models.dag import DagModel 

36 from airflow.models.taskinstance import TaskInstance 

37 

38 

39class DatasetManager(LoggingMixin): 

40 """ 

41 A pluggable class that manages operations for datasets. 

42 

43 The intent is to have one place to handle all Dataset-related operations, so different 

44 Airflow deployments can use plugins that broadcast dataset events to each other. 

45 """ 

46 

47 def __init__(self, **kwargs): 

48 super().__init__(**kwargs) 

49 

50 def create_datasets(self, dataset_models: list[DatasetModel], session: Session) -> None: 

51 """Create new datasets.""" 

52 for dataset_model in dataset_models: 

53 session.add(dataset_model) 

54 session.flush() 

55 

56 for dataset_model in dataset_models: 

57 self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) 

58 

59 def register_dataset_change( 

60 self, 

61 *, 

62 task_instance: TaskInstance | None = None, 

63 dataset: Dataset, 

64 extra=None, 

65 session: Session, 

66 **kwargs, 

67 ) -> DatasetEvent | None: 

68 """ 

69 Register dataset related changes. 

70 

71 For local datasets, look them up, record the dataset event, queue dagruns, and broadcast 

72 the dataset event 

73 """ 

74 dataset_model = session.scalar( 

75 select(DatasetModel) 

76 .where(DatasetModel.uri == dataset.uri) 

77 .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag)) 

78 ) 

79 if not dataset_model: 

80 self.log.warning("DatasetModel %s not found", dataset) 

81 return None 

82 

83 event_kwargs = { 

84 "dataset_id": dataset_model.id, 

85 "extra": extra, 

86 } 

87 if task_instance: 

88 event_kwargs.update( 

89 { 

90 "source_task_id": task_instance.task_id, 

91 "source_dag_id": task_instance.dag_id, 

92 "source_run_id": task_instance.run_id, 

93 "source_map_index": task_instance.map_index, 

94 } 

95 ) 

96 dataset_event = DatasetEvent(**event_kwargs) 

97 session.add(dataset_event) 

98 session.flush() 

99 

100 self.notify_dataset_changed(dataset=dataset) 

101 

102 Stats.incr("dataset.updates") 

103 self._queue_dagruns(dataset_model, session) 

104 session.flush() 

105 return dataset_event 

106 

107 def notify_dataset_created(self, dataset: Dataset): 

108 """Run applicable notification actions when a dataset is created.""" 

109 get_listener_manager().hook.on_dataset_created(dataset=dataset) 

110 

111 def notify_dataset_changed(self, dataset: Dataset): 

112 """Run applicable notification actions when a dataset is changed.""" 

113 get_listener_manager().hook.on_dataset_changed(dataset=dataset) 

114 

115 def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: 

116 # Possible race condition: if multiple dags or multiple (usually 

117 # mapped) tasks update the same dataset, this can fail with a unique 

118 # constraint violation. 

119 # 

120 # If we support it, use ON CONFLICT to do nothing, otherwise 

121 # "fallback" to running this in a nested transaction. This is needed 

122 # so that the adding of these rows happens in the same transaction 

123 # where `ti.state` is changed. 

124 

125 if session.bind.dialect.name == "postgresql": 

126 return self._postgres_queue_dagruns(dataset, session) 

127 return self._slow_path_queue_dagruns(dataset, session) 

128 

129 def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: 

130 def _queue_dagrun_if_needed(dag: DagModel) -> str | None: 

131 if not dag.is_active or dag.is_paused: 

132 return None 

133 item = DatasetDagRunQueue(target_dag_id=dag.dag_id, dataset_id=dataset.id) 

134 # Don't error whole transaction when a single RunQueue item conflicts. 

135 # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint 

136 try: 

137 with session.begin_nested(): 

138 session.merge(item) 

139 except exc.IntegrityError: 

140 self.log.debug("Skipping record %s", item, exc_info=True) 

141 return dag.dag_id 

142 

143 queued_results = (_queue_dagrun_if_needed(ref.dag) for ref in dataset.consuming_dags) 

144 if queued_dag_ids := [r for r in queued_results if r is not None]: 

145 self.log.debug("consuming dag ids %s", queued_dag_ids) 

146 

147 def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: 

148 from sqlalchemy.dialects.postgresql import insert 

149 

150 values = [ 

151 {"target_dag_id": dag.dag_id} 

152 for dag in (r.dag for r in dataset.consuming_dags) 

153 if dag.is_active and not dag.is_paused 

154 ] 

155 if not values: 

156 return 

157 stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset.id).on_conflict_do_nothing() 

158 session.execute(stmt, values) 

159 

160 

161def resolve_dataset_manager() -> DatasetManager: 

162 """Retrieve the dataset manager.""" 

163 _dataset_manager_class = conf.getimport( 

164 section="core", 

165 key="dataset_manager_class", 

166 fallback="airflow.datasets.manager.DatasetManager", 

167 ) 

168 _dataset_manager_kwargs = conf.getjson( 

169 section="core", 

170 key="dataset_manager_kwargs", 

171 fallback={}, 

172 ) 

173 return _dataset_manager_class(**_dataset_manager_kwargs) 

174 

175 

176dataset_manager = resolve_dataset_manager()