1
#pragma once
2

            
3
#include <functional>
4
#include <memory>
5
#include <thread>
6
#include <type_traits>
7

            
8
#include "envoy/event/dispatcher.h"
9
#include "envoy/singleton/instance.h"
10

            
11
#include "source/common/common/assert.h"
12
#include "source/common/common/non_copyable.h"
13
#include "source/common/common/thread_synchronizer.h"
14

            
15
#include "absl/container/flat_hash_set.h"
16

            
17
namespace Envoy {
18
namespace SharedPool {
19

            
20
/**
21
 * Used to share objects that have the same content.
22
 * control the life cycle of shared objects by reference counting
23
 *
24
 * Note:  ObjectSharedPool needs to be created in the main thread,
25
 * all the member methods can only be called in the main thread,
26
 * it does not have the ownership of object stored internally, the internal storage is weak_ptr,
27
 * when the internal storage object destructor executes the custom deleter to remove its own
28
 * weak_ptr from the ObjectSharedPool.
29
 *
30
 * There is also a need to ensure that the thread where ObjectSharedPool's destructor is also in the
31
 * main thread, or that ObjectSharedPool destruct before the program exit
32
 */
33
template <typename T, typename HashFunc = std::hash<T>, typename EqualFunc = std::equal_to<T>,
34
          class = typename std::enable_if<std::is_copy_constructible<T>::value>::type>
35
class ObjectSharedPool
36
    : public Singleton::Instance,
37
      public std::enable_shared_from_this<ObjectSharedPool<T, HashFunc, EqualFunc>>,
38
      NonCopyable {
39
public:
40
  ObjectSharedPool(Event::Dispatcher& dispatcher)
41
33907
      : thread_id_(std::this_thread::get_id()), dispatcher_(dispatcher) {}
42

            
43
36049
  std::shared_ptr<T> getObject(const T& obj) {
44
36049
    ASSERT(std::this_thread::get_id() == thread_id_);
45

            
46
    // Return from the object pool if we find the object there.
47
36049
    if (auto iter = object_pool_.find(&obj); iter != object_pool_.end()) {
48
14224
      if (auto lock_object = iter->lock(); static_cast<bool>(lock_object) == true) {
49
14222
        return lock_object;
50
14222
      } else {
51
        // Remove the weak_ptr since all associated shared_ptrs have been
52
        // destroyed.
53
2
        object_pool_.erase(iter);
54
2
      }
55
14224
    }
56

            
57
    // Create a shared_ptr and add the object to the object_pool.
58
21827
    auto this_shared_ptr = this->shared_from_this();
59
21827
    std::shared_ptr<T> obj_shared(new T(obj), [this_shared_ptr](T* ptr) {
60
21827
      this_shared_ptr->sync().syncPoint(ObjectSharedPool<T>::ObjectDeleterEntry);
61
21827
      this_shared_ptr->deleteObject(ptr);
62
21827
    });
63
21827
    object_pool_.emplace(obj_shared);
64
21827
    return obj_shared;
65
36049
  }
66

            
67
8
  std::size_t poolSize() const {
68
8
    ASSERT(std::this_thread::get_id() == thread_id_);
69
8
    return object_pool_.size();
70
8
  }
71

            
72
  /**
73
   * @return a thread synchronizer object used for reproducing a race-condition in tests.
74
   */
75
21835
  Thread::ThreadSynchronizer& sync() { return sync_; }
76
  static const char DeleteObjectOnMainThread[];
77
  static const char ObjectDeleterEntry[];
78

            
79
  friend class SharedPoolTest;
80

            
81
private:
82
21829
  void deleteObject(T* ptr) {
83
21829
    if (std::this_thread::get_id() == thread_id_) {
84
21793
      deleteObjectOnMainThread(ptr);
85
21793
    } else {
86
      // Most of the time, the object's destructor occurs in the main thread, but with some
87
      // exceptions, it is destructed in the worker thread. In order to keep the object_pool_ thread
88
      // safe, the deleteObject needs to be delivered to the main thread.
89
36
      auto this_shared_ptr = this->shared_from_this();
90
      // Used for testing to simulate some race condition scenarios
91
36
      sync_.syncPoint(DeleteObjectOnMainThread);
92
36
      dispatcher_.post([ptr, this_shared_ptr] { this_shared_ptr->deleteObjectOnMainThread(ptr); });
93
36
    }
94
21829
  }
95

            
96
21829
  void deleteObjectOnMainThread(T* ptr) {
97
21829
    ASSERT(std::this_thread::get_id() == thread_id_);
98
21829
    if (auto iter = object_pool_.find(ptr); iter != object_pool_.end()) {
99
      // It is possible that the entry in object_pool_ corresponds to a
100
      // different weak_ptr, due to a race condition in a shared_ptr being
101
      // destroyed on another thread, and getObject() being called on the main
102
      // thread.
103
21827
      if (iter->use_count() == 0) {
104
21825
        object_pool_.erase(iter);
105
21825
      }
106
21827
    }
107
    // Wait till here to delete the pointer because we don't want the OS to
108
    // reallocate the memory location before this method completes to prevent
109
    // "hash collisions".
110
21829
    delete ptr;
111
21829
  }
112

            
113
  class Element {
114
  public:
115
22163
    Element(const std::shared_ptr<T>& ptr) : ptr_{ptr.get()}, weak_ptr_{ptr} {}
116

            
117
    Element() = delete;
118
    Element(const Element&) = delete;
119

            
120
147
    Element(Element&&) noexcept = default;
121

            
122
14224
    std::shared_ptr<T> lock() const { return weak_ptr_.lock(); }
123
21827
    long use_count() const { return weak_ptr_.use_count(); }
124

            
125
    friend struct Hash;
126
    friend struct Compare;
127

            
128
    struct Hash {
129
      using is_transparent = void; // NOLINT(readability-identifier-naming)
130
644
      constexpr size_t operator()(const T* ptr) const { return HashFunc{}(*ptr); }
131
327
      constexpr size_t operator()(const Element& element) const {
132
327
        return HashFunc{}(*element.ptr_);
133
327
      }
134
    };
135
    struct Compare {
136
      using is_transparent = void; // NOLINT(readability-identifier-naming)
137
141
      bool operator()(const Element& a, const Element& b) const {
138
141
        ASSERT(a.ptr_ != nullptr && b.ptr_ != nullptr);
139
141
        return a.ptr_ == b.ptr_ ||
140
141
               (a.ptr_ != nullptr && b.ptr_ != nullptr && EqualFunc{}(*a.ptr_, *b.ptr_));
141
141
      }
142
36195
      bool operator()(const Element& a, const T* ptr) const {
143
36195
        ASSERT(a.ptr_ != nullptr && ptr != nullptr);
144
36195
        return a.ptr_ == ptr || (a.ptr_ != nullptr && ptr != nullptr && EqualFunc{}(*a.ptr_, *ptr));
145
36195
      }
146
    };
147

            
148
  private:
149
    const T* const ptr_ = nullptr; ///< This is only used to speed up
150
                                   ///< comparisons and should never be
151
                                   ///< made available outside this class.
152
    std::weak_ptr<T> weak_ptr_;
153
  };
154

            
155
  const std::thread::id thread_id_;
156
  absl::flat_hash_set<Element, typename Element::Hash, typename Element::Compare> object_pool_;
157
  // Use a multimap to allow for multiple objects with the same hash key.
158
  // std::unordered_multimap<size_t, std::weak_ptr<T>> object_pool_;
159
  Event::Dispatcher& dispatcher_;
160
  Thread::ThreadSynchronizer sync_;
161
};
162

            
163
template <typename T, typename HashFunc, typename EqualFunc, class V>
164
const char ObjectSharedPool<T, HashFunc, EqualFunc, V>::DeleteObjectOnMainThread[] =
165
    "delete-object-on-main";
166

            
167
template <typename T, typename HashFunc, typename EqualFunc, class V>
168
const char ObjectSharedPool<T, HashFunc, EqualFunc, V>::ObjectDeleterEntry[] = "deleter-entry";
169

            
170
} // namespace SharedPool
171
} // namespace Envoy