#ifndef UNORDEREDPARALLELMAP_H #define UNORDEREDPARALLELMAP_H #include #include #include #include #include #include #include #include "optimization.h" #ifdef __AVX2__ #include #endif template class LockFreeMap { static_assert(std::is_trivially_copyable_v); static_assert(std::is_trivially_copyable_v); struct alignas(64) Entry { std::atomic version{0}; // even: stable, odd: writing bool occupied = false; bool deleted = false; K key{}; V value{}; char padding[64 - sizeof(std::atomic) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{}; }; std::vector buckets; mutable std::shared_mutex resize_mutex; std::atomic count{0}; size_t capacity; static constexpr float MAX_LOAD = 0.7; size_t hash(const K& key) const { return std::hash{}(key) % this->capacity; } size_t probe(size_t h, size_t i) const { return (h + i) % this->capacity; } void resize() { size_t new_capacity = this->capacity * 2; std::vector new_buckets(new_capacity); for (size_t i = 0; i < this->buckets.size(); ++i) { auto& e = this->buckets[i]; prefetch_for_read(&e); if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]); cpu_relax(); if (e.occupied && !e.deleted) { size_t h = std::hash{}(e.key) % new_capacity; for (size_t j = 0; j < new_capacity; ++j) { size_t idx = (h + j) % new_capacity; Entry& ne = new_buckets[idx]; prefetch_for_write(&ne); cpu_relax(); if (!ne.occupied) { ne.version.store(1); ne.key = e.key; ne.value = e.value; ne.occupied = true; ne.deleted = false; ne.version.store(2); break; } } } } this->buckets = std::move(new_buckets); this->capacity = new_capacity; } public: explicit LockFreeMap(size_t init_cap = 1024) : buckets(init_cap), capacity(init_cap) { } bool insert(const K& key, const V& val) { std::unique_lock lock(this->resize_mutex); if ((float)(this->count.load() + 1) / this->capacity > this->MAX_LOAD) resize(); size_t h = hash(key); for (size_t i = 0; i < this->capacity; ++i) { size_t idx = probe(h, i); Entry& e = this->buckets[idx]; prefetch_for_write(&e); cpu_relax(); if (!e.occupied || (e.deleted && e.key == key)) { uint64_t v = e.version.load(); if (v % 2 != 0) continue; if (e.version.compare_exchange_weak(v, v + 1)) { e.key = key; e.value = val; e.occupied = true; e.deleted = false; e.version.store(v + 2); this->count.fetch_add(1); return true; } --i; } else if (!e.deleted && e.key == key) { return false; } } return false; } std::optional find(const K& key) const { size_t h = hash(key); for (size_t i = 0; i < this->capacity; ++i) { size_t idx = probe(h, i); const Entry& e = this->buckets[idx]; prefetch_for_read(&e); cpu_relax(); uint64_t v1 = e.version.load(std::memory_order_acquire); if (v1 % 2 != 0) continue; if (e.occupied && !e.deleted && e.key == key) { V val = e.value; cpu_relax(); uint64_t v2 = e.version.load(std::memory_order_acquire); if (v1 == v2 && v2 % 2 == 0) return val; } } return std::nullopt; } bool erase(const K& key) { std::unique_lock lock(this->resize_mutex); size_t h = hash(key); for (size_t i = 0; i < this->capacity; ++i) { size_t idx = probe(h, i); Entry& e = this->buckets[idx]; prefetch_for_write(&e); cpu_relax(); if (e.occupied && e.key == key) { if (e.deleted) return false; uint64_t v = e.version.load(); if (v % 2 != 0) continue; if (e.version.compare_exchange_strong(v, v + 1)) { if (!e.deleted) { e.deleted = true; this->count.fetch_sub(1); } e.version.store(v + 2); return true; } --i; } } return false; } bool update(const K& key, const V& new_val) { size_t h = hash(key); for (size_t i = 0; i < this->capacity; ++i) { size_t idx = probe(h, i); Entry& e = this->buckets[idx]; prefetch_for_write(&e); cpu_relax(); uint64_t v = e.version.load(); if (v % 2 != 0) continue; if (e.occupied && !e.deleted && e.key == key) { if (e.version.compare_exchange_strong(v, v + 1)) { e.value = new_val; e.version.store(v + 2); return true; } --i; } } return false; } std::vector keys() const { std::vector result; for (const auto& e : this->buckets) { prefetch_for_read(&e); cpu_relax(); uint64_t v1 = e.version.load(std::memory_order_acquire); if (v1 % 2 != 0 || !e.occupied || e.deleted) continue; K key = e.key; cpu_relax(); uint64_t v2 = e.version.load(std::memory_order_acquire); if (v1 == v2 && v2 % 2 == 0) result.push_back(key); } return result; } std::vector> entries() const { std::vector> result; for (const auto& e : this->buckets) { prefetch_for_read(&e); cpu_relax(); uint64_t v1 = e.version.load(std::memory_order_acquire); if (v1 % 2 != 0 || !e.occupied || e.deleted) continue; K key = e.key; V val = e.value; cpu_relax(); uint64_t v2 = e.version.load(std::memory_order_acquire); if (v1 == v2 && v2 % 2 == 0) result.emplace_back(key, val); } return result; } void for_each(const std::function& cb) const { const size_t N = this->buckets.size(); for (size_t i = 0; i < N; ++i) { const auto& e = this->buckets[i]; prefetch_for_read(&e); if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]); cpu_relax(); uint64_t v1 = e.version.load(std::memory_order_acquire); if (v1 % 2 != 0 || !e.occupied || e.deleted) continue; K key = e.key; V val = e.value; cpu_relax(); uint64_t v2 = e.version.load(std::memory_order_acquire); if (v1 == v2 && v2 % 2 == 0) cb(key, val); } } void for_each_mut(const std::function& cb) { for (auto& e : this->buckets) { prefetch_for_write(&e); cpu_relax(); if (!e.occupied || e.deleted) continue; uint64_t v = e.version.load(); if (v % 2 != 0) continue; if (e.version.compare_exchange_strong(v, v + 1)) { cb(e.key, e.value); e.version.store(v + 2); } else { cpu_relax(); } } } void clear() { std::unique_lock lock(this->resize_mutex); for (auto& e : this->buckets) { prefetch_for_write(&e); cpu_relax(); uint64_t v = e.version.load(); if (v % 2 != 0) continue; if (e.version.compare_exchange_strong(v, v + 1)) { e.occupied = false; e.deleted = false; e.version.store(v + 2); } } this->count.store(0); } void reserve(size_t desired_capacity) { std::unique_lock lock(this->resize_mutex); if (desired_capacity <= this->capacity) return; size_t new_capacity = 1; while (new_capacity < desired_capacity) new_capacity <<= 1; std::vector new_buckets(new_capacity); for (auto& e : this->buckets) { prefetch_for_read(&e); cpu_relax(); if (e.occupied && !e.deleted) { size_t h = std::hash{}(e.key) % new_capacity; for (size_t i = 0; i < new_capacity; ++i) { size_t idx = (h + i) % new_capacity; Entry& ne = new_buckets[idx]; prefetch_for_write(&ne); cpu_relax(); if (!ne.occupied) { ne.version.store(1); ne.key = e.key; ne.value = e.value; ne.occupied = true; ne.deleted = false; ne.version.store(2); break; } } } } this->buckets = std::move(new_buckets); this->capacity = new_capacity; } std::vector> batch_find(const int* keys, size_t n) const { std::vector> result(n); #ifdef __AVX2__ constexpr size_t stride = 8; size_t i = 0; for (; i + stride <= n; i += stride) { __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); alignas(32) int key_arr[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); for (int j = 0; j < 8; ++j) { result[i + j] = find(key_arr[j]); } } for (; i < n; ++i) result[i] = find(keys[i]); #else for (size_t i = 0; i < n; ++i) result[i] = find(keys[i]); #endif return result; } void batch_insert(const int* keys, const int* values, size_t n) { #ifdef __AVX2__ constexpr size_t stride = 8; size_t i = 0; for (; i + stride <= n; i += stride) { __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); __m256i vvals = _mm256_loadu_si256(reinterpret_cast(values + i)); alignas(32) int key_arr[8]; alignas(32) int val_arr[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); _mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals); for (int j = 0; j < 8; ++j) { insert(key_arr[j], val_arr[j]); } } for (; i < n; ++i) insert(keys[i], values[i]); #else for (size_t i = 0; i < n; ++i) insert(keys[i], values[i]); #endif } void batch_insert_or_update(const int* keys, const int* values, size_t n) { #ifdef __AVX2__ constexpr size_t stride = 8; size_t i = 0; for (; i + stride <= n; i += stride) { __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); __m256i vvals = _mm256_loadu_si256(reinterpret_cast(values + i)); alignas(32) int key_arr[8]; alignas(32) int val_arr[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); _mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals); for (int j = 0; j < 8; ++j) { if (!insert(key_arr[j], val_arr[j])) update(key_arr[j], val_arr[j]); } } for (; i < n; ++i) { if (!insert(keys[i], values[i])) update(keys[i], values[i]); } #else for (size_t i = 0; i < n; ++i) { if (!insert(keys[i], values[i])) update(keys[i], values[i]); } #endif } void batch_erase(const int* keys, size_t n) { #ifdef __AVX2__ constexpr size_t stride = 8; size_t i = 0; for (; i + stride <= n; i += stride) { __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); alignas(32) int key_arr[8]; _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); for (int j = 0; j < 8; ++j) { erase(key_arr[j]); } } for (; i < n; ++i) { erase(keys[i]); } #else for (size_t i = 0; i < n; ++i) { erase(keys[i]); } #endif } size_t size() const { return this->count.load(); } }; #endif // UNORDEREDPARALLELMAP_H