1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import annotations
6
7import typing
8
9from cryptography.exceptions import InvalidTag, UnsupportedAlgorithm, _Reasons
10from cryptography.hazmat.primitives import ciphers
11from cryptography.hazmat.primitives.ciphers import algorithms, modes
12
13if typing.TYPE_CHECKING:
14 from cryptography.hazmat.backends.openssl.backend import Backend
15
16
17class _CipherContext:
18 _ENCRYPT = 1
19 _DECRYPT = 0
20 _MAX_CHUNK_SIZE = 2**29
21
22 def __init__(self, backend: Backend, cipher, mode, operation: int) -> None:
23 self._backend = backend
24 self._cipher = cipher
25 self._mode = mode
26 self._operation = operation
27 self._tag: bytes | None = None
28
29 if isinstance(self._cipher, ciphers.BlockCipherAlgorithm):
30 self._block_size_bytes = self._cipher.block_size // 8
31 else:
32 self._block_size_bytes = 1
33
34 ctx = self._backend._lib.EVP_CIPHER_CTX_new()
35 ctx = self._backend._ffi.gc(
36 ctx, self._backend._lib.EVP_CIPHER_CTX_free
37 )
38
39 registry = self._backend._cipher_registry
40 try:
41 adapter = registry[type(cipher), type(mode)]
42 except KeyError:
43 raise UnsupportedAlgorithm(
44 "cipher {} in {} mode is not supported "
45 "by this backend.".format(
46 cipher.name, mode.name if mode else mode
47 ),
48 _Reasons.UNSUPPORTED_CIPHER,
49 )
50
51 evp_cipher = adapter(self._backend, cipher, mode)
52 if evp_cipher == self._backend._ffi.NULL:
53 msg = f"cipher {cipher.name} "
54 if mode is not None:
55 msg += f"in {mode.name} mode "
56 msg += (
57 "is not supported by this backend (Your version of OpenSSL "
58 "may be too old. Current version: {}.)"
59 ).format(self._backend.openssl_version_text())
60 raise UnsupportedAlgorithm(msg, _Reasons.UNSUPPORTED_CIPHER)
61
62 if isinstance(mode, modes.ModeWithInitializationVector):
63 iv_nonce = self._backend._ffi.from_buffer(
64 mode.initialization_vector
65 )
66 elif isinstance(mode, modes.ModeWithTweak):
67 iv_nonce = self._backend._ffi.from_buffer(mode.tweak)
68 elif isinstance(mode, modes.ModeWithNonce):
69 iv_nonce = self._backend._ffi.from_buffer(mode.nonce)
70 elif isinstance(cipher, algorithms.ChaCha20):
71 iv_nonce = self._backend._ffi.from_buffer(cipher.nonce)
72 else:
73 iv_nonce = self._backend._ffi.NULL
74 # begin init with cipher and operation type
75 res = self._backend._lib.EVP_CipherInit_ex(
76 ctx,
77 evp_cipher,
78 self._backend._ffi.NULL,
79 self._backend._ffi.NULL,
80 self._backend._ffi.NULL,
81 operation,
82 )
83 self._backend.openssl_assert(res != 0)
84 # set the key length to handle variable key ciphers
85 res = self._backend._lib.EVP_CIPHER_CTX_set_key_length(
86 ctx, len(cipher.key)
87 )
88 self._backend.openssl_assert(res != 0)
89 if isinstance(mode, modes.GCM):
90 res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
91 ctx,
92 self._backend._lib.EVP_CTRL_AEAD_SET_IVLEN,
93 len(iv_nonce),
94 self._backend._ffi.NULL,
95 )
96 self._backend.openssl_assert(res != 0)
97 if mode.tag is not None:
98 res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
99 ctx,
100 self._backend._lib.EVP_CTRL_AEAD_SET_TAG,
101 len(mode.tag),
102 mode.tag,
103 )
104 self._backend.openssl_assert(res != 0)
105 self._tag = mode.tag
106
107 # pass key/iv
108 res = self._backend._lib.EVP_CipherInit_ex(
109 ctx,
110 self._backend._ffi.NULL,
111 self._backend._ffi.NULL,
112 self._backend._ffi.from_buffer(cipher.key),
113 iv_nonce,
114 operation,
115 )
116
117 # Check for XTS mode duplicate keys error
118 errors = self._backend._consume_errors()
119 lib = self._backend._lib
120 if res == 0 and (
121 (
122 not lib.CRYPTOGRAPHY_IS_LIBRESSL
123 and errors[0]._lib_reason_match(
124 lib.ERR_LIB_EVP, lib.EVP_R_XTS_DUPLICATED_KEYS
125 )
126 )
127 or (
128 lib.Cryptography_HAS_PROVIDERS
129 and errors[0]._lib_reason_match(
130 lib.ERR_LIB_PROV, lib.PROV_R_XTS_DUPLICATED_KEYS
131 )
132 )
133 ):
134 raise ValueError("In XTS mode duplicated keys are not allowed")
135
136 self._backend.openssl_assert(res != 0, errors=errors)
137
138 # We purposely disable padding here as it's handled higher up in the
139 # API.
140 self._backend._lib.EVP_CIPHER_CTX_set_padding(ctx, 0)
141 self._ctx = ctx
142
143 def update(self, data: bytes) -> bytes:
144 buf = bytearray(len(data) + self._block_size_bytes - 1)
145 n = self.update_into(data, buf)
146 return bytes(buf[:n])
147
148 def update_into(self, data: bytes, buf: bytes) -> int:
149 total_data_len = len(data)
150 if len(buf) < (total_data_len + self._block_size_bytes - 1):
151 raise ValueError(
152 "buffer must be at least {} bytes for this payload".format(
153 len(data) + self._block_size_bytes - 1
154 )
155 )
156
157 data_processed = 0
158 total_out = 0
159 outlen = self._backend._ffi.new("int *")
160 baseoutbuf = self._backend._ffi.from_buffer(buf, require_writable=True)
161 baseinbuf = self._backend._ffi.from_buffer(data)
162
163 while data_processed != total_data_len:
164 outbuf = baseoutbuf + total_out
165 inbuf = baseinbuf + data_processed
166 inlen = min(self._MAX_CHUNK_SIZE, total_data_len - data_processed)
167
168 res = self._backend._lib.EVP_CipherUpdate(
169 self._ctx, outbuf, outlen, inbuf, inlen
170 )
171 if res == 0 and isinstance(self._mode, modes.XTS):
172 self._backend._consume_errors()
173 raise ValueError(
174 "In XTS mode you must supply at least a full block in the "
175 "first update call. For AES this is 16 bytes."
176 )
177 else:
178 self._backend.openssl_assert(res != 0)
179 data_processed += inlen
180 total_out += outlen[0]
181
182 return total_out
183
184 def finalize(self) -> bytes:
185 if (
186 self._operation == self._DECRYPT
187 and isinstance(self._mode, modes.ModeWithAuthenticationTag)
188 and self.tag is None
189 ):
190 raise ValueError(
191 "Authentication tag must be provided when decrypting."
192 )
193
194 buf = self._backend._ffi.new("unsigned char[]", self._block_size_bytes)
195 outlen = self._backend._ffi.new("int *")
196 res = self._backend._lib.EVP_CipherFinal_ex(self._ctx, buf, outlen)
197 if res == 0:
198 errors = self._backend._consume_errors()
199
200 if not errors and isinstance(self._mode, modes.GCM):
201 raise InvalidTag
202
203 lib = self._backend._lib
204 self._backend.openssl_assert(
205 errors[0]._lib_reason_match(
206 lib.ERR_LIB_EVP,
207 lib.EVP_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH,
208 )
209 or (
210 lib.Cryptography_HAS_PROVIDERS
211 and errors[0]._lib_reason_match(
212 lib.ERR_LIB_PROV,
213 lib.PROV_R_WRONG_FINAL_BLOCK_LENGTH,
214 )
215 )
216 or (
217 lib.CRYPTOGRAPHY_IS_BORINGSSL
218 and errors[0].reason
219 == lib.CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH
220 ),
221 errors=errors,
222 )
223 raise ValueError(
224 "The length of the provided data is not a multiple of "
225 "the block length."
226 )
227
228 if (
229 isinstance(self._mode, modes.GCM)
230 and self._operation == self._ENCRYPT
231 ):
232 tag_buf = self._backend._ffi.new(
233 "unsigned char[]", self._block_size_bytes
234 )
235 res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
236 self._ctx,
237 self._backend._lib.EVP_CTRL_AEAD_GET_TAG,
238 self._block_size_bytes,
239 tag_buf,
240 )
241 self._backend.openssl_assert(res != 0)
242 self._tag = self._backend._ffi.buffer(tag_buf)[:]
243
244 res = self._backend._lib.EVP_CIPHER_CTX_reset(self._ctx)
245 self._backend.openssl_assert(res == 1)
246 return self._backend._ffi.buffer(buf)[: outlen[0]]
247
248 def finalize_with_tag(self, tag: bytes) -> bytes:
249 tag_len = len(tag)
250 if tag_len < self._mode._min_tag_length:
251 raise ValueError(
252 "Authentication tag must be {} bytes or longer.".format(
253 self._mode._min_tag_length
254 )
255 )
256 elif tag_len > self._block_size_bytes:
257 raise ValueError(
258 "Authentication tag cannot be more than {} bytes.".format(
259 self._block_size_bytes
260 )
261 )
262 res = self._backend._lib.EVP_CIPHER_CTX_ctrl(
263 self._ctx, self._backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag
264 )
265 self._backend.openssl_assert(res != 0)
266 self._tag = tag
267 return self.finalize()
268
269 def authenticate_additional_data(self, data: bytes) -> None:
270 outlen = self._backend._ffi.new("int *")
271 res = self._backend._lib.EVP_CipherUpdate(
272 self._ctx,
273 self._backend._ffi.NULL,
274 outlen,
275 self._backend._ffi.from_buffer(data),
276 len(data),
277 )
278 self._backend.openssl_assert(res != 0)
279
280 @property
281 def tag(self) -> bytes | None:
282 return self._tag