Coverage Report

Created: 2026-05-11 06:44

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/freeradius-server/src/lib/util/trie.c
Line
Count
Source
1
/*
2
 *   This library is free software; you can redistribute it and/or
3
 *   modify it under the terms of the GNU Lesser General Public
4
 *   License as published by the Free Software Foundation; either
5
 *   version 2.1 of the License, or (at your option) any later version.
6
 *
7
 *   This library is distributed in the hope that it will be useful,
8
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
9
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10
 *   Lesser General Public License for more details.
11
 *
12
 *   You should have received a copy of the GNU Lesser General Public
13
 *   License along with this library; if not, write to the Free Software
14
 *   Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301, USA
15
 */
16
17
/** Path-compressed prefix tries
18
 *
19
 * @file src/lib/util/trie.c
20
 *
21
 * @copyright 2017 Alan DeKok (aland@freeradius.org)
22
 */
23
RCSID("$Id: 0c2e975b6d8169f07c309a70593953b7f3bcd440 $")
24
25
#include <freeradius-devel/util/dict.h>
26
#include <freeradius-devel/util/skip.h>
27
#include <freeradius-devel/util/syserror.h>
28
#include <freeradius-devel/util/trie.h>
29
30
31
/*
32
 *  This file implements path-compressed, level-compressed
33
 *  patricia tries.  The original research paper is:
34
 *
35
 *  https://www.nada.kth.se/~snilsson/publications/Dynamic-trie-compression-implementation/
36
 *
37
 *  The functionality has been extended to include intermediate
38
 *  nodes which consume 0 bits, but which hold user context data.
39
 *  These intermediate nodes allow for "longest prefix" matching.
40
 *  For example, in networking, you can have a routing table entry
41
 *  with 0/0 leading to one destination, and 10/8 leading to a
42
 *  different one.  Looking up an address in the 10/8 network will
43
 *  return the 10/8 destination.  Looking up any other address
44
 *  will return the default destination.
45
 *
46
 *  In addition, we desire the ability to add and delete nodes
47
 *  dynamically.  In the example given above, this means that
48
 *  after deleting 10/8, the trie should contain only the 0/0
49
 *  network and associated destination.
50
 *
51
 *  As of yet, it does not do level compression.  This can be
52
 *  added without (hopefully) too much work.  That would require
53
 *  an additional step to "normalize" the trie.
54
 *
55
 *  This code could be extended to do packet matching, through the
56
 *  inclusion of "don't care" paths.  e.g. parsing an IP header,
57
 *  where the src/dst IP addresses are 32-bit "don't care" fields.
58
 *
59
 *  It could also be extended via "count" paths, where the path
60
 *  holds a count that is used in another part of the trie.  For
61
 *  example, in RADIUS.  The attribute encoding is one byte
62
 *  attribute, one byte length, followed by "length - 2" bytes of
63
 *  data.  At that point though, you might as well just use Ragel.
64
 */
65
66
/** Enable path compression (or not)
67
 *
68
 *  With path compression, long sequences of bits are stored as a
69
 *  path, e.g. "abcdef".  Without path compression, we would have to
70
 *  create a large number of intermediate 2^N-way nodes, all of which
71
 *  would have only one edge.
72
 */
73
#if !defined(NO_PATH_COMPRESSION) && !defined(WITH_PATH_COMPRESSION)
74
#define WITH_PATH_COMPRESSION
75
#endif
76
77
//#define WITH_NODE_COMPRESSION
78
79
#ifdef WITH_NODE_COMPRESSION
80
#ifndef WITH_PATH_COMPRESSION
81
#define WITH_PATH_COMPRESSION
82
#endif
83
84
#ifndef MAX_COMP_BITS
85
#define MAX_COMP_BITS (8)
86
#endif
87
88
#ifndef MAX_COMP_EDGES
89
#define MAX_COMP_EDGES (4)
90
#endif
91
92
#endif  /* WITH_NODE_COMPRESSION */
93
94
0
#define MAX_KEY_BYTES (256)
95
0
#define MAX_KEY_BITS (MAX_KEY_BYTES * 8)
96
97
#ifndef MAX_NODE_BITS
98
0
#define MAX_NODE_BITS (4)
99
#endif
100
101
/**  Internal sanity checks for debugging.
102
 *
103
 *  Tries are complex.  So we have verification routines for every
104
 *  type of node.  These routines are called from within the trie
105
 *  manipulation functions.  If the trie manipulation has a bug, the
106
 *  verification routines are likely to catch some of the more
107
 *  egregious issues.
108
 */
109
DIAG_OFF(unused-macros)
110
#ifdef TESTING
111
#define WITH_TRIE_VERIFY (1)
112
#  define MPRINT(...) fprintf(stderr, ## __VA_ARGS__)
113
114
   /* define this to be MPRINT for additional debugging */
115
#  define MPRINT2(...)
116
#  define MPRINT3(...)
117
static void trie_sprint(fr_trie_t *trie, uint8_t const *key, int start_bit, int lineno);
118
#else
119
#  define MPRINT(...)
120
#  define MPRINT2(...)
121
#  define MPRINT3(...)
122
#define trie_sprint(_trie, _key, _start_bit, _lineno)
123
#endif
124
125
#ifdef WITH_TRIE_VERIFY
126
static int trie_verify(fr_trie_t *trie);
127
//#define VERIFY(_x) fr_cond_assert(trie_verify((fr_trie_t *) _x) == 0)
128
#define VERIFY(_x) if (trie_verify((fr_trie_t *) _x) < 0) do { fprintf(stderr, "FAIL VERIFY at %d - %s\n", __LINE__, fr_strerror()); fr_cond_assert(0); } while (0)
129
#else
130
#define VERIFY(_x)
131
#endif
132
133
/*
134
 *  Macros to swap one for the other.
135
 */
136
#define BITSOF(_x)  ((_x) * 8)
137
0
#define BYTEOF(_x)  ((_x) >> 3)
138
#define BYTES(_x) (((_x) + 0x07) >> 3)
139
DIAG_ON(unused-macros)
140
141
142
// @todo - do level compression
143
// stop merging nodes if a key ends at the top of the level
144
// otherwise merge so we have at least 2^4 way fan-out, but no more than 2^8
145
// that should be a decent trade-off between memory and speed
146
147
// @todo - generalized function to normalize the trie.
148
149
150
static uint8_t start_bit_mask[8] = {
151
  0xff, 0x7f, 0x3f, 0x1f,
152
  0x0f, 0x07, 0x03, 0x01
153
};
154
155
static uint8_t used_bit_mask[8] = {
156
  0x80, 0xc0, 0xe0, 0xf0,
157
  0xf8, 0xfc, 0xfe, 0xff,
158
};
159
160
#if 0
161
/*
162
 *  For testing and debugging.
163
 */
164
static char const *spaces = "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                ";
165
#endif
166
167
168
#if defined(WITH_PATH_COMPRESSION) || defined(TESTING)
169
/*
170
 *  Table of how many leading bits there are in KEY1^KEY2.
171
 */
172
static uint8_t xor2lcp[256] = {
173
  8, 7, 6, 6,
174
  5, 5, 5, 5,   /* 4x 5 */
175
  4, 4, 4, 4,   /* 8x 4 */
176
  4, 4, 4, 4,
177
  3, 3, 3, 3,   /* 16x 3 */
178
  3, 3, 3, 3,
179
  3, 3, 3, 3,
180
  3, 3, 3, 3,
181
  2, 2, 2, 2,   /* 32x 2 */
182
  2, 2, 2, 2,
183
  2, 2, 2, 2,
184
  2, 2, 2, 2,
185
  2, 2, 2, 2,
186
  2, 2, 2, 2,
187
  2, 2, 2, 2,
188
  2, 2, 2, 2,
189
  1, 1, 1, 1,   /* 64x 1 */
190
  1, 1, 1, 1,
191
  1, 1, 1, 1,
192
  1, 1, 1, 1,
193
  1, 1, 1, 1,
194
  1, 1, 1, 1,
195
  1, 1, 1, 1,
196
  1, 1, 1, 1,
197
  1, 1, 1, 1,
198
  1, 1, 1, 1,
199
  1, 1, 1, 1,
200
  1, 1, 1, 1,
201
  1, 1, 1, 1,
202
  1, 1, 1, 1,
203
  1, 1, 1, 1,
204
  1, 1, 1, 1,
205
  0, 0, 0, 0,   /* 128x 0 */
206
  0, 0, 0, 0,
207
  0, 0, 0, 0,
208
  0, 0, 0, 0,
209
  0, 0, 0, 0,
210
  0, 0, 0, 0,
211
  0, 0, 0, 0,
212
  0, 0, 0, 0,
213
  0, 0, 0, 0,
214
  0, 0, 0, 0,
215
  0, 0, 0, 0,
216
  0, 0, 0, 0,
217
  0, 0, 0, 0,
218
  0, 0, 0, 0,
219
  0, 0, 0, 0,
220
  0, 0, 0, 0,
221
};
222
223
224
/** Get the longest prefix of the two keys.
225
 *
226
 */
227
static int fr_trie_key_lcp(uint8_t const *key1, int keylen1, uint8_t const *key2, int keylen2, int start_bit)
228
0
{
229
0
  int lcp, end_bit;
230
231
0
  if (!keylen1 || !keylen2) return 0;
232
0
  fr_cond_assert((start_bit & 0x07) == start_bit);
233
234
0
  end_bit = keylen1;
235
0
  if (end_bit > keylen2) end_bit = keylen2;
236
0
  end_bit += start_bit;
237
238
0
  MPRINT2("%.*sLCP %02x%02x %02x%02x start %d length %d, %d\n",
239
0
    start_bit, spaces, key1[0], key1[1], key2[0], key2[1], start_bit, keylen1, keylen2);
240
241
0
  lcp = 0;
242
243
0
  while (end_bit > 0) {
244
0
    int num_bits;
245
0
    uint8_t cmp1, cmp2, xor;
246
247
0
    MPRINT2("END %d\n", end_bit);
248
249
    /*
250
     *  Default to grabbing the whole byte.
251
     */
252
0
    cmp1 = key1[0];
253
0
    cmp2 = key2[0];
254
0
    num_bits = 8;
255
256
    /*
257
     *  The LCP ends in this byte.  Mask off the
258
     *  trailing bits so that they don't affect the
259
     *  result.
260
     */
261
0
    if (end_bit < 8) {
262
0
      cmp1 &= used_bit_mask[end_bit - 1];
263
0
      cmp2 &= used_bit_mask[end_bit - 1];
264
0
      num_bits = end_bit;
265
0
    }
266
267
    /*
268
     *  The key doesn't start on the leading bit.
269
     *  Shift the data left until it does start there.
270
     */
271
0
    if ((start_bit & 0x07) != 0) {
272
0
      cmp1 <<= start_bit;
273
0
      cmp2 <<= start_bit;
274
0
      num_bits -= start_bit;
275
0
      end_bit -= start_bit;
276
277
      /*
278
       *  For subsequent bytes we start on a
279
       *  byte boundary.
280
       */
281
0
      start_bit = 0;
282
0
    }
283
284
0
    xor = cmp1 ^ cmp2;
285
286
    /*
287
     *  A table lookup is faster than looping through
288
     *  the bits.  If the LCP is smaller than the
289
     *  number of bits we're looking up, we can stop.
290
     *
291
     *  On the other hand, if it returns the same or
292
     *  too many bits, just do another round through
293
     *  the loop, so that we can update the pointers
294
     *  and check the exit conditions.
295
     */
296
0
    if (xor2lcp[xor] < num_bits) {
297
0
      MPRINT2("RETURN %d + %d\n", lcp, xor2lcp[xor]);
298
0
      return lcp + xor2lcp[xor];
299
0
    }
300
301
    /*
302
     *  The LCP may be longer than num_bits if we're
303
     *  checking the first byte, which has only
304
     *  "start_bit" things we care about.  Ignore that
305
     *  case, and just keep going.
306
     */
307
308
0
    lcp += num_bits;
309
0
    end_bit -= num_bits;
310
0
    key1++;
311
0
    key2++;
312
0
  }
313
314
0
  return lcp;
315
0
}
316
#endif
317
318
//#define HEX_DUMP
319
320
#ifdef HEX_DUMP
321
static void hex_dump(FILE *fp, char const *msg, uint8_t const *key, int start_bit, int end_bit)
322
{
323
  int i;
324
325
  fprintf(fp, "%s\ts=%zd e=%zd\t\t", msg, start_bit, end_bit);
326
327
  for (i = 0; i < BYTES(end_bit); i++) {
328
    fprintf(fp, "%02x ", key[i]);
329
  }
330
  fprintf(fp, "\n");
331
}
332
#endif
333
334
/** Return a chunk of a key (in the low bits) for use in 2^N node de-indexing
335
 *
336
 */
337
static CC_HINT(nonnull) uint16_t get_chunk(uint8_t const *key, uint32_t start_bit, uint32_t num_bits)
338
{
339
  uint16_t chunk;
340
  int end_bit;
341
342
  fr_cond_assert(num_bits > 0);
343
  fr_cond_assert(num_bits <= 16);
344
345
  /*
346
   *  Normalize it so that the caller doesn't have to.
347
   */
348
  if (start_bit > 7) {
349
    key += BYTEOF(start_bit);
350
    start_bit &= 0x07;
351
  }
352
353
  /*
354
   *  Special-case 1-bit lookups.
355
   */
356
  if (num_bits == 1) {
357
    chunk = key[0] >> (7 - start_bit);
358
    chunk &= 0x01;
359
    return chunk;
360
  }
361
362
  /*
363
   *  Catch some simple use-cases.
364
   */
365
  if (start_bit == 0) {
366
    if (num_bits < 7) return key[0] >> (8 - num_bits);
367
    if (num_bits == 8) return key[0];
368
369
    chunk = (key[0] << 8) | key[1];
370
    if (num_bits < 16) chunk >>= (16 - num_bits);
371
    fr_cond_assert(chunk < (1 << num_bits));
372
    return chunk;
373
  }
374
375
  /*
376
   *  Load the first byte and mask off the bits we don't
377
   *  want.
378
   */
379
  chunk = key[0] & start_bit_mask[start_bit & 0x07];
380
381
  fr_cond_assert(BYTEOF(start_bit + num_bits - 1) <= 1);
382
383
  if (BYTEOF(start_bit + num_bits - 1) != 0) {
384
    chunk <<= 8;
385
    chunk |= key[1];
386
  }
387
388
  /*
389
   *  The bits we want are now all in the higher bits
390
   *  of "chunk".  But we only want some of them.
391
   *
392
   *  Shift the chunk so that the bits we want are now in
393
   *  the low bits.
394
   */
395
  end_bit = (start_bit + num_bits) & 0x07;
396
  if (end_bit != 0) chunk >>= 8 - end_bit;
397
398
  fr_cond_assert(chunk < (1 << num_bits));
399
400
  return chunk;
401
}
402
403
404
static void write_chunk(uint8_t *out, int start_bit, int num_bits, uint16_t chunk) CC_HINT(nonnull);
405
406
/** Write a chunk to an output buffer
407
 *
408
 */
409
static void write_chunk(uint8_t *out, int start_bit, int num_bits, uint16_t chunk)
410
0
{
411
0
  fr_cond_assert(chunk < (1 << num_bits));
412
413
  /*
414
   *  Normalize it so that the caller doesn't have to.
415
   */
416
0
  if (start_bit > 7) {
417
0
    out += BYTEOF(start_bit);
418
0
    start_bit &= 0x07;
419
0
  }
420
421
  /*
422
   *  Special-case 1-bit writes.
423
   */
424
0
  if (num_bits == 1) {
425
0
    out[0] &= ~(1 << (7 - start_bit));
426
0
    out[0] |= chunk << (7 - start_bit);
427
0
    return;
428
0
  }
429
430
  /*
431
   *  Ensure that we don't write to more than 2 octets at
432
   *  the same time.
433
   */
434
0
  fr_cond_assert((start_bit + num_bits) <= 16);
435
436
  /*
437
   *  Shift the chunk to the high bits, but leave room for
438
   *  start_bit
439
   */
440
0
  if ((start_bit + num_bits) < 16) chunk <<= (16 - (start_bit + num_bits));
441
442
  /*
443
   *  Mask off the first bits that are already in the
444
   *  output.  Then OR in the relevant bits of "chunk".
445
   */
446
0
  out[0] &= (used_bit_mask[start_bit] << 1);
447
0
  out[0] |= chunk >> 8;
448
449
0
  if ((start_bit + num_bits) > 8) {
450
0
    out[1] = chunk & 0xff;
451
0
  }
452
0
}
453
454
typedef enum fr_trie_type_t {
455
  FR_TRIE_INVALID = 0,
456
  FR_TRIE_USER,
457
#ifdef WITH_PATH_COMPRESSION
458
  FR_TRIE_PATH,
459
#endif
460
#ifdef WITH_NODE_COMPRESSION
461
  FR_TRIE_COMP,   /* 4-way, N bits deep */
462
#endif
463
  FR_TRIE_NODE,
464
} fr_trie_type_t;
465
466
#define FR_TRIE_MAX (FR_TRIE_NODE + 1)
467
468
#ifdef TESTING
469
static int trie_number = 0;
470
471
#define TRIE_HEADER uint8_t type; uint8_t bits; int number
472
#define TRIE_TYPE_CHECK(_x, _r) do { if ((trie->type == FR_TRIE_INVALID) || \
473
           (trie->type >= FR_TRIE_MAX) || \
474
           !trie_ ## _x ##_table [trie->type]) { \
475
            fr_strerror_printf("unknown trie type %d", trie->type); \
476
            return _r; \
477
             } } while (0)
478
479
#else
480
#define TRIE_HEADER uint8_t type; uint8_t bits
481
#define TRIE_TYPE_CHECK(_x, _r)
482
#endif
483
484
struct fr_trie_s {
485
  TRIE_HEADER;
486
487
  fr_trie_t *trie;  /* for USER and PATH nodes*/
488
};
489
490
typedef struct {
491
  TRIE_HEADER;
492
493
  int   used;
494
  fr_trie_t *trie[];
495
} fr_trie_node_t;
496
497
typedef struct {
498
  TRIE_HEADER;
499
500
  fr_trie_t *trie;
501
  void      *data;
502
} fr_trie_user_t;
503
504
#ifdef WITH_PATH_COMPRESSION
505
typedef struct {
506
  TRIE_HEADER;
507
508
  fr_trie_t *trie;
509
510
  uint16_t  chunk;
511
  uint8_t   key[2];
512
} fr_trie_path_t;
513
#endif
514
515
#ifdef WITH_NODE_COMPRESSION
516
typedef struct {
517
  TRIE_HEADER;
518
519
  int   used;   //!< number of used entries
520
  uint8_t   index[MAX_COMP_EDGES];
521
  fr_trie_t *trie[MAX_COMP_EDGES];
522
} fr_trie_comp_t;
523
#endif
524
525
526
/* ALLOC FUNCTIONS */
527
528
static fr_trie_node_t *trie_node_alloc(TALLOC_CTX *ctx, int bits)
529
0
{
530
0
  fr_trie_node_t *node;
531
0
  int size;
532
533
0
  if ((bits <= 0) || (bits > 8)) {
534
0
    fr_strerror_printf("Invalid bit size %d passed to node alloc", bits);
535
0
    return NULL;
536
0
  }
537
538
0
  size = 1 << bits;
539
540
0
  node = (fr_trie_node_t *) talloc_zero_array(ctx, uint8_t, sizeof(fr_trie_node_t) + sizeof(node->trie[0]) * size);
541
0
  if (!node) {
542
0
    fr_strerror_const("failed allocating node trie");
543
0
    return NULL;
544
0
  }
545
546
0
  talloc_set_name_const(node, "fr_trie_node_t");
547
0
  node->type = FR_TRIE_NODE;
548
0
  node->bits = bits;
549
550
#ifdef TESTING
551
  node->number = trie_number++;
552
#endif
553
0
  return node;
554
0
}
555
556
/** Free a fr_trie_t
557
 *
558
 *  We can't use talloc_free(), because we can't talloc_parent the
559
 *  nodes from each other, as talloc_steal() is O(N).  So, we just
560
 *  recurse manually.
561
 */
562
static void trie_free(fr_trie_t *trie)
563
0
{
564
0
  if (!trie) return;
565
566
0
  if (trie->type == FR_TRIE_USER) {
567
0
    fr_trie_user_t *user = (fr_trie_user_t *) trie;
568
569
0
    trie_free(user->trie);
570
0
    talloc_free(user);
571
0
    return;
572
0
  }
573
574
0
  if (trie->type == FR_TRIE_NODE) {
575
0
    fr_trie_node_t *node = (fr_trie_node_t *) trie;
576
0
    int i;
577
578
0
    for (i = 0; i < (1 << node->bits); i++) {
579
0
      if (!node->trie[i]) continue; /* save a function call in the common case */
580
581
0
      trie_free(node->trie[i]);
582
0
    }
583
584
0
    talloc_free(node);
585
0
    return;
586
0
  }
587
588
0
#ifdef WITH_PATH_COMPRESSION
589
0
  if (trie->type == FR_TRIE_PATH) {
590
0
    fr_trie_path_t *path = (fr_trie_path_t *) trie;
591
592
0
    trie_free(path->trie);
593
0
    talloc_free(path);
594
0
    return;
595
0
  }
596
0
#endif
597
598
#ifdef WITH_NODE_COMPRESSION
599
  if (trie->type == FR_TRIE_COMP) {
600
    fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
601
    int i;
602
603
    for (i = 0; i < comp->used; i++) {
604
      trie_free(comp->trie[i]);
605
    }
606
607
    talloc_free(comp);
608
    return;
609
  }
610
#endif
611
0
}
612
613
static CC_HINT(nonnull(2)) fr_trie_user_t *fr_trie_user_alloc(TALLOC_CTX *ctx, void const *data)
614
0
{
615
0
  fr_trie_user_t *user;
616
617
0
  user = talloc_zero(ctx, fr_trie_user_t);
618
0
  if (!user) {
619
0
    fr_strerror_const("failed allocating user trie");
620
0
    return NULL;
621
0
  }
622
623
0
  user->type = FR_TRIE_USER;
624
0
  user->data = UNCONST(void *, data);
625
626
#ifdef TESTING
627
  user->number = trie_number++;
628
#endif
629
630
0
  return user;
631
0
}
632
633
#ifdef WITH_PATH_COMPRESSION
634
static CC_HINT(nonnull(2)) fr_trie_path_t *fr_trie_path_alloc(TALLOC_CTX *ctx, uint8_t const *key, int start_bit, int end_bit)
635
0
{
636
0
  fr_trie_path_t *path;
637
638
0
  if (end_bit <= start_bit) {
639
0
    fr_strerror_printf("path asked for start >= end, %d >= %d", start_bit, end_bit);
640
0
    return NULL;
641
0
  }
642
643
  /*
644
   *  Normalize it so that the caller doesn't have to.
645
   */
646
0
  if (start_bit > 7) {
647
0
    key += (start_bit >> 3);
648
0
    end_bit -= 8 * (start_bit >> 3);
649
0
    start_bit -= 8 * (start_bit >> 3);
650
0
  }
651
652
0
  if ((end_bit - start_bit) > 16) {
653
0
    fr_strerror_printf("path asked for too many bits (%d)", end_bit - start_bit);
654
0
    return NULL;
655
0
  }
656
657
  /*
658
   *  The "end_bit" is the bit we're not using, so it's
659
   *  allowed to point past the end of path->key.
660
   */
661
0
  if ((BYTEOF(start_bit) - BYTEOF(end_bit - 1)) > 1) {
662
0
    fr_strerror_printf("path asked for too many bits / bytes (%d)", end_bit - start_bit);
663
0
    return NULL;
664
0
  }
665
666
0
  path = talloc_zero(ctx, fr_trie_path_t);
667
0
  if (!path) {
668
0
    fr_strerror_const("failed allocating path trie");
669
0
    return NULL;
670
0
  }
671
672
0
  path->type = FR_TRIE_PATH;
673
0
  path->bits = end_bit - start_bit;
674
0
  path->chunk = get_chunk(key, start_bit, path->bits);
675
676
  /*
677
   *  Write the chunk back to the key.
678
   */
679
0
  write_chunk(&path->key[0], start_bit, path->bits, path->chunk);
680
681
#if 0
682
  fprintf(stderr, "PATH ALLOC key %02x%02x start %d end %d bits %d == chunk %04x key %02x%02x\n",
683
    key[0], key[1],
684
    start_bit, end_bit, path->bits,
685
    path->chunk, path->key[0], path->key[1]);
686
#endif
687
688
#ifdef TESTING
689
  path->number = trie_number++;
690
#endif
691
692
0
  return path;
693
0
}
694
#endif  /* WITH_PATH_COMPRESSION */
695
696
#ifdef WITH_NODE_COMPRESSION
697
static fr_trie_comp_t *fr_trie_comp_alloc(TALLOC_CTX *ctx, int bits)
698
{
699
  fr_trie_comp_t *comp;
700
701
  /*
702
   *  For 1 && 2 bits, just allocate fr_trie_node_t.
703
   */
704
  if ((bits <= 2) || (bits > MAX_COMP_BITS)) {
705
    fr_strerror_printf("Invalid bit size %d passed to comp alloc", bits);
706
    return NULL;
707
  }
708
709
  comp = talloc_zero(ctx, fr_trie_comp_t);
710
  if (!comp) {
711
    fr_strerror_const("failed allocating comp trie");
712
    return NULL;
713
  }
714
715
  comp->type = FR_TRIE_COMP;
716
  comp->bits = bits;
717
  comp->used = 0;
718
719
#ifdef TESTING
720
  comp->number = trie_number++;
721
#endif
722
  return comp;
723
}
724
#endif  /* WITH_NODE_COMPRESSION */
725
726
typedef struct {
727
  uint8_t   buffer[16]; /* for get_key callbacks */
728
  fr_trie_key_t get_key;
729
  fr_free_t free_data;
730
} fr_trie_ctx_t;
731
732
/** Allocate a trie
733
 *
734
 * @param ctx   The talloc ctx.
735
 * @param get_key The "get key from object" function.
736
 * @param free_data Callback to free data.
737
 * @return
738
 *  - NULL on error
739
 *  - fr_trie_node_t on success
740
 */
741
fr_trie_t *fr_trie_alloc(TALLOC_CTX *ctx, fr_trie_key_t get_key, fr_free_t free_data)
742
0
{
743
0
  fr_trie_user_t *user;
744
0
  fr_trie_ctx_t *uctx;
745
746
  /*
747
   *  The trie itself is just a user node with user data
748
   *  that is the get_key function.
749
   */
750
0
  user = (fr_trie_user_t *) fr_trie_user_alloc(ctx, "");
751
0
  if (!user) return NULL;
752
753
  /*
754
   *  Only the top-level node here can have 'user->data == NULL'
755
   */
756
0
  user->data = uctx = talloc_zero(user, fr_trie_ctx_t);
757
0
  if (!user->data) {
758
0
    talloc_free(user);
759
0
    return NULL;
760
0
  }
761
762
0
  uctx->get_key = get_key;
763
0
  uctx->free_data = free_data;
764
765
0
  return (fr_trie_t *) user;
766
0
}
767
768
/* SPLIT FUNCTIONS */
769
770
/** Split a node at bits
771
 *
772
 */
773
static CC_HINT(nonnull(2)) fr_trie_node_t *trie_node_split(TALLOC_CTX *ctx, fr_trie_node_t *node, int bits)
774
0
{
775
0
  fr_trie_node_t *split;
776
0
  int i, remaining_bits;
777
778
  /*
779
   *  Can't split zero bits, more bits than the node has, or
780
   *  a node which has 1 bit.
781
   */
782
0
  if ((bits == 0) || (bits >= node->bits) || (node->bits == 1)) {
783
0
    fr_strerror_printf("invalid value for node split (%d / %d)", bits, node->bits);
784
0
    return NULL;
785
0
  }
786
787
0
  split = trie_node_alloc(ctx, bits);
788
0
  if (!split) return NULL;
789
790
0
  remaining_bits = node->bits - bits;
791
792
  /*
793
   *  Allocate the children.  For now, just brute-force all
794
   *  of the children.  We take a later pass at optimizing this.
795
   */
796
0
  for (i = 0; i < (1 << bits); i++) {
797
0
    int j;
798
0
    fr_trie_node_t *child;
799
800
0
    child = trie_node_alloc(ctx, remaining_bits);
801
0
    if (!child) {
802
0
      trie_free((fr_trie_t *) split);
803
0
      return NULL;
804
0
    }
805
806
0
    for (j = 0; j < (1 << remaining_bits); j++) {
807
0
      if (!node->trie[(i << remaining_bits) + j]) continue;
808
809
0
      child->trie[j] = node->trie[(i << remaining_bits) + j];
810
0
      node->trie[(i << remaining_bits) + j] = NULL; /* so we don't free it when freeing 'node' */
811
0
      child->used++;
812
0
    }
813
814
0
    if (!child->used) {
815
0
      talloc_free(child); /* no children, so no need to recurse */
816
0
      continue;
817
0
    }
818
819
0
    split->trie[i] = (fr_trie_t *) child;
820
0
    split->used++;
821
0
  }
822
823
  /*
824
   *  Note that we do NOT free "node".  The caller still
825
   *  needs it for some activities.
826
   */
827
0
  return split;
828
0
}
829
830
#ifdef WITH_PATH_COMPRESSION
831
static CC_HINT(nonnull(2)) fr_trie_path_t *trie_path_split(TALLOC_CTX *ctx, fr_trie_path_t *path, int start_bit, int lcp)
832
0
{
833
0
  fr_trie_path_t *split, *child;
834
#ifdef TESTING
835
  uint8_t key[2] = { 0, 0 };
836
#endif
837
838
0
  if ((lcp <= 0) || (lcp > path->bits) || (start_bit < 0)) {
839
0
    fr_strerror_printf("invalid parameter %d %d to path split", lcp, start_bit);
840
0
    return NULL;
841
0
  }
842
843
0
  MPRINT3("%.*sSPLIT start %d\n", start_bit, spaces, start_bit);
844
0
  start_bit &= 0x07;
845
846
0
  split = fr_trie_path_alloc(ctx, &path->key[0], start_bit, start_bit + lcp);
847
0
  if (!split) return NULL;
848
849
0
  child = fr_trie_path_alloc(ctx, &path->key[0], start_bit + lcp, start_bit + path->bits);
850
0
  if (!child) {
851
0
    talloc_free(split);
852
0
    return NULL;
853
0
  }
854
855
0
  split->trie = (fr_trie_t *) child;
856
0
  child->trie = (fr_trie_t *) path->trie;
857
858
  /*
859
   *  Don't free "path" until we've successfully inserted
860
   *  the new key.
861
   */
862
863
#ifdef TESTING
864
  /*
865
   *  Check that the two chunks add up to the parent chunk.
866
   */
867
  fr_cond_assert(path->chunk == ((split->chunk << (path->bits - lcp)) | child->chunk));
868
869
  /*
870
   *  Check that the two keys match the parent key.
871
   */
872
873
  write_chunk(&key[0], start_bit, split->bits, split->chunk);
874
  write_chunk(&key[0], start_bit + split->bits, child->bits, child->chunk);
875
876
  fr_cond_assert(key[0] == path->key[0]);
877
  fr_cond_assert(key[1] == path->key[1]);
878
879
  MPRINT3("%.*ssplit %02x%02x start %d split %d -> %02x%02x %02x%02x\n",
880
    start_bit, spaces,
881
    path->key[0], path->key[1],
882
    start_bit, split->bits,
883
    split->key[0], split->key[1],
884
    child->key[0], child->key[1]);
885
#endif
886
887
0
  return split;
888
0
}
889
890
static CC_HINT(nonnull(2)) fr_trie_t *trie_key_alloc(TALLOC_CTX *ctx, uint8_t const *key, int start_bit, int end_bit, void *data)
891
0
{
892
0
  fr_trie_path_t *path;
893
0
  int next_bit;
894
895
0
  if (start_bit == end_bit) return (fr_trie_t *) fr_trie_user_alloc(ctx, data);
896
897
0
  if (start_bit > end_bit) {
898
0
    fr_strerror_printf("key_alloc asked for start >= end, %d >= %d", start_bit, end_bit);
899
0
    return NULL;
900
0
  }
901
902
  /*
903
   *  Grab some more bits.  Try to grab 16 bits at a time.
904
   */
905
0
  next_bit = start_bit + 16 - (start_bit & 0x07);
906
907
0
  if (next_bit >= end_bit) {
908
0
    path = fr_trie_path_alloc(ctx, key, start_bit, end_bit);
909
0
    if (!path) return NULL;
910
911
0
    path->trie = (fr_trie_t *) fr_trie_user_alloc(ctx, data);
912
0
    return (fr_trie_t *) path;
913
0
  }
914
915
916
0
  path = fr_trie_path_alloc(ctx,  key, start_bit, next_bit);
917
0
  if (!path) return NULL;
918
919
0
  path->trie = (fr_trie_t *) trie_key_alloc(ctx, key, next_bit, end_bit, data);
920
0
  if (!path->trie) {
921
0
    talloc_free(path); /* no children */
922
0
    return NULL;
923
0
  }
924
925
0
  return (fr_trie_t *) path;
926
0
}
927
#else  /* WITH_PATH_COMPRESSION */
928
static CC_HINT(nonnull(2)) fr_trie_t *trie_key_alloc(TALLOC_CTX *ctx, uint8_t const *key, int start_bit, int end_bit, void *data)
929
{
930
  fr_trie_node_t *node;
931
  uint16_t chunk;
932
  int bits = MAX_NODE_BITS;
933
934
  if (start_bit == end_bit) {
935
    return (fr_trie_t *) fr_trie_user_alloc(ctx, data);
936
  }
937
938
  bits = end_bit - start_bit;
939
  if (bits > MAX_NODE_BITS) bits = MAX_NODE_BITS;
940
941
  /*
942
   *  We only want one edge here.
943
   */
944
  node = trie_node_alloc(ctx, bits);
945
  if (!node) return NULL;
946
947
  chunk = get_chunk(key, start_bit, node->bits);
948
  node->trie[chunk] = trie_key_alloc(ctx, key, start_bit + node->bits, end_bit, data);
949
  if (!node->trie[chunk]) {
950
    talloc_free(node); /* no children */
951
    return NULL;
952
  }
953
  node->used++;
954
955
  return (fr_trie_t *) node;
956
}
957
#endif
958
959
960
#if 0
961
/** Split a compressed at bits
962
 *
963
 */
964
#ifdef WITH_NODE_COMPRESSION
965
static fr_trie_t *trie_comp_split(TALLOC_CTX *ctx, fr_trie_comp_t *comp, int start_bit, int bits)
966
{
967
  int i;
968
  fr_trie_comp_t *split;
969
970
  /*
971
   *  Can't split zero bits, more bits than the node has, or
972
   *  a node which has 1 bit.
973
   */
974
  if ((bits == 0) || (bits >= comp->bits)) {
975
    fr_strerror_printf("invalid value for comp split (%d / %d)", bits, comp->bits);
976
    return NULL;
977
  }
978
979
  split = fr_trie_comp_alloc(ctx, bits);
980
  if (!split) return NULL;
981
982
  if (start_bit > 7) start_bit &= 0x07;
983
984
  // walk over the edges, seeing how many edges have the same before bits
985
  //
986
  // if all have the same bits, then split by creating a path
987
  // node, and then a child split node.
988
989
  /*
990
   *  Walk over each edge, inserting the first chunk into
991
   *  the new node, and the split node...
992
   */
993
  for (i = 0; i < comp->used; i++) {
994
    int j, where;
995
    uint16_t before, after;
996
    uint8_t key[2];
997
    fr_trie_path_t *path;
998
999
    before = i >> (comp->bits - bits);
1000
    after = i & ((1 << bits) - 1);
1001
1002
    write_chunk(&key[0], start_bit, comp->bits, i);
1003
1004
    // see if "before" was already used in the newly created node.
1005
1006
    where = 0;
1007
1008
    for (j = 0; j < split->used; j++) {
1009
      if (before == split->index[j]) {
1010
        where = j;
1011
        break;
1012
      }
1013
    }
1014
1015
    if (split->index[where]) {
1016
      // the children MUST be different
1017
      // create another compressed node as a child, and go from there.
1018
1019
    } else {
1020
      split->index[split->used] = before;
1021
      path = fr_trie_path_alloc(ctx, &key[0], start_bit, start_bit + bits);
1022
      if (!path) goto fail;
1023
1024
      split->trie[split->used++] = (fr_trie_t *) path;
1025
      path->trie = comp->trie[i];
1026
    }
1027
  }
1028
1029
  return (fr_trie_t *) split;
1030
1031
fail:
1032
  for (i = 0; i < split->used; i++) {
1033
    talloc_free(split->trie[i]);
1034
  }
1035
  talloc_free(split);
1036
  return NULL;
1037
}
1038
#endif  /* WITH_NODE_COMPRESSION */
1039
#endif
1040
1041
/* ADD EDGES */
1042
1043
#ifdef WITH_PATH_COMPRESSION
1044
/** Add an edge to a node.
1045
 *
1046
 *  This function is so that we can abstract 2^N-way nodes, or
1047
 *  compressed edge nodes.
1048
 */
1049
static int trie_add_edge(fr_trie_t *trie, uint16_t chunk, fr_trie_t *child)
1050
0
{
1051
0
  fr_trie_node_t *node;
1052
1053
#ifdef WITH_NODE_COMPRESSION
1054
  if (trie->type == FR_TRIE_COMP) {
1055
    fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
1056
    int i, edge;
1057
1058
    if (chunk >= (1 << comp->bits)) return -1;
1059
1060
    if (comp->used >= MAX_COMP_EDGES) return -1;
1061
1062
    edge = comp->used;
1063
    for (i = 0; i < comp->used; i++) {
1064
      if (comp->index[i] < chunk) continue;
1065
1066
      if (comp->index[edge] == chunk) return -1;
1067
1068
      edge = i;
1069
      break;
1070
    }
1071
1072
    if (edge == MAX_COMP_EDGES) return -1;
1073
1074
    /*
1075
     *  Move the nodes up so that we have room for the
1076
     *  new edge.
1077
     */
1078
    for (i = edge; i < comp->used; i++) {
1079
      comp->index[i + 1] = comp->index[i];
1080
      comp->trie[i + 1] = comp->trie[i];
1081
    }
1082
1083
    comp->index[edge] = chunk;
1084
    comp->trie[edge] = child;
1085
1086
    comp->used++;
1087
    VERIFY(comp);
1088
    return 0;
1089
  }
1090
#endif
1091
1092
0
  if (trie->type != FR_TRIE_NODE) return -1;
1093
1094
0
  node = (fr_trie_node_t *) trie;
1095
1096
0
  if (chunk >= (1 << node->bits)) return -1;
1097
1098
0
  if (node->trie[chunk] != NULL) return -1;
1099
1100
0
  node->used++;
1101
0
  node->trie[chunk] = child;
1102
1103
0
  return 0;
1104
0
}
1105
#endif
1106
1107
/* MATCH FUNCTIONS */
1108
1109
typedef void *(*trie_key_match_t)(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact);
1110
1111
static void *trie_key_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact);
1112
1113
static void *trie_user_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact)
1114
0
{
1115
0
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
1116
0
  void *data;
1117
1118
  /*
1119
   *  We've matched the input exactly.  Return the
1120
   *  user data.
1121
   */
1122
0
  if (start_bit == end_bit) return user->data;
1123
1124
  /*
1125
   *  We're not at the end of the input.  Go find a
1126
   *  deeper match.  If a match is found, return
1127
   *  that.
1128
   */
1129
0
  data = trie_key_match(user->trie, key, start_bit, end_bit, exact);
1130
0
  if (data) return data;
1131
1132
  /*
1133
   *  We didn't find anything deeper in the trie,
1134
   *  AND we require an exact match.  That's a
1135
   *  failure.
1136
   */
1137
0
  if (exact) {
1138
0
    MPRINT2("no exact match at %d\n", __LINE__);
1139
0
    return NULL;
1140
0
  }
1141
1142
  /*
1143
   *  Return the closest (i.e. inexact) match.
1144
   */
1145
0
  return user->data;
1146
0
}
1147
1148
static void *trie_node_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact)
1149
0
{
1150
0
  uint16_t chunk;
1151
0
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
1152
1153
0
  chunk = get_chunk(key, start_bit, node->bits);
1154
0
  if (!node->trie[chunk]) {
1155
0
    MPRINT2("no match for node chunk %02x at %d\n", chunk, __LINE__);
1156
0
    return NULL;
1157
0
  }
1158
1159
0
  return trie_key_match(node->trie[chunk], key, start_bit + node->bits, end_bit, exact);
1160
0
}
1161
1162
#ifdef WITH_PATH_COMPRESSION
1163
static void *trie_path_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact)
1164
0
{
1165
0
  uint16_t chunk;
1166
0
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
1167
1168
0
  chunk = get_chunk(key, start_bit, path->bits);
1169
0
  if (chunk != path->chunk) return NULL;
1170
1171
0
  return trie_key_match(path->trie, key, start_bit + path->bits, end_bit, exact);
1172
0
}
1173
#endif
1174
1175
#ifdef WITH_NODE_COMPRESSION
1176
static void *trie_comp_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact)
1177
{
1178
  int i;
1179
  uint16_t chunk;
1180
  fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
1181
1182
  chunk = get_chunk(key, start_bit, comp->bits);
1183
1184
  for (i = 0; i < comp->used; i++) {
1185
    if (comp->index[i] < chunk) continue;
1186
1187
    if (comp->index[i] == chunk) {
1188
      return trie_key_match(comp->trie[i], key, start_bit + comp->bits, end_bit, exact);
1189
    }
1190
1191
    /*
1192
     *  The edges are ordered smallest to largest.  So
1193
     *  if the edge is larger than the chunk, NO edge
1194
     *  will match the chunk.
1195
     */
1196
    return NULL;
1197
  }
1198
1199
  return NULL;
1200
}
1201
#endif
1202
1203
static trie_key_match_t trie_match_table[FR_TRIE_MAX] = {
1204
  [ FR_TRIE_USER ] = trie_user_match,
1205
  [ FR_TRIE_NODE ] = trie_node_match,
1206
#ifdef WITH_PATH_COMPRESSION
1207
  [ FR_TRIE_PATH ] = trie_path_match,
1208
#endif
1209
#ifdef WITH_NODE_COMPRESSION
1210
  [ FR_TRIE_COMP ] = trie_comp_match,
1211
#endif
1212
};
1213
1214
1215
/** Match a key in a trie and return user ctx, if any
1216
 *
1217
 *  The key may be LONGER than entries in the trie.  In which case the
1218
 *  closest match is returned.
1219
 *
1220
 * @param trie    the trie
1221
 * @param key   the key
1222
 * @param start_bit the start bit
1223
 * @param end_bit the end bit
1224
 * @param exact   do we return an exact match, or the shortest one.
1225
 * @return
1226
 *  - NULL on not found
1227
 *  - void* user ctx on found
1228
 */
1229
static void *trie_key_match(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, bool exact)
1230
0
{
1231
0
  if (!trie) return NULL;
1232
1233
  /*
1234
   *  We've run out of trie, so it's not a match.
1235
   */
1236
0
  if ((start_bit + trie->bits) > end_bit) {
1237
0
    MPRINT2("%d + %d = %d > %d\n",
1238
0
           start_bit, trie->bits, start_bit + trie->bits, end_bit);
1239
#ifdef TESTING
1240
    MPRINT2("no match for key too short for trie NODE-%d at %d\n", trie->number, __LINE__);
1241
#endif
1242
0
    return NULL;
1243
0
  }
1244
1245
0
  TRIE_TYPE_CHECK(match, NULL);
1246
1247
  /*
1248
   *  Recursively match each type.
1249
   */
1250
0
  return trie_match_table[trie->type](trie, key, start_bit, end_bit, exact);
1251
0
}
1252
1253
/** Lookup a key in a trie and return user ctx, if any
1254
 *
1255
 *  The key may be LONGER than entries in the trie.  In which case the
1256
 *  closest match is returned.
1257
 *
1258
 * @param ft   the trie
1259
 * @param key  the key bytes
1260
 * @param keylen length in bits of the key
1261
 * @return
1262
 *  - NULL on not found
1263
 *  - void* user ctx on found
1264
 */
1265
void *fr_trie_lookup_by_key(fr_trie_t const *ft, void const *key, size_t keylen)
1266
0
{
1267
0
  fr_trie_user_t *user;
1268
1269
0
  if (keylen > MAX_KEY_BITS) return NULL;
1270
1271
0
  if (!ft->trie) return NULL;
1272
1273
0
  user = UNCONST(fr_trie_user_t *, ft);
1274
1275
0
  return trie_key_match(user->trie, key, 0, keylen, false);
1276
0
}
1277
1278
/** Match a key and length in a trie and return user ctx, if any
1279
 *
1280
 * Only the exact match is returned.
1281
 *
1282
 * @param ft   the trie
1283
 * @param key  the key bytes
1284
 * @param keylen length in bits of the key
1285
 * @return
1286
 *  - NULL on not found
1287
 *  - void* user ctx on found
1288
 */
1289
void *fr_trie_match_by_key(fr_trie_t const *ft, void const *key, size_t keylen)
1290
0
{
1291
0
  fr_trie_user_t *user;
1292
1293
0
  if (keylen > MAX_KEY_BITS) return NULL;
1294
1295
0
  if (!ft->trie) return NULL;
1296
1297
0
  user = UNCONST(fr_trie_user_t *, ft);
1298
1299
0
  return trie_key_match(user->trie, key, 0, keylen, true);
1300
0
}
1301
1302
/* INSERT FUNCTIONS */
1303
1304
#ifdef TESTING
1305
static void trie_check(fr_trie_t *trie, uint8_t const *key, int start_bit, int end_bit, void *data, int lineno)
1306
{
1307
  void *answer;
1308
1309
  trie_sprint(trie, key, start_bit, lineno);
1310
1311
  answer = trie_key_match(trie, key, start_bit, end_bit, true);
1312
  if (!answer) {
1313
    fr_strerror_printf("Failed trie check answer at %d", lineno);
1314
1315
    // print out the current trie!
1316
    MPRINT3("%.*sFailed to find user data %s from start %d end %d at %d\n", start_bit, spaces, data,
1317
      start_bit, end_bit, lineno);
1318
    fr_cond_assert(0);
1319
  }
1320
1321
  if (answer != data) {
1322
    fr_strerror_printf("Failed trie check answer == data at %d", lineno);
1323
1324
    MPRINT3("%.*sFound wrong user data %s != %s, from start %d end %d at %d\n", start_bit, spaces,
1325
      answer, data, start_bit, end_bit, lineno);
1326
    fr_cond_assert(0);
1327
  }
1328
}
1329
#else
1330
#define trie_check(_trie, _key, _start_bit, _end_bit, _data, _lineno)
1331
#endif
1332
1333
static int trie_key_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data) CC_HINT(nonnull(2,3,6));
1334
1335
typedef int (*trie_key_insert_t)(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data);
1336
1337
static int trie_user_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data)
1338
0
{
1339
0
  fr_trie_t *trie = *trie_p;
1340
0
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
1341
1342
0
  MPRINT3("user insert to start %d end %d with data %s\n", start_bit, end_bit, (char *) data);
1343
1344
  /*
1345
   *  Just insert the key into user->trie.
1346
   */
1347
0
  MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1348
0
  return trie_key_insert(ctx, &user->trie, key, start_bit, end_bit, data);
1349
0
}
1350
1351
static int trie_node_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data)
1352
0
{
1353
0
  fr_trie_t *trie = *trie_p;
1354
0
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
1355
0
  fr_trie_t *trie_to_free = NULL;
1356
0
  uint32_t chunk;
1357
1358
0
  MPRINT3("%.*snode insert end %d with data %s\n",
1359
0
    start_bit, spaces, end_bit, (char *) data);
1360
1361
  /*
1362
   *  The current node is longer than the input bits
1363
   *  for the key.  Split the node into a smaller
1364
   *  N-way node, and insert the key into the (now
1365
   *  fitting) node.
1366
   */
1367
0
  if ((start_bit + node->bits) > end_bit) {
1368
0
    fr_trie_node_t *split;
1369
1370
0
    MPRINT3("%.*snode insert splitting %d at %d start %d end %d with data %s\n",
1371
0
      start_bit, spaces,
1372
0
      node->bits, start_bit - end_bit,
1373
0
      start_bit, end_bit, (char *) data);
1374
1375
0
    split = trie_node_split(ctx, node, end_bit - start_bit);
1376
0
    if (!split) {
1377
0
      fr_strerror_printf("Failed splitting node at %d\n", __LINE__);
1378
0
      return -1;
1379
0
    }
1380
1381
0
    trie_to_free = (fr_trie_t *) node;
1382
0
    node = split;
1383
0
  }
1384
1385
0
  chunk = get_chunk(key, start_bit, node->bits);
1386
1387
  /*
1388
   *  No existing trie, create a brand new trie from
1389
   *  the key.
1390
   */
1391
0
  if (!node->trie[chunk]) {
1392
0
    node->trie[chunk] = trie_key_alloc(ctx, key, start_bit + node->bits, end_bit, data);
1393
0
    if (!node->trie[chunk]) {
1394
0
      fr_strerror_printf("Failed key_alloc at %d\n", __LINE__);
1395
0
      if (trie_to_free) trie_free(trie_to_free);
1396
0
      return -1;
1397
0
    }
1398
0
    node->used++;
1399
1400
0
  } else {
1401
    /*
1402
     *  Recurse in order to insert the key
1403
     *  into the current node.
1404
     */
1405
0
    MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1406
0
    if (trie_key_insert(ctx, &node->trie[chunk], key, start_bit + node->bits, end_bit, data) < 0) {
1407
0
      MPRINT("Failed recursing at %d\n", __LINE__);
1408
0
      if (trie_to_free) trie_free(trie_to_free);
1409
0
      return -1;
1410
0
    }
1411
0
  }
1412
1413
0
  trie_check((fr_trie_t *) node, key, start_bit, end_bit, data, __LINE__);
1414
1415
0
  MPRINT3("%.*snode insert returning at %d\n",
1416
0
    start_bit, spaces, __LINE__);
1417
1418
0
  if (trie_to_free) trie_free(trie_to_free);
1419
0
  *trie_p = (fr_trie_t *) node;
1420
0
  VERIFY(node);
1421
0
  return 0;
1422
0
}
1423
1424
#ifdef WITH_PATH_COMPRESSION
1425
static CC_HINT(nonnull(2,3,6)) int trie_path_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data)
1426
0
{
1427
0
  fr_trie_t *trie = *trie_p;
1428
0
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
1429
0
  uint32_t chunk;
1430
0
  int lcp, bits;
1431
0
  uint8_t const *key2;
1432
0
  int start_bit2;
1433
0
  fr_trie_t *node;
1434
0
  fr_trie_t *child;
1435
1436
0
  MPRINT3("%.*spath insert start %d end %d with key %02x%02x data %s\n",
1437
0
    start_bit, spaces, start_bit, end_bit, key[0], key[1], (char *) data);
1438
1439
0
  VERIFY(path);
1440
0
  trie_sprint((fr_trie_t *) path, key, start_bit, __LINE__);
1441
1442
  /*
1443
   *  The key exactly matches the path.  Recurse.
1444
   */
1445
0
  if (start_bit + path->bits <= end_bit) {
1446
0
    chunk = get_chunk(key, start_bit, path->bits);
1447
1448
    /*
1449
     *  The chunk matches exactly.  Recurse to
1450
     *  insert the key into the child trie.
1451
     */
1452
0
    if (chunk == path->chunk) {
1453
0
      MPRINT3("%.*spath chunk matches %04x bits of %d\n",
1454
0
        start_bit, spaces, chunk, path->bits);
1455
0
      MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1456
0
      if (trie_key_insert(ctx, &path->trie, key, start_bit + path->bits, end_bit, data) < 0) {
1457
0
        return -1;
1458
0
      }
1459
1460
0
      trie_check((fr_trie_t *) path, key, start_bit, end_bit, data, __LINE__);
1461
1462
0
      MPRINT3("%.*spath returning at %d\n", start_bit, spaces, __LINE__);
1463
0
      VERIFY(path);
1464
0
      return 0;
1465
0
    }
1466
1467
0
    bits = path->bits;
1468
0
    MPRINT3("%.*spath using %d\n", start_bit, spaces, path->bits);
1469
1470
0
  } else {
1471
    /*
1472
     *  Limit the number of bits we check to
1473
     *  the number of bits left in the key.
1474
     */
1475
0
    bits = end_bit - start_bit;
1476
0
    MPRINT3("%.*spath limiting %d to %d\n", start_bit, spaces, path->bits, bits);
1477
0
  }
1478
1479
  /*
1480
   *  Figure out what part of the key we need to
1481
   *  look at for LCP.
1482
   */
1483
0
  key2 = key;
1484
0
  start_bit2 = start_bit;
1485
0
  if (start_bit2 > 7) {
1486
0
    key2 += (start_bit2 >> 3);
1487
0
    start_bit2 -= 8 * (start_bit2 >> 3);
1488
0
  }
1489
1490
  /*
1491
   *  Get the LCP.  If we have one, split the path
1492
   *  node at the LCP.  Replace the parent with the
1493
   *  first half of the path, and build an N-way
1494
   *  node for the second half.
1495
   */
1496
0
  lcp = fr_trie_key_lcp(&path->key[0], bits, key2, bits, start_bit2);
1497
0
  MPRINT3("%.*spath lcp %d\n", start_bit, spaces, lcp);
1498
1499
  /*
1500
   *  This should have been caught above.
1501
   */
1502
0
  if (lcp == path->bits) {
1503
0
    fr_strerror_const("found lcp which should have been previously found");
1504
0
    return -1;
1505
0
  }
1506
1507
0
  if (lcp > 0) {
1508
0
    fr_trie_path_t *split;
1509
1510
    /*
1511
     *  Note that "path" is still valid after this
1512
     *  call.  We will rewrite things on the way back
1513
     *  up the stack.
1514
     */
1515
0
    MPRINT3("%.*spath split depth %d bits %d at lcp %d with data %s\n",
1516
0
      start_bit, spaces, start_bit, path->bits, lcp, (char *) data);
1517
1518
0
    MPRINT3("%.*spath key %02x%02x input key %02x%02x, offset %d\n",
1519
0
      start_bit, spaces,
1520
0
      path->key[0],path->key[1],
1521
0
      key[0], key[1],
1522
0
      start_bit2);
1523
1524
0
    split = trie_path_split(ctx, path, start_bit2, lcp);
1525
0
    if (!split) {
1526
0
      fr_strerror_printf("failed path split at %d\n", __LINE__);
1527
0
      return -1;
1528
0
    }
1529
1530
0
    trie_sprint((fr_trie_t *) path, key, start_bit, __LINE__);
1531
0
    trie_sprint((fr_trie_t *) split, key, start_bit, __LINE__);
1532
0
    trie_sprint((fr_trie_t *) split->trie, key, start_bit + split->bits, __LINE__);
1533
1534
    /*
1535
     *  Recurse to insert the key into the child node.
1536
     *  Note that if "bits > MAX_NODE_BITS", we will
1537
     *  have to split "path" again.
1538
     */
1539
0
    MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1540
0
    if (trie_key_insert(ctx, &split->trie, key, start_bit + split->bits, end_bit, data) < 0) {
1541
0
      talloc_free(split->trie);
1542
0
      talloc_free(split);
1543
0
      return -1;
1544
0
    }
1545
1546
    /*
1547
     *  We can't have two LCPs in a row here, as we
1548
     *  SHOULD have found the LCP above!
1549
     */
1550
0
    fr_cond_assert(split->type == FR_TRIE_PATH);
1551
0
    fr_cond_assert(split->trie->type != FR_TRIE_PATH);
1552
1553
0
    trie_check((fr_trie_t *) split, key, start_bit, end_bit, data, __LINE__);
1554
1555
0
    MPRINT3("%.*spath returning at %d\n", start_bit, spaces, __LINE__);
1556
0
    talloc_free(path);
1557
0
    *trie_p = (fr_trie_t *) split;
1558
0
    VERIFY(split);
1559
0
    return 0;
1560
0
  }
1561
1562
  /*
1563
   *  Else there's no common prefix.  Just create an
1564
   *  fanout node.
1565
   */
1566
  /*
1567
   *  We only want two edges here. Try to create a
1568
   *  compressed N-way node if possible.
1569
   */
1570
#ifdef WITH_NODE_COMPRESSION
1571
  if (bits > 2) {
1572
    if (bits > MAX_COMP_BITS) bits = MAX_COMP_BITS;
1573
1574
    MPRINT3("%.*sFanout to comp %d at depth %d data %s\n", start_bit, spaces, bits, start_bit, (char *) data);
1575
    node = (fr_trie_t *) fr_trie_comp_alloc(ctx, bits);
1576
  } else
1577
#endif
1578
0
  {
1579
    /*
1580
     *  Without path compression create no more than a
1581
     *  16-way node.
1582
     */
1583
0
    if (bits > MAX_NODE_BITS) bits = MAX_NODE_BITS;
1584
1585
0
    MPRINT3("%.*sFanout to node %d at depth %d data %s\n", start_bit, spaces, bits, start_bit, (char *) data);
1586
0
    node = (fr_trie_t *) trie_node_alloc(ctx, bits);
1587
0
  }
1588
0
  if (!node) return -1;
1589
1590
  /*
1591
   *  Get the chunk from the path, and insert the child trie
1592
   *  into the node at that chunk.
1593
   */
1594
0
  chunk = get_chunk(&path->key[0], start_bit2, node->bits);
1595
1596
0
  if (node->bits == path->bits) {
1597
0
    child = path->trie;
1598
1599
0
  } else {
1600
    /*
1601
     *  Skip the common prefix.
1602
     */
1603
0
    child = (fr_trie_t *) fr_trie_path_alloc(ctx, &path->key[0], start_bit2 + node->bits, start_bit2 + path->bits);
1604
0
    if (!child) {
1605
0
      fr_strerror_printf("failed allocating path child at %d", __LINE__);
1606
0
      return -1;
1607
0
    }
1608
1609
    /*
1610
     *  Patch in the child trie.
1611
     */
1612
0
    ((fr_trie_path_t *)child)->trie = path->trie;
1613
1614
0
    VERIFY(child);
1615
0
  }
1616
1617
0
  trie = NULL;
1618
1619
  /*
1620
   *  Recurse to insert the key into the second edge.  If
1621
   *  this fails, then we haven't changed anything.  So just
1622
   *  free memory and return.
1623
   *
1624
   *  Note that if "bits > DEFAULT_BITS", we will have to
1625
   *  split "path" again.
1626
   */
1627
0
  MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1628
0
  if (trie_key_insert(ctx, &trie, key, start_bit + node->bits, end_bit, data) < 0) {
1629
0
    talloc_free(node);
1630
0
    if (child != path->trie) talloc_free(child);
1631
0
    return -1;
1632
0
  }
1633
1634
  /*
1635
   *  Copy the first edge over to the first chunk.
1636
   */
1637
0
  if (trie_add_edge(node, chunk, child) < 0) {
1638
0
    fr_strerror_printf("chunk failure in insert node %d at %d", node->bits, __LINE__);
1639
0
    talloc_free(node);
1640
0
    if (child != path->trie) talloc_free(child);
1641
0
    return -1;
1642
0
  }
1643
1644
  /*
1645
   *  Copy the second edge from the new chunk.
1646
   */
1647
0
  chunk = get_chunk(key, start_bit, node->bits);
1648
0
  if (trie_add_edge(node, chunk, trie) < 0) {
1649
0
    fr_strerror_printf("chunk failure in insert node %d at %d", node->bits, __LINE__);
1650
0
    talloc_free(node);
1651
0
    trie_free(trie);
1652
0
    return -1;
1653
0
  }
1654
1655
0
  trie_check((fr_trie_t *) node, key, start_bit, end_bit, data, __LINE__);
1656
1657
0
  MPRINT3("%.*spath returning at %d\n", start_bit, spaces, __LINE__);
1658
1659
  /*
1660
   *  Only update the answer if the insert succeeded.
1661
   */
1662
0
  *trie_p = node;
1663
0
  talloc_free(path);
1664
0
  VERIFY(node);
1665
0
  return 0;
1666
0
}
1667
#endif
1668
1669
#ifdef WITH_NODE_COMPRESSION
1670
static CC_HINT(nonnull(2,3,6)) int trie_comp_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data)
1671
{
1672
  int i, bits;
1673
  fr_trie_t *trie = *trie_p;
1674
  fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
1675
  fr_trie_node_t *node;
1676
  uint16_t chunk;
1677
1678
  MPRINT3("%.*scomp insert start %d end %d with key %02x%02x data %s\n",
1679
    start_bit, spaces, start_bit, end_bit, key[0], key[1], (char *) data);
1680
1681
  if ((end_bit - start_bit) < comp->bits) {
1682
    fr_strerror_printf("Not implemented at %d", __LINE__);
1683
    return -1;
1684
  }
1685
1686
  chunk = get_chunk(key, start_bit, comp->bits);
1687
1688
  /*
1689
   *  Search for a matching edge.  If found, recurse and
1690
   *  insert the key there.
1691
   */
1692
  for (i = 0; i < comp->used; i++) {
1693
    if (comp->index[i] < chunk) continue;
1694
1695
    /*
1696
     *  We've found a matching chunk, recurse.
1697
     */
1698
    if (comp->index[i] == chunk) {
1699
      MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1700
      if (trie_key_insert(ctx, &comp->trie[i], key, start_bit + comp->bits, end_bit, data) < 0) {
1701
        MPRINT3("%.*scomp failed recursing at %d", start_bit, spaces, __LINE__);
1702
        return -1;
1703
      }
1704
1705
      trie_check((fr_trie_t *) comp, key, start_bit, end_bit, data, __LINE__);
1706
1707
      MPRINT3("%.*scomp returning at %d", start_bit, spaces, __LINE__);
1708
      VERIFY(comp);
1709
      return 0;
1710
    }
1711
1712
    /*
1713
     *  The chunk is larger than the current edge,
1714
     *  stop.
1715
     */
1716
    break;
1717
  }
1718
1719
  /*
1720
   *  No edge matches the chunk from the key.  Insert the
1721
   *  child trie into a place-holder entry, so that we don't
1722
   *  modify the current node on failure.
1723
   */
1724
  if (comp->used < MAX_COMP_EDGES) {
1725
    MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1726
    trie = NULL;
1727
    if (trie_key_insert(ctx, &trie, key, start_bit + comp->bits, end_bit, data) < 0) {
1728
      MPRINT3("%.*scomp failed recursing at %d", start_bit, spaces, __LINE__);
1729
      return -1;
1730
    }
1731
    fr_cond_assert(trie != NULL);
1732
1733
    if (trie_add_edge((fr_trie_t *) comp, chunk, trie) < 0) {
1734
      talloc_free(trie); // @todo - there may be multiple nodes here?
1735
      return -1;
1736
    }
1737
1738
    trie_check((fr_trie_t *) comp, key, start_bit, end_bit, data, __LINE__);
1739
1740
    VERIFY(comp);
1741
    return 0;
1742
  }
1743
1744
  /*
1745
   *  All edges are used.  Create an N-way node.
1746
   */
1747
1748
  /*
1749
   *  @todo - limit bits by calling
1750
   *  trie_comp_split()?
1751
   */
1752
  bits = comp->bits;
1753
1754
  MPRINT3("%.*scomp swapping to node bits %d at %d\n", start_bit, spaces, bits, __LINE__);
1755
1756
  node = trie_node_alloc(ctx, bits);
1757
  if (!node) return -1;
1758
1759
  for (i = 0; i < comp->used; i++) {
1760
    fr_cond_assert(node->trie[comp->index[i]] == NULL);
1761
    node->trie[comp->index[i]] = comp->trie[i];
1762
  }
1763
  node->used = comp->used;
1764
  node->used += (node->trie[chunk] == NULL); /* will get set if the recursive insert succeeds */
1765
1766
  /*
1767
   *  Insert the new chunk, which may or may not overlap
1768
   *  with an existing one.
1769
   */
1770
  MPRINT3("%.*srecurse at %d\n", start_bit, spaces, __LINE__);
1771
  if (trie_key_insert(ctx, &node->trie[chunk], key, start_bit + node->bits, end_bit, data) < 0) {
1772
    MPRINT3("%.*scomp failed recursing at %d", start_bit, spaces, __LINE__);
1773
    talloc_free(node);
1774
    return -1;
1775
  }
1776
1777
  trie_check((fr_trie_t *) node, key, start_bit, end_bit, data, __LINE__);
1778
1779
  MPRINT3("%.*scomp returning at %d", start_bit, spaces, __LINE__);
1780
1781
  talloc_free(comp);
1782
  *trie_p = (fr_trie_t *) node;
1783
  VERIFY(node);
1784
  return 0;
1785
}
1786
#endif
1787
1788
static trie_key_insert_t trie_insert_table[FR_TRIE_MAX] = {
1789
  [ FR_TRIE_USER ] = trie_user_insert,
1790
  [ FR_TRIE_NODE ] = trie_node_insert,
1791
#ifdef WITH_PATH_COMPRESSION
1792
  [ FR_TRIE_PATH ] = trie_path_insert,
1793
#endif
1794
#ifdef WITH_NODE_COMPRESSION
1795
  [ FR_TRIE_COMP ] = trie_comp_insert,
1796
#endif
1797
};
1798
1799
/** Insert a binary key into the trie
1800
 *
1801
 *  The key must have at least ((start_bit + keylen) >> 3) bytes
1802
 *
1803
 * @param ctx   the talloc ctx
1804
 * @param[in,out] trie_p the trie where things are inserted
1805
 * @param key   the binary key
1806
 * @param start_bit the start bit
1807
 * @param end_bit the end bit
1808
 * @param data    user data to insert
1809
 * @return
1810
 *  - <0 on error
1811
 *  - 0 on success
1812
 */
1813
static int trie_key_insert(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit, void *data)
1814
0
{
1815
0
  fr_trie_t *trie = *trie_p;
1816
1817
  /*
1818
   *  We've reached the end of the trie, but may still have
1819
   *  key bits to insert.
1820
   */
1821
0
  if (!trie) {
1822
0
    *trie_p = trie_key_alloc(ctx, key, start_bit, end_bit, data);
1823
0
    if (!*trie_p) return -1;
1824
0
    return 0;
1825
0
  }
1826
1827
0
  MPRINT3("%.*sIN recurse at %d\n", start_bit, spaces, __LINE__);
1828
0
  trie_sprint(trie, key, start_bit, __LINE__);
1829
1830
  /*
1831
   *  We've reached the end of the key.  Insert a user node
1832
   *  here, and push the remaining bits of the trie to after
1833
   *  the user node.
1834
   */
1835
0
  if (start_bit == end_bit) {
1836
0
    fr_trie_user_t *user;
1837
1838
0
    if (trie->type == FR_TRIE_USER) {
1839
0
      fr_strerror_printf("already has a user node at %d\n", __LINE__);
1840
0
      return -1;
1841
0
    }
1842
1843
0
    user = fr_trie_user_alloc(ctx, data);
1844
0
    if (!user) return -1;
1845
1846
0
    user->trie = trie;
1847
0
    *trie_p = (fr_trie_t *) user;
1848
0
    return 0;
1849
0
  }
1850
1851
0
  TRIE_TYPE_CHECK(insert, -1);
1852
1853
0
#ifndef TESTING
1854
0
  return trie_insert_table[trie->type](ctx, trie_p, key, start_bit, end_bit, data);
1855
#else
1856
  MPRINT3("%.*srecurse at start %d end %d with data %s\n", start_bit, spaces, start_bit, end_bit, (char *) data);
1857
1858
  if (trie_insert_table[trie->type](ctx, trie_p, key, start_bit, end_bit, data) < 0) {
1859
    return -1;
1860
  }
1861
1862
  trie_check(*trie_p, key, start_bit, end_bit, data, __LINE__);
1863
1864
  return 0;
1865
#endif
1866
0
}
1867
1868
/** Insert a key and user ctx into a trie
1869
 *
1870
 * @param ft   the trie
1871
 * @param key  the key
1872
 * @param keylen key length in bits
1873
 * @param data   user ctx information to associated with the key
1874
 * @return
1875
 *  - <0 on error
1876
 *  - 0 on success
1877
 */
1878
int fr_trie_insert_by_key(fr_trie_t *ft, void const *key, size_t keylen, void const *data)
1879
0
{
1880
0
  void *my_data;
1881
0
  fr_trie_user_t *user;
1882
1883
0
  if (keylen > MAX_KEY_BITS) {
1884
0
    fr_strerror_printf("keylen too long (%u > %d)", (unsigned int) keylen, MAX_KEY_BITS);
1885
0
    return -1;
1886
0
  }
1887
1888
0
  user = (fr_trie_user_t *) ft;
1889
1890
  /*
1891
   *  Do a lookup before insertion.  If we tried to insert
1892
   *  the key with new nodes and then discovered a conflict,
1893
   *  we would not be able to undo the process.  This check
1894
   *  ensures that the insertion can modify the trie in
1895
   *  place without worry.
1896
   */
1897
0
  if (trie_key_match(user->trie, key, 0, keylen, true) != NULL) {
1898
0
    fr_strerror_const("Cannot insert due to pre-existing key");
1899
0
    return -1;
1900
0
  }
1901
1902
0
  my_data = UNCONST(void *, data);
1903
0
  MPRINT2("No match for data, inserting...\n");
1904
1905
0
  MPRINT3("%.*srecurse STARTS at %d with %.*s=%s\n", 0, spaces, __LINE__,
1906
0
    (int) keylen, key, my_data);
1907
0
  return trie_key_insert(user->data, &user->trie, key, 0, keylen, my_data);
1908
0
}
1909
1910
/* REMOVE FUNCTIONS */
1911
static void *trie_key_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit);
1912
1913
typedef void *(*trie_key_remove_t)(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit);
1914
1915
static void *trie_user_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit)
1916
0
{
1917
0
  fr_trie_t *trie = *trie_p;
1918
0
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
1919
1920
  /*
1921
   *  We're at the end of the key, return the data
1922
   *  given here, and free the node that we're
1923
   *  removing.
1924
   */
1925
0
  if (start_bit == end_bit) {
1926
0
    void *data = user->data;
1927
1928
0
    *trie_p = user->trie;
1929
0
    talloc_free(user);
1930
1931
0
    return data;
1932
0
  }
1933
1934
0
  return trie_key_remove(ctx, &user->trie, key, start_bit, end_bit);
1935
0
}
1936
1937
static void *trie_node_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit)
1938
0
{
1939
0
  fr_trie_t *trie = *trie_p;
1940
0
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
1941
0
  uint32_t chunk;
1942
0
  void *data;
1943
1944
0
  chunk = get_chunk(key, start_bit, node->bits);
1945
0
  if (!node->trie[chunk]) return NULL;
1946
1947
0
  data = trie_key_remove(ctx, &node->trie[chunk], key, start_bit + node->bits, end_bit);
1948
0
  if (!data) return NULL;
1949
1950
  /*
1951
   *  The trie still has a subtrie.  Just return the data.
1952
   */
1953
0
  if (node->trie[chunk]) return data;
1954
1955
  /*
1956
   *  One less used edge.
1957
   */
1958
0
  node->used--;
1959
0
  if (node->used > 0) return data;
1960
1961
  /*
1962
   *  @todo - if we have path compression, and
1963
   *  node->used==1, then create a fr_trie_path_t from the
1964
   *  chunk, and concatenate it (if necessary) to any
1965
   *  trailing path compression node.
1966
   */
1967
1968
  /*
1969
   *  Our entire node is empty.  Delete it as we walk back up the trie.
1970
   */
1971
0
  *trie_p = NULL;
1972
0
  talloc_free(node); /* no children */
1973
0
  return data;
1974
0
}
1975
1976
#ifdef WITH_PATH_COMPRESSION
1977
static void *trie_path_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit)
1978
0
{
1979
0
  fr_trie_t *trie = *trie_p;
1980
0
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
1981
0
  uint32_t chunk;
1982
0
  void *data;
1983
1984
0
  chunk = get_chunk(key, start_bit, path->bits);
1985
1986
  /*
1987
   *  No match, can't remove it.
1988
   */
1989
0
  if (path->chunk != chunk) return NULL;
1990
1991
0
  data = trie_key_remove(ctx, &path->trie, key, start_bit + path->bits, end_bit);
1992
0
  if (!data) return NULL;
1993
1994
  /*
1995
   *  The trie still has a subtrie.  Just return the data.
1996
   */
1997
0
  if (path->trie) return data;
1998
1999
  /*
2000
   *  Our entire path is empty.  Delete it as we walk back up the trie.
2001
   */
2002
0
  *trie_p = NULL;
2003
0
  talloc_free(path); /* no children */
2004
0
  return data;
2005
0
}
2006
#endif
2007
2008
#ifdef WITH_NODE_COMPRESSION
2009
static void *trie_comp_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit)
2010
{
2011
  int i, j;
2012
  uint16_t chunk;
2013
  void *data;
2014
  fr_trie_comp_t *comp = *(fr_trie_comp_t **) trie_p;
2015
  fr_trie_path_t *path;
2016
2017
  chunk = get_chunk(key, start_bit, comp->bits);
2018
2019
  MPRINT3("%.*sremove at %d\n", start_bit, spaces, __LINE__);
2020
  trie_sprint(*trie_p, key, start_bit, __LINE__);
2021
2022
  for (i = 0; i < comp->used; i++) {
2023
    if (comp->index[i] < chunk) continue;
2024
2025
    if (comp->index[i] == chunk) {
2026
      break;
2027
    }
2028
2029
    /*
2030
     *  The edges are ordered smallest to largest.  So
2031
     *  if the edge is larger than the chunk, NO edge
2032
     *  will match the chunk.
2033
     */
2034
    if (comp->index[i] > chunk) return NULL;
2035
  }
2036
2037
  /*
2038
   *  Didn't find it, fail.
2039
   */
2040
  if (i >= comp->used) return NULL;
2041
2042
  fr_cond_assert(chunk == comp->index[i]);
2043
2044
  data = trie_key_remove(ctx, &comp->trie[i], key, start_bit + comp->bits, end_bit);
2045
  if (!data) return NULL;
2046
2047
  /*
2048
   *  The trie still has a subtrie.  Just return the data.
2049
   */
2050
  if (comp->trie[i]) {
2051
    MPRINT3("%.*sremove at %d\n", start_bit, spaces, __LINE__);
2052
    trie_sprint((fr_trie_t *) comp, key, start_bit, __LINE__);
2053
    VERIFY(comp);
2054
    return data;
2055
  }
2056
2057
  /*
2058
   *  Shrinking at the end is easy, we don't need to do
2059
   *  anything.  For shrinking in the middle, we just copy
2060
   *  the entries down.
2061
   */
2062
  for (j = i; j < comp->used - 1; j++) {
2063
    comp->index[j] = comp->index[j + 1];
2064
    comp->trie[j] = comp->trie[j + 1];
2065
  }
2066
  comp->used--;
2067
2068
  if (comp->used >= 2) {
2069
    VERIFY(comp);
2070
    return data;
2071
  }
2072
2073
  /*
2074
   *  Our entire path is empty.  Delete it as we walk back
2075
   *  up the trie.  We hope that this doesn't happen.
2076
   */
2077
  if (!comp->used) {
2078
    *trie_p = NULL;
2079
    talloc_free(comp); /* no children */
2080
    MPRINT3("%.*sremove at %d\n", start_bit, spaces, __LINE__);
2081
    return data;
2082
  }
2083
2084
  /*
2085
   *  Only one edge.  Turn it back into a path node.  Note
2086
   *  that we pass "key" here, which is wrong... that's the
2087
   *  key we're removing, not the key left in the node.  But
2088
   *  we fix that later.
2089
   *
2090
   *  @todo - check the child. If it's also a path node, try
2091
   *  to concatenate the nodes together.
2092
   */
2093
  path = fr_trie_path_alloc(ctx, key, start_bit, start_bit + comp->bits);
2094
  if (!path) return data;
2095
2096
  /*
2097
   *  Tie the new node in.
2098
   */
2099
  path->trie = comp->trie[0];
2100
2101
  /*
2102
   *  Fix up the chunk and key to be the one left in the
2103
   *  trie, not the one which was removed.
2104
   */
2105
  path->chunk = comp->index[0];
2106
  write_chunk(&path->key[0], start_bit & 0x07, path->bits, path->chunk);
2107
2108
  *trie_p = (fr_trie_t *) path;
2109
  talloc_free(comp);
2110
  VERIFY(path);
2111
  return data;
2112
}
2113
#endif
2114
2115
static trie_key_remove_t trie_remove_table[FR_TRIE_MAX] = {
2116
  [ FR_TRIE_USER ] = trie_user_remove,
2117
  [ FR_TRIE_NODE ] = trie_node_remove,
2118
#ifdef WITH_PATH_COMPRESSION
2119
  [ FR_TRIE_PATH ] = trie_path_remove,
2120
#endif
2121
#ifdef WITH_NODE_COMPRESSION
2122
  [ FR_TRIE_COMP ] = trie_comp_remove,
2123
#endif
2124
};
2125
2126
/** Remove a key from a trie, and return the user data.
2127
 *
2128
 */
2129
static void *trie_key_remove(TALLOC_CTX *ctx, fr_trie_t **trie_p, uint8_t const *key, int start_bit, int end_bit)
2130
0
{
2131
0
  fr_trie_t *trie = *trie_p;
2132
2133
0
  if (!trie) return NULL;
2134
2135
  /*
2136
   *  We can't remove a key which is shorter than the
2137
   *  current trie.
2138
   */
2139
0
  if ((start_bit + trie->bits) > end_bit) return NULL;
2140
2141
0
  TRIE_TYPE_CHECK(remove, NULL);
2142
2143
0
  return trie_remove_table[trie->type](ctx, trie_p, key, start_bit, end_bit);
2144
0
}
2145
2146
/** Remove a key and return the associated user ctx
2147
 *
2148
 *  The key must match EXACTLY.  This is not a prefix match.
2149
 *
2150
 * @param ft   the trie
2151
 * @param key  the key
2152
 * @param keylen key length in bits
2153
 * @return
2154
 *  - NULL on not found
2155
 *  - user ctx data on success
2156
 */
2157
void *fr_trie_remove_by_key(fr_trie_t *ft, void const *key, size_t keylen)
2158
0
{
2159
0
  fr_trie_user_t *user;
2160
2161
0
  if (keylen > MAX_KEY_BITS) return NULL;
2162
2163
0
  if (!ft->trie) return NULL;
2164
2165
0
  user = (fr_trie_user_t *) ft;
2166
2167
  /*
2168
   *  Remove the user trie, not ft->trie.
2169
   */
2170
0
  return trie_key_remove(user->data, &user->trie, key, 0, (int) keylen);
2171
0
}
2172
2173
/* WALK FUNCTIONS */
2174
2175
typedef struct fr_trie_callback_s fr_trie_callback_t;
2176
2177
typedef int (*fr_trie_key_walk_t)(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more);
2178
2179
static int trie_key_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more);
2180
2181
struct fr_trie_callback_s {
2182
  uint8_t     *start;
2183
  uint8_t const   *end;
2184
2185
  void      *ctx;
2186
2187
  fr_trie_key_walk_t  callback;
2188
  fr_trie_walk_t    user_callback;
2189
};
2190
2191
static int trie_user_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more)
2192
0
{
2193
0
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
2194
2195
0
  if (!user->trie) return 0;
2196
2197
0
  return trie_key_walk(user->trie, cb, depth, more);
2198
0
}
2199
2200
static int trie_node_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more)
2201
0
{
2202
0
  int i, used;
2203
0
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
2204
2205
0
  used = 0;
2206
0
  for (i = 0; i < (1 << node->bits); i++) {
2207
0
    if (!node->trie[i]) continue;
2208
2209
0
    write_chunk(cb->start, depth, node->bits, (uint16_t) i);
2210
0
    used++;
2211
2212
0
    if (trie_key_walk(node->trie[i], cb, depth + node->bits,
2213
0
             more || (used < node->used)) < 0) {
2214
0
      return -1;
2215
0
    }
2216
0
  }
2217
2218
0
  return 0;
2219
0
}
2220
2221
#ifdef WITH_PATH_COMPRESSION
2222
static int trie_path_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more)
2223
{
2224
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
2225
2226
  write_chunk(cb->start, depth, path->bits, path->chunk);
2227
2228
  fr_cond_assert(path->trie != NULL);
2229
  return trie_key_walk(path->trie, cb, depth + path->bits, more);
2230
}
2231
#endif
2232
2233
#ifdef WITH_NODE_COMPRESSION
2234
static int trie_comp_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more)
2235
{
2236
  int i, used;
2237
  fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
2238
2239
  used = 0;
2240
  for (i = 0; i < comp->used; i++) {
2241
    write_chunk(cb->start, depth, comp->bits, comp->index[i]);
2242
2243
    fr_cond_assert(comp->trie[i] != NULL);
2244
2245
    used++;
2246
    if (trie_key_walk(comp->trie[i], cb, depth + comp->bits,
2247
             more || (used < comp->used)) < 0) {
2248
      return -1;
2249
    }
2250
  }
2251
2252
  return 0;
2253
}
2254
#endif
2255
2256
static fr_trie_key_walk_t trie_walk_table[FR_TRIE_MAX] = {
2257
  [ FR_TRIE_USER ] = trie_user_walk,
2258
  [ FR_TRIE_NODE ] = trie_node_walk,
2259
#ifdef WITH_PATH_COMPRESSION
2260
  [ FR_TRIE_PATH ] = trie_path_walk,
2261
#endif
2262
#ifdef WITH_NODE_COMPRESSION
2263
  [ FR_TRIE_COMP ] = trie_comp_walk,
2264
#endif
2265
};
2266
2267
static int trie_key_walk(fr_trie_t *trie, fr_trie_callback_t *cb, int depth, bool more)
2268
0
{
2269
  /*
2270
   *  Do the callback before anything else.
2271
   */
2272
0
  if (cb->callback(trie, cb, depth, more) < 0) return -1;
2273
2274
0
  if (!trie) {
2275
0
    fr_cond_assert(depth == 0);
2276
0
    return 0;
2277
0
  }
2278
2279
0
  TRIE_TYPE_CHECK(walk, -1);
2280
2281
  /*
2282
   *  No more buffer space, stop.
2283
   */
2284
0
  if ((cb->start + BYTEOF(depth + trie->bits + 8)) >= cb->end) return 0;
2285
2286
0
  return trie_walk_table[trie->type](trie, cb, depth, more);
2287
0
}
2288
2289
#ifdef WITH_TRIE_VERIFY
2290
/* VERIFY FUNCTIONS */
2291
2292
typedef int (*trie_verify_t)(fr_trie_t *trie);
2293
2294
static int trie_user_verify(fr_trie_t *trie)
2295
{
2296
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
2297
2298
  if (!user->data) {
2299
    fr_strerror_const("user node has no user data");
2300
    return -1;
2301
  }
2302
2303
  if (!user->trie) return 0;
2304
2305
  return trie_verify(user->trie);
2306
}
2307
2308
static int trie_node_verify(fr_trie_t *trie)
2309
{
2310
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
2311
  int i, used;
2312
2313
  if ((node->bits == 0) || (node->bits > MAX_NODE_BITS)) {
2314
    fr_strerror_printf("N-way node has invalid bits %d",
2315
           node->bits);
2316
    return -1;
2317
  }
2318
2319
  if ((node->used == 0) || (node->used > (1 << node->bits))) {
2320
    fr_strerror_printf("N-way node has invalid used %d for bits %d",
2321
           node->used, node->bits);
2322
    return -1;
2323
  }
2324
2325
  used = 0;
2326
  for (i = 0; i < (1 << node->bits); i++) {
2327
    if (!node->trie[i]) continue;
2328
2329
    if (trie_verify(node->trie[i]) < 0) return -1;
2330
2331
    used++;
2332
  }
2333
2334
  if (used != node->used) {
2335
    fr_strerror_printf("N-way node has incorrect used %d when actually used %d",
2336
           node->used, used);
2337
    return -1;
2338
  }
2339
2340
  return 0;
2341
}
2342
2343
#ifdef WITH_PATH_COMPRESSION
2344
static int trie_path_verify(fr_trie_t *trie)
2345
{
2346
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
2347
2348
  if ((path->bits == 0) || (path->bits > 16)) {
2349
    fr_strerror_printf("path node has invalid bits %d",
2350
           path->bits);
2351
    return -1;
2352
  }
2353
2354
  if (!path->trie) {
2355
    fr_strerror_const("path node has no child trie");
2356
    return -1;
2357
  }
2358
2359
  return trie_verify(path->trie);
2360
}
2361
#endif
2362
2363
#ifdef WITH_NODE_COMPRESSION
2364
static int trie_comp_verify(fr_trie_t *trie)
2365
{
2366
  int i, used;
2367
  fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
2368
2369
  if ((comp->bits == 0) || (comp->bits > MAX_COMP_BITS)) {
2370
    fr_strerror_printf("comp node has invalid bits %d",
2371
           comp->bits);
2372
    return -1;
2373
  }
2374
2375
  used = 0;
2376
  for (i = 0; i < comp->used; i++) {
2377
    if (!comp->trie[i]) {
2378
      fr_strerror_printf("comp node has no child trie at %d", i);
2379
      return -1;
2380
    }
2381
2382
    if ((i + 1) < comp->used) {
2383
      if (comp->index[i] >= comp->index[i + 1]) {
2384
        fr_strerror_printf("comp node has inverted edges at %d (%04x >= %04x)",
2385
               i, comp->index[i], comp->index[i + 1]);
2386
        return -1;
2387
      }
2388
    }
2389
2390
    if (trie_verify(comp->trie[i]) < 0) return -1;
2391
    used++;
2392
  }
2393
2394
  if (used != comp->used) {
2395
    fr_strerror_printf("Compressed node has incorrect used %d when actually used %d",
2396
           comp->used, used);
2397
    return -1;
2398
  }
2399
2400
  return 0;
2401
}
2402
#endif
2403
2404
static trie_verify_t trie_verify_table[FR_TRIE_MAX] = {
2405
  [ FR_TRIE_USER ] = trie_user_verify,
2406
  [ FR_TRIE_NODE ] = trie_node_verify,
2407
#ifdef WITH_PATH_COMPRESSION
2408
  [ FR_TRIE_PATH ] = trie_path_verify,
2409
#endif
2410
#ifdef WITH_NODE_COMPRESSION
2411
  [ FR_TRIE_COMP ] = trie_comp_verify,
2412
#endif
2413
};
2414
2415
2416
/**  Verify the trie nodes
2417
 *
2418
 */
2419
static int trie_verify(fr_trie_t *trie)
2420
{
2421
  if (!trie) return 0;
2422
2423
  TRIE_TYPE_CHECK(verify, -1);
2424
2425
  return trie_verify_table[trie->type](trie);
2426
}
2427
#endif  /* WITH_TRIE_VERIFY */
2428
2429
/* MISCELLANEOUS FUNCTIONS */
2430
2431
#ifdef TESTING
2432
/** Dump a trie edge in canonical form.
2433
 *
2434
 */
2435
static void trie_dump_edge(FILE *fp, fr_trie_t *trie)
2436
{
2437
  fr_cond_assert(trie != NULL);
2438
2439
  fprintf(fp, "NODE-%d\n", trie->number);
2440
  return;
2441
}
2442
2443
2444
typedef void (*fr_trie_dump_t)(FILE *fp, fr_trie_t *trie, char const *key, int keylen);
2445
2446
static void trie_user_dump(FILE *fp, fr_trie_t *trie, char const *key, int keylen)
2447
{
2448
  fr_trie_user_t *user = (fr_trie_user_t *) trie;
2449
  int bytes = BYTES(keylen);
2450
2451
  fprintf(fp, "{ NODE-%d\n", user->number);
2452
  fprintf(fp, "\ttype\tUSER\n");
2453
  fprintf(fp, "\tinput\t{%d}%.*s\n", keylen, bytes, key);
2454
2455
  fprintf(fp, "\tdata\t\"%s\"\n", (char const *) user->data);
2456
  if (!user->trie) {
2457
    fprintf(fp, "}\n\n");
2458
    return;
2459
  }
2460
2461
  fprintf(fp, "\tnext\t");
2462
  trie_dump_edge(fp, user->trie);
2463
  fprintf(fp, "}\n\n");
2464
}
2465
2466
static void trie_node_dump(FILE *fp, fr_trie_t *trie, char const *key, int keylen)
2467
{
2468
  fr_trie_node_t *node = (fr_trie_node_t *) trie;
2469
  int i;
2470
  int bytes = BYTES(keylen);
2471
2472
  fprintf(fp, "{ NODE-%d\n", node->number);
2473
  fprintf(fp, "\ttype\tNODE\n");
2474
  fprintf(fp, "\tinput\t{%d}%.*s\n", keylen, bytes, key);
2475
2476
  fprintf(fp, "\tbits\t%d\n", node->bits);
2477
  fprintf(fp, "\tused\t%d\n", node->used);
2478
2479
  for (i = 0; i < (1 << node->bits); i++) {
2480
    if (!node->trie[i]) continue;
2481
2482
    fprintf(fp, "\t%02x\t", (int) i);
2483
    trie_dump_edge(fp, node->trie[i]);
2484
  }
2485
  fprintf(fp, "}\n\n");
2486
}
2487
2488
#ifdef WITH_PATH_COMPRESSION
2489
static void trie_path_dump(FILE *fp, fr_trie_t *trie, char const *key, int keylen)
2490
{
2491
  fr_trie_path_t *path = (fr_trie_path_t *) trie;
2492
  int bytes = BYTES(keylen);
2493
2494
  fprintf(fp, "{ NODE-%d\n", path->number);
2495
  fprintf(fp, "\ttype\tPATH\n");
2496
  fprintf(fp, "\tinput\t{%d}%.*s\n", keylen, bytes, key);
2497
2498
  fprintf(fp, "\tbits\t%d\n", (int) path->bits);
2499
  fprintf(fp, "\tpath\t");
2500
2501
  fprintf(fp, "%02x %02x", path->key[0], path->key[1]);
2502
2503
  fprintf(fp, "\n");
2504
2505
  fprintf(fp, "\tnext\t");
2506
  trie_dump_edge(fp, path->trie);
2507
2508
  fprintf(fp, "}\n\n");
2509
}
2510
#endif
2511
2512
#ifdef WITH_NODE_COMPRESSION
2513
static void trie_comp_dump(FILE *fp, fr_trie_t *trie, char const *key, int keylen)
2514
{
2515
  fr_trie_comp_t *comp = (fr_trie_comp_t *) trie;
2516
  int i;
2517
  int bytes = BYTES(keylen);
2518
2519
  fprintf(fp, "{ NODE-%d\n", comp->number);
2520
  fprintf(fp, "\ttype\tCOMP\n");
2521
  fprintf(fp, "\tinput\t{%d}%.*s\n", keylen, bytes, key);
2522
2523
  fprintf(fp, "\tbits\t%d\n", comp->bits);
2524
  fprintf(fp, "\tused\t%d\n", comp->used);
2525
2526
  for (i = 0; i < comp->used; i++) {
2527
    fprintf(fp, "\t%d = %02x\t", i, comp->index[i]);
2528
    trie_dump_edge(fp, comp->trie[i]);
2529
  }
2530
  fprintf(fp, "}\n\n");
2531
}
2532
2533
#endif
2534
2535
static fr_trie_dump_t trie_dump_table[FR_TRIE_MAX] = {
2536
  [ FR_TRIE_USER ] = trie_user_dump,
2537
  [ FR_TRIE_NODE ] = trie_node_dump,
2538
#ifdef WITH_PATH_COMPRESSION
2539
  [ FR_TRIE_PATH ] = trie_path_dump,
2540
#endif
2541
#ifdef WITH_NODE_COMPRESSION
2542
  [ FR_TRIE_COMP ] = trie_comp_dump,
2543
#endif
2544
};
2545
2546
2547
/**  Dump the trie nodes
2548
 *
2549
 */
2550
static int _trie_dump_cb(fr_trie_t *trie, fr_trie_callback_t *cb, int keylen, UNUSED bool more)
2551
{
2552
  FILE *fp = cb->ctx;
2553
2554
  if (!trie) return 0;
2555
2556
  TRIE_TYPE_CHECK(dump, -1);
2557
2558
  trie_dump_table[trie->type](fp, trie, (char const *) cb->start, keylen);
2559
  return 0;
2560
}
2561
2562
/**  Print the strings accepted by a trie to a file
2563
 *
2564
 */
2565
static int _trie_print_cb(fr_trie_t *trie, fr_trie_callback_t *cb, int keylen, UNUSED bool more)
2566
{
2567
  int bytes;
2568
  FILE *fp = cb->ctx;
2569
  fr_trie_user_t *user;
2570
2571
  if (!trie || (trie->type != FR_TRIE_USER)) {
2572
    return 0;
2573
  }
2574
2575
  bytes = BYTES(keylen);
2576
  user = (fr_trie_user_t *) trie;
2577
2578
  if ((keylen & 0x07) != 0) {
2579
    fprintf(fp, "{%d}%.*s\t%s\n", keylen, bytes, cb->start, (char const *) user->data);
2580
  } else {
2581
    fprintf(fp, "%.*s\t%s\n", bytes, cb->start, (char const *) user->data);
2582
  }
2583
2584
  return 0;
2585
}
2586
#endif  /* TESTING */
2587
2588
2589
/**  Implement the user-visible side of the walk callback.
2590
 *
2591
 */
2592
static int _trie_user_cb(fr_trie_t *trie, fr_trie_callback_t *cb, int keylen, UNUSED bool more)
2593
0
{
2594
0
  fr_trie_user_t *user;
2595
2596
0
  if (!trie || (trie->type != FR_TRIE_USER)) return 0;
2597
2598
0
  user = (fr_trie_user_t *) trie;
2599
2600
  /*
2601
   *  Call the user function with the key, key length, and data.
2602
   */
2603
0
  if (cb->user_callback(cb->start, keylen, UNCONST(void *, user->data), cb->ctx) < 0) {
2604
0
    return -1;
2605
0
  }
2606
2607
0
  return 0;
2608
0
}
2609
2610
int fr_trie_walk(fr_trie_t *ft, void *ctx, fr_trie_walk_t callback)
2611
0
{
2612
0
  uint8_t buffer[MAX_KEY_BYTES + 1];
2613
0
  fr_trie_callback_t my_cb = {
2614
0
    .start = buffer,
2615
0
    .end = buffer + sizeof(buffer),
2616
0
    .callback = _trie_user_cb,
2617
0
    .user_callback = callback,
2618
0
    .ctx = ctx
2619
0
  };
2620
2621
0
  memset(buffer, 0, sizeof(buffer));
2622
2623
  /*
2624
   *  Call the internal walk function to do the work.
2625
   */
2626
0
  return trie_key_walk(ft->trie, &my_cb, 0, false);
2627
0
}
2628
2629
2630
/**********************************************************************/
2631
2632
/*
2633
 *  Object API.
2634
 */
2635
2636
/** Find an element in the trie, returning the data.
2637
 *
2638
 * @param[in] ft to search in.
2639
 * @param[in] data to find.
2640
 * @return
2641
 *  - User data matching the data passed in.
2642
 *  - NULL if nothing matched passed data.
2643
 */
2644
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2645
void *fr_trie_find(fr_trie_t *ft, void const *data)
2646
0
{
2647
0
  fr_trie_user_t *user = (fr_trie_user_t *) ft;
2648
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2649
0
  uint8_t *key;
2650
0
  size_t keylen;
2651
2652
0
  key = &uctx->buffer[0];
2653
0
  keylen = sizeof(uctx->buffer) * 8;
2654
2655
0
  if (!uctx->get_key) return NULL;
2656
2657
0
  if (uctx->get_key(&key, &keylen, data) < 0) return NULL;
2658
2659
0
  return fr_trie_lookup_by_key(ft, key, keylen);
2660
0
}
2661
2662
/** Match an element exactly in the trie, returning the data.
2663
 *
2664
 * @param[in] ft to search in.
2665
 * @param[in] data to find.
2666
 * @return
2667
 *  - User data matching the data passed in.
2668
 *  - NULL if nothing matched passed data.
2669
 */
2670
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2671
void *fr_trie_match(fr_trie_t *ft, void const *data)
2672
0
{
2673
0
  fr_trie_user_t *user = (fr_trie_user_t *) ft;
2674
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2675
0
  uint8_t *key;
2676
0
  size_t keylen;
2677
2678
0
  key = &uctx->buffer[0];
2679
0
  keylen = sizeof(uctx->buffer) * 8;
2680
2681
0
  if (!uctx->get_key) return NULL;
2682
2683
0
  if (uctx->get_key(&key, &keylen, data) < 0) return NULL;
2684
2685
0
  return fr_trie_match_by_key(ft, key, keylen);
2686
0
}
2687
2688
/** Insert data into a trie
2689
 *
2690
 * @param[in] ft  to insert data into.
2691
 * @param[in] data  to insert.
2692
 * @return
2693
 *  - true if data was inserted.
2694
 *  - false if data already existed and was not inserted.
2695
 */
2696
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2697
bool fr_trie_insert(fr_trie_t *ft, void const *data)
2698
0
{
2699
0
  fr_trie_user_t *user = (fr_trie_user_t *) ft;
2700
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2701
0
  uint8_t *key;
2702
0
  size_t keylen;
2703
2704
0
  key = &uctx->buffer[0];
2705
0
  keylen = sizeof(uctx->buffer) * 8;
2706
2707
0
  if (!uctx->get_key) return false;
2708
2709
0
  if (uctx->get_key(&key, &keylen, data) < 0) return false;
2710
2711
0
  if (fr_trie_insert_by_key(ft, key, keylen, data) < 0) return false;
2712
2713
0
  return true;
2714
0
}
2715
2716
/** Replace old data with new data, OR insert if there is no old
2717
 *
2718
 * @param[out] old  data that was replaced.  If this argument
2719
 *      is not NULL, then the old data will not
2720
 *      be freed, even if a free function is
2721
 *      configured.
2722
 * @param[in] ft  to insert data into.
2723
 * @param[in] data  to replace.
2724
 * @return
2725
 *  - 1 if data was replaced.
2726
 *      - 0 if data was inserted.
2727
 *      - -1 if we failed to replace data
2728
 */
2729
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2730
int fr_trie_replace(void **old, fr_trie_t *ft, void const *data)
2731
0
{
2732
0
  fr_trie_user_t  *user = (fr_trie_user_t *) ft;
2733
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2734
0
  uint8_t   *key;
2735
0
  size_t    keylen;
2736
0
  void    *found;
2737
2738
0
  key = &uctx->buffer[0];
2739
0
  keylen = sizeof(uctx->buffer) * 8;
2740
2741
0
  if (!uctx->get_key) return -1;
2742
2743
0
  if (uctx->get_key(&key, &keylen, data) < 0) return -1;
2744
2745
0
  found = trie_key_match(ft, key, 0, keylen, true); /* do exact match */
2746
0
  if (found) {
2747
0
    if (old) *old = found;
2748
0
    if (fr_trie_remove_by_key(ft, key, keylen) != found) return -1;
2749
0
  } else {
2750
0
    if (old) *old = NULL;
2751
0
  }
2752
2753
  /*
2754
   *  Insert the new key.
2755
   */
2756
0
  if (fr_trie_insert_by_key(ft, key, keylen, data) < 0) return -1;
2757
2758
0
  return found ? 1 : 0;
2759
0
}
2760
2761
/** Remove an entry, without freeing the data
2762
 *
2763
 * @param[in] ft  to remove data from.
2764
 * @param[in] data  to remove.
2765
 * @return
2766
 *      - The user data we removed.
2767
 *  - NULL if we couldn't find any matching data.
2768
 */
2769
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2770
void *fr_trie_remove(fr_trie_t *ft, void const *data)
2771
0
{
2772
0
  fr_trie_user_t *user = (fr_trie_user_t *) ft;
2773
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2774
0
  uint8_t *key;
2775
0
  size_t keylen;
2776
2777
0
  key = &uctx->buffer[0];
2778
0
  keylen = sizeof(uctx->buffer) * 8;
2779
2780
0
  if (!uctx->get_key) return NULL;
2781
2782
0
  if (uctx->get_key(&key, &keylen, data) < 0) return NULL;
2783
2784
0
  return fr_trie_remove_by_key(ft, key, keylen);
2785
0
}
2786
2787
/** Remove node and free data (if a free function was specified)
2788
 *
2789
 * @param[in] ft  to remove data from.
2790
 * @param[in] data  to remove/free.
2791
 * @return
2792
 *  - true if we removed data.
2793
 *      - false if we couldn't find any matching data.
2794
 */
2795
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2796
bool fr_trie_delete(fr_trie_t *ft, void const *data)
2797
0
{
2798
0
  fr_trie_user_t *user = (fr_trie_user_t *) ft;
2799
0
  fr_trie_ctx_t *uctx = talloc_get_type_abort(user->data, fr_trie_ctx_t);
2800
0
  uint8_t *key;
2801
0
  size_t keylen;
2802
0
  void *found;
2803
2804
0
  key = &uctx->buffer[0];
2805
0
  keylen = sizeof(uctx->buffer) * 8;
2806
2807
0
  if (!uctx->get_key) return false;
2808
2809
0
  if (uctx->get_key(&key, &keylen, data) < 0) return false;
2810
2811
0
  found = fr_trie_remove_by_key(ft, key, keylen);
2812
0
  if (!found) return false;
2813
2814
0
  if (!uctx->free_data) return true;
2815
2816
0
  uctx->free_data(found);
2817
0
  return true;
2818
0
}
2819
2820
/** Return how many nodes there are in a trie
2821
 *
2822
 * @param[in] ft  to return node count for.
2823
 */
2824
CC_NO_UBSAN(function) /* UBSAN: false positive - htrie call with first argument of void * trips --fsanitize=function */
2825
unsigned int fr_trie_num_elements(UNUSED fr_trie_t *ft)
2826
0
{
2827
0
  return 0;
2828
0
}
2829
2830
#ifdef TESTING
2831
static bool print_lineno = false;
2832
2833
typedef struct {
2834
  char  *start;
2835
  char  *buffer;
2836
  size_t  buflen;
2837
} trie_sprint_ctx_t;
2838
2839
2840
/**  Print the strings accepted by a trie to one line
2841
 */
2842
static int _trie_sprint_cb(fr_trie_t *trie, fr_trie_callback_t *cb, int keylen, bool more)
2843
{
2844
  int bytes, len;
2845
  trie_sprint_ctx_t *ctx;
2846
  fr_trie_user_t *user;
2847
2848
  ctx = cb->ctx;
2849
2850
  if (!trie) {
2851
    len = snprintf(ctx->buffer, ctx->buflen, "{}");
2852
    goto done;
2853
  }
2854
2855
  if (trie->type != FR_TRIE_USER) return 0;
2856
2857
  bytes = BYTES(keylen);
2858
  user = (fr_trie_user_t *) trie;
2859
2860
  if (!user->trie && !more) {
2861
    len = snprintf(ctx->buffer, ctx->buflen, "%.*s=%s",
2862
        bytes, cb->start, (char const *) user->data);
2863
  } else {
2864
    len = snprintf(ctx->buffer, ctx->buflen, "%.*s=%s,",
2865
             bytes, cb->start, (char const *) user->data);
2866
  }
2867
2868
done:
2869
  ctx->buffer += len;
2870
  ctx->buflen -= len;
2871
2872
  return 0;
2873
}
2874
2875
2876
static void trie_sprint(fr_trie_t *trie, uint8_t const *key, int start_bit, UNUSED int lineno)
2877
{
2878
  fr_trie_callback_t my_cb;
2879
  trie_sprint_ctx_t my_sprint;
2880
  uint8_t buffer[MAX_KEY_BYTES + 1];
2881
  char out[8192];
2882
2883
  /*
2884
   *  Initialize the buffer
2885
   */
2886
  memset(buffer, 0, sizeof(buffer));
2887
  memset(out, 0, sizeof(out));
2888
  if (key) {
2889
    memcpy(buffer, key, BYTES(start_bit) + 1);
2890
  }
2891
2892
  /*
2893
   *  Where the output data goes.
2894
   */
2895
  my_sprint.start = out;
2896
  my_sprint.buffer = out;
2897
  my_sprint.buflen = sizeof(out);
2898
2899
  /*
2900
   *  Where the keys are built.
2901
   */
2902
  my_cb.start = buffer;
2903
  my_cb.end = buffer + sizeof(buffer);
2904
  my_cb.callback = _trie_sprint_cb;
2905
  my_cb.user_callback = NULL;
2906
  my_cb.ctx = &my_sprint;
2907
2908
  /*
2909
   *  Call the internal walk function to do the work.
2910
   */
2911
  (void) trie_key_walk(trie, &my_cb, start_bit, false);
2912
2913
  MPRINT3("%.*s%s at %d\n", start_bit, spaces, out, lineno);
2914
}
2915
2916
2917
/**  Parse a string into bits + key
2918
 *
2919
 *  The format is one of:
2920
 *
2921
 *  - string such as "abcdef"
2922
 *  - string prefixed with a bit length, {4}a
2923
 */
2924
static int arg2key(char *arg, char **key, int *length)
2925
{
2926
  char *p;
2927
  int bits, size;
2928
2929
  if (*arg != '{') {
2930
    *key = arg;
2931
    *length = BITSOF(strlen(arg));
2932
    return 0;
2933
  }
2934
2935
  p = strchr(arg, '}');
2936
  if (!p) {
2937
    MPRINT("Failed to find end '}' for {bits}\n");
2938
    return -1;
2939
  }
2940
2941
  bits = BITSOF(strlen(p + 1));
2942
  if (!bits) {
2943
    MPRINT("No key found in in '%s'\n", arg);
2944
    return -1;
2945
  }
2946
2947
  size = atoi(arg + 1); /* ignore end character... */
2948
  if (size > bits) {
2949
    MPRINT("Length '%d' is longer than bits in key %s",
2950
      size, p + 1);
2951
  }
2952
2953
  *key = p + 1;
2954
  *length = size;
2955
2956
  return 0;
2957
}
2958
2959
/**  Our TALLOC_CTX for the data we put into the trie.
2960
 *
2961
 *  Most people don't need to do this, they can just insert their own
2962
 *  data.
2963
 */
2964
static void *data_ctx = NULL;
2965
2966
/**  Insert a key + data into a trie.
2967
 *
2968
 */
2969
static int command_insert(fr_trie_t *ft, UNUSED int argc, char **argv, UNUSED char *out, UNUSED size_t outlen)
2970
{
2971
  int bits;
2972
  void *data;
2973
  char *key;
2974
2975
  if (arg2key(argv[0], &key, &bits) < 0) {
2976
    return -1;
2977
  }
2978
2979
  /*
2980
   *  This has to stick around in between command
2981
   *  invocations.
2982
   */
2983
  data = talloc_strdup(data_ctx, argv[1]);
2984
  if (!data) {
2985
    MPRINT("OOM\n");
2986
    return -1;
2987
  }
2988
2989
  if (fr_trie_insert_by_key(ft, key, bits, data) < 0) {
2990
    MPRINT("Failed inserting key %s=%s - %s\n", key, argv[1], fr_strerror());
2991
    return -1;
2992
  }
2993
2994
  return 0;
2995
}
2996
2997
/**  Verify a trie recursively
2998
 *
2999
 *  For sanity reasons, this command runs but doesn't do anything if
3000
 *  the code is built with no trie verification.
3001
 */
3002
static int command_verify(fr_trie_t *ft, UNUSED int argc, UNUSED char **argv, UNUSED char *out, UNUSED size_t outlen)
3003
{
3004
  fr_cond_assert(ft != NULL);
3005
3006
  /*
3007
   *  The top-level node may have a NULL talloc ctx, which
3008
   *  is OK.  So we skip that.
3009
   */
3010
  if (ft->type != FR_TRIE_USER) {
3011
    fprintf(stderr, "Verify failed: trie is malformed\n");
3012
    return -1;
3013
  }
3014
3015
  if (trie_verify(ft->trie) < 0) {
3016
    fprintf(stderr, "Verify failed: %s\n", fr_strerror());
3017
    return -1;
3018
  }
3019
3020
  return 0;
3021
}
3022
3023
/** Print the keys accepted by a trie
3024
 *
3025
 *  The strings are printed to stdout.
3026
 *
3027
 *  @todo - allow printing to a file.
3028
 */
3029
static int command_keys(fr_trie_t *ft, UNUSED int argc, UNUSED char **argv, char *out, size_t outlen)
3030
{
3031
  fr_trie_callback_t my_cb;
3032
3033
  my_cb.start = (uint8_t *) out;
3034
  my_cb.end = (uint8_t *) (out + outlen);
3035
  my_cb.callback = _trie_print_cb;
3036
  my_cb.user_callback = NULL;
3037
  my_cb.ctx = stdout;
3038
3039
  /*
3040
   *  Call the internal walk function to do the work.
3041
   */
3042
  return trie_key_walk(ft->trie, &my_cb, 0, false);
3043
}
3044
3045
3046
/** Dump the trie in internal format
3047
 *
3048
 *  The information is printed to stdout.
3049
 *
3050
 *  For sanity reasons, this command runs but doesn't do anything if
3051
 *  the code is built with no trie dumping.
3052
 *
3053
 *  @todo - allow printing to a file.
3054
 */
3055
static int command_dump(fr_trie_t *ft, UNUSED int argc, UNUSED char **argv, char *out, size_t outlen)
3056
{
3057
  fr_trie_callback_t my_cb;
3058
3059
  my_cb.start = (uint8_t *) out;
3060
  my_cb.end = (uint8_t *) (out + outlen);
3061
  my_cb.callback = _trie_dump_cb;
3062
  my_cb.user_callback = NULL;
3063
  my_cb.ctx = stdout;
3064
3065
  /*
3066
   *  Call the internal walk function to do the work.
3067
   */
3068
  return trie_key_walk(ft->trie, &my_cb, 0, false);
3069
}
3070
3071
3072
/**  Clear the entire trie without caring what's in it.
3073
 *
3074
 */
3075
static int command_clear(fr_trie_t *ft, UNUSED int argc, UNUSED char **argv, UNUSED char *out, UNUSED size_t outlen)
3076
{
3077
  if (!ft->trie) return 0;
3078
3079
  trie_free(ft->trie);
3080
  ft->trie = NULL;
3081
3082
  /*
3083
   *  Clean up our internal data ctx, too.
3084
   */
3085
  talloc_free(data_ctx);
3086
  data_ctx = talloc_init_const("data_ctx");
3087
3088
  return 0;
3089
}
3090
3091
3092
/**  Turn on line number debugging.
3093
 *
3094
 *  @todo - add general "debug" functionality.
3095
 */
3096
static int command_lineno(UNUSED fr_trie_t *ft, UNUSED int argc, char **argv, UNUSED char *out, UNUSED size_t outlen)
3097
{
3098
  if (strcmp(argv[0], "true") == 0) {
3099
    print_lineno = true;
3100
  } else {
3101
    print_lineno = false;
3102
  }
3103
3104
  return 0;
3105
}
3106
3107
3108
/**  Match an exact key + length
3109
 *
3110
 *  Normally, the "lookup" returns the longest prefix match, so that
3111
 *  *long* key lookups can return *short* matches.
3112
 *
3113
 *  In some cases, we want to know if an exact key is in the trie.
3114
 *  For those cases, we use this function.
3115
 */
3116
static int command_match(fr_trie_t *ft, UNUSED int argc, char **argv, char *out, size_t outlen)
3117
{
3118
  int bits;
3119
  void *answer;
3120
  char *key;
3121
3122
  if (arg2key(argv[0], &key, &bits) < 0) {
3123
    return -1;
3124
  }
3125
3126
  answer = trie_key_match(ft->trie, (uint8_t *) key, 0, bits, true);
3127
  if (!answer) {
3128
    strlcpy(out, "{}", outlen);
3129
    return 0;
3130
  }
3131
3132
  strlcpy(out, answer, outlen);
3133
3134
  return 0;
3135
}
3136
3137
3138
/**  Look up a key and return user ctx data.
3139
 *
3140
 *  This is done by longest prefix match, not exact match.
3141
 */
3142
static int command_lookup(fr_trie_t *ft, UNUSED int argc, char **argv, char *out, size_t outlen)
3143
{
3144
  int bits;
3145
  void *answer;
3146
  char *key;
3147
3148
  if (arg2key(argv[0], &key, &bits) < 0) {
3149
    return -1;
3150
  }
3151
3152
  answer = fr_trie_lookup_by_key(ft, key, bits);
3153
  if (!answer) {
3154
    strlcpy(out, "{}", outlen);
3155
    return 0;
3156
  }
3157
3158
  strlcpy(out, answer, outlen);
3159
3160
  return 0;
3161
}
3162
3163
3164
/**  Remove a key from the trie.
3165
 *
3166
 *  The key has to match exactly.
3167
 */
3168
static int command_remove(fr_trie_t *ft, UNUSED int argc, char **argv, char *out, size_t outlen)
3169
{
3170
  int bits;
3171
  void *answer;
3172
  char *key;
3173
3174
  if (arg2key(argv[0], &key, &bits) < 0) {
3175
    return -1;
3176
  }
3177
3178
  answer = fr_trie_remove_by_key(ft, key, bits);
3179
  if (!answer) {
3180
    MPRINT("Could not remove key %s\n", key);
3181
    return -1;
3182
  }
3183
3184
  strlcpy(out, answer, outlen);
3185
3186
  talloc_free(answer);
3187
3188
  /*
3189
   *  We now try to find an exact match.  i.e. we don't want
3190
   *  to find a shorter prefix.
3191
   */
3192
  answer = trie_key_match(ft->trie, (uint8_t *) key, 0, bits, true);
3193
  if (answer) {
3194
    MPRINT("Still in trie after 'remove' for key %s, found data %s\n", key, (char const *) answer);
3195
    return -1;
3196
  }
3197
3198
  return 0;
3199
}
3200
3201
3202
/**  Remove a key from the trie.
3203
 *
3204
 *  Try to remove a key, but don't error if we can't.
3205
 */
3206
static int command_try_to_remove(fr_trie_t *ft, UNUSED int argc, char **argv, char *out, size_t outlen)
3207
{
3208
  int bits;
3209
  void *answer;
3210
  char *key;
3211
3212
  if (arg2key(argv[0], &key, &bits) < 0) {
3213
    return -1;
3214
  }
3215
3216
  answer = fr_trie_remove_by_key(ft, key, bits);
3217
  if (!answer) {
3218
    strlcpy(out, ".", outlen);
3219
    return 0;
3220
  }
3221
3222
  strlcpy(out, answer, outlen);
3223
3224
  talloc_free(answer);
3225
3226
  /*
3227
   *  We now try to find an exact match.  i.e. we don't want
3228
   *  to find a shorter prefix.
3229
   */
3230
  answer = trie_key_match(ft->trie, (uint8_t *) key, 0, bits, true);
3231
  if (answer) {
3232
    MPRINT("Still in trie after 'remove' for key %s, found data %s\n", key, (char const *) answer);
3233
    return -1;
3234
  }
3235
3236
  return 0;
3237
}
3238
3239
3240
/** Print a trie to a string
3241
 *
3242
 *  The trie is printed one one line.  If the trie contains keys which
3243
 *  are not on a byte boundary, well... too bad.  It gets printed
3244
 *  terribly.
3245
 */
3246
static int command_print(fr_trie_t *ft, UNUSED int argc, UNUSED char **argv, char *out, size_t outlen)
3247
{
3248
  fr_trie_callback_t my_cb;
3249
  trie_sprint_ctx_t my_sprint;
3250
  uint8_t buffer[MAX_KEY_BYTES + 1];
3251
3252
  /*
3253
   *  Where the output data goes.
3254
   */
3255
  my_sprint.start = out;
3256
  my_sprint.buffer = out;
3257
  my_sprint.buflen = outlen;
3258
3259
  /*
3260
   *  Where the keys are built.
3261
   */
3262
  my_cb.start = buffer;
3263
  my_cb.end = buffer + sizeof(buffer);
3264
  my_cb.callback = _trie_sprint_cb;
3265
  my_cb.user_callback = NULL;
3266
  my_cb.ctx = &my_sprint;
3267
3268
  memset(buffer, 0, sizeof(buffer));
3269
3270
  /*
3271
   *  Call the internal walk function to do the work.
3272
   */
3273
  return trie_key_walk(ft->trie, &my_cb, 0, false);
3274
}
3275
3276
3277
/**  Do insert / lookup / remove all at once.
3278
 *
3279
 *  Sometimes it's more useful to do insert / lookup / remove for
3280
 *  simple keys.
3281
 */
3282
static int command_path(fr_trie_t *ft, int argc, char **argv, char *out, size_t outlen)
3283
{
3284
  void *data;
3285
  void *answer;
3286
3287
  data = talloc_strdup(ft, argv[1]); /* has to be malloc'd data, sorry */
3288
  if (!data) {
3289
    MPRINT("OOM\n");
3290
    return -1;
3291
  }
3292
3293
  if (fr_trie_insert_by_key(ft, argv[0], BITSOF(strlen(argv[0])), data) < 0) {
3294
    MPRINT("Could not insert key %s=%s - %s\n", argv[0], argv[1], fr_strerror());
3295
    return -1;
3296
  }
3297
3298
  answer = fr_trie_lookup_by_key(ft, argv[0], BITSOF(strlen(argv[0])));
3299
  if (!answer) {
3300
    MPRINT("Could not look up key %s\n", argv[0]);
3301
    return -1;
3302
  }
3303
3304
  if (answer != data) {
3305
    MPRINT("Expected to find %s, got %s\n", argv[1], (char const *) answer);
3306
    return -1;
3307
  }
3308
3309
  /*
3310
   *  Call the command 'print' to print out the key.
3311
   */
3312
  (void) command_print(ft, argc, argv, out, outlen);
3313
3314
  answer = fr_trie_remove_by_key(ft, (uint8_t const *) argv[0], BITSOF(strlen(argv[0])));
3315
  if (!answer) {
3316
    MPRINT("Could not remove key %s\n", argv[0]);
3317
    return -1;
3318
  }
3319
3320
  if (answer != data) {
3321
    MPRINT("Expected to remove %s, got %s\n", argv[1], (char const *) answer);
3322
    return -1;
3323
  }
3324
3325
  talloc_free(answer);
3326
3327
  return 0;
3328
}
3329
3330
3331
/**  Return the longest common prefix of two bit strings.
3332
 *
3333
 */
3334
static int command_lcp(UNUSED fr_trie_t *ft, int argc, char **argv, char *out, size_t outlen)
3335
{
3336
  int lcp;
3337
  int keylen1, keylen2;
3338
  int start_bit;
3339
  uint8_t const *key1, *key2;
3340
3341
  if (argc == 2) {
3342
    key1 = (uint8_t const *) argv[0];
3343
    keylen1 = BITSOF(strlen(argv[0]));
3344
3345
    key2 = (uint8_t const *) argv[1];
3346
    keylen2 = BITSOF(strlen(argv[1]));
3347
    start_bit = 0;
3348
3349
  } else if (argc == 5) {
3350
    key1 = (uint8_t const *) argv[0];
3351
    keylen1 = atoi(argv[1]);
3352
    if ((keylen1 < 0) || (keylen1 > (int) BITSOF(strlen(argv[0])))) {
3353
      MPRINT("length of key1 %s is larger than string length %ld\n",
3354
        argv[1], BITSOF(strlen(argv[0])));
3355
      return -1;
3356
    }
3357
3358
    key2 = (uint8_t const *) argv[2];
3359
    keylen2 = atoi(argv[3]);
3360
    if ((keylen2 < 0) || (keylen2 > (int) BITSOF(strlen(argv[2])))) {
3361
      MPRINT("length of key2 %s is larger than string length %ld\n",
3362
        argv[3], BITSOF(strlen(argv[2])));
3363
      return -1;
3364
    }
3365
3366
    start_bit = atoi(argv[4]);
3367
    if ((start_bit < 0) || (start_bit > 7)) {
3368
      MPRINT("start_bit has invalid value %s\n", argv[4]);
3369
      return -1;
3370
    }
3371
3372
  } else {
3373
    MPRINT("Invalid number of arguments\n");
3374
    return -1;
3375
  }
3376
3377
  lcp = fr_trie_key_lcp(key1, keylen1, key2, keylen2, start_bit);
3378
3379
  snprintf(out, outlen, "%d", lcp);
3380
  return 0;
3381
}
3382
3383
/** Get chunks from raw data
3384
 *
3385
 */
3386
static int command_chunk(UNUSED fr_trie_t *ft, UNUSED int argc, char **argv, char *out, size_t outlen)
3387
{
3388
  int start_bit, num_bits;
3389
  uint16_t chunk;
3390
3391
  start_bit = atoi(argv[1]);
3392
  num_bits = atoi(argv[2]);
3393
3394
  chunk = get_chunk((uint8_t const *) argv[0], start_bit, num_bits);
3395
3396
  snprintf(out, outlen, "%04x", chunk);
3397
  return 0;
3398
}
3399
3400
/**  A function to parse a trie command line.
3401
 *
3402
 */
3403
typedef int (*fr_trie_function_t)(fr_trie_t *ft, int argc, char **argv, char *out, size_t outlen);
3404
3405
/**  Data structure which holds the trie command name, function, etc.
3406
 *
3407
 */
3408
typedef struct {
3409
  char const    *name;
3410
  fr_trie_function_t  function;
3411
  int     min_argc;
3412
  int     max_argc;
3413
  bool      output;
3414
} fr_trie_command_t;
3415
3416
3417
/**  The trie commands for debugging.
3418
 *
3419
 */
3420
static fr_trie_command_t commands[] = {
3421
  { "lcp",  command_lcp,  2, 5, true },
3422
  { "chunk",  command_chunk,  3, 3, true },
3423
  { "path", command_path, 2, 2, true },
3424
  { "insert", command_insert, 2, 2, false },
3425
  { "match",  command_match,  1, 1, true },
3426
  { "lookup", command_lookup, 1, 1, true },
3427
  { "remove", command_remove, 1, 1, true },
3428
  { "-remove",  command_try_to_remove, 1, 1, true },
3429
  { "print",  command_print,  0, 0, true },
3430
  { "dump", command_dump, 0, 0, false },
3431
  { "keys", command_keys, 0, 0, false },
3432
  { "verify", command_verify, 0, 0, false },
3433
  { "lineno", command_lineno, 1, 1, false },
3434
  { "clear",  command_clear,  0, 0, false },
3435
  { .name = NULL }
3436
};
3437
3438
#define MAX_ARGC (16)
3439
int main(int argc, char **argv)
3440
{
3441
  int lineno = 0;
3442
  int ret = 0;
3443
  fr_trie_t *ft;
3444
  FILE *fp;
3445
  int my_argc;
3446
  char *my_argv[MAX_ARGC];
3447
  char buffer[8192];
3448
  char output[8192];
3449
3450
  if (argc < 2) {
3451
    fprintf(stderr, "Please specify filename\n");
3452
    fr_exit_now(EXIT_SUCCESS);
3453
  }
3454
3455
  fp = fopen(argv[1], "r");
3456
  if (!fp) {
3457
    fprintf(stderr, "Failed opening %s: %s\n", argv[1], fr_syserror(errno));
3458
    fr_exit_now(1);
3459
  }
3460
3461
  /*
3462
   *  Tell us if we leaked memory.
3463
   */
3464
  talloc_enable_null_tracking();
3465
3466
  data_ctx = talloc_init_const("data_ctx");
3467
3468
  ft = fr_trie_alloc(NULL, NULL, NULL);
3469
  if (!ft) {
3470
    fprintf(stderr, "Failed creating trie\n");
3471
    fr_exit_now(1);
3472
  }
3473
3474
  while (fgets(buffer, sizeof(buffer), fp) != NULL) {
3475
    int i, cmd;
3476
    char *p;
3477
3478
    lineno++;
3479
3480
    /*
3481
     *  Remove comments.
3482
     */
3483
    for (p = buffer; *p != '\0'; p++) {
3484
      if (*p == '#') {
3485
        *p = '\0';
3486
        break;
3487
      }
3488
    }
3489
3490
    /*
3491
     *  Skip leading whitespace.
3492
     */
3493
    p = buffer;
3494
    fr_skip_whitespace(p);
3495
3496
    /*
3497
     *  Skip (now) blank lines.
3498
     */
3499
    if (!*p) continue;
3500
3501
    my_argc = fr_dict_str_to_argv(p, my_argv, MAX_ARGC);
3502
3503
    cmd = -1;
3504
    for (i = 0; commands[i].name != NULL; i++) {
3505
      if (strcmp(my_argv[0], commands[i].name) != 0) continue;
3506
3507
      cmd = i;
3508
      break;
3509
    }
3510
3511
    if (cmd < 0) {
3512
      fprintf(stderr, "Unknown command '%s' at line %d\n",
3513
        my_argv[0], lineno);
3514
      ret = 1;
3515
      break;
3516
    }
3517
3518
    /*
3519
     *  argv[0] is the command.
3520
     *  argv[argc-1] is the output.
3521
     */
3522
    if (((commands[cmd].min_argc + 1 + commands[cmd].output) > my_argc) ||
3523
        ((commands[cmd].max_argc + 1 + commands[cmd].output) < my_argc)) {
3524
      fprintf(stderr, "Invalid number of arguments to %s at line %d.  Expected %d, got %d\n",
3525
        my_argv[0], lineno, commands[cmd].min_argc + 1, my_argc - 1);
3526
      fr_exit_now(1);
3527
    }
3528
3529
    if (print_lineno) {
3530
      printf("%d ", lineno);
3531
      fflush(stdout);
3532
    }
3533
3534
    MPRINT3("[%d] %s\n", lineno, my_argv[0]);
3535
    if (commands[cmd].function(ft, my_argc - 1 - commands[cmd].output, &my_argv[1], output, sizeof(output)) < 0) {
3536
      fprintf(stderr, "Failed running %s at line %d\n",
3537
        my_argv[0], lineno);
3538
      fr_exit_now(1);
3539
    }
3540
3541
    if (!commands[cmd].output) continue;
3542
3543
    if (strcmp(output, my_argv[my_argc - 1]) != 0) {
3544
      fprintf(stderr, "Failed running %s at line %d: Expected '%s' got '%s'\n",
3545
        my_argv[0], lineno, my_argv[my_argc - 1], output);
3546
      fr_exit_now(1);
3547
    }
3548
  }
3549
3550
  fclose(fp);
3551
3552
  trie_free(ft);
3553
  talloc_free(data_ctx);
3554
3555
  talloc_report_full(NULL, stdout); /* Print details of any leaked memory */
3556
  talloc_disable_null_tracking();   /* Cleanup talloc null tracking context */
3557
3558
  return ret;
3559
}
3560
#endif