Paddle version: 2.0.0-rc0 使用paddle.inference报:UnimplementedError: Invalid dimension to be accessed

2cmtqfgy  于 2022-10-20  发布在  其他
关注(0)|答案(4)|浏览(106)

环境************

Paddle version: 2.0.0-rc0
Paddle With CUDA: True
OS: Windows 10
Python version: 3.6.5
CUDA version: 10.1.243
cuDNN version: None.None.None
Nvidia driver version: 441.87

-代码如下:****************

from paddle.inference import Config
from paddle.inference import Predictor
import paddle
from paddle.vision.datasets import MNIST
import paddle.vision.transforms as trans
import numpy as np
from paddle.io import DataLoader
import matplotlib.pyplot as plt

%%

def draw_imgs(imgs):

print(imgs.shape)

imgs=imgs.reshape(imgs.shape[0],imgs.shape[2],imgs.shape[3])
print("draw_imgs--shape:{}".format(imgs.shape))
for i in range(imgs.shape[0]):
plt.subplot(2,4,i+1)
plt.imshow(imgs[i],cmap="gray")
plt.show()
#%%
train_ds=MNIST(mode="train")
train_loader=DataLoader(train_ds,batch_size=1)

train_it=iter(train_loader)
data=next(train_it)
imgs,labels=data

%%

print(imgs)
print(labels)

%%

draw_imgs(imgs.numpy())

%%

imgs
#%%
model_path="./mnist/"
config=Config(model_path+"mnist.pdmodel",model_path+"mnist.pdiparams")
predictor=Predictor(config)
input_names=predictor.get_input_names()
input_tensor=predictor.get_input_handle(input_names[0])

input_tensor.reshape([1,1,28,28])
input_tensor.copy_from_cpu(imgs) #此行报错
predictor.run()

output_names=predictor.get_output_names()
output_tensor=predictor.get_output_handle(output_names[0])
output_data=output_tensor.copy_to_cpu()
print(output_data)

报错:*************

UnimplementedError: Invalid dimension to be accessed. Now only supports access to dimension 0 to 9, but received dimension is 32. (at D:\2.0.0-rc0\paddle\paddle/fluid/framework/ddim.h:54)

!上面用到的模型为以下代码生成:***********************

#%%

import paddle
import paddle.nn as nn
import paddle.vision as vision
import paddle.metric as metric
import paddle.vision.transforms as trans
import paddle.optimizer as opt
import numpy as np
import matplotlib.pyplot as plt
from paddle.io import DataLoader
import paddle.nn.functional as F
import paddle.static
import paddle.jit

%%

transform=trans.Normalize(mean=[0.0],std=[1.0]) #mean /std的维度与数据维度一致
train_ds=vision.datasets.MNIST(mode="train",transform=transform)
test_ds=vision.datasets.MNIST(mode="test",transform=transform)

%%

Batch_size=64
epochs=5
LR=0.001

%%

train_loader=DataLoader(train_ds,batch_size=Batch_size,shuffle=True)
test_loader=DataLoader(test_ds,batch_size=Batch_size,shuffle=False)

%%

class MyNet(nn.Layer):
definit(self):
super(MyNet,self).init()
self.flatten=nn.Flatten()
self.fc1=nn.Linear(784,1024)
self.fc2=nn.Linear(1024,512)
self.fc3=nn.Linear(512,10)

@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None,1,28,28],name="inputs"),])  #注意书写格式
def forward(self, inputs):
    out=self.flatten(inputs)
    out=F.relu(self.fc1(out))
    out=F.relu(self.fc2(out))
    out=self.fc3(out)
    return out

%%

model=MyNet()

model.train()

optimizer=opt.Adam(learning_rate=LR,parameters=model.parameters())
loss_fun=nn.CrossEntropyLoss()

%%

for epoch in range(1,epochs+1):
correct_count=0
for idx,data in enumerate(train_loader):
inputs,labels=data
y_preds=model(inputs)
loss=loss_fun(y_preds,labels)
y_preds=y_preds.numpy()
y_max=np.argmax(y_preds,axis=1)
y_max=y_max.reshape(y_max.shape[0],1)
correct_count+=(y_max==labels.numpy()).sum()
loss.backward()
optimizer.step()
optimizer.clear_grad()
if idx%100==0:
print("epoch:%1d idx:%4d loss:%.4f"%(epoch,idx,loss.numpy()))
print("epoch:%1d Accuracy:%.2f"%(epoch,correct_count/len(train_ds)*100))

%%

paddle.jit.save(model,"./mnist/mnist")

【注意】*******************

DataLoader中batchsize设置不生效,总是使用32的batchsize,也请确认
   

xxb16uws

xxb16uws1#

您好,我们已经收到了您的问题,会安排技术人员在一天之内解答您的疑惑,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看 官网API文档常见问题历史IssueAI社区 来寻求解答。祝您生活愉快~

Hi! We've received your issue and please be patient to get responded.The average response time is expected to be with in one day.Please make sure that you have posted enough message to demo your request. You may also check out the APIFAQGithub Issue and AI community to get the answer.Have a nice day!

9udxz4iz

9udxz4iz3#

这个问题进展如何?

uklbhaso

uklbhaso4#

同样遇到这个问题,两年了居然还没解决。。。

相关问题