1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import base64
8import binascii
9import os
10import time
11import typing
12from collections.abc import Iterable
13
14from cryptography import utils
15from cryptography.exceptions import InvalidSignature
16from cryptography.hazmat.primitives import hashes, padding
17from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
18from cryptography.hazmat.primitives.hmac import HMAC
19
20
21class InvalidToken(Exception):
22 pass
23
24
25_MAX_CLOCK_SKEW = 60
26
27
28class Fernet:
29 def __init__(
30 self,
31 key: bytes | str,
32 backend: typing.Any = None,
33 ) -> None:
34 try:
35 key = base64.urlsafe_b64decode(key)
36 except binascii.Error as exc:
37 raise ValueError(
38 "Fernet key must be 32 url-safe base64-encoded bytes."
39 ) from exc
40 if len(key) != 32:
41 raise ValueError(
42 "Fernet key must be 32 url-safe base64-encoded bytes."
43 )
44
45 self._signing_key = key[:16]
46 self._encryption_key = key[16:]
47
48 @classmethod
49 def generate_key(cls) -> bytes:
50 return base64.urlsafe_b64encode(os.urandom(32))
51
52 def encrypt(self, data: bytes) -> bytes:
53 return self.encrypt_at_time(data, int(time.time()))
54
55 def encrypt_at_time(self, data: bytes, current_time: int) -> bytes:
56 iv = os.urandom(16)
57 return self._encrypt_from_parts(data, current_time, iv)
58
59 def _encrypt_from_parts(
60 self, data: bytes, current_time: int, iv: bytes
61 ) -> bytes:
62 utils._check_bytes("data", data)
63
64 padder = padding.PKCS7(algorithms.AES.block_size).padder()
65 padded_data = padder.update(data) + padder.finalize()
66 encryptor = Cipher(
67 algorithms.AES(self._encryption_key),
68 modes.CBC(iv),
69 ).encryptor()
70 ciphertext = encryptor.update(padded_data) + encryptor.finalize()
71
72 basic_parts = (
73 b"\x80"
74 + current_time.to_bytes(length=8, byteorder="big")
75 + iv
76 + ciphertext
77 )
78
79 h = HMAC(self._signing_key, hashes.SHA256())
80 h.update(basic_parts)
81 hmac = h.finalize()
82 return base64.urlsafe_b64encode(basic_parts + hmac)
83
84 def decrypt(self, token: bytes | str, ttl: int | None = None) -> bytes:
85 timestamp, data = Fernet._get_unverified_token_data(token)
86 if ttl is None:
87 time_info = None
88 else:
89 time_info = (ttl, int(time.time()))
90 return self._decrypt_data(data, timestamp, time_info)
91
92 def decrypt_at_time(
93 self, token: bytes | str, ttl: int, current_time: int
94 ) -> bytes:
95 if ttl is None:
96 raise ValueError(
97 "decrypt_at_time() can only be used with a non-None ttl"
98 )
99 timestamp, data = Fernet._get_unverified_token_data(token)
100 return self._decrypt_data(data, timestamp, (ttl, current_time))
101
102 def extract_timestamp(self, token: bytes | str) -> int:
103 timestamp, data = Fernet._get_unverified_token_data(token)
104 # Verify the token was not tampered with.
105 self._verify_signature(data)
106 return timestamp
107
108 @staticmethod
109 def _get_unverified_token_data(token: bytes | str) -> tuple[int, bytes]:
110 if not isinstance(token, (str, bytes)):
111 raise TypeError("token must be bytes or str")
112
113 try:
114 data = base64.urlsafe_b64decode(token)
115 except (TypeError, binascii.Error):
116 raise InvalidToken
117
118 if not data or data[0] != 0x80:
119 raise InvalidToken
120
121 if len(data) < 9:
122 raise InvalidToken
123
124 timestamp = int.from_bytes(data[1:9], byteorder="big")
125 return timestamp, data
126
127 def _verify_signature(self, data: bytes) -> None:
128 h = HMAC(self._signing_key, hashes.SHA256())
129 h.update(data[:-32])
130 try:
131 h.verify(data[-32:])
132 except InvalidSignature:
133 raise InvalidToken
134
135 def _decrypt_data(
136 self,
137 data: bytes,
138 timestamp: int,
139 time_info: tuple[int, int] | None,
140 ) -> bytes:
141 if time_info is not None:
142 ttl, current_time = time_info
143 if timestamp + ttl < current_time:
144 raise InvalidToken
145
146 if current_time + _MAX_CLOCK_SKEW < timestamp:
147 raise InvalidToken
148
149 self._verify_signature(data)
150
151 iv = data[9:25]
152 ciphertext = data[25:-32]
153 decryptor = Cipher(
154 algorithms.AES(self._encryption_key), modes.CBC(iv)
155 ).decryptor()
156 plaintext_padded = decryptor.update(ciphertext)
157 try:
158 plaintext_padded += decryptor.finalize()
159 except ValueError:
160 raise InvalidToken
161 unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
162
163 unpadded = unpadder.update(plaintext_padded)
164 try:
165 unpadded += unpadder.finalize()
166 except ValueError:
167 raise InvalidToken
168 return unpadded
169
170
171class MultiFernet:
172 def __init__(self, fernets: Iterable[Fernet]):
173 fernets = list(fernets)
174 if not fernets:
175 raise ValueError(
176 "MultiFernet requires at least one Fernet instance"
177 )
178 self._fernets = fernets
179
180 def encrypt(self, msg: bytes) -> bytes:
181 return self.encrypt_at_time(msg, int(time.time()))
182
183 def encrypt_at_time(self, msg: bytes, current_time: int) -> bytes:
184 return self._fernets[0].encrypt_at_time(msg, current_time)
185
186 def rotate(self, msg: bytes | str) -> bytes:
187 timestamp, data = Fernet._get_unverified_token_data(msg)
188 for f in self._fernets:
189 try:
190 p = f._decrypt_data(data, timestamp, None)
191 break
192 except InvalidToken:
193 pass
194 else:
195 raise InvalidToken
196
197 iv = os.urandom(16)
198 return self._fernets[0]._encrypt_from_parts(p, timestamp, iv)
199
200 def decrypt(self, msg: bytes | str, ttl: int | None = None) -> bytes:
201 for f in self._fernets:
202 try:
203 return f.decrypt(msg, ttl)
204 except InvalidToken:
205 pass
206 raise InvalidToken
207
208 def decrypt_at_time(
209 self, msg: bytes | str, ttl: int, current_time: int
210 ) -> bytes:
211 for f in self._fernets:
212 try:
213 return f.decrypt_at_time(msg, ttl, current_time)
214 except InvalidToken:
215 pass
216 raise InvalidToken
217
218 def extract_timestamp(self, msg: bytes | str) -> int:
219 for f in self._fernets:
220 try:
221 return f.extract_timestamp(msg)
222 except InvalidToken:
223 pass
224 raise InvalidToken