Coverage Report

Created: 2024-05-20 06:11

/src/FreeRDP/libfreerdp/core/rdstls.c
Line
Count
Source (jump to first uncovered line)
1
/**
2
 * FreeRDP: A Remote Desktop Protocol Implementation
3
 * RDSTLS Security protocol
4
 *
5
 * Copyright 2023 Joan Torres <joan.torres@suse.com>
6
 *
7
 * Licensed under the Apache License, Version 2.0 (the "License");
8
 * you may not use this file except in compliance with the License.
9
 * You may obtain a copy of the License at
10
 *
11
 *     http://www.apache.org/licenses/LICENSE-2.0
12
 *
13
 * Unless required by applicable law or agreed to in writing, software
14
 * distributed under the License is distributed on an "AS IS" BASIS,
15
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
 * See the License for the specific language governing permissions and
17
 * limitations under the License.
18
 */
19
20
#include <freerdp/config.h>
21
22
#include "settings.h"
23
24
#include <freerdp/log.h>
25
#include <freerdp/error.h>
26
#include <freerdp/settings.h>
27
28
#include <winpr/assert.h>
29
#include <winpr/stream.h>
30
#include <winpr/wlog.h>
31
32
#include "rdstls.h"
33
#include "transport.h"
34
#include "utils.h"
35
36
0
#define RDSTLS_VERSION_1 0x01
37
38
0
#define RDSTLS_TYPE_CAPABILITIES 0x01
39
0
#define RDSTLS_TYPE_AUTHREQ 0x02
40
0
#define RDSTLS_TYPE_AUTHRSP 0x04
41
42
0
#define RDSTLS_DATA_CAPABILITIES 0x01
43
0
#define RDSTLS_DATA_PASSWORD_CREDS 0x01
44
0
#define RDSTLS_DATA_AUTORECONNECT_COOKIE 0x02
45
0
#define RDSTLS_DATA_RESULT_CODE 0x01
46
47
typedef enum
48
{
49
  RDSTLS_STATE_INITIAL,
50
  RDSTLS_STATE_CAPABILITIES,
51
  RDSTLS_STATE_AUTH_REQ,
52
  RDSTLS_STATE_AUTH_RSP,
53
  RDSTLS_STATE_FINAL,
54
} RDSTLS_STATE;
55
56
struct rdp_rdstls
57
{
58
  BOOL server;
59
  RDSTLS_STATE state;
60
  rdpContext* context;
61
  rdpTransport* transport;
62
63
  UINT32 resultCode;
64
  wLog* log;
65
};
66
67
/**
68
 * Create new RDSTLS state machine.
69
 *
70
 * @param context A pointer to the rdp context to use
71
 *
72
 * @return new RDSTLS state machine.
73
 */
74
75
rdpRdstls* rdstls_new(rdpContext* context, rdpTransport* transport)
76
0
{
77
0
  WINPR_ASSERT(context);
78
0
  WINPR_ASSERT(transport);
79
80
0
  rdpSettings* settings = context->settings;
81
0
  WINPR_ASSERT(settings);
82
83
0
  rdpRdstls* rdstls = (rdpRdstls*)calloc(1, sizeof(rdpRdstls));
84
85
0
  if (!rdstls)
86
0
    return NULL;
87
0
  rdstls->log = WLog_Get(FREERDP_TAG("core.rdstls"));
88
0
  rdstls->context = context;
89
0
  rdstls->transport = transport;
90
0
  rdstls->server = settings->ServerMode;
91
92
0
  rdstls->state = RDSTLS_STATE_INITIAL;
93
94
0
  return rdstls;
95
0
}
96
97
/**
98
 * Free RDSTLS state machine.
99
 * @param rdstls The RDSTLS instance to free
100
 */
101
102
void rdstls_free(rdpRdstls* rdstls)
103
0
{
104
0
  free(rdstls);
105
0
}
106
107
static const char* rdstls_get_state_str(RDSTLS_STATE state)
108
0
{
109
0
  switch (state)
110
0
  {
111
0
    case RDSTLS_STATE_INITIAL:
112
0
      return "RDSTLS_STATE_INITIAL";
113
0
    case RDSTLS_STATE_CAPABILITIES:
114
0
      return "RDSTLS_STATE_CAPABILITIES";
115
0
    case RDSTLS_STATE_AUTH_REQ:
116
0
      return "RDSTLS_STATE_AUTH_REQ";
117
0
    case RDSTLS_STATE_AUTH_RSP:
118
0
      return "RDSTLS_STATE_AUTH_RSP";
119
0
    case RDSTLS_STATE_FINAL:
120
0
      return "RDSTLS_STATE_FINAL";
121
0
    default:
122
0
      return "UNKNOWN";
123
0
  }
124
0
}
125
126
static RDSTLS_STATE rdstls_get_state(rdpRdstls* rdstls)
127
0
{
128
0
  WINPR_ASSERT(rdstls);
129
0
  return rdstls->state;
130
0
}
131
132
static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expected,
133
                             RDSTLS_STATE requested)
134
0
{
135
0
  if (requested != expected)
136
0
  {
137
0
    WLog_Print(log, WLOG_ERROR,
138
0
               "Unexpected rdstls state transition from %s [%d] to %s [%d], expected %s [%d]",
139
0
               rdstls_get_state_str(current), current, rdstls_get_state_str(requested),
140
0
               requested, rdstls_get_state_str(expected), expected);
141
0
    return FALSE;
142
0
  }
143
0
  return TRUE;
144
0
}
145
146
static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state)
147
0
{
148
0
  BOOL rc = FALSE;
149
0
  WINPR_ASSERT(rdstls);
150
151
0
  WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state),
152
0
             rdstls_get_state_str(state));
153
154
0
  switch (rdstls->state)
155
0
  {
156
0
    case RDSTLS_STATE_INITIAL:
157
0
      rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
158
0
      break;
159
0
    case RDSTLS_STATE_CAPABILITIES:
160
0
      rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
161
0
      break;
162
0
    case RDSTLS_STATE_AUTH_REQ:
163
0
      rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
164
0
      break;
165
0
    case RDSTLS_STATE_AUTH_RSP:
166
0
      rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_FINAL, state);
167
0
      break;
168
0
    case RDSTLS_STATE_FINAL:
169
0
      rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
170
0
      break;
171
0
    default:
172
0
      WLog_Print(rdstls->log, WLOG_ERROR,
173
0
                 "Invalid rdstls state %s [%d], requested transition to %s [%d]",
174
0
                 rdstls_get_state_str(rdstls->state), rdstls->state,
175
0
                 rdstls_get_state_str(state), state);
176
0
      break;
177
0
  }
178
0
  if (rc)
179
0
    rdstls->state = state;
180
181
0
  return rc;
182
0
}
183
184
static BOOL rdstls_write_capabilities(rdpRdstls* rdstls, wStream* s)
185
0
{
186
0
  if (!Stream_EnsureRemainingCapacity(s, 6))
187
0
    return FALSE;
188
189
0
  Stream_Write_UINT16(s, RDSTLS_TYPE_CAPABILITIES);
190
0
  Stream_Write_UINT16(s, RDSTLS_DATA_CAPABILITIES);
191
0
  Stream_Write_UINT16(s, RDSTLS_VERSION_1);
192
193
0
  return TRUE;
194
0
}
195
196
static SSIZE_T rdstls_write_string(wStream* s, const char* str)
197
0
{
198
0
  const size_t pos = Stream_GetPosition(s);
199
200
0
  if (!Stream_EnsureRemainingCapacity(s, 2))
201
0
    return -1;
202
203
0
  if (!str)
204
0
  {
205
    /* Write unicode null */
206
0
    Stream_Write_UINT16(s, 2);
207
0
    if (!Stream_EnsureRemainingCapacity(s, 2))
208
0
      return -1;
209
210
0
    Stream_Write_UINT16(s, 0);
211
0
    return (SSIZE_T)(Stream_GetPosition(s) - pos);
212
0
  }
213
214
0
  const size_t length = (strlen(str) + 1);
215
216
0
  Stream_Write_UINT16(s, (UINT16)length * sizeof(WCHAR));
217
218
0
  if (!Stream_EnsureRemainingCapacity(s, length * sizeof(WCHAR)))
219
0
    return -1;
220
221
0
  if (Stream_Write_UTF16_String_From_UTF8(s, length, str, length, TRUE) < 0)
222
0
    return -1;
223
224
0
  return (SSIZE_T)(Stream_GetPosition(s) - pos);
225
0
}
226
227
static BOOL rdstls_write_data(wStream* s, UINT32 length, const BYTE* data)
228
0
{
229
0
  WINPR_ASSERT(data || (length == 0));
230
231
0
  if (!Stream_EnsureRemainingCapacity(s, 2))
232
0
    return FALSE;
233
234
0
  Stream_Write_UINT16(s, length);
235
236
0
  if (!Stream_EnsureRemainingCapacity(s, length))
237
0
    return FALSE;
238
239
0
  Stream_Write(s, data, length);
240
241
0
  return TRUE;
242
0
}
243
244
static BOOL rdstls_write_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
245
0
{
246
0
  rdpSettings* settings = rdstls->context->settings;
247
0
  WINPR_ASSERT(settings);
248
249
0
  if (!Stream_EnsureRemainingCapacity(s, 4))
250
0
    return FALSE;
251
252
0
  Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHREQ);
253
0
  Stream_Write_UINT16(s, RDSTLS_DATA_PASSWORD_CREDS);
254
255
0
  if (!rdstls_write_data(s, settings->RedirectionGuidLength, settings->RedirectionGuid))
256
0
    return FALSE;
257
258
0
  if (rdstls_write_string(s, settings->Username) < 0)
259
0
    return FALSE;
260
261
0
  if (rdstls_write_string(s, settings->Domain) < 0)
262
0
    return FALSE;
263
264
0
  if (!rdstls_write_data(s, settings->RedirectionPasswordLength, settings->RedirectionPassword))
265
0
    return FALSE;
266
267
0
  return TRUE;
268
0
}
269
270
static BOOL rdstls_write_authentication_request_with_cookie(rdpRdstls* rdstls, wStream* s)
271
0
{
272
  // TODO
273
0
  return FALSE;
274
0
}
275
276
static BOOL rdstls_write_authentication_response(rdpRdstls* rdstls, wStream* s)
277
0
{
278
0
  if (!Stream_EnsureRemainingCapacity(s, 8))
279
0
    return FALSE;
280
281
0
  Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHRSP);
282
0
  Stream_Write_UINT16(s, RDSTLS_DATA_RESULT_CODE);
283
0
  Stream_Write_UINT32(s, rdstls->resultCode);
284
285
0
  return TRUE;
286
0
}
287
288
static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s)
289
0
{
290
0
  UINT16 dataType = 0;
291
0
  UINT16 supportedVersions = 0;
292
293
0
  if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4))
294
0
    return FALSE;
295
296
0
  Stream_Read_UINT16(s, dataType);
297
0
  if (dataType != RDSTLS_DATA_CAPABILITIES)
298
0
  {
299
0
    WLog_Print(rdstls->log, WLOG_ERROR,
300
0
               "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16, dataType,
301
0
               RDSTLS_DATA_CAPABILITIES);
302
0
    return FALSE;
303
0
  }
304
305
0
  Stream_Read_UINT16(s, supportedVersions);
306
0
  if ((supportedVersions & RDSTLS_VERSION_1) == 0)
307
0
  {
308
0
    WLog_Print(rdstls->log, WLOG_ERROR,
309
0
               "received invalid supportedVersions=0x%04" PRIX16 ", expected 0x%04" PRIX16,
310
0
               supportedVersions, RDSTLS_VERSION_1);
311
0
    return FALSE;
312
0
  }
313
314
0
  return TRUE;
315
0
}
316
317
static BOOL rdstls_read_unicode_string(wLog* log, wStream* s, char** str)
318
0
{
319
0
  UINT16 length = 0;
320
321
0
  WINPR_ASSERT(str);
322
323
0
  if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2))
324
0
    return FALSE;
325
326
0
  Stream_Read_UINT16(s, length);
327
328
0
  if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length))
329
0
    return FALSE;
330
331
0
  if (length <= 2)
332
0
  {
333
0
    Stream_Seek(s, length);
334
0
    return TRUE;
335
0
  }
336
337
0
  *str = Stream_Read_UTF16_String_As_UTF8(s, length / sizeof(WCHAR), NULL);
338
0
  if (!*str)
339
0
    return FALSE;
340
341
0
  return TRUE;
342
0
}
343
344
static BOOL rdstls_read_data(wLog* log, wStream* s, UINT16* pLength, const BYTE** pData)
345
0
{
346
0
  UINT16 length = 0;
347
348
0
  WINPR_ASSERT(pLength);
349
0
  WINPR_ASSERT(pData);
350
351
0
  *pData = NULL;
352
0
  *pLength = 0;
353
0
  if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2))
354
0
    return FALSE;
355
356
0
  Stream_Read_UINT16(s, length);
357
358
0
  if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length))
359
0
    return FALSE;
360
361
0
  if (length <= 2)
362
0
  {
363
0
    Stream_Seek(s, length);
364
0
    return TRUE;
365
0
  }
366
367
0
  *pData = Stream_ConstPointer(s);
368
0
  *pLength = length;
369
0
  return Stream_SafeSeek(s, length);
370
0
}
371
372
static BOOL rdstls_cmp_data(wLog* log, const char* field, const BYTE* serverData,
373
                            const UINT32 serverDataLength, const BYTE* clientData,
374
                            const UINT16 clientDataLength)
375
0
{
376
0
  if (serverDataLength > 0)
377
0
  {
378
0
    if (clientDataLength == 0)
379
0
    {
380
0
      WLog_Print(log, WLOG_ERROR, "expected %s", field);
381
0
      return FALSE;
382
0
    }
383
384
0
    if (serverDataLength > UINT16_MAX || serverDataLength != clientDataLength ||
385
0
        memcmp(serverData, clientData, serverDataLength) != 0)
386
0
    {
387
0
      WLog_Print(log, WLOG_ERROR, "%s verification failed", field);
388
0
      return FALSE;
389
0
    }
390
0
  }
391
392
0
  return TRUE;
393
0
}
394
395
static BOOL rdstls_cmp_str(wLog* log, const char* field, const char* serverStr,
396
                           const char* clientStr)
397
0
{
398
0
  if (!utils_str_is_empty(serverStr))
399
0
  {
400
0
    if (utils_str_is_empty(clientStr))
401
0
    {
402
0
      WLog_Print(log, WLOG_ERROR, "expected %s", field);
403
0
      return FALSE;
404
0
    }
405
406
0
    if (strcmp(serverStr, clientStr) != 0)
407
0
    {
408
0
      WLog_Print(log, WLOG_ERROR, "%s verification failed", field);
409
0
      return FALSE;
410
0
    }
411
0
  }
412
413
0
  return TRUE;
414
0
}
415
416
static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
417
0
{
418
0
  BOOL rc = FALSE;
419
420
0
  const BYTE* clientRedirectionGuid = NULL;
421
0
  UINT16 clientRedirectionGuidLength = 0;
422
0
  char* clientPassword = NULL;
423
0
  char* clientUsername = NULL;
424
0
  char* clientDomain = NULL;
425
426
0
  const BYTE* serverRedirectionGuid = NULL;
427
0
  UINT16 serverRedirectionGuidLength = 0;
428
0
  const char* serverPassword = NULL;
429
0
  const char* serverUsername = NULL;
430
0
  const char* serverDomain = NULL;
431
432
0
  rdpSettings* settings = rdstls->context->settings;
433
0
  WINPR_ASSERT(settings);
434
435
0
  if (!rdstls_read_data(rdstls->log, s, &clientRedirectionGuidLength, &clientRedirectionGuid))
436
0
    goto fail;
437
438
0
  if (!rdstls_read_unicode_string(rdstls->log, s, &clientUsername))
439
0
    goto fail;
440
441
0
  if (!rdstls_read_unicode_string(rdstls->log, s, &clientDomain))
442
0
    goto fail;
443
444
0
  if (!rdstls_read_unicode_string(rdstls->log, s, &clientPassword))
445
0
    goto fail;
446
447
0
  serverRedirectionGuid = freerdp_settings_get_pointer(settings, FreeRDP_RedirectionGuid);
448
0
  serverRedirectionGuidLength =
449
0
      freerdp_settings_get_uint32(settings, FreeRDP_RedirectionGuidLength);
450
0
  serverUsername = freerdp_settings_get_string(settings, FreeRDP_Username);
451
0
  serverDomain = freerdp_settings_get_string(settings, FreeRDP_Domain);
452
0
  serverPassword = freerdp_settings_get_string(settings, FreeRDP_Password);
453
454
0
  rdstls->resultCode = ERROR_SUCCESS;
455
456
0
  if (!rdstls_cmp_data(rdstls->log, "RedirectionGuid", serverRedirectionGuid,
457
0
                       serverRedirectionGuidLength, clientRedirectionGuid,
458
0
                       clientRedirectionGuidLength))
459
0
    rdstls->resultCode = ERROR_LOGON_FAILURE;
460
461
0
  if (!rdstls_cmp_str(rdstls->log, "UserName", serverUsername, clientUsername))
462
0
    rdstls->resultCode = ERROR_LOGON_FAILURE;
463
464
0
  if (!rdstls_cmp_str(rdstls->log, "Domain", serverDomain, clientDomain))
465
0
    rdstls->resultCode = ERROR_LOGON_FAILURE;
466
467
0
  if (!rdstls_cmp_str(rdstls->log, "Password", serverPassword, clientPassword))
468
0
    rdstls->resultCode = ERROR_LOGON_FAILURE;
469
470
0
  rc = TRUE;
471
0
fail:
472
0
  return rc;
473
0
}
474
475
static BOOL rdstls_process_authentication_request_with_cookie(rdpRdstls* rdstls, wStream* s)
476
0
{
477
  // TODO
478
0
  return FALSE;
479
0
}
480
481
static BOOL rdstls_process_authentication_request(rdpRdstls* rdstls, wStream* s)
482
0
{
483
0
  UINT16 dataType = 0;
484
485
0
  if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 2))
486
0
    return FALSE;
487
488
0
  Stream_Read_UINT16(s, dataType);
489
0
  switch (dataType)
490
0
  {
491
0
    case RDSTLS_DATA_PASSWORD_CREDS:
492
0
      if (!rdstls_process_authentication_request_with_password(rdstls, s))
493
0
        return FALSE;
494
0
      break;
495
0
    case RDSTLS_DATA_AUTORECONNECT_COOKIE:
496
0
      if (!rdstls_process_authentication_request_with_cookie(rdstls, s))
497
0
        return FALSE;
498
0
      break;
499
0
    default:
500
0
      WLog_Print(rdstls->log, WLOG_ERROR,
501
0
                 "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16
502
0
                 " or 0x%04" PRIX16,
503
0
                 dataType, RDSTLS_DATA_PASSWORD_CREDS, RDSTLS_DATA_AUTORECONNECT_COOKIE);
504
0
      return FALSE;
505
0
  }
506
507
0
  return TRUE;
508
0
}
509
510
static BOOL rdstls_process_authentication_response(rdpRdstls* rdstls, wStream* s)
511
0
{
512
0
  UINT16 dataType = 0;
513
0
  UINT32 resultCode = 0;
514
515
0
  if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 6))
516
0
    return FALSE;
517
518
0
  Stream_Read_UINT16(s, dataType);
519
0
  if (dataType != RDSTLS_DATA_RESULT_CODE)
520
0
  {
521
0
    WLog_Print(rdstls->log, WLOG_ERROR,
522
0
               "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16, dataType,
523
0
               RDSTLS_DATA_RESULT_CODE);
524
0
    return FALSE;
525
0
  }
526
527
0
  Stream_Read_UINT32(s, resultCode);
528
0
  if (resultCode != ERROR_SUCCESS)
529
0
  {
530
0
    WLog_Print(rdstls->log, WLOG_ERROR, "resultCode: %s [0x%08" PRIX32 "] %s",
531
0
               freerdp_get_last_error_name(resultCode), resultCode,
532
0
               freerdp_get_last_error_string(resultCode));
533
0
    return FALSE;
534
0
  }
535
536
0
  return TRUE;
537
0
}
538
539
static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra)
540
0
{
541
0
  rdpRdstls* rdstls = (rdpRdstls*)extra;
542
0
  rdpSettings* settings = NULL;
543
544
0
  WINPR_ASSERT(transport);
545
0
  WINPR_ASSERT(s);
546
0
  WINPR_ASSERT(rdstls);
547
548
0
  settings = rdstls->context->settings;
549
0
  WINPR_ASSERT(settings);
550
551
0
  if (!Stream_EnsureRemainingCapacity(s, 2))
552
0
    return FALSE;
553
554
0
  Stream_Write_UINT16(s, RDSTLS_VERSION_1);
555
556
0
  const RDSTLS_STATE state = rdstls_get_state(rdstls);
557
0
  switch (state)
558
0
  {
559
0
    case RDSTLS_STATE_CAPABILITIES:
560
0
      if (!rdstls_write_capabilities(rdstls, s))
561
0
        return FALSE;
562
0
      break;
563
0
    case RDSTLS_STATE_AUTH_REQ:
564
0
      if (settings->RedirectionFlags & LB_PASSWORD_IS_PK_ENCRYPTED)
565
0
      {
566
0
        if (!rdstls_write_authentication_request_with_password(rdstls, s))
567
0
          return FALSE;
568
0
      }
569
0
      else if (settings->ServerAutoReconnectCookie != NULL)
570
0
      {
571
0
        if (!rdstls_write_authentication_request_with_cookie(rdstls, s))
572
0
          return FALSE;
573
0
      }
574
0
      else
575
0
      {
576
0
        WLog_Print(rdstls->log, WLOG_ERROR,
577
0
                   "cannot authenticate with password or auto-reconnect cookie");
578
0
        return FALSE;
579
0
      }
580
0
      break;
581
0
    case RDSTLS_STATE_AUTH_RSP:
582
0
      if (!rdstls_write_authentication_response(rdstls, s))
583
0
        return FALSE;
584
0
      break;
585
0
    default:
586
0
      WLog_Print(rdstls->log, WLOG_ERROR, "Invalid rdstls state %s [%d]",
587
0
                 rdstls_get_state_str(state), state);
588
0
      return FALSE;
589
0
  }
590
591
0
  if (transport_write(rdstls->transport, s) < 0)
592
0
    return FALSE;
593
594
0
  return TRUE;
595
0
}
596
597
static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra)
598
0
{
599
0
  UINT16 version = 0;
600
0
  UINT16 pduType = 0;
601
0
  rdpRdstls* rdstls = (rdpRdstls*)extra;
602
603
0
  WINPR_ASSERT(transport);
604
0
  WINPR_ASSERT(s);
605
0
  WINPR_ASSERT(rdstls);
606
607
0
  if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4))
608
0
    return FALSE;
609
610
0
  Stream_Read_UINT16(s, version);
611
0
  if (version != RDSTLS_VERSION_1)
612
0
  {
613
0
    WLog_Print(rdstls->log, WLOG_ERROR,
614
0
               "received invalid RDSTLS Version=0x%04" PRIX16 ", expected 0x%04" PRIX16,
615
0
               version, RDSTLS_VERSION_1);
616
0
    return -1;
617
0
  }
618
619
0
  Stream_Read_UINT16(s, pduType);
620
0
  switch (pduType)
621
0
  {
622
0
    case RDSTLS_TYPE_CAPABILITIES:
623
0
      if (!rdstls_process_capabilities(rdstls, s))
624
0
        return -1;
625
0
      break;
626
0
    case RDSTLS_TYPE_AUTHREQ:
627
0
      if (!rdstls_process_authentication_request(rdstls, s))
628
0
        return -1;
629
0
      break;
630
0
    case RDSTLS_TYPE_AUTHRSP:
631
0
      if (!rdstls_process_authentication_response(rdstls, s))
632
0
        return -1;
633
0
      break;
634
0
    default:
635
0
      WLog_Print(rdstls->log, WLOG_ERROR, "unknown RDSTLS PDU type [0x%04" PRIx16 "]",
636
0
                 pduType);
637
0
      return -1;
638
0
  }
639
640
0
  return 1;
641
0
}
642
643
#define rdstls_check_state_requirements(rdstls, expected) \
644
0
  rdstls_check_state_requirements_((rdstls), (expected), __FILE__, __func__, __LINE__)
645
static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE expected,
646
                                             const char* file, const char* fkt, size_t line)
647
0
{
648
0
  const RDSTLS_STATE current = rdstls_get_state(rdstls);
649
0
  if (current == expected)
650
0
    return TRUE;
651
652
0
  const DWORD log_level = WLOG_ERROR;
653
0
  if (WLog_IsLevelActive(rdstls->log, log_level))
654
0
    WLog_PrintMessage(rdstls->log, WLOG_MESSAGE_TEXT, log_level, line, file, fkt,
655
0
                      "Unexpected rdstls state %s [%d], expected %s [%d]",
656
0
                      rdstls_get_state_str(current), current, rdstls_get_state_str(expected),
657
0
                      expected);
658
659
0
  return FALSE;
660
0
}
661
662
static BOOL rdstls_send_capabilities(rdpRdstls* rdstls)
663
0
{
664
0
  BOOL rc = FALSE;
665
0
  wStream* s = NULL;
666
667
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
668
0
    goto fail;
669
670
0
  s = Stream_New(NULL, 512);
671
0
  if (!s)
672
0
    goto fail;
673
674
0
  if (!rdstls_send(rdstls->transport, s, rdstls))
675
0
    goto fail;
676
677
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ);
678
0
fail:
679
0
  Stream_Free(s, TRUE);
680
0
  return rc;
681
0
}
682
683
static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls)
684
0
{
685
0
  BOOL rc = FALSE;
686
0
  int status = 0;
687
0
  wStream* s = NULL;
688
689
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
690
0
    goto fail;
691
692
0
  s = Stream_New(NULL, 4096);
693
0
  if (!s)
694
0
    goto fail;
695
696
0
  status = transport_read_pdu(rdstls->transport, s);
697
698
0
  if (status < 0)
699
0
    goto fail;
700
701
0
  status = rdstls_recv(rdstls->transport, s, rdstls);
702
703
0
  if (status < 0)
704
0
    goto fail;
705
706
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP);
707
0
fail:
708
0
  Stream_Free(s, TRUE);
709
0
  return rc;
710
0
}
711
712
static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls)
713
0
{
714
0
  BOOL rc = FALSE;
715
0
  wStream* s = NULL;
716
717
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
718
0
    goto fail;
719
720
0
  s = Stream_New(NULL, 512);
721
0
  if (!s)
722
0
    goto fail;
723
724
0
  if (!rdstls_send(rdstls->transport, s, rdstls))
725
0
    goto fail;
726
727
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL);
728
0
fail:
729
0
  Stream_Free(s, TRUE);
730
0
  return rc;
731
0
}
732
733
static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls)
734
0
{
735
0
  BOOL rc = FALSE;
736
0
  int status = 0;
737
0
  wStream* s = NULL;
738
739
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
740
0
    goto fail;
741
742
0
  s = Stream_New(NULL, 512);
743
0
  if (!s)
744
0
    goto fail;
745
746
0
  status = transport_read_pdu(rdstls->transport, s);
747
748
0
  if (status < 0)
749
0
    goto fail;
750
751
0
  status = rdstls_recv(rdstls->transport, s, rdstls);
752
753
0
  if (status < 0)
754
0
    goto fail;
755
756
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ);
757
0
fail:
758
0
  Stream_Free(s, TRUE);
759
0
  return rc;
760
0
}
761
762
static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls)
763
0
{
764
0
  BOOL rc = FALSE;
765
0
  wStream* s = NULL;
766
767
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
768
0
    goto fail;
769
770
0
  s = Stream_New(NULL, 4096);
771
0
  if (!s)
772
0
    goto fail;
773
774
0
  if (!rdstls_send(rdstls->transport, s, rdstls))
775
0
    goto fail;
776
777
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP);
778
0
fail:
779
0
  Stream_Free(s, TRUE);
780
0
  return rc;
781
0
}
782
783
static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls)
784
0
{
785
0
  BOOL rc = FALSE;
786
0
  int status = 0;
787
0
  wStream* s = NULL;
788
789
0
  WINPR_ASSERT(rdstls);
790
791
0
  if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
792
0
    goto fail;
793
794
0
  s = Stream_New(NULL, 512);
795
0
  if (!s)
796
0
    goto fail;
797
798
0
  status = transport_read_pdu(rdstls->transport, s);
799
800
0
  if (status < 0)
801
0
    goto fail;
802
803
0
  status = rdstls_recv(rdstls->transport, s, rdstls);
804
805
0
  if (status < 0)
806
0
    goto fail;
807
808
0
  rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL);
809
0
fail:
810
0
  Stream_Free(s, TRUE);
811
0
  return rc;
812
0
}
813
814
static int rdstls_server_authenticate(rdpRdstls* rdstls)
815
0
{
816
0
  if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES))
817
0
    return -1;
818
819
0
  if (!rdstls_send_capabilities(rdstls))
820
0
    return -1;
821
822
0
  if (!rdstls_recv_authentication_request(rdstls))
823
0
    return -1;
824
825
0
  if (!rdstls_send_authentication_response(rdstls))
826
0
    return -1;
827
828
0
  if (rdstls->resultCode != 0)
829
0
    return -1;
830
831
0
  return 1;
832
0
}
833
834
static int rdstls_client_authenticate(rdpRdstls* rdstls)
835
0
{
836
0
  if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES))
837
0
    return -1;
838
839
0
  if (!rdstls_recv_capabilities(rdstls))
840
0
    return -1;
841
842
0
  if (!rdstls_send_authentication_request(rdstls))
843
0
    return -1;
844
845
0
  if (!rdstls_recv_authentication_response(rdstls))
846
0
    return -1;
847
848
0
  return 1;
849
0
}
850
851
/**
852
 * Authenticate using RDSTLS.
853
 * @param rdstls The RDSTLS instance to use
854
 *
855
 * @return 1 if authentication is successful
856
 */
857
858
int rdstls_authenticate(rdpRdstls* rdstls)
859
0
{
860
0
  WINPR_ASSERT(rdstls);
861
862
0
  if (rdstls->server)
863
0
    return rdstls_server_authenticate(rdstls);
864
0
  else
865
0
    return rdstls_client_authenticate(rdstls);
866
0
}
867
868
static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s)
869
0
{
870
0
  switch (dataType)
871
0
  {
872
0
    case RDSTLS_DATA_PASSWORD_CREDS:
873
0
    {
874
0
      UINT16 redirGuidLength = 0;
875
0
      if (Stream_GetRemainingLength(s) < 2)
876
0
        return 0;
877
0
      Stream_Read_UINT16(s, redirGuidLength);
878
879
0
      if (Stream_GetRemainingLength(s) < redirGuidLength)
880
0
        return 0;
881
0
      Stream_Seek(s, redirGuidLength);
882
883
0
      UINT16 usernameLength = 0;
884
0
      if (Stream_GetRemainingLength(s) < 2)
885
0
        return 0;
886
0
      Stream_Read_UINT16(s, usernameLength);
887
888
0
      if (Stream_GetRemainingLength(s) < usernameLength)
889
0
        return 0;
890
0
      Stream_Seek(s, usernameLength);
891
892
0
      UINT16 domainLength = 0;
893
0
      if (Stream_GetRemainingLength(s) < 2)
894
0
        return 0;
895
0
      Stream_Read_UINT16(s, domainLength);
896
897
0
      if (Stream_GetRemainingLength(s) < domainLength)
898
0
        return 0;
899
0
      Stream_Seek(s, domainLength);
900
901
0
      UINT16 passwordLength = 0;
902
0
      if (Stream_GetRemainingLength(s) < 2)
903
0
        return 0;
904
0
      Stream_Read_UINT16(s, passwordLength);
905
906
0
      return Stream_GetPosition(s) + passwordLength;
907
0
    }
908
0
    case RDSTLS_DATA_AUTORECONNECT_COOKIE:
909
0
    {
910
0
      if (Stream_GetRemainingLength(s) < 4)
911
0
        return 0;
912
0
      Stream_Seek(s, 4);
913
914
0
      UINT16 cookieLength = 0;
915
0
      if (Stream_GetRemainingLength(s) < 2)
916
0
        return 0;
917
0
      Stream_Read_UINT16(s, cookieLength);
918
919
0
      return 12u + cookieLength;
920
0
    }
921
0
    default:
922
0
      WLog_Print(log, WLOG_ERROR, "invalid RDSLTS dataType");
923
0
      return -1;
924
0
  }
925
0
}
926
927
SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream)
928
0
{
929
0
  SSIZE_T pduLength = -1;
930
0
  wStream sbuffer = { 0 };
931
0
  wStream* s = Stream_StaticConstInit(&sbuffer, Stream_Buffer(stream), Stream_Length(stream));
932
933
0
  UINT16 version = 0;
934
0
  if (Stream_GetRemainingLength(s) < 2)
935
0
    return 0;
936
0
  Stream_Read_UINT16(s, version);
937
0
  if (version != RDSTLS_VERSION_1)
938
0
  {
939
0
    WLog_Print(log, WLOG_ERROR, "invalid RDSTLS version");
940
0
    return -1;
941
0
  }
942
943
0
  UINT16 pduType = 0;
944
0
  if (Stream_GetRemainingLength(s) < 2)
945
0
    return 0;
946
0
  Stream_Read_UINT16(s, pduType);
947
0
  switch (pduType)
948
0
  {
949
0
    case RDSTLS_TYPE_CAPABILITIES:
950
0
      pduLength = 8;
951
0
      break;
952
0
    case RDSTLS_TYPE_AUTHREQ:
953
0
      if (Stream_GetRemainingLength(s) < 2)
954
0
        return 0;
955
0
      UINT16 dataType = 0;
956
0
      Stream_Read_UINT16(s, dataType);
957
0
      pduLength = rdstls_parse_pdu_data_type(log, dataType, s);
958
959
0
      break;
960
0
    case RDSTLS_TYPE_AUTHRSP:
961
0
      pduLength = 10;
962
0
      break;
963
0
    default:
964
0
      WLog_Print(log, WLOG_ERROR, "invalid RDSTLS PDU type");
965
0
      return -1;
966
0
  }
967
968
0
  return pduLength;
969
0
}