Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/secrets/cache.py: 50%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

66 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import datetime 

21import multiprocessing 

22 

23from airflow.configuration import conf 

24from airflow.utils import timezone 

25 

26 

27class SecretCache: 

28 """A static class to manage the global secret cache.""" 

29 

30 __manager: multiprocessing.managers.SyncManager | None = None 

31 _cache: dict[str, _CacheValue] | None = None 

32 _ttl: datetime.timedelta 

33 

34 class NotPresentException(Exception): 

35 """Raised when a key is not present in the cache.""" 

36 

37 class _CacheValue: 

38 def __init__(self, value: str | None) -> None: 

39 self.value = value 

40 self.date = timezone.utcnow() 

41 

42 def is_expired(self, ttl: datetime.timedelta) -> bool: 

43 return timezone.utcnow() - self.date > ttl 

44 

45 _VARIABLE_PREFIX = "__v_" 

46 _CONNECTION_PREFIX = "__c_" 

47 

48 @classmethod 

49 def init(cls): 

50 """ 

51 Initialize the cache, provided the configuration allows it. 

52 

53 Safe to call several times. 

54 """ 

55 if cls._cache is not None: 

56 return 

57 use_cache = conf.getboolean(section="secrets", key="use_cache", fallback=False) 

58 if not use_cache: 

59 return 

60 if cls.__manager is None: 

61 # it is not really necessary to save the manager, but doing so allows to reuse it between tests, 

62 # making them run a lot faster because this operation takes ~300ms each time 

63 cls.__manager = multiprocessing.Manager() 

64 cls._cache = cls.__manager.dict() 

65 ttl_seconds = conf.getint(section="secrets", key="cache_ttl_seconds", fallback=15 * 60) 

66 cls._ttl = datetime.timedelta(seconds=ttl_seconds) 

67 

68 @classmethod 

69 def reset(cls): 

70 """Use for test purposes only.""" 

71 cls._cache = None 

72 

73 @classmethod 

74 def get_variable(cls, key: str) -> str | None: 

75 """ 

76 Try to get the value associated with the key from the cache. 

77 

78 :return: The saved value (which can be None) if present in cache and not expired, 

79 a NotPresent exception otherwise. 

80 """ 

81 return cls._get(key, cls._VARIABLE_PREFIX) 

82 

83 @classmethod 

84 def get_connection_uri(cls, conn_id: str) -> str: 

85 """ 

86 Try to get the uri associated with the conn_id from the cache. 

87 

88 :return: The saved uri if present in cache and not expired, 

89 a NotPresent exception otherwise. 

90 """ 

91 val = cls._get(conn_id, cls._CONNECTION_PREFIX) 

92 if val: # there shouldn't be any empty entries in the connections cache, but we enforce it here. 

93 return val 

94 raise cls.NotPresentException 

95 

96 @classmethod 

97 def _get(cls, key: str, prefix: str) -> str | None: 

98 if cls._cache is None: 

99 # using an exception for misses allow to meaningfully cache None values 

100 raise cls.NotPresentException 

101 

102 val = cls._cache.get(f"{prefix}{key}") 

103 if val and not val.is_expired(cls._ttl): 

104 return val.value 

105 raise cls.NotPresentException 

106 

107 @classmethod 

108 def save_variable(cls, key: str, value: str | None): 

109 """Save the value for that key in the cache, if initialized.""" 

110 cls._save(key, value, cls._VARIABLE_PREFIX) 

111 

112 @classmethod 

113 def save_connection_uri(cls, conn_id: str, uri: str): 

114 """Save the uri representation for that connection in the cache, if initialized.""" 

115 if uri is None: 

116 # connections raise exceptions if not present, so we shouldn't have any None value to save. 

117 return 

118 cls._save(conn_id, uri, cls._CONNECTION_PREFIX) 

119 

120 @classmethod 

121 def _save(cls, key: str, value: str | None, prefix: str): 

122 if cls._cache is not None: 

123 cls._cache[f"{prefix}{key}"] = cls._CacheValue(value) 

124 

125 @classmethod 

126 def invalidate_variable(cls, key: str): 

127 """Invalidate (actually removes) the value stored in the cache for that Variable.""" 

128 if cls._cache is not None: 

129 # second arg ensures no exception if key is absent 

130 cls._cache.pop(f"{cls._VARIABLE_PREFIX}{key}", None)