c++ 优化写入/更新数组的新值

8ehkhllq  于 2023-04-08  发布在  其他
关注(0)|答案(2)|浏览(126)

我正在编写一个代码,计算所有粒子之间的距离,并计算它们之间的力。代码如下:

void calcCoulombicRepForces(double xc[], double yc[], double zc[], double fx[], double fy[], double fz[], int Np, double q1q2){
    double x1,y1,z1,xhat,yhat,zhat,dist,distcube;
    for (int i  = 0; i < Np; i++){
        double sumX = 0.0; double sumY = 0.0; double sumZ = 0.0; 
        x1 = xc[i];
        y1 = yc[i];
        z1 = zc[i];
        for (int j = 0; j < Np; j++){
            xhat = xc[j]-x1;
            yhat = yc[j]-y1;
            zhat = zc[j]-z1;
            dist = sqrt(xhat*xhat + yhat*yhat + zhat*zhat);
            distcube = dist*dist*dist;
            sumX += xhat/distcube;
            sumY += yhat/distcube;
            sumZ += zhat/distcube;
        }
        fx[i] = -q1q2*sumX;
        fy[i] = -q1q2*sumY;
        fz[i] = -q1q2*sumZ;
    }
}

xc、yc和zc是坐标。fx、fy和fz是力。这里每个数组的大小为Np ~ 58000个元素。
代码工作正常,但该函数所花费的99%的时间都花在写入/更新数组内部的值上,即这三行

fx[i] = -q1q2*sumX;
    fy[i] = -q1q2*sumY;
    fz[i] = -q1q2*sumZ;

我在某个地方读到它可能是因为它试图从更高级别的缓存或RAM访问数组fx,fy和fz。如果是这样,有没有一种方法可以优化这三行,使数组写入/更新变得更快?如果没有,我在这里做错了什么?如果我错过了一些我应该包括的东西,请让我知道。

7fyelxc5

7fyelxc51#

你的基准测试方法只会产生无意义的结果。
如果你去掉了这三个赋值,那么函数就没有任何副作用了,它也不会返回任何东西,所以编译器会优化掉整个函数体。
这就是为什么显然所有的时间都花在了三个赋值上。

rkue9o1l

rkue9o1l2#

你计算距离(和其他值)的次数太多了。你应该只为每对点计算一次。这段代码应该会对你有所帮助。它应该会在~2x时间内运行得更快,因为你的函数处理的是n * (n -1)对,而我的函数将处理(n * (n -1) /2)对(所有唯一的对只进动一次)。
此外,正如有人在评论中提到的,制作一个结构体来保存这些xyz和每个点的剩余值,然后用这些结构体的数组来加速代码。

void calcCoulombicRepForces(double xc[], double yc[], double zc[], double fx[], double fy[], double fz[], int Np, double q1q2){
    double x1,y1,z1,xhat,yhat,zhat,dist,distcube;
    for (int i  = 0; i < Np; i++){
        
        x1 = xc[i];
        y1 = yc[i];
        z1 = zc[i];
        for (int j = i + 1; j < Np; j++){
            xhat = xc[j]-x1;
            yhat = yc[j]-y1;
            zhat = zc[j]-z1;
            dist = sqrt(xhat*xhat + yhat*yhat + zhat*zhat);
            distcube = dist*dist*dist;

            double xVal = xhat/distcube;
            double yVal =yhat/distcube;
            double zVal =zhat/distcube;

            fx[i]+= xVal;
            fy[i]+= yVal;
            fz[i]+= zVal;

            fx[j] -= xVal;
            fy[j] -= yVal;
            fz[j] -= zVal;
        }
    }
    for (int i  = 0; i < Np; i++){
        fx[i] = -q1q2*fy[i];
        fy[i] = -q1q2*fy[i];
        fz[i] = -q1q2*fz[i];
    }
}

相关问题