在rust中smartcore的train_test_split函数中使用DMatrix(nalgebra)时出错

zqdjd7g9  于 2023-10-20  发布在  其他
关注(0)|答案(1)|浏览(97)

我正在尝试创建一个小型的机器学习应用程序。我正在从csv中阅读数据,并已将其转换为nalgebra库中的DMatrix。为了将数据集划分为训练和测试子集,我想利用smartcore函数train_test_split
我有问题时,使用此功能与从csv生成的DMatrix。你能告诉我为什么它们会发生,以及我如何解决它们吗?
代码如下:

use std::error::Error;
use std::io::BufReader;
use std::io::BufRead;
use std::fs::File;
use nalgebra::DMatrix;
use std::str::FromStr;

use smartcore::model_selection::train_test_split;

fn read_csv(input: &mut dyn BufRead) -> Result<DMatrix<f64>, Box<dyn Error>> {

    let mut samples = Vec::new();

    let mut rows = 0;

    for line in input.lines().skip(1){
        rows += 1;

        for data in line?.split_terminator(",") {

            let a = f64::from_str(data.trim());

            match a {
                Ok(value) => samples.push(value),
                Err(e) => println!("Error parsing data in row: {}", rows),
            }
        }
    }

    let cols = samples.len() / rows;

    Ok(DMatrix::from_row_slice(rows, cols, &samples[..]))

}

fn main() -> Result<(), Box<dyn Error>> {

    //Load CSV
    let file = File::open("dataset/heart.csv").unwrap();
    let data: DMatrix<f64> = read_csv(&mut BufReader::new(file)).unwrap();

    let x = data.columns(0, 13).into_owned();
    let y = data.column(13).into_owned();

    // ERROR
    let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y.transpose(), 0.2, true);

    println!("{:?}", x_train);
    
    Ok(())

}

以下是我得到的错误:

error[E0277]: the trait bound `nalgebra::Matrix<f64, Dyn, Dyn, VecStorage<f64, Dyn, Dyn>>: smartcore::linalg::Matrix<_>` is not satisfied
   --> src/main.rs:53:63
    |
53  |     let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y.transpose(), 0.2, true);
    |                                              ---------------- ^^ the trait `smartcore::linalg::Matrix<_>` is not implemented for `nalgebra::Matrix<f64, Dyn, Dyn, VecStorage<f64, Dyn, Dyn>>`
    |                                              |
    |                                              required by a bound introduced by this call
    |
    = help: the following other types implement trait `smartcore::linalg::Matrix<T>`:
              DenseMatrix<T>
              nalgebra::base::matrix::Matrix<T, nalgebra::base::dimension::Dynamic, nalgebra::base::dimension::Dynamic, nalgebra::base::vec_storage::VecStorage<T, nalgebra::base::dimension::Dynamic, nalgebra::base::dimension::Dynamic>>
note: required by a bound in `train_test_split`

    |
133 | pub fn train_test_split<T: RealNumber, M: Matrix<T>>(
    |                                           ^^^^^^^^^ required by this bound in `train_test_split`

error[E0277]: the trait bound `nalgebra::Matrix<f64, Dyn, Dyn, VecStorage<f64, Dyn, Dyn>>: BaseMatrix<_>` is not satisfied
  --> src/main.rs:53:46
   |
53 |     let (x_train, x_test, y_train, y_test) = train_test_split(&x, &y.transpose(), 0.2, true);
   |                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `BaseMatrix<_>` is not implemented for `nalgebra::Matrix<f64, Dyn, Dyn, VecStorage<f64, Dyn, Dyn>>`
   |
   = help: the following other types implement trait `BaseMatrix<T>`:
             DenseMatrix<T>
             nalgebra::base::matrix::Matrix<T, nalgebra::base::dimension::Dynamic, nalgebra::base::dimension::Dynamic, nalgebra::base::vec_storage::VecStorage<T, nalgebra::base::dimension::Dynamic, nalgebra::base::dimension::Dynamic>>

For more information about this error, try `rustc --explain E0277`.
error: could not compile `logistic-regression` due to 2 previous errors
ubof19bj

ubof19bj1#

问题是[[email protected]](https://stackoverflow.com/cdn-cgi/l/email-protection)依赖于[[email protected]](https://stackoverflow.com/cdn-cgi/l/email-protection),所以它只为该版本的nalgebra类型实现了它的trait。直到有一个更新的smartcore依赖于nalgebra的更新版本,你将不得不降级到相同的nalgebra版本:

[dependencies]
nalgebra = "0.23.2"
smartcore = "0.2.0"

相关问题