ParallelUnorderedMap/UnorderedParallelMap.h

499 lines
14 KiB
C++

#ifndef UNORDEREDPARALLELMAP_H
#define UNORDEREDPARALLELMAP_H
#include <vector>
#include <optional>
#include <shared_mutex>
#include <atomic>
#include <functional>
#include <cstdint>
#include "optimization.h"
#ifdef __AVX2__
#include <immintrin.h>
#endif
template <typename K, typename V>
class LockFreeMap
{
static_assert(std::is_trivially_copyable_v<K>);
static_assert(std::is_trivially_copyable_v<V>);
struct alignas(64) Entry
{
std::atomic<uint64_t> version{0}; // even: stable, odd: writing
bool occupied = false;
bool deleted = false;
K key{};
V value{};
char padding[64 - sizeof(std::atomic<uint64_t>) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{};
};
std::vector<Entry> buckets;
mutable std::shared_mutex resize_mutex;
std::atomic<size_t> count{0};
size_t capacity;
static constexpr float MAX_LOAD = 0.7;
size_t hash(const K& key) const
{
return std::hash<K>{}(key) % this->capacity;
}
size_t probe(size_t h, size_t i) const
{
return (h + i) % this->capacity;
}
void resize()
{
size_t new_capacity = this->capacity * 2;
std::vector<Entry> new_buckets(new_capacity);
for (size_t i = 0; i < this->buckets.size(); ++i)
{
auto& e = this->buckets[i];
prefetch_for_read(&e);
if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]);
cpu_relax();
if (e.occupied && !e.deleted)
{
size_t h = std::hash<K>{}(e.key) % new_capacity;
for (size_t j = 0; j < new_capacity; ++j)
{
size_t idx = (h + j) % new_capacity;
Entry& ne = new_buckets[idx];
prefetch_for_write(&ne);
cpu_relax();
if (!ne.occupied)
{
ne.version.store(1);
ne.key = e.key;
ne.value = e.value;
ne.occupied = true;
ne.deleted = false;
ne.version.store(2);
break;
}
}
}
}
this->buckets = std::move(new_buckets);
this->capacity = new_capacity;
}
public:
explicit LockFreeMap(size_t init_cap = 1024)
: buckets(init_cap), capacity(init_cap)
{
}
bool insert(const K& key, const V& val)
{
std::unique_lock lock(this->resize_mutex);
if ((float)(this->count.load() + 1) / this->capacity > this->MAX_LOAD)
resize();
size_t h = hash(key);
for (size_t i = 0; i < this->capacity; ++i)
{
size_t idx = probe(h, i);
Entry& e = this->buckets[idx];
prefetch_for_write(&e);
cpu_relax();
if (!e.occupied || (e.deleted && e.key == key))
{
uint64_t v = e.version.load();
if (v % 2 != 0) continue;
if (e.version.compare_exchange_weak(v, v + 1))
{
e.key = key;
e.value = val;
e.occupied = true;
e.deleted = false;
e.version.store(v + 2);
this->count.fetch_add(1);
return true;
}
--i;
}
else if (!e.deleted && e.key == key)
{
return false;
}
}
return false;
}
std::optional<V> find(const K& key) const
{
size_t h = hash(key);
for (size_t i = 0; i < this->capacity; ++i)
{
size_t idx = probe(h, i);
const Entry& e = this->buckets[idx];
prefetch_for_read(&e);
cpu_relax();
uint64_t v1 = e.version.load(std::memory_order_acquire);
if (v1 % 2 != 0) continue;
if (e.occupied && !e.deleted && e.key == key)
{
V val = e.value;
cpu_relax();
uint64_t v2 = e.version.load(std::memory_order_acquire);
if (v1 == v2 && v2 % 2 == 0)
return val;
}
}
return std::nullopt;
}
bool erase(const K& key)
{
std::unique_lock lock(this->resize_mutex);
size_t h = hash(key);
for (size_t i = 0; i < this->capacity; ++i)
{
size_t idx = probe(h, i);
Entry& e = this->buckets[idx];
prefetch_for_write(&e);
cpu_relax();
if (e.occupied && e.key == key)
{
if (e.deleted) return false;
uint64_t v = e.version.load();
if (v % 2 != 0) continue;
if (e.version.compare_exchange_strong(v, v + 1))
{
if (!e.deleted)
{
e.deleted = true;
this->count.fetch_sub(1);
}
e.version.store(v + 2);
return true;
}
--i;
}
}
return false;
}
bool update(const K& key, const V& new_val)
{
size_t h = hash(key);
for (size_t i = 0; i < this->capacity; ++i)
{
size_t idx = probe(h, i);
Entry& e = this->buckets[idx];
prefetch_for_write(&e);
cpu_relax();
uint64_t v = e.version.load();
if (v % 2 != 0) continue;
if (e.occupied && !e.deleted && e.key == key)
{
if (e.version.compare_exchange_strong(v, v + 1))
{
e.value = new_val;
e.version.store(v + 2);
return true;
}
--i;
}
}
return false;
}
std::vector<K> keys() const
{
std::vector<K> result;
for (const auto& e : this->buckets)
{
prefetch_for_read(&e);
cpu_relax();
uint64_t v1 = e.version.load(std::memory_order_acquire);
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
K key = e.key;
cpu_relax();
uint64_t v2 = e.version.load(std::memory_order_acquire);
if (v1 == v2 && v2 % 2 == 0)
result.push_back(key);
}
return result;
}
std::vector<std::pair<K, V>> entries() const
{
std::vector<std::pair<K, V>> result;
for (const auto& e : this->buckets)
{
prefetch_for_read(&e);
cpu_relax();
uint64_t v1 = e.version.load(std::memory_order_acquire);
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
K key = e.key;
V val = e.value;
cpu_relax();
uint64_t v2 = e.version.load(std::memory_order_acquire);
if (v1 == v2 && v2 % 2 == 0)
result.emplace_back(key, val);
}
return result;
}
void for_each(const std::function<void(const K&, const V&)>& cb) const
{
const size_t N = this->buckets.size();
for (size_t i = 0; i < N; ++i)
{
const auto& e = this->buckets[i];
prefetch_for_read(&e);
if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]);
cpu_relax();
uint64_t v1 = e.version.load(std::memory_order_acquire);
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
K key = e.key;
V val = e.value;
cpu_relax();
uint64_t v2 = e.version.load(std::memory_order_acquire);
if (v1 == v2 && v2 % 2 == 0)
cb(key, val);
}
}
void for_each_mut(const std::function<void(const K&, V&)>& cb)
{
for (auto& e : this->buckets)
{
prefetch_for_write(&e);
cpu_relax();
if (!e.occupied || e.deleted) continue;
uint64_t v = e.version.load();
if (v % 2 != 0) continue;
if (e.version.compare_exchange_strong(v, v + 1))
{
cb(e.key, e.value);
e.version.store(v + 2);
}
else
{
cpu_relax();
}
}
}
void clear()
{
std::unique_lock lock(this->resize_mutex);
for (auto& e : this->buckets)
{
prefetch_for_write(&e);
cpu_relax();
uint64_t v = e.version.load();
if (v % 2 != 0) continue;
if (e.version.compare_exchange_strong(v, v + 1))
{
e.occupied = false;
e.deleted = false;
e.version.store(v + 2);
}
}
this->count.store(0);
}
void reserve(size_t desired_capacity)
{
std::unique_lock lock(this->resize_mutex);
if (desired_capacity <= this->capacity) return;
size_t new_capacity = 1;
while (new_capacity < desired_capacity) new_capacity <<= 1;
std::vector<Entry> new_buckets(new_capacity);
for (auto& e : this->buckets)
{
prefetch_for_read(&e);
cpu_relax();
if (e.occupied && !e.deleted)
{
size_t h = std::hash<K>{}(e.key) % new_capacity;
for (size_t i = 0; i < new_capacity; ++i)
{
size_t idx = (h + i) % new_capacity;
Entry& ne = new_buckets[idx];
prefetch_for_write(&ne);
cpu_relax();
if (!ne.occupied)
{
ne.version.store(1);
ne.key = e.key;
ne.value = e.value;
ne.occupied = true;
ne.deleted = false;
ne.version.store(2);
break;
}
}
}
}
this->buckets = std::move(new_buckets);
this->capacity = new_capacity;
}
std::vector<std::optional<int>> batch_find(const int* keys, size_t n) const
{
std::vector<std::optional<int>> result(n);
#ifdef __AVX2__
constexpr size_t stride = 8;
size_t i = 0;
for (; i + stride <= n; i += stride) {
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
alignas(32) int key_arr[8];
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
for (int j = 0; j < 8; ++j) {
result[i + j] = find(key_arr[j]);
}
}
for (; i < n; ++i)
result[i] = find(keys[i]);
#else
for (size_t i = 0; i < n; ++i)
result[i] = find(keys[i]);
#endif
return result;
}
void batch_insert(const int* keys, const int* values, size_t n) {
#ifdef __AVX2__
constexpr size_t stride = 8;
size_t i = 0;
for (; i + stride <= n; i += stride) {
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
alignas(32) int key_arr[8];
alignas(32) int val_arr[8];
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
for (int j = 0; j < 8; ++j) {
insert(key_arr[j], val_arr[j]);
}
}
for (; i < n; ++i)
insert(keys[i], values[i]);
#else
for (size_t i = 0; i < n; ++i)
insert(keys[i], values[i]);
#endif
}
void batch_insert_or_update(const int* keys, const int* values, size_t n) {
#ifdef __AVX2__
constexpr size_t stride = 8;
size_t i = 0;
for (; i + stride <= n; i += stride) {
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
alignas(32) int key_arr[8];
alignas(32) int val_arr[8];
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
for (int j = 0; j < 8; ++j) {
if (!insert(key_arr[j], val_arr[j]))
update(key_arr[j], val_arr[j]);
}
}
for (; i < n; ++i) {
if (!insert(keys[i], values[i]))
update(keys[i], values[i]);
}
#else
for (size_t i = 0; i < n; ++i) {
if (!insert(keys[i], values[i]))
update(keys[i], values[i]);
}
#endif
}
void batch_erase(const int* keys, size_t n) {
#ifdef __AVX2__
constexpr size_t stride = 8;
size_t i = 0;
for (; i + stride <= n; i += stride) {
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
alignas(32) int key_arr[8];
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
for (int j = 0; j < 8; ++j) {
erase(key_arr[j]);
}
}
for (; i < n; ++i) {
erase(keys[i]);
}
#else
for (size_t i = 0; i < n; ++i) {
erase(keys[i]);
}
#endif
}
size_t size() const
{
return this->count.load();
}
};
#endif // UNORDEREDPARALLELMAP_H