Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/gcache/ghost_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class GhostCache {

Handle_t access_impl(SizeType block_id, HashType hash, AccessMode mode);

Handle_t lookup_no_refresh(SizeType block_id, HashType hash);

template <uint32_t S, typename H, typename ST, typename HT>
friend class SampledGhostKvCache;

Expand Down Expand Up @@ -291,6 +293,13 @@ GhostCache<Hash, Meta, SizeType, HashType>::access_impl(SizeType block_id,
return h;
}

template <typename Hash, typename Meta, typename SizeType, typename HashType>
inline typename GhostCache<Hash, Meta, SizeType, HashType>::Handle_t
GhostCache<Hash, Meta, SizeType, HashType>::lookup_no_refresh(SizeType block_id,
HashType hash) {
return cache.lookup_no_refresh(block_id, hash);
}

template <typename Hash, typename Meta, typename SizeType, typename HashType>
inline void GhostCache<Hash, Meta, SizeType, HashType>::build_caches_stat() {
size_t accum_hit_cnt = 0;
Expand Down
17 changes: 15 additions & 2 deletions include/gcache/ghost_kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class SampledGhostKvCache {

void access(const std::string_view key, SizeType kv_size,
AccessMode mode = AccessMode::DEFAULT) {
HashType key_hash = Hash{}(key);
access(key_hash, kv_size, mode);
access(Hash{}(key), kv_size, mode);
}

void access(HashType key_hash, SizeType kv_size,
Expand All @@ -55,6 +54,20 @@ class SampledGhostKvCache {
h->kv_size = kv_size;
}

void update_size(const std::string_view key, SizeType kv_size) {
update_size(Hash{}(key), kv_size);
}

void update_size(HashType key_hash, SizeType kv_size) {
if constexpr (SampleShift > 0) {
// only sample keys with certain number of leading zeros in hash
if (key_hash >> (std::numeric_limits<SizeType>::digits - SampleShift))
return;
}
auto h = ghost_cache.lookup_no_refresh(key_hash, key_hash);
if (h) h->kv_size = kv_size;
}

// for compatibility with GhostCache: APIs to query by keys count
[[nodiscard]] SizeType get_tick() const { return ghost_cache.get_tick(); }
[[nodiscard]] SizeType get_min_count() const {
Expand Down
9 changes: 9 additions & 0 deletions include/gcache/lru_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ class LRUCache {
/* Below are intrusive functions that should only be called by GhostCache */
/****************************************************************************/

// Similar to lookup but won't refresh LRU.
Handle_t lookup_no_refresh(Key_t key, uint32_t hash);

// Similar to insert but 1) the targeted node must be in LRU list and this
// function never pins it, 2) return `successor`: the node with the same order
// as the returned node after LRU operations (nullptr if newly inserted).
Expand Down Expand Up @@ -374,6 +377,12 @@ LRUCache<Key_t, Value_t, Hash>::lookup(Key_t key, bool pin) {
return lookup_impl(key, Hash{}(key), pin);
}

template <typename Key_t, typename Value_t, typename Hash>
inline typename LRUCache<Key_t, Value_t, Hash>::Handle_t
LRUCache<Key_t, Value_t, Hash>::lookup_no_refresh(Key_t key, uint32_t hash) {
return table_->lookup(key, hash);
}

template <typename Key_t, typename Value_t, typename Hash>
inline typename LRUCache<Key_t, Value_t, Hash>::Node_t*
LRUCache<Key_t, Value_t, Hash>::lookup_impl(Key_t key, uint32_t hash,
Expand Down
110 changes: 108 additions & 2 deletions tests/test_ghost_kv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,113 @@ void bench1() {
std::cout << std::endl;
}
std::cout << "=============================================================="
<< "================================================" << std::endl;
<< "================================================\n"
<< std::endl;
}

int main() { bench1(); }
void test_update_size() {
uint32_t tick = bench_size / 64;
SampledGhostKvCache<sample_shift> sampled_ghost_kv_cache(tick, tick,
bench_size);
SampledGhostKvCache<sample_shift> sampled_ghost_kv_cache2(tick, tick,
bench_size);

// filling the cache
for (uint32_t i = 0; i < bench_size; ++i) {
auto k = make_key(i);
sampled_ghost_kv_cache.access(k, i > bench_size / 4 ? 500 : 2000,
AccessMode::NOOP);
sampled_ghost_kv_cache2.access(k, i > bench_size / 4 ? 500 : 2000,
AccessMode::NOOP);
}

std::vector<uint32_t> reqs;
std::vector<std::pair<uint32_t, std::string>> reqs2;
for (uint32_t i = 0; i < num_ops; ++i) reqs.emplace_back(rand() % bench_size);
std::shuffle(reqs.begin(), reqs.end(), urbg);
for (auto i : reqs) reqs2.emplace_back(i, make_key(i));

for (const auto& [i, k] : reqs2) {
sampled_ghost_kv_cache.access(k, i > bench_size / 4 ? 500 : 2000);
sampled_ghost_kv_cache2.access(k, i > bench_size / 4 ? 500 : 2000);
}
std::shuffle(reqs2.begin(), reqs2.end(), urbg);
for (const auto& [i, k] : reqs2)
sampled_ghost_kv_cache2.update_size(k, (2000 * 1 + 500 * 3) / 4);

// Dump keys using for_each_lru and compare
std::vector<std::pair<uint32_t, uint32_t>> keys1, keys2;
sampled_ghost_kv_cache.for_each_lru(
[&keys1](const auto& h) { keys1.emplace_back(h.get_key(), h->kv_size); });
sampled_ghost_kv_cache2.for_each_lru(
[&keys2](const auto& h) { keys2.emplace_back(h.get_key(), h->kv_size); });

// Compare vectors
if (keys1.size() != keys2.size())
throw std::runtime_error("Key count mismatch");
for (size_t i = 0; i < keys1.size(); ++i) {
if (keys1[i].first != keys2[i].first) {
throw std::runtime_error(
"Key mismatch: lhs=" + std::to_string(keys1[i].first) +
", rhs=" + std::to_string(keys2[i].first));
}
}

// Print comparison table
auto curve1 = sampled_ghost_kv_cache.get_cache_stat_curve();
auto curve2 = sampled_ghost_kv_cache2.get_cache_stat_curve();

std::cout << "=== Update Size Test ===\n"
<< "==============================================================="
"==========================================\n"
<< " size | w/ kv sampling | kv memory "
"| w/ kv sample update | updated kv memory \n"
<< "---------------------------------------------------------------"
"------------------------------------------\n";

for (uint32_t s = tick; s <= bench_size; s += tick) {
std::cout << std::setw(5) << s / 1024 << "K|";
sampled_ghost_kv_cache.get_stat(s).print(std::cout, 8);
std::cout << '|';

auto idx = s / tick - 1;
if (idx < curve1.size()) {
auto [count, size, cache_stat] = curve1[idx];
assert(count == s);
if (cache_stat.hit_cnt == 0) {
std::cout << " NAN";
} else {
std::cout << std::setw(5) << std::fixed << std::setprecision(1)
<< cache_stat.get_hit_rate() * 100 << '%';
}
std::cout << " @" << std::setw(7) << std::fixed << size / 1024 / 1024
<< 'M' << std::setw(5) << std::fixed << size / count;
}
std::cout << '|';

sampled_ghost_kv_cache2.get_stat(s).print(std::cout, 8);
std::cout << '|';

if (idx < curve2.size()) {
auto [count, size, cache_stat] = curve2[idx];
assert(count == s);
if (cache_stat.hit_cnt == 0) {
std::cout << " NAN";
} else {
std::cout << std::setw(5) << std::fixed << std::setprecision(1)
<< cache_stat.get_hit_rate() * 100 << '%';
}
std::cout << " @" << std::setw(7) << std::fixed << size / 1024 / 1024
<< 'M' << std::setw(5) << std::fixed << size / count;
}
std::cout << std::endl;
}
std::cout << "==============================================================="
"==========================================\n"
<< std::endl;
}

int main() {
bench1();
test_update_size();
}