From 8a49a004ce7c169eaccaa8d9fc7e4b4e2d859668 Mon Sep 17 00:00:00 2001 From: g2px1 Date: Fri, 2 May 2025 12:51:00 +0300 Subject: [PATCH] added batch methods and some optimization --- UnorderedParallelMap.h | 154 +++++++++++++++++++++++++++++++++++++---- main.cpp | 96 +++++++++++++++++++++++++ 2 files changed, 238 insertions(+), 12 deletions(-) diff --git a/UnorderedParallelMap.h b/UnorderedParallelMap.h index 28c0ef8..260bf27 100644 --- a/UnorderedParallelMap.h +++ b/UnorderedParallelMap.h @@ -8,6 +8,9 @@ #include #include #include "optimization.h" +#ifdef __AVX2__ +#include +#endif template class LockFreeMap @@ -15,13 +18,14 @@ class LockFreeMap static_assert(std::is_trivially_copyable_v); static_assert(std::is_trivially_copyable_v); - struct Entry + struct alignas(64) Entry { std::atomic version{0}; // even: stable, odd: writing bool occupied = false; bool deleted = false; K key{}; V value{}; + char padding[64 - sizeof(std::atomic) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{}; }; std::vector buckets; @@ -45,17 +49,19 @@ class LockFreeMap size_t new_capacity = this->capacity * 2; std::vector new_buckets(new_capacity); - for (auto& e : this->buckets) + for (size_t i = 0; i < this->buckets.size(); ++i) { + auto& e = this->buckets[i]; prefetch_for_read(&e); + if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]); cpu_relax(); if (e.occupied && !e.deleted) { size_t h = std::hash{}(e.key) % new_capacity; - for (size_t i = 0; i < new_capacity; ++i) + for (size_t j = 0; j < new_capacity; ++j) { - size_t idx = (h + i) % new_capacity; + size_t idx = (h + j) % new_capacity; Entry& ne = new_buckets[idx]; prefetch_for_write(&ne); cpu_relax(); @@ -105,7 +111,7 @@ public: uint64_t v = e.version.load(); if (v % 2 != 0) continue; - if (e.version.compare_exchange_strong(v, v + 1)) + if (e.version.compare_exchange_weak(v, v + 1)) { e.key = key; e.value = val; @@ -151,25 +157,30 @@ public: return std::nullopt; } - bool erase(const K& key) { + bool erase(const K& key) + { std::unique_lock lock(this->resize_mutex); size_t h = hash(key); - for (size_t i = 0; i < this->capacity; ++i) { + for (size_t i = 0; i < this->capacity; ++i) + { size_t idx = probe(h, i); Entry& e = this->buckets[idx]; prefetch_for_write(&e); cpu_relax(); - if (e.occupied && e.key == key) { + if (e.occupied && e.key == key) + { if (e.deleted) return false; uint64_t v = e.version.load(); if (v % 2 != 0) continue; - if (e.version.compare_exchange_strong(v, v + 1)) { - if (!e.deleted) { + if (e.version.compare_exchange_strong(v, v + 1)) + { + if (!e.deleted) + { e.deleted = true; this->count.fetch_sub(1); } @@ -253,9 +264,13 @@ public: void for_each(const std::function& cb) const { - for (const auto& e : this->buckets) + const size_t N = this->buckets.size(); + for (size_t i = 0; i < N; ++i) { + const auto& e = this->buckets[i]; + prefetch_for_read(&e); + if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]); cpu_relax(); uint64_t v1 = e.version.load(std::memory_order_acquire); @@ -263,6 +278,7 @@ public: K key = e.key; V val = e.value; + cpu_relax(); uint64_t v2 = e.version.load(std::memory_order_acquire); if (v1 == v2 && v2 % 2 == 0) @@ -313,7 +329,7 @@ public: e.version.store(v + 2); } } - count.store(0); + this->count.store(0); } void reserve(size_t desired_capacity) @@ -359,6 +375,120 @@ public: this->capacity = new_capacity; } + std::vector> batch_find(const int* keys, size_t n) const + { + std::vector> result(n); + +#ifdef __AVX2__ + constexpr size_t stride = 8; + size_t i = 0; + for (; i + stride <= n; i += stride) { + __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); + alignas(32) int key_arr[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); + + for (int j = 0; j < 8; ++j) { + result[i + j] = find(key_arr[j]); + } + } + + for (; i < n; ++i) + result[i] = find(keys[i]); + +#else + for (size_t i = 0; i < n; ++i) + result[i] = find(keys[i]); +#endif + + return result; + } + + void batch_insert(const int* keys, const int* values, size_t n) { +#ifdef __AVX2__ + constexpr size_t stride = 8; + size_t i = 0; + + for (; i + stride <= n; i += stride) { + __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); + __m256i vvals = _mm256_loadu_si256(reinterpret_cast(values + i)); + + alignas(32) int key_arr[8]; + alignas(32) int val_arr[8]; + + _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); + _mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals); + + for (int j = 0; j < 8; ++j) { + insert(key_arr[j], val_arr[j]); + } + } + + for (; i < n; ++i) + insert(keys[i], values[i]); +#else + for (size_t i = 0; i < n; ++i) + insert(keys[i], values[i]); +#endif + } + + void batch_insert_or_update(const int* keys, const int* values, size_t n) { +#ifdef __AVX2__ + constexpr size_t stride = 8; + size_t i = 0; + + for (; i + stride <= n; i += stride) { + __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); + __m256i vvals = _mm256_loadu_si256(reinterpret_cast(values + i)); + + alignas(32) int key_arr[8]; + alignas(32) int val_arr[8]; + + _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); + _mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals); + + for (int j = 0; j < 8; ++j) { + if (!insert(key_arr[j], val_arr[j])) + update(key_arr[j], val_arr[j]); + } + } + + for (; i < n; ++i) { + if (!insert(keys[i], values[i])) + update(keys[i], values[i]); + } +#else + for (size_t i = 0; i < n; ++i) { + if (!insert(keys[i], values[i])) + update(keys[i], values[i]); + } +#endif + } + + void batch_erase(const int* keys, size_t n) { +#ifdef __AVX2__ + constexpr size_t stride = 8; + size_t i = 0; + + for (; i + stride <= n; i += stride) { + __m256i vkeys = _mm256_loadu_si256(reinterpret_cast(keys + i)); + alignas(32) int key_arr[8]; + _mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys); + + for (int j = 0; j < 8; ++j) { + erase(key_arr[j]); + } + } + + for (; i < n; ++i) { + erase(keys[i]); + } +#else + for (size_t i = 0; i < n; ++i) { + erase(keys[i]); + } +#endif + } + size_t size() const { return this->count.load(); diff --git a/main.cpp b/main.cpp index 4ba2893..89735ae 100644 --- a/main.cpp +++ b/main.cpp @@ -10,6 +10,9 @@ constexpr int THREADS = 8; constexpr int OPS_PER_THREAD = 10000; +constexpr int BATCH_THREADS = 4; +constexpr int BATCH_SIZE = 1024; + std::atomic insert_ns{0}; std::atomic find_ns{0}; std::atomic update_ns{0}; @@ -55,6 +58,96 @@ void worker(LockFreeMap& map, int thread_id) } } +void batch_worker_insert(LockFreeMap& map, int base_key) { + std::vector keys(BATCH_SIZE); + std::vector values(BATCH_SIZE); + + for (int i = 0; i < BATCH_SIZE; ++i) { + keys[i] = base_key + i; + values[i] = i; + } + + map.batch_insert(keys.data(), values.data(), BATCH_SIZE); +} + +void batch_worker_update(LockFreeMap& map, int base_key) { + std::vector keys(BATCH_SIZE); + std::vector values(BATCH_SIZE); + + for (int i = 0; i < BATCH_SIZE; ++i) { + keys[i] = base_key + i; + values[i] = 10000 + i; + } + + map.batch_insert_or_update(keys.data(), values.data(), BATCH_SIZE); +} + +void batch_worker_find(const LockFreeMap& map, int base_key) { + std::vector keys(BATCH_SIZE); + for (int i = 0; i < BATCH_SIZE; ++i) { + keys[i] = base_key + i; + } + + auto results = map.batch_find(keys.data(), BATCH_SIZE); + for (size_t i = 0; i < results.size(); ++i) { + assert(results[i].has_value()); + } +} + +void batch_worker_erase(LockFreeMap& map, int base_key) { + std::vector keys(BATCH_SIZE); + for (int i = 0; i < BATCH_SIZE; ++i) { + keys[i] = base_key + i; + } + + map.batch_erase(keys.data(), BATCH_SIZE); +} + +void test_batch_operations_parallel() { + LockFreeMap map; + std::vector threads; + + auto benchmark = [&](const std::string& label, auto&& fn) { + uint64_t start = now_ns(); + fn(); + uint64_t end = now_ns(); + double ms = (end - start) / 1e6; + double ops = BATCH_THREADS * BATCH_SIZE; + double throughput = ops / (ms / 1000.0); + + std::cout << "✓ " << label << " passed in " << ms << " ms, throughput = " + << static_cast(throughput) << " ops/sec\n"; + }; + + benchmark("batch_insert (parallel)", [&]() { + for (int i = 0; i < BATCH_THREADS; ++i) + threads.emplace_back(batch_worker_insert, std::ref(map), i * BATCH_SIZE); + for (auto& t : threads) t.join(); + threads.clear(); + }); + + benchmark("batch_insert_or_update (parallel)", [&]() { + for (int i = 0; i < BATCH_THREADS; ++i) + threads.emplace_back(batch_worker_update, std::ref(map), i * BATCH_SIZE); + for (auto& t : threads) t.join(); + threads.clear(); + }); + + benchmark("batch_find (parallel)", [&]() { + for (int i = 0; i < BATCH_THREADS; ++i) + threads.emplace_back(batch_worker_find, std::cref(map), i * BATCH_SIZE); + for (auto& t : threads) t.join(); + threads.clear(); + }); + + benchmark("batch_erase (parallel)", [&]() { + for (int i = 0; i < BATCH_THREADS; ++i) + threads.emplace_back(batch_worker_erase, std::ref(map), i * BATCH_SIZE); + for (auto& t : threads) t.join(); + threads.clear(); + }); +} + void verify_map(LockFreeMap& map) { std::cout << "Final size: " << map.size() << "\n"; @@ -114,5 +207,8 @@ int main() std::cout << "Erase avg: " << erase_ns / 1e3 << " μs total (" << (erase_ns / (double)total_ops) << " ns/op)\n"; std::cout << "--------------------------\n"; + std::cout << "=== Batch Operations: Parallel Test ===\n"; + test_batch_operations_parallel(); + std::cout << "=== Done ===\n"; return 0; }