Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/secrets/cache.py: 50%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import datetime
21import multiprocessing
23from airflow.configuration import conf
24from airflow.utils import timezone
27class SecretCache:
28 """A static class to manage the global secret cache."""
30 __manager: multiprocessing.managers.SyncManager | None = None
31 _cache: dict[str, _CacheValue] | None = None
32 _ttl: datetime.timedelta
34 class NotPresentException(Exception):
35 """Raised when a key is not present in the cache."""
37 class _CacheValue:
38 def __init__(self, value: str | None) -> None:
39 self.value = value
40 self.date = timezone.utcnow()
42 def is_expired(self, ttl: datetime.timedelta) -> bool:
43 return timezone.utcnow() - self.date > ttl
45 _VARIABLE_PREFIX = "__v_"
46 _CONNECTION_PREFIX = "__c_"
48 @classmethod
49 def init(cls):
50 """
51 Initialize the cache, provided the configuration allows it.
53 Safe to call several times.
54 """
55 if cls._cache is not None:
56 return
57 use_cache = conf.getboolean(section="secrets", key="use_cache", fallback=False)
58 if not use_cache:
59 return
60 if cls.__manager is None:
61 # it is not really necessary to save the manager, but doing so allows to reuse it between tests,
62 # making them run a lot faster because this operation takes ~300ms each time
63 cls.__manager = multiprocessing.Manager()
64 cls._cache = cls.__manager.dict()
65 ttl_seconds = conf.getint(section="secrets", key="cache_ttl_seconds", fallback=15 * 60)
66 cls._ttl = datetime.timedelta(seconds=ttl_seconds)
68 @classmethod
69 def reset(cls):
70 """Use for test purposes only."""
71 cls._cache = None
73 @classmethod
74 def get_variable(cls, key: str) -> str | None:
75 """
76 Try to get the value associated with the key from the cache.
78 :return: The saved value (which can be None) if present in cache and not expired,
79 a NotPresent exception otherwise.
80 """
81 return cls._get(key, cls._VARIABLE_PREFIX)
83 @classmethod
84 def get_connection_uri(cls, conn_id: str) -> str:
85 """
86 Try to get the uri associated with the conn_id from the cache.
88 :return: The saved uri if present in cache and not expired,
89 a NotPresent exception otherwise.
90 """
91 val = cls._get(conn_id, cls._CONNECTION_PREFIX)
92 if val: # there shouldn't be any empty entries in the connections cache, but we enforce it here.
93 return val
94 raise cls.NotPresentException
96 @classmethod
97 def _get(cls, key: str, prefix: str) -> str | None:
98 if cls._cache is None:
99 # using an exception for misses allow to meaningfully cache None values
100 raise cls.NotPresentException
102 val = cls._cache.get(f"{prefix}{key}")
103 if val and not val.is_expired(cls._ttl):
104 return val.value
105 raise cls.NotPresentException
107 @classmethod
108 def save_variable(cls, key: str, value: str | None):
109 """Save the value for that key in the cache, if initialized."""
110 cls._save(key, value, cls._VARIABLE_PREFIX)
112 @classmethod
113 def save_connection_uri(cls, conn_id: str, uri: str):
114 """Save the uri representation for that connection in the cache, if initialized."""
115 if uri is None:
116 # connections raise exceptions if not present, so we shouldn't have any None value to save.
117 return
118 cls._save(conn_id, uri, cls._CONNECTION_PREFIX)
120 @classmethod
121 def _save(cls, key: str, value: str | None, prefix: str):
122 if cls._cache is not None:
123 cls._cache[f"{prefix}{key}"] = cls._CacheValue(value)
125 @classmethod
126 def invalidate_variable(cls, key: str):
127 """Invalidate (actually removes) the value stored in the cache for that Variable."""
128 if cls._cache is not None:
129 # second arg ensures no exception if key is absent
130 cls._cache.pop(f"{cls._VARIABLE_PREFIX}{key}", None)