rust 如何简明扼要地比较ndarray的形状?

ozxc1zmp  于 2023-01-02  发布在  其他
关注(0)|答案(1)|浏览(118)

我刚到拉斯特。
假设矩阵a具有(n1, n2)的形状,b具有(m1, m2)c具有(k1, k2),我想检查ab是否可以相乘(作为矩阵),a * b的形状是否等于c,换句话说,(n2 == m1) && (n1 == k1) && (m2 == k2)

use ndarray::Array2;

// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>

.shape方法将数组的形状作为切片返回。什么是简洁的方法?
.shape()返回的数组是否保证长度为2,或者我是否应该检查它?如果保证,是否有办法跳过None检查?

let n1 = a.shape().get(0);  // this is Optional<i64>
c9x0cxw0

c9x0cxw01#

对于Array2,有.ncols()和.nrows()方法,如果你只使用二维数组,那么这可能是最好的选择,它们返回usize,所以不需要None检查。

use ndarray::prelude::*;

fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
    //nrows() and ncols() are only valid for Array2, 
    //[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
    return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
    let a = Array2::<i64>::zeros((3, 5));
    let b = Array2::<i64>::zeros((5, 6));
    let c_valid = Array2::<i64>::zeros((3, 6));
    let c_invalid = Array2::<i64>::zeros((8, 6));

    println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
    println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/

相关问题