Coverage Report

Created: 2024-07-27 06:39

/src/openssl31/crypto/rsa/rsa_pss.c
Line
Count
Source (jump to first uncovered line)
1
/*
2
 * Copyright 2005-2022 The OpenSSL Project Authors. All Rights Reserved.
3
 *
4
 * Licensed under the Apache License 2.0 (the "License").  You may not use
5
 * this file except in compliance with the License.  You can obtain a copy
6
 * in the file LICENSE in the source distribution or at
7
 * https://www.openssl.org/source/license.html
8
 */
9
10
/*
11
 * RSA low level APIs are deprecated for public use, but still ok for
12
 * internal use.
13
 */
14
#include "internal/deprecated.h"
15
16
#include <stdio.h>
17
#include "internal/cryptlib.h"
18
#include <openssl/bn.h>
19
#include <openssl/rsa.h>
20
#include <openssl/evp.h>
21
#include <openssl/rand.h>
22
#include <openssl/sha.h>
23
#include "rsa_local.h"
24
25
static const unsigned char zeroes[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
26
27
#if defined(_MSC_VER) && defined(_ARM_)
28
# pragma optimize("g", off)
29
#endif
30
31
int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash,
32
                         const EVP_MD *Hash, const unsigned char *EM,
33
                         int sLen)
34
0
{
35
0
    return RSA_verify_PKCS1_PSS_mgf1(rsa, mHash, Hash, NULL, EM, sLen);
36
0
}
37
38
int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const unsigned char *mHash,
39
                              const EVP_MD *Hash, const EVP_MD *mgf1Hash,
40
                              const unsigned char *EM, int sLen)
41
7.88k
{
42
7.88k
    int i;
43
7.88k
    int ret = 0;
44
7.88k
    int hLen, maskedDBLen, MSBits, emLen;
45
7.88k
    const unsigned char *H;
46
7.88k
    unsigned char *DB = NULL;
47
7.88k
    EVP_MD_CTX *ctx = EVP_MD_CTX_new();
48
7.88k
    unsigned char H_[EVP_MAX_MD_SIZE];
49
50
7.88k
    if (ctx == NULL)
51
0
        goto err;
52
53
7.88k
    if (mgf1Hash == NULL)
54
0
        mgf1Hash = Hash;
55
56
7.88k
    hLen = EVP_MD_get_size(Hash);
57
7.88k
    if (hLen < 0)
58
0
        goto err;
59
    /*-
60
     * Negative sLen has special meanings:
61
     *      -1      sLen == hLen
62
     *      -2      salt length is autorecovered from signature
63
     *      -3      salt length is maximized
64
     *      -4      salt length is autorecovered from signature
65
     *      -N      reserved
66
     */
67
7.88k
    if (sLen == RSA_PSS_SALTLEN_DIGEST) {
68
5.32k
        sLen = hLen;
69
5.32k
    } else if (sLen < RSA_PSS_SALTLEN_AUTO_DIGEST_MAX) {
70
0
        ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED);
71
0
        goto err;
72
0
    }
73
74
7.88k
    MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
75
7.88k
    emLen = RSA_size(rsa);
76
7.88k
    if (EM[0] & (0xFF << MSBits)) {
77
2.28k
        ERR_raise(ERR_LIB_RSA, RSA_R_FIRST_OCTET_INVALID);
78
2.28k
        goto err;
79
2.28k
    }
80
5.59k
    if (MSBits == 0) {
81
882
        EM++;
82
882
        emLen--;
83
882
    }
84
5.59k
    if (emLen < hLen + 2) {
85
61
        ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE);
86
61
        goto err;
87
61
    }
88
5.53k
    if (sLen == RSA_PSS_SALTLEN_MAX) {
89
0
        sLen = emLen - hLen - 2;
90
5.53k
    } else if (sLen > emLen - hLen - 2) { /* sLen can be small negative */
91
71
        ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE);
92
71
        goto err;
93
71
    }
94
5.46k
    if (EM[emLen - 1] != 0xbc) {
95
5.18k
        ERR_raise(ERR_LIB_RSA, RSA_R_LAST_OCTET_INVALID);
96
5.18k
        goto err;
97
5.18k
    }
98
281
    maskedDBLen = emLen - hLen - 1;
99
281
    H = EM + maskedDBLen;
100
281
    DB = OPENSSL_malloc(maskedDBLen);
101
281
    if (DB == NULL) {
102
0
        ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
103
0
        goto err;
104
0
    }
105
281
    if (PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash) < 0)
106
0
        goto err;
107
52.4k
    for (i = 0; i < maskedDBLen; i++)
108
52.1k
        DB[i] ^= EM[i];
109
281
    if (MSBits)
110
237
        DB[0] &= 0xFF >> (8 - MSBits);
111
5.45k
    for (i = 0; DB[i] == 0 && i < (maskedDBLen - 1); i++) ;
112
281
    if (DB[i++] != 0x1) {
113
217
        ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_RECOVERY_FAILED);
114
217
        goto err;
115
217
    }
116
64
    if (sLen != RSA_PSS_SALTLEN_AUTO
117
64
            && sLen != RSA_PSS_SALTLEN_AUTO_DIGEST_MAX
118
64
            && (maskedDBLen - i) != sLen) {
119
31
        ERR_raise_data(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED,
120
31
                       "expected: %d retrieved: %d", sLen,
121
31
                       maskedDBLen - i);
122
31
        goto err;
123
31
    }
124
33
    if (!EVP_DigestInit_ex(ctx, Hash, NULL)
125
33
        || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes))
126
33
        || !EVP_DigestUpdate(ctx, mHash, hLen))
127
0
        goto err;
128
33
    if (maskedDBLen - i) {
129
33
        if (!EVP_DigestUpdate(ctx, DB + i, maskedDBLen - i))
130
0
            goto err;
131
33
    }
132
33
    if (!EVP_DigestFinal_ex(ctx, H_, NULL))
133
0
        goto err;
134
33
    if (memcmp(H_, H, hLen)) {
135
27
        ERR_raise(ERR_LIB_RSA, RSA_R_BAD_SIGNATURE);
136
27
        ret = 0;
137
27
    } else {
138
6
        ret = 1;
139
6
    }
140
141
7.88k
 err:
142
7.88k
    OPENSSL_free(DB);
143
7.88k
    EVP_MD_CTX_free(ctx);
144
145
7.88k
    return ret;
146
147
33
}
148
149
int RSA_padding_add_PKCS1_PSS(RSA *rsa, unsigned char *EM,
150
                              const unsigned char *mHash,
151
                              const EVP_MD *Hash, int sLen)
152
0
{
153
0
    return RSA_padding_add_PKCS1_PSS_mgf1(rsa, EM, mHash, Hash, NULL, sLen);
154
0
}
155
156
int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, unsigned char *EM,
157
                                   const unsigned char *mHash,
158
                                   const EVP_MD *Hash, const EVP_MD *mgf1Hash,
159
                                   int sLen)
160
647
{
161
647
    int i;
162
647
    int ret = 0;
163
647
    int hLen, maskedDBLen, MSBits, emLen;
164
647
    unsigned char *H, *salt = NULL, *p;
165
647
    EVP_MD_CTX *ctx = NULL;
166
647
    int sLenMax = -1;
167
168
647
    if (mgf1Hash == NULL)
169
0
        mgf1Hash = Hash;
170
171
647
    hLen = EVP_MD_get_size(Hash);
172
647
    if (hLen < 0)
173
0
        goto err;
174
    /*-
175
     * Negative sLen has special meanings:
176
     *      -1      sLen == hLen
177
     *      -2      salt length is maximized
178
     *      -3      same as above (on signing)
179
     *      -4      salt length is min(hLen, maximum salt length)
180
     *      -N      reserved
181
     */
182
    /* FIPS 186-4 section 5 "The RSA Digital Signature Algorithm", subsection
183
     * 5.5 "PKCS #1" says: "For RSASSA-PSS […] the length (in bytes) of the
184
     * salt (sLen) shall satisfy 0 <= sLen <= hLen, where hLen is the length of
185
     * the hash function output block (in bytes)."
186
     *
187
     * Provide a way to use at most the digest length, so that the default does
188
     * not violate FIPS 186-4. */
189
647
    if (sLen == RSA_PSS_SALTLEN_DIGEST) {
190
647
        sLen = hLen;
191
647
    } else if (sLen == RSA_PSS_SALTLEN_MAX_SIGN
192
0
            || sLen == RSA_PSS_SALTLEN_AUTO) {
193
0
        sLen = RSA_PSS_SALTLEN_MAX;
194
0
    } else if (sLen == RSA_PSS_SALTLEN_AUTO_DIGEST_MAX) {
195
0
        sLen = RSA_PSS_SALTLEN_MAX;
196
0
        sLenMax = hLen;
197
0
    } else if (sLen < RSA_PSS_SALTLEN_AUTO_DIGEST_MAX) {
198
0
        ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED);
199
0
        goto err;
200
0
    }
201
202
647
    MSBits = (BN_num_bits(rsa->n) - 1) & 0x7;
203
647
    emLen = RSA_size(rsa);
204
647
    if (MSBits == 0) {
205
0
        *EM++ = 0;
206
0
        emLen--;
207
0
    }
208
647
    if (emLen < hLen + 2) {
209
0
        ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
210
0
        goto err;
211
0
    }
212
647
    if (sLen == RSA_PSS_SALTLEN_MAX) {
213
0
        sLen = emLen - hLen - 2;
214
0
        if (sLenMax >= 0 && sLen > sLenMax)
215
0
            sLen = sLenMax;
216
647
    } else if (sLen > emLen - hLen - 2) {
217
0
        ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE);
218
0
        goto err;
219
0
    }
220
647
    if (sLen > 0) {
221
647
        salt = OPENSSL_malloc(sLen);
222
647
        if (salt == NULL) {
223
0
            ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
224
0
            goto err;
225
0
        }
226
647
        if (RAND_bytes_ex(rsa->libctx, salt, sLen, 0) <= 0)
227
0
            goto err;
228
647
    }
229
647
    maskedDBLen = emLen - hLen - 1;
230
647
    H = EM + maskedDBLen;
231
647
    ctx = EVP_MD_CTX_new();
232
647
    if (ctx == NULL)
233
0
        goto err;
234
647
    if (!EVP_DigestInit_ex(ctx, Hash, NULL)
235
647
        || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes))
236
647
        || !EVP_DigestUpdate(ctx, mHash, hLen))
237
0
        goto err;
238
647
    if (sLen && !EVP_DigestUpdate(ctx, salt, sLen))
239
0
        goto err;
240
647
    if (!EVP_DigestFinal_ex(ctx, H, NULL))
241
0
        goto err;
242
243
    /* Generate dbMask in place then perform XOR on it */
244
647
    if (PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash))
245
0
        goto err;
246
247
647
    p = EM;
248
249
    /*
250
     * Initial PS XORs with all zeroes which is a NOP so just update pointer.
251
     * Note from a test above this value is guaranteed to be non-negative.
252
     */
253
647
    p += emLen - sLen - hLen - 2;
254
647
    *p++ ^= 0x1;
255
647
    if (sLen > 0) {
256
26.5k
        for (i = 0; i < sLen; i++)
257
25.8k
            *p++ ^= salt[i];
258
647
    }
259
647
    if (MSBits)
260
647
        EM[0] &= 0xFF >> (8 - MSBits);
261
262
    /* H is already in place so just set final 0xbc */
263
264
647
    EM[emLen - 1] = 0xbc;
265
266
647
    ret = 1;
267
268
647
 err:
269
647
    EVP_MD_CTX_free(ctx);
270
647
    OPENSSL_clear_free(salt, (size_t)sLen); /* salt != NULL implies sLen > 0 */
271
272
647
    return ret;
273
274
647
}
275
276
/*
277
 * The defaults for PSS restrictions are defined in RFC 8017, A.2.3 RSASSA-PSS
278
 * (https://tools.ietf.org/html/rfc8017#appendix-A.2.3):
279
 *
280
 * If the default values of the hashAlgorithm, maskGenAlgorithm, and
281
 * trailerField fields of RSASSA-PSS-params are used, then the algorithm
282
 * identifier will have the following value:
283
 *
284
 *     rSASSA-PSS-Default-Identifier    RSASSA-AlgorithmIdentifier ::= {
285
 *         algorithm   id-RSASSA-PSS,
286
 *         parameters  RSASSA-PSS-params : {
287
 *             hashAlgorithm       sha1,
288
 *             maskGenAlgorithm    mgf1SHA1,
289
 *             saltLength          20,
290
 *             trailerField        trailerFieldBC
291
 *         }
292
 *     }
293
 *
294
 *     RSASSA-AlgorithmIdentifier ::= AlgorithmIdentifier {
295
 *         {PKCS1Algorithms}
296
 *     }
297
 */
298
static const RSA_PSS_PARAMS_30 default_RSASSA_PSS_params = {
299
    NID_sha1,                    /* default hashAlgorithm */
300
    {
301
        NID_mgf1,                /* default maskGenAlgorithm */
302
        NID_sha1                 /* default MGF1 hash */
303
    },
304
    20,                          /* default saltLength */
305
    1                            /* default trailerField (0xBC) */
306
};
307
308
int ossl_rsa_pss_params_30_set_defaults(RSA_PSS_PARAMS_30 *rsa_pss_params)
309
49.1k
{
310
49.1k
    if (rsa_pss_params == NULL)
311
0
        return 0;
312
49.1k
    *rsa_pss_params = default_RSASSA_PSS_params;
313
49.1k
    return 1;
314
49.1k
}
315
316
int ossl_rsa_pss_params_30_is_unrestricted(const RSA_PSS_PARAMS_30 *rsa_pss_params)
317
50.2k
{
318
50.2k
    static RSA_PSS_PARAMS_30 pss_params_cmp = { 0, };
319
320
50.2k
    return rsa_pss_params == NULL
321
50.2k
        || memcmp(rsa_pss_params, &pss_params_cmp,
322
50.2k
                  sizeof(*rsa_pss_params)) == 0;
323
50.2k
}
324
325
int ossl_rsa_pss_params_30_copy(RSA_PSS_PARAMS_30 *to,
326
                                const RSA_PSS_PARAMS_30 *from)
327
0
{
328
0
    memcpy(to, from, sizeof(*to));
329
0
    return 1;
330
0
}
331
332
int ossl_rsa_pss_params_30_set_hashalg(RSA_PSS_PARAMS_30 *rsa_pss_params,
333
                                       int hashalg_nid)
334
22.1k
{
335
22.1k
    if (rsa_pss_params == NULL)
336
0
        return 0;
337
22.1k
    rsa_pss_params->hash_algorithm_nid = hashalg_nid;
338
22.1k
    return 1;
339
22.1k
}
340
341
int ossl_rsa_pss_params_30_set_maskgenalg(RSA_PSS_PARAMS_30 *rsa_pss_params,
342
                                          int maskgenalg_nid)
343
0
{
344
0
    if (rsa_pss_params == NULL)
345
0
        return 0;
346
0
    rsa_pss_params->mask_gen.algorithm_nid = maskgenalg_nid;
347
0
    return 1;
348
0
}
349
350
int ossl_rsa_pss_params_30_set_maskgenhashalg(RSA_PSS_PARAMS_30 *rsa_pss_params,
351
                                              int maskgenhashalg_nid)
352
22.1k
{
353
22.1k
    if (rsa_pss_params == NULL)
354
0
        return 0;
355
22.1k
    rsa_pss_params->mask_gen.hash_algorithm_nid = maskgenhashalg_nid;
356
22.1k
    return 1;
357
22.1k
}
358
359
int ossl_rsa_pss_params_30_set_saltlen(RSA_PSS_PARAMS_30 *rsa_pss_params,
360
                                       int saltlen)
361
22.1k
{
362
22.1k
    if (rsa_pss_params == NULL)
363
0
        return 0;
364
22.1k
    rsa_pss_params->salt_len = saltlen;
365
22.1k
    return 1;
366
22.1k
}
367
368
int ossl_rsa_pss_params_30_set_trailerfield(RSA_PSS_PARAMS_30 *rsa_pss_params,
369
                                            int trailerfield)
370
22.1k
{
371
22.1k
    if (rsa_pss_params == NULL)
372
0
        return 0;
373
22.1k
    rsa_pss_params->trailer_field = trailerfield;
374
22.1k
    return 1;
375
22.1k
}
376
377
int ossl_rsa_pss_params_30_hashalg(const RSA_PSS_PARAMS_30 *rsa_pss_params)
378
44.7k
{
379
44.7k
    if (rsa_pss_params == NULL)
380
22.1k
        return default_RSASSA_PSS_params.hash_algorithm_nid;
381
22.5k
    return rsa_pss_params->hash_algorithm_nid;
382
44.7k
}
383
384
int ossl_rsa_pss_params_30_maskgenalg(const RSA_PSS_PARAMS_30 *rsa_pss_params)
385
44.4k
{
386
44.4k
    if (rsa_pss_params == NULL)
387
22.1k
        return default_RSASSA_PSS_params.mask_gen.algorithm_nid;
388
22.2k
    return rsa_pss_params->mask_gen.algorithm_nid;
389
44.4k
}
390
391
int ossl_rsa_pss_params_30_maskgenhashalg(const RSA_PSS_PARAMS_30 *rsa_pss_params)
392
44.7k
{
393
44.7k
    if (rsa_pss_params == NULL)
394
22.1k
        return default_RSASSA_PSS_params.hash_algorithm_nid;
395
22.5k
    return rsa_pss_params->mask_gen.hash_algorithm_nid;
396
44.7k
}
397
398
int ossl_rsa_pss_params_30_saltlen(const RSA_PSS_PARAMS_30 *rsa_pss_params)
399
48.7k
{
400
48.7k
    if (rsa_pss_params == NULL)
401
6
        return default_RSASSA_PSS_params.salt_len;
402
48.6k
    return rsa_pss_params->salt_len;
403
48.7k
}
404
405
int ossl_rsa_pss_params_30_trailerfield(const RSA_PSS_PARAMS_30 *rsa_pss_params)
406
26.2k
{
407
26.2k
    if (rsa_pss_params == NULL)
408
6
        return default_RSASSA_PSS_params.trailer_field;
409
26.2k
    return rsa_pss_params->trailer_field;
410
26.2k
}
411
412
#if defined(_MSC_VER)
413
# pragma optimize("",on)
414
#endif