added batch methods and some optimization
This commit is contained in:
parent
a3f7f01476
commit
8a49a004ce
@ -8,6 +8,9 @@
|
||||
#include <functional>
|
||||
#include <cstdint>
|
||||
#include "optimization.h"
|
||||
#ifdef __AVX2__
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
template <typename K, typename V>
|
||||
class LockFreeMap
|
||||
@ -15,13 +18,14 @@ class LockFreeMap
|
||||
static_assert(std::is_trivially_copyable_v<K>);
|
||||
static_assert(std::is_trivially_copyable_v<V>);
|
||||
|
||||
struct Entry
|
||||
struct alignas(64) Entry
|
||||
{
|
||||
std::atomic<uint64_t> version{0}; // even: stable, odd: writing
|
||||
bool occupied = false;
|
||||
bool deleted = false;
|
||||
K key{};
|
||||
V value{};
|
||||
char padding[64 - sizeof(std::atomic<uint64_t>) - sizeof(bool) * 2 - sizeof(K) - sizeof(V)]{};
|
||||
};
|
||||
|
||||
std::vector<Entry> buckets;
|
||||
@ -45,17 +49,19 @@ class LockFreeMap
|
||||
size_t new_capacity = this->capacity * 2;
|
||||
std::vector<Entry> new_buckets(new_capacity);
|
||||
|
||||
for (auto& e : this->buckets)
|
||||
for (size_t i = 0; i < this->buckets.size(); ++i)
|
||||
{
|
||||
auto& e = this->buckets[i];
|
||||
prefetch_for_read(&e);
|
||||
if (i + 2 < this->buckets.size()) prefetch_for_read(&this->buckets[i + 2]);
|
||||
cpu_relax();
|
||||
|
||||
if (e.occupied && !e.deleted)
|
||||
{
|
||||
size_t h = std::hash<K>{}(e.key) % new_capacity;
|
||||
for (size_t i = 0; i < new_capacity; ++i)
|
||||
for (size_t j = 0; j < new_capacity; ++j)
|
||||
{
|
||||
size_t idx = (h + i) % new_capacity;
|
||||
size_t idx = (h + j) % new_capacity;
|
||||
Entry& ne = new_buckets[idx];
|
||||
prefetch_for_write(&ne);
|
||||
cpu_relax();
|
||||
@ -105,7 +111,7 @@ public:
|
||||
uint64_t v = e.version.load();
|
||||
if (v % 2 != 0) continue;
|
||||
|
||||
if (e.version.compare_exchange_strong(v, v + 1))
|
||||
if (e.version.compare_exchange_weak(v, v + 1))
|
||||
{
|
||||
e.key = key;
|
||||
e.value = val;
|
||||
@ -151,25 +157,30 @@ public:
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
bool erase(const K& key) {
|
||||
bool erase(const K& key)
|
||||
{
|
||||
std::unique_lock lock(this->resize_mutex);
|
||||
|
||||
size_t h = hash(key);
|
||||
for (size_t i = 0; i < this->capacity; ++i) {
|
||||
for (size_t i = 0; i < this->capacity; ++i)
|
||||
{
|
||||
size_t idx = probe(h, i);
|
||||
Entry& e = this->buckets[idx];
|
||||
|
||||
prefetch_for_write(&e);
|
||||
cpu_relax();
|
||||
|
||||
if (e.occupied && e.key == key) {
|
||||
if (e.occupied && e.key == key)
|
||||
{
|
||||
if (e.deleted) return false;
|
||||
|
||||
uint64_t v = e.version.load();
|
||||
if (v % 2 != 0) continue;
|
||||
|
||||
if (e.version.compare_exchange_strong(v, v + 1)) {
|
||||
if (!e.deleted) {
|
||||
if (e.version.compare_exchange_strong(v, v + 1))
|
||||
{
|
||||
if (!e.deleted)
|
||||
{
|
||||
e.deleted = true;
|
||||
this->count.fetch_sub(1);
|
||||
}
|
||||
@ -253,9 +264,13 @@ public:
|
||||
|
||||
void for_each(const std::function<void(const K&, const V&)>& cb) const
|
||||
{
|
||||
for (const auto& e : this->buckets)
|
||||
const size_t N = this->buckets.size();
|
||||
for (size_t i = 0; i < N; ++i)
|
||||
{
|
||||
const auto& e = this->buckets[i];
|
||||
|
||||
prefetch_for_read(&e);
|
||||
if (i + 2 < N) prefetch_for_read(&this->buckets[i + 2]);
|
||||
cpu_relax();
|
||||
|
||||
uint64_t v1 = e.version.load(std::memory_order_acquire);
|
||||
@ -263,6 +278,7 @@ public:
|
||||
|
||||
K key = e.key;
|
||||
V val = e.value;
|
||||
|
||||
cpu_relax();
|
||||
uint64_t v2 = e.version.load(std::memory_order_acquire);
|
||||
if (v1 == v2 && v2 % 2 == 0)
|
||||
@ -313,7 +329,7 @@ public:
|
||||
e.version.store(v + 2);
|
||||
}
|
||||
}
|
||||
count.store(0);
|
||||
this->count.store(0);
|
||||
}
|
||||
|
||||
void reserve(size_t desired_capacity)
|
||||
@ -359,6 +375,120 @@ public:
|
||||
this->capacity = new_capacity;
|
||||
}
|
||||
|
||||
std::vector<std::optional<int>> batch_find(const int* keys, size_t n) const
|
||||
{
|
||||
std::vector<std::optional<int>> result(n);
|
||||
|
||||
#ifdef __AVX2__
|
||||
constexpr size_t stride = 8;
|
||||
size_t i = 0;
|
||||
for (; i + stride <= n; i += stride) {
|
||||
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||
alignas(32) int key_arr[8];
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
result[i + j] = find(key_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < n; ++i)
|
||||
result[i] = find(keys[i]);
|
||||
|
||||
#else
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
result[i] = find(keys[i]);
|
||||
#endif
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void batch_insert(const int* keys, const int* values, size_t n) {
|
||||
#ifdef __AVX2__
|
||||
constexpr size_t stride = 8;
|
||||
size_t i = 0;
|
||||
|
||||
for (; i + stride <= n; i += stride) {
|
||||
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
||||
|
||||
alignas(32) int key_arr[8];
|
||||
alignas(32) int val_arr[8];
|
||||
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
insert(key_arr[j], val_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < n; ++i)
|
||||
insert(keys[i], values[i]);
|
||||
#else
|
||||
for (size_t i = 0; i < n; ++i)
|
||||
insert(keys[i], values[i]);
|
||||
#endif
|
||||
}
|
||||
|
||||
void batch_insert_or_update(const int* keys, const int* values, size_t n) {
|
||||
#ifdef __AVX2__
|
||||
constexpr size_t stride = 8;
|
||||
size_t i = 0;
|
||||
|
||||
for (; i + stride <= n; i += stride) {
|
||||
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||
__m256i vvals = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(values + i));
|
||||
|
||||
alignas(32) int key_arr[8];
|
||||
alignas(32) int val_arr[8];
|
||||
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(val_arr), vvals);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
if (!insert(key_arr[j], val_arr[j]))
|
||||
update(key_arr[j], val_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < n; ++i) {
|
||||
if (!insert(keys[i], values[i]))
|
||||
update(keys[i], values[i]);
|
||||
}
|
||||
#else
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
if (!insert(keys[i], values[i]))
|
||||
update(keys[i], values[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void batch_erase(const int* keys, size_t n) {
|
||||
#ifdef __AVX2__
|
||||
constexpr size_t stride = 8;
|
||||
size_t i = 0;
|
||||
|
||||
for (; i + stride <= n; i += stride) {
|
||||
__m256i vkeys = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(keys + i));
|
||||
alignas(32) int key_arr[8];
|
||||
_mm256_store_si256(reinterpret_cast<__m256i*>(key_arr), vkeys);
|
||||
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
erase(key_arr[j]);
|
||||
}
|
||||
}
|
||||
|
||||
for (; i < n; ++i) {
|
||||
erase(keys[i]);
|
||||
}
|
||||
#else
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
erase(keys[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
return this->count.load();
|
||||
|
96
main.cpp
96
main.cpp
@ -10,6 +10,9 @@
|
||||
constexpr int THREADS = 8;
|
||||
constexpr int OPS_PER_THREAD = 10000;
|
||||
|
||||
constexpr int BATCH_THREADS = 4;
|
||||
constexpr int BATCH_SIZE = 1024;
|
||||
|
||||
std::atomic<uint64_t> insert_ns{0};
|
||||
std::atomic<uint64_t> find_ns{0};
|
||||
std::atomic<uint64_t> update_ns{0};
|
||||
@ -55,6 +58,96 @@ void worker(LockFreeMap<int, int>& map, int thread_id)
|
||||
}
|
||||
}
|
||||
|
||||
void batch_worker_insert(LockFreeMap<int, int>& map, int base_key) {
|
||||
std::vector<int> keys(BATCH_SIZE);
|
||||
std::vector<int> values(BATCH_SIZE);
|
||||
|
||||
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||
keys[i] = base_key + i;
|
||||
values[i] = i;
|
||||
}
|
||||
|
||||
map.batch_insert(keys.data(), values.data(), BATCH_SIZE);
|
||||
}
|
||||
|
||||
void batch_worker_update(LockFreeMap<int, int>& map, int base_key) {
|
||||
std::vector<int> keys(BATCH_SIZE);
|
||||
std::vector<int> values(BATCH_SIZE);
|
||||
|
||||
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||
keys[i] = base_key + i;
|
||||
values[i] = 10000 + i;
|
||||
}
|
||||
|
||||
map.batch_insert_or_update(keys.data(), values.data(), BATCH_SIZE);
|
||||
}
|
||||
|
||||
void batch_worker_find(const LockFreeMap<int, int>& map, int base_key) {
|
||||
std::vector<int> keys(BATCH_SIZE);
|
||||
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||
keys[i] = base_key + i;
|
||||
}
|
||||
|
||||
auto results = map.batch_find(keys.data(), BATCH_SIZE);
|
||||
for (size_t i = 0; i < results.size(); ++i) {
|
||||
assert(results[i].has_value());
|
||||
}
|
||||
}
|
||||
|
||||
void batch_worker_erase(LockFreeMap<int, int>& map, int base_key) {
|
||||
std::vector<int> keys(BATCH_SIZE);
|
||||
for (int i = 0; i < BATCH_SIZE; ++i) {
|
||||
keys[i] = base_key + i;
|
||||
}
|
||||
|
||||
map.batch_erase(keys.data(), BATCH_SIZE);
|
||||
}
|
||||
|
||||
void test_batch_operations_parallel() {
|
||||
LockFreeMap<int, int> map;
|
||||
std::vector<std::thread> threads;
|
||||
|
||||
auto benchmark = [&](const std::string& label, auto&& fn) {
|
||||
uint64_t start = now_ns();
|
||||
fn();
|
||||
uint64_t end = now_ns();
|
||||
double ms = (end - start) / 1e6;
|
||||
double ops = BATCH_THREADS * BATCH_SIZE;
|
||||
double throughput = ops / (ms / 1000.0);
|
||||
|
||||
std::cout << "✓ " << label << " passed in " << ms << " ms, throughput = "
|
||||
<< static_cast<size_t>(throughput) << " ops/sec\n";
|
||||
};
|
||||
|
||||
benchmark("batch_insert (parallel)", [&]() {
|
||||
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||
threads.emplace_back(batch_worker_insert, std::ref(map), i * BATCH_SIZE);
|
||||
for (auto& t : threads) t.join();
|
||||
threads.clear();
|
||||
});
|
||||
|
||||
benchmark("batch_insert_or_update (parallel)", [&]() {
|
||||
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||
threads.emplace_back(batch_worker_update, std::ref(map), i * BATCH_SIZE);
|
||||
for (auto& t : threads) t.join();
|
||||
threads.clear();
|
||||
});
|
||||
|
||||
benchmark("batch_find (parallel)", [&]() {
|
||||
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||
threads.emplace_back(batch_worker_find, std::cref(map), i * BATCH_SIZE);
|
||||
for (auto& t : threads) t.join();
|
||||
threads.clear();
|
||||
});
|
||||
|
||||
benchmark("batch_erase (parallel)", [&]() {
|
||||
for (int i = 0; i < BATCH_THREADS; ++i)
|
||||
threads.emplace_back(batch_worker_erase, std::ref(map), i * BATCH_SIZE);
|
||||
for (auto& t : threads) t.join();
|
||||
threads.clear();
|
||||
});
|
||||
}
|
||||
|
||||
void verify_map(LockFreeMap<int, int>& map)
|
||||
{
|
||||
std::cout << "Final size: " << map.size() << "\n";
|
||||
@ -114,5 +207,8 @@ int main()
|
||||
std::cout << "Erase avg: " << erase_ns / 1e3 << " μs total (" << (erase_ns / (double)total_ops) << " ns/op)\n";
|
||||
std::cout << "--------------------------\n";
|
||||
|
||||
std::cout << "=== Batch Operations: Parallel Test ===\n";
|
||||
test_batch_operations_parallel();
|
||||
std::cout << "=== Done ===\n";
|
||||
return 0;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user