正则化Numpy线性回归

t0ybt7op  于 2024-01-08  发布在  其他
关注(0)|答案(2)|浏览(187)

我没有看到我的正则化线性回归代码有什么问题。Unregularized我只有这个,我有理由肯定是正确的:

import numpy as np

def get_model(features, labels):
    return np.linalg.pinv(features).dot(labels)

字符串
下面是我的正则化解决方案的代码,我没有看到它有什么问题:

def get_model(features, labels, lamb=0.0):
    n_cols = features.shape[1]
    return linalg.inv(features.transpose().dot(features) + lamb * np.identity(n_cols))\
            .dot(features.transpose()).dot(labels)


对于lamb的默认值0. 0,我的意图是它应该给予与(正确的)未正则化版本相同的结果,但实际上差异相当大。
有人知道问题出在哪吗?

qmb5sa22

qmb5sa221#

问题是:
features.transpose().dot(features)可能是不可逆的。根据文献,numpy.linalg.inv仅适用于满秩矩阵。然而,(非零)正则化项总是使方程非奇异。
顺便说一下,你对实现的看法是对的。但是它效率不高。解决这个方程的一个有效方法是最小二乘法。
np.linalg.lstsq(features, labels)可以为np.linalg.pinv(features).dot(labels)工作。
一般来说,你可以这样做

def get_model(A, y, lamb=0):
    n_col = A.shape[1]
    return np.linalg.lstsq(A.T.dot(A) + lamb * np.identity(n_col), A.T.dot(y))

字符串

holgip5t

holgip5t2#

由于接受的答案使用np.linalg.lstsq,它在幕后使用SVD(如果我记得LAPACK dgelsd是如何正确工作的),我将展示另一种使用SVD进行正则化回归的方法:

def lsqL2(A, y, lamb=1e-10):
    U,S,Vt = np.linalg.svd(A, full_matrices=False)
    return Vt.T@((U.T@y)*(S/(S**2+lamb)))

字符串
注解中提到使用np.linalg.solve,因此这里是另一个求解法方程的方法

def lsqL2b(A, y, lamb=1e-10):
    AtA = A.T@A
    AtA.flat[::AtA.shape[0]+1] += lamb
    return np.linalg.solve(AtA, A.T@y)

相关问题