gcc 搜索14个整数的数组,构建掩码,并使用氖在ARMv8a上返回匹配项

wfveoks0  于 2022-11-12  发布在  其他
关注(0)|答案(1)|浏览(198)

对于我的开源项目cachegrand,我们正在实现AARCH 64支持,虽然大部分端口已经完成,但我们正在整理一个特性,使用氖指令执行加速数组搜索。
我们使用的逻辑非常简单:

  • 在input中,有一个包含14个uint 32元素的数组、要查找的值和用于忽略某些匹配项的掩码
  • 代码必须查找与特定uint 32匹配的任何值
  • 生成位掩码
  • 位掩码的最低有效位与数组的开始匹配
  • 然后,位掩码与跳过索引掩码一起&
  • 然后对尾随零进行计数以确定第一次出现的索引

这是一个非常罕见的情况下,跳过索引掩码实际上是使用,我会说,99.9%的情况下将是零。
我已经提出了以下实现,但我没有ARMv 8氖指令的经验,感觉有点笨重,特别是我想知道是否有一种方法可以使它更快和/或更好。
仅供参考,目前该代码仅使用GCC编译。

uint8_t hashtable_mcmp_support_hash_search_armv8a_neon_14(
        uint32_t hash,
        volatile uint32_t* hashes,
        uint32_t skip_indexes_mask) {
    uint32x4_t tmp;
    uint32_t compacted_result_mask = 0;
    uint32_t skip_indexes_mask_inv = ~skip_indexes_mask;
    static const int32x4_t shift = {0, 1, 2, 3};

    uint32x4_t cmp_vector = vdupq_n_u32(hash);

    uint32x4_t ring_vector_0_3 = vld1q_u32((hashtable_hash_half_t*)hashes + 0);
    uint32x4_t cmp_vector_0_3 = vceqq_u32(ring_vector_0_3, cmp_vector);
    tmp = vshrq_n_u32(cmp_vector_0_3, 31);
    compacted_result_mask |=  vaddvq_u32(vshlq_u32(tmp, shift)) << 0;

    uint32x4_t ring_vector_4_7 = vld1q_u32((hashtable_hash_half_t*)hashes + 4);
    uint32x4_t cmp_vector_4_7 = vceqq_u32(ring_vector_4_7, cmp_vector);
    tmp = vshrq_n_u32(cmp_vector_4_7, 31);
    compacted_result_mask |=  vaddvq_u32(vshlq_u32(tmp, shift)) << 4;

    uint32x4_t ring_vector_8_11 = vld1q_u32((hashtable_hash_half_t*)hashes + 8);
    uint32x4_t cmp_vector_8_11 = vceqq_u32(ring_vector_8_11, cmp_vector);
    tmp = vshrq_n_u32(cmp_vector_8_11, 31);
    compacted_result_mask |=  vaddvq_u32(vshlq_u32(tmp, shift)) << 8;

    uint32x4_t ring_vector_10_13 = vld1q_u32((hashtable_hash_half_t*)hashes + 10);
    uint32x4_t cmp_vector_10_13 = vceqq_u32(ring_vector_10_13, cmp_vector);
    tmp = vshrq_n_u32(cmp_vector_10_13, 31);
    compacted_result_mask |=  vaddvq_u32(vshlq_u32(tmp, shift)) << 10;

    return __builtin_ctz(compacted_result_mask & skip_indexes_mask_inv);
}

仅供参考,此处为AVX 2代码

static inline uint8_t hashtable_mcmp_support_hash_search_avx2_14(
        uint32_t hash,
        volatile uint32_t* hashes,
        uint32_t skip_indexes_mask) {
    uint32_t compacted_result_mask = 0;
    uint32_t skip_indexes_mask_inv = ~skip_indexes_mask;
    __m256i cmp_vector = _mm256_set1_epi32(hash);

    // The second load, load from the 6th uint32 to the 14th uint32, _mm256_loadu_si256 always loads 8 x uint32
    for(uint8_t base_index = 0; base_index < 12; base_index += 6) {
        __m256i ring_vector = _mm256_loadu_si256((__m256i*) (hashes + base_index));
        __m256i result_mask_vector = _mm256_cmpeq_epi32(ring_vector, cmp_vector);

        // Uses _mm256_movemask_ps to reduce the bandwidth
        compacted_result_mask |= (uint32_t)_mm256_movemask_ps(_mm256_castsi256_ps(result_mask_vector)) << (base_index);
    }

    return _tzcnt_u32(compacted_result_mask & skip_indexes_mask_inv);
}

另一个附带问题是,您认为支持SVE 2指令值得吗?特别是考虑到这是一个非常简单的操作,而且看起来可能不强制支持256位寄存器(这可能是在此特定环境中使用SVE 2的最大好处)

xdnvmnnf

xdnvmnnf1#

每个布尔值不需要32位:在进行进一步操作之前,通过vuzp1vomovn将它们尽快缩小到8位。

uint8_t hashtable_mcmp_support_hash_search_armv8a_neon_14(
        uint32_t hash,
        volatile uint32_t* hashes,
        uint32_t skip_indexes_mask)
{

    uint16x8_t tmp16a, tmp16b;
    uint8x8_t tmp8a, tmp8b;
    uint32_t tmp;
    static const uint8x8_t mask = {1, 2, 4, 8, 16, 32, 64, 128};
    uint32x4_t cmp_vector = vdupq_n_u32(hash);

    uint32x4x3_t ring_vector_0_11 = vld1q_u32_x3((uint32_t *)hashes);
    uint32x4_t ring_vector_10_13 = vld1q_u32((uint32_t *)hashes+10);

    ring_vector_0_11.val[0] = vceqq_u32(ring_vector_0_11.val[0], cmp_vector);
    ring_vector_0_11.val[1] = vceqq_u32(ring_vector_0_11.val[1], cmp_vector);
    ring_vector_0_11.val[2] = vceqq_u32(ring_vector_0_11.val[2], cmp_vector);
    ring_vector_10_13 = vceqq_u32(ring_vector_10_13, cmp_vector);

    tmp16a = vuzp1q_u16(ring_vector_0_11.val[0], ring_vector_0_11.val[1]);
    tmp16b = vuzp1q_u16(ring_vector_0_11.val[2], ring_vector_10_13);

    tmp8a = vmovn_u16(tmp16a);
    tmp8b = vmovn_u16(tmp16b);

    tmp8a = vand_u8(tmp8a, mask);
    tmp8b = vand_u8(tmp8b, mask);

    tmp = (uint32_t)vaddv_u8(tmp8a) | (uint32_t)(vaddv_u8(tmp8b)<<8);

    return __builtin_ctz(tmp &~ skip_indexes_mask);
}

而且我不认为sve会带来有意义的性能提升,因为性能在最后或多或少会受到影响(vaddv,尤其是向arm寄存器的传输)
如果您要处理数千个14入口数组,您应该考虑重新设计您的函数以写入8位数组,而不是每次都返回arm寄存器。这将消除由氖到arm传输引起的最耗时的流水线危险。

#include <arm_neon.h>
#include <arm_acle.h>

void hashtable_mcmp_support_hash_search_armv8a_neon_14_b(
        uint8_t *pDst,
        uint32_t hash,
        volatile uint32_t* hashes,
        uint32_t skip_indexes_mask, uint32_t number_of_arrays)
{

    uint16x8_t tmp16a, tmp16b;
    uint16x4_t tmp;
    uint8x8_t tmp8a, tmp8b;
    static const uint8x8_t mask = {128, 64, 32, 16, 8, 4, 2, 1};
    uint32x4_t cmp_vector = vdupq_n_u32(hash);

    skip_indexes_mask = __rbit(skip_indexes_mask)>>16;
    uint16x4_t index_mask = vdup_n_u16((uint16_t) skip_indexes_mask);
    uint32x4x4_t ring_vector;

    while (number_of_arrays--)
    {
        ring_vector = vld1q_u32_x4((uint32_t *)hashes);
        hashes += 16;

        ring_vector.val[0] = vceqq_u32(ring_vector.val[0], cmp_vector);
        ring_vector.val[1] = vceqq_u32(ring_vector.val[1], cmp_vector);
        ring_vector.val[2] = vceqq_u32(ring_vector.val[2], cmp_vector);
        ring_vector.val[3] = vceqq_u32(ring_vector.val[3], cmp_vector);

        tmp16a = vuzp1q_u16(vreinterpretq_u16_u32(ring_vector.val[0]), vreinterpretq_u16_u32(ring_vector.val[1]));
        tmp16b = vuzp1q_u16(vreinterpretq_u16_u32(ring_vector.val[2]), vreinterpretq_u16_u32(ring_vector.val[3]));

        tmp8a = vmovn_u16(tmp16a);
        tmp8b = vmovn_u16(tmp16b);

        tmp8a = vand_u8(tmp8a, mask);
        tmp8b = vand_u8(tmp8b, mask);

        tmp8a[1] = vaddv_u8(tmp8a);
        tmp8a[0] = vaddv_u8(tmp8b);

        tmp = vbic_u16(vreinterpret_u16_u8(tmp8a), index_mask);
        tmp = vclz_u16(tmp);

        vst1_lane_u8(pDst++,vreinterpret_u8_u16(tmp), 0);
    }
}

以上是一个“改进”版本

  • 它假定数组位于连续内存中,并具有8字节填充,这该高速缓存效率是优选的,除非内存需求是一个问题。
  • 它不返回8位结果,而是直接将结果写入存储器,避免了氖到臂传输引起的流水线危险。
  • 它仍然存在vaddv延迟(8个周期)。您可以展开循环,使其每次迭代处理2个甚至4个数组,以隐藏延迟。

相关问题