我正在尝试用Rust编写一个用于深度学习的并行数据加载器,任务是编写一个迭代器,它在底层执行以下操作
1.从磁盘读取文件并对它们应用一些计算量大的预处理,结果通常是一个数字数组(或多个)
1.将上一步的结果分组为大小为B
的批处理,并对其进行“整理”-这通常意味着仅连接数组-计算量适中
1.产生步骤2的结果。
步骤1可能同时受到IO和计算的限制,这取决于网络延迟、文件大小和预处理的复杂性。它必须由许多工作线程并行运行。步骤2应该脱离主线程,但可能不需要工作线程池。步骤3发生在主线程上(暴露给Python)。
我用Rust编写它的原因是Python提供了两个选项:PyTorch附带了一个纯Python实现,基于multiprocessing
,速度有点慢,但非常灵活(任意用户定义的数据预处理和批处理)和Tensorflow附带的C++实现,后者的速度要快得多,但对于我希望进行的数据处理来说限制太大。我希望Rust能给予我Tensorflow的速度和PyTorch中任意代码的灵活性。
我的问题纯粹是关于如何实现并行性的。理想的设置是有N个worker执行步骤1)-〉channel -〉worker执行步骤2)-〉channel -〉step 3。因为迭代器对象可能随时被丢弃,所以有一个严格的要求,即能够在Drop
之后终止整个方案。另一方面,存在以任意顺序加载文件的灵活性:例如,如果批量大小为B == 16
和max_n_threads == 32
,则启动32个工作进程并生成包含16个首先返回的示例的第一个批量是非常好的。
我的简单实现通过3个步骤创建DataLoader
:
1.创建n_working: Arc<AtomicUsize>
以控制活动工作线程的数量,创建should_shutdown: Arc<AtomicBool>
以发出关闭信号(当调用Drop
时)
1.创建一个负责维护池的线程。它在n_working < max_n_threads
上旋转,并继续繁殖在should_shutdown
上终止的工作线程,否则获取一个示例,将其发送到工作线程-〉批处理器通道,并递减n_working
1.创建一个批处理线程,它轮询worker-〉batcher通道,在接收到B
对象后,将它们连接成一个批处理,并向下发送batcher-〉yielder通道
# [pyclass]
struct DataLoader {
collate_worker: Option<thread::JoinHandle<()>>,
example_worker: Option<thread::JoinHandle<()>>,
should_shut_down: Arc<AtomicBool>,
receiver: Receiver<Batch>,
length: usize,
}
impl DataLoader {
fn new(
dataset: Dataset,
batch_size: usize,
capacity: usize,
) -> Self {
let n_batches = dataset.len() / batch_size;
let max_n_threads = capacity * batch_size;
let (example_sender, collate_receiver) = bounded((batch_size - 1) * capacity);
let should_shut_down = Arc::new(AtomicBool::new(false));
let shutdown_flag = should_shut_down.clone();
let example_worker = thread::spawn(move || {
rayon::scope_fifo(|s| {
let dataset = &dataset;
let n_working = Arc::new(AtomicUsize::new(0));
let mut current_index = 0;
while current_index < n_batches * batch_size {
if n_working.load(Ordering::Relaxed) == max_n_threads {
continue;
}
if shutdown_flag.load(Ordering::Relaxed) {
break;
}
let index = current_index.clone();
let sender = example_sender.clone();
let counter = n_working.clone();
let shutdown_flag = shutdown_flag.clone();
s.spawn_fifo(move |_s| {
let example = dataset.get_example(index);
if !shutdown_flag.load(Ordering::Relaxed) {
_ = sender.send(example);
} // if we should shut down, skip sending
counter.fetch_sub(1, Ordering::Relaxed);
});
current_index += 1;
n_working.fetch_add(1, Ordering::Relaxed);
};
});
});
let (batch_sender, final_receiver) = bounded(capacity);
let shutdown_flag = should_shut_down.clone();
let collate_worker = thread::spawn(move || {
'outer: loop {
let mut batch = vec![];
for _ in 0..batch_size {
if let Ok(example) = collate_receiver.recv() {
batch.push(example);
} else {
break 'outer;
}
};
let collated = collate(batch);
if shutdown_flag.load(Ordering::Relaxed) {
break; // skip sending
}
_ = batch_sender.send(collated);
};
});
Self {
collate_worker: Some(collate_worker),
example_worker: Some(example_worker),
should_shut_down: should_shut_down,
receiver: final_receiver,
length: n_batches,
}
}
}
# [pymethods]
impl DataLoader {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> { slf }
fn __next__(&mut self) -> Option<Batch> {
self.receiver.recv().ok()
}
fn __len__(&self) -> usize {
self.length
}
}
impl Drop for DataLoader {
fn drop(&mut self) {
self.should_shut_down.store(true, Ordering::Relaxed);
if self.collate_worker.take().unwrap().join().is_err() {
println!("Panic in collate worker");
};
if self.example_worker.take().unwrap().join().is_err() {
println!("Panic in example_worker");
};
println!("dropped the dataloader");
}
}
这个实现工作正常,性能与PyTorch大致相当,但没有提供显著的加速比。但我认为,如果能够以窃取工作的方式自动实现负载平衡,并根据IO和计算时间的比例灵活地产生工作线程,将会有所帮助。我还预计到由于旋转池管理器和处理Drop
时可能出现的极端情况而导致的性能问题。
我的问题是如何最好地解决这个问题,我通常不确定是否应该用类似rayon
的并行板条箱、类似tokio
的异步板条箱或者两者的混合。我也有一种预感,如果正确使用它们的组合子/高阶API,我的实现会简单得多。我尝试使用rayon
,但我无法得到一个不包含“Don“不要浪费地执行原始的顺序返回顺序,并遵守Drop
要求。
1条答案
按热度按时间vohkndzv1#
好的,我想我已经为你想出了一个解决方案,它使用了人造丝并行迭代器。
技巧是在rayon迭代器中使用
Results
,如果设置了取消标志,则返回Err
。我首先创建了一个实用程序类型来创建一个可取消的线程,你可以在其中执行人造丝迭代器。你可以通过传入线程闭包来使用它,线程闭包将原子取消令牌作为参数。然后你必须检查取消令牌是否为
true
,如果是,就提前退出。我发现创建一个返回
Result<(), Cancelled>
的闭包很有用,这样我就可以使用try操作符(?
)提前退出。然后,我在
DataLoader
中使用了CancellableThread
抽象。不需要为它创建特殊的Drop
impl,因为默认情况下,它将在每个字段上调用drop
,这将处理取消操作。我用的是
unbounded
频道,因为它少了一件麻烦的事情。换用bounded
频道应该不难。我不确定您是否总能保证数据集中的项数是
batch_size
的倍数,所以我决定显式地处理它。我首先创建了排序工作者,尽管这可能不是必需的。正如您所看到的,我不得不复制一点来处理部分批处理。
在示例worker中,事情变得非常简单,因为我们可以只使用rayon并行迭代器。正如你所看到的,我们在每次繁重的计算之前检查取消。
然后我们构造
DataLoader
。你可以看到Python的实现是相同的:playground