Coverage Report

Created: 2025-07-23 06:53

/src/wolfssl/wolfcrypt/src/kdf.c
Line
Count
Source (jump to first uncovered line)
1
/* kdf.c
2
 *
3
 * Copyright (C) 2006-2025 wolfSSL Inc.
4
 *
5
 * This file is part of wolfSSL.
6
 *
7
 * wolfSSL is free software; you can redistribute it and/or modify
8
 * it under the terms of the GNU General Public License as published by
9
 * the Free Software Foundation; either version 3 of the License, or
10
 * (at your option) any later version.
11
 *
12
 * wolfSSL is distributed in the hope that it will be useful,
13
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15
 * GNU General Public License for more details.
16
 *
17
 * You should have received a copy of the GNU General Public License
18
 * along with this program; if not, write to the Free Software
19
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335, USA
20
 */
21
22
#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
23
24
#ifndef NO_KDF
25
26
#if FIPS_VERSION3_GE(5,0,0)
27
    /* set NO_WRAPPERS before headers, use direct internal f()s not wrappers */
28
    #define FIPS_NO_WRAPPERS
29
30
    #ifdef USE_WINDOWS_API
31
        #pragma code_seg(".fipsA$h")
32
        #pragma const_seg(".fipsB$h")
33
    #endif
34
#endif
35
36
37
#ifdef NO_INLINE
38
    #include <wolfssl/wolfcrypt/misc.h>
39
#else
40
    #define WOLFSSL_MISC_INCLUDED
41
    #include <wolfcrypt/src/misc.c>
42
#endif
43
44
#include <wolfssl/wolfcrypt/hmac.h>
45
#include <wolfssl/wolfcrypt/kdf.h>
46
#ifdef WC_SRTP_KDF
47
#include <wolfssl/wolfcrypt/aes.h>
48
#endif
49
50
#if FIPS_VERSION3_GE(6,0,0)
51
    const unsigned int wolfCrypt_FIPS_kdf_ro_sanity[2] =
52
                                                     { 0x1a2b3c4d, 0x00000009 };
53
    int wolfCrypt_FIPS_KDF_sanity(void)
54
    {
55
        return 0;
56
    }
57
#endif
58
59
#if defined(WOLFSSL_HAVE_PRF) && !defined(NO_HMAC)
60
61
#ifdef WOLFSSL_SHA512
62
0
    #define P_HASH_MAX_SIZE WC_SHA512_DIGEST_SIZE
63
#elif defined(WOLFSSL_SHA384)
64
    #define P_HASH_MAX_SIZE WC_SHA384_DIGEST_SIZE
65
#else
66
    #define P_HASH_MAX_SIZE WC_SHA256_DIGEST_SIZE
67
#endif
68
69
/* Pseudo Random Function for MD5, SHA-1, SHA-256, SHA-384, or SHA-512 */
70
int wc_PRF(byte* result, word32 resLen, const byte* secret,
71
                  word32 secLen, const byte* seed, word32 seedLen, int hash,
72
                  void* heap, int devId)
73
0
{
74
0
    word32 len = P_HASH_MAX_SIZE;
75
0
    word32 times;
76
0
    word32 lastLen;
77
0
    word32 lastTime;
78
0
    int    ret = 0;
79
#ifdef WOLFSSL_SMALL_STACK
80
    byte*  current;
81
    Hmac*  hmac;
82
#else
83
0
    byte   current[P_HASH_MAX_SIZE];   /* max size */
84
0
    Hmac   hmac[1];
85
0
#endif
86
87
0
    switch (hash) {
88
    #ifndef NO_MD5
89
        case md5_mac:
90
            hash = WC_MD5;
91
            len  = WC_MD5_DIGEST_SIZE;
92
        break;
93
    #endif
94
95
0
    #ifndef NO_SHA256
96
0
        case sha256_mac:
97
0
            hash = WC_SHA256;
98
0
            len  = WC_SHA256_DIGEST_SIZE;
99
0
        break;
100
0
    #endif
101
102
0
    #ifdef WOLFSSL_SHA384
103
0
        case sha384_mac:
104
0
            hash = WC_SHA384;
105
0
            len  = WC_SHA384_DIGEST_SIZE;
106
0
        break;
107
0
    #endif
108
109
0
    #ifdef WOLFSSL_SHA512
110
0
        case sha512_mac:
111
0
            hash = WC_SHA512;
112
0
            len  = WC_SHA512_DIGEST_SIZE;
113
0
        break;
114
0
    #endif
115
116
    #ifdef WOLFSSL_SM3
117
        case sm3_mac:
118
            hash = WC_SM3;
119
            len  = WC_SM3_DIGEST_SIZE;
120
        break;
121
    #endif
122
123
0
    #ifndef NO_SHA
124
0
        case sha_mac:
125
0
            hash = WC_SHA;
126
0
            len  = WC_SHA_DIGEST_SIZE;
127
0
        break;
128
0
    #endif
129
0
        default:
130
0
            return HASH_TYPE_E;
131
0
    }
132
133
0
    times   = resLen / len;
134
0
    lastLen = resLen % len;
135
136
0
    if (lastLen)
137
0
        times += 1;
138
139
    /* times == 0 if resLen == 0, but times == 0 abides clang static analyzer
140
       while resLen == 0 doesn't */
141
0
    if (times == 0)
142
0
        return BAD_FUNC_ARG;
143
144
0
    lastTime = times - 1;
145
146
#ifdef WOLFSSL_SMALL_STACK
147
    current = (byte*)XMALLOC(P_HASH_MAX_SIZE, heap, DYNAMIC_TYPE_DIGEST);
148
    hmac    = (Hmac*)XMALLOC(sizeof(Hmac),    heap, DYNAMIC_TYPE_HMAC);
149
    if (current == NULL || hmac == NULL) {
150
        XFREE(current, heap, DYNAMIC_TYPE_DIGEST);
151
        XFREE(hmac, heap, DYNAMIC_TYPE_HMAC);
152
        return MEMORY_E;
153
    }
154
#endif
155
#ifdef WOLFSSL_CHECK_MEM_ZERO
156
    XMEMSET(current, 0xff, P_HASH_MAX_SIZE);
157
    wc_MemZero_Add("wc_PRF current", current, P_HASH_MAX_SIZE);
158
    wc_MemZero_Add("wc_PRF hmac", hmac, sizeof(Hmac));
159
#endif
160
161
0
    ret = wc_HmacInit(hmac, heap, devId);
162
0
    if (ret == 0) {
163
0
        ret = wc_HmacSetKey(hmac, hash, secret, secLen);
164
0
        if (ret == 0)
165
0
            ret = wc_HmacUpdate(hmac, seed, seedLen); /* A0 = seed */
166
0
        if (ret == 0)
167
0
            ret = wc_HmacFinal(hmac, current);        /* A1 */
168
0
        if (ret == 0) {
169
0
            word32 i;
170
0
            word32 idx = 0;
171
172
0
            for (i = 0; i < times; i++) {
173
0
                ret = wc_HmacUpdate(hmac, current, len);
174
0
                if (ret != 0)
175
0
                    break;
176
0
                ret = wc_HmacUpdate(hmac, seed, seedLen);
177
0
                if (ret != 0)
178
0
                    break;
179
0
                if ((i != lastTime) || !lastLen) {
180
0
                    ret = wc_HmacFinal(hmac, &result[idx]);
181
0
                    if (ret != 0)
182
0
                        break;
183
0
                    idx += len;
184
185
0
                    ret = wc_HmacUpdate(hmac, current, len);
186
0
                    if (ret != 0)
187
0
                        break;
188
0
                    ret = wc_HmacFinal(hmac, current);
189
0
                    if (ret != 0)
190
0
                        break;
191
0
                }
192
0
                else {
193
0
                    ret = wc_HmacFinal(hmac, current);
194
0
                    if (ret != 0)
195
0
                        break;
196
0
                    XMEMCPY(&result[idx], current,
197
0
                                             min(lastLen, P_HASH_MAX_SIZE));
198
0
                }
199
0
            }
200
0
        }
201
0
        wc_HmacFree(hmac);
202
0
    }
203
204
0
    ForceZero(current, P_HASH_MAX_SIZE);
205
0
    ForceZero(hmac,    sizeof(Hmac));
206
207
#if defined(WOLFSSL_CHECK_MEM_ZERO)
208
    wc_MemZero_Check(current, P_HASH_MAX_SIZE);
209
    wc_MemZero_Check(hmac,    sizeof(Hmac));
210
#endif
211
212
#ifdef WOLFSSL_SMALL_STACK
213
    XFREE(current, heap, DYNAMIC_TYPE_DIGEST);
214
    XFREE(hmac,     heap, DYNAMIC_TYPE_HMAC);
215
#endif
216
217
0
    return ret;
218
0
}
219
#undef P_HASH_MAX_SIZE
220
221
/* compute PRF (pseudo random function) using SHA1 and MD5 for TLSv1 */
222
int wc_PRF_TLSv1(byte* digest, word32 digLen, const byte* secret,
223
           word32 secLen, const byte* label, word32 labLen,
224
           const byte* seed, word32 seedLen, void* heap, int devId)
225
0
{
226
0
    int         ret  = 0;
227
0
    word32      half = (secLen + 1) / 2;
228
0
    const byte* md5_half;
229
0
    const byte* sha_half;
230
0
    byte*      md5_result;
231
#ifdef WOLFSSL_SMALL_STACK
232
    byte*      sha_result;
233
    byte*      labelSeed;
234
#else
235
0
    byte       sha_result[MAX_PRF_DIG];    /* digLen is real size */
236
0
    byte       labelSeed[MAX_PRF_LABSEED];
237
0
#endif
238
239
0
    if (half > MAX_PRF_HALF ||
240
0
        labLen + seedLen > MAX_PRF_LABSEED ||
241
0
        digLen > MAX_PRF_DIG)
242
0
    {
243
0
        return BUFFER_E;
244
0
    }
245
246
#ifdef WOLFSSL_SMALL_STACK
247
    sha_result = (byte*)XMALLOC(MAX_PRF_DIG, heap, DYNAMIC_TYPE_DIGEST);
248
    labelSeed = (byte*)XMALLOC(MAX_PRF_LABSEED, heap, DYNAMIC_TYPE_DIGEST);
249
    if (sha_result == NULL || labelSeed == NULL) {
250
        XFREE(sha_result, heap, DYNAMIC_TYPE_DIGEST);
251
        XFREE(labelSeed, heap, DYNAMIC_TYPE_DIGEST);
252
        return MEMORY_E;
253
    }
254
#endif
255
256
0
    md5_half = secret;
257
0
    sha_half = secret + half - secLen % 2;
258
0
    md5_result = digest;
259
260
0
    XMEMCPY(labelSeed, label, labLen);
261
0
    XMEMCPY(labelSeed + labLen, seed, seedLen);
262
263
0
    if ((ret = wc_PRF(md5_result, digLen, md5_half, half, labelSeed,
264
0
                                labLen + seedLen, md5_mac, heap, devId)) == 0) {
265
0
        if ((ret = wc_PRF(sha_result, digLen, sha_half, half, labelSeed,
266
0
                                labLen + seedLen, sha_mac, heap, devId)) == 0) {
267
        #ifdef WOLFSSL_CHECK_MEM_ZERO
268
            wc_MemZero_Add("wc_PRF_TLSv1 sha_result", sha_result, digLen);
269
        #endif
270
            /* calculate XOR for TLSv1 PRF */
271
            /* md5 result is placed directly in digest */
272
0
            xorbuf(digest, sha_result, digLen);
273
0
            ForceZero(sha_result, digLen);
274
0
        }
275
0
    }
276
277
#if defined(WOLFSSL_CHECK_MEM_ZERO)
278
    wc_MemZero_Check(sha_result, MAX_PRF_DIG);
279
#endif
280
281
#ifdef WOLFSSL_SMALL_STACK
282
    XFREE(sha_result, heap, DYNAMIC_TYPE_DIGEST);
283
    XFREE(labelSeed, heap, DYNAMIC_TYPE_DIGEST);
284
#endif
285
286
0
    return ret;
287
0
}
288
289
/* Wrapper for TLS 1.2 and TLSv1 cases to calculate PRF */
290
/* In TLS 1.2 case call straight thru to wc_PRF */
291
int wc_PRF_TLS(byte* digest, word32 digLen, const byte* secret, word32 secLen,
292
            const byte* label, word32 labLen, const byte* seed, word32 seedLen,
293
            int useAtLeastSha256, int hash_type, void* heap, int devId)
294
0
{
295
0
    int ret = 0;
296
297
#ifdef WOLFSSL_DEBUG_TLS
298
    WOLFSSL_MSG("  secret");
299
    WOLFSSL_BUFFER(secret, secLen);
300
    WOLFSSL_MSG("  label");
301
    WOLFSSL_BUFFER(label, labLen);
302
    WOLFSSL_MSG("  seed");
303
    WOLFSSL_BUFFER(seed, seedLen);
304
#endif
305
306
307
0
    if (useAtLeastSha256) {
308
    #ifdef WOLFSSL_SMALL_STACK
309
        byte* labelSeed;
310
    #else
311
0
        byte  labelSeed[MAX_PRF_LABSEED];
312
0
    #endif
313
314
0
        if (labLen + seedLen > MAX_PRF_LABSEED) {
315
0
            return BUFFER_E;
316
0
        }
317
318
    #ifdef WOLFSSL_SMALL_STACK
319
        labelSeed = (byte*)XMALLOC(MAX_PRF_LABSEED, heap, DYNAMIC_TYPE_DIGEST);
320
        if (labelSeed == NULL) {
321
            return MEMORY_E;
322
        }
323
    #endif
324
325
0
        XMEMCPY(labelSeed, label, labLen);
326
0
        XMEMCPY(labelSeed + labLen, seed, seedLen);
327
328
        /* If a cipher suite wants an algorithm better than sha256, it
329
         * should use better. */
330
0
        if (hash_type < sha256_mac || hash_type == blake2b_mac) {
331
0
            hash_type = sha256_mac;
332
0
        }
333
        /* compute PRF for MD5, SHA-1, SHA-256, or SHA-384 for TLSv1.2 PRF */
334
0
        ret = wc_PRF(digest, digLen, secret, secLen, labelSeed,
335
0
                     labLen + seedLen, hash_type, heap, devId);
336
337
    #ifdef WOLFSSL_SMALL_STACK
338
        XFREE(labelSeed, heap, DYNAMIC_TYPE_DIGEST);
339
    #endif
340
0
    }
341
0
    else {
342
#ifndef NO_OLD_TLS
343
        /* compute TLSv1 PRF (pseudo random function using HMAC) */
344
        ret = wc_PRF_TLSv1(digest, digLen, secret, secLen, label, labLen, seed,
345
                          seedLen, heap, devId);
346
#else
347
0
        ret = BAD_FUNC_ARG;
348
0
#endif
349
0
    }
350
351
#ifdef WOLFSSL_DEBUG_TLS
352
    WOLFSSL_MSG("  digest");
353
    WOLFSSL_BUFFER(digest, digLen);
354
    WOLFSSL_MSG_EX("hash_type %d", hash_type);
355
#endif
356
357
0
    return ret;
358
0
}
359
#endif /* WOLFSSL_HAVE_PRF && !NO_HMAC */
360
361
362
#if defined(HAVE_HKDF) && !defined(NO_HMAC)
363
364
    /* Extract data using HMAC, salt and input.
365
     * RFC 5869 - HMAC-based Extract-and-Expand Key Derivation Function (HKDF)
366
     */
367
    int wc_Tls13_HKDF_Extract_ex(byte* prk, const byte* salt, word32 saltLen,
368
        byte* ikm, word32 ikmLen, int digest, void* heap, int devId)
369
0
    {
370
0
        int ret;
371
0
        word32 len = 0;
372
373
0
        switch (digest) {
374
0
            #ifndef NO_SHA256
375
0
            case WC_SHA256:
376
0
                len = WC_SHA256_DIGEST_SIZE;
377
0
                break;
378
0
            #endif
379
380
0
            #ifdef WOLFSSL_SHA384
381
0
            case WC_SHA384:
382
0
                len = WC_SHA384_DIGEST_SIZE;
383
0
                break;
384
0
            #endif
385
386
            #ifdef WOLFSSL_TLS13_SHA512
387
            case WC_SHA512:
388
                len = WC_SHA512_DIGEST_SIZE;
389
                break;
390
            #endif
391
392
            #ifdef WOLFSSL_SM3
393
            case WC_SM3:
394
                len = WC_SM3_DIGEST_SIZE;
395
                break;
396
            #endif
397
398
0
            default:
399
0
                return BAD_FUNC_ARG;
400
0
        }
401
402
        /* When length is 0 then use zeroed data of digest length. */
403
0
        if (ikmLen == 0) {
404
0
            ikmLen = len;
405
0
            XMEMSET(ikm, 0, len);
406
0
        }
407
408
#ifdef WOLFSSL_DEBUG_TLS
409
        WOLFSSL_MSG("  Salt");
410
        WOLFSSL_BUFFER(salt, saltLen);
411
        WOLFSSL_MSG("  IKM");
412
        WOLFSSL_BUFFER(ikm, ikmLen);
413
#endif
414
415
0
#if !defined(HAVE_SELFTEST) && (!defined(HAVE_FIPS) || \
416
0
    (defined(FIPS_VERSION_GE) && FIPS_VERSION_GE(5,3)))
417
0
        ret = wc_HKDF_Extract_ex(digest, salt, saltLen, ikm, ikmLen, prk, heap,
418
0
            devId);
419
#else
420
        ret = wc_HKDF_Extract(digest, salt, saltLen, ikm, ikmLen, prk);
421
        (void)heap;
422
        (void)devId;
423
#endif
424
425
#ifdef WOLFSSL_DEBUG_TLS
426
        WOLFSSL_MSG("  PRK");
427
        WOLFSSL_BUFFER(prk, len);
428
#endif
429
430
0
        return ret;
431
0
    }
432
433
    int wc_Tls13_HKDF_Extract(byte* prk, const byte* salt, word32 saltLen,
434
                                 byte* ikm, word32 ikmLen, int digest)
435
0
    {
436
0
        return wc_Tls13_HKDF_Extract_ex(prk, salt, saltLen, ikm, ikmLen, digest,
437
0
            NULL, INVALID_DEVID);
438
0
    }
439
440
    /* Expand data using HMAC, salt and label and info.
441
     * TLS v1.3 defines this function. */
442
    int wc_Tls13_HKDF_Expand_Label_ex(byte* okm, word32 okmLen,
443
                                 const byte* prk, word32 prkLen,
444
                                 const byte* protocol, word32 protocolLen,
445
                                 const byte* label, word32 labelLen,
446
                                 const byte* info, word32 infoLen,
447
                                 int digest, void* heap, int devId)
448
0
    {
449
0
        int    ret = 0;
450
0
        word32 idx = 0;
451
    #ifdef WOLFSSL_SMALL_STACK
452
        byte*  data;
453
    #else
454
0
        byte   data[MAX_TLS13_HKDF_LABEL_SZ];
455
0
    #endif
456
457
        /* okmLen (2) + protocol|label len (1) + info len(1) + protocollen +
458
         * labellen + infolen */
459
0
        idx = 4 + protocolLen + labelLen + infoLen;
460
0
        if (idx > MAX_TLS13_HKDF_LABEL_SZ) {
461
0
            return BUFFER_E;
462
0
        }
463
464
    #ifdef WOLFSSL_SMALL_STACK
465
        data = (byte*)XMALLOC(idx, NULL, DYNAMIC_TYPE_TMP_BUFFER);
466
        if (data == NULL) {
467
            return MEMORY_E;
468
        }
469
    #endif
470
0
        idx = 0;
471
472
        /* Output length. */
473
0
        data[idx++] = (byte)(okmLen >> 8);
474
0
        data[idx++] = (byte)okmLen;
475
        /* Length of protocol | label. */
476
0
        data[idx++] = (byte)(protocolLen + labelLen);
477
0
        if (protocolLen > 0) {
478
            /* Protocol */
479
0
            XMEMCPY(&data[idx], protocol, protocolLen);
480
0
            idx += protocolLen;
481
0
        }
482
0
        if (labelLen > 0) {
483
            /* Label */
484
0
            XMEMCPY(&data[idx], label, labelLen);
485
0
            idx += labelLen;
486
0
        }
487
        /* Length of hash of messages */
488
0
        data[idx++] = (byte)infoLen;
489
0
        if (infoLen > 0) {
490
            /* Hash of messages */
491
0
            XMEMCPY(&data[idx], info, infoLen);
492
0
            idx += infoLen;
493
0
        }
494
495
    #ifdef WOLFSSL_CHECK_MEM_ZERO
496
        wc_MemZero_Add("wc_Tls13_HKDF_Expand_Label data", data, idx);
497
    #endif
498
499
#ifdef WOLFSSL_DEBUG_TLS
500
        WOLFSSL_MSG("  PRK");
501
        WOLFSSL_BUFFER(prk, prkLen);
502
        WOLFSSL_MSG("  Info");
503
        WOLFSSL_BUFFER(data, idx);
504
        WOLFSSL_MSG_EX("  Digest %d", digest);
505
#endif
506
507
0
#if !defined(HAVE_SELFTEST) && (!defined(HAVE_FIPS) || \
508
0
    (defined(FIPS_VERSION_GE) && FIPS_VERSION_GE(5,3)))
509
0
        ret = wc_HKDF_Expand_ex(digest, prk, prkLen, data, idx, okm, okmLen,
510
0
            heap, devId);
511
#else
512
        ret = wc_HKDF_Expand(digest, prk, prkLen, data, idx, okm, okmLen);
513
        (void)heap;
514
        (void)devId;
515
#endif
516
517
#ifdef WOLFSSL_DEBUG_TLS
518
        WOLFSSL_MSG("  OKM");
519
        WOLFSSL_BUFFER(okm, okmLen);
520
#endif
521
522
0
        ForceZero(data, idx);
523
524
    #ifdef WOLFSSL_CHECK_MEM_ZERO
525
        wc_MemZero_Check(data, idx);
526
    #endif
527
    #ifdef WOLFSSL_SMALL_STACK
528
        XFREE(data, NULL, DYNAMIC_TYPE_TMP_BUFFER);
529
    #endif
530
0
        return ret;
531
0
    }
532
533
    int wc_Tls13_HKDF_Expand_Label(byte* okm, word32 okmLen,
534
                                 const byte* prk, word32 prkLen,
535
                                 const byte* protocol, word32 protocolLen,
536
                                 const byte* label, word32 labelLen,
537
                                 const byte* info, word32 infoLen,
538
                                 int digest)
539
0
    {
540
0
        return wc_Tls13_HKDF_Expand_Label_ex(okm, okmLen, prk, prkLen, protocol,
541
0
            protocolLen, label, labelLen, info, infoLen, digest,
542
0
            NULL, INVALID_DEVID);
543
0
    }
544
545
#if defined(WOLFSSL_TICKET_NONCE_MALLOC) &&                                    \
546
    (!defined(HAVE_FIPS) || (defined(FIPS_VERSION_GE) && FIPS_VERSION_GE(5,3)))
547
    /* Expand data using HMAC, salt and label and info.
548
     * TLS v1.3 defines this function. */
549
    int wc_Tls13_HKDF_Expand_Label_Alloc(byte* okm, word32 okmLen,
550
        const byte* prk, word32 prkLen, const byte* protocol,
551
        word32 protocolLen, const byte* label, word32 labelLen,
552
        const byte* info, word32 infoLen, int digest, void* heap)
553
    {
554
        int    ret = 0;
555
        word32 idx = 0;
556
        size_t len;
557
        byte   *data;
558
559
        (void)heap;
560
        /* okmLen (2) + protocol|label len (1) + info len(1) + protocollen +
561
         * labellen + infolen */
562
        len = 4U + protocolLen + labelLen + infoLen;
563
564
        data = (byte*)XMALLOC(len, heap, DYNAMIC_TYPE_TMP_BUFFER);
565
        if (data == NULL)
566
            return BUFFER_E;
567
568
        /* Output length. */
569
        data[idx++] = (byte)(okmLen >> 8);
570
        data[idx++] = (byte)okmLen;
571
        /* Length of protocol | label. */
572
        data[idx++] = (byte)(protocolLen + labelLen);
573
        /* Protocol */
574
        XMEMCPY(&data[idx], protocol, protocolLen);
575
        idx += protocolLen;
576
        /* Label */
577
        XMEMCPY(&data[idx], label, labelLen);
578
        idx += labelLen;
579
        /* Length of hash of messages */
580
        data[idx++] = (byte)infoLen;
581
        /* Hash of messages */
582
        XMEMCPY(&data[idx], info, infoLen);
583
        idx += infoLen;
584
585
    #ifdef WOLFSSL_CHECK_MEM_ZERO
586
        wc_MemZero_Add("wc_Tls13_HKDF_Expand_Label data", data, idx);
587
    #endif
588
589
#ifdef WOLFSSL_DEBUG_TLS
590
        WOLFSSL_MSG("  PRK");
591
        WOLFSSL_BUFFER(prk, prkLen);
592
        WOLFSSL_MSG("  Info");
593
        WOLFSSL_BUFFER(data, idx);
594
        WOLFSSL_MSG_EX("  Digest %d", digest);
595
#endif
596
597
        ret = wc_HKDF_Expand(digest, prk, prkLen, data, idx, okm, okmLen);
598
599
#ifdef WOLFSSL_DEBUG_TLS
600
        WOLFSSL_MSG("  OKM");
601
        WOLFSSL_BUFFER(okm, okmLen);
602
#endif
603
604
        ForceZero(data, idx);
605
606
    #ifdef WOLFSSL_CHECK_MEM_ZERO
607
        wc_MemZero_Check(data, len);
608
    #endif
609
        XFREE(data, heap, DYNAMIC_TYPE_TMP_BUFFER);
610
        return ret;
611
    }
612
613
#endif
614
/* defined(WOLFSSL_TICKET_NONCE_MALLOC) && (!defined(HAVE_FIPS) ||
615
 *  FIPS_VERSION_GE(5,3)) */
616
617
#endif /* HAVE_HKDF && !NO_HMAC */
618
619
620
#ifdef WOLFSSL_WOLFSSH
621
622
/* hash union */
623
typedef union {
624
#ifndef NO_MD5
625
    wc_Md5 md5;
626
#endif
627
#ifndef NO_SHA
628
    wc_Sha sha;
629
#endif
630
#ifdef WOLFSSL_SHA224
631
    wc_Sha224 sha224;
632
#endif
633
#ifndef NO_SHA256
634
    wc_Sha256 sha256;
635
#endif
636
#ifdef WOLFSSL_SHA384
637
    wc_Sha384 sha384;
638
#endif
639
#ifdef WOLFSSL_SHA512
640
    wc_Sha512 sha512;
641
#endif
642
#ifdef WOLFSSL_SHA3
643
    wc_Sha3 sha3;
644
#endif
645
} _hash;
646
647
static
648
int _HashInit(byte hashId, _hash* hash)
649
{
650
    int ret = WC_NO_ERR_TRACE(BAD_FUNC_ARG);
651
652
    switch (hashId) {
653
    #ifndef NO_SHA
654
        case WC_SHA:
655
            ret = wc_InitSha(&hash->sha);
656
            break;
657
    #endif /* !NO_SHA */
658
659
    #ifndef NO_SHA256
660
        case WC_SHA256:
661
            ret = wc_InitSha256(&hash->sha256);
662
            break;
663
    #endif /* !NO_SHA256 */
664
665
    #ifdef WOLFSSL_SHA384
666
        case WC_SHA384:
667
            ret = wc_InitSha384(&hash->sha384);
668
            break;
669
    #endif /* WOLFSSL_SHA384 */
670
    #ifdef WOLFSSL_SHA512
671
        case WC_SHA512:
672
            ret = wc_InitSha512(&hash->sha512);
673
            break;
674
    #endif /* WOLFSSL_SHA512 */
675
        default:
676
            ret = BAD_FUNC_ARG;
677
            break;
678
    }
679
680
    return ret;
681
}
682
683
static
684
int _HashUpdate(byte hashId, _hash* hash,
685
        const byte* data, word32 dataSz)
686
{
687
    int ret = WC_NO_ERR_TRACE(BAD_FUNC_ARG);
688
689
    switch (hashId) {
690
    #ifndef NO_SHA
691
        case WC_SHA:
692
            ret = wc_ShaUpdate(&hash->sha, data, dataSz);
693
            break;
694
    #endif /* !NO_SHA */
695
696
    #ifndef NO_SHA256
697
        case WC_SHA256:
698
            ret = wc_Sha256Update(&hash->sha256, data, dataSz);
699
            break;
700
    #endif /* !NO_SHA256 */
701
702
    #ifdef WOLFSSL_SHA384
703
        case WC_SHA384:
704
            ret = wc_Sha384Update(&hash->sha384, data, dataSz);
705
            break;
706
    #endif /* WOLFSSL_SHA384 */
707
    #ifdef WOLFSSL_SHA512
708
        case WC_SHA512:
709
            ret = wc_Sha512Update(&hash->sha512, data, dataSz);
710
            break;
711
    #endif /* WOLFSSL_SHA512 */
712
        default:
713
            ret = BAD_FUNC_ARG;
714
            break;
715
    }
716
717
    return ret;
718
}
719
720
static
721
int _HashFinal(byte hashId, _hash* hash, byte* digest)
722
{
723
    int ret = WC_NO_ERR_TRACE(BAD_FUNC_ARG);
724
725
    switch (hashId) {
726
    #ifndef NO_SHA
727
        case WC_SHA:
728
            ret = wc_ShaFinal(&hash->sha, digest);
729
            break;
730
    #endif /* !NO_SHA */
731
732
    #ifndef NO_SHA256
733
        case WC_SHA256:
734
            ret = wc_Sha256Final(&hash->sha256, digest);
735
            break;
736
    #endif /* !NO_SHA256 */
737
738
    #ifdef WOLFSSL_SHA384
739
        case WC_SHA384:
740
            ret = wc_Sha384Final(&hash->sha384, digest);
741
            break;
742
    #endif /* WOLFSSL_SHA384 */
743
    #ifdef WOLFSSL_SHA512
744
        case WC_SHA512:
745
            ret = wc_Sha512Final(&hash->sha512, digest);
746
            break;
747
    #endif /* WOLFSSL_SHA512 */
748
        default:
749
            ret = BAD_FUNC_ARG;
750
            break;
751
    }
752
753
    return ret;
754
}
755
756
static
757
void _HashFree(byte hashId, _hash* hash)
758
{
759
    switch (hashId) {
760
    #ifndef NO_SHA
761
        case WC_SHA:
762
            wc_ShaFree(&hash->sha);
763
            break;
764
    #endif /* !NO_SHA */
765
766
    #ifndef NO_SHA256
767
        case WC_SHA256:
768
            wc_Sha256Free(&hash->sha256);
769
            break;
770
    #endif /* !NO_SHA256 */
771
772
    #ifdef WOLFSSL_SHA384
773
        case WC_SHA384:
774
            wc_Sha384Free(&hash->sha384);
775
            break;
776
    #endif /* WOLFSSL_SHA384 */
777
    #ifdef WOLFSSL_SHA512
778
        case WC_SHA512:
779
            wc_Sha512Free(&hash->sha512);
780
            break;
781
    #endif /* WOLFSSL_SHA512 */
782
    }
783
}
784
785
786
#define LENGTH_SZ 4
787
788
int wc_SSH_KDF(byte hashId, byte keyId, byte* key, word32 keySz,
789
        const byte* k, word32 kSz, const byte* h, word32 hSz,
790
        const byte* sessionId, word32 sessionIdSz)
791
{
792
    word32 blocks, remainder;
793
    _hash hash;
794
    enum wc_HashType enmhashId = (enum wc_HashType)hashId;
795
    byte kPad = 0;
796
    byte pad = 0;
797
    byte kSzFlat[LENGTH_SZ];
798
    word32 digestSz;
799
    int ret;
800
801
    if (key == NULL || keySz == 0 ||
802
        k == NULL || kSz == 0 ||
803
        h == NULL || hSz == 0 ||
804
        sessionId == NULL || sessionIdSz == 0) {
805
806
        return BAD_FUNC_ARG;
807
    }
808
809
    ret = wc_HmacSizeByType((int)enmhashId);
810
    if (ret <= 0) {
811
        return BAD_FUNC_ARG;
812
    }
813
    digestSz = (word32)ret;
814
815
    if (k[0] & 0x80) kPad = 1;
816
    c32toa(kSz + kPad, kSzFlat);
817
818
    blocks = keySz / digestSz;
819
    remainder = keySz % digestSz;
820
821
    ret = _HashInit(enmhashId, &hash);
822
    if (ret == 0)
823
        ret = _HashUpdate(enmhashId, &hash, kSzFlat, LENGTH_SZ);
824
    if (ret == 0 && kPad)
825
        ret = _HashUpdate(enmhashId, &hash, &pad, 1);
826
    if (ret == 0)
827
        ret = _HashUpdate(enmhashId, &hash, k, kSz);
828
    if (ret == 0)
829
        ret = _HashUpdate(enmhashId, &hash, h, hSz);
830
    if (ret == 0)
831
        ret = _HashUpdate(enmhashId, &hash, &keyId, sizeof(keyId));
832
    if (ret == 0)
833
        ret = _HashUpdate(enmhashId, &hash, sessionId, sessionIdSz);
834
835
    if (ret == 0) {
836
        if (blocks == 0) {
837
            if (remainder > 0) {
838
                byte lastBlock[WC_MAX_DIGEST_SIZE];
839
                ret = _HashFinal(enmhashId, &hash, lastBlock);
840
                if (ret == 0)
841
                    XMEMCPY(key, lastBlock, remainder);
842
            }
843
        }
844
        else {
845
            word32 runningKeySz, curBlock;
846
847
            runningKeySz = digestSz;
848
            ret = _HashFinal(enmhashId, &hash, key);
849
850
            for (curBlock = 1; curBlock < blocks; curBlock++) {
851
                ret = _HashInit(enmhashId, &hash);
852
                if (ret != 0) break;
853
                ret = _HashUpdate(enmhashId, &hash, kSzFlat, LENGTH_SZ);
854
                if (ret != 0) break;
855
                if (kPad)
856
                    ret = _HashUpdate(enmhashId, &hash, &pad, 1);
857
                if (ret != 0) break;
858
                ret = _HashUpdate(enmhashId, &hash, k, kSz);
859
                if (ret != 0) break;
860
                ret = _HashUpdate(enmhashId, &hash, h, hSz);
861
                if (ret != 0) break;
862
                ret = _HashUpdate(enmhashId, &hash, key, runningKeySz);
863
                if (ret != 0) break;
864
                ret = _HashFinal(enmhashId, &hash, key + runningKeySz);
865
                if (ret != 0) break;
866
                runningKeySz += digestSz;
867
            }
868
869
            if (remainder > 0) {
870
                byte lastBlock[WC_MAX_DIGEST_SIZE];
871
                if (ret == 0)
872
                    ret = _HashInit(enmhashId, &hash);
873
                if (ret == 0)
874
                    ret = _HashUpdate(enmhashId, &hash, kSzFlat, LENGTH_SZ);
875
                if (ret == 0 && kPad)
876
                    ret = _HashUpdate(enmhashId, &hash, &pad, 1);
877
                if (ret == 0)
878
                    ret = _HashUpdate(enmhashId, &hash, k, kSz);
879
                if (ret == 0)
880
                    ret = _HashUpdate(enmhashId, &hash, h, hSz);
881
                if (ret == 0)
882
                    ret = _HashUpdate(enmhashId, &hash, key, runningKeySz);
883
                if (ret == 0)
884
                    ret = _HashFinal(enmhashId, &hash, lastBlock);
885
                if (ret == 0)
886
                    XMEMCPY(key + runningKeySz, lastBlock, remainder);
887
            }
888
        }
889
    }
890
891
    _HashFree(enmhashId, &hash);
892
893
    return ret;
894
}
895
896
#endif /* WOLFSSL_WOLFSSH */
897
898
#ifdef WC_SRTP_KDF
899
/* Calculate first block to encrypt.
900
 *
901
 * @param [in]  salt     Random value to XOR in.
902
 * @param [in]  saltSz   Size of random value in bytes.
903
 * @param [in]  kdrIdx   Key derivation rate. kdr = 0 when -1, otherwise
904
 *                       kdr = 2^kdrIdx.
905
 * @param [in]  idx      Index value to XOR in.
906
 * @param [in]  idxSz    Size of index value in bytes.
907
 * @param [out] block    First block to encrypt.
908
 */
909
static void wc_srtp_kdf_first_block(const byte* salt, word32 saltSz, int kdrIdx,
910
        const byte* idx, int idxSz, unsigned char* block)
911
{
912
    int i;
913
914
    /* XOR salt into zeroized buffer. */
915
    for (i = 0; i < WC_SRTP_MAX_SALT - (int)saltSz; i++) {
916
        block[i] = 0;
917
    }
918
    XMEMCPY(block + WC_SRTP_MAX_SALT - saltSz, salt, saltSz);
919
    block[WC_SRTP_MAX_SALT] = 0;
920
    /* block[15] is counter. */
921
922
    /* When kdrIdx is -1, don't XOR in index. */
923
    if (kdrIdx >= 0) {
924
        /* Get the number of bits to shift index by. */
925
        word32 bits = kdrIdx & 0x7;
926
        /* Reduce index size by number of bytes to remove. */
927
        idxSz -= kdrIdx >> 3;
928
929
        if ((kdrIdx & 0x7) == 0) {
930
            /* Just XOR in as no bit shifting. */
931
            for (i = 0; i < idxSz; i++) {
932
                block[i + WC_SRTP_MAX_SALT - idxSz] ^= idx[i];
933
            }
934
        }
935
        else {
936
            /* XOR in as bit shifted index. */
937
            block[WC_SRTP_MAX_SALT - idxSz] ^= (byte)(idx[0] >> bits);
938
            for (i = 1; i < idxSz; i++) {
939
                block[i + WC_SRTP_MAX_SALT - idxSz] ^=
940
                    (byte)((idx[i-1] << (8 - bits)) |
941
                           (idx[i+0] >>      bits ));
942
            }
943
        }
944
    }
945
}
946
947
/* Derive a key given the first block.
948
 *
949
 * @param [in, out] block    First block to encrypt. Need label XORed in.
950
 * @param [in]      indexSz  Size of index in bytes to calculate where label is
951
 *                           XORed into.
952
 * @param [in]      label    Label byte that differs for each key.
953
 * @param [out]     key      Derived key.
954
 * @param [in]      keySz    Size of key to derive in bytes.
955
 * @param [in]      aes      AES object to encrypt with.
956
 * @return  0 on success.
957
 */
958
static int wc_srtp_kdf_derive_key(byte* block, int idxSz, byte label,
959
        byte* key, word32 keySz, Aes* aes)
960
{
961
    int i;
962
    int ret = 0;
963
    /* Calculate the number of full blocks needed for derived key. */
964
    int blocks = (int)(keySz / WC_AES_BLOCK_SIZE);
965
966
    /* XOR in label. */
967
    block[WC_SRTP_MAX_SALT - idxSz - 1] ^= label;
968
    for (i = 0; (ret == 0) && (i < blocks); i++) {
969
        /* Set counter. */
970
        block[15] = (byte)i;
971
        /* Encrypt block into key buffer. */
972
        ret = wc_AesEcbEncrypt(aes, key, block, WC_AES_BLOCK_SIZE);
973
        /* Reposition for more derived key. */
974
        key += WC_AES_BLOCK_SIZE;
975
        /* Reduce the count of key bytes required. */
976
        keySz -= WC_AES_BLOCK_SIZE;
977
    }
978
    /* Do any partial blocks. */
979
    if ((ret == 0) && (keySz > 0)) {
980
        byte enc[WC_AES_BLOCK_SIZE];
981
        /* Set counter. */
982
        block[15] = (byte)i;
983
        /* Encrypt block into temporary. */
984
        ret = wc_AesEcbEncrypt(aes, enc, block, WC_AES_BLOCK_SIZE);
985
        if (ret == 0) {
986
            /* Copy into key required amount. */
987
            XMEMCPY(key, enc, keySz);
988
        }
989
    }
990
    /* XOR out label. */
991
    block[WC_SRTP_MAX_SALT - idxSz - 1] ^= label;
992
993
    return ret;
994
}
995
996
/* Derive keys using SRTP KDF algorithm.
997
 *
998
 * SP 800-135 (RFC 3711).
999
 *
1000
 * @param [in]  key      Key to use with encryption.
1001
 * @param [in]  keySz    Size of key in bytes.
1002
 * @param [in]  salt     Random non-secret value.
1003
 * @param [in]  saltSz   Size of random in bytes.
1004
 * @param [in]  kdrIdx   Key derivation rate. kdr = 0 when -1, otherwise
1005
 *                       kdr = 2^kdrIdx.
1006
 * @param [in]  index    Index value to XOR in.
1007
 * @param [out] key1     First key. Label value of 0x00.
1008
 * @param [in]  key1Sz   Size of first key in bytes.
1009
 * @param [out] key2     Second key. Label value of 0x01.
1010
 * @param [in]  key2Sz   Size of second key in bytes.
1011
 * @param [out] key3     Third key. Label value of 0x02.
1012
 * @param [in]  key3Sz   Size of third key in bytes.
1013
 * @return  BAD_FUNC_ARG when key or salt is NULL.
1014
 * @return  BAD_FUNC_ARG when key length is not 16, 24 or 32.
1015
 * @return  BAD_FUNC_ARG when saltSz is larger than 14.
1016
 * @return  BAD_FUNC_ARG when kdrIdx is less than -1 or larger than 24.
1017
 * @return  MEMORY_E on dynamic memory allocation failure.
1018
 * @return  0 on success.
1019
 */
1020
int wc_SRTP_KDF(const byte* key, word32 keySz, const byte* salt, word32 saltSz,
1021
        int kdrIdx, const byte* idx, byte* key1, word32 key1Sz, byte* key2,
1022
        word32 key2Sz, byte* key3, word32 key3Sz)
1023
{
1024
    int ret = 0;
1025
    byte block[WC_AES_BLOCK_SIZE];
1026
#ifdef WOLFSSL_SMALL_STACK
1027
    Aes* aes = NULL;
1028
#else
1029
    Aes aes[1];
1030
#endif
1031
    int aes_inited = 0;
1032
1033
    /* Validate parameters. */
1034
    if ((key == NULL) || (keySz > AES_256_KEY_SIZE) || (salt == NULL) ||
1035
            (saltSz > WC_SRTP_MAX_SALT) || (kdrIdx < -1) || (kdrIdx > 24)) {
1036
        ret = BAD_FUNC_ARG;
1037
    }
1038
1039
#ifdef WOLFSSL_SMALL_STACK
1040
    if (ret == 0) {
1041
        aes = (Aes*)XMALLOC(sizeof(Aes), NULL, DYNAMIC_TYPE_CIPHER);
1042
        if (aes == NULL) {
1043
            ret = MEMORY_E;
1044
        }
1045
    }
1046
#endif
1047
1048
    /* Setup AES object. */
1049
    if (ret == 0) {
1050
        ret = wc_AesInit(aes, NULL, INVALID_DEVID);
1051
    }
1052
    if (ret == 0) {
1053
        aes_inited = 1;
1054
        ret = wc_AesSetKey(aes, key, keySz, NULL, AES_ENCRYPTION);
1055
    }
1056
1057
    /* Calculate first block that can be used in each derivation. */
1058
    if (ret == 0) {
1059
        wc_srtp_kdf_first_block(salt, saltSz, kdrIdx, idx, WC_SRTP_INDEX_LEN,
1060
            block);
1061
    }
1062
1063
    /* Calculate first key if required. */
1064
    if ((ret == 0) && (key1 != NULL)) {
1065
        ret = wc_srtp_kdf_derive_key(block, WC_SRTP_INDEX_LEN,
1066
            WC_SRTP_LABEL_ENCRYPTION, key1, key1Sz, aes);
1067
    }
1068
    /* Calculate second key if required. */
1069
    if ((ret == 0) && (key2 != NULL)) {
1070
        ret = wc_srtp_kdf_derive_key(block, WC_SRTP_INDEX_LEN,
1071
            WC_SRTP_LABEL_MSG_AUTH, key2, key2Sz, aes);
1072
    }
1073
    /* Calculate third key if required. */
1074
    if ((ret == 0) && (key3 != NULL)) {
1075
        ret = wc_srtp_kdf_derive_key(block, WC_SRTP_INDEX_LEN,
1076
            WC_SRTP_LABEL_SALT, key3, key3Sz, aes);
1077
    }
1078
1079
    if (aes_inited)
1080
        wc_AesFree(aes);
1081
#ifdef WOLFSSL_SMALL_STACK
1082
    XFREE(aes, NULL, DYNAMIC_TYPE_CIPHER);
1083
#endif
1084
    return ret;
1085
}
1086
1087
/* Derive keys using SRTCP KDF algorithm.
1088
 *
1089
 * SP 800-135 (RFC 3711).
1090
 *
1091
 * @param [in]  key      Key to use with encryption.
1092
 * @param [in]  keySz    Size of key in bytes.
1093
 * @param [in]  salt     Random non-secret value.
1094
 * @param [in]  saltSz   Size of random in bytes.
1095
 * @param [in]  kdrIdx   Key derivation rate index. kdr = 0 when -1, otherwise
1096
 *                       kdr = 2^kdrIdx. See wc_SRTP_KDF_kdr_to_idx()
1097
 * @param [in]  index    Index value to XOR in.
1098
 * @param [out] key1     First key. Label value of 0x03.
1099
 * @param [in]  key1Sz   Size of first key in bytes.
1100
 * @param [out] key2     Second key. Label value of 0x04.
1101
 * @param [in]  key2Sz   Size of second key in bytes.
1102
 * @param [out] key3     Third key. Label value of 0x05.
1103
 * @param [in]  key3Sz   Size of third key in bytes.
1104
 * @return  BAD_FUNC_ARG when key or salt is NULL.
1105
 * @return  BAD_FUNC_ARG when key length is not 16, 24 or 32.
1106
 * @return  BAD_FUNC_ARG when saltSz is larger than 14.
1107
 * @return  BAD_FUNC_ARG when kdrIdx is less than -1 or larger than 24.
1108
 * @return  MEMORY_E on dynamic memory allocation failure.
1109
 * @return  0 on success.
1110
 */
1111
int wc_SRTCP_KDF_ex(const byte* key, word32 keySz, const byte* salt, word32 saltSz,
1112
        int kdrIdx, const byte* idx, byte* key1, word32 key1Sz, byte* key2,
1113
        word32 key2Sz, byte* key3, word32 key3Sz, int idxLenIndicator)
1114
{
1115
    int ret = 0;
1116
    byte block[WC_AES_BLOCK_SIZE];
1117
#ifdef WOLFSSL_SMALL_STACK
1118
    Aes* aes = NULL;
1119
#else
1120
    Aes aes[1];
1121
#endif
1122
    int aes_inited = 0;
1123
    int idxLen;
1124
1125
    if (idxLenIndicator == WC_SRTCP_32BIT_IDX) {
1126
        idxLen = WC_SRTCP_INDEX_LEN;
1127
    } else if (idxLenIndicator == WC_SRTCP_48BIT_IDX) {
1128
        idxLen = WC_SRTP_INDEX_LEN;
1129
    } else {
1130
        return BAD_FUNC_ARG; /* bad or invalid idxLenIndicator */
1131
    }
1132
1133
    /* Validate parameters. */
1134
    if ((key == NULL) || (keySz > AES_256_KEY_SIZE) || (salt == NULL) ||
1135
            (saltSz > WC_SRTP_MAX_SALT) || (kdrIdx < -1) || (kdrIdx > 24)) {
1136
        ret = BAD_FUNC_ARG;
1137
    }
1138
1139
#ifdef WOLFSSL_SMALL_STACK
1140
    if (ret == 0) {
1141
        aes = (Aes*)XMALLOC(sizeof(Aes), NULL, DYNAMIC_TYPE_CIPHER);
1142
        if (aes == NULL) {
1143
            ret = MEMORY_E;
1144
        }
1145
    }
1146
#endif
1147
1148
    /* Setup AES object. */
1149
    if (ret == 0) {
1150
        ret = wc_AesInit(aes, NULL, INVALID_DEVID);
1151
    }
1152
    if (ret == 0) {
1153
        aes_inited = 1;
1154
        ret = wc_AesSetKey(aes, key, keySz, NULL, AES_ENCRYPTION);
1155
    }
1156
1157
    /* Calculate first block that can be used in each derivation. */
1158
    if (ret == 0) {
1159
        wc_srtp_kdf_first_block(salt, saltSz, kdrIdx, idx, idxLen, block);
1160
    }
1161
1162
    /* Calculate first key if required. */
1163
    if ((ret == 0) && (key1 != NULL)) {
1164
        ret = wc_srtp_kdf_derive_key(block, idxLen,
1165
            WC_SRTCP_LABEL_ENCRYPTION, key1, key1Sz, aes);
1166
    }
1167
    /* Calculate second key if required. */
1168
    if ((ret == 0) && (key2 != NULL)) {
1169
        ret = wc_srtp_kdf_derive_key(block, idxLen,
1170
            WC_SRTCP_LABEL_MSG_AUTH, key2, key2Sz, aes);
1171
    }
1172
    /* Calculate third key if required. */
1173
    if ((ret == 0) && (key3 != NULL)) {
1174
        ret = wc_srtp_kdf_derive_key(block, idxLen,
1175
            WC_SRTCP_LABEL_SALT, key3, key3Sz, aes);
1176
    }
1177
1178
    if (aes_inited)
1179
        wc_AesFree(aes);
1180
#ifdef WOLFSSL_SMALL_STACK
1181
    XFREE(aes, NULL, DYNAMIC_TYPE_CIPHER);
1182
#endif
1183
    return ret;
1184
}
1185
1186
int wc_SRTCP_KDF(const byte* key, word32 keySz, const byte* salt, word32 saltSz,
1187
        int kdrIdx, const byte* idx, byte* key1, word32 key1Sz, byte* key2,
1188
        word32 key2Sz, byte* key3, word32 key3Sz)
1189
{
1190
    /* The default 32-bit IDX expected by many implementations */
1191
    return wc_SRTCP_KDF_ex(key, keySz, salt, saltSz, kdrIdx, idx,
1192
                           key1, key1Sz, key2, key2Sz, key3, key3Sz,
1193
                           WC_SRTCP_32BIT_IDX);
1194
}
1195
/* Derive key with label using SRTP KDF algorithm.
1196
 *
1197
 * SP 800-135 (RFC 3711).
1198
 *
1199
 * @param [in]  key       Key to use with encryption.
1200
 * @param [in]  keySz     Size of key in bytes.
1201
 * @param [in]  salt      Random non-secret value.
1202
 * @param [in]  saltSz    Size of random in bytes.
1203
 * @param [in]  kdrIdx    Key derivation rate index. kdr = 0 when -1, otherwise
1204
 *                        kdr = 2^kdrIdx. See wc_SRTP_KDF_kdr_to_idx()
1205
 * @param [in]  index     Index value to XOR in.
1206
 * @param [in]  label     Label to use when deriving key.
1207
 * @param [out] outKey    Derived key.
1208
 * @param [in]  outKeySz  Size of derived key in bytes.
1209
 * @return  BAD_FUNC_ARG when key, salt or outKey is NULL.
1210
 * @return  BAD_FUNC_ARG when key length is not 16, 24 or 32.
1211
 * @return  BAD_FUNC_ARG when saltSz is larger than 14.
1212
 * @return  BAD_FUNC_ARG when kdrIdx is less than -1 or larger than 24.
1213
 * @return  MEMORY_E on dynamic memory allocation failure.
1214
 * @return  0 on success.
1215
 */
1216
int wc_SRTP_KDF_label(const byte* key, word32 keySz, const byte* salt,
1217
        word32 saltSz, int kdrIdx, const byte* idx, byte label, byte* outKey,
1218
        word32 outKeySz)
1219
{
1220
    int ret = 0;
1221
    byte block[WC_AES_BLOCK_SIZE];
1222
#ifdef WOLFSSL_SMALL_STACK
1223
    Aes* aes = NULL;
1224
#else
1225
    Aes aes[1];
1226
#endif
1227
    int aes_inited = 0;
1228
1229
    /* Validate parameters. */
1230
    if ((key == NULL) || (keySz > AES_256_KEY_SIZE) || (salt == NULL) ||
1231
            (saltSz > WC_SRTP_MAX_SALT) || (kdrIdx < -1) || (kdrIdx > 24) ||
1232
            (outKey == NULL)) {
1233
        ret = BAD_FUNC_ARG;
1234
    }
1235
1236
#ifdef WOLFSSL_SMALL_STACK
1237
    if (ret == 0) {
1238
        aes = (Aes*)XMALLOC(sizeof(Aes), NULL, DYNAMIC_TYPE_CIPHER);
1239
        if (aes == NULL) {
1240
            ret = MEMORY_E;
1241
        }
1242
    }
1243
#endif
1244
1245
    /* Setup AES object. */
1246
    if (ret == 0) {
1247
        ret = wc_AesInit(aes, NULL, INVALID_DEVID);
1248
    }
1249
    if (ret == 0) {
1250
        aes_inited = 1;
1251
        ret = wc_AesSetKey(aes, key, keySz, NULL, AES_ENCRYPTION);
1252
    }
1253
1254
    /* Calculate first block that can be used in each derivation. */
1255
    if (ret == 0) {
1256
        wc_srtp_kdf_first_block(salt, saltSz, kdrIdx, idx, WC_SRTP_INDEX_LEN,
1257
            block);
1258
    }
1259
    if (ret == 0) {
1260
        /* Calculate key. */
1261
        ret = wc_srtp_kdf_derive_key(block, WC_SRTP_INDEX_LEN, label, outKey,
1262
            outKeySz, aes);
1263
    }
1264
1265
    if (aes_inited)
1266
        wc_AesFree(aes);
1267
#ifdef WOLFSSL_SMALL_STACK
1268
    XFREE(aes, NULL, DYNAMIC_TYPE_CIPHER);
1269
#endif
1270
    return ret;
1271
1272
}
1273
1274
/* Derive key with label using SRTCP KDF algorithm.
1275
 *
1276
 * SP 800-135 (RFC 3711).
1277
 *
1278
 * @param [in]  key       Key to use with encryption.
1279
 * @param [in]  keySz     Size of key in bytes.
1280
 * @param [in]  salt      Random non-secret value.
1281
 * @param [in]  saltSz    Size of random in bytes.
1282
 * @param [in]  kdrIdx    Key derivation rate index. kdr = 0 when -1, otherwise
1283
 *                        kdr = 2^kdrIdx. See wc_SRTP_KDF_kdr_to_idx()
1284
 * @param [in]  index     Index value to XOR in.
1285
 * @param [in]  label     Label to use when deriving key.
1286
 * @param [out] outKey    Derived key.
1287
 * @param [in]  outKeySz  Size of derived key in bytes.
1288
 * @return  BAD_FUNC_ARG when key, salt or outKey is NULL.
1289
 * @return  BAD_FUNC_ARG when key length is not 16, 24 or 32.
1290
 * @return  BAD_FUNC_ARG when saltSz is larger than 14.
1291
 * @return  BAD_FUNC_ARG when kdrIdx is less than -1 or larger than 24.
1292
 * @return  MEMORY_E on dynamic memory allocation failure.
1293
 * @return  0 on success.
1294
 */
1295
int wc_SRTCP_KDF_label(const byte* key, word32 keySz, const byte* salt,
1296
        word32 saltSz, int kdrIdx, const byte* idx, byte label, byte* outKey,
1297
        word32 outKeySz)
1298
{
1299
    int ret = 0;
1300
    byte block[WC_AES_BLOCK_SIZE];
1301
#ifdef WOLFSSL_SMALL_STACK
1302
    Aes* aes = NULL;
1303
#else
1304
    Aes aes[1];
1305
#endif
1306
    int aes_inited = 0;
1307
1308
    /* Validate parameters. */
1309
    if ((key == NULL) || (keySz > AES_256_KEY_SIZE) || (salt == NULL) ||
1310
            (saltSz > WC_SRTP_MAX_SALT) || (kdrIdx < -1) || (kdrIdx > 24) ||
1311
            (outKey == NULL)) {
1312
        ret = BAD_FUNC_ARG;
1313
    }
1314
1315
#ifdef WOLFSSL_SMALL_STACK
1316
    if (ret == 0) {
1317
        aes = (Aes*)XMALLOC(sizeof(Aes), NULL, DYNAMIC_TYPE_CIPHER);
1318
        if (aes == NULL) {
1319
            ret = MEMORY_E;
1320
        }
1321
    }
1322
#endif
1323
1324
    /* Setup AES object. */
1325
    if (ret == 0) {
1326
        ret = wc_AesInit(aes, NULL, INVALID_DEVID);
1327
    }
1328
    if (ret == 0) {
1329
        aes_inited = 1;
1330
        ret = wc_AesSetKey(aes, key, keySz, NULL, AES_ENCRYPTION);
1331
    }
1332
1333
    /* Calculate first block that can be used in each derivation. */
1334
    if (ret == 0) {
1335
        wc_srtp_kdf_first_block(salt, saltSz, kdrIdx, idx, WC_SRTCP_INDEX_LEN,
1336
            block);
1337
    }
1338
    if (ret == 0) {
1339
        /* Calculate key. */
1340
        ret = wc_srtp_kdf_derive_key(block, WC_SRTCP_INDEX_LEN, label, outKey,
1341
            outKeySz, aes);
1342
    }
1343
1344
    if (aes_inited)
1345
        wc_AesFree(aes);
1346
#ifdef WOLFSSL_SMALL_STACK
1347
    XFREE(aes, NULL, DYNAMIC_TYPE_CIPHER);
1348
#endif
1349
    return ret;
1350
1351
}
1352
1353
/* Converts a kdr value to an index to use in SRTP/SRTCP KDF API.
1354
 *
1355
 * @param [in] kdr  Key derivation rate to convert.
1356
 * @return  Key derivation rate as an index.
1357
 */
1358
int wc_SRTP_KDF_kdr_to_idx(word32 kdr)
1359
{
1360
    int idx = -1;
1361
1362
    /* Keep shifting value down and incrementing index until top bit is gone. */
1363
    while (kdr != 0) {
1364
        kdr >>= 1;
1365
        idx++;
1366
    }
1367
1368
    /* Index of top bit set. */
1369
    return idx;
1370
}
1371
#endif /* WC_SRTP_KDF */
1372
1373
#ifdef WC_KDF_NIST_SP_800_56C
1374
static int wc_KDA_KDF_iteration(const byte* z, word32 zSz, word32 counter,
1375
    const byte* fixedInfo, word32 fixedInfoSz, enum wc_HashType hashType,
1376
    byte* output)
1377
{
1378
    byte counterBuf[4];
1379
    wc_HashAlg hash;
1380
    int ret;
1381
1382
    ret = wc_HashInit(&hash, hashType);
1383
    if (ret != 0)
1384
        return ret;
1385
    c32toa(counter, counterBuf);
1386
    ret = wc_HashUpdate(&hash, hashType, counterBuf, 4);
1387
    if (ret == 0) {
1388
        ret = wc_HashUpdate(&hash, hashType, z, zSz);
1389
    }
1390
    if (ret == 0 && fixedInfoSz > 0) {
1391
        ret = wc_HashUpdate(&hash, hashType, fixedInfo, fixedInfoSz);
1392
    }
1393
    if (ret == 0) {
1394
        ret = wc_HashFinal(&hash, hashType, output);
1395
    }
1396
    wc_HashFree(&hash, hashType);
1397
    return ret;
1398
}
1399
1400
/**
1401
 * \brief Performs the single-step key derivation function (KDF) as specified in
1402
 * SP800-56C option 1.
1403
 *
1404
 * \param [in] z The input keying material.
1405
 * \param [in] zSz The size of the input keying material.
1406
 * \param [in] fixedInfo The fixed information to be included in the KDF.
1407
 * \param [in] fixedInfoSz The size of the fixed information.
1408
 * \param [in] derivedSecretSz The desired size of the derived secret.
1409
 * \param [in] hashType The hash algorithm to be used in the KDF.
1410
 * \param [out] output The buffer to store the derived secret.
1411
 * \param [in] outputSz The size of the output buffer.
1412
 *
1413
 * \return 0 if the KDF operation is successful.
1414
 * \return BAD_FUNC_ARG if the input parameters are invalid.
1415
 * \return negative error code if the KDF operation fails.
1416
 */
1417
int wc_KDA_KDF_onestep(const byte* z, word32 zSz, const byte* fixedInfo,
1418
    word32 fixedInfoSz, word32 derivedSecretSz, enum wc_HashType hashType,
1419
    byte* output, word32 outputSz)
1420
{
1421
    byte hashTempBuf[WC_MAX_DIGEST_SIZE];
1422
    word32 counter, outIdx;
1423
    int hashOutSz;
1424
    int ret;
1425
1426
    if (output == NULL || outputSz < derivedSecretSz)
1427
        return BAD_FUNC_ARG;
1428
    if (z == NULL || zSz == 0 || (fixedInfoSz > 0 && fixedInfo == NULL))
1429
        return BAD_FUNC_ARG;
1430
    if (derivedSecretSz == 0)
1431
        return BAD_FUNC_ARG;
1432
1433
    hashOutSz = wc_HashGetDigestSize(hashType);
1434
    if (hashOutSz == WC_NO_ERR_TRACE(HASH_TYPE_E))
1435
        return BAD_FUNC_ARG;
1436
1437
    /* According to SP800_56C, table 1, the max input size (max_H_inputBits)
1438
     * depends on the HASH algo. The smaller value in the table is (2**64-1)/8.
1439
     * This is larger than the possible length using word32 integers. */
1440
1441
    counter = 1;
1442
    outIdx = 0;
1443
    ret = 0;
1444
1445
    /* According to SP800_56C the number of iterations shall not be greater than
1446
     * 2**32-1. This is not possible using word32 integers.*/
1447
    while (outIdx + hashOutSz <= derivedSecretSz) {
1448
        ret = wc_KDA_KDF_iteration(z, zSz, counter, fixedInfo, fixedInfoSz,
1449
            hashType, output + outIdx);
1450
        if (ret != 0)
1451
            break;
1452
        counter++;
1453
        outIdx += hashOutSz;
1454
    }
1455
1456
    if (ret == 0 && outIdx < derivedSecretSz) {
1457
        ret = wc_KDA_KDF_iteration(z, zSz, counter, fixedInfo, fixedInfoSz,
1458
            hashType, hashTempBuf);
1459
        if (ret == 0) {
1460
            XMEMCPY(output + outIdx, hashTempBuf, derivedSecretSz - outIdx);
1461
        }
1462
        ForceZero(hashTempBuf, hashOutSz);
1463
    }
1464
1465
    if (ret != 0) {
1466
        ForceZero(output, derivedSecretSz);
1467
    }
1468
1469
    return ret;
1470
}
1471
#endif /* WC_KDF_NIST_SP_800_56C */
1472
1473
#endif /* NO_KDF */