我有一个经过训练的tf模型,我想把它应用到hdfs中的大数据集,这个数据集大约有十亿个样本。重点是我需要将tf模型的预测写入hdfs文件。但是我在tensorflow中找不到关于如何在hdfs文件中保存数据的相关api,只能找到关于读取hdfs文件的api
到目前为止,我的方法是将经过训练的tf模型保存到本地的pb文件中,然后使用spark或mapreduce代码中的javaapi加载pb文件。spark和mapreduce的问题都是运行速度很慢,出现了内存错误。这是我的演示:
public class TF_model implements Serializable{
public Session session;
public TF_model(String model_path){
try{
Graph graph = new Graph();
InputStream stream = this.getClass().getClassLoader().getResourceAsStream(model_path);
byte[] graphBytes = IOUtils.toByteArray(stream);
graph.importGraphDef(graphBytes);
this.session = new Session(graph);
}
catch (Exception e){
System.out.println("failed to load tensorflow model");
}
}
// this is the function to predict a sample in hdfs
public int[][] predict(int[] token_id_array){
Tensor z = session.runner()
.feed("words_ids_placeholder", Tensor.create(new int[][]{token_id_array}))
.fetch("softmax_prediction").run().get(0);
double[][][] softmax_prediction = new double[1][token_id_array.length][2];
z.copyTo(softmax_prediction);
return softmax_prediction[0];
}}
下面是我的Spark代码:
val rdd = spark.sparkContext.textFile(file_path)
val predct_result= rdd.mapPartitions(pa=>{
val tf_model = new TF_model("model.pb")
pa.map(line=>{
val transformed = transform(line) // omitted the transform code
val rs = tf_model .predict(transformed)
rs
})
})
我也尝试了部署在hadoop中的tensorflow,但是找不到将大数据集写入hdfs的方法。
1条答案
按热度按时间hpxqektj1#
您可以从hdfs读取一次模型文件,然后使用sc.broadcast将图形的字节数组广播到分区。最后,启动负荷图并进行预测。只是为了避免从hdfs多次读取文件。