ParallelUnorderedMap/UnorderedParallelMap.h
2025-05-01 12:47:29 +03:00

367 lines
9.4 KiB
C++

#ifndef UNORDEREDPARALLELMAP_H
#define UNORDEREDPARALLELMAP_H
#include <memory>
#include <atomic>
#include <vector>
#include <optional>
#include <thread>
#include "optimization.h"
template <typename K, typename V>
class LockFreeMap
{
private:
struct Bucket
{
std::atomic<bool> occupied{false};
std::atomic<bool> deleted{false};
std::atomic<K> key;
std::atomic<V> value;
std::atomic<Bucket*> next{nullptr};
};
struct Table
{
size_t capacity;
std::vector<Bucket> buckets;
explicit Table(size_t cap) : capacity(cap), buckets(cap)
{
}
};
std::shared_ptr<Table> table;
std::shared_ptr<Table> new_table{nullptr};
std::atomic<size_t> size_counter{0};
std::mutex resize_mutex;
static constexpr float MAX_LOAD_FACTOR = 0.75;
size_t hash(const K& key, size_t capacity) const
{
return std::hash<K>{}(key) % capacity;
}
void migrate_entry(Bucket& src, const std::shared_ptr<Table>& dest)
{
if (!src.occupied.load() || src.deleted.load()) return;
const K key = src.key.load();
const V val = src.value.load();
size_t idx = hash(key, dest->capacity);
Bucket& head = dest->buckets[idx];
if (!head.occupied.load())
{
bool expected = false;
if (head.occupied.compare_exchange_strong(expected, true))
{
head.key.store(key);
head.value.store(val);
head.deleted.store(false);
return;
}
}
Bucket* current = &head;
while (true)
{
if (!current->deleted.load() && current->key.load() == key)
return;
Bucket* next = current->next.load();
if (next)
{
current = next;
}
else
{
Bucket* new_node = new Bucket;
new_node->occupied.store(true);
new_node->key.store(key);
new_node->value.store(val);
new_node->deleted.store(false);
if (current->next.compare_exchange_strong(next, new_node))
return;
delete new_node;
}
}
}
void start_resize(size_t new_capacity)
{
std::lock_guard<std::mutex> lock(resize_mutex);
if (new_table) return;
auto old_table = table;
auto next = std::make_shared<Table>(new_capacity);
new_table = next;
std::thread([this, old_table, next]()
{
for (auto& bucket : old_table->buckets)
{
Bucket* current = &bucket;
while (current)
{
migrate_entry(*current, next);
current = current->next.load();
}
}
table = next;
new_table = nullptr;
for (auto& bucket : old_table->buckets)
{
Bucket* current = bucket.next.load();
while (current)
{
Bucket* next = current->next.load();
delete current;
current = next;
}
}
}).detach();
}
std::shared_ptr<Table> active_table() const
{
auto nt = new_table;
return nt ? nt : table;
}
public:
explicit LockFreeMap(size_t initial_capacity = 1024)
{
table = std::make_shared<Table>(initial_capacity);
}
~LockFreeMap()
{
auto t = table;
for (auto& bucket : t->buckets)
{
Bucket* current = bucket.next.load();
while (current)
{
Bucket* next = current->next.load();
delete current;
current = next;
}
}
}
bool insert(const K& key, const V& value)
{
if ((float)(size_counter.load() + 1) / table->capacity > MAX_LOAD_FACTOR)
start_resize(table->capacity * 2);
auto t = active_table();
size_t idx = hash(key, t->capacity);
Bucket& head = t->buckets[idx];
prefetch_for_read(&head);
if (!head.occupied.load())
{
bool expected = false;
if (head.occupied.compare_exchange_strong(expected, true))
{
head.key.store(key);
head.value.store(value);
head.deleted.store(false);
size_counter.fetch_add(1);
return true;
}
}
Bucket* current = &head;
while (true)
{
if (!current->deleted.load() && current->key.load() == key)
return false;
Bucket* next = current->next.load();
if (next)
{
current = next;
cpu_relax();
}
else
{
Bucket* new_node = new Bucket;
new_node->occupied.store(true);
new_node->key.store(key);
new_node->value.store(value);
new_node->deleted.store(false);
if (current->next.compare_exchange_strong(next, new_node))
{
size_counter.fetch_add(1);
return true;
}
delete new_node;
cpu_relax();
}
}
}
std::optional<V> find(const K& key)
{
auto t = active_table();
size_t idx = hash(key, t->capacity);
Bucket* current = &t->buckets[idx];
while (current)
{
if (current->occupied.load() &&
!current->deleted.load() &&
current->key.load() == key)
{
return current->value.load();
}
current = current->next.load();
cpu_relax();
}
return std::nullopt;
}
bool erase(const K& key)
{
auto t = active_table();
size_t idx = hash(key, t->capacity);
Bucket* current = &t->buckets[idx];
while (current)
{
if (current->occupied.load() &&
!current->deleted.load() &&
current->key.load() == key)
{
current->deleted.store(true);
size_counter.fetch_sub(1);
return true;
}
current = current->next.load();
cpu_relax();
}
return false;
}
bool update(const K& key, const V& new_value)
{
auto t = active_table();
size_t idx = hash(key, t->capacity);
Bucket* current = &t->buckets[idx];
while (current)
{
if (current->occupied.load() &&
!current->deleted.load() &&
current->key.load() == key)
{
current->value.store(new_value);
return true;
}
current = current->next.load();
cpu_relax();
}
return false;
}
void shrink()
{
size_t current_size = size_counter.load();
auto cap = table->capacity;
if (current_size < cap / 4 && cap > 1024)
start_resize(cap / 2);
}
size_t size() const
{
return size_counter.load();
}
bool rehash_one(const K& key)
{
auto src = table;
auto dst = new_table;
if (!dst) return false;
size_t idx = hash(key, src->capacity);
Bucket* current = &src->buckets[idx];
while (current)
{
if (current->occupied.load() &&
!current->deleted.load() &&
current->key.load() == key)
{
migrate_entry(*current, dst);
return true;
}
current = current->next.load();
}
return false;
}
std::vector<K> keys()
{
std::vector<K> result;
auto t = active_table();
for (auto& bucket : t->buckets)
{
Bucket* current = &bucket;
while (current)
{
if (current->occupied.load() &&
!current->deleted.load())
{
result.push_back(current->key.load());
}
current = current->next.load();
}
}
return result;
}
std::vector<std::pair<K, V>> entries()
{
std::vector<std::pair<K, V>> result;
auto t = active_table();
for (auto& bucket : t->buckets)
{
Bucket* current = &bucket;
while (current)
{
if (current->occupied.load() && !current->deleted.load())
{
result.emplace_back(current->key.load(), current->value.load());
}
current = current->next.load();
}
}
return result;
}
void for_each(const std::function<void(const K&, const V&)>& cb)
{
auto t = active_table();
for (auto& bucket : t->buckets)
{
Bucket* current = &bucket;
while (current)
{
if (current->occupied.load() && !current->deleted.load())
{
cb(current->key.load(), current->value.load());
}
current = current->next.load();
}
}
}
};
#endif // UNORDEREDPARALLELMAP_H