1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""
18Pod generator.
19
20This module provides an interface between the previous Pod
21API and outputs a kubernetes.client.models.V1Pod.
22The advantage being that the full Kubernetes API
23is supported and no serialization need be written.
24"""
25
26from __future__ import annotations
27
28import copy
29import logging
30import os
31import warnings
32from functools import reduce
33from typing import TYPE_CHECKING
34
35import re2
36from dateutil import parser
37from deprecated import deprecated
38from kubernetes.client import models as k8s
39from kubernetes.client.api_client import ApiClient
40
41from airflow.exceptions import (
42 AirflowConfigException,
43 AirflowException,
44 AirflowProviderDeprecationWarning,
45)
46from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
47 POD_NAME_MAX_LENGTH,
48 add_unique_suffix,
49 rand_str,
50)
51from airflow.providers.cncf.kubernetes.pod_generator_deprecated import (
52 PodDefaults as PodDefaultsDeprecated,
53 PodGenerator as PodGeneratorDeprecated,
54)
55from airflow.utils import yaml
56from airflow.utils.hashlib_wrapper import md5
57from airflow.version import version as airflow_version
58
59if TYPE_CHECKING:
60 import datetime
61
62log = logging.getLogger(__name__)
63
64MAX_LABEL_LEN = 63
65
66
67class PodMutationHookException(AirflowException):
68 """Raised when exception happens during Pod Mutation Hook execution."""
69
70
71class PodReconciliationError(AirflowException):
72 """Raised when an error is encountered while trying to merge pod configs."""
73
74
75def make_safe_label_value(string: str) -> str:
76 """
77 Normalize a provided label to be of valid length and characters.
78
79 Valid label values must be 63 characters or less and must be empty or begin and
80 end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
81 dots (.), and alphanumerics between.
82
83 If the label value is greater than 63 chars once made safe, or differs in any
84 way from the original value sent to this function, then we need to truncate to
85 53 chars, and append it with a unique hash.
86 """
87 safe_label = re2.sub(r"^[^a-z0-9A-Z]*|[^a-zA-Z0-9_\-\.]|[^a-z0-9A-Z]*$", "", string)
88
89 if len(safe_label) > MAX_LABEL_LEN or string != safe_label:
90 safe_hash = md5(string.encode()).hexdigest()[:9]
91 safe_label = safe_label[: MAX_LABEL_LEN - len(safe_hash) - 1] + "-" + safe_hash
92
93 return safe_label
94
95
96def datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str:
97 """
98 Transform a datetime string to use as a label.
99
100 Kubernetes doesn't like ":" in labels, since ISO datetime format uses ":" but
101 not "_" let's
102 replace ":" with "_"
103
104 :param datetime_obj: datetime.datetime object
105 :return: ISO-like string representing the datetime
106 """
107 return datetime_obj.isoformat().replace(":", "_").replace("+", "_plus_")
108
109
110def label_safe_datestring_to_datetime(string: str) -> datetime.datetime:
111 """
112 Transform a label back to a datetime object.
113
114 Kubernetes doesn't permit ":" in labels. ISO datetime format uses ":" but not
115 "_", let's
116 replace ":" with "_"
117
118 :param string: str
119 :return: datetime.datetime object
120 """
121 return parser.parse(string.replace("_plus_", "+").replace("_", ":"))
122
123
124class PodGenerator:
125 """
126 Contains Kubernetes Airflow Worker configuration logic.
127
128 Represents a kubernetes pod and manages execution of a single pod.
129 Any configuration that is container specific gets applied to
130 the first container in the list of containers.
131
132 :param pod: The fully specified pod. Mutually exclusive with `pod_template_file`
133 :param pod_template_file: Path to YAML file. Mutually exclusive with `pod`
134 :param extract_xcom: Whether to bring up a container for xcom
135 """
136
137 def __init__(
138 self,
139 pod: k8s.V1Pod | None = None,
140 pod_template_file: str | None = None,
141 extract_xcom: bool = True,
142 ):
143 if not pod_template_file and not pod:
144 raise AirflowConfigException(
145 "Podgenerator requires either a `pod` or a `pod_template_file` argument"
146 )
147 if pod_template_file and pod:
148 raise AirflowConfigException("Cannot pass both `pod` and `pod_template_file` arguments")
149
150 if pod_template_file:
151 self.ud_pod = self.deserialize_model_file(pod_template_file)
152 else:
153 self.ud_pod = pod
154
155 # Attach sidecar
156 self.extract_xcom = extract_xcom
157
158 @deprecated(
159 reason="This method is deprecated and will be removed in the future releases",
160 category=AirflowProviderDeprecationWarning,
161 )
162 def gen_pod(self) -> k8s.V1Pod:
163 """Generate pod."""
164 result = self.ud_pod
165
166 result.metadata.name = add_unique_suffix(name=result.metadata.name)
167
168 if self.extract_xcom:
169 result = self.add_xcom_sidecar(result)
170
171 return result
172
173 @staticmethod
174 @deprecated(
175 reason=(
176 "This function is deprecated. "
177 "Please use airflow.providers.cncf.kubernetes.utils.xcom_sidecar.add_xcom_sidecar instead"
178 ),
179 category=AirflowProviderDeprecationWarning,
180 )
181 def add_xcom_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod:
182 """Add sidecar."""
183 pod_cp = copy.deepcopy(pod)
184 pod_cp.spec.volumes = pod.spec.volumes or []
185 pod_cp.spec.volumes.insert(0, PodDefaultsDeprecated.VOLUME)
186 pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or []
187 pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaultsDeprecated.VOLUME_MOUNT)
188 pod_cp.spec.containers.append(PodDefaultsDeprecated.SIDECAR_CONTAINER)
189
190 return pod_cp
191
192 @staticmethod
193 def from_obj(obj) -> dict | k8s.V1Pod | None:
194 """Convert to pod from obj."""
195 if obj is None:
196 return None
197
198 k8s_legacy_object = obj.get("KubernetesExecutor", None)
199 k8s_object = obj.get("pod_override", None)
200
201 if k8s_legacy_object and k8s_object:
202 raise AirflowConfigException(
203 "Can not have both a legacy and new"
204 "executor_config object. Please delete the KubernetesExecutor"
205 "dict and only use the pod_override kubernetes.client.models.V1Pod"
206 "object."
207 )
208 if not k8s_object and not k8s_legacy_object:
209 return None
210
211 if isinstance(k8s_object, k8s.V1Pod):
212 return k8s_object
213 elif isinstance(k8s_legacy_object, dict):
214 warnings.warn(
215 "Using a dictionary for the executor_config is deprecated and will soon be removed. "
216 'Please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key'
217 " instead. ",
218 category=AirflowProviderDeprecationWarning,
219 stacklevel=2,
220 )
221 return PodGenerator.from_legacy_obj(obj)
222 else:
223 raise TypeError(
224 "Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig"
225 )
226
227 @staticmethod
228 def from_legacy_obj(obj) -> k8s.V1Pod | None:
229 """Convert to pod from obj."""
230 if obj is None:
231 return None
232
233 # We do not want to extract constant here from ExecutorLoader because it is just
234 # A name in dictionary rather than executor selection mechanism and it causes cyclic import
235 namespaced = obj.get("KubernetesExecutor", {})
236
237 if not namespaced:
238 return None
239
240 resources = namespaced.get("resources")
241
242 if resources is None:
243 requests = {
244 "cpu": namespaced.pop("request_cpu", None),
245 "memory": namespaced.pop("request_memory", None),
246 "ephemeral-storage": namespaced.get("ephemeral-storage"), # We pop this one in limits
247 }
248 limits = {
249 "cpu": namespaced.pop("limit_cpu", None),
250 "memory": namespaced.pop("limit_memory", None),
251 "ephemeral-storage": namespaced.pop("ephemeral-storage", None),
252 }
253 all_resources = list(requests.values()) + list(limits.values())
254 if all(r is None for r in all_resources):
255 resources = None
256 else:
257 # remove None's so they don't become 0's
258 requests = {k: v for k, v in requests.items() if v is not None}
259 limits = {k: v for k, v in limits.items() if v is not None}
260 resources = k8s.V1ResourceRequirements(requests=requests, limits=limits)
261 namespaced["resources"] = resources
262 return PodGeneratorDeprecated(**namespaced).gen_pod()
263
264 @staticmethod
265 def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod:
266 """
267 Merge Kubernetes Pod objects.
268
269 :param base_pod: has the base attributes which are overwritten if they exist
270 in the client pod and remain if they do not exist in the client_pod
271 :param client_pod: the pod that the client wants to create.
272 :return: the merged pods
273
274 This can't be done recursively as certain fields are overwritten and some are concatenated.
275 """
276 if client_pod is None:
277 return base_pod
278
279 client_pod_cp = copy.deepcopy(client_pod)
280 client_pod_cp.spec = PodGenerator.reconcile_specs(base_pod.spec, client_pod_cp.spec)
281 client_pod_cp.metadata = PodGenerator.reconcile_metadata(base_pod.metadata, client_pod_cp.metadata)
282 client_pod_cp = merge_objects(base_pod, client_pod_cp)
283
284 return client_pod_cp
285
286 @staticmethod
287 def reconcile_metadata(base_meta, client_meta):
288 """
289 Merge Kubernetes Metadata objects.
290
291 :param base_meta: has the base attributes which are overwritten if they exist
292 in the client_meta and remain if they do not exist in the client_meta
293 :param client_meta: the spec that the client wants to create.
294 :return: the merged specs
295 """
296 if base_meta and not client_meta:
297 return base_meta
298 if not base_meta and client_meta:
299 return client_meta
300 elif client_meta and base_meta:
301 client_meta.labels = merge_objects(base_meta.labels, client_meta.labels)
302 client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations)
303 extend_object_field(base_meta, client_meta, "managed_fields")
304 extend_object_field(base_meta, client_meta, "finalizers")
305 extend_object_field(base_meta, client_meta, "owner_references")
306 return merge_objects(base_meta, client_meta)
307
308 return None
309
310 @staticmethod
311 def reconcile_specs(
312 base_spec: k8s.V1PodSpec | None, client_spec: k8s.V1PodSpec | None
313 ) -> k8s.V1PodSpec | None:
314 """
315 Merge Kubernetes PodSpec objects.
316
317 :param base_spec: has the base attributes which are overwritten if they exist
318 in the client_spec and remain if they do not exist in the client_spec
319 :param client_spec: the spec that the client wants to create.
320 :return: the merged specs
321 """
322 if base_spec and not client_spec:
323 return base_spec
324 if not base_spec and client_spec:
325 return client_spec
326 elif client_spec and base_spec:
327 client_spec.containers = PodGenerator.reconcile_containers(
328 base_spec.containers, client_spec.containers
329 )
330 merged_spec = extend_object_field(base_spec, client_spec, "init_containers")
331 merged_spec = extend_object_field(base_spec, merged_spec, "volumes")
332 return merge_objects(base_spec, merged_spec)
333
334 return None
335
336 @staticmethod
337 def reconcile_containers(
338 base_containers: list[k8s.V1Container], client_containers: list[k8s.V1Container]
339 ) -> list[k8s.V1Container]:
340 """
341 Merge Kubernetes Container objects.
342
343 :param base_containers: has the base attributes which are overwritten if they exist
344 in the client_containers and remain if they do not exist in the client_containers
345 :param client_containers: the containers that the client wants to create.
346 :return: the merged containers
347
348 The runs recursively over the list of containers.
349 """
350 if not base_containers:
351 return client_containers
352 if not client_containers:
353 return base_containers
354
355 client_container = client_containers[0]
356 base_container = base_containers[0]
357 client_container = extend_object_field(base_container, client_container, "volume_mounts")
358 client_container = extend_object_field(base_container, client_container, "env")
359 client_container = extend_object_field(base_container, client_container, "env_from")
360 client_container = extend_object_field(base_container, client_container, "ports")
361 client_container = extend_object_field(base_container, client_container, "volume_devices")
362 client_container = merge_objects(base_container, client_container)
363
364 return [
365 client_container,
366 *PodGenerator.reconcile_containers(base_containers[1:], client_containers[1:]),
367 ]
368
369 @classmethod
370 def construct_pod(
371 cls,
372 dag_id: str,
373 task_id: str,
374 pod_id: str,
375 try_number: int,
376 kube_image: str,
377 date: datetime.datetime | None,
378 args: list[str],
379 pod_override_object: k8s.V1Pod | None,
380 base_worker_pod: k8s.V1Pod,
381 namespace: str,
382 scheduler_job_id: str,
383 run_id: str | None = None,
384 map_index: int = -1,
385 *,
386 with_mutation_hook: bool = False,
387 ) -> k8s.V1Pod:
388 """
389 Create a Pod.
390
391 Construct a pod by gathering and consolidating the configuration from 3 places:
392 - airflow.cfg
393 - executor_config
394 - dynamic arguments
395 """
396 if len(pod_id) > POD_NAME_MAX_LENGTH:
397 warnings.warn(
398 f"pod_id supplied is longer than {POD_NAME_MAX_LENGTH} characters; "
399 f"truncating and adding unique suffix.",
400 UserWarning,
401 stacklevel=2,
402 )
403 pod_id = add_unique_suffix(name=pod_id, max_len=POD_NAME_MAX_LENGTH)
404 try:
405 image = pod_override_object.spec.containers[0].image # type: ignore
406 if not image:
407 image = kube_image
408 except Exception:
409 image = kube_image
410
411 annotations = {
412 "dag_id": dag_id,
413 "task_id": task_id,
414 "try_number": str(try_number),
415 }
416 if map_index >= 0:
417 annotations["map_index"] = str(map_index)
418 if date:
419 annotations["execution_date"] = date.isoformat()
420 if run_id:
421 annotations["run_id"] = run_id
422
423 dynamic_pod = k8s.V1Pod(
424 metadata=k8s.V1ObjectMeta(
425 namespace=namespace,
426 annotations=annotations,
427 name=pod_id,
428 labels=cls.build_labels_for_k8s_executor_pod(
429 dag_id=dag_id,
430 task_id=task_id,
431 try_number=try_number,
432 airflow_worker=scheduler_job_id,
433 map_index=map_index,
434 execution_date=date,
435 run_id=run_id,
436 ),
437 ),
438 spec=k8s.V1PodSpec(
439 containers=[
440 k8s.V1Container(
441 name="base",
442 args=args,
443 image=image,
444 env=[k8s.V1EnvVar(name="AIRFLOW_IS_K8S_EXECUTOR_POD", value="True")],
445 )
446 ]
447 ),
448 )
449
450 # Reconcile the pods starting with the first chronologically,
451 # Pod from the pod_template_File -> Pod from the K8s executor -> Pod from executor_config arg
452 pod_list = [base_worker_pod, dynamic_pod, pod_override_object]
453
454 try:
455 pod = reduce(PodGenerator.reconcile_pods, pod_list)
456 except Exception as e:
457 raise PodReconciliationError from e
458
459 if with_mutation_hook:
460 from airflow.settings import pod_mutation_hook
461
462 try:
463 pod_mutation_hook(pod)
464 except Exception as e:
465 raise PodMutationHookException from e
466
467 return pod
468
469 @classmethod
470 def build_selector_for_k8s_executor_pod(
471 cls,
472 *,
473 dag_id,
474 task_id,
475 try_number,
476 map_index=None,
477 execution_date=None,
478 run_id=None,
479 airflow_worker=None,
480 ):
481 """
482 Generate selector for kubernetes executor pod.
483
484 :meta private:
485 """
486 labels = cls.build_labels_for_k8s_executor_pod(
487 dag_id=dag_id,
488 task_id=task_id,
489 try_number=try_number,
490 map_index=map_index,
491 execution_date=execution_date,
492 run_id=run_id,
493 airflow_worker=airflow_worker,
494 )
495 label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())]
496 selector = ",".join(label_strings)
497 if not airflow_worker: # this filters out KPO pods even when we don't know the scheduler job id
498 selector += ",airflow-worker"
499 return selector
500
501 @classmethod
502 def build_labels_for_k8s_executor_pod(
503 cls,
504 *,
505 dag_id,
506 task_id,
507 try_number,
508 airflow_worker=None,
509 map_index=None,
510 execution_date=None,
511 run_id=None,
512 ):
513 """
514 Generate labels for kubernetes executor pod.
515
516 :meta private:
517 """
518 labels = {
519 "dag_id": make_safe_label_value(dag_id),
520 "task_id": make_safe_label_value(task_id),
521 "try_number": str(try_number),
522 "kubernetes_executor": "True",
523 "airflow_version": airflow_version.replace("+", "-"),
524 }
525 if airflow_worker is not None:
526 labels["airflow-worker"] = make_safe_label_value(str(airflow_worker))
527 if map_index is not None and map_index >= 0:
528 labels["map_index"] = str(map_index)
529 if execution_date:
530 labels["execution_date"] = datetime_to_label_safe_datestring(execution_date)
531 if run_id:
532 labels["run_id"] = make_safe_label_value(run_id)
533 return labels
534
535 @staticmethod
536 def serialize_pod(pod: k8s.V1Pod) -> dict:
537 """
538 Convert a k8s.V1Pod into a json serializable dictionary.
539
540 :param pod: k8s.V1Pod object
541 :return: Serialized version of the pod returned as dict
542 """
543 api_client = ApiClient()
544 return api_client.sanitize_for_serialization(pod)
545
546 @staticmethod
547 def deserialize_model_file(path: str) -> k8s.V1Pod:
548 """
549 Generate a Pod from a file.
550
551 :param path: Path to the file
552 :return: a kubernetes.client.models.V1Pod
553 """
554 if os.path.exists(path):
555 with open(path) as stream:
556 pod = yaml.safe_load(stream)
557 else:
558 pod = None
559 log.warning("Model file %s does not exist", path)
560
561 return PodGenerator.deserialize_model_dict(pod)
562
563 @staticmethod
564 def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod:
565 """
566 Deserializes a Python dictionary to k8s.V1Pod.
567
568 Unfortunately we need access to the private method
569 ``_ApiClient__deserialize_model`` from the kubernetes client.
570 This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
571
572 :param pod_dict: Serialized dict of k8s.V1Pod object
573 :return: De-serialized k8s.V1Pod
574 """
575 api_client = ApiClient()
576 return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod)
577
578 @staticmethod
579 @deprecated(
580 reason="This method is deprecated. Use `add_pod_suffix` in `kubernetes_helper_functions`.",
581 category=AirflowProviderDeprecationWarning,
582 )
583 def make_unique_pod_id(pod_id: str) -> str | None:
584 r"""
585 Generate a unique Pod name.
586
587 Kubernetes pod names must consist of one or more lowercase
588 rfc1035/rfc1123 labels separated by '.' with a maximum length of 253
589 characters.
590
591 Name must pass the following regex for validation
592 ``^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$``
593
594 For more details, see:
595 https://github.com/kubernetes/kubernetes/blob/release-1.1/docs/design/identifiers.md
596
597 :param pod_id: requested pod name
598 :return: ``str`` valid Pod name of appropriate length
599 """
600 if not pod_id:
601 return None
602
603 max_pod_id_len = 100 # arbitrarily chosen
604 suffix = rand_str(8) # 8 seems good enough
605 base_pod_id_len = max_pod_id_len - len(suffix) - 1 # -1 for separator
606 trimmed_pod_id = pod_id[:base_pod_id_len].rstrip("-.")
607 return f"{trimmed_pod_id}-{suffix}"
608
609
610def merge_objects(base_obj, client_obj):
611 """
612 Merge objects.
613
614 :param base_obj: has the base attributes which are overwritten if they exist
615 in the client_obj and remain if they do not exist in the client_obj
616 :param client_obj: the object that the client wants to create.
617 :return: the merged objects
618 """
619 if not base_obj:
620 return client_obj
621 if not client_obj:
622 return base_obj
623
624 client_obj_cp = copy.deepcopy(client_obj)
625
626 if isinstance(base_obj, dict) and isinstance(client_obj_cp, dict):
627 base_obj_cp = copy.deepcopy(base_obj)
628 base_obj_cp.update(client_obj_cp)
629 return base_obj_cp
630
631 for base_key in base_obj.to_dict():
632 base_val = getattr(base_obj, base_key, None)
633 if not getattr(client_obj, base_key, None) and base_val:
634 if not isinstance(client_obj_cp, dict):
635 setattr(client_obj_cp, base_key, base_val)
636 else:
637 client_obj_cp[base_key] = base_val
638 return client_obj_cp
639
640
641def extend_object_field(base_obj, client_obj, field_name):
642 """
643 Add field values to existing objects.
644
645 :param base_obj: an object which has a property `field_name` that is a list
646 :param client_obj: an object which has a property `field_name` that is a list.
647 A copy of this object is returned with `field_name` modified
648 :param field_name: the name of the list field
649 :return: the client_obj with the property `field_name` being the two properties appended
650 """
651 client_obj_cp = copy.deepcopy(client_obj)
652 base_obj_field = getattr(base_obj, field_name, None)
653 client_obj_field = getattr(client_obj, field_name, None)
654
655 if (not isinstance(base_obj_field, list) and base_obj_field is not None) or (
656 not isinstance(client_obj_field, list) and client_obj_field is not None
657 ):
658 raise ValueError(
659 f"The chosen field must be a list. Got {type(base_obj_field)} base_object_field "
660 f"and {type(client_obj_field)} client_object_field."
661 )
662
663 if not base_obj_field:
664 return client_obj_cp
665 if not client_obj_field:
666 setattr(client_obj_cp, field_name, base_obj_field)
667 return client_obj_cp
668
669 appended_fields = base_obj_field + client_obj_field
670 setattr(client_obj_cp, field_name, appended_fields)
671 return client_obj_cp