我试图在python中实现一个简单的梯度下降,只使用numpy,但是缺少了一些东西,我找不到它。我过去又做过一次,但不知何故,我一直盯着这个问题,过去的一天,没有能够使它的工作。
def init_param():
w1 = np.random.rand(10, 64)
b1 = np.random.rand(10, 1)
w2 = np.random.rand(10, 10)
b2 = np.random.rand(10, 1)
return w1, b1, w2, b2
def ReLU(z):
return np.maximum(0, z)
def dReLU(z):
return z > 0
def dtanh(z):
return 1 - np.tanh(z)**2
def tanh(z):
return np.tanh(z)
def forward_prop(w1, b1, w2, b2, x):
z1 = w1.dot(x).reshape(-1,1) + b1
a1 = ReLU(z1)
z2 = w2.dot(a1) + b2
a2 = tanh(z2)
return z1, a1, z2, a2
def one_hot(y):
one_hot_y = np.zeros((y.size, 10))
one_hot_y[np.arange(y.size), y] = 1
return one_hot_y.T
def back_prop(z1, a1, z2, a2, w2, x, y):
one_hot_y = one_hot(y)
dz2 = (a2 - one_hot_y) * tanh(z2)
dw2 = dz2.dot(a1.T)
db2 = np.sum(dz2, 1, keepdims=True)
dz1 = w2.T.dot(dz2) * dReLU(z1)
dw1 = dz1.dot(x.T)
db1 = np.sum(dz1, 1, keepdims=True)
return dw1, db1, dw2, db2
def update_param(w1, b1, w2, b2, dw1, db1, dw2, db2, lr):
w1 = w1 - lr*dw1
b1 = b1 - lr*db1
w2 = w2 - lr*dw2
b2 = b2 - lr*db2
return w1, b1, w2, b2
def gradient_descent(X, Y, max_iter, lr):
w1, b1, w2, b2 = init_param()
m = len(X)
for _ in range(max_iter):
y_hat = []
dw1, db1, dw2, db2 = None, None, None, None
for x, y in zip(X, Y):
x = x.reshape(-1,1)
z1, a1, z2, a2 = forward_prop(w1, b1, w2, b2, x)
if dw1 is None:
dw1, db1, dw2, db2 = back_prop(z1, a1, z2, a2, w2, x, y)
else:
_dw1, _db1, _dw2, _db2 = back_prop(z1, a1, z2, a2, w2, x, y)
dw1 += _dw1
db1 += _db1
dw2 += _dw2
db2 += _db2
y_hat.append(np.argmax(a2))
w1, b1, w2, b2 = update_param(w1, b1, w2, b2, dw1, db1, dw2, db2, lr*(1/m))
print(accuracy_score(Y, y_hat), end='\r')
return w1, b1, w2, b2
模型不是训练。它停留在特定的错误值上。我已经检查了矩阵和数组的尺寸,这些都是正确的。问题一定是在模型的数学上,但不幸的是我不能弄清楚。
1条答案
按热度按时间q35jwt9p1#
我可以在你的代码中看到一些潜在的问题:
back_prop
函数中使用了tanh
,而你应该使用dtanh
。a2 - one_hot_y
,它不能作为损失函数,因为(在非常高的级别上)它有一个符号,而你只需要dz2.dot(a1.T)
的符号,否则第一行的符号经常会抵消第二行的符号。一个可能有效的损失函数的例子是简单地平方你的当前值,但是你最好使用分类交叉熵损失之类的东西。one_hot
函数与您的tanh
激活函数不匹配:tanh
产生-1和1(其间有平滑过渡)。one_hot
函数生成0和1。你应该使用softmax而不是tanh作为你的激活函数。你可以修改你的one_hot
函数,但通常不会这样做。另一个潜在的问题是初始权重可能太大:你的很多计算结果都是绝对值很大的数字。也许你可以试着减轻你的初始体重。在文献中还存在用于初始化权重的特定方法。看看这些可能会有好处。此外,您可能会更好地将权重初始化为均值为零,因此您的网络也具有负权重。
我现在不能实际测试它以确保它工作正常,但如果这不能解决问题,我将尝试稍后再看一看。