x86 Intrinsic:优化复杂浮点数的矩阵乘法

t5fffqht  于 2024-01-06  发布在  其他
关注(0)|答案(1)|浏览(105)

下面的代码用于将复数浮点矩阵(分离的真实的,Imag)乘以浮点矩阵。
我很确定它可以通过重新排序代码来优化,因为加载,存储和乘法的延迟。你能告诉我是否有规则如何优化代码来处理这种延迟吗?

/***************************************************************************************/
void CVector::MatrixMultiply(float* pReA, float* pImA,
                            float* pTranB,
                            float* pOutRe, float* pOutIm,
                            uint32_t RowsA, uint32_t ColsA,
                            uint32_t RowsB, uint32_t ColsB)
{
    float *pSrcReA;
    float* pSrcImA;
    float* pSrcB;
    float* pDstRe = pOutRe;
    float* pDstIm = pOutIm;
    float* pRowReA, * pRowImA;

    __m256 ReSum, ImSum, VecReA, VecImA;
    __m256 *pAvec, *pBvec;
    __m256 VecA, VecB;
    __m128 Low, High, Sum128;
    __m128 Zero128 = _mm_set_ps1(0);

    uint32_t Offset;

    for (int i = 0; i < RowsA; i++)
    {
        Offset = ColsA * i;
        pSrcReA = pReA + Offset;
        pSrcImA = pImA + Offset;
        for (int j = 0; j < ColsB; j++)
        {
            ReSum = _mm256_set1_ps(0);
            ImSum = ReSum;
            pRowReA = pSrcReA;
            pRowImA = pSrcImA;
            pSrcB = pTranB + RowsB * j;

            for (int k = 0; k < (ColsA >> 3); k++)
            {
                VecReA = _mm256_load_ps((float*)pRowReA);
                VecImA = _mm256_load_ps((float*)pRowImA);
                VecB = _mm256_load_ps((float*)pSrcB);

                ReSum = _mm256_fmadd_ps (VecReA, VecB, ReSum);
                ImSum = _mm256_fmadd_ps (VecImA, VecB, ImSum);

                pRowReA += 8;
                pRowImA += 8;
            }

            Low = _mm256_extractf128_ps(ReSum, 0);
            High = _mm256_extractf128_ps(ReSum, 1);
            Sum128 = _mm_add_ps(Low, High);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            *pDstRe = _mm_cvtss_f32(Sum128);

            Low = _mm256_extractf128_ps(ImSum, 0);
            High = _mm256_extractf128_ps(ImSum, 1);
            Sum128 = _mm_add_ps(Low, High);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            Sum128 = _mm_hadd_ps(Sum128, Zero128);
            *pDstIm = _mm_cvtss_f32(Sum128);

            pDstRe++;
            pDstIm++;
        }
    }
}

字符串

llmtgqce

llmtgqce1#

代码的最大性能问题是(在大多数CPU上)fmadd的延迟为4-5个周期,但吞吐量倒数为0.5(即,可以同时执行两个独立的FMA)-- source:uops.info
要获得完整的吞吐量,需要执行8(或在某些CPU上10)内部循环中的独立FMA操作。例如,有8个独立的ReSum0..3ImSum0..3,并通过8个{VecReA, VecImA} * VecB0..3乘积累加到它们。我不写出来,因为我不完全理解你的代码,例如,为什么在k循环中不递增pSrcB?你确定ColsA==RowsB和它们是8的倍数吗?

相关问题