499 lines
14 KiB
C++
499 lines
14 KiB
C++
#ifndef UNORDEREDPARALLELMAP_H
|
|
#define UNORDEREDPARALLELMAP_H
|
|
|
|
#include <vector>
|
|
#include <optional>
|
|
#include <shared_mutex>
|
|
#include <atomic>
|
|
#include <functional>
|
|
#include <cstdint>
|
|
#include "optimization.h"
|
|
#ifdef __AVX2__
|
|
#include <immintrin.h>
|
|
#endif
|
|
|
|
template <typename K, typename V>
|
|
class LockFreeMap
|
|
{
|
|
static_assert(std::is_trivially_copyable_v<K>);
|
|
static_assert(std::is_trivially_copyable_v<V>);
|
|
|
|
struct alignas(64) Entry
|
|
{
|
|
std::atomic<uint64_t> version{0}; // even: stable, odd: writing
|
|
bool occupied = false;
|
|
bool deleted = false;
|
|
K key{};
|
|
V value{};
|
|
char padding[64 - sizeof(std::atomic<uint64_t>) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{};
|
|
};
|
|
|
|
std::vector<Entry> buckets;
|
|
mutable std::shared_mutex resize_mutex;
|
|
std::atomic<size_t> count{0};
|
|
size_t capacity;
|
|
static constexpr float MAX_LOAD = 0.7;
|
|
|
|
size_t hash(const K& key) const
|
|
{
|
|
return std::hash<K>{}(key) % this->capacity;
|
|
}
|
|
|
|
size_t probe(size_t h, size_t i) const
|
|
{
|
|
return (h + i) % this->capacity;
|
|
}
|
|
|
|
void resize()
|
|
{
|
|
size_t new_capacity = this->capacity * 2;
|
|
std::vector<Entry> new_buckets(new_capacity);
|
|
|
|
for (size_t i = 0; i < this->buckets.size(); ++i)
|
|
{
|
|
auto& e = this->buckets[i];
|
|
prefetch_for_read(&e);
|
|
if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]);
|
|
cpu_relax();
|
|
|
|
if (e.occupied && !e.deleted)
|
|
{
|
|
size_t h = std::hash<K>{}(e.key) % new_capacity;
|
|
for (size_t j = 0; j < new_capacity; ++j)
|
|
{
|
|
size_t idx = (h + j) % new_capacity;
|
|
Entry& ne = new_buckets[idx];
|
|
prefetch_for_write(&ne);
|
|
cpu_relax();
|
|
|
|
if (!ne.occupied)
|
|
{
|
|
ne.version.store(1);
|
|
ne.key = e.key;
|
|
ne.value = e.value;
|
|
ne.occupied = true;
|
|
ne.deleted = false;
|
|
ne.version.store(2);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
this->buckets = std::move(new_buckets);
|
|
this->capacity = new_capacity;
|
|
}
|
|
|
|
public:
|
|
explicit LockFreeMap(size_t init_cap = 1024)
|
|
: buckets(init_cap), capacity(init_cap)
|
|
{
|
|
}
|
|
|
|
bool insert(const K& key, const V& val)
|
|
{
|
|
std::unique_lock lock(this->resize_mutex);
|
|
|
|
if ((float)(this->count.load() + 1) / this->capacity > this->MAX_LOAD)
|
|
resize();
|
|
|
|
size_t h = hash(key);
|
|
for (size_t i = 0; i < this->capacity; ++i)
|
|
{
|
|
size_t idx = probe(h, i);
|
|
Entry& e = this->buckets[idx];
|
|
|
|
prefetch_for_write(&e);
|
|
cpu_relax();
|
|
|
|
if (!e.occupied || (e.deleted && e.key == key))
|
|
{
|
|
uint64_t v = e.version.load();
|
|
if (v % 2 != 0) continue;
|
|
|
|
if (e.version.compare_exchange_weak(v, v + 1))
|
|
{
|
|
e.key = key;
|
|
e.value = val;
|
|
e.occupied = true;
|
|
e.deleted = false;
|
|
e.version.store(v + 2);
|
|
this->count.fetch_add(1);
|
|
return true;
|
|
}
|
|
--i;
|
|
}
|
|
else if (!e.deleted && e.key == key)
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::optional<V> find(const K& key) const
|
|
{
|
|
size_t h = hash(key);
|
|
for (size_t i = 0; i < this->capacity; ++i)
|
|
{
|
|
size_t idx = probe(h, i);
|
|
const Entry& e = this->buckets[idx];
|
|
|
|
prefetch_for_read(&e);
|
|
cpu_relax();
|
|
|
|
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
|
if (v1 % 2 != 0) continue;
|
|
|
|
if (e.occupied && !e.deleted && e.key == key)
|
|
{
|
|
V val = e.value;
|
|
cpu_relax();
|
|
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
|
if (v1 == v2 && v2 % 2 == 0)
|
|
return val;
|
|
}
|
|
}
|
|
return std::nullopt;
|
|
}
|
|
|
|
bool erase(const K& key)
|
|
{
|
|
std::unique_lock lock(this->resize_mutex);
|
|
|
|
size_t h = hash(key);
|
|
for (size_t i = 0; i < this->capacity; ++i)
|
|
{
|
|
size_t idx = probe(h, i);
|
|
Entry& e = this->buckets[idx];
|
|
|
|
prefetch_for_write(&e);
|
|
cpu_relax();
|
|
|
|
if (e.occupied && e.key == key)
|
|
{
|
|
if (e.deleted) return false;
|
|
|
|
uint64_t v = e.version.load();
|
|
if (v % 2 != 0) continue;
|
|
|
|
if (e.version.compare_exchange_strong(v, v + 1))
|
|
{
|
|
if (!e.deleted)
|
|
{
|
|
e.deleted = true;
|
|
this->count.fetch_sub(1);
|
|
}
|
|
e.version.store(v + 2);
|
|
return true;
|
|
}
|
|
--i;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool update(const K& key, const V& new_val)
|
|
{
|
|
size_t h = hash(key);
|
|
for (size_t i = 0; i < this->capacity; ++i)
|
|
{
|
|
size_t idx = probe(h, i);
|
|
Entry& e = this->buckets[idx];
|
|
|
|
prefetch_for_write(&e);
|
|
cpu_relax();
|
|
|
|
uint64_t v = e.version.load();
|
|
if (v % 2 != 0) continue;
|
|
|
|
if (e.occupied && !e.deleted && e.key == key)
|
|
{
|
|
if (e.version.compare_exchange_strong(v, v + 1))
|
|
{
|
|
e.value = new_val;
|
|
e.version.store(v + 2);
|
|
return true;
|
|
}
|
|
--i;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
std::vector<K> keys() const
|
|
{
|
|
std::vector<K> result;
|
|
for (const auto& e : this->buckets)
|
|
{
|
|
prefetch_for_read(&e);
|
|
cpu_relax();
|
|
|
|
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
|
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
|
|
|
|
K key = e.key;
|
|
cpu_relax();
|
|
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
|
if (v1 == v2 && v2 % 2 == 0)
|
|
result.push_back(key);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
std::vector<std::pair<K, V>> entries() const
|
|
{
|
|
std::vector<std::pair<K, V>> result;
|
|
for (const auto& e : this->buckets)
|
|
{
|
|
prefetch_for_read(&e);
|
|
cpu_relax();
|
|
|
|
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
|
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
|
|
|
|
K key = e.key;
|
|
V val = e.value;
|
|
cpu_relax();
|
|
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
|
if (v1 == v2 && v2 % 2 == 0)
|
|
result.emplace_back(key, val);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
void for_each(const std::function<void(const K&, const V&)>& cb) const
|
|
{
|
|
const size_t N = this->buckets.size();
|
|
for (size_t i = 0; i < N; ++i)
|
|
{
|
|
const auto& e = this->buckets[i];
|
|
|
|
prefetch_for_read(&e);
|
|
if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]);
|
|
cpu_relax();
|
|
|
|
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
|
if (v1 % 2 != 0 || !e.occupied || e.deleted) continue;
|
|
|
|
K key = e.key;
|
|
V val = e.value;
|
|
|
|
cpu_relax();
|
|
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
|
if (v1 == v2 && v2 % 2 == 0)
|
|
cb(key, val);
|
|
}
|
|
}
|
|
|
|
void for_each_mut(const std::function<void(const K&, V&)>& cb)
|
|
{
|
|
for (auto& e : this->buckets)
|
|
{
|
|
prefetch_for_write(&e);
|
|
cpu_relax();
|
|
|
|
if (!e.occupied || e.deleted) continue;
|
|
|
|
uint64_t v = e.version.load();
|
|
if (v % 2 != 0) continue;
|
|
|
|
if (e.version.compare_exchange_strong(v, v + 1))
|
|
{
|
|
cb(e.key, e.value);
|
|
e.version.store(v + 2);
|
|
}
|
|
else
|
|
{
|
|
cpu_relax();
|
|
}
|
|
}
|
|
}
|
|
|
|
void clear()
|
|
{
|
|
std::unique_lock lock(this->resize_mutex);
|
|
|
|
for (auto& e : this->buckets)
|
|
{
|
|
prefetch_for_write(&e);
|
|
cpu_relax();
|
|
|
|
uint64_t v = e.version.load();
|
|
if (v % 2 != 0) continue;
|
|
|
|
if (e.version.compare_exchange_strong(v, v + 1))
|
|
{
|
|
e.occupied = false;
|
|
e.deleted = false;
|
|
e.version.store(v + 2);
|
|
}
|
|
}
|
|
this->count.store(0);
|
|
}
|
|
|
|
void reserve(size_t desired_capacity)
|
|
{
|
|
std::unique_lock lock(this->resize_mutex);
|
|
if (desired_capacity <= this->capacity) return;
|
|
|
|
size_t new_capacity = 1;
|
|
while (new_capacity < desired_capacity) new_capacity <<= 1;
|
|
|
|
std::vector<Entry> new_buckets(new_capacity);
|
|
|
|
for (auto& e : this->buckets)
|
|
{
|
|
prefetch_for_read(&e);
|
|
cpu_relax();
|
|
|
|
if (e.occupied && !e.deleted)
|
|
{
|
|
size_t h = std::hash<K>{}(e.key) % new_capacity;
|
|
for (size_t i = 0; i < new_capacity; ++i)
|
|
{
|
|
size_t idx = (h + i) % new_capacity;
|
|
Entry& ne = new_buckets[idx];
|
|
prefetch_for_write(&ne);
|
|
cpu_relax();
|
|
|
|
if (!ne.occupied)
|
|
{
|
|
ne.version.store(1);
|
|
ne.key = e.key;
|
|
ne.value = e.value;
|
|
ne.occupied = true;
|
|
ne.deleted = false;
|
|
ne.version.store(2);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
this->buckets = std::move(new_buckets);
|
|
this->capacity = new_capacity;
|
|
}
|
|
|
|
std::vector<std::optional<int>> batch_find(const int* keys, size_t n) const
|
|
{
|
|
std::vector<std::optional<int>> result(n);
|
|
|
|
#ifdef __AVX2__
|
|
constexpr size_t stride = 8;
|
|
size_t i = 0;
|
|
for (; i + stride <= n; i += stride) {
|
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
|
alignas(32) int key_arr[8];
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
|
|
|
for (int j = 0; j < 8; ++j) {
|
|
result[i + j] = find(key_arr[j]);
|
|
}
|
|
}
|
|
|
|
for (; i < n; ++i)
|
|
result[i] = find(keys[i]);
|
|
|
|
#else
|
|
for (size_t i = 0; i < n; ++i)
|
|
result[i] = find(keys[i]);
|
|
#endif
|
|
|
|
return result;
|
|
}
|
|
|
|
void batch_insert(const int* keys, const int* values, size_t n) {
|
|
#ifdef __AVX2__
|
|
constexpr size_t stride = 8;
|
|
size_t i = 0;
|
|
|
|
for (; i + stride <= n; i += stride) {
|
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
|
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
|
|
|
alignas(32) int key_arr[8];
|
|
alignas(32) int val_arr[8];
|
|
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
|
|
|
for (int j = 0; j < 8; ++j) {
|
|
insert(key_arr[j], val_arr[j]);
|
|
}
|
|
}
|
|
|
|
for (; i < n; ++i)
|
|
insert(keys[i], values[i]);
|
|
#else
|
|
for (size_t i = 0; i < n; ++i)
|
|
insert(keys[i], values[i]);
|
|
#endif
|
|
}
|
|
|
|
void batch_insert_or_update(const int* keys, const int* values, size_t n) {
|
|
#ifdef __AVX2__
|
|
constexpr size_t stride = 8;
|
|
size_t i = 0;
|
|
|
|
for (; i + stride <= n; i += stride) {
|
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
|
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
|
|
|
alignas(32) int key_arr[8];
|
|
alignas(32) int val_arr[8];
|
|
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
|
|
|
for (int j = 0; j < 8; ++j) {
|
|
if (!insert(key_arr[j], val_arr[j]))
|
|
update(key_arr[j], val_arr[j]);
|
|
}
|
|
}
|
|
|
|
for (; i < n; ++i) {
|
|
if (!insert(keys[i], values[i]))
|
|
update(keys[i], values[i]);
|
|
}
|
|
#else
|
|
for (size_t i = 0; i < n; ++i) {
|
|
if (!insert(keys[i], values[i]))
|
|
update(keys[i], values[i]);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void batch_erase(const int* keys, size_t n) {
|
|
#ifdef __AVX2__
|
|
constexpr size_t stride = 8;
|
|
size_t i = 0;
|
|
|
|
for (; i + stride <= n; i += stride) {
|
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
|
alignas(32) int key_arr[8];
|
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
|
|
|
for (int j = 0; j < 8; ++j) {
|
|
erase(key_arr[j]);
|
|
}
|
|
}
|
|
|
|
for (; i < n; ++i) {
|
|
erase(keys[i]);
|
|
}
|
|
#else
|
|
for (size_t i = 0; i < n; ++i) {
|
|
erase(keys[i]);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
size_t size() const
|
|
{
|
|
return this->count.load();
|
|
}
|
|
};
|
|
|
|
#endif // UNORDEREDPARALLELMAP_H
|