pytorch 找到aX1+bX2+cX3+.....mXn=d的N个解,Python

hwamh0ep  于 2023-04-06  发布在  Python
关注(0)|答案(2)|浏览(116)

我怎样才能找到任意N个(比如30个)aX1+bX2+cX3+.....mXn=d的解(其中n,也称为这个空间的维数,可以是大于2的整数,并且0〈=Xn〈=1)。

weights = torch.tensor([a,b,c....m])
# X is a  tensor with the same size of w
# What I want do is to find a tensor X that qualified for:
(weights*X).sum() = d

当维数为2时,我随机生成一个Tensor,如下所示:

u = 0.5
t = torch.rand(2)
if t*weights == d:
   return t

当维数大于2时,这个方法变得非常慢。有更好的解决方案吗?

jyztefdp

jyztefdp1#

随机设置权重中除m以外的所有值,然后根据这些值只找到m,怎么样?
验证码:

import torch

N,d = 10, 10

# define X, W except m
X = torch.rand(N)
W = torch.rand(N-1)*4

# find m based on other weights
m = ((d - torch.dot(X[:N-1], W))/X[N-1]).unsqueeze(dim=0)
W = torch.cat((W, m))

print(X)
print(W)
print(torch.dot(X, W))

结果:

tensor([0.1062, 0.0361, 0.9462, 0.0534, 0.0591, 0.5729, 0.1521, 0.9087, 0.1210,
        0.7654])
tensor([2.3289, 0.4069, 3.8243, 2.2443, 1.1903, 0.6269, 0.2839, 3.9864, 0.4654,
        2.4148])
tensor(10.)
vsikbqxv

vsikbqxv2#

我自己用线性代数找到了一个简单的解,这个解适合空间[0,1]^N

class A:
    def __init__(dim,weights):
        self.dim = dim
        self.weights = weights

    def gen_solution(self)->torch.Tensor:
        w = self.weights
        s = self.target

        v = torch.rand(self.dim)
        v = v -(w*v).sum() * w / (w **2).sum()
        vmin = v.min()
        if vmin < -s:
            v = v * s / (-vmin)
        vmax = v.max()
        if vmax > (1-s):
            v = v * (1-s) / vmax
        solution =  v + s
        return solution

相关问题