367 lines
9.4 KiB
C++
367 lines
9.4 KiB
C++
#ifndef UNORDEREDPARALLELMAP_H
|
|
#define UNORDEREDPARALLELMAP_H
|
|
|
|
#include <memory>
|
|
#include <atomic>
|
|
#include <vector>
|
|
#include <optional>
|
|
#include <thread>
|
|
#include "optimization.h"
|
|
|
|
template <typename K, typename V>
|
|
class LockFreeMap
|
|
{
|
|
private:
|
|
struct Bucket
|
|
{
|
|
std::atomic<bool> occupied{false};
|
|
std::atomic<bool> deleted{false};
|
|
std::atomic<K> key;
|
|
std::atomic<V> value;
|
|
std::atomic<Bucket*> next{nullptr};
|
|
};
|
|
|
|
struct Table
|
|
{
|
|
size_t capacity;
|
|
std::vector<Bucket> buckets;
|
|
|
|
explicit Table(size_t cap) : capacity(cap), buckets(cap)
|
|
{
|
|
}
|
|
};
|
|
|
|
std::shared_ptr<Table> table;
|
|
std::shared_ptr<Table> new_table{nullptr};
|
|
std::atomic<size_t> size_counter{0};
|
|
std::mutex resize_mutex;
|
|
static constexpr float MAX_LOAD_FACTOR = 0.75;
|
|
|
|
size_t hash(const K& key, size_t capacity) const
|
|
{
|
|
return std::hash<K>{}(key) % capacity;
|
|
}
|
|
|
|
void migrate_entry(Bucket& src, const std::shared_ptr<Table>& dest)
|
|
{
|
|
if (!src.occupied.load() || src.deleted.load()) return;
|
|
|
|
const K key = src.key.load();
|
|
const V val = src.value.load();
|
|
size_t idx = hash(key, dest->capacity);
|
|
Bucket& head = dest->buckets[idx];
|
|
|
|
if (!head.occupied.load())
|
|
{
|
|
bool expected = false;
|
|
if (head.occupied.compare_exchange_strong(expected, true))
|
|
{
|
|
head.key.store(key);
|
|
head.value.store(val);
|
|
head.deleted.store(false);
|
|
return;
|
|
}
|
|
}
|
|
|
|
Bucket* current = &head;
|
|
while (true)
|
|
{
|
|
if (!current->deleted.load() && current->key.load() == key)
|
|
return;
|
|
|
|
Bucket* next = current->next.load();
|
|
if (next)
|
|
{
|
|
current = next;
|
|
}
|
|
else
|
|
{
|
|
Bucket* new_node = new Bucket;
|
|
new_node->occupied.store(true);
|
|
new_node->key.store(key);
|
|
new_node->value.store(val);
|
|
new_node->deleted.store(false);
|
|
if (current->next.compare_exchange_strong(next, new_node))
|
|
return;
|
|
delete new_node;
|
|
}
|
|
}
|
|
}
|
|
|
|
void start_resize(size_t new_capacity)
|
|
{
|
|
std::lock_guard<std::mutex> lock(resize_mutex);
|
|
if (new_table) return;
|
|
|
|
auto old_table = table;
|
|
auto next = std::make_shared<Table>(new_capacity);
|
|
new_table = next;
|
|
|
|
std::thread([this, old_table, next]()
|
|
{
|
|
for (auto& bucket : old_table->buckets)
|
|
{
|
|
Bucket* current = &bucket;
|
|
while (current)
|
|
{
|
|
migrate_entry(*current, next);
|
|
current = current->next.load();
|
|
}
|
|
}
|
|
table = next;
|
|
new_table = nullptr;
|
|
|
|
for (auto& bucket : old_table->buckets)
|
|
{
|
|
Bucket* current = bucket.next.load();
|
|
while (current)
|
|
{
|
|
Bucket* next = current->next.load();
|
|
delete current;
|
|
current = next;
|
|
}
|
|
}
|
|
}).detach();
|
|
}
|
|
|
|
std::shared_ptr<Table> active_table() const
|
|
{
|
|
auto nt = new_table;
|
|
return nt ? nt : table;
|
|
}
|
|
|
|
public:
|
|
explicit LockFreeMap(size_t initial_capacity = 1024)
|
|
{
|
|
table = std::make_shared<Table>(initial_capacity);
|
|
}
|
|
|
|
~LockFreeMap()
|
|
{
|
|
auto t = table;
|
|
for (auto& bucket : t->buckets)
|
|
{
|
|
Bucket* current = bucket.next.load();
|
|
while (current)
|
|
{
|
|
Bucket* next = current->next.load();
|
|
delete current;
|
|
current = next;
|
|
}
|
|
}
|
|
}
|
|
|
|
bool insert(const K& key, const V& value)
|
|
{
|
|
if ((float)(size_counter.load() + 1) / table->capacity > MAX_LOAD_FACTOR)
|
|
start_resize(table->capacity * 2);
|
|
|
|
auto t = active_table();
|
|
size_t idx = hash(key, t->capacity);
|
|
Bucket& head = t->buckets[idx];
|
|
prefetch_for_read(&head);
|
|
|
|
if (!head.occupied.load())
|
|
{
|
|
bool expected = false;
|
|
if (head.occupied.compare_exchange_strong(expected, true))
|
|
{
|
|
head.key.store(key);
|
|
head.value.store(value);
|
|
head.deleted.store(false);
|
|
size_counter.fetch_add(1);
|
|
return true;
|
|
}
|
|
}
|
|
|
|
Bucket* current = &head;
|
|
while (true)
|
|
{
|
|
if (!current->deleted.load() && current->key.load() == key)
|
|
return false;
|
|
|
|
Bucket* next = current->next.load();
|
|
if (next)
|
|
{
|
|
current = next;
|
|
cpu_relax();
|
|
}
|
|
else
|
|
{
|
|
Bucket* new_node = new Bucket;
|
|
new_node->occupied.store(true);
|
|
new_node->key.store(key);
|
|
new_node->value.store(value);
|
|
new_node->deleted.store(false);
|
|
if (current->next.compare_exchange_strong(next, new_node))
|
|
{
|
|
size_counter.fetch_add(1);
|
|
return true;
|
|
}
|
|
delete new_node;
|
|
cpu_relax();
|
|
}
|
|
}
|
|
}
|
|
|
|
std::optional<V> find(const K& key)
|
|
{
|
|
auto t = active_table();
|
|
size_t idx = hash(key, t->capacity);
|
|
Bucket* current = &t->buckets[idx];
|
|
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() &&
|
|
!current->deleted.load() &&
|
|
current->key.load() == key)
|
|
{
|
|
return current->value.load();
|
|
}
|
|
current = current->next.load();
|
|
cpu_relax();
|
|
}
|
|
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool erase(const K& key)
|
|
{
|
|
auto t = active_table();
|
|
size_t idx = hash(key, t->capacity);
|
|
Bucket* current = &t->buckets[idx];
|
|
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() &&
|
|
!current->deleted.load() &&
|
|
current->key.load() == key)
|
|
{
|
|
current->deleted.store(true);
|
|
size_counter.fetch_sub(1);
|
|
return true;
|
|
}
|
|
current = current->next.load();
|
|
cpu_relax();
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
bool update(const K& key, const V& new_value)
|
|
{
|
|
auto t = active_table();
|
|
size_t idx = hash(key, t->capacity);
|
|
Bucket* current = &t->buckets[idx];
|
|
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() &&
|
|
!current->deleted.load() &&
|
|
current->key.load() == key)
|
|
{
|
|
current->value.store(new_value);
|
|
return true;
|
|
}
|
|
current = current->next.load();
|
|
cpu_relax();
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
void shrink()
|
|
{
|
|
size_t current_size = size_counter.load();
|
|
auto cap = table->capacity;
|
|
if (current_size < cap / 4 && cap > 1024)
|
|
start_resize(cap / 2);
|
|
}
|
|
|
|
size_t size() const
|
|
{
|
|
return size_counter.load();
|
|
}
|
|
|
|
bool rehash_one(const K& key)
|
|
{
|
|
auto src = table;
|
|
auto dst = new_table;
|
|
if (!dst) return false;
|
|
|
|
size_t idx = hash(key, src->capacity);
|
|
Bucket* current = &src->buckets[idx];
|
|
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() &&
|
|
!current->deleted.load() &&
|
|
current->key.load() == key)
|
|
{
|
|
migrate_entry(*current, dst);
|
|
return true;
|
|
}
|
|
current = current->next.load();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::vector<K> keys()
|
|
{
|
|
std::vector<K> result;
|
|
auto t = active_table();
|
|
for (auto& bucket : t->buckets)
|
|
{
|
|
Bucket* current = &bucket;
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() &&
|
|
!current->deleted.load())
|
|
{
|
|
result.push_back(current->key.load());
|
|
}
|
|
current = current->next.load();
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::vector<std::pair<K, V>> entries()
|
|
{
|
|
std::vector<std::pair<K, V>> result;
|
|
auto t = active_table();
|
|
for (auto& bucket : t->buckets)
|
|
{
|
|
Bucket* current = &bucket;
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() && !current->deleted.load())
|
|
{
|
|
result.emplace_back(current->key.load(), current->value.load());
|
|
}
|
|
current = current->next.load();
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void for_each(const std::function<void(const K&, const V&)>& cb)
|
|
{
|
|
auto t = active_table();
|
|
for (auto& bucket : t->buckets)
|
|
{
|
|
Bucket* current = &bucket;
|
|
while (current)
|
|
{
|
|
if (current->occupied.load() && !current->deleted.load())
|
|
{
|
|
cb(current->key.load(), current->value.load());
|
|
}
|
|
current = current->next.load();
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
#endif // UNORDEREDPARALLELMAP_H
|