1
#pragma once
2

            
3
#include <functional>
4

            
5
#include "envoy/event/dispatcher.h"
6
#include "envoy/thread_local/thread_local.h"
7

            
8
#include "source/common/config/datasource.h"
9

            
10
#include "zstd.h"
11

            
12
namespace Envoy {
13
namespace Extensions {
14
namespace Compression {
15
namespace Zstd {
16
namespace Common {
17

            
18
// Dictionary manager for `Zstd` compression.
19
template <class T, size_t (*deleter)(T*), unsigned (*getDictId)(const T*)> class DictionaryManager {
20
public:
21
  using DictionaryBuilder = std::function<T*(const void*, size_t)>;
22

            
23
  DictionaryManager(
24
      const Protobuf::RepeatedPtrField<envoy::config::core::v3::DataSource> dictionaries,
25
      Event::Dispatcher& dispatcher, Api::Api& api, ThreadLocal::SlotAllocator& tls,
26
      bool replace_mode, DictionaryBuilder builder)
27
338
      : api_(api), tls_slot_(ThreadLocal::TypedSlot<DictionaryThreadLocalMap>::makeUnique(tls)),
28
338
        replace_mode_(replace_mode), builder_(builder) {
29
338
    bool is_watch_added = false;
30
338
    watcher_ = dispatcher.createFilesystemWatcher();
31

            
32
338
    auto dictionary_map = std::make_shared<DictionaryThreadLocalMap>();
33
338
    dictionary_map->reserve(dictionaries.size());
34

            
35
340
    for (const auto& source : dictionaries) {
36
340
      const auto data =
37
340
          THROW_OR_RETURN_VALUE(Config::DataSource::read(source, false, api), std::string);
38
340
      auto dictionary = DictionarySharedPtr(builder_(data.data(), data.length()));
39
340
      auto id = getDictId(dictionary.get());
40
      // If id == 0, the dictionary is not conform to Zstd specification, or empty.
41
340
      RELEASE_ASSERT(id != 0, "Illegal Zstd dictionary");
42
340
      dictionary_map->emplace(id, std::move(dictionary));
43
340
      if (source.specifier_case() ==
44
340
          envoy::config::core::v3::DataSource::SpecifierCase::kFilename) {
45
338
        is_watch_added = true;
46
338
        const auto& filename = source.filename();
47
338
        THROW_IF_NOT_OK(watcher_->addWatch(
48
338
            filename, Filesystem::Watcher::Events::Modified | Filesystem::Watcher::Events::MovedTo,
49
338
            [this, id, filename](uint32_t) {
50
338
              onDictionaryUpdate(id, filename);
51
338
              return absl::OkStatus();
52
338
            }));
53
338
      }
54
340
    }
55

            
56
338
    tls_slot_->set([dictionary_map](Event::Dispatcher&) {
57
336
      auto map = std::make_shared<DictionaryThreadLocalMap>();
58
336
      map->insert(dictionary_map->begin(), dictionary_map->end());
59
336
      return map;
60
336
    });
61

            
62
338
    if (!is_watch_added) {
63
      watcher_.reset();
64
    }
65
338
  };
66

            
67
348
  T* getDictionary(bool first_only, unsigned id) {
68
348
    auto dictionary_map = tls_slot_->get();
69

            
70
348
    typename absl::flat_hash_map<unsigned, DictionarySharedPtr>::iterator it;
71
348
    if (first_only) {
72
174
      it = dictionary_map->begin();
73
174
    } else {
74
174
      it = dictionary_map->find(id);
75
174
    }
76
348
    if (it != dictionary_map->end()) {
77
345
      return it->second.get();
78
345
    }
79

            
80
3
    return nullptr;
81
348
  };
82

            
83
174
  T* getDictionaryById(unsigned id) { return getDictionary(false, id); };
84

            
85
174
  T* getFirstDictionary() { return getDictionary(true, 0); };
86

            
87
private:
88
  class DictionarySharedPtr : public std::shared_ptr<T> {
89
  public:
90
344
    DictionarySharedPtr(T* object) : std::shared_ptr<T>(object, deleter) {}
91
  };
92
  class DictionaryThreadLocalMap : public absl::flat_hash_map<unsigned, DictionarySharedPtr>,
93
                                   public ThreadLocal::ThreadLocalObject {};
94

            
95
6
  void onDictionaryUpdate(unsigned origin_id, const std::string& filename) {
96
6
    auto file_or_error = api_.fileSystem().fileReadToEnd(filename);
97
6
    THROW_IF_NOT_OK_REF(file_or_error.status());
98
6
    const auto data = file_or_error.value();
99
6
    if (!data.empty()) {
100
6
      auto dictionary = DictionarySharedPtr(builder_(data.data(), data.length()));
101
6
      auto id = getDictId(dictionary.get());
102
      // Keep origin dictionary if the new is illegal
103
6
      if (id != 0) {
104
6
        tls_slot_->runOnAllThreads(
105
6
            [dictionary = std::move(dictionary), id, origin_id,
106
6
             replace_mode = replace_mode_](OptRef<DictionaryThreadLocalMap> dictionary_map) {
107
6
              if (replace_mode) {
108
3
                dictionary_map->erase(origin_id);
109
3
              }
110
6
              dictionary_map->emplace(id, dictionary);
111
6
            });
112
6
      }
113
6
    }
114
6
  }
115

            
116
  Api::Api& api_;
117
  ThreadLocal::TypedSlotPtr<DictionaryThreadLocalMap> tls_slot_;
118
  bool replace_mode_;
119
  DictionaryBuilder builder_;
120
  std::unique_ptr<Filesystem::Watcher> watcher_;
121
};
122

            
123
} // namespace Common
124
} // namespace Zstd
125
} // namespace Compression
126
} // namespace Extensions
127
} // namespace Envoy