请教大家一个 C++线程池的问题

14次阅读

共计 4025 个字符,预计需要花费 11 分钟才能阅读完成。

最近在找一个简单的 C++11 线程池实现,发现网上有很多相关的代码,在 CSDN 网上看到一个比较简洁的。但是总感觉是不是实现错了。

  1. Any 类 noncopyable 的,仅仅支持移动语义,
  2. Result 类使用了 Any 实例作为成员变量,那么 Result 类应该也是 noncopyable 的,
  3. Result SubmitTask(std::shared_ptr taskPtr); 直接使用了复制语义,应该是有问题吧,可是代码能够被 vs2022 正常编译。

threadpool.h

#pragma once
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 


// Any 类型:可以接收任意数据的类型
// 任意其他类型 template
// 能让一个类型指向其他类型,基类指针可以指向子类
class Any
{
public:
	Any() = default;
	~Any() = default;
	Any(const Any&) = delete;
	Any& operator=(const Any&) = delete;
	Any(Any&&) = default;
	Any& operator=(Any&&) = default;

	template
	Any(T data) : m_base(std::make_unique>(data)) {}

	template
	T cast_()
	{Derive* pd = dynamic_cast*>(m_base.get());

		if (pd == nullptr) {throw "type is unmath!!";}

		return pd->m_data;
	}

private:
	// 基类
	class Base
	{
	public:
		virtual ~Base() = default;};

	// 派生类
	template
	class Derive : public Base
	{
	public:
		Derive(T data) : m_data(data) {}
	public:
		T m_data;
	};

private:
	std::unique_ptr m_base;
};


// 实现一个信号量类
class Semaphore
{
public:
	Semaphore(int limit = 0) : m_resLimit(limit)
	{}

	~Semaphore() = default;

	// 获取一个信号量资源
	void wait()
	{std::unique_lock lock(m_mtx);
		// 如果没有资源,阻塞线程
		while (m_resLimit < 1) {m_cond.wait(lock);
		}

		m_resLimit--;
	}

	// 增加一个信号量资源
	void post()
	{std::unique_lock lock(m_mtx);
		m_resLimit++;
		m_cond.notify_all();}
private:
	int m_resLimit;  // 资源量
	std::mutex m_mtx;
	std::condition_variable m_cond;
};


// Task 类型前置声明
class Task;

// 实现接收提交到线程池的 task 任务执行完成后的返回值类型
class Result
{
public:
	Result(std::shared_ptr task, bool isValid = true);
	~Result() = default;

	// setVal
	void setVal(Any result);

	// get 方法,用户调用这个方法获取 task 的返回值
	Any get();
private:
	Any m_any;
	Semaphore m_sem;
	std::shared_ptr m_task;
	std::atomic_bool m_isValid;
};


// 任务抽象基类
class Task
{
public:
	void exec();
	void setResult(Result* res);
	virtual Any run() = 0;

private:
	Result* m_result{nullptr};  // 不要用智能指针,task 含有 Result  Result 含有 task,可能导致问题
};

class MyTask : public Task
{
public:
	MyTask(int start, int end) : m_start(start), m_end(end) {}

	Any run()
	{
		std::ostringstream ostr;
		ostr <;

	Thread(ThreadFunc func);
	~Thread();

	void Start();
	int GetId() { return m_threadId;}
private:
	ThreadFunc m_func;
	static int generateId;
	int m_threadId;
};


class ThreadPool
{
public:
	ThreadPool();
	~ThreadPool();

	// 设置线程池工作模式
	void SetMode(ThreadPoolMode mode);

	// 设置任务数量上限
	void SetTaskQueMaxThreshold(int value);

	// 给线程池提交任务
	Result SubmitTask(std::shared_ptr taskPtr);

	// 开启线程池
	void Start(int initThreadSize = std::thread::hardware_concurrency());

private:
	ThreadPool(const ThreadPool&) = delete;
	ThreadPool& operator=(const ThreadPool&) = delete;

	// 定义线程函数
	void ThreadFunc(int threadId);
	bool CheckRunningState() const;

private:
	std::unordered_map> m_threadMap;  // 线程列表
	int m_initThreadSize;  // 初始的线程数量
	std::atomic_int m_curThreadSize;  // 当前线程数量

	std::queue> m_taskQue;  // 任务队列
	std::atomic_int m_taskSize;  // 任务的数量
	int m_taskQueMaxThreshold;  // 任务队列的数量上限

	std::mutex m_taskQueMtx;  // 保证任务队列的线程安全
	std::condition_variable m_taskQueNotFullCv;  // 表示任务队列不满
	std::condition_variable m_taskQueNotEmptyCv;  // 表示任务队列不空
	std::condition_variable m_exitCv;  // 退出线程池

	ThreadPoolMode m_poolMode;  // 当前线程池的工作模式
	std::atomic_bool m_isPoolRuning;  // 当前线程工作状态
};

threadpool.cpp

#include "threadpool.h"
#include 
#include 

constexpr int TASK_MAX_THRESHOLD = 1024;

ThreadPool::ThreadPool() : m_initThreadSize(4), m_taskSize(0),
m_taskQueMaxThreshold(TASK_MAX_THRESHOLD),
m_poolMode(ThreadPoolMode::MODE_FIXED)
{
}

ThreadPool::~ThreadPool()
{
	m_isPoolRuning = false;
	std::unique_lock lock(m_taskQueMtx);

	// 线程 要么在阻塞中 要么在工作中
	while (m_threadMap.size() > 0) {m_taskQueNotEmptyCv.notify_all();  // 唤醒等待的工作线程
		m_exitCv.wait(lock);
	}
}

void ThreadPool::SetMode(ThreadPoolMode mode)
{if (m_isPoolRuning) {return;}  // 线程池启动后,不允许设置线程池一些参数

	m_poolMode = mode;
}

void ThreadPool::SetTaskQueMaxThreshold(int value)
{if (m_isPoolRuning) {return;}

	m_taskQueMaxThreshold = value;
}

Result ThreadPool::SubmitTask(std::shared_ptr taskPtr)
{
	// 获取锁
	std::unique_lock lock(m_taskQueMtx);

	// 线程通信,检查任务队列是否有空余
	while (m_taskQue.size() >= m_taskQueMaxThreshold) {

		// 用于提交任务,不能阻塞太长时间,如果超过 1s,给用户返回提交失败
		if (m_taskQueNotFullCv.wait_for(lock, std::chrono::seconds(1)) == std::cv_status::timeout) {return Result(taskPtr, false);
		}
	}

	// 如果有空余,把任务提交到任务队列中
	m_taskQue.emplace(taskPtr);
	m_taskSize++;

	// 因为新放了任务,任务队列肯定不为空了,在 m_taskQueNotEmptyCv 进行通知,赶快分配线程执行这个任务
	m_taskQueNotEmptyCv.notify_all();

	return Result(taskPtr);
}

void ThreadPool::Start(int initThreadSize)
{
	m_initThreadSize = initThreadSize;
	m_curThreadSize = initThreadSize;
    m_isPoolRuning = true;

	// 创建线程对象
	for (int i = 0; i < m_initThreadSize; i++) {auto ptr = std::make_unique(std::bind(&ThreadPool::ThreadFunc, this, std::placeholders::_1));
		int threadId = ptr->GetId();
		m_threadMap.emplace(threadId, std::move(ptr));
	}

	// 启动所有线程
	for (auto iter = m_threadMap.cbegin(); iter != m_threadMap.end(); iter++) {iter->second->Start();}
}

void ThreadPool::ThreadFunc(int threadId)
{while (true) {

		// 获取锁
		std::unique_lock lock(m_taskQueMtx);

		std::ostringstream ostr;
		ostr < 0) {m_taskQueNotEmptyCv.notify_all();
		}

		// 通知队列已经不满
		m_taskQueNotFullCv.notify_all();

		taskPtr->exec();

		if (!m_isPoolRuning) {m_threadMap.erase(threadId);
			m_exitCv.notify_all();

			printf("deconstructor thread exit, id = %dn", threadId);
			return;
		}

	}
}

bool ThreadPool::CheckRunningState() const
{if (m_isPoolRuning) {return true;}

	return false;
}

// 线程方法
int Thread::generateId = 0;

Thread::Thread(ThreadFunc func) : m_func(func),
								m_threadId(generateId++)
{
}

Thread::~Thread()
{
}

void Thread::Start()
{std::thread t(m_func, m_threadId);
	t.detach();}

Result::Result(std::shared_ptr task, bool isValid) : m_task(task), m_isValid(isValid)
{m_task->setResult(this);
}

void Result::setVal(Any result)
{m_any = std::move(result);
	m_sem.post();  // 通知已经获得结果}

Any Result::get()
{if (!m_isValid) {return "";}

	m_sem.wait();  // 等待结果
	return std::move(m_any);
}


void Task::exec()
{if (m_result != nullptr) {Any result = run();  // 这里发生多态调用

		m_result->setVal(std::move(result));
	}
}

void Task::setResult(Result* res)
{m_result = res;}

main.cpp

#include "threadpool.h"

#include 
#include 

using std::cout;
using std::endl;


int main(int argc, char* argv[])
{
	{
		ThreadPool pool;
		pool.Start(4);

		Result res1 = pool.SubmitTask(std::make_shared(1, 100000000));
		Result res2 = pool.SubmitTask(std::make_shared(100000001, 200000000));
		Result res3 = pool.SubmitTask(std::make_shared(200000001, 300000000));

		//uint64_t sum1 = res1.get().cast_();
		//uint64_t sum2 = res2.get().cast_();
		//uint64_t sum3 = res3.get().cast_();

		//cout <<(sum1 + sum2 + sum3) <
正文完
 0