重排序模型ScoringModel目前仅支持两种在线模型。它是否可以类似于OnnxEmbeddingModel,以实现对OnnxScoringModel和本地重排序模型的使用?
jecbmhm31#
我写了一个类似的模型,基于OnnxEmbeddingModel,并使用本地的bge-reranker-large进行了正常测试,请随时指正我。 public class OnnxScoringModel extends AbstractInProcessScoringModel { } import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.output.Response; import dev.langchain4j.model.scoring.ScoringModel; import java.util.ArrayList; import java.util.List; public abstract class AbstractInProcessScoringModel implements ScoringModel { } import ai.djl.huggingface.tokenizers.Encoding; import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import java.nio.LongBuffer; import java.nio.file.Paths; import java.util.*; public class OnnxScoringBertBiEncoder { private final OrtEnvironment environment; private final OrtSession session; private final HuggingFaceTokenizer tokenizer;
public OnnxScoringBertBiEncoder(String modelPath, String pathToTokenizer) {try {this.environment = OrtEnvironment.getEnvironment();this.session = this.environment.createSession(modelPath);this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(pathToTokenizer), Map.of("truncation", "LONGEST_FIRST")); // LONGEST_FIRST 优先截断最长的那部分} catch (Exception var4) {throw new RuntimeException(var4);}}
public double score(String text, String query) {double score;try {OrtSession.Result result = this.encode(text, query);Throwable var8 = null;
try { score = this.toScore(result); } catch (Throwable var18) { var8 = var18; throw var18; } finally { if (result != null) { if (var8 != null) { try { result.close(); } catch (Throwable var17) { var8.addSuppressed(var17); } } else { result.close(); } } } } catch (OrtException var20) { throw new RuntimeException(var20); } return score;
}
private OrtSession.Result encode(String text, String query) throws OrtException {String[] pairs = new String[]{query, text};Encoding encoding = this.tokenizer.encode(pairs, true, false);long[] inputIds = encoding.getIds();long[] attentionMask = encoding.getAttentionMask();long[] shape = new long[]{1L, (long)inputIds.length};OnnxTensor inputIdsTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(inputIds), shape);Throwable var8 = null;
OrtSession.Result var14; try { OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(attentionMask), shape); Throwable var10 = null; try { Map<String, OnnxTensor> inputs = new HashMap(); inputs.put("input_ids", inputIdsTensor); inputs.put("attention_mask", attentionMaskTensor); var14 = this.session.run(inputs); } catch (Throwable var60) { var10 = var60; throw var60; } finally { if (attentionMaskTensor != null) { if (var10 != null) { try { attentionMaskTensor.close(); } catch (Throwable var56) { var10.addSuppressed(var56); } } else { attentionMaskTensor.close(); } } } } catch (Throwable var62) { var8 = var62; throw var62; } finally { if (inputIdsTensor != null) { if (var8 != null) { try { inputIdsTensor.close(); } catch (Throwable var55) { var8.addSuppressed(var55); } } else { inputIdsTensor.close(); } } } return var14;
private double toScore(OrtSession.Result result) throws OrtException {float[][] output = (float[][]) result.get(0).getValue();return this.sigmoid(output[0][0]);}
private double sigmoid(float x) {return 1 / (1 + Math.exp(-x));}
1条答案
按热度按时间jecbmhm31#
public OnnxScoringBertBiEncoder(String modelPath, String pathToTokenizer) {
try {
this.environment = OrtEnvironment.getEnvironment();
this.session = this.environment.createSession(modelPath);
this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(pathToTokenizer), Map.of("truncation", "LONGEST_FIRST")); // LONGEST_FIRST 优先截断最长的那部分
} catch (Exception var4) {
throw new RuntimeException(var4);
}
}
public double score(String text, String query) {
double score;
try {
OrtSession.Result result = this.encode(text, query);
Throwable var8 = null;
}
private OrtSession.Result encode(String text, String query) throws OrtException {
String[] pairs = new String[]{query, text};
Encoding encoding = this.tokenizer.encode(pairs, true, false);
long[] inputIds = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
long[] shape = new long[]{1L, (long)inputIds.length};
OnnxTensor inputIdsTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(inputIds), shape);
Throwable var8 = null;
}
private double toScore(OrtSession.Result result) throws OrtException {
float[][] output = (float[][]) result.get(0).getValue();
return this.sigmoid(output[0][0]);
}
private double sigmoid(float x) {
return 1 / (1 + Math.exp(-x));
}