rust 使用泛型输出实现枚举,泛型输出的类型在实现中确定

66bbxpm5  于 2023-03-08  发布在  其他
关注(0)|答案(1)|浏览(138)

我遇到了这样的问题:我想创建一个具有多个字段的结构体,其中一个是枚举,它决定如何处理许多操作。枚举决定了一个重要函数具有什么inputoutput类型。我不希望主结构包含。input/的初始化结构output,因为它们往往变化很大,并且值对于主结构来说并不重要。
给定设置:

struct Model<I,O> {
    name: String,
    model_type: ModelType
    _data_format: PhantomData::<(I,O)>
    //... many other fields
}

enum ModelType{
    Onnx,
    DecisiomTree,
}

impl ModelType{
   pub fn get_model<I,O>(self) -> Model<I,O>{
      match self {
         Self::Onnx => {
            Model<I,O> {
               name: "onnx".into(),
               model_type: Self::Onnx,
               PhantomData::<(OnnxInput,OnnxOutput)> //  <--- Cant do it like this, what should be done here?? without initializing the given structures.
            }
         },
         //.. other models
     }
   }
}

struct OnnxInput {
    matrix: Vec<Vec<f64>>
}

struct DtInput {
    color: f32,
    size: f32,
    //... more... 
}

impl<I,O> Model<I,O> {
    pub fn score(&self, input: I) -> O{
        
        self.model_type.score::<I>(input)
    }
    
}

impl ModelType{
    pub fn score<I>(&self, input: I) -> O
}
cu6pst1q

cu6pst1q1#

我认为你想要的是用单元结构实现一个trait,而不是ModelType枚举,如下所示:

use std::marker::PhantomData;

trait ModelScore {
    type Input;
    type Output;
    
    fn score(i: Self::Input) -> Self::Output;
}

struct Model<T: ModelScore> {
    marker: PhantomData<T>,
}

impl<T: ModelScore> Model<T> {
    fn new() -> Self {
        Self {
            marker: PhantomData,
        }
    }
    
    fn score(&self, i: T::Input) -> T::Output {
        T::score(i)
    }
}

struct Onnx;

struct OnnxInput;

#[derive(PartialEq, Debug)]
struct OnnxOutput;

struct DecisionTree;

struct DecisionTreeInput;

#[derive(PartialEq, Debug)]
struct DecisionTreeOutput;

impl ModelScore for Onnx {
    type Input = OnnxInput;
    type Output = OnnxOutput;
    
    fn score(_: Self::Input) -> Self::Output {
        OnnxOutput
    }
}

impl ModelScore for DecisionTree {
    type Input = DecisionTreeInput;
    type Output = DecisionTreeOutput;
    
    fn score(_: Self::Input) -> Self::Output {
        DecisionTreeOutput
    }
}

fn main() {
    let m1: Model<Onnx> = Model::new();
    let m2: Model<DecisionTree> = Model::new();
    
    assert_eq!(m1.score(OnnxInput), OnnxOutput);
    assert_eq!(m2.score(DecisionTreeInput), DecisionTreeOutput);
}

Playground.
我们现在有一个trait ModelScore,它与InputOutput相关联,ModelType的变体被单元结构OnnxDecisionTree所取代。这允许我们定义Model示例,如main函数中所示。您的输入和输出类型对于编译器来说是已知的,但是没有按照您的要求进行初始化。

相关问题