Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/build/lib/airflow/models/taskmap.py: 67%
39 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Table to store information about mapped task instances (AIP-42)."""
19from __future__ import annotations
21import collections.abc
22import enum
23from typing import TYPE_CHECKING, Any, Collection
25from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String
27from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
28from airflow.utils.sqlalchemy import ExtendedJSON
30if TYPE_CHECKING:
31 from airflow.models.taskinstance import TaskInstance
34class TaskMapVariant(enum.Enum):
35 """Task map variant.
37 Possible values are **dict** (for a key-value mapping) and **list** (for an
38 ordered value sequence).
39 """
41 DICT = "dict"
42 LIST = "list"
45class TaskMap(Base):
46 """Model to track dynamic task-mapping information.
48 This is currently only populated by an upstream TaskInstance pushing an
49 XCom that's pulled by a downstream for mapping purposes.
50 """
52 __tablename__ = "task_map"
54 # Link to upstream TaskInstance creating this dynamic mapping information.
55 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
56 task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
57 run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
58 map_index = Column(Integer, primary_key=True)
60 length = Column(Integer, nullable=False)
61 keys = Column(ExtendedJSON, nullable=True)
63 __table_args__ = (
64 CheckConstraint(length >= 0, name="task_map_length_not_negative"),
65 ForeignKeyConstraint(
66 [dag_id, task_id, run_id, map_index],
67 [
68 "task_instance.dag_id",
69 "task_instance.task_id",
70 "task_instance.run_id",
71 "task_instance.map_index",
72 ],
73 name="task_map_task_instance_fkey",
74 ondelete="CASCADE",
75 onupdate="CASCADE",
76 ),
77 )
79 def __init__(
80 self,
81 dag_id: str,
82 task_id: str,
83 run_id: str,
84 map_index: int,
85 length: int,
86 keys: list[Any] | None,
87 ) -> None:
88 self.dag_id = dag_id
89 self.task_id = task_id
90 self.run_id = run_id
91 self.map_index = map_index
92 self.length = length
93 self.keys = keys
95 @classmethod
96 def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap:
97 if ti.run_id is None:
98 raise ValueError("cannot record task map for unrun task instance")
99 return cls(
100 dag_id=ti.dag_id,
101 task_id=ti.task_id,
102 run_id=ti.run_id,
103 map_index=ti.map_index,
104 length=len(value),
105 keys=(list(value) if isinstance(value, collections.abc.Mapping) else None),
106 )
108 @property
109 def variant(self) -> TaskMapVariant:
110 if self.keys is None:
111 return TaskMapVariant.LIST
112 return TaskMapVariant.DICT