threadpool/src/lib.rs
2024-01-13 12:07:15 +08:00

105 lines
2.9 KiB
Rust

use std::{
collections::LinkedList,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex,
},
thread,
thread::JoinHandle,
time::Duration,
};
use log::info;
pub struct ThreadPool {
thrd: Vec<Option<JoinHandle<()>>>,
queue: Arc<Vec<Mutex<LinkedList<Box<Task>>>>>,
index: Arc<AtomicUsize>,
cnt: Arc<AtomicUsize>,
stop: Arc<AtomicBool>,
}
type Task = dyn FnOnce() + Send;
impl ThreadPool {
pub fn new(n: u8) -> Self {
let mut w = Vec::new();
for _ in 0..n {
w.push(Mutex::new(LinkedList::new()));
}
let mut r = ThreadPool {
thrd: vec![],
queue: Arc::new(w),
index: Arc::new(0.into()),
cnt: Arc::new(0.into()),
stop: Arc::new(false.into()),
};
for id in 0..n as usize {
let q = r.queue.clone();
let stop = r.stop.clone();
let cnt = r.cnt.clone();
r.thrd.push(Some(thread::spawn(move || {
cnt.fetch_add(1, Ordering::SeqCst);
let n = n as usize;
let lock = &q[id % n];
let mut l = LinkedList::new();
let process = |l: &mut LinkedList<Box<Task>>| {
while !l.is_empty() {
let t = l.pop_front().unwrap();
t();
}
};
while !stop.load(Ordering::SeqCst) {
info!("thread {}", id);
thread::sleep(Duration::from_micros(500));
{
let mut tmp = lock.lock().unwrap();
l.append(&mut tmp);
}
info!("thread {} working", id);
process(&mut l);
}
let mut l = lock.lock().unwrap();
process(&mut l);
info!("thread {} exit", id);
})));
}
r
}
pub fn push(&mut self, t: impl FnOnce() + Send + 'static) {
let idx = self.index.fetch_add(1, Ordering::SeqCst) % self.queue.len();
let lock = &self.queue[idx];
info!("add task to thread {}", idx);
let l = &mut lock.lock().unwrap();
l.push_back(Box::new(t));
info!("notify thread {}", idx);
}
pub fn started(&self) -> bool {
return self.cnt.load(Ordering::SeqCst) == self.thrd.len();
}
pub fn stop(&mut self) {
self.stop.store(true, Ordering::SeqCst);
let mut i = 0;
for lock in self.queue.as_ref() {
let _ = lock.lock();
info!("notify exit {}", i);
i += 1;
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
for task in &mut self.thrd {
if let Some(t) = task.take() {
let id = t.thread().id();
t.join().expect(&format!("can't join thrd {:?}", id));
}
}
}
}