c++ 使用SSE的水平最小值和最大值

zaqlnxep  于 2023-01-15  发布在  其他
关注(0)|答案(3)|浏览(251)

我有一个函数使用SSE来做很多事情,分析器显示我用来计算水平最小值和最大值的代码部分占用了大部分时间。
我一直在使用以下实现作为最低限度的例子:

static inline int16_t hMin(__m128i buffer) {
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi8(buffer, m1));
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi8(buffer, m2));
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi8(buffer, m3));
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi8(buffer, m4));
    return ((int8_t*) ((void *) &buffer))[0];
}

我需要计算16个1字节整数的最小值和最大值,如您所见。
任何好的建议都非常感谢:)
谢谢

lb3vh1jj

lb3vh1jj1#

SSE 4.1有一条指令几乎可以做你想做的事情。它的名字是PHMINPOSUW,C/C++内部函数是_mm_minpos_epu16。它被限制为16位无符号值,并且不能给予最大值,但是这些问题可以很容易地解决。
1.如果你需要找到非负字节的最小值,什么也不做。如果字节可能是负的,给每个字节加上128。如果你需要最大值,从127减去每个字节。
1.使用_mm_srli_pi16_mm_shuffle_epi8,然后使用_mm_min_epu8来获取某个XMM寄存器的偶数字节中的8个成对最小值和奇数字节中的零。(这些零由移位/混洗指令产生,并且应该保留在_mm_min_epu8之后的位置)。
1.使用_mm_minpos_epu16查找这些值中的最小值。
1.用_mm_cvtsi128_si32提取得到的最小值。
1.撤消步骤1的效果以获得原始字节值。
下面是返回最多16个有符号字节的示例:

static inline int16_t hMax(__m128i buffer)
{
    __m128i tmp1 = _mm_sub_epi8(_mm_set1_epi8(127), buffer);
    __m128i tmp2 = _mm_min_epu8(tmp1, _mm_srli_epi16(tmp1, 8));
    __m128i tmp3 = _mm_minpos_epu16(tmp2);
    return (int8_t)(127 - _mm_cvtsi128_si32(tmp3));
}
w6lpcovy

w6lpcovy2#

我建议作出两项修改:

  • ((int8_t*) ((void *) &buffer))[0]替换为_mm_cvtsi128_si32
  • _mm_shuffle_epi8替换为_mm_shuffle_epi32/_mm_shufflelo_epi16,它们在最新的AMD处理器和Intel Atom上具有更低的延迟,并将保存内存加载操作:
static inline int16_t hMin(__m128i buffer)
{
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi32(buffer, _MM_SHUFFLE(3, 2, 3, 2)));
    buffer = _mm_min_epi8(buffer, _mm_shuffle_epi32(buffer, _MM_SHUFFLE(1, 1, 1, 1)));
    buffer = _mm_min_epi8(buffer, _mm_shufflelo_epi16(buffer, _MM_SHUFFLE(1, 1, 1, 1)));
    buffer = _mm_min_epi8(buffer, _mm_srli_epi16(buffer, 8));
    return (int8_t)_mm_cvtsi128_si32(buffer);
}
lokaqttq

lokaqttq3#

这是一个没有shuffle实现,由于某种原因,Shuffle在AMD5000Ryzen7上运行速度很慢

float max_elem3() const {
        __m128 a = _mm_unpacklo_ps(mm, mm); // x x y y
        __m128 b = _mm_unpackhi_ps(mm, mm); // z z w w
        __m128 c = _mm_max_ps(a, b); // ..., max(x, z), ..., ...
        Vector4 res = _mm_max_ps(mm, c); // ..., max(y, max(x, z)), ..., ...
        return res.y;
    }

    float min_elem3() const {
        __m128 a = _mm_unpacklo_ps(mm, mm); // x x y y
        __m128 b = _mm_unpackhi_ps(mm, mm); // z z w w
        __m128 c = _mm_min_ps(a, b); // ..., min(x, z), ..., ...
        Vector4 res = _mm_min_ps(mm, c); // ..., min(y, min(x, z)), ..., ...
        return res.y;
    }

相关问题