关于多线程的三两事

多线程一直是编程中的重要的工具,它可以分充分的利用硬件资源,是我们用更少的时间去完成更多的事情。在之前的博客中,我有介绍了OpenMP的基本使用,OpenMP可以理解为多线程的一个合理和高效的一套抽象工具。这次,打算仔细的介绍多线程编程中的常见的概念和典型的案例。

典型的案例

说到多线程,最核心的问题就是保证数据的读写安全。为了达到此目的,我们需要多很多常见的数据结构做一些改造,从而适应多线程的场景。以下是我工作中比较常见到的一些使用场景:

  1. 线程池
  2. 读写锁
  3. 消息队列
  4. ConcurrentCache
  5. PingPang Buffer

在具体介绍这些使用场景之前,我们还是需要了解需要使用到的一些基本的工具:互斥量、条件变量、原子操作等。

互斥量

互斥量,顾名思义,就是互斥的数据,一个线程持有的时候,其他线程就必须等待。

在C++11中,使用<mutex>头文件引入。以下是一个简单的计数器的实现。

emit函数通过mutex_进行加锁,使得同一时间仅有一个线程可以执行++ x_的操作,从而保证了计数的正确性。

std::lock_guard是个工具类,lck在构造时,调用了lock函数,在析构时调用了unlock,从而避免我们自行调用的时候忘记unlock。

#include <mutex>
#include <thread>
#include <iostream>

class Counter {
public:
    Counter(): x_(0) {}
    void emit() {
        mutex_.lock();
        ++ x_;
        mutex_.unlock();
        // or
        // std::lock_guard<std::mutex> lck(mutex_);
        // ++ x_;
    }
    int count() {
        return x_;
    }
private:
    int x_;
    std::mutex mutex_;
};

int main() {
    Counter c;
    std::thread t1([&c]{
        for (int i = 0; i < 10000000; ++ i) {
            c.emit();
        }
    });
    std::thread t2([&c]{
        for (int i = 0; i < 10000000; ++ i) {
            c.emit();
        }
    });
    t1.join();
    t2.join();
    std::cout << c.count() << std::endl; // 20000000
}

基于Mutex,我们可以方便的实现读写锁。读写锁的作用是,保证数据可以供多个线程并发读,仅一个线程允许写。在存在线程读的情况下,写的线程会阻塞,直到没有任何线程有读操作。

读写锁

首先读写锁会存在一个write_mutex,读线程和写线程都需要抢占这个mutex,从而保证读和写不会同时进行。但是只需要第一个读线程抢占write_mutex即可,其他的读线程不需要再抢占(抢占的话,就不支持并发读了)。当不存在读线程的时候,需要释放write_mutex,这才运行写线程抢占。

因此我们还需要一个计数器,记录当前读线程的个数,并使用另一个read_mutex保证计数器的准确。

#include <mutex>
#include <thread>
#include <iostream>
#include <vector>

class ReadWriteLock {
public:
    ReadWriteLock():reader_count_(0) {}
    void lock_read() {
        read_mutex_.lock();
        if (reader_count_ == 0) {
            write_mutex_.lock();
        }
        ++ reader_count_;
        read_mutex_.unlock();
    }
    void unlock_read() {
        read_mutex_.lock();
        -- reader_count_;
        if (reader_count_ == 0) {
            write_mutex_.unlock();
        }
        read_mutex_.unlock();
    }
    void lock_write() {
        write_mutex_.lock();
    }
    void unlock_write() {
        write_mutex_.unlock();
    }
private:
    std::mutex read_mutex_;
    std::mutex write_mutex_;
    int64_t reader_count_;
};

ReadWriteLock rw_lock;

void read_fn(int idx, int start, int end) {
    std::this_thread::sleep_for(std::chrono::seconds(start));
    rw_lock.lock_read();
    std::cout << "read thread #" << idx << ": read data" << std::endl;
    std::this_thread::sleep_for (std::chrono::seconds(end - start));
    std::cout << "read thread #" << idx << ": read over" << std::endl;
    rw_lock.unlock_read();
}

void write_fn(int idx, int start, int end) {
    std::this_thread::sleep_for(std::chrono::seconds(start));
    rw_lock.lock_write();
    std::cout << "write thread #" << idx << ": write data" << std::endl;
    std::this_thread::sleep_for (std::chrono::seconds(end - start));
    std::cout << "write thread #" << idx << ": write over" << std::endl;
    rw_lock.unlock_write();
}

int main() {

    std::vector<std::thread> threads;
    threads.push_back(std::thread([](){read_fn(1, 0, 3);}));
    threads.push_back(std::thread([](){read_fn(2, 2, 4);}));
    threads.push_back(std::thread([](){read_fn(3, 6, 10);}));

    threads.push_back(std::thread([](){write_fn(1, 1, 4);}));
    threads.push_back(std::thread([](){write_fn(2, 5, 7);}));
    
    for (auto &&t : threads) {
        t.join();
    }
}

// output
// read thread #1: read data
// read thread #2: read data
// read thread #1: read over
// read thread #2: read over
// write thread #1: write data
// write thread #1: write over
// write thread #2: write data
// write thread #2: write over
// read thread #3: read data
// read thread #3: read over

可以看到读线程1和2同时进行的读操作,而写线程1在读线程结束之后才进行。

条件变量

条件变量主要作用是是线程间进行通信,用于一个线程告知其他线程当前的状态。

一般可以用于控制线程的执行顺序,告知其他线程资源是否可用等。

条件变量的使用需要搭配互斥量。

#include <mutex>
#include <iostream>
#include <thread>
#include <condition_variable>

std::mutex mutex_;
std::condition_variable cv_;
bool ready_ = false;

void print_id(int id) {
    std::unique_lock<std::mutex> lck(mutex_);
    while (!ready_) {
        cv_.wait(lck);
    }
    std::cout << "thread -- " << id << std::endl;
}

void go() {
    std::unique_lock<std::mutex> lck(mutex_);
    ready_ = true;
    cv_.notify_all();
}

int main() {
    std::thread threads[10];
    // spawn 10 threads:
    for (int i = 0; i < 10; ++i) {
        threads[i] = std::thread(print_id, i);
    }

    std::cout << "10 threads ready to race...\n";
    go(); // go!

    for (auto &th : threads) {
        th.join();
    }

    return 0;
}

这里使用std::unique_lock来管理互斥量。相对于lock_guardunique_lock的功能更丰富,可以通过它来对mutex进行lockunlock,具体的使用可以查看相关的文档。

condition_variable通过wait操作,可以等待唤醒。wait操作有两个行为:

  1. 将当前线程加入条件变量的等待队列
  2. 释放锁

唤醒条件变量的方法有两个:notify_onenotify_all,分别唤醒一个和所有的线程。

当一个wait中的线程被唤醒时,它会抢占住mutex,因此后续的操作均是线程安全的。

为什么condition_variable需要一个mutex呢?

  1. 一方面是有些变量的访问,我们需要保证它的互斥性,比如这里的ready_字段
  2. 保证wait的两个操作(等待和锁释放)是原子的。

可以参考下面这篇文章:

C++面试问题:为什么条件变量要和互斥锁一起使用?

那么使用条件变量,我们可以创造哪些有意思的工具呢?阻塞队列就是一个巧妙的应用。

BlockQueue

阻塞队列是一种非常常见的数据结构,它允许一个或多个生产者向Queue中写入数据,如果Queue满了,则阻塞住。允许一个或多个消费者读取Queue的数据,如果Queue为空,则一直阻塞直至Queue中有数据。

根据BlockQueue的两种阻塞行为,我们可以大胆的推测,这里可以用两个条件变量,分别控制写入阻塞和读取阻塞。

#include <deque>
#include <mutex>
#include <condition_variable>

template<typename TaskType>
class BlockQueue {
public:
    BlockQueue(size_t capacity): capacity_(capacity) {}
    size_t capacity() {
        std::lock_guard<std::mutex> lck(this->mutex_);
        return this->capacity_;
    }
    size_t size() {
        std::lock_guard<std::mutex> lck(this->mutex_);
        return this->task_queue_.size();
    }
    void push(TaskType *task) {
        std::unique_lock<std::mutex> lck(this->mutex_);
        while (this->task_queue_.size() >= this->capacity_) {
            this->full_cv_.wait(lck);
        }
        this->task_queue_.push_back(task);
        this->empty_cv_.notify_all();
    }
    void get(TaskType **task) {
        std::unique_lock<std::mutex> lck(this->mutex_);
        while (this->task_queue_.empty()) {
            this->empty_cv_.wait(lck);
        }
        *task = task_queue_.front();
        task_queue_.pop_front();
        this->full_cv_.notify_all();
    }

private:
    std::deque<TaskType *> task_queue_;
    size_t capacity_;

    std::mutex mutex_;
    std::condition_variable full_cv_;
    std::condition_variable empty_cv_;
};

上述的例子,如果将wait改为wait_for的话,还可以方便的实现带timeout的BlockQueue,感兴趣的同学可以自己尝试一下。

原子类型与原子操作

C++中的原子类型的定义和使用十分简单。仅需要包含头文件<atomic>即可。使用std::atomic<T>的方式即可构造原子类型的变量。

#include <atomic>
std::atomic<int32_t> i32_count;
std::atomic<uint64_t> u64_count;

针对原子类型的变量,有许多的操作可用。最常用到的就是++用来计数。比如我们前面的使用mutex完成计数器的例子,其实使用原子类型会更加的简单和高效。

#include <atomic>

class Counter {
public:
    Counter(): x_(0) {}
    void emit() {
        ++ x_;
    }
    int count() {
        return x_;
    }
private:
    std::atomic<int> x_;
};

以下是具体的几个方法:

函数 功能
store 用非原子对象替换当前对象的值。相等于线程安全的=操作
load 原子地获取原子对象的值
fetch_add/fetch_sub 原子地对原子做加减操作,返回操作之前的值
+= / -= 同上
fetch_and/fetch_or/fetch_xor 原子地对原子对象做与/或/异或地操作,返回操作之前的值
&= / |= / ^= 同上

另外,atomic类型的函数可以指定memory_order参数,用于约束atomic类型数据在多线程中的视图。感兴趣可以看这篇文章:https://zhuanlan.zhihu.com/p/31386431

一般我们使用默认的memory_order就已经足够了。

之后我们再介绍一个复杂但十分有用的原子操作:CAS(Compare And Swap)

看名字就知道,他的作用是,比较两个值,如果相同就交换。

百度上给了一个比较直观的解释:

  • compare and swap,解决多线程并行情况下使用锁造成性能损耗的一种机制,CAS操作包含三个操作数——内存位置(V)、预期原值(A)和新值(B)。如果内存位置的值与预期原值相匹配,那么处理器会自动将该位置值更新为新值。否则,处理器不做任何操作。无论哪种情况,它都会在CAS指令之前返回该位置的值。CAS有效地说明了“我认为位置V应该包含值A;如果包含该值,则将B放到这个位置;否则,不要更改该位置,只告诉我这个位置现在的值即可。

通过CAS操作,我们可以方便的实现无锁的线程安全队列:

#include <atomic>

class Node {
public:
    Node(int val): val_(val), next_(nullptr) {}
public:
    int val_;
    class Node *next_;
};

void push(std::atomic<Node *> &head, Node *new_node) {
    new_node->next_ = head;
    while (head.compare_exchange_weak(new_node->next_, new_node));
}

int main() {
    std::atomic<Node *> head;
    Node *new_node = new Node(100);
    push(head, new_node);
}

当我们插入一个节点的时候,首先尝试加入它,也就是new_node->next_ = head; 然后如果head没有变化的话,那么就更新head为我们新的节点,如果变化的话就不断重试。也就是while (head.compare_exchange_weak(new_node->next_, new_node)); 的逻辑。

上面这个例子是所有的CAS介绍都会说到的,可以非常容易地帮助我们理解CAS地功能,但是对于POP操作,并不好实现。

另外其实还存在一个ABA地问题,需要解决。这里就不展开了。感兴趣地可以搜一下相关的资料,这里仅做简单地介绍。

其他

最后我们看几个非常有意思地设计。

PingPang Buffer

PingPang Buffer也被称为双Buffer。它的核心是这样地,由于一些系统配置需要不断地更新,而更新地过程中也会被不断地读取。如果使用之前的读写锁,就可能永远都更新不了(读线程一直占着锁),同时线程同步也是非常低效地一个过程。然后就诞生了PingPang Buffer这么个结构。

它的核心是有两块内存,一块用来给所有线程进行读操作,另一块用来给写线程进行更新,在更新完毕之后,交换这两个内存。新的内存变成了读内存,旧内存变成了写内存。

以下是一个简单的实现,和网上的其他版本可能略有不同,看思路即可。

#include <atomic>
#include <memory>
#include <mutex>

template<typename T>
class PingPangBuffer {
public:
    PingPangBuffer(std::shared_ptr<T> read_buffer, std::shared_ptr<T> write_buffer) {
        data_[0] = read_buffer;
        data_[1] = write_buffer;
        read_idx_ = 0;
    }
    std::shared_ptr<T> read_data() {
        std::lock_guard<std::mutex> lock(mutex_);
        return data_[read_idx_];
    }
    std::shared_ptr<T> write_data() {
        int write_idx = 1 - read_idx_;
        while (data_[write_idx].use_count() > 1) {
            // sleep 1s
            continue;
        }
        return data_[write_idx];
    }
    bool update() {
        std::lock_guard<std::mutex> lock(mutex_);
        read_idx_ = 1 - read_idx_;
    }
private:
    std::shared_ptr<T> data_[2];
    int read_idx_;
    std::mutex mutex_;
};

这里read_data函数被多个读线程去调用。而write_dataupdate只有一个写线程进行调用。

使用一个read_idx_记录读的Buffer的下标,那么交换读写Buffer的操作就可以简化为read_idx_ = 1 - read_idx_ 。不过下标切换之后,切换之前的读线程还在读旧数据。

而获取写数据的操作需要等待当前Buffer不再被使用了才可以再次被使用(反正早晚它都是可以被使用的),这里就直接使用了shared_ptruse_count

线程安全的LRUCache

一般Cache是使用std::unordered_map来实现的。和前面的读写锁类似,map支持多线程的读,但是仅支持单线程写入。这就会造成这个map的写入性能可能会较差。因此这里一般采用分shard的方式进行库的拆分。

一个简单的实现,先根据key分shard,然后每个分片都使用读写锁。(多线程的测试不太好写,这里只测试了过期时间和容量)

#include <mutex>
#include <thread>
#include <iostream>
#include <chrono>
#include <vector>
#include <list>
#include <unordered_map>

// 读写锁,就是前面原封不动的代码
class ReadWriteLock {
public:
    ReadWriteLock():reader_count_(0) {}
    void lock_read() {
        read_mutex_.lock();
        if (reader_count_ == 0) {
            write_mutex_.lock();
        }
        ++ reader_count_;
        read_mutex_.unlock();
    }
    void unlock_read() {
        read_mutex_.lock();
        -- reader_count_;
        if (reader_count_ == 0) {
            write_mutex_.unlock();
        }
        read_mutex_.unlock();
    }
    void lock_write() {
        write_mutex_.lock();
    }
    void unlock_write() {
        write_mutex_.unlock();
    }
private:
    std::mutex read_mutex_;
    std::mutex write_mutex_;
    int64_t reader_count_;
};

template<typename KeyType, typename ValType>
class ConcurrentLRUCache {
public:
    class Node {
    public:
        Node(const KeyType& key, const ValType& val, size_t time_ms): key_(key), val_(val), time_ms_(time_ms) {}
        KeyType key_;
        ValType val_;
        size_t time_ms_;
    };
    using node_iter_type = typename std::list<Node>::iterator;
public:
    ConcurrentLRUCache(size_t capacity, size_t shard, size_t expire_time /* ms */) {
        capacity_ = capacity;
        shard_ = shard;
        capacity_per_cache_ = capacity_ / shard_;
        expire_time_ = expire_time;
        cache_shard_list_.resize(shard_);
        node_data_list_shard_list_.resize(shard_);
    }
    bool get(const KeyType& key, ValType& val) {
        auto &cache = cache_shard_list_[get_shard_idx(key)];
        rw_lock_.lock_read();
        bool ok = false;
        do {
            auto iter = cache.find(key);
            if (iter == cache.end()) { // not found
                break;
            }
            size_t cur_ms = get_cur_time_ms();
            size_t record_ms = iter->second->time_ms_;
            if (cur_ms - record_ms > expire_time_) { // found but expired
                break;
            }
            val = iter->second->val_;
            ok = true;
        } while (0);
        rw_lock_.unlock_read();
        return ok;
    }
    void set(const KeyType& key, ValType& val) {
        size_t shard_idx = get_shard_idx(key);
        auto &cache = cache_shard_list_[shard_idx];
        auto &data_list = node_data_list_shard_list_[shard_idx];
        rw_lock_.lock_write();

        do {
            // when found, del the older
            auto iter = cache.find(key);
            if (iter != cache.end()) {
                data_list.erase(iter->second);
                cache.erase(iter);
            }

            // when cache full, del the oldest
            while (cache.size() >= capacity_per_cache_) {
                cache.erase(data_list.front().key_);
                data_list.pop_front();
            }

            size_t cur_ms = get_cur_time_ms();
            data_list.emplace_back(key, val, cur_ms);
            cache[key] = --data_list.end();
        } while (0);

        rw_lock_.unlock_write();
    }
private:
    static size_t get_cur_time_ms() {
        return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
    }
    size_t get_shard_idx(const KeyType& key) {
        static std::hash<KeyType> hash_func;
        return hash_func(key) % shard_;
    }

    ReadWriteLock rw_lock_;

    size_t capacity_;
    size_t shard_;
    size_t capacity_per_cache_;
    size_t expire_time_;

    std::vector<std::unordered_map<KeyType, node_iter_type>> cache_shard_list_;
    std::vector<std::list<Node>> node_data_list_shard_list_;
};

int main() {
    ConcurrentLRUCache<int, int> cache(20, 2, 1000 /* 1s */);
    for (int i = 0; i < 20; ++ i) {
        cache.set(i, i);
        std::cout << "set: (" << i << ", " << i << ") " << std::endl;
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(500));
    for (int i = 20; i < 30; ++ i) {
        cache.set(i, i);
        std::cout << "set: (" << i << ", " << i << ") " << std::endl;
    }

    // 此时0-9已经被覆盖(容量),10-19已经过去500ms,20-29是最新时间

    for (int i = 0; i < 30; ++ i) {
        int data = -1;
        bool is_ok = cache.get(i, data);
		// 这里只有10-29被查到了
        std::cout << "get: (" << i << ", " << data << ") " << is_ok << std::endl;
    }

    // 总共过去800ms,10-29都没过期
	std::this_thread::sleep_for(std::chrono::milliseconds(300));
    for (int i = 0; i < 30; ++ i) {
        int data = -1;
        bool is_ok = cache.get(i, data);
        // 只有20-29被查到
        std::cout << "get: (" << i << ", " << data << ") " << is_ok << std::endl;
    }

    // 总共过去1100ms,20-29没过期
    std::this_thread::sleep_for(std::chrono::milliseconds(300));
    for (int i = 0; i < 30; ++ i) {
        int data = -1;
        bool is_ok = cache.get(i, data);
        // 20-29
        std::cout << "get: (" << i << ", " << data << ") " << is_ok << std::endl;
    }
}

写在最后

知识的总结一直是一件令人愉悦的事情,时隔1年多有一次捡起技术博客。