有没有办法把PyTorch中的预训练模型下载到一个特定的路径上?

y1aodyip  于 2022-11-29  发布在  其他
关注(0)|答案(4)|浏览(224)

我指的是可以在这里找到的模型:https://pytorch.org/docs/stable/torchvision/models.html#torchvision-models

sg24os4d

sg24os4d1#

正如,@dennlinger在他的answer中提到的:torch.utils.model_zoo,当您载入预先训练的模型时,会在内部呼叫。
更具体地,该方法:torch.utils.model_zoo.load_url()会在每次加载预先训练的模型时被调用。
model_dir的默认值为$TORCH_HOME/models,其中$TORCH_HOME的默认值为~/.torch
可以使用$TORCH_HOME环境变量覆盖默认目录。
这可以通过以下方式完成:

import torch 
import torchvision
import os

# Suppose you are trying to load pre-trained resnet model in directory- models\resnet

os.environ['TORCH_HOME'] = 'models\\resnet' #setting the environment variable
resnet = torchvision.models.resnet18(pretrained=True)

我在PyTorch的GitHub仓库中提出了一个问题,从而遇到了上述解决方案:https://github.com/pytorch/vision/issues/616
这导致了文档的改进,即上述解决方案。

6mw9ycah

6mw9ycah2#

是的,你可以简单地复制url,然后使用wget将其下载到所需的路径。
对于AlexNet

$ wget -c https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth

对于Google创业(v3)

$ wget -c https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth

对于挤压网

$ wget -c https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth

对于移动网络V2

$ wget -c https://download.pytorch.org/models/mobilenet_v2-b0353104.pth

对于高密度网络201

$ wget -c https://download.pytorch.org/models/densenet201-c1103571.pth

对于MNASNet1_0

$ wget -c https://download.pytorch.org/models/mnasnet1.0_top1_73.512-f206786ef8.pth

对于随机网络v2_x1.0

$ wget -c https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

如果你想在Python中实现它,那么可以使用如下代码:

In [11]: from six.moves import urllib

# resnet 101 host url
In [12]: url = "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"

# download and rename the file to `resnet_101.pth`
In [13]: urllib.request.urlretrieve(url, "resnet_101.pth")
Out[13]: ('resnet_101.pth', <http.client.HTTPMessage at 0x7f7fd7f53438>)

附言:您可以在torchvision.models的各个python模块中找到下载URL

x9ybnkn6

x9ybnkn63#

有一个脚本可以输出整个包中的URL列表。
pytorch/vision软件包中执行以下命令:

python scripts/collect_model_urls.py .

# ...
# https://download.pytorch.org/models/swin_v2_b-781e5279.pth
# https://download.pytorch.org/models/swin_v2_s-637d8ceb.pth
# https://download.pytorch.org/models/swin_v2_t-b137f0e2.pth
# https://download.pytorch.org/models/vgg11-8a719046.pth
# https://download.pytorch.org/models/vgg11_bn-6002323d.pth
# ...
rqcrx0a6

rqcrx0a64#

TL;DR:不,这是不可能的,但你可以很容易地适应它。
我想你要做的是查看torch.utils.model_zoo,当你加载一个预先训练好的模型时,它会被内部调用:
如果我们查看预先训练好的模型的代码,例如AlexNet here,我们可以看到它只是调用前面提到的model_zoo函数,但是没有保存的位置。(这实际上将是一个伟大的补充海事组织,所以也许打开一个拉请求),或者简单地根据自己的喜好采用第二个链接中的代码(并将其保存到一个不同名称下的自定义位置),然后在那里手动插入相关位置。
如果您希望定期更新PyTorch,我强烈推荐第二种方法,因为它不涉及直接更改PyTorch的代码库,并且在更新过程中可能会抛出错误。

相关问题