rust 如何将&mut ndarray::Array传递给函数并使用它执行元素算术?

j13ufse2  于 2023-06-06  发布在  其他
关注(0)|答案(2)|浏览(215)

动机

我正在尝试用Rust为一个学校项目创建我的第一个真实的的交易程序(不是必需的..我只是被Rust迷住了,并决定冒险一试)。
该项目是一个简单的模拟机器人的决策基于一些传感器数据,一些概率,预测未来的回报,和一些其他的东西。该程序由一个主循环组成,其中在未来的某个时间范围内,每个时间步都会发生大量的数学运算。被携带到每个后续时间步的数据由矩阵 Y 表示,矩阵 * Y * 由一组线性约束的两列线性系数(其在每个时间步被修改)组成(其中更多约束/系数行在每个时间步被添加到该组)。
由于该程序需要大量的元素矩阵运算,而我在NumPy方面经验丰富,因此ndarray crate似乎非常适合这项工作。我对这个程序的想法是为 Y 创建一个可变的2D数组,它会随着每次循环迭代而修改,而不是每次都分配一个新数组。从那以后,我开始意识到,每次迭代的行数也会以未知的数量增长,所以也许这种方法不是最好的想法,但我对错误的问题仍然存在。

提问

我的问题是:如果我想在循环的每次迭代中修改一个数组,方法是将数组的引用传递给几个函数,这些函数将修改数组的数据,那么我如何在基本的元素算术运算中使用同一个数组?
下面是我的代码的一个基本示例:

extern crate ndarray;

use ndarray::prelude::*;

fn main() {
    let pz = array![[0.7, 0.3], [0.3, 0.7]]; // measurement probabilities

    let mut Y = Array2::<f64>::zeros((1, 2));

    for i in 1..10 {
        do_some_maths(&mut Y, pz);
        // other functions that will modify Y
    }
    
    println!("Result: {}", Y);
}

fn do_some_maths(Y: &mut Array2<f64>, pz: Array2<f64>) {

    let Yp = Y * pz.slice(s![.., 0]);  // <-- this is the problem

    // do lots of matrix math with Yp
    // ...
    // then modify Y's data using Yp (hence Y needs to be &mut)
}

这会产生以下编译错误:

error[E0369]: binary operation `*` cannot be applied to type `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`
  --> src/main2.rs:21:16
   |
21 |     let Yp = Y * pz.slice(s![.., 0]);  // <-- this is the problem
   |              - ^ ------------------- ndarray::ArrayBase<ndarray::ViewRepr<&f64>, _>
   |              |
   |              &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>
   |
   = note: an implementation of `std::ops::Mul` might be missing for `&mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>`

我花了很多时间试图理解
1.我的用例的正确方法是什么,以及
1.为什么我写的代码不起作用
我在这个网站上读到了几个有点相关的问题,但没有一个真正涉及到将数组引用作为函数参数处理并对其执行二进制操作的情况。
我已经努力学习了Rust这本书的前5章,并深入研究了ndarray的文档,但仍然找不到答案。ndarrayArrayBase文档包含以下解释,我不完全理解:

两个数组的二进制操作

设A是一个数组或任意类型的视图。设B为拥有存储的阵列(Array或ArcArray)。假设C是一个具有可变数据的数组(Array、ArcArray或ArrayViewMut)。由@表示的任意二元运算符支持以下操作数组合(可以是+、-、*、/等)。

  • &A @ &A生成一个新数组
  • B @ A,它消耗B,用结果更新它,并返回它
  • B @ &A,它消耗B,用结果更新它,并返回它
  • C @= &A,它在适当的位置执行算术运算

给出这个描述,并搜索AddMul等的许多trait实现,在我看来,**可变的ndarray::Array不能是二进制操作中的操作数,除非在复合赋值的情况下。
是真的吗,还是我漏掉了什么?我不想简单地记住这个小花絮,然后继续前进;我真的很想了解这里到底发生了什么,以及我的理解在哪里有所欠缺。请帮助我把我的C++/Python训练过的大脑围绕这个。:)

ubof19bj

ubof19bj1#

你已经回答了自己的问题:您尝试执行的乘法是&C @ B,它不是ndarray支持的四个乘法之一。另外,您将pz作为值传递给函数。它在循环的第一轮就被消耗掉了,在其余的循环中不再可用。所以也编译不了。
这个方法的作用是:

extern crate ndarray;
use ndarray::prelude::*;

fn main() {
    let pz = array![[0.7, 0.3], [0.3, 0.7]];
    let mut y = Array2::<f64>::zeros((1, 2));

    for _ in 1..10 {
        do_some_maths(&mut y, &pz);
    }

    println!("Result: {}", y);
}

fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
    *y *= &pz.slice(s![.., 0]);
}
agxfikkp

agxfikkp2#

可变的引用比不可变的引用“更强大”,你总是可以让一个可变的引用充当不可变的引用,所以这不是问题。
正如edwardw所指出的,你可能不想在每个循环中都使用数组pz(编译器也不会让你这么做)。实际上,如果你考虑do_some_maths函数的签名,你会得到:

  • 要修改的可变数组
  • 一个不可变的,你用另外

因此,将签名设为:

fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
   ...
}

现在,ndarray机箱可以让您:

  • 就地修改值或
  • 为您的运营创建新的

一般来说,它对它的输入是非常敏感的,只要有可能就接受引用,以免消耗你的输入数组。这意味着可能需要大量的引用(去引用),您可以随意使用。在numpy中,几乎所有的东西都是引用,所以你不必担心,但逻辑是一样的。
如果你想从Y创建Yp,你可以通过分配一个新的Yp值来实现:

// EDIT: does not compiles as of 2023
// fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
//     // yp is a new Array2<f64>
//     let yp: Array2<f64> = y * pz;
//     // We may want to modify `y` now
//     y.scaled_add(-2.3, yp);
//     y *= pz;
// }
fn do_some_maths(y: &mut Array2<f64>, pz: &Array2<f64>) {
    // yp is a new Array2<f64>
    let yp: Array2<f64> = &*y * &*pz;
    // We may want to modify `y` now
    y.scaled_add(-2.3, &yp);
    *y *= pz;
}

这里进行的各种操作是:

  • &Array2 * &Array2 -> Array2
  • scaled_add(self:&mut Array 2,f64,&Array2)->(),就地修改数组
  • 就地标量运算&mut Array 2 *= &Array2

一般来说,尽可能多地使用引用(不管是不是可变的),除非你知道输入 * 应该 * 被使用。
为了阐明与numpy的相似之处:numpy数组本质上是 all 引用。Rust提供了直接传入值(因此被消耗-将它们视为使用一次,然后它们被销毁)或引用(可变或不可变,取决于您是否需要改变它们)的粒度。Numpy基本上在任何地方都使用可变引用(除非显式切换WRITEABLE标志)。

相关问题