Coverage Report

Created: 2025-06-24 06:49

/src/nss/lib/freebl/rsapkcs.c
Line
Count
Source (jump to first uncovered line)
1
/* This Source Code Form is subject to the terms of the Mozilla Public
2
 * License, v. 2.0. If a copy of the MPL was not distributed with this
3
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5
/*
6
 * RSA PKCS#1 v2.1 (RFC 3447) operations
7
 */
8
9
#ifdef FREEBL_NO_DEPEND
10
#include "stubs.h"
11
#endif
12
13
#include "secerr.h"
14
15
#include "blapi.h"
16
#include "secitem.h"
17
#include "blapii.h"
18
19
92.2k
#define RSA_BLOCK_MIN_PAD_LEN 8
20
64.0k
#define RSA_BLOCK_FIRST_OCTET 0x00
21
42.7k
#define RSA_BLOCK_PRIVATE_PAD_OCTET 0xff
22
1.14M
#define RSA_BLOCK_AFTER_PAD_OCTET 0x00
23
24
/*
25
 * RSA block types
26
 *
27
 * The values of RSA_BlockPrivate and RSA_BlockPublic are fixed.
28
 * The value of RSA_BlockRaw isn't fixed by definition, but we are keeping
29
 * the value that NSS has been using in the past.
30
 */
31
typedef enum {
32
    RSA_BlockPrivate = 1, /* pad for a private-key operation */
33
    RSA_BlockPublic = 2,  /* pad for a public-key operation */
34
    RSA_BlockRaw = 4      /* simply justify the block appropriately */
35
} RSA_BlockType;
36
37
/* Needed for RSA-PSS functions */
38
static const unsigned char eightZeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
39
40
/* Constant time comparison of a single byte.
41
 * Returns 1 iff a == b, otherwise returns 0.
42
 * Note: For ranges of bytes, use constantTimeCompare.
43
 */
44
static unsigned char
45
constantTimeEQ8(unsigned char a, unsigned char b)
46
0
{
47
0
    unsigned char c = ~((a - b) | (b - a));
48
0
    c >>= 7;
49
0
    return c;
50
0
}
51
52
/* Constant time comparison of a range of bytes.
53
 * Returns 1 iff len bytes of a are identical to len bytes of b, otherwise
54
 * returns 0.
55
 */
56
static unsigned char
57
constantTimeCompare(const unsigned char *a,
58
                    const unsigned char *b,
59
                    unsigned int len)
60
0
{
61
0
    unsigned char tmp = 0;
62
0
    unsigned int i;
63
0
    for (i = 0; i < len; ++i, ++a, ++b)
64
0
        tmp |= *a ^ *b;
65
0
    return constantTimeEQ8(0x00, tmp);
66
0
}
67
68
/* Constant time conditional.
69
 * Returns a if c is 1, or b if c is 0. The result is undefined if c is
70
 * not 0 or 1.
71
 */
72
static unsigned int
73
constantTimeCondition(unsigned int c,
74
                      unsigned int a,
75
                      unsigned int b)
76
0
{
77
0
    return (~(c - 1) & a) | ((c - 1) & b);
78
0
}
79
80
static unsigned int
81
rsa_modulusLen(SECItem *modulus)
82
126k
{
83
126k
    if (modulus->len == 0) {
84
0
        return 0;
85
0
    }
86
87
126k
    unsigned char byteZero = modulus->data[0];
88
126k
    unsigned int modLen = modulus->len - !byteZero;
89
126k
    return modLen;
90
126k
}
91
92
static unsigned int
93
rsa_modulusBits(SECItem *modulus)
94
6.00k
{
95
6.00k
    if (modulus->len == 0) {
96
0
        return 0;
97
0
    }
98
99
6.00k
    unsigned char byteZero = modulus->data[0];
100
6.00k
    unsigned int numBits = (modulus->len - 1) * 8;
101
102
6.00k
    if (byteZero == 0 && modulus->len == 1) {
103
0
        return 0;
104
0
    }
105
106
6.00k
    if (byteZero == 0) {
107
0
        numBits -= 8;
108
0
        byteZero = modulus->data[1];
109
0
    }
110
111
53.1k
    while (byteZero > 0) {
112
47.1k
        numBits++;
113
47.1k
        byteZero >>= 1;
114
47.1k
    }
115
116
6.00k
    return numBits;
117
6.00k
}
118
119
/*
120
 * Format one block of data for public/private key encryption using
121
 * the rules defined in PKCS #1.
122
 */
123
static unsigned char *
124
rsa_FormatOneBlock(unsigned modulusLen,
125
                   RSA_BlockType blockType,
126
                   SECItem *data)
127
30.2k
{
128
30.2k
    unsigned char *block;
129
30.2k
    unsigned char *bp;
130
30.2k
    unsigned int padLen;
131
30.2k
    unsigned int i, j;
132
30.2k
    SECStatus rv;
133
134
30.2k
    block = (unsigned char *)PORT_Alloc(modulusLen);
135
30.2k
    if (block == NULL)
136
0
        return NULL;
137
138
30.2k
    bp = block;
139
140
    /*
141
     * All RSA blocks start with two octets:
142
     *  0x00 || BlockType
143
     */
144
30.2k
    *bp++ = RSA_BLOCK_FIRST_OCTET;
145
30.2k
    *bp++ = (unsigned char)blockType;
146
147
30.2k
    switch (blockType) {
148
149
        /*
150
         * Blocks intended for private-key operation.
151
         */
152
17.7k
        case RSA_BlockPrivate: /* preferred method */
153
            /*
154
             * 0x00 || BT || Pad || 0x00 || ActualData
155
             *   1      1   padLen    1      data->len
156
             * padLen must be at least RSA_BLOCK_MIN_PAD_LEN (8) bytes.
157
             * Pad is either all 0x00 or all 0xff bytes, depending on blockType.
158
             */
159
17.7k
            padLen = modulusLen - data->len - 3;
160
17.7k
            PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
161
17.7k
            if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
162
0
                PORT_ZFree(block, modulusLen);
163
0
                return NULL;
164
0
            }
165
17.7k
            PORT_Memset(bp, RSA_BLOCK_PRIVATE_PAD_OCTET, padLen);
166
17.7k
            bp += padLen;
167
17.7k
            *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
168
17.7k
            PORT_Memcpy(bp, data->data, data->len);
169
17.7k
            break;
170
171
        /*
172
         * Blocks intended for public-key operation.
173
         */
174
12.4k
        case RSA_BlockPublic:
175
            /*
176
             * 0x00 || BT || Pad || 0x00 || ActualData
177
             *   1      1   padLen    1      data->len
178
             * Pad is 8 or more non-zero random bytes.
179
             *
180
             * Build the block left to right.
181
             * Fill the entire block from Pad to the end with random bytes.
182
             * Use the bytes after Pad as a supply of extra random bytes from
183
             * which to find replacements for the zero bytes in Pad.
184
             * If we need more than that, refill the bytes after Pad with
185
             * new random bytes as necessary.
186
             */
187
188
12.4k
            padLen = modulusLen - (data->len + 3);
189
12.4k
            PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
190
12.4k
            if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
191
0
                PORT_ZFree(block, modulusLen);
192
0
                return NULL;
193
0
            }
194
12.4k
            j = modulusLen - 2;
195
12.4k
            rv = RNG_GenerateGlobalRandomBytes(bp, j);
196
12.4k
            if (rv == SECSuccess) {
197
1.09M
                for (i = 0; i < padLen;) {
198
1.08M
                    unsigned char repl;
199
                    /* Pad with non-zero random data. */
200
1.08M
                    if (bp[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
201
1.08M
                        ++i;
202
1.08M
                        continue;
203
1.08M
                    }
204
3.66k
                    if (j <= padLen) {
205
0
                        rv = RNG_GenerateGlobalRandomBytes(bp + padLen,
206
0
                                                           modulusLen - (2 + padLen));
207
0
                        if (rv != SECSuccess)
208
0
                            break;
209
0
                        j = modulusLen - 2;
210
0
                    }
211
3.66k
                    do {
212
3.66k
                        repl = bp[--j];
213
3.66k
                    } while (repl == RSA_BLOCK_AFTER_PAD_OCTET && j > padLen);
214
3.66k
                    if (repl != RSA_BLOCK_AFTER_PAD_OCTET) {
215
3.66k
                        bp[i++] = repl;
216
3.66k
                    }
217
3.66k
                }
218
12.4k
            }
219
12.4k
            if (rv != SECSuccess) {
220
0
                PORT_ZFree(block, modulusLen);
221
0
                PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
222
0
                return NULL;
223
0
            }
224
12.4k
            bp += padLen;
225
12.4k
            *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
226
12.4k
            PORT_Memcpy(bp, data->data, data->len);
227
12.4k
            break;
228
229
0
        default:
230
0
            PORT_Assert(0);
231
0
            PORT_ZFree(block, modulusLen);
232
0
            return NULL;
233
30.2k
    }
234
235
30.2k
    return block;
236
30.2k
}
237
238
/* modulusLen has to be larger than RSA_BLOCK_MIN_PAD_LEN + 3, and data has to be smaller than modulus - (RSA_BLOCK_MIN_PAD_LEN + 3) */
239
static SECStatus
240
rsa_FormatBlock(SECItem *result,
241
                unsigned modulusLen,
242
                RSA_BlockType blockType,
243
                SECItem *data)
244
30.2k
{
245
30.2k
    switch (blockType) {
246
17.7k
        case RSA_BlockPrivate:
247
30.2k
        case RSA_BlockPublic:
248
            /*
249
             * 0x00 || BT || Pad || 0x00 || ActualData
250
             *
251
             * The "3" below is the first octet + the second octet + the 0x00
252
             * octet that always comes just before the ActualData.
253
             */
254
30.2k
            if (modulusLen < (3 + RSA_BLOCK_MIN_PAD_LEN) || data->len > (modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN))) {
255
0
                return SECFailure;
256
0
            }
257
30.2k
            result->data = rsa_FormatOneBlock(modulusLen, blockType, data);
258
30.2k
            if (result->data == NULL) {
259
0
                result->len = 0;
260
0
                return SECFailure;
261
0
            }
262
30.2k
            result->len = modulusLen;
263
264
30.2k
            break;
265
266
0
        case RSA_BlockRaw:
267
            /*
268
             * Pad || ActualData
269
             * Pad is zeros. The application is responsible for recovering
270
             * the actual data.
271
             */
272
0
            if (data->len > modulusLen) {
273
0
                return SECFailure;
274
0
            }
275
0
            result->data = (unsigned char *)PORT_ZAlloc(modulusLen);
276
0
            result->len = modulusLen;
277
0
            PORT_Memcpy(result->data + (modulusLen - data->len),
278
0
                        data->data, data->len);
279
0
            break;
280
281
0
        default:
282
0
            PORT_Assert(0);
283
0
            result->data = NULL;
284
0
            result->len = 0;
285
0
            return SECFailure;
286
30.2k
    }
287
288
30.2k
    return SECSuccess;
289
30.2k
}
290
291
/*
292
 * Mask generation function MGF1 as defined in PKCS #1 v2.1 / RFC 3447.
293
 */
294
static SECStatus
295
MGF1(HASH_HashType hashAlg,
296
     unsigned char *mask,
297
     unsigned int maskLen,
298
     const unsigned char *mgfSeed,
299
     unsigned int mgfSeedLen)
300
2.45k
{
301
2.45k
    unsigned int digestLen;
302
2.45k
    PRUint32 counter;
303
2.45k
    PRUint32 rounds;
304
2.45k
    unsigned char *tempHash;
305
2.45k
    unsigned char *temp;
306
2.45k
    const SECHashObject *hash;
307
2.45k
    void *hashContext;
308
2.45k
    unsigned char C[4];
309
2.45k
    SECStatus rv = SECSuccess;
310
311
2.45k
    hash = HASH_GetRawHashObject(hashAlg);
312
2.45k
    if (hash == NULL) {
313
0
        return SECFailure;
314
0
    }
315
316
2.45k
    hashContext = (*hash->create)();
317
2.45k
    rounds = (maskLen + hash->length - 1) / hash->length;
318
15.5k
    for (counter = 0; counter < rounds; counter++) {
319
13.1k
        C[0] = (unsigned char)((counter >> 24) & 0xff);
320
13.1k
        C[1] = (unsigned char)((counter >> 16) & 0xff);
321
13.1k
        C[2] = (unsigned char)((counter >> 8) & 0xff);
322
13.1k
        C[3] = (unsigned char)(counter & 0xff);
323
324
        /* This could be optimized when the clone functions in
325
         * rawhash.c are implemented. */
326
13.1k
        (*hash->begin)(hashContext);
327
13.1k
        (*hash->update)(hashContext, mgfSeed, mgfSeedLen);
328
13.1k
        (*hash->update)(hashContext, C, sizeof C);
329
330
13.1k
        tempHash = mask + counter * hash->length;
331
13.1k
        if (counter != (rounds - 1)) {
332
10.6k
            (*hash->end)(hashContext, tempHash, &digestLen, hash->length);
333
10.6k
        } else { /* we're in the last round and need to cut the hash */
334
2.45k
            temp = (unsigned char *)PORT_Alloc(hash->length);
335
2.45k
            if (!temp) {
336
0
                rv = SECFailure;
337
0
                goto done;
338
0
            }
339
2.45k
            (*hash->end)(hashContext, temp, &digestLen, hash->length);
340
2.45k
            PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length);
341
2.45k
            PORT_Free(temp);
342
2.45k
        }
343
13.1k
    }
344
345
2.45k
done:
346
2.45k
    (*hash->destroy)(hashContext, PR_TRUE);
347
2.45k
    return rv;
348
2.45k
}
349
350
/* XXX Doesn't set error code */
351
SECStatus
352
RSA_SignRaw(RSAPrivateKey *key,
353
            unsigned char *output,
354
            unsigned int *outputLen,
355
            unsigned int maxOutputLen,
356
            const unsigned char *data,
357
            unsigned int dataLen)
358
0
{
359
0
    SECStatus rv = SECSuccess;
360
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
361
0
    SECItem formatted;
362
0
    SECItem unformatted;
363
364
0
    if (maxOutputLen < modulusLen)
365
0
        return SECFailure;
366
367
0
    unformatted.len = dataLen;
368
0
    unformatted.data = (unsigned char *)data;
369
0
    formatted.data = NULL;
370
0
    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
371
0
    if (rv != SECSuccess)
372
0
        goto done;
373
374
0
    rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
375
0
    *outputLen = modulusLen;
376
377
0
done:
378
0
    if (formatted.data != NULL)
379
0
        PORT_ZFree(formatted.data, modulusLen);
380
0
    return rv;
381
0
}
382
383
/* XXX Doesn't set error code */
384
SECStatus
385
RSA_CheckSignRaw(RSAPublicKey *key,
386
                 const unsigned char *sig,
387
                 unsigned int sigLen,
388
                 const unsigned char *hash,
389
                 unsigned int hashLen)
390
0
{
391
0
    SECStatus rv;
392
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
393
0
    unsigned char *buffer;
394
395
0
    if (sigLen != modulusLen)
396
0
        goto failure;
397
0
    if (hashLen > modulusLen)
398
0
        goto failure;
399
400
0
    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
401
0
    if (!buffer)
402
0
        goto failure;
403
404
0
    rv = RSA_PublicKeyOp(key, buffer, sig);
405
0
    if (rv != SECSuccess)
406
0
        goto loser;
407
408
    /*
409
     * make sure we get the same results
410
     */
411
    /* XXX(rsleevi): Constant time */
412
    /* NOTE: should we verify the leading zeros? */
413
0
    if (PORT_Memcmp(buffer + (modulusLen - hashLen), hash, hashLen) != 0)
414
0
        goto loser;
415
416
0
    PORT_Free(buffer);
417
0
    return SECSuccess;
418
419
0
loser:
420
0
    PORT_Free(buffer);
421
0
failure:
422
0
    return SECFailure;
423
0
}
424
425
/* XXX Doesn't set error code */
426
SECStatus
427
RSA_CheckSignRecoverRaw(RSAPublicKey *key,
428
                        unsigned char *data,
429
                        unsigned int *dataLen,
430
                        unsigned int maxDataLen,
431
                        const unsigned char *sig,
432
                        unsigned int sigLen)
433
0
{
434
0
    SECStatus rv;
435
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
436
437
0
    if (sigLen != modulusLen)
438
0
        goto failure;
439
0
    if (maxDataLen < modulusLen)
440
0
        goto failure;
441
442
0
    rv = RSA_PublicKeyOp(key, data, sig);
443
0
    if (rv != SECSuccess)
444
0
        goto failure;
445
446
0
    *dataLen = modulusLen;
447
0
    return SECSuccess;
448
449
0
failure:
450
0
    return SECFailure;
451
0
}
452
453
/* XXX Doesn't set error code */
454
SECStatus
455
RSA_EncryptRaw(RSAPublicKey *key,
456
               unsigned char *output,
457
               unsigned int *outputLen,
458
               unsigned int maxOutputLen,
459
               const unsigned char *input,
460
               unsigned int inputLen)
461
0
{
462
0
    SECStatus rv;
463
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
464
0
    SECItem formatted;
465
0
    SECItem unformatted;
466
467
0
    formatted.data = NULL;
468
0
    if (maxOutputLen < modulusLen)
469
0
        goto failure;
470
471
0
    unformatted.len = inputLen;
472
0
    unformatted.data = (unsigned char *)input;
473
0
    formatted.data = NULL;
474
0
    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
475
0
    if (rv != SECSuccess)
476
0
        goto failure;
477
478
0
    rv = RSA_PublicKeyOp(key, output, formatted.data);
479
0
    if (rv != SECSuccess)
480
0
        goto failure;
481
482
0
    PORT_ZFree(formatted.data, modulusLen);
483
0
    *outputLen = modulusLen;
484
0
    return SECSuccess;
485
486
0
failure:
487
0
    if (formatted.data != NULL)
488
0
        PORT_ZFree(formatted.data, modulusLen);
489
0
    return SECFailure;
490
0
}
491
492
/* XXX Doesn't set error code */
493
SECStatus
494
RSA_DecryptRaw(RSAPrivateKey *key,
495
               unsigned char *output,
496
               unsigned int *outputLen,
497
               unsigned int maxOutputLen,
498
               const unsigned char *input,
499
               unsigned int inputLen)
500
0
{
501
0
    SECStatus rv;
502
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
503
504
0
    if (modulusLen > maxOutputLen)
505
0
        goto failure;
506
0
    if (inputLen != modulusLen)
507
0
        goto failure;
508
509
0
    rv = RSA_PrivateKeyOp(key, output, input);
510
0
    if (rv != SECSuccess)
511
0
        goto failure;
512
513
0
    *outputLen = modulusLen;
514
0
    return SECSuccess;
515
516
0
failure:
517
0
    return SECFailure;
518
0
}
519
520
/*
521
 * Decodes an EME-OAEP encoded block, validating the encoding in constant
522
 * time.
523
 * Described in RFC 3447, section 7.1.2.
524
 * input contains the encoded block, after decryption.
525
 * label is the optional value L that was associated with the message.
526
 * On success, the original message and message length will be stored in
527
 * output and outputLen.
528
 */
529
static SECStatus
530
eme_oaep_decode(unsigned char *output,
531
                unsigned int *outputLen,
532
                unsigned int maxOutputLen,
533
                const unsigned char *input,
534
                unsigned int inputLen,
535
                HASH_HashType hashAlg,
536
                HASH_HashType maskHashAlg,
537
                const unsigned char *label,
538
                unsigned int labelLen)
539
0
{
540
0
    const SECHashObject *hash;
541
0
    void *hashContext;
542
0
    SECStatus rv = SECFailure;
543
0
    unsigned char labelHash[HASH_LENGTH_MAX];
544
0
    unsigned int i;
545
0
    unsigned int maskLen;
546
0
    unsigned int paddingOffset;
547
0
    unsigned char *mask = NULL;
548
0
    unsigned char *tmpOutput = NULL;
549
0
    unsigned char isGood;
550
0
    unsigned char foundPaddingEnd;
551
552
0
    hash = HASH_GetRawHashObject(hashAlg);
553
554
    /* 1.c */
555
0
    if (inputLen < (hash->length * 2) + 2) {
556
0
        PORT_SetError(SEC_ERROR_INPUT_LEN);
557
0
        return SECFailure;
558
0
    }
559
560
    /* Step 3.a - Generate lHash */
561
0
    hashContext = (*hash->create)();
562
0
    if (hashContext == NULL) {
563
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
564
0
        return SECFailure;
565
0
    }
566
0
    (*hash->begin)(hashContext);
567
0
    if (labelLen > 0)
568
0
        (*hash->update)(hashContext, label, labelLen);
569
0
    (*hash->end)(hashContext, labelHash, &i, sizeof(labelHash));
570
0
    (*hash->destroy)(hashContext, PR_TRUE);
571
572
0
    tmpOutput = (unsigned char *)PORT_Alloc(inputLen);
573
0
    if (tmpOutput == NULL) {
574
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
575
0
        goto done;
576
0
    }
577
578
0
    maskLen = inputLen - hash->length - 1;
579
0
    mask = (unsigned char *)PORT_Alloc(maskLen);
580
0
    if (mask == NULL) {
581
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
582
0
        goto done;
583
0
    }
584
585
0
    PORT_Memcpy(tmpOutput, input, inputLen);
586
587
    /* 3.c - Generate seedMask */
588
0
    MGF1(maskHashAlg, mask, hash->length, &tmpOutput[1 + hash->length],
589
0
         inputLen - hash->length - 1);
590
    /* 3.d - Unmask seed */
591
0
    for (i = 0; i < hash->length; ++i)
592
0
        tmpOutput[1 + i] ^= mask[i];
593
594
    /* 3.e - Generate dbMask */
595
0
    MGF1(maskHashAlg, mask, maskLen, &tmpOutput[1], hash->length);
596
    /* 3.f - Unmask DB */
597
0
    for (i = 0; i < maskLen; ++i)
598
0
        tmpOutput[1 + hash->length + i] ^= mask[i];
599
600
    /* 3.g - Compare Y, lHash, and PS in constant time
601
     * Warning: This code is timing dependent and must not disclose which of
602
     * these were invalid.
603
     */
604
0
    paddingOffset = 0;
605
0
    isGood = 1;
606
0
    foundPaddingEnd = 0;
607
608
    /* Compare Y */
609
0
    isGood &= constantTimeEQ8(0x00, tmpOutput[0]);
610
611
    /* Compare lHash and lHash' */
612
0
    isGood &= constantTimeCompare(&labelHash[0],
613
0
                                  &tmpOutput[1 + hash->length],
614
0
                                  hash->length);
615
616
    /* Compare that the padding is zero or more zero octets, followed by a
617
     * 0x01 octet */
618
0
    for (i = 1 + (hash->length * 2); i < inputLen; ++i) {
619
0
        unsigned char isZero = constantTimeEQ8(0x00, tmpOutput[i]);
620
0
        unsigned char isOne = constantTimeEQ8(0x01, tmpOutput[i]);
621
        /* non-constant time equivalent:
622
         * if (tmpOutput[i] == 0x01 && !foundPaddingEnd)
623
         *     paddingOffset = i;
624
         */
625
0
        paddingOffset = constantTimeCondition(isOne & ~foundPaddingEnd, i,
626
0
                                              paddingOffset);
627
        /* non-constant time equivalent:
628
         * if (tmpOutput[i] == 0x01)
629
         *    foundPaddingEnd = true;
630
         *
631
         * Note: This may yield false positives, as it will be set whenever
632
         * a 0x01 byte is encountered. If there was bad padding (eg:
633
         * 0x03 0x02 0x01), foundPaddingEnd will still be set to true, and
634
         * paddingOffset will still be set to 2.
635
         */
636
0
        foundPaddingEnd = constantTimeCondition(isOne, 1, foundPaddingEnd);
637
        /* non-constant time equivalent:
638
         * if (tmpOutput[i] != 0x00 && tmpOutput[i] != 0x01 &&
639
         *     !foundPaddingEnd) {
640
         *    isGood = false;
641
         * }
642
         *
643
         * Note: This may yield false positives, as a message (and padding)
644
         * that is entirely zeros will result in isGood still being true. Thus
645
         * it's necessary to check foundPaddingEnd is positive below.
646
         */
647
0
        isGood = constantTimeCondition(~foundPaddingEnd & ~isZero, 0, isGood);
648
0
    }
649
650
    /* While both isGood and foundPaddingEnd may have false positives, they
651
     * cannot BOTH have false positives. If both are not true, then an invalid
652
     * message was received. Note, this comparison must still be done in constant
653
     * time so as not to leak either condition.
654
     */
655
0
    if (!(isGood & foundPaddingEnd)) {
656
0
        PORT_SetError(SEC_ERROR_BAD_DATA);
657
0
        goto done;
658
0
    }
659
660
    /* End timing dependent code */
661
662
0
    ++paddingOffset; /* Skip the 0x01 following the end of PS */
663
664
0
    *outputLen = inputLen - paddingOffset;
665
0
    if (*outputLen > maxOutputLen) {
666
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
667
0
        goto done;
668
0
    }
669
670
0
    if (*outputLen)
671
0
        PORT_Memcpy(output, &tmpOutput[paddingOffset], *outputLen);
672
0
    rv = SECSuccess;
673
674
0
done:
675
0
    if (mask)
676
0
        PORT_ZFree(mask, maskLen);
677
0
    if (tmpOutput)
678
0
        PORT_ZFree(tmpOutput, inputLen);
679
0
    return rv;
680
0
}
681
682
/*
683
 * Generate an EME-OAEP encoded block for encryption
684
 * Described in RFC 3447, section 7.1.1
685
 * We use input instead of M for the message to be encrypted
686
 * label is the optional value L to be associated with the message.
687
 */
688
static SECStatus
689
eme_oaep_encode(unsigned char *em,
690
                unsigned int emLen,
691
                const unsigned char *input,
692
                unsigned int inputLen,
693
                HASH_HashType hashAlg,
694
                HASH_HashType maskHashAlg,
695
                const unsigned char *label,
696
                unsigned int labelLen,
697
                const unsigned char *seed,
698
                unsigned int seedLen)
699
0
{
700
0
    const SECHashObject *hash;
701
0
    void *hashContext;
702
0
    SECStatus rv;
703
0
    unsigned char *mask;
704
0
    unsigned int reservedLen;
705
0
    unsigned int dbMaskLen;
706
0
    unsigned int i;
707
708
0
    hash = HASH_GetRawHashObject(hashAlg);
709
0
    PORT_Assert(seed == NULL || seedLen == hash->length);
710
711
    /* Step 1.b */
712
0
    reservedLen = (2 * hash->length) + 2;
713
0
    if (emLen < reservedLen || inputLen > (emLen - reservedLen)) {
714
0
        PORT_SetError(SEC_ERROR_INPUT_LEN);
715
0
        return SECFailure;
716
0
    }
717
718
    /*
719
     * From RFC 3447, Section 7.1
720
     *                      +----------+---------+-------+
721
     *                 DB = |  lHash   |    PS   |   M   |
722
     *                      +----------+---------+-------+
723
     *                                     |
724
     *           +----------+              V
725
     *           |   seed   |--> MGF ---> xor
726
     *           +----------+              |
727
     *                 |                   |
728
     *        +--+     V                   |
729
     *        |00|    xor <----- MGF <-----|
730
     *        +--+     |                   |
731
     *          |      |                   |
732
     *          V      V                   V
733
     *        +--+----------+----------------------------+
734
     *  EM =  |00|maskedSeed|          maskedDB          |
735
     *        +--+----------+----------------------------+
736
     *
737
     * We use mask to hold the result of the MGF functions, and all other
738
     * values are generated in their final resting place.
739
     */
740
0
    *em = 0x00;
741
742
    /* Step 2.a - Generate lHash */
743
0
    hashContext = (*hash->create)();
744
0
    if (hashContext == NULL) {
745
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
746
0
        return SECFailure;
747
0
    }
748
0
    (*hash->begin)(hashContext);
749
0
    if (labelLen > 0)
750
0
        (*hash->update)(hashContext, label, labelLen);
751
0
    (*hash->end)(hashContext, &em[1 + hash->length], &i, hash->length);
752
0
    (*hash->destroy)(hashContext, PR_TRUE);
753
754
    /* Step 2.b - Generate PS */
755
0
    if (emLen - reservedLen - inputLen > 0) {
756
0
        PORT_Memset(em + 1 + (hash->length * 2), 0x00,
757
0
                    emLen - reservedLen - inputLen);
758
0
    }
759
760
    /* Step 2.c. - Generate DB
761
     * DB = lHash || PS || 0x01 || M
762
     * Note that PS and lHash have already been placed into em at their
763
     * appropriate offsets. This just copies M into place
764
     */
765
0
    em[emLen - inputLen - 1] = 0x01;
766
0
    if (inputLen)
767
0
        PORT_Memcpy(em + emLen - inputLen, input, inputLen);
768
769
0
    if (seed == NULL) {
770
        /* Step 2.d - Generate seed */
771
0
        rv = RNG_GenerateGlobalRandomBytes(em + 1, hash->length);
772
0
        if (rv != SECSuccess) {
773
0
            return rv;
774
0
        }
775
0
    } else {
776
        /* For Known Answer Tests, copy the supplied seed. */
777
0
        PORT_Memcpy(em + 1, seed, seedLen);
778
0
    }
779
780
    /* Step 2.e - Generate dbMask*/
781
0
    dbMaskLen = emLen - hash->length - 1;
782
0
    mask = (unsigned char *)PORT_Alloc(dbMaskLen);
783
0
    if (mask == NULL) {
784
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
785
0
        return SECFailure;
786
0
    }
787
0
    MGF1(maskHashAlg, mask, dbMaskLen, em + 1, hash->length);
788
    /* Step 2.f - Compute maskedDB*/
789
0
    for (i = 0; i < dbMaskLen; ++i)
790
0
        em[1 + hash->length + i] ^= mask[i];
791
792
    /* Step 2.g - Generate seedMask */
793
0
    MGF1(maskHashAlg, mask, hash->length, &em[1 + hash->length], dbMaskLen);
794
    /* Step 2.h - Compute maskedSeed */
795
0
    for (i = 0; i < hash->length; ++i)
796
0
        em[1 + i] ^= mask[i];
797
798
0
    PORT_ZFree(mask, dbMaskLen);
799
0
    return SECSuccess;
800
0
}
801
802
SECStatus
803
RSA_EncryptOAEP(RSAPublicKey *key,
804
                HASH_HashType hashAlg,
805
                HASH_HashType maskHashAlg,
806
                const unsigned char *label,
807
                unsigned int labelLen,
808
                const unsigned char *seed,
809
                unsigned int seedLen,
810
                unsigned char *output,
811
                unsigned int *outputLen,
812
                unsigned int maxOutputLen,
813
                const unsigned char *input,
814
                unsigned int inputLen)
815
0
{
816
0
    SECStatus rv = SECFailure;
817
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
818
0
    unsigned char *oaepEncoded = NULL;
819
820
0
    if (maxOutputLen < modulusLen) {
821
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
822
0
        return SECFailure;
823
0
    }
824
825
0
    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
826
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
827
0
        return SECFailure;
828
0
    }
829
830
0
    if ((labelLen == 0 && label != NULL) ||
831
0
        (labelLen > 0 && label == NULL)) {
832
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
833
0
        return SECFailure;
834
0
    }
835
836
0
    oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
837
0
    if (oaepEncoded == NULL) {
838
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
839
0
        return SECFailure;
840
0
    }
841
0
    rv = eme_oaep_encode(oaepEncoded, modulusLen, input, inputLen,
842
0
                         hashAlg, maskHashAlg, label, labelLen, seed, seedLen);
843
0
    if (rv != SECSuccess)
844
0
        goto done;
845
846
0
    rv = RSA_PublicKeyOp(key, output, oaepEncoded);
847
0
    if (rv != SECSuccess)
848
0
        goto done;
849
0
    *outputLen = modulusLen;
850
851
0
done:
852
0
    PORT_Free(oaepEncoded);
853
0
    return rv;
854
0
}
855
856
SECStatus
857
RSA_DecryptOAEP(RSAPrivateKey *key,
858
                HASH_HashType hashAlg,
859
                HASH_HashType maskHashAlg,
860
                const unsigned char *label,
861
                unsigned int labelLen,
862
                unsigned char *output,
863
                unsigned int *outputLen,
864
                unsigned int maxOutputLen,
865
                const unsigned char *input,
866
                unsigned int inputLen)
867
0
{
868
0
    SECStatus rv = SECFailure;
869
0
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
870
0
    unsigned char *oaepEncoded = NULL;
871
872
0
    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
873
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
874
0
        return SECFailure;
875
0
    }
876
877
0
    if (inputLen != modulusLen) {
878
0
        PORT_SetError(SEC_ERROR_INPUT_LEN);
879
0
        return SECFailure;
880
0
    }
881
882
0
    if ((labelLen == 0 && label != NULL) ||
883
0
        (labelLen > 0 && label == NULL)) {
884
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
885
0
        return SECFailure;
886
0
    }
887
888
0
    oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
889
0
    if (oaepEncoded == NULL) {
890
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
891
0
        return SECFailure;
892
0
    }
893
894
0
    rv = RSA_PrivateKeyOpDoubleChecked(key, oaepEncoded, input);
895
0
    if (rv != SECSuccess) {
896
0
        goto done;
897
0
    }
898
0
    rv = eme_oaep_decode(output, outputLen, maxOutputLen, oaepEncoded,
899
0
                         modulusLen, hashAlg, maskHashAlg, label,
900
0
                         labelLen);
901
902
0
done:
903
0
    if (oaepEncoded)
904
0
        PORT_ZFree(oaepEncoded, modulusLen);
905
0
    return rv;
906
0
}
907
908
/* XXX Doesn't set error code */
909
SECStatus
910
RSA_EncryptBlock(RSAPublicKey *key,
911
                 unsigned char *output,
912
                 unsigned int *outputLen,
913
                 unsigned int maxOutputLen,
914
                 const unsigned char *input,
915
                 unsigned int inputLen)
916
12.4k
{
917
12.4k
    SECStatus rv;
918
12.4k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
919
12.4k
    SECItem formatted;
920
12.4k
    SECItem unformatted;
921
922
12.4k
    formatted.data = NULL;
923
12.4k
    if (maxOutputLen < modulusLen)
924
0
        goto failure;
925
926
12.4k
    unformatted.len = inputLen;
927
12.4k
    unformatted.data = (unsigned char *)input;
928
12.4k
    formatted.data = NULL;
929
12.4k
    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPublic,
930
12.4k
                         &unformatted);
931
12.4k
    if (rv != SECSuccess)
932
0
        goto failure;
933
934
12.4k
    rv = RSA_PublicKeyOp(key, output, formatted.data);
935
12.4k
    if (rv != SECSuccess)
936
0
        goto failure;
937
938
12.4k
    PORT_ZFree(formatted.data, modulusLen);
939
12.4k
    *outputLen = modulusLen;
940
12.4k
    return SECSuccess;
941
942
0
failure:
943
0
    if (formatted.data != NULL)
944
0
        PORT_ZFree(formatted.data, modulusLen);
945
0
    return SECFailure;
946
12.4k
}
947
948
static HMACContext *
949
rsa_GetHMACContext(const SECHashObject *hash, RSAPrivateKey *key,
950
                   const unsigned char *input, unsigned int inputLen)
951
2.77k
{
952
2.77k
    unsigned char keyHash[HASH_LENGTH_MAX];
953
2.77k
    void *hashContext;
954
2.77k
    HMACContext *hmac = NULL;
955
2.77k
    unsigned int privKeyLen = key->privateExponent.len;
956
2.77k
    unsigned int keyLen;
957
2.77k
    SECStatus rv;
958
959
    /* first get the key hash (should store in the key structure) */
960
2.77k
    PORT_Memset(keyHash, 0, sizeof(keyHash));
961
2.77k
    hashContext = (*hash->create)();
962
2.77k
    if (hashContext == NULL) {
963
0
        return NULL;
964
0
    }
965
2.77k
    (*hash->begin)(hashContext);
966
2.77k
    if (privKeyLen < inputLen) {
967
0
        int padLen = inputLen - privKeyLen;
968
0
        while (padLen > sizeof(keyHash)) {
969
0
            (*hash->update)(hashContext, keyHash, sizeof(keyHash));
970
0
            padLen -= sizeof(keyHash);
971
0
        }
972
0
        (*hash->update)(hashContext, keyHash, padLen);
973
0
    }
974
2.77k
    (*hash->update)(hashContext, key->privateExponent.data, privKeyLen);
975
2.77k
    (*hash->end)(hashContext, keyHash, &keyLen, sizeof(keyHash));
976
2.77k
    (*hash->destroy)(hashContext, PR_TRUE);
977
978
    /* now create the hmac key */
979
2.77k
    hmac = HMAC_Create(hash, keyHash, keyLen, PR_TRUE);
980
2.77k
    if (hmac == NULL) {
981
0
        PORT_SafeZero(keyHash, sizeof(keyHash));
982
0
        return NULL;
983
0
    }
984
2.77k
    HMAC_Begin(hmac);
985
2.77k
    HMAC_Update(hmac, input, inputLen);
986
2.77k
    rv = HMAC_Finish(hmac, keyHash, &keyLen, sizeof(keyHash));
987
2.77k
    if (rv != SECSuccess) {
988
0
        PORT_SafeZero(keyHash, sizeof(keyHash));
989
0
        HMAC_Destroy(hmac, PR_TRUE);
990
0
        return NULL;
991
0
    }
992
    /* Finally set the new key into the hash context. We
993
     * reuse the original context allocated above so we don't
994
     * need to allocate and free another one */
995
2.77k
    rv = HMAC_ReInit(hmac, hash, keyHash, keyLen, PR_TRUE);
996
2.77k
    PORT_SafeZero(keyHash, sizeof(keyHash));
997
2.77k
    if (rv != SECSuccess) {
998
0
        HMAC_Destroy(hmac, PR_TRUE);
999
0
        return NULL;
1000
0
    }
1001
1002
2.77k
    return hmac;
1003
2.77k
}
1004
1005
static SECStatus
1006
rsa_HMACPrf(HMACContext *hmac, const char *label, int labelLen,
1007
            int hashLength, unsigned char *output, int length)
1008
5.55k
{
1009
5.55k
    unsigned char iterator[2] = { 0, 0 };
1010
5.55k
    unsigned char encodedLen[2] = { 0, 0 };
1011
5.55k
    unsigned char hmacLast[HASH_LENGTH_MAX];
1012
5.55k
    unsigned int left = length;
1013
5.55k
    unsigned int hashReturn;
1014
5.55k
    SECStatus rv = SECSuccess;
1015
1016
    /* encodedLen is in bits, length is in bytes, thus the shifts
1017
     * do an implied multiply by 8 */
1018
5.55k
    encodedLen[0] = (length >> 5) & 0xff;
1019
5.55k
    encodedLen[1] = (length << 3) & 0xff;
1020
1021
44.4k
    while (left > hashLength) {
1022
38.8k
        HMAC_Begin(hmac);
1023
38.8k
        HMAC_Update(hmac, iterator, 2);
1024
38.8k
        HMAC_Update(hmac, (const unsigned char *)label, labelLen);
1025
38.8k
        HMAC_Update(hmac, encodedLen, 2);
1026
38.8k
        rv = HMAC_Finish(hmac, output, &hashReturn, hashLength);
1027
38.8k
        if (rv != SECSuccess) {
1028
0
            return rv;
1029
0
        }
1030
38.8k
        iterator[1]++;
1031
38.8k
        if (iterator[1] == 0)
1032
0
            iterator[0]++;
1033
38.8k
        left -= hashLength;
1034
38.8k
        output += hashLength;
1035
38.8k
    }
1036
5.55k
    if (left) {
1037
5.55k
        HMAC_Begin(hmac);
1038
5.55k
        HMAC_Update(hmac, iterator, 2);
1039
5.55k
        HMAC_Update(hmac, (const unsigned char *)label, labelLen);
1040
5.55k
        HMAC_Update(hmac, encodedLen, 2);
1041
5.55k
        rv = HMAC_Finish(hmac, hmacLast, &hashReturn, sizeof(hmacLast));
1042
5.55k
        if (rv != SECSuccess) {
1043
0
            return rv;
1044
0
        }
1045
5.55k
        PORT_Memcpy(output, hmacLast, left);
1046
5.55k
        PORT_SafeZero(hmacLast, sizeof(hmacLast));
1047
5.55k
    }
1048
5.55k
    return rv;
1049
5.55k
}
1050
1051
/* This function takes a 16-bit input number and
1052
 * creates the smallest mask which covers
1053
 * the whole number. Examples:
1054
 *     0x81 -> 0xff
1055
 *     0x1af -> 0x1ff
1056
 *     0x4d1 -> 0x7ff
1057
 */
1058
static int
1059
makeMask16(int len)
1060
2.77k
{
1061
    // or the high bit in each bit location
1062
2.77k
    len |= (len >> 1);
1063
2.77k
    len |= (len >> 2);
1064
2.77k
    len |= (len >> 4);
1065
2.77k
    len |= (len >> 8);
1066
2.77k
    return len;
1067
2.77k
}
1068
1069
5.55k
#define STRING_AND_LENGTH(s) s, sizeof(s) - 1
1070
static int
1071
rsa_GetErrorLength(HMACContext *hmac, int hashLen, int maxLegalLen)
1072
2.77k
{
1073
2.77k
    unsigned char out[128 * 2];
1074
2.77k
    unsigned char *outp;
1075
2.77k
    int outLength = 0;
1076
2.77k
    int lengthMask;
1077
2.77k
    SECStatus rv;
1078
1079
2.77k
    lengthMask = makeMask16(maxLegalLen);
1080
2.77k
    rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("length"), hashLen,
1081
2.77k
                     out, sizeof(out));
1082
2.77k
    if (rv != SECSuccess) {
1083
0
        return -1;
1084
0
    }
1085
358k
    for (outp = out; outp < out + sizeof(out); outp += 2) {
1086
355k
        int candidate = outp[0] << 8 | outp[1];
1087
355k
        candidate = candidate & lengthMask;
1088
355k
        outLength = PORT_CT_SEL(PORT_CT_LT(candidate, maxLegalLen),
1089
355k
                                candidate, outLength);
1090
355k
    }
1091
2.77k
    PORT_SafeZero(out, sizeof(out));
1092
2.77k
    return outLength;
1093
2.77k
}
1094
1095
/*
1096
 * This function can only fail in environmental cases: Programming errors
1097
 * and out of memory situations. It can't fail if the keys are valid and
1098
 * the inputs are the proper size. If the actual RSA decryption fails, a
1099
 * fake value and a fake length, both of which have already been generated
1100
 * based on the key and input, are returned.
1101
 * Applications are expected to detect decryption failures based on the fact
1102
 * that the decrypted value (usually a key) doesn't validate. The prevents
1103
 * Blecheinbaucher style attacks against the key. */
1104
SECStatus
1105
RSA_DecryptBlock(RSAPrivateKey *key,
1106
                 unsigned char *output,
1107
                 unsigned int *outputLen,
1108
                 unsigned int maxOutputLen,
1109
                 const unsigned char *input,
1110
                 unsigned int inputLen)
1111
72.6k
{
1112
72.6k
    SECStatus rv;
1113
72.6k
    PRUint32 fail;
1114
72.6k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1115
72.6k
    unsigned int i;
1116
72.6k
    unsigned char *buffer = NULL;
1117
72.6k
    unsigned char *errorBuffer = NULL;
1118
72.6k
    unsigned char *bp = NULL;
1119
72.6k
    unsigned char *ep = NULL;
1120
72.6k
    unsigned int outLen = modulusLen;
1121
72.6k
    unsigned int maxLegalLen = modulusLen - 10;
1122
72.6k
    unsigned int errorLength;
1123
72.6k
    const SECHashObject *hashObj;
1124
72.6k
    HMACContext *hmac = NULL;
1125
1126
    /* failures in the top section indicate failures in the environment
1127
     * (memory) or the library. OK to return errors in these cases because
1128
     * it doesn't provide any oracle information to attackers. */
1129
72.6k
    if (inputLen != modulusLen || modulusLen < 10) {
1130
69.8k
        PORT_SetError(SEC_ERROR_INVALID_ARGS);
1131
69.8k
        return SECFailure;
1132
69.8k
    }
1133
1134
    /* Allocate enough space to decrypt */
1135
2.77k
    buffer = PORT_ZAlloc(modulusLen);
1136
2.77k
    if (!buffer) {
1137
0
        goto loser;
1138
0
    }
1139
2.77k
    errorBuffer = PORT_ZAlloc(modulusLen);
1140
2.77k
    if (!errorBuffer) {
1141
0
        goto loser;
1142
0
    }
1143
2.77k
    hashObj = HASH_GetRawHashObject(HASH_AlgSHA256);
1144
2.77k
    if (hashObj == NULL) {
1145
0
        goto loser;
1146
0
    }
1147
1148
    /* calculate the values to return in the error case rather than
1149
     * the actual returned values. This data is the same for the
1150
     * same input and private key. */
1151
2.77k
    hmac = rsa_GetHMACContext(hashObj, key, input, inputLen);
1152
2.77k
    if (hmac == NULL) {
1153
0
        goto loser;
1154
0
    }
1155
2.77k
    errorLength = rsa_GetErrorLength(hmac, hashObj->length, maxLegalLen);
1156
2.77k
    if (((int)errorLength) < 0) {
1157
0
        goto loser;
1158
0
    }
1159
    /* we always have to generate a full moduluslen error string. Otherwise
1160
     * we create a timing dependency on errorLength, which could be used to
1161
     * determine the difference between errorLength and outputLen and tell
1162
     * us that there was a pkcs1 decryption failure */
1163
2.77k
    rv = rsa_HMACPrf(hmac, STRING_AND_LENGTH("message"),
1164
2.77k
                     hashObj->length, errorBuffer, modulusLen);
1165
2.77k
    if (rv != SECSuccess) {
1166
0
        goto loser;
1167
0
    }
1168
1169
2.77k
    HMAC_Destroy(hmac, PR_TRUE);
1170
2.77k
    hmac = NULL;
1171
1172
    /* From here on out, we will always return success. If there is
1173
     * an error, we will return deterministic output based on the key
1174
     * and the input data. */
1175
2.77k
    rv = RSA_PrivateKeyOp(key, buffer, input);
1176
1177
2.77k
    fail = PORT_CT_NE(rv, SECSuccess);
1178
2.77k
    fail |= PORT_CT_NE(buffer[0], RSA_BLOCK_FIRST_OCTET) | PORT_CT_NE(buffer[1], RSA_BlockPublic);
1179
1180
    /* There have to be at least 8 bytes of padding. */
1181
25.0k
    for (i = 2; i < 10; i++) {
1182
22.2k
        fail |= PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET);
1183
22.2k
    }
1184
1185
686k
    for (i = 10; i < modulusLen; i++) {
1186
683k
        unsigned int newLen = modulusLen - i - 1;
1187
683k
        PRUint32 condition = PORT_CT_EQ(buffer[i], RSA_BLOCK_AFTER_PAD_OCTET) & PORT_CT_EQ(outLen, modulusLen);
1188
683k
        outLen = PORT_CT_SEL(condition, newLen, outLen);
1189
683k
    }
1190
    // this can only happen if a zero wasn't found above
1191
2.77k
    fail |= PORT_CT_GE(outLen, modulusLen);
1192
1193
2.77k
    outLen = PORT_CT_SEL(fail, errorLength, outLen);
1194
1195
    /* index into the correct buffer. Do it before we truncate outLen if the
1196
     * application was asking for less data than we can return */
1197
2.77k
    bp = buffer + modulusLen - outLen;
1198
2.77k
    ep = errorBuffer + modulusLen - outLen;
1199
1200
    /* at this point, outLen returns no information about decryption failures,
1201
     * no need to hide its value. maxOutputLen is how much data the
1202
     * application is expecting, which is also not sensitive. */
1203
2.77k
    if (outLen > maxOutputLen) {
1204
0
        outLen = maxOutputLen;
1205
0
    }
1206
1207
    /* we can't use PORT_Memcpy because caching could create a time dependency
1208
     * on the status of fail. */
1209
356k
    for (i = 0; i < outLen; i++) {
1210
353k
        output[i] = PORT_CT_SEL(fail, ep[i], bp[i]);
1211
353k
    }
1212
1213
2.77k
    *outputLen = outLen;
1214
1215
2.77k
    PORT_Free(buffer);
1216
2.77k
    PORT_Free(errorBuffer);
1217
1218
2.77k
    return SECSuccess;
1219
1220
0
loser:
1221
0
    if (hmac) {
1222
0
        HMAC_Destroy(hmac, PR_TRUE);
1223
0
    }
1224
0
    PORT_Free(buffer);
1225
0
    PORT_Free(errorBuffer);
1226
1227
0
    return SECFailure;
1228
2.77k
}
1229
1230
/*
1231
 * Encode a RSA-PSS signature.
1232
 * Described in RFC 3447, section 9.1.1.
1233
 * We use mHash instead of M as input.
1234
 * emBits from the RFC is just modBits - 1, see section 8.1.1.
1235
 * We only support MGF1 as the MGF.
1236
 */
1237
SECStatus
1238
RSA_EMSAEncodePSS(unsigned char *em,
1239
                  unsigned int emLen,
1240
                  unsigned int emBits,
1241
                  const unsigned char *mHash,
1242
                  HASH_HashType hashAlg,
1243
                  HASH_HashType maskHashAlg,
1244
                  const unsigned char *salt,
1245
                  unsigned int saltLen)
1246
2.23k
{
1247
2.23k
    const SECHashObject *hash;
1248
2.23k
    void *hash_context;
1249
2.23k
    unsigned char *dbMask;
1250
2.23k
    unsigned int dbMaskLen;
1251
2.23k
    unsigned int i;
1252
2.23k
    SECStatus rv;
1253
1254
2.23k
    hash = HASH_GetRawHashObject(hashAlg);
1255
2.23k
    dbMaskLen = emLen - hash->length - 1;
1256
1257
    /* Step 3 */
1258
2.23k
    if (emLen < hash->length + saltLen + 2) {
1259
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1260
0
        return SECFailure;
1261
0
    }
1262
1263
    /* Step 4 */
1264
2.23k
    if (salt == NULL) {
1265
2.23k
        rv = RNG_GenerateGlobalRandomBytes(&em[dbMaskLen - saltLen], saltLen);
1266
2.23k
        if (rv != SECSuccess) {
1267
0
            return rv;
1268
0
        }
1269
2.23k
    } else {
1270
0
        PORT_Memcpy(&em[dbMaskLen - saltLen], salt, saltLen);
1271
0
    }
1272
1273
    /* Step 5 + 6 */
1274
    /* Compute H and store it at its final location &em[dbMaskLen]. */
1275
2.23k
    hash_context = (*hash->create)();
1276
2.23k
    if (hash_context == NULL) {
1277
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1278
0
        return SECFailure;
1279
0
    }
1280
2.23k
    (*hash->begin)(hash_context);
1281
2.23k
    (*hash->update)(hash_context, eightZeros, 8);
1282
2.23k
    (*hash->update)(hash_context, mHash, hash->length);
1283
2.23k
    (*hash->update)(hash_context, &em[dbMaskLen - saltLen], saltLen);
1284
2.23k
    (*hash->end)(hash_context, &em[dbMaskLen], &i, hash->length);
1285
2.23k
    (*hash->destroy)(hash_context, PR_TRUE);
1286
1287
    /* Step 7 + 8 */
1288
2.23k
    PORT_Memset(em, 0, dbMaskLen - saltLen - 1);
1289
2.23k
    em[dbMaskLen - saltLen - 1] = 0x01;
1290
1291
    /* Step 9 */
1292
2.23k
    dbMask = (unsigned char *)PORT_Alloc(dbMaskLen);
1293
2.23k
    if (dbMask == NULL) {
1294
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1295
0
        return SECFailure;
1296
0
    }
1297
2.23k
    MGF1(maskHashAlg, dbMask, dbMaskLen, &em[dbMaskLen], hash->length);
1298
1299
    /* Step 10 */
1300
471k
    for (i = 0; i < dbMaskLen; i++)
1301
468k
        em[i] ^= dbMask[i];
1302
2.23k
    PORT_Free(dbMask);
1303
1304
    /* Step 11 */
1305
2.23k
    em[0] &= 0xff >> (8 * emLen - emBits);
1306
1307
    /* Step 12 */
1308
2.23k
    em[emLen - 1] = 0xbc;
1309
1310
2.23k
    return SECSuccess;
1311
2.23k
}
1312
1313
/*
1314
 * Verify a RSA-PSS signature.
1315
 * Described in RFC 3447, section 9.1.2.
1316
 * We use mHash instead of M as input.
1317
 * emBits from the RFC is just modBits - 1, see section 8.1.2.
1318
 * We only support MGF1 as the MGF.
1319
 */
1320
static SECStatus
1321
emsa_pss_verify(const unsigned char *mHash,
1322
                const unsigned char *em,
1323
                unsigned int emLen,
1324
                unsigned int emBits,
1325
                HASH_HashType hashAlg,
1326
                HASH_HashType maskHashAlg,
1327
                unsigned int saltLen)
1328
3.70k
{
1329
3.70k
    const SECHashObject *hash;
1330
3.70k
    void *hash_context;
1331
3.70k
    unsigned char *db;
1332
3.70k
    unsigned char *H_; /* H' from the RFC */
1333
3.70k
    unsigned int i;
1334
3.70k
    unsigned int dbMaskLen;
1335
3.70k
    unsigned int zeroBits;
1336
3.70k
    SECStatus rv;
1337
1338
3.70k
    hash = HASH_GetRawHashObject(hashAlg);
1339
3.70k
    dbMaskLen = emLen - hash->length - 1;
1340
1341
    /* Step 3 + 4 */
1342
3.70k
    if ((emLen < (hash->length + saltLen + 2)) ||
1343
3.70k
        (em[emLen - 1] != 0xbc)) {
1344
3.46k
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1345
3.46k
        return SECFailure;
1346
3.46k
    }
1347
1348
    /* Step 6 */
1349
245
    zeroBits = 8 * emLen - emBits;
1350
245
    if (em[0] >> (8 - zeroBits)) {
1351
22
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1352
22
        return SECFailure;
1353
22
    }
1354
1355
    /* Step 7 */
1356
223
    db = (unsigned char *)PORT_Alloc(dbMaskLen);
1357
223
    if (db == NULL) {
1358
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1359
0
        return SECFailure;
1360
0
    }
1361
    /* &em[dbMaskLen] points to H, used as mgfSeed */
1362
223
    MGF1(maskHashAlg, db, dbMaskLen, &em[dbMaskLen], hash->length);
1363
1364
    /* Step 8 */
1365
37.9k
    for (i = 0; i < dbMaskLen; i++) {
1366
37.6k
        db[i] ^= em[i];
1367
37.6k
    }
1368
1369
    /* Step 9 */
1370
223
    db[0] &= 0xff >> zeroBits;
1371
1372
    /* Step 10 */
1373
16.7k
    for (i = 0; i < (dbMaskLen - saltLen - 1); i++) {
1374
16.6k
        if (db[i] != 0) {
1375
107
            PORT_Free(db);
1376
107
            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1377
107
            return SECFailure;
1378
107
        }
1379
16.6k
    }
1380
116
    if (db[dbMaskLen - saltLen - 1] != 0x01) {
1381
1
        PORT_Free(db);
1382
1
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1383
1
        return SECFailure;
1384
1
    }
1385
1386
    /* Step 12 + 13 */
1387
115
    H_ = (unsigned char *)PORT_Alloc(hash->length);
1388
115
    if (H_ == NULL) {
1389
0
        PORT_Free(db);
1390
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1391
0
        return SECFailure;
1392
0
    }
1393
115
    hash_context = (*hash->create)();
1394
115
    if (hash_context == NULL) {
1395
0
        PORT_Free(db);
1396
0
        PORT_Free(H_);
1397
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1398
0
        return SECFailure;
1399
0
    }
1400
115
    (*hash->begin)(hash_context);
1401
115
    (*hash->update)(hash_context, eightZeros, 8);
1402
115
    (*hash->update)(hash_context, mHash, hash->length);
1403
115
    (*hash->update)(hash_context, &db[dbMaskLen - saltLen], saltLen);
1404
115
    (*hash->end)(hash_context, H_, &i, hash->length);
1405
115
    (*hash->destroy)(hash_context, PR_TRUE);
1406
1407
115
    PORT_Free(db);
1408
1409
    /* Step 14 */
1410
115
    if (PORT_Memcmp(H_, &em[dbMaskLen], hash->length) != 0) {
1411
76
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1412
76
        rv = SECFailure;
1413
76
    } else {
1414
39
        rv = SECSuccess;
1415
39
    }
1416
1417
115
    PORT_Free(H_);
1418
115
    return rv;
1419
115
}
1420
1421
SECStatus
1422
RSA_SignPSS(RSAPrivateKey *key,
1423
            HASH_HashType hashAlg,
1424
            HASH_HashType maskHashAlg,
1425
            const unsigned char *salt,
1426
            unsigned int saltLength,
1427
            unsigned char *output,
1428
            unsigned int *outputLen,
1429
            unsigned int maxOutputLen,
1430
            const unsigned char *input,
1431
            unsigned int inputLen)
1432
2.23k
{
1433
2.23k
    SECStatus rv = SECSuccess;
1434
2.23k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1435
2.23k
    unsigned int modulusBits = rsa_modulusBits(&key->modulus);
1436
2.23k
    unsigned int emLen = modulusLen;
1437
2.23k
    unsigned char *pssEncoded, *em;
1438
1439
2.23k
    if (maxOutputLen < modulusLen) {
1440
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1441
0
        return SECFailure;
1442
0
    }
1443
1444
2.23k
    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
1445
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
1446
0
        return SECFailure;
1447
0
    }
1448
1449
2.23k
    pssEncoded = em = (unsigned char *)PORT_Alloc(modulusLen);
1450
2.23k
    if (pssEncoded == NULL) {
1451
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1452
0
        return SECFailure;
1453
0
    }
1454
1455
    /* len(em) == ceil((modulusBits - 1) / 8). */
1456
2.23k
    if (modulusBits % 8 == 1) {
1457
0
        em[0] = 0;
1458
0
        emLen--;
1459
0
        em++;
1460
0
    }
1461
2.23k
    rv = RSA_EMSAEncodePSS(em, emLen, modulusBits - 1, input, hashAlg,
1462
2.23k
                           maskHashAlg, salt, saltLength);
1463
2.23k
    if (rv != SECSuccess)
1464
0
        goto done;
1465
1466
    // This sets error codes upon failure.
1467
2.23k
    rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded);
1468
2.23k
    *outputLen = modulusLen;
1469
1470
2.23k
done:
1471
2.23k
    PORT_Free(pssEncoded);
1472
2.23k
    return rv;
1473
2.23k
}
1474
1475
SECStatus
1476
RSA_CheckSignPSS(RSAPublicKey *key,
1477
                 HASH_HashType hashAlg,
1478
                 HASH_HashType maskHashAlg,
1479
                 unsigned int saltLength,
1480
                 const unsigned char *sig,
1481
                 unsigned int sigLen,
1482
                 const unsigned char *hash,
1483
                 unsigned int hashLen)
1484
3.77k
{
1485
3.77k
    SECStatus rv;
1486
3.77k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1487
3.77k
    unsigned int modulusBits = rsa_modulusBits(&key->modulus);
1488
3.77k
    unsigned int emLen = modulusLen;
1489
3.77k
    unsigned char *buffer, *em;
1490
1491
3.77k
    if (sigLen != modulusLen) {
1492
45
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1493
45
        return SECFailure;
1494
45
    }
1495
1496
3.72k
    if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
1497
0
        PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
1498
0
        return SECFailure;
1499
0
    }
1500
1501
3.72k
    buffer = em = (unsigned char *)PORT_Alloc(modulusLen);
1502
3.72k
    if (!buffer) {
1503
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1504
0
        return SECFailure;
1505
0
    }
1506
1507
3.72k
    rv = RSA_PublicKeyOp(key, buffer, sig);
1508
3.72k
    if (rv != SECSuccess) {
1509
21
        PORT_Free(buffer);
1510
21
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1511
21
        return SECFailure;
1512
21
    }
1513
1514
    /* len(em) == ceil((modulusBits - 1) / 8). */
1515
3.70k
    if (modulusBits % 8 == 1) {
1516
1
        emLen--;
1517
1
        em++;
1518
1
    }
1519
3.70k
    rv = emsa_pss_verify(hash, em, emLen, modulusBits - 1, hashAlg,
1520
3.70k
                         maskHashAlg, saltLength);
1521
1522
3.70k
    PORT_Free(buffer);
1523
3.70k
    return rv;
1524
3.72k
}
1525
1526
SECStatus
1527
RSA_Sign(RSAPrivateKey *key,
1528
         unsigned char *output,
1529
         unsigned int *outputLen,
1530
         unsigned int maxOutputLen,
1531
         const unsigned char *input,
1532
         unsigned int inputLen)
1533
17.7k
{
1534
17.7k
    SECStatus rv = SECFailure;
1535
17.7k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1536
17.7k
    SECItem formatted = { siBuffer, NULL, 0 };
1537
17.7k
    SECItem unformatted = { siBuffer, (unsigned char *)input, inputLen };
1538
1539
17.7k
    if (maxOutputLen < modulusLen) {
1540
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1541
0
        goto done;
1542
0
    }
1543
1544
17.7k
    rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate,
1545
17.7k
                         &unformatted);
1546
17.7k
    if (rv != SECSuccess) {
1547
0
        PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
1548
0
        goto done;
1549
0
    }
1550
1551
    // This sets error codes upon failure.
1552
17.7k
    rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
1553
17.7k
    *outputLen = modulusLen;
1554
1555
17.7k
done:
1556
17.7k
    if (formatted.data != NULL) {
1557
17.7k
        PORT_ZFree(formatted.data, modulusLen);
1558
17.7k
    }
1559
17.7k
    return rv;
1560
17.7k
}
1561
1562
SECStatus
1563
RSA_CheckSign(RSAPublicKey *key,
1564
              const unsigned char *sig,
1565
              unsigned int sigLen,
1566
              const unsigned char *data,
1567
              unsigned int dataLen)
1568
1.46k
{
1569
1.46k
    SECStatus rv = SECFailure;
1570
1.46k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1571
1.46k
    unsigned int i;
1572
1.46k
    unsigned char *buffer = NULL;
1573
1574
1.46k
    if (sigLen != modulusLen) {
1575
39
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1576
39
        goto done;
1577
39
    }
1578
1579
    /*
1580
     * 0x00 || BT || Pad || 0x00 || ActualData
1581
     *
1582
     * The "3" below is the first octet + the second octet + the 0x00
1583
     * octet that always comes just before the ActualData.
1584
     */
1585
1.42k
    if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)) {
1586
0
        PORT_SetError(SEC_ERROR_BAD_DATA);
1587
0
        goto done;
1588
0
    }
1589
1590
1.42k
    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
1591
1.42k
    if (!buffer) {
1592
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1593
0
        goto done;
1594
0
    }
1595
1596
1.42k
    if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
1597
22
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1598
22
        goto done;
1599
22
    }
1600
1601
    /*
1602
     * check the padding that was used
1603
     */
1604
1.40k
    if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
1605
1.40k
        buffer[1] != (unsigned char)RSA_BlockPrivate) {
1606
1.32k
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1607
1.32k
        goto done;
1608
1.32k
    }
1609
6.80k
    for (i = 2; i < modulusLen - dataLen - 1; i++) {
1610
6.75k
        if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
1611
26
            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1612
26
            goto done;
1613
26
        }
1614
6.75k
    }
1615
56
    if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
1616
1
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1617
1
        goto done;
1618
1
    }
1619
1620
    /*
1621
     * make sure we get the same results
1622
     */
1623
55
    if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) == 0) {
1624
6
        rv = SECSuccess;
1625
6
    }
1626
1627
1.46k
done:
1628
1.46k
    if (buffer) {
1629
1.42k
        PORT_Free(buffer);
1630
1.42k
    }
1631
1.46k
    return rv;
1632
55
}
1633
1634
SECStatus
1635
RSA_CheckSignRecover(RSAPublicKey *key,
1636
                     unsigned char *output,
1637
                     unsigned int *outputLen,
1638
                     unsigned int maxOutputLen,
1639
                     const unsigned char *sig,
1640
                     unsigned int sigLen)
1641
16.0k
{
1642
16.0k
    SECStatus rv = SECFailure;
1643
16.0k
    unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1644
16.0k
    unsigned int i;
1645
16.0k
    unsigned char *buffer = NULL;
1646
16.0k
    unsigned int padLen;
1647
1648
16.0k
    if (sigLen != modulusLen) {
1649
454
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1650
454
        goto done;
1651
454
    }
1652
1653
15.5k
    buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
1654
15.5k
    if (!buffer) {
1655
0
        PORT_SetError(SEC_ERROR_NO_MEMORY);
1656
0
        goto done;
1657
0
    }
1658
1659
15.5k
    if (RSA_PublicKeyOp(key, buffer, sig) != SECSuccess) {
1660
100
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1661
100
        goto done;
1662
100
    }
1663
1664
15.4k
    *outputLen = 0;
1665
1666
    /*
1667
     * check the padding that was used
1668
     */
1669
15.4k
    if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
1670
15.4k
        buffer[1] != (unsigned char)RSA_BlockPrivate) {
1671
15.2k
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1672
15.2k
        goto done;
1673
15.2k
    }
1674
18.3k
    for (i = 2; i < modulusLen; i++) {
1675
18.3k
        if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) {
1676
165
            *outputLen = modulusLen - i - 1;
1677
165
            break;
1678
165
        }
1679
18.2k
        if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET) {
1680
46
            PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1681
46
            goto done;
1682
46
        }
1683
18.2k
    }
1684
166
    padLen = i - 2;
1685
166
    if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
1686
11
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1687
11
        goto done;
1688
11
    }
1689
155
    if (*outputLen == 0) {
1690
1
        PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1691
1
        goto done;
1692
1
    }
1693
154
    if (*outputLen > maxOutputLen) {
1694
0
        PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1695
0
        goto done;
1696
0
    }
1697
1698
154
    PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen);
1699
154
    rv = SECSuccess;
1700
1701
16.0k
done:
1702
16.0k
    if (buffer) {
1703
15.5k
        PORT_Free(buffer);
1704
15.5k
    }
1705
16.0k
    return rv;
1706
154
}