diff --git a/include/gcache/ghost_cache.h b/include/gcache/ghost_cache.h index 23dc4c4..b4b3fa0 100644 --- a/include/gcache/ghost_cache.h +++ b/include/gcache/ghost_cache.h @@ -64,6 +64,8 @@ class GhostCache { Handle_t access_impl(SizeType block_id, HashType hash, AccessMode mode); + Handle_t lookup_no_refresh(SizeType block_id, HashType hash); + template friend class SampledGhostKvCache; @@ -291,6 +293,13 @@ GhostCache::access_impl(SizeType block_id, return h; } +template +inline typename GhostCache::Handle_t +GhostCache::lookup_no_refresh(SizeType block_id, + HashType hash) { + return cache.lookup_no_refresh(block_id, hash); +} + template inline void GhostCache::build_caches_stat() { size_t accum_hit_cnt = 0; diff --git a/include/gcache/ghost_kv_cache.h b/include/gcache/ghost_kv_cache.h index f188946..85c7f14 100644 --- a/include/gcache/ghost_kv_cache.h +++ b/include/gcache/ghost_kv_cache.h @@ -40,8 +40,7 @@ class SampledGhostKvCache { void access(const std::string_view key, SizeType kv_size, AccessMode mode = AccessMode::DEFAULT) { - HashType key_hash = Hash{}(key); - access(key_hash, kv_size, mode); + access(Hash{}(key), kv_size, mode); } void access(HashType key_hash, SizeType kv_size, @@ -55,6 +54,20 @@ class SampledGhostKvCache { h->kv_size = kv_size; } + void update_size(const std::string_view key, SizeType kv_size) { + update_size(Hash{}(key), kv_size); + } + + void update_size(HashType key_hash, SizeType kv_size) { + if constexpr (SampleShift > 0) { + // only sample keys with certain number of leading zeros in hash + if (key_hash >> (std::numeric_limits::digits - SampleShift)) + return; + } + auto h = ghost_cache.lookup_no_refresh(key_hash, key_hash); + if (h) h->kv_size = kv_size; + } + // for compatibility with GhostCache: APIs to query by keys count [[nodiscard]] SizeType get_tick() const { return ghost_cache.get_tick(); } [[nodiscard]] SizeType get_min_count() const { diff --git a/include/gcache/lru_cache.h b/include/gcache/lru_cache.h index 4e0af63..67a5b44 100644 --- a/include/gcache/lru_cache.h +++ b/include/gcache/lru_cache.h @@ -136,6 +136,9 @@ class LRUCache { /* Below are intrusive functions that should only be called by GhostCache */ /****************************************************************************/ + // Similar to lookup but won't refresh LRU. + Handle_t lookup_no_refresh(Key_t key, uint32_t hash); + // Similar to insert but 1) the targeted node must be in LRU list and this // function never pins it, 2) return `successor`: the node with the same order // as the returned node after LRU operations (nullptr if newly inserted). @@ -374,6 +377,12 @@ LRUCache::lookup(Key_t key, bool pin) { return lookup_impl(key, Hash{}(key), pin); } +template +inline typename LRUCache::Handle_t +LRUCache::lookup_no_refresh(Key_t key, uint32_t hash) { + return table_->lookup(key, hash); +} + template inline typename LRUCache::Node_t* LRUCache::lookup_impl(Key_t key, uint32_t hash, diff --git a/tests/test_ghost_kv.cpp b/tests/test_ghost_kv.cpp index b38d88b..c3e636a 100644 --- a/tests/test_ghost_kv.cpp +++ b/tests/test_ghost_kv.cpp @@ -99,7 +99,113 @@ void bench1() { std::cout << std::endl; } std::cout << "==============================================================" - << "================================================" << std::endl; + << "================================================\n" + << std::endl; } -int main() { bench1(); } +void test_update_size() { + uint32_t tick = bench_size / 64; + SampledGhostKvCache sampled_ghost_kv_cache(tick, tick, + bench_size); + SampledGhostKvCache sampled_ghost_kv_cache2(tick, tick, + bench_size); + + // filling the cache + for (uint32_t i = 0; i < bench_size; ++i) { + auto k = make_key(i); + sampled_ghost_kv_cache.access(k, i > bench_size / 4 ? 500 : 2000, + AccessMode::NOOP); + sampled_ghost_kv_cache2.access(k, i > bench_size / 4 ? 500 : 2000, + AccessMode::NOOP); + } + + std::vector reqs; + std::vector> reqs2; + for (uint32_t i = 0; i < num_ops; ++i) reqs.emplace_back(rand() % bench_size); + std::shuffle(reqs.begin(), reqs.end(), urbg); + for (auto i : reqs) reqs2.emplace_back(i, make_key(i)); + + for (const auto& [i, k] : reqs2) { + sampled_ghost_kv_cache.access(k, i > bench_size / 4 ? 500 : 2000); + sampled_ghost_kv_cache2.access(k, i > bench_size / 4 ? 500 : 2000); + } + std::shuffle(reqs2.begin(), reqs2.end(), urbg); + for (const auto& [i, k] : reqs2) + sampled_ghost_kv_cache2.update_size(k, (2000 * 1 + 500 * 3) / 4); + + // Dump keys using for_each_lru and compare + std::vector> keys1, keys2; + sampled_ghost_kv_cache.for_each_lru( + [&keys1](const auto& h) { keys1.emplace_back(h.get_key(), h->kv_size); }); + sampled_ghost_kv_cache2.for_each_lru( + [&keys2](const auto& h) { keys2.emplace_back(h.get_key(), h->kv_size); }); + + // Compare vectors + if (keys1.size() != keys2.size()) + throw std::runtime_error("Key count mismatch"); + for (size_t i = 0; i < keys1.size(); ++i) { + if (keys1[i].first != keys2[i].first) { + throw std::runtime_error( + "Key mismatch: lhs=" + std::to_string(keys1[i].first) + + ", rhs=" + std::to_string(keys2[i].first)); + } + } + + // Print comparison table + auto curve1 = sampled_ghost_kv_cache.get_cache_stat_curve(); + auto curve2 = sampled_ghost_kv_cache2.get_cache_stat_curve(); + + std::cout << "=== Update Size Test ===\n" + << "===============================================================" + "==========================================\n" + << " size | w/ kv sampling | kv memory " + "| w/ kv sample update | updated kv memory \n" + << "---------------------------------------------------------------" + "------------------------------------------\n"; + + for (uint32_t s = tick; s <= bench_size; s += tick) { + std::cout << std::setw(5) << s / 1024 << "K|"; + sampled_ghost_kv_cache.get_stat(s).print(std::cout, 8); + std::cout << '|'; + + auto idx = s / tick - 1; + if (idx < curve1.size()) { + auto [count, size, cache_stat] = curve1[idx]; + assert(count == s); + if (cache_stat.hit_cnt == 0) { + std::cout << " NAN"; + } else { + std::cout << std::setw(5) << std::fixed << std::setprecision(1) + << cache_stat.get_hit_rate() * 100 << '%'; + } + std::cout << " @" << std::setw(7) << std::fixed << size / 1024 / 1024 + << 'M' << std::setw(5) << std::fixed << size / count; + } + std::cout << '|'; + + sampled_ghost_kv_cache2.get_stat(s).print(std::cout, 8); + std::cout << '|'; + + if (idx < curve2.size()) { + auto [count, size, cache_stat] = curve2[idx]; + assert(count == s); + if (cache_stat.hit_cnt == 0) { + std::cout << " NAN"; + } else { + std::cout << std::setw(5) << std::fixed << std::setprecision(1) + << cache_stat.get_hit_rate() * 100 << '%'; + } + std::cout << " @" << std::setw(7) << std::fixed << size / 1024 / 1024 + << 'M' << std::setw(5) << std::fixed << size / count; + } + std::cout << std::endl; + } + std::cout << "===============================================================" + "==========================================\n" + << std::endl; +} + +int main() { + bench1(); + test_update_size(); +}