apacheignite更新以前训练过的ml模型

ws51t4hk  于 2021-07-06  发布在  Java
关注(0)|答案(1)|浏览(388)

我有一个用于训练knn模型的数据集。稍后我想用新的训练数据更新模型。我看到的是,更新后的模型只接受新的训练数据,而忽略了以前训练的数据。

  1. Vectorizer vec = new DummyVectorizer<Integer>(1, 2).labeled(0);
  2. DatasetTrainer<KNNClassificationModel, Double> trainer = new KNNClassificationTrainer();
  3. KNNClassificationModel model;
  4. KNNClassificationModel modelUpdated;
  5. Map<Integer, Vector> trainingData = new HashMap<Integer, Vector>();
  6. Map<Integer, Vector> trainingDataNew = new HashMap<Integer, Vector>();
  7. Double[][] data1 = new Double[][] {
  8. {0.136,0.644,0.154},
  9. {0.302,0.634,0.779},
  10. {0.806,0.254,0.211},
  11. {0.241,0.951,0.744},
  12. {0.542,0.893,0.612},
  13. {0.334,0.277,0.486},
  14. {0.616,0.259,0.121},
  15. {0.738,0.585,0.017},
  16. {0.124,0.567,0.358},
  17. {0.934,0.346,0.863}};
  18. Double[][] data2 = new Double[][] {
  19. {0.300,0.236,0.193}};
  20. Double[] observationData = new Double[] { 0.8, 0.7 };
  21. // fill dataset (in cache)
  22. for (int i = 0; i < data1.length; i++)
  23. trainingData.put(i, new DenseVector(data1[i]));
  24. // first training / prediction
  25. model = trainer.fit(trainingData, 1, vec);
  26. System.out.println("First prediction : " + model.predict(new DenseVector(observationData)));
  27. // new training data
  28. for (int i = 0; i < data2.length; i++)
  29. trainingDataNew.put(data1.length + i, new DenseVector(data2[i]));
  30. // second training / prediction
  31. modelUpdated = trainer.update(model, trainingDataNew, 1, vec);
  32. System.out.println("Second prediction: " + modelUpdated.predict(new DenseVector(observationData)));

作为输出,我得到:

  1. First prediction : 0.124
  2. Second prediction: 0.3

这看起来像是第二个预测只使用了data2,它必须导致0.3作为预测。
模型更新是如何工作的?如果我必须将data2添加到data1中,然后再次对data1进行训练,那么与对所有组合数据进行的全新训练相比,有什么区别呢?

mspsb9vt

mspsb9vt1#

模型更新是如何工作的?
特别是对于knn:将data2添加到data1,并对组合数据调用modelupdate。
以该测试为例:https://github.com/apache/ignite/blob/635dafb7742673494efa6e8e91e236820156d38f/modules/ml/src/test/java/org/apache/ignite/ml/knn/knnclassificationtest.java#l167
按照测试中的说明操作:设置培训师:

  1. KNNClassificationTrainer trainer = new KNNClassificationTrainer()
  2. .withK(3)
  3. .withDistanceMeasure(new EuclideanDistance())
  4. .withWeighted(false);

然后设置矢量器:(注意标记坐标是如何创建的)

  1. model = trainer.fit(
  2. trainingData,
  3. parts,
  4. new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
  5. );

然后根据需要调用updatemodel。

  1. KNNClassificationModel updatedOnData = trainer.update(
  2. originalMdlOnEmptyDataset,
  3. newData,
  4. parts,
  5. new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
  6. );

knn分类文件:https://ignite.apache.org/docs/latest/machine-learning/binary-classification/knn-classification
knn分类示例:https://github.com/apache/ignite/blob/master/examples/src/main/java/org/apache/ignite/examples/ml/knn/knnclassificationexample.java

展开查看全部

相关问题