added batch methods and some optimization
This commit is contained in:
parent
a3f7f01476
commit
8a49a004ce
@ -8,6 +8,9 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include "optimization.h"
|
#include "optimization.h"
|
||||||
|
#ifdef __AVX2__
|
||||||
|
#include <immintrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename K, typename V>
|
template <typename K, typename V>
|
||||||
class LockFreeMap
|
class LockFreeMap
|
||||||
@ -15,13 +18,14 @@ class LockFreeMap
|
|||||||
static_assert(std::is_trivially_copyable_v<K>);
|
static_assert(std::is_trivially_copyable_v<K>);
|
||||||
static_assert(std::is_trivially_copyable_v<V>);
|
static_assert(std::is_trivially_copyable_v<V>);
|
||||||
|
|
||||||
struct Entry
|
struct alignas(64) Entry
|
||||||
{
|
{
|
||||||
std::atomic<uint64_t> version{0}; // even: stable, odd: writing
|
std::atomic<uint64_t> version{0}; // even: stable, odd: writing
|
||||||
bool occupied = false;
|
bool occupied = false;
|
||||||
bool deleted = false;
|
bool deleted = false;
|
||||||
K key{};
|
K key{};
|
||||||
V value{};
|
V value{};
|
||||||
|
char padding[64 - sizeof(std::atomic<uint64_t>) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{};
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<Entry> buckets;
|
std::vector<Entry> buckets;
|
||||||
@ -45,17 +49,19 @@ class LockFreeMap
|
|||||||
size_t new_capacity = this->capacity * 2;
|
size_t new_capacity = this->capacity * 2;
|
||||||
std::vector<Entry> new_buckets(new_capacity);
|
std::vector<Entry> new_buckets(new_capacity);
|
||||||
|
|
||||||
for (auto& e : this->buckets)
|
for (size_t i = 0; i < this->buckets.size(); ++i)
|
||||||
{
|
{
|
||||||
|
auto& e = this->buckets[i];
|
||||||
prefetch_for_read(&e);
|
prefetch_for_read(&e);
|
||||||
|
if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]);
|
||||||
cpu_relax();
|
cpu_relax();
|
||||||
|
|
||||||
if (e.occupied && !e.deleted)
|
if (e.occupied && !e.deleted)
|
||||||
{
|
{
|
||||||
size_t h = std::hash<K>{}(e.key) % new_capacity;
|
size_t h = std::hash<K>{}(e.key) % new_capacity;
|
||||||
for (size_t i = 0; i < new_capacity; ++i)
|
for (size_t j = 0; j < new_capacity; ++j)
|
||||||
{
|
{
|
||||||
size_t idx = (h + i) % new_capacity;
|
size_t idx = (h + j) % new_capacity;
|
||||||
Entry& ne = new_buckets[idx];
|
Entry& ne = new_buckets[idx];
|
||||||
prefetch_for_write(&ne);
|
prefetch_for_write(&ne);
|
||||||
cpu_relax();
|
cpu_relax();
|
||||||
@ -105,7 +111,7 @@ public:
|
|||||||
uint64_t v = e.version.load();
|
uint64_t v = e.version.load();
|
||||||
if (v % 2 != 0) continue;
|
if (v % 2 != 0) continue;
|
||||||
|
|
||||||
if (e.version.compare_exchange_strong(v, v + 1))
|
if (e.version.compare_exchange_weak(v, v + 1))
|
||||||
{
|
{
|
||||||
e.key = key;
|
e.key = key;
|
||||||
e.value = val;
|
e.value = val;
|
||||||
@ -151,25 +157,30 @@ public:
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool erase(const K& key) {
|
bool erase(const K& key)
|
||||||
|
{
|
||||||
std::unique_lock lock(this->resize_mutex);
|
std::unique_lock lock(this->resize_mutex);
|
||||||
|
|
||||||
size_t h = hash(key);
|
size_t h = hash(key);
|
||||||
for (size_t i = 0; i < this->capacity; ++i) {
|
for (size_t i = 0; i < this->capacity; ++i)
|
||||||
|
{
|
||||||
size_t idx = probe(h, i);
|
size_t idx = probe(h, i);
|
||||||
Entry& e = this->buckets[idx];
|
Entry& e = this->buckets[idx];
|
||||||
|
|
||||||
prefetch_for_write(&e);
|
prefetch_for_write(&e);
|
||||||
cpu_relax();
|
cpu_relax();
|
||||||
|
|
||||||
if (e.occupied && e.key == key) {
|
if (e.occupied && e.key == key)
|
||||||
|
{
|
||||||
if (e.deleted) return false;
|
if (e.deleted) return false;
|
||||||
|
|
||||||
uint64_t v = e.version.load();
|
uint64_t v = e.version.load();
|
||||||
if (v % 2 != 0) continue;
|
if (v % 2 != 0) continue;
|
||||||
|
|
||||||
if (e.version.compare_exchange_strong(v, v + 1)) {
|
if (e.version.compare_exchange_strong(v, v + 1))
|
||||||
if (!e.deleted) {
|
{
|
||||||
|
if (!e.deleted)
|
||||||
|
{
|
||||||
e.deleted = true;
|
e.deleted = true;
|
||||||
this->count.fetch_sub(1);
|
this->count.fetch_sub(1);
|
||||||
}
|
}
|
||||||
@ -253,9 +264,13 @@ public:
|
|||||||
|
|
||||||
void for_each(const std::function<void(const K&, const V&)>& cb) const
|
void for_each(const std::function<void(const K&, const V&)>& cb) const
|
||||||
{
|
{
|
||||||
for (const auto& e : this->buckets)
|
const size_t N = this->buckets.size();
|
||||||
|
for (size_t i = 0; i < N; ++i)
|
||||||
{
|
{
|
||||||
|
const auto& e = this->buckets[i];
|
||||||
|
|
||||||
prefetch_for_read(&e);
|
prefetch_for_read(&e);
|
||||||
|
if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]);
|
||||||
cpu_relax();
|
cpu_relax();
|
||||||
|
|
||||||
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
||||||
@ -263,6 +278,7 @@ public:
|
|||||||
|
|
||||||
K key = e.key;
|
K key = e.key;
|
||||||
V val = e.value;
|
V val = e.value;
|
||||||
|
|
||||||
cpu_relax();
|
cpu_relax();
|
||||||
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
||||||
if (v1 == v2 && v2 % 2 == 0)
|
if (v1 == v2 && v2 % 2 == 0)
|
||||||
@ -313,7 +329,7 @@ public:
|
|||||||
e.version.store(v + 2);
|
e.version.store(v + 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
count.store(0);
|
this->count.store(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void reserve(size_t desired_capacity)
|
void reserve(size_t desired_capacity)
|
||||||
@ -359,6 +375,120 @@ public:
|
|||||||
this->capacity = new_capacity;
|
this->capacity = new_capacity;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<std::optional<int>> batch_find(const int* keys, size_t n) const
|
||||||
|
{
|
||||||
|
std::vector<std::optional<int>> result(n);
|
||||||
|
|
||||||
|
#ifdef __AVX2__
|
||||||
|
constexpr size_t stride = 8;
|
||||||
|
size_t i = 0;
|
||||||
|
for (; i + stride <= n; i += stride) {
|
||||||
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||||
|
alignas(32) int key_arr[8];
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||||
|
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
result[i + j] = find(key_arr[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n; ++i)
|
||||||
|
result[i] = find(keys[i]);
|
||||||
|
|
||||||
|
#else
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
result[i] = find(keys[i]);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_insert(const int* keys, const int* values, size_t n) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
constexpr size_t stride = 8;
|
||||||
|
size_t i = 0;
|
||||||
|
|
||||||
|
for (; i + stride <= n; i += stride) {
|
||||||
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||||
|
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
||||||
|
|
||||||
|
alignas(32) int key_arr[8];
|
||||||
|
alignas(32) int val_arr[8];
|
||||||
|
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
||||||
|
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
insert(key_arr[j], val_arr[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n; ++i)
|
||||||
|
insert(keys[i], values[i]);
|
||||||
|
#else
|
||||||
|
for (size_t i = 0; i < n; ++i)
|
||||||
|
insert(keys[i], values[i]);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_insert_or_update(const int* keys, const int* values, size_t n) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
constexpr size_t stride = 8;
|
||||||
|
size_t i = 0;
|
||||||
|
|
||||||
|
for (; i + stride <= n; i += stride) {
|
||||||
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||||
|
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
||||||
|
|
||||||
|
alignas(32) int key_arr[8];
|
||||||
|
alignas(32) int val_arr[8];
|
||||||
|
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
||||||
|
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
if (!insert(key_arr[j], val_arr[j]))
|
||||||
|
update(key_arr[j], val_arr[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n; ++i) {
|
||||||
|
if (!insert(keys[i], values[i]))
|
||||||
|
update(keys[i], values[i]);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
if (!insert(keys[i], values[i]))
|
||||||
|
update(keys[i], values[i]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_erase(const int* keys, size_t n) {
|
||||||
|
#ifdef __AVX2__
|
||||||
|
constexpr size_t stride = 8;
|
||||||
|
size_t i = 0;
|
||||||
|
|
||||||
|
for (; i + stride <= n; i += stride) {
|
||||||
|
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||||
|
alignas(32) int key_arr[8];
|
||||||
|
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||||
|
|
||||||
|
for (int j = 0; j < 8; ++j) {
|
||||||
|
erase(key_arr[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (; i < n; ++i) {
|
||||||
|
erase(keys[i]);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
erase(keys[i]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
size_t size() const
|
size_t size() const
|
||||||
{
|
{
|
||||||
return this->count.load();
|
return this->count.load();
|
||||||
|
96
main.cpp
96
main.cpp
@ -10,6 +10,9 @@
|
|||||||
constexpr int THREADS = 8;
|
constexpr int THREADS = 8;
|
||||||
constexpr int OPS_PER_THREAD = 10000;
|
constexpr int OPS_PER_THREAD = 10000;
|
||||||
|
|
||||||
|
constexpr int BATCH_THREADS = 4;
|
||||||
|
constexpr int BATCH_SIZE = 1024;
|
||||||
|
|
||||||
std::atomic<uint64_t> insert_ns{0};
|
std::atomic<uint64_t> insert_ns{0};
|
||||||
std::atomic<uint64_t> find_ns{0};
|
std::atomic<uint64_t> find_ns{0};
|
||||||
std::atomic<uint64_t> update_ns{0};
|
std::atomic<uint64_t> update_ns{0};
|
||||||
@ -55,6 +58,96 @@ void worker(LockFreeMap<int, int>& map, int thread_id)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void batch_worker_insert(LockFreeMap<int, int>& map, int base_key) {
|
||||||
|
std::vector<int> keys(BATCH_SIZE);
|
||||||
|
std::vector<int> values(BATCH_SIZE);
|
||||||
|
|
||||||
|
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||||
|
keys[i] = base_key + i;
|
||||||
|
values[i] = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
map.batch_insert(keys.data(), values.data(), BATCH_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_worker_update(LockFreeMap<int, int>& map, int base_key) {
|
||||||
|
std::vector<int> keys(BATCH_SIZE);
|
||||||
|
std::vector<int> values(BATCH_SIZE);
|
||||||
|
|
||||||
|
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||||
|
keys[i] = base_key + i;
|
||||||
|
values[i] = 10000 + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
map.batch_insert_or_update(keys.data(), values.data(), BATCH_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_worker_find(const LockFreeMap<int, int>& map, int base_key) {
|
||||||
|
std::vector<int> keys(BATCH_SIZE);
|
||||||
|
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||||
|
keys[i] = base_key + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto results = map.batch_find(keys.data(), BATCH_SIZE);
|
||||||
|
for (size_t i = 0; i < results.size(); ++i) {
|
||||||
|
assert(results[i].has_value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void batch_worker_erase(LockFreeMap<int, int>& map, int base_key) {
|
||||||
|
std::vector<int> keys(BATCH_SIZE);
|
||||||
|
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||||
|
keys[i] = base_key + i;
|
||||||
|
}
|
||||||
|
|
||||||
|
map.batch_erase(keys.data(), BATCH_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_batch_operations_parallel() {
|
||||||
|
LockFreeMap<int, int> map;
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
|
||||||
|
auto benchmark = [&](const std::string& label, auto&& fn) {
|
||||||
|
uint64_t start = now_ns();
|
||||||
|
fn();
|
||||||
|
uint64_t end = now_ns();
|
||||||
|
double ms = (end - start) / 1e6;
|
||||||
|
double ops = BATCH_THREADS * BATCH_SIZE;
|
||||||
|
double throughput = ops / (ms / 1000.0);
|
||||||
|
|
||||||
|
std::cout << "✓ " << label << " passed in " << ms << " ms, throughput = "
|
||||||
|
<< static_cast<size_t>(throughput) << " ops/sec\n";
|
||||||
|
};
|
||||||
|
|
||||||
|
benchmark("batch_insert (parallel)", [&]() {
|
||||||
|
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||||
|
threads.emplace_back(batch_worker_insert, std::ref(map), i * BATCH_SIZE);
|
||||||
|
for (auto& t : threads) t.join();
|
||||||
|
threads.clear();
|
||||||
|
});
|
||||||
|
|
||||||
|
benchmark("batch_insert_or_update (parallel)", [&]() {
|
||||||
|
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||||
|
threads.emplace_back(batch_worker_update, std::ref(map), i * BATCH_SIZE);
|
||||||
|
for (auto& t : threads) t.join();
|
||||||
|
threads.clear();
|
||||||
|
});
|
||||||
|
|
||||||
|
benchmark("batch_find (parallel)", [&]() {
|
||||||
|
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||||
|
threads.emplace_back(batch_worker_find, std::cref(map), i * BATCH_SIZE);
|
||||||
|
for (auto& t : threads) t.join();
|
||||||
|
threads.clear();
|
||||||
|
});
|
||||||
|
|
||||||
|
benchmark("batch_erase (parallel)", [&]() {
|
||||||
|
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||||
|
threads.emplace_back(batch_worker_erase, std::ref(map), i * BATCH_SIZE);
|
||||||
|
for (auto& t : threads) t.join();
|
||||||
|
threads.clear();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void verify_map(LockFreeMap<int, int>& map)
|
void verify_map(LockFreeMap<int, int>& map)
|
||||||
{
|
{
|
||||||
std::cout << "Final size: " << map.size() << "\n";
|
std::cout << "Final size: " << map.size() << "\n";
|
||||||
@ -114,5 +207,8 @@ int main()
|
|||||||
std::cout << "Erase avg: " << erase_ns / 1e3 << " μs total (" << (erase_ns / (double)total_ops) << " ns/op)\n";
|
std::cout << "Erase avg: " << erase_ns / 1e3 << " μs total (" << (erase_ns / (double)total_ops) << " ns/op)\n";
|
||||||
std::cout << "--------------------------\n";
|
std::cout << "--------------------------\n";
|
||||||
|
|
||||||
|
std::cout << "=== Batch Operations: Parallel Test ===\n";
|
||||||
|
test_batch_operations_parallel();
|
||||||
|
std::cout << "=== Done ===\n";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user