Rust中任意顺序的并行工作窃取

iaqfqrcu  于 2022-11-12  发布在  其他
关注(0)|答案(1)|浏览(133)

我正在尝试用Rust编写一个用于深度学习的并行数据加载器,任务是编写一个迭代器,它在底层执行以下操作
1.从磁盘读取文件并对它们应用一些计算量大的预处理,结果通常是一个数字数组(或多个)
1.将上一步的结果分组为大小为B的批处理,并对其进行“整理”-这通常意味着仅连接数组-计算量适中
1.产生步骤2的结果。
步骤1可能同时受到IO和计算的限制,这取决于网络延迟、文件大小和预处理的复杂性。它必须由许多工作线程并行运行。步骤2应该脱离主线程,但可能不需要工作线程池。步骤3发生在主线程上(暴露给Python)。
我用Rust编写它的原因是Python提供了两个选项:PyTorch附带了一个纯Python实现,基于multiprocessing,速度有点慢,但非常灵活(任意用户定义的数据预处理和批处理)和Tensorflow附带的C++实现,后者的速度要快得多,但对于我希望进行的数据处理来说限制太大。我希望Rust能给予我Tensorflow的速度和PyTorch中任意代码的灵活性。
我的问题纯粹是关于如何实现并行性的。理想的设置是有N个worker执行步骤1)-〉channel -〉worker执行步骤2)-〉channel -〉step 3。因为迭代器对象可能随时被丢弃,所以有一个严格的要求,即能够在Drop之后终止整个方案。另一方面,存在以任意顺序加载文件的灵活性:例如,如果批量大小为B == 16max_n_threads == 32,则启动32个工作进程并生成包含16个首先返回的示例的第一个批量是非常好的。
我的简单实现通过3个步骤创建DataLoader
1.创建n_working: Arc<AtomicUsize>以控制活动工作线程的数量,创建should_shutdown: Arc<AtomicBool>以发出关闭信号(当调用Drop时)
1.创建一个负责维护池的线程。它在n_working < max_n_threads上旋转,并继续繁殖在should_shutdown上终止的工作线程,否则获取一个示例,将其发送到工作线程-〉批处理器通道,并递减n_working
1.创建一个批处理线程,它轮询worker-〉batcher通道,在接收到B对象后,将它们连接成一个批处理,并向下发送batcher-〉yielder通道


# [pyclass]

struct DataLoader {
    collate_worker: Option<thread::JoinHandle<()>>,
    example_worker: Option<thread::JoinHandle<()>>,
    should_shut_down: Arc<AtomicBool>,
    receiver: Receiver<Batch>,
    length: usize,
}

impl DataLoader {
    fn new(
        dataset: Dataset,
        batch_size: usize,
        capacity: usize,
    ) -> Self {
        let n_batches = dataset.len() / batch_size;

        let max_n_threads = capacity * batch_size;
        let (example_sender, collate_receiver) = bounded((batch_size - 1) * capacity);

        let should_shut_down = Arc::new(AtomicBool::new(false));

        let shutdown_flag = should_shut_down.clone();
        let example_worker = thread::spawn(move || {
            rayon::scope_fifo(|s| {
                let dataset = &dataset;
                let n_working = Arc::new(AtomicUsize::new(0));
                let mut current_index = 0;

                while current_index < n_batches * batch_size {
                    if n_working.load(Ordering::Relaxed) == max_n_threads {
                        continue;
                    }
                    if shutdown_flag.load(Ordering::Relaxed) {
                        break;
                    }

                    let index = current_index.clone();
                    let sender = example_sender.clone();
                    let counter = n_working.clone();
                    let shutdown_flag = shutdown_flag.clone();
                    s.spawn_fifo(move |_s| {
                        let example = dataset.get_example(index);
                        if !shutdown_flag.load(Ordering::Relaxed) {
                            _ = sender.send(example);
                        } // if we should shut down, skip sending
                        counter.fetch_sub(1, Ordering::Relaxed);
                    });

                    current_index += 1;
                    n_working.fetch_add(1, Ordering::Relaxed);
                };

            });
        });

        let (batch_sender, final_receiver) = bounded(capacity);

        let shutdown_flag = should_shut_down.clone();
        let collate_worker = thread::spawn(move || {
            'outer: loop {
                let mut batch = vec![];
                for _ in 0..batch_size {
                    if let Ok(example) = collate_receiver.recv() {
                        batch.push(example);
                    } else {
                        break 'outer;
                    }
                };
                let collated = collate(batch);
                if shutdown_flag.load(Ordering::Relaxed) {
                    break; // skip sending
                }
                _ = batch_sender.send(collated);
            };
        });

        Self {
            collate_worker: Some(collate_worker),
            example_worker: Some(example_worker),
            should_shut_down: should_shut_down,
            receiver: final_receiver,
            length: n_batches,
        }
    }
}

# [pymethods]

impl DataLoader {
    fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { slf }

    fn __next__(&mut self) -> Option<Batch> {
        self.receiver.recv().ok() 
    }

    fn __len__(&self) -> usize {
        self.length
    }
}

impl Drop for DataLoader {
    fn drop(&mut self) {
        self.should_shut_down.store(true, Ordering::Relaxed);
        if self.collate_worker.take().unwrap().join().is_err() {
            println!("Panic in collate worker");
        };
        if self.example_worker.take().unwrap().join().is_err() {
            println!("Panic in example_worker");
        };

        println!("dropped the dataloader");
    }
}

这个实现工作正常,性能与PyTorch大致相当,但没有提供显著的加速比。但我认为,如果能够以窃取工作的方式自动实现负载平衡,并根据IO和计算时间的比例灵活地产生工作线程,将会有所帮助。我还预计到由于旋转池管理器和处理Drop时可能出现的极端情况而导致的性能问题。
我的问题是如何最好地解决这个问题,我通常不确定是否应该用类似rayon的并行板条箱、类似tokio的异步板条箱或者两者的混合。我也有一种预感,如果正确使用它们的组合子/高阶API,我的实现会简单得多。我尝试使用rayon,但我无法得到一个不包含“Don“不要浪费地执行原始的顺序返回顺序,并遵守Drop要求。

vohkndzv

vohkndzv1#

好的,我想我已经为你想出了一个解决方案,它使用了人造丝并行迭代器。
技巧是在rayon迭代器中使用Results,如果设置了取消标志,则返回Err
我首先创建了一个实用程序类型来创建一个可取消的线程,你可以在其中执行人造丝迭代器。你可以通过传入线程闭包来使用它,线程闭包将原子取消令牌作为参数。然后你必须检查取消令牌是否为true,如果是,就提前退出。

use std::sync::Arc;
use std::sync::atomic::{Ordering, AtomicBool};
use std::thread::JoinHandle;

fn collate(batch: &[Computed]) -> Batch {
    batch.iter().map(|&x| i128::from(x)).sum()
}

# [derive(Debug)]

struct Cancelled;

struct CancellableThread<Output: Send + 'static> {
    cancel_token: Arc<AtomicBool>,
    thread: Option<JoinHandle<Result<Output, Cancelled>>>,
}
impl<Output: Send + 'static> CancellableThread<Output> {
    fn new<F: FnOnce(Arc<AtomicBool>) -> Result<Output, Cancelled> + Send + 'static>(init: F) -> Self {
        let cancel_token = Arc::new(AtomicBool::new(false));
        let thread_cancel_token = Arc::clone(&cancel_token);

        CancellableThread {
            thread: Some(std::thread::spawn(move || init(thread_cancel_token))),
            cancel_token,
        }
    }

    fn output(mut self) -> Output {
        self.thread.take().unwrap().join().unwrap().unwrap()
    }
}
impl<Output: Send + 'static> Drop for CancellableThread<Output> {
    fn drop(&mut self) {
        self.cancel_token.store(true, Ordering::Relaxed);

        if let Some(thread) = self.thread.take() {
            let _ = thread.join().unwrap();
        }
    }
}

我发现创建一个返回Result<(), Cancelled>的闭包很有用,这样我就可以使用try操作符(?)提前退出。

CancellableThread::new(move |cancel_token| {
    let cancelled = || if cancel_token.load(Ordering::Relaxed) {
        Err(Cancelled)
    } else {
        Ok(())
    };

    loop {
        // was the thread dropped?
        // if so, stop what we're doing
        cancelled?;

        // do stuff and 
        // eventually return a result
    }
});

然后,我在DataLoader中使用了CancellableThread抽象。不需要为它创建特殊的Drop impl,因为默认情况下,它将在每个字段上调用drop,这将处理取消操作。

type Data = Vec<u8>;
type Dataset = Vec<Data>;
type Computed = u64;
type Batch = i128;

use rayon::prelude::*;
use crossbeam::channel::{unbounded, Receiver};

struct DataLoader {
    example_worker: CancellableThread<()>,
    collate_worker: CancellableThread<()>,
    receiver: Receiver<Batch>,
    length: usize,
}

我用的是unbounded频道,因为它少了一件麻烦的事情。换用bounded频道应该不难。

impl DataLoader {
    fn new(dataset: Dataset, batch_size: usize) -> Self {
        let (example_sender, collate_receiver) = unbounded();
        let (batch_sender, final_receiver) = unbounded();

我不确定您是否总能保证数据集中的项数是batch_size的倍数,所以我决定显式地处理它。

let length = if dataset.len() % batch_size == 0 {
            dataset.len() / batch_size
        } else {
            dataset.len() / batch_size + 1
        };

我首先创建了排序工作者,尽管这可能不是必需的。正如您所看到的,我不得不复制一点来处理部分批处理。

let collate_worker = CancellableThread::new(move |cancel_token| {
            let cancelled = || if cancel_token.load(Ordering::Relaxed) {
                Err(Cancelled)
            } else {
                Ok(())
            };

            'outer: loop {
                let mut batch = Vec::with_capacity(batch_size);
                for _ in 0..batch_size {
                    cancelled()?;

                    if let Ok(data) = collate_receiver.recv() {
                        batch.push(data);
                    } else {
                        if !batch.is_empty() {
                            // handle the last batch, if there
                            // weren't enough items to fill it
                            let collated = collate(&batch);
                            cancelled()?;
                            batch_sender.send(collated).unwrap();
                        }

                        break 'outer;
                    }
                }

                let collated = collate(&batch);
                cancelled()?;
                batch_sender.send(collated).unwrap();
            }

            Ok(())
        });

在示例worker中,事情变得非常简单,因为我们可以只使用rayon并行迭代器。正如你所看到的,我们在每次繁重的计算之前检查取消。

let example_worker = CancellableThread::new(move |cancel_token| {
            let cancelled = || if cancel_token.load(Ordering::Relaxed) {
                Err(Cancelled)
            } else {
                Ok(())
            };

            let heavy_compute = |data: Data| -> Result<Computed, Cancelled> {
                cancelled()?;

                Ok(data.iter().map(|&x| u64::from(x)).product())
            };

            dataset
                .into_par_iter()
                .map(heavy_compute)
                .try_for_each(|computed| {
                    example_sender.send(computed?).unwrap();

                    Ok(())
                })
        });

然后我们构造DataLoader。你可以看到Python的实现是相同的:

DataLoader {
            example_worker,
            collate_worker,
            receiver: final_receiver,
            length,
        }
    }
}

// #[pymethods]
impl DataLoader {
    fn __iter__(this: Self /* PyRef<Self> */) -> Self /* PyRef<Self> */ { this }

    fn __next__(&mut self) -> Option<Batch> {
        self.receiver.recv().ok() 
    }

    fn __len__(&self) -> usize {
        self.length
    }
}

playground

相关问题