Line data Source code
1 : #pragma once 2 : 3 : #include <functional> 4 : 5 : #include "envoy/event/dispatcher.h" 6 : #include "envoy/thread_local/thread_local.h" 7 : 8 : #include "source/common/config/datasource.h" 9 : 10 : #include "zstd.h" 11 : 12 : namespace Envoy { 13 : namespace Extensions { 14 : namespace Compression { 15 : namespace Zstd { 16 : namespace Common { 17 : 18 : // Dictionary manager for `Zstd` compression. 19 : template <class T, size_t (*deleter)(T*), unsigned (*getDictId)(const T*)> class DictionaryManager { 20 : public: 21 : using DictionaryBuilder = std::function<T*(const void*, size_t)>; 22 : 23 : DictionaryManager( 24 : const Protobuf::RepeatedPtrField<envoy::config::core::v3::DataSource> dictionaries, 25 : Event::Dispatcher& dispatcher, Api::Api& api, ThreadLocal::SlotAllocator& tls, 26 : bool replace_mode, DictionaryBuilder builder) 27 : : api_(api), tls_slot_(ThreadLocal::TypedSlot<DictionaryThreadLocalMap>::makeUnique(tls)), 28 0 : replace_mode_(replace_mode), builder_(builder) { 29 0 : bool is_watch_added = false; 30 0 : watcher_ = dispatcher.createFilesystemWatcher(); 31 : 32 0 : auto dictionary_map = std::make_shared<DictionaryThreadLocalMap>(); 33 0 : dictionary_map->reserve(dictionaries.size()); 34 : 35 0 : for (const auto& source : dictionaries) { 36 0 : const auto data = Config::DataSource::read(source, false, api); 37 0 : auto dictionary = DictionarySharedPtr(builder_(data.data(), data.length())); 38 0 : auto id = getDictId(dictionary.get()); 39 : // If id == 0, the dictionary is not conform to Zstd specification, or empty. 40 0 : RELEASE_ASSERT(id != 0, "Illegal Zstd dictionary"); 41 0 : dictionary_map->emplace(id, std::move(dictionary)); 42 0 : if (source.specifier_case() == 43 0 : envoy::config::core::v3::DataSource::SpecifierCase::kFilename) { 44 0 : is_watch_added = true; 45 0 : const auto& filename = source.filename(); 46 0 : watcher_->addWatch( 47 0 : filename, Filesystem::Watcher::Events::Modified | Filesystem::Watcher::Events::MovedTo, 48 0 : [this, id, filename](uint32_t) { onDictionaryUpdate(id, filename); }); 49 0 : } 50 0 : } 51 : 52 0 : tls_slot_->set([dictionary_map](Event::Dispatcher&) { 53 0 : auto map = std::make_shared<DictionaryThreadLocalMap>(); 54 0 : map->insert(dictionary_map->begin(), dictionary_map->end()); 55 0 : return map; 56 0 : }); 57 : 58 0 : if (!is_watch_added) { 59 0 : watcher_.reset(); 60 0 : } 61 0 : }; 62 : 63 0 : T* getDictionary(bool first_only, unsigned id) { 64 0 : auto dictionary_map = tls_slot_->get(); 65 : 66 0 : typename absl::flat_hash_map<unsigned, DictionarySharedPtr>::iterator it; 67 0 : if (first_only) { 68 0 : it = dictionary_map->begin(); 69 0 : } else { 70 0 : it = dictionary_map->find(id); 71 0 : } 72 0 : if (it != dictionary_map->end()) { 73 0 : return it->second.get(); 74 0 : } 75 : 76 0 : return nullptr; 77 0 : }; 78 : 79 0 : T* getDictionaryById(unsigned id) { return getDictionary(false, id); }; 80 : 81 0 : T* getFirstDictionary() { return getDictionary(true, 0); }; 82 : 83 : private: 84 : class DictionarySharedPtr : public std::shared_ptr<T> { 85 : public: 86 0 : DictionarySharedPtr(T* object) : std::shared_ptr<T>(object, deleter) {} 87 : }; 88 : class DictionaryThreadLocalMap : public absl::flat_hash_map<unsigned, DictionarySharedPtr>, 89 : public ThreadLocal::ThreadLocalObject {}; 90 : 91 0 : void onDictionaryUpdate(unsigned origin_id, const std::string& filename) { 92 0 : auto file_or_error = api_.fileSystem().fileReadToEnd(filename); 93 0 : THROW_IF_STATUS_NOT_OK(file_or_error, throw); 94 0 : const auto data = file_or_error.value(); 95 0 : if (!data.empty()) { 96 0 : auto dictionary = DictionarySharedPtr(builder_(data.data(), data.length())); 97 0 : auto id = getDictId(dictionary.get()); 98 : // Keep origin dictionary if the new is illegal 99 0 : if (id != 0) { 100 0 : tls_slot_->runOnAllThreads( 101 0 : [dictionary = std::move(dictionary), id, origin_id, 102 0 : replace_mode = replace_mode_](OptRef<DictionaryThreadLocalMap> dictionary_map) { 103 0 : if (replace_mode) { 104 0 : dictionary_map->erase(origin_id); 105 0 : } 106 0 : dictionary_map->emplace(id, dictionary); 107 0 : }); 108 0 : } 109 0 : } 110 0 : } 111 : 112 : Api::Api& api_; 113 : ThreadLocal::TypedSlotPtr<DictionaryThreadLocalMap> tls_slot_; 114 : bool replace_mode_; 115 : DictionaryBuilder builder_; 116 : std::unique_ptr<Filesystem::Watcher> watcher_; 117 : }; 118 : 119 : } // namespace Common 120 : } // namespace Zstd 121 : } // namespace Compression 122 : } // namespace Extensions 123 : } // namespace Envoy