langchain4j [特性]类似于OnnxEmbeddingModel的OnnxScoringModel

c8ib6hqw  于 4个月前  发布在  其他
关注(0)|答案(1)|浏览(55)

重排序模型ScoringModel目前仅支持两种在线模型。它是否可以类似于OnnxEmbeddingModel,以实现对OnnxScoringModel和本地重排序模型的使用?

jecbmhm3

jecbmhm31#

我写了一个类似的模型,基于OnnxEmbeddingModel,并使用本地的bge-reranker-large进行了正常测试,请随时指正我。
public class OnnxScoringModel extends AbstractInProcessScoringModel {
}
import dev.langchain4j.data.segment.TextSegment;
 import dev.langchain4j.model.output.Response;
 import dev.langchain4j.model.scoring.ScoringModel;
import java.util.ArrayList;
 import java.util.List;
public abstract class AbstractInProcessScoringModel implements ScoringModel {
}
import ai.djl.huggingface.tokenizers.Encoding;
 import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
 import ai.onnxruntime.OnnxTensor;
 import ai.onnxruntime.OrtEnvironment;
 import ai.onnxruntime.OrtException;
 import ai.onnxruntime.OrtSession;
import java.nio.LongBuffer;
 import java.nio.file.Paths;
 import java.util.*;
public class OnnxScoringBertBiEncoder {
 private final OrtEnvironment environment;
 private final OrtSession session;
 private final HuggingFaceTokenizer tokenizer;

public OnnxScoringBertBiEncoder(String modelPath, String pathToTokenizer) {
try {
this.environment = OrtEnvironment.getEnvironment();
this.session = this.environment.createSession(modelPath);
this.tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(pathToTokenizer), Map.of("truncation", "LONGEST_FIRST")); // LONGEST_FIRST 优先截断最长的那部分
} catch (Exception var4) {
throw new RuntimeException(var4);
}
}

public double score(String text, String query) {
double score;
try {
OrtSession.Result result = this.encode(text, query);
Throwable var8 = null;

    try {
        score = this.toScore(result);
    } catch (Throwable var18) {
        var8 = var18;
        throw var18;
    } finally {
        if (result != null) {
            if (var8 != null) {
                try {
                    result.close();
                } catch (Throwable var17) {
                    var8.addSuppressed(var17);
                }
            } else {
                result.close();
            }
        }

    }
} catch (OrtException var20) {
    throw new RuntimeException(var20);
}

return score;

}

private OrtSession.Result encode(String text, String query) throws OrtException {
String[] pairs = new String[]{query, text};
Encoding encoding = this.tokenizer.encode(pairs, true, false);
long[] inputIds = encoding.getIds();
long[] attentionMask = encoding.getAttentionMask();
long[] shape = new long[]{1L, (long)inputIds.length};
OnnxTensor inputIdsTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(inputIds), shape);
Throwable var8 = null;

OrtSession.Result var14;
try {
    OnnxTensor attentionMaskTensor = OnnxTensor.createTensor(this.environment, LongBuffer.wrap(attentionMask), shape);
    Throwable var10 = null;
    try {
        Map<String, OnnxTensor> inputs = new HashMap();
        inputs.put("input_ids", inputIdsTensor);
        inputs.put("attention_mask", attentionMaskTensor);

        var14 = this.session.run(inputs);
    } catch (Throwable var60) {
        var10 = var60;
        throw var60;
    } finally {
        if (attentionMaskTensor != null) {
            if (var10 != null) {
                try {
                    attentionMaskTensor.close();
                } catch (Throwable var56) {
                    var10.addSuppressed(var56);
                }
            } else {
                attentionMaskTensor.close();
            }
        }

    }
} catch (Throwable var62) {
    var8 = var62;
    throw var62;
} finally {
    if (inputIdsTensor != null) {
        if (var8 != null) {
            try {
                inputIdsTensor.close();
            } catch (Throwable var55) {
                var8.addSuppressed(var55);
            }
        } else {
            inputIdsTensor.close();
        }
    }
}
return var14;

}

private double toScore(OrtSession.Result result) throws OrtException {
float[][] output = (float[][]) result.get(0).getValue();
return this.sigmoid(output[0][0]);
}

private double sigmoid(float x) {
return 1 / (1 + Math.exp(-x));
}


}

相关问题