Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskmap.py: 67%
39 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18"""Table to store information about mapped task instances (AIP-42)."""
19from __future__ import annotations
21import collections.abc
22import enum
23from typing import TYPE_CHECKING, Any, Collection
25from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String
27from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
28from airflow.utils.sqlalchemy import ExtendedJSON
30if TYPE_CHECKING:
31 from airflow.models.taskinstance import TaskInstance
34class TaskMapVariant(enum.Enum):
35 """Task map variant.
37 Possible values are **dict** (for a key-value mapping) and **list** (for an
38 ordered value sequence).
39 """
41 DICT = "dict"
42 LIST = "list"
45class TaskMap(Base):
46 """Model to track dynamic task-mapping information.
48 This is currently only populated by an upstream TaskInstance pushing an
49 XCom that's pulled by a downstream for mapping purposes.
50 """
52 __tablename__ = "task_map"
54 # Link to upstream TaskInstance creating this dynamic mapping information.
55 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
56 task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
57 run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
58 map_index = Column(Integer, primary_key=True)
60 length = Column(Integer, nullable=False)
61 keys = Column(ExtendedJSON, nullable=True)
63 __table_args__ = (
64 CheckConstraint(length >= 0, name="task_map_length_not_negative"),
65 ForeignKeyConstraint(
66 [dag_id, task_id, run_id, map_index],
67 [
68 "task_instance.dag_id",
69 "task_instance.task_id",
70 "task_instance.run_id",
71 "task_instance.map_index",
72 ],
73 name="task_map_task_instance_fkey",
74 ondelete="CASCADE",
75 ),
76 )
78 def __init__(
79 self,
80 dag_id: str,
81 task_id: str,
82 run_id: str,
83 map_index: int,
84 length: int,
85 keys: list[Any] | None,
86 ) -> None:
87 self.dag_id = dag_id
88 self.task_id = task_id
89 self.run_id = run_id
90 self.map_index = map_index
91 self.length = length
92 self.keys = keys
94 @classmethod
95 def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap:
96 if ti.run_id is None:
97 raise ValueError("cannot record task map for unrun task instance")
98 return cls(
99 dag_id=ti.dag_id,
100 task_id=ti.task_id,
101 run_id=ti.run_id,
102 map_index=ti.map_index,
103 length=len(value),
104 keys=(list(value) if isinstance(value, collections.abc.Mapping) else None),
105 )
107 @property
108 def variant(self) -> TaskMapVariant:
109 if self.keys is None:
110 return TaskMapVariant.LIST
111 return TaskMapVariant.DICT