Python和numpy中的简单梯度下降

disho6za  于 2023-05-21  发布在  Python
关注(0)|答案(1)|浏览(202)

我试图在python中实现一个简单的梯度下降,只使用numpy,但是缺少了一些东西,我找不到它。我过去又做过一次,但不知何故,我一直盯着这个问题,过去的一天,没有能够使它的工作。

def init_param():
    w1 = np.random.rand(10, 64)
    b1 = np.random.rand(10, 1)
    w2 = np.random.rand(10, 10)
    b2 = np.random.rand(10, 1)
    return w1, b1, w2, b2

def ReLU(z):
    return np.maximum(0, z)

def dReLU(z):
    return z > 0

def dtanh(z):
    return 1 - np.tanh(z)**2

def tanh(z):
    return np.tanh(z)

def forward_prop(w1, b1, w2, b2, x):
    z1 = w1.dot(x).reshape(-1,1) + b1
    a1 = ReLU(z1)
    z2 = w2.dot(a1) + b2
    a2 = tanh(z2)
    return z1, a1, z2, a2

def one_hot(y):
    one_hot_y = np.zeros((y.size, 10))
    one_hot_y[np.arange(y.size), y] = 1
    return one_hot_y.T

def back_prop(z1, a1, z2, a2, w2, x, y):
    one_hot_y = one_hot(y)
    
    dz2 = (a2 - one_hot_y) * tanh(z2)
    dw2 = dz2.dot(a1.T)
    db2 = np.sum(dz2, 1, keepdims=True)

    dz1 = w2.T.dot(dz2) * dReLU(z1)
    dw1 = dz1.dot(x.T)
    db1 = np.sum(dz1, 1, keepdims=True)

    return dw1, db1, dw2, db2

def update_param(w1, b1, w2, b2, dw1, db1, dw2, db2, lr):
    w1 = w1 - lr*dw1
    b1 = b1 - lr*db1
    w2 = w2 - lr*dw2
    b2 = b2 - lr*db2
    return w1, b1, w2, b2

def gradient_descent(X, Y, max_iter, lr):
    w1, b1, w2, b2 = init_param()
    m = len(X)
    for _ in range(max_iter):
        y_hat = []
        dw1, db1, dw2, db2 = None, None, None, None
        for x, y in zip(X, Y):
            x = x.reshape(-1,1)
            z1, a1, z2, a2 = forward_prop(w1, b1, w2, b2, x)
            if dw1 is None:
                dw1, db1, dw2, db2 = back_prop(z1, a1, z2, a2, w2, x, y)
            else:
                _dw1, _db1, _dw2, _db2 = back_prop(z1, a1, z2, a2, w2, x, y)
                dw1 += _dw1
                db1 += _db1
                dw2 += _dw2
                db2 += _db2
            y_hat.append(np.argmax(a2))
        w1, b1, w2, b2 = update_param(w1, b1, w2, b2, dw1, db1, dw2, db2, lr*(1/m))
        print(accuracy_score(Y, y_hat), end='\r')
    return w1, b1, w2, b2

模型不是训练。它停留在特定的错误值上。我已经检查了矩阵和数组的尺寸,这些都是正确的。问题一定是在模型的数学上,但不幸的是我不能弄清楚。

q35jwt9p

q35jwt9p1#

我可以在你的代码中看到一些潜在的问题:

  • 你在你的back_prop函数中使用了tanh,而你应该使用dtanh
  • 你不用损失函数。你只需要使用a2 - one_hot_y,它不能作为损失函数,因为(在非常高的级别上)它有一个符号,而你只需要dz2.dot(a1.T)的符号,否则第一行的符号经常会抵消第二行的符号。一个可能有效的损失函数的例子是简单地平方你的当前值,但是你最好使用分类交叉熵损失之类的东西。
  • 您的one_hot函数与您的tanh激活函数不匹配:tanh产生-1和1(其间有平滑过渡)。one_hot函数生成0和1。你应该使用softmax而不是tanh作为你的激活函数。你可以修改你的one_hot函数,但通常不会这样做。

另一个潜在的问题是初始权重可能太大:你的很多计算结果都是绝对值很大的数字。也许你可以试着减轻你的初始体重。在文献中还存在用于初始化权重的特定方法。看看这些可能会有好处。此外,您可能会更好地将权重初始化为均值为零,因此您的网络也具有负权重。
我现在不能实际测试它以确保它工作正常,但如果这不能解决问题,我将尝试稍后再看一看。

相关问题