xgboost
cache.h
Go to the documentation of this file.
1 
4 #ifndef XGBOOST_CACHE_H_
5 #define XGBOOST_CACHE_H_
6 
7 #include <xgboost/logging.h> // for CHECK_EQ, CHECK
8 
9 #include <cstddef> // for size_t
10 #include <memory> // for weak_ptr, shared_ptr, make_shared
11 #include <mutex> // for mutex, lock_guard
12 #include <queue> // for queue
13 #include <thread> // for thread
14 #include <unordered_map> // for unordered_map
15 #include <utility> // for move
16 #include <vector> // for vector
17 
18 namespace xgboost {
19 class DMatrix;
25 template <typename CacheT>
26 class DMatrixCache {
27  public:
28  struct Item {
29  // A weak pointer for checking whether the DMatrix object has expired.
30  std::weak_ptr<DMatrix> ref;
31  // The cached item
32  std::shared_ptr<CacheT> value;
33 
34  CacheT const& Value() const { return *value; }
35  CacheT& Value() { return *value; }
36 
37  Item(std::shared_ptr<DMatrix> m, std::shared_ptr<CacheT> v) : ref{m}, value{std::move(v)} {}
38  };
39 
40  static constexpr std::size_t DefaultSize() { return 32; }
41 
42  private:
43  mutable std::mutex lock_;
44 
45  protected:
46  struct Key {
47  DMatrix const* ptr;
48  std::thread::id const thread_id;
49 
50  bool operator==(Key const& that) const {
51  return ptr == that.ptr && thread_id == that.thread_id;
52  }
53  };
54  struct Hash {
55  std::size_t operator()(Key const& key) const noexcept {
56  std::size_t f = std::hash<DMatrix const*>()(key.ptr);
57  std::size_t s = std::hash<std::thread::id>()(key.thread_id);
58  if (f == s) {
59  return f;
60  }
61  return f ^ s;
62  }
63  };
64 
65  std::unordered_map<Key, Item, Hash> container_;
66  std::queue<Key> queue_;
67  std::size_t max_size_;
68 
69  void CheckConsistent() const { CHECK_EQ(queue_.size(), container_.size()); }
70 
71  void ClearExpired() {
72  // Clear expired entries
73  this->CheckConsistent();
74  std::vector<Key> expired;
75  std::queue<Key> remained;
76 
77  while (!queue_.empty()) {
78  auto p_fmat = queue_.front();
79  auto it = container_.find(p_fmat);
80  CHECK(it != container_.cend());
81  if (it->second.ref.expired()) {
82  expired.push_back(it->first);
83  } else {
84  remained.push(it->first);
85  }
86  queue_.pop();
87  }
88  CHECK(queue_.empty());
89  CHECK_EQ(remained.size() + expired.size(), container_.size());
90 
91  for (auto const& key : expired) {
92  container_.erase(key);
93  }
94  while (!remained.empty()) {
95  auto p_fmat = remained.front();
96  queue_.push(p_fmat);
97  remained.pop();
98  }
99  this->CheckConsistent();
100  }
101 
102  void ClearExcess() {
103  this->CheckConsistent();
104  // clear half of the entries to prevent repeatingly clearing cache.
105  std::size_t half_size = max_size_ / 2;
106  while (queue_.size() >= half_size && !queue_.empty()) {
107  auto p_fmat = queue_.front();
108  queue_.pop();
109  container_.erase(p_fmat);
110  }
111  this->CheckConsistent();
112  }
113 
114  public:
118  explicit DMatrixCache(std::size_t cache_size) : max_size_{cache_size} {}
119 
121  CHECK(lock_.try_lock());
122  lock_.unlock();
123  CHECK(that.lock_.try_lock());
124  that.lock_.unlock();
125  std::swap(this->container_, that.container_);
126  std::swap(this->queue_, that.queue_);
127  std::swap(this->max_size_, that.max_size_);
128  return *this;
129  }
130 
144  template <typename... Args>
145  std::shared_ptr<CacheT> CacheItem(std::shared_ptr<DMatrix> m, Args const&... args) {
146  CHECK(m);
147  std::lock_guard<std::mutex> guard{lock_};
148 
149  this->ClearExpired();
150  if (container_.size() >= max_size_) {
151  this->ClearExcess();
152  }
153  // after clear, cache size < max_size
154  CHECK_LT(container_.size(), max_size_);
155  auto key = Key{m.get(), std::this_thread::get_id()};
156  auto it = container_.find(key);
157  if (it == container_.cend()) {
158  // after the new DMatrix, cache size is at most max_size
159  container_.emplace(key, Item{m, std::make_shared<CacheT>(args...)});
160  queue_.emplace(key);
161  }
162  return container_.at(key).value;
163  }
173  template <typename... Args>
174  std::shared_ptr<CacheT> ResetItem(std::shared_ptr<DMatrix> m, Args const&... args) {
175  std::lock_guard<std::mutex> guard{lock_};
176  CheckConsistent();
177  auto key = Key{m.get(), std::this_thread::get_id()};
178  auto it = container_.find(key);
179  CHECK(it != container_.cend());
180  it->second = {m, std::make_shared<CacheT>(args...)};
181  CheckConsistent();
182  return it->second.value;
183  }
188  decltype(container_) const& Container() {
189  std::lock_guard<std::mutex> guard{lock_};
190 
191  this->ClearExpired();
192  return container_;
193  }
194 
195  std::shared_ptr<CacheT> Entry(DMatrix const* m) const {
196  std::lock_guard<std::mutex> guard{lock_};
197  auto key = Key{m, std::this_thread::get_id()};
198  CHECK(container_.find(key) != container_.cend());
199  CHECK(!container_.at(key).ref.expired());
200  return container_.at(key).value;
201  }
202 };
203 } // namespace xgboost
204 #endif // XGBOOST_CACHE_H_
Thread-aware FIFO cache for DMatrix related data.
Definition: cache.h:26
decltype(container_) const & Container()
Get a const reference to the underlying hash map. Clear expired caches before returning.
Definition: cache.h:188
void ClearExpired()
Definition: cache.h:71
std::unordered_map< Key, Item, Hash > container_
Definition: cache.h:65
DMatrixCache(std::size_t cache_size)
Definition: cache.h:118
DMatrixCache & operator=(DMatrixCache &&that)
Definition: cache.h:120
std::queue< Key > queue_
Definition: cache.h:66
std::shared_ptr< CacheT > ResetItem(std::shared_ptr< DMatrix > m, Args const &... args)
Re-initialize the item in cache.
Definition: cache.h:174
void CheckConsistent() const
Definition: cache.h:69
std::shared_ptr< CacheT > Entry(DMatrix const *m) const
Definition: cache.h:195
std::size_t max_size_
Definition: cache.h:67
static constexpr std::size_t DefaultSize()
Definition: cache.h:40
void ClearExcess()
Definition: cache.h:102
std::shared_ptr< CacheT > CacheItem(std::shared_ptr< DMatrix > m, Args const &... args)
Cache a new DMatrix if it's not in the cache already.
Definition: cache.h:145
Internal data structured used by XGBoost during training.
Definition: data.h:509
Definition: intrusive_ptr.h:207
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
namespace of xgboost
Definition: base.h:90
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
Definition: cache.h:54
std::size_t operator()(Key const &key) const noexcept
Definition: cache.h:55
Definition: cache.h:28
Item(std::shared_ptr< DMatrix > m, std::shared_ptr< CacheT > v)
Definition: cache.h:37
std::shared_ptr< CacheT > value
Definition: cache.h:32
CacheT const & Value() const
Definition: cache.h:34
CacheT & Value()
Definition: cache.h:35
std::weak_ptr< DMatrix > ref
Definition: cache.h:30
Definition: cache.h:46
std::thread::id const thread_id
Definition: cache.h:48
bool operator==(Key const &that) const
Definition: cache.h:50
DMatrix const * ptr
Definition: cache.h:47