Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/models/taskmap.py: 67%

39 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18"""Table to store information about mapped task instances (AIP-42).""" 

19from __future__ import annotations 

20 

21import collections.abc 

22import enum 

23from typing import TYPE_CHECKING, Any, Collection 

24 

25from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String 

26 

27from airflow.models.base import COLLATION_ARGS, ID_LEN, Base 

28from airflow.utils.sqlalchemy import ExtendedJSON 

29 

30if TYPE_CHECKING: 

31 from airflow.models.taskinstance import TaskInstance 

32 

33 

34class TaskMapVariant(enum.Enum): 

35 """Task map variant. 

36 

37 Possible values are **dict** (for a key-value mapping) and **list** (for an 

38 ordered value sequence). 

39 """ 

40 

41 DICT = "dict" 

42 LIST = "list" 

43 

44 

45class TaskMap(Base): 

46 """Model to track dynamic task-mapping information. 

47 

48 This is currently only populated by an upstream TaskInstance pushing an 

49 XCom that's pulled by a downstream for mapping purposes. 

50 """ 

51 

52 __tablename__ = "task_map" 

53 

54 # Link to upstream TaskInstance creating this dynamic mapping information. 

55 dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) 

56 task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) 

57 run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) 

58 map_index = Column(Integer, primary_key=True) 

59 

60 length = Column(Integer, nullable=False) 

61 keys = Column(ExtendedJSON, nullable=True) 

62 

63 __table_args__ = ( 

64 CheckConstraint(length >= 0, name="task_map_length_not_negative"), 

65 ForeignKeyConstraint( 

66 [dag_id, task_id, run_id, map_index], 

67 [ 

68 "task_instance.dag_id", 

69 "task_instance.task_id", 

70 "task_instance.run_id", 

71 "task_instance.map_index", 

72 ], 

73 name="task_map_task_instance_fkey", 

74 ondelete="CASCADE", 

75 ), 

76 ) 

77 

78 def __init__( 

79 self, 

80 dag_id: str, 

81 task_id: str, 

82 run_id: str, 

83 map_index: int, 

84 length: int, 

85 keys: list[Any] | None, 

86 ) -> None: 

87 self.dag_id = dag_id 

88 self.task_id = task_id 

89 self.run_id = run_id 

90 self.map_index = map_index 

91 self.length = length 

92 self.keys = keys 

93 

94 @classmethod 

95 def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap: 

96 if ti.run_id is None: 

97 raise ValueError("cannot record task map for unrun task instance") 

98 return cls( 

99 dag_id=ti.dag_id, 

100 task_id=ti.task_id, 

101 run_id=ti.run_id, 

102 map_index=ti.map_index, 

103 length=len(value), 

104 keys=(list(value) if isinstance(value, collections.abc.Mapping) else None), 

105 ) 

106 

107 @property 

108 def variant(self) -> TaskMapVariant: 

109 if self.keys is None: 

110 return TaskMapVariant.LIST 

111 return TaskMapVariant.DICT