Rust并行获取ndarray中每个元素的可变引用

xlpyo6sf  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(122)

我正在Rust中编写一个并行矩阵乘法代码,我想并行计算乘积的每个元素。我使用ndarray s来存储我的数据。因此,我的代码将是一些单独的行

fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
   let N = lhs.raw_size()[0];
   let M = rhs.raw_size()[1];
   let mut result = Array2::zeros((N,M));
   
   range_2d(0..N,0..M).par_iter().map(|(i, j)| {
      // load the result for the (i,j) element into 'result'
   }).count();

   result
}

字符串
有没有办法做到这一点?

70gysomp

70gysomp1#

你可以用这种方式创建一个并行迭代器:

use rayon::prelude::*;

pub fn mul(lhs: &Array2<f32>, rhs: &Array2<f32>) -> Array2<f32> {
    let n = lhs.raw_dim()[0];
    let m = rhs.raw_dim()[1];
    let mut result = Array2::zeros((n, m));

    result
        .axis_iter_mut(Axis(0))
        .into_par_iter()
        .enumerate()
        .flat_map(|(n, axis)| {
            axis.into_slice()
                .unwrap()
                .par_iter_mut()
                .enumerate()
                .map(move |(m, item)| (n, m, item))
        })
        .for_each(|(n, m, item)| {
            // Do the multiplication.
            *item = n as f32 * m as f32;
        });

    result
}

字符串

相关问题