2025-04-20 18:24:28 +03:00

290 lines
8.8 KiB
C++

#ifndef LFSKIPLIST_H
#define LFSKIPLIST_H
#include <atomic>
#include <array>
#include <random>
#include <optional>
#include <limits>
#include <memory>
#include <cassert>
#include <thread>
#include <cstdint>
#include "utils/io/VersionManager.h"
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)
#include <immintrin.h>
inline void cpu_relax() noexcept { _mm_pause(); }
inline void prefetch_for_read(const void* ptr) noexcept {
_mm_prefetch(reinterpret_cast<const char*>(ptr), _MM_HINT_T0);
}
#else
inline void cpu_relax() noexcept
{
}
inline void prefetch_for_read(const void*) noexcept
{
}
#endif
namespace usub::utils
{
constexpr int MAX_LEVEL = 16;
inline int random_level()
{
static thread_local std::mt19937 rng(std::random_device{}());
static thread_local std::uniform_int_distribution<int> dist(0, 1);
int lvl = 1;
while (lvl < MAX_LEVEL && dist(rng)) ++lvl;
return lvl;
}
template <typename Key, typename Value>
class LFSkipList
{
struct Node
{
Key key;
Value value;
int topLevel;
bool is_tombstone;
uint64_t version;
std::array<std::atomic<Node*>, MAX_LEVEL> next;
std::atomic<bool> marked{false};
Node(const Key& k, const Value& v, int level, bool tombstone, uint64_t ver)
: key(k), value(v), topLevel(level), is_tombstone(tombstone), version(ver)
{
for (int i = 0; i < MAX_LEVEL; ++i)
next[i].store(nullptr, std::memory_order_relaxed);
}
};
Node* head;
usub::utils::VersionManager& version_manager;
public:
using key_type = Key;
using value_type = Value;
LFSkipList(usub::utils::VersionManager& vm)
: version_manager(vm)
{
head = new Node(std::numeric_limits<Key>::min(), Value{}, MAX_LEVEL, false, version_manager.next_version());
}
~LFSkipList()
{
Node* curr = head;
while (curr)
{
Node* next = next_node(curr->next[0].load(std::memory_order_relaxed));
delete curr;
curr = next;
}
}
bool insert(const Key& key, const Value& value)
{
return insert_internal(key, value, false);
}
bool erase(const Key& key)
{
return insert_internal(key, Value{}, true);
}
std::optional<Value> find(const Key& key) const
{
Node* best = nullptr;
Node* node = head->next[0].load(std::memory_order_acquire);
while (node)
{
prefetch_for_read(node);
if (node->key == key && !node->marked.load(std::memory_order_acquire))
{
if (!best || node->version > best->version)
{
best = node;
}
}
node = next_node(node->next[0].load(std::memory_order_acquire));
}
if (best && !best->is_tombstone)
return best->value;
return std::nullopt;
}
template <typename F>
void for_each(F&& func) const
{
Node* node = head->next[0].load(std::memory_order_acquire);
while (node)
{
prefetch_for_read(node);
if (!node->marked.load(std::memory_order_acquire) && !node->is_tombstone)
{
func(node->key, node->value);
}
node = next_node(node->next[0].load(std::memory_order_acquire));
}
}
template <typename F>
void for_each_raw(F&& func) const
{
Node* node = head->next[0].load(std::memory_order_acquire);
while (node)
{
prefetch_for_read(node);
if (!node->marked.load(std::memory_order_acquire))
{
func(node->key, node->value, node->is_tombstone, node->version);
}
node = next_node(node->next[0].load(std::memory_order_acquire));
}
}
bool insert_raw(const Key& key, const Value& value, bool tombstone, uint64_t version)
{
Node* preds[MAX_LEVEL]{};
Node* succs[MAX_LEVEL]{};
while (true)
{
bool found = find_internal(key, preds, succs);
int topLevel = random_level();
Node* newNode = new Node(key, value, topLevel, tombstone, version);
for (int i = 0; i < topLevel; ++i)
newNode->next[i].store(succs[i], std::memory_order_relaxed);
if (!preds[0]->next[0].compare_exchange_strong(
succs[0], newNode, std::memory_order_acq_rel, std::memory_order_relaxed))
{
delete newNode;
cpu_relax();
continue;
}
for (int i = 1; i < topLevel; ++i)
{
while (true)
{
if (preds[i]->next[i].compare_exchange_strong(
succs[i], newNode, std::memory_order_acq_rel, std::memory_order_relaxed))
break;
cpu_relax();
find_internal(key, preds, succs);
}
}
return !found || tombstone;
}
}
[[nodiscard]] size_t unsafe_size() const
{
size_t count = 0;
Node* node = head->next[0].load(std::memory_order_relaxed);
while (node)
{
if (!node->marked.load(std::memory_order_relaxed) && !node->is_tombstone)
++count;
node = next_node(node->next[0].load(std::memory_order_relaxed));
}
return count;
}
private:
bool insert_internal(const Key& key, const Value& value, bool tombstone)
{
Node* preds[MAX_LEVEL]{};
Node* succs[MAX_LEVEL]{};
while (true)
{
bool found = find_internal(key, preds, succs);
int topLevel = random_level();
Node* newNode = new Node(key, value, topLevel, tombstone, version_manager.next_version());
for (int i = 0; i < topLevel; ++i)
newNode->next[i].store(succs[i], std::memory_order_relaxed);
if (!preds[0]->next[0].compare_exchange_strong(
succs[0], newNode, std::memory_order_acq_rel, std::memory_order_relaxed))
{
delete newNode;
cpu_relax();
continue;
}
for (int i = 1; i < topLevel; ++i)
{
while (true)
{
if (preds[i]->next[i].compare_exchange_strong(
succs[i], newNode, std::memory_order_acq_rel, std::memory_order_relaxed))
break;
cpu_relax();
find_internal(key, preds, succs);
}
}
return !found || tombstone;
}
}
bool find_internal(const Key& key, Node** preds, Node** succs) const
{
bool found = false;
Node* pred = head;
for (int level = MAX_LEVEL - 1; level >= 0; --level)
{
Node* curr = pred->next[level].load(std::memory_order_acquire);
while (curr)
{
prefetch_for_read(curr);
Node* next = curr->next[level].load(std::memory_order_acquire);
if (reinterpret_cast<uintptr_t>(next) & 1)
{
curr = next_node(next);
continue;
}
if (curr->key < key)
{
pred = curr;
curr = next;
}
else
{
break;
}
}
preds[level] = pred;
succs[level] = curr;
}
if (succs[0] && succs[0]->key == key && !succs[0]->marked.load(std::memory_order_acquire))
found = true;
return found;
}
static Node* next_node(Node* n)
{
return reinterpret_cast<Node*>(reinterpret_cast<uintptr_t>(n) & ~uintptr_t(1));
}
};
} // namespace usub::utils
#endif //LFSKIPLIST_H