keras 如何在本地使用TF-Hub模型

8qgya5xd  于 11个月前  发布在  其他
关注(0)|答案(5)|浏览(108)

我一直在尝试使用来自tf-hub https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2的BERT模型。

import tensorflow_hub as hub
bert_layer = hub.keras_layer('./bert_en_uncased_L-12_H-768_A-12_2', trainable=True)

字符串
但问题是,它是下载数据后,每次运行。
所以我从tf-hub https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2下载了.tar文件
现在我尝试使用这个下载的tar文件(untar之后)
我已经遵循本教程https://medium.com/@xianbao.qian/how-to-run-tf-hub-locally-without-internet-connection-4506b850a915
但它并没有很好地工作,并没有进一步的信息或脚本是在这篇博客文章提供
如果有人可以提供完整的脚本,在本地使用下载的模型(没有互联网)或可以改善上述博客文章(中等)。
我也试过

untarredFilePath = './bert_en_uncased_L-12_H-768_A-12_2'
bert_lyr = hub.load(untarredFilePath)
print(bert_lyr)


输出

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.<locals>._UserObject object at 0x7f05c46e6a10>


好像不管用。
有没有其他方法可以做到这一点?

du7egjpx

du7egjpx1#

嗯,我不能重现你的问题。对我有效的是:
第一个月

# download the model file using the 'wget' program
wget "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2?tf-hub-format=compressed"

# rename the downloaded file name to 'tar_file.tar.gz'
mv 2\?tf-hub-format\=compressed tar_file.tar.gz

# extract tar_file.tar.gz to the local directory 
tar -zxvf tar_file.tar.gz

# turn off internet

# run a test script
python3 test.py

# running the last command prints some tensorflow warnings, and then '<tensorflow_hub.keras_layer.KerasLayer object at 0x7fd702a7d8d0>'

字符串
test.py

import tensorflow_hub as hub
print(hub.KerasLayer('.'))

sgtfey8w

sgtfey8w2#

我使用这篇中型文章(https://medium.com/@xianbao.qian/how-to-run-tf-hub-locally-without-internet-connection-4506b850a915)作为参考编写了这个脚本。我在我的项目中创建了一个缓存目录,tensorflow模型被缓存在这个缓存目录中,我可以在本地加载模型。希望这对你有帮助。

import os
os.environ["TFHUB_CACHE_DIR"] = r'C:\Users\USERX\PycharmProjects\PROJECTX\tf_hub'

import tensorflow as tf
import tensorflow_hub as hub
import hashlib

handle = "https://tfhub.dev/google/universal-sentence-encoder/4"
hashlib.sha1(handle.encode("utf8")).hexdigest()

embed = hub.load("https://tfhub.dev/google/universal-sentence-encoder/4")
def get_sentence_embeddings(paragraph_array):
    embeddings=embed(paragraph_array)
    return embeddings

字符串

sh7euo9m

sh7euo9m3#

在从tf-hub团队获得信息后,他们提供了这个解决方案。假设你已经从tf-hub官方模型页面下载了. tar.gz文件。你已经解压缩了它。你得到了一个包含资产,变量和模型的文件夹。你把它放在你的工作目录中。
在脚本中,只需添加该文件夹的路径:

import tensroflow-hub as hub

model_path ='./bert_en_uncased_L-12_H-768_A-12_2' # in my case
# one thing the path you have to provide is for folder which contain assets, variable and model
# not of the model.pb itself

lyr = hub.KerasLayer(model_path, trainable=True)

字符串
希望它也能对你起作用。给予试试吧

slhcrj9b

slhcrj9b4#

tensorflow_hub库将下载和未压缩的模型缓存在磁盘上,以避免重复上传。tensorflow.org/hub/caching上的文档已经扩展到讨论这种情况和其他情况。

mnemlml8

mnemlml85#

我发现模块下载到:/tmp/tfhub_modules

相关问题