c++ 如何将枚举变量传递给具有枚举模板专门化的结构

doinxwow  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(145)

我为enum设计了一个带有模板专门化的结构体,像这样:

template<DataType type>
struct TypeTrait;

template<>
struct TypeTrait<DATA_TYPE_INT8> {
    static constexpr uint32_t size = sizeof(int8_t);
};

template<>
struct TypeTrait<DATA_TYPE_INT16> {
    static constexpr uint32_t size = sizeof(int16_t);
};

template<>
struct TypeTrait<DATA_TYPE_FP16> {
    static constexpr uint32_t size = sizeof(uint16_t);
};

template<>
struct TypeTrait<DATA_TYPE_UINT8> {
    static constexpr uint32_t size = sizeof(uint8_t);
};

template<>
struct TypeTrait<DATA_TYPE_UINT16> {
    static constexpr uint32_t size = sizeof(uint16_t);
};

template<>
struct TypeTrait<DATA_TYPE_INT32> {
    static constexpr uint32_t size = sizeof(int32_t);
};

template<>
struct TypeTrait<DATA_TYPE_UINT32> {
    static constexpr uint32_t size = sizeof(uint32_t);
};

template<>
struct TypeTrait<DATA_TYPE_FP32> {
    static constexpr uint32_t size = sizeof(float);
};

字符串
enum DataType是这样定义的:

enum DataType {
    DATA_TYPE_INT8 = 0,
    DATA_TYPE_INT16 = 1,
    DATA_TYPE_FP16 = 2,
    DATA_TYPE_UINT8 = 3,
    DATA_TYPE_UINT16 = 4,
    DATA_TYPE_INT32 = 5,
    DATA_TYPE_UINT32 = 6,
    DATA_TYPE_FP32 = 7,
    DATA_TYPE_UNKOWN
};


我想传递一个DataType变量给结构体TypeTrait,像这样:

class Test {
public:
    ...

    void Convert() {
        ...
        uint32_t size = TypeTrait<type_>::size;
        ...
    }
private:
    DataType type_;
};


当我这样做时,编译程序时会出现问题:

main.cc: In member function ‘void Test::Convert()’:
main.cc:63:35: error: use of ‘this’ in a constant expression
   63 |         uint32_t size = TypeTrait<type_>::size;
      |                                   ^~~~~
main.cc:63:40: error: use of ‘this’ in a constant expression
   63 |         uint32_t size = TypeTrait<type_>::size;
      |                                        ^
main.cc:63:35: note: in template argument for type ‘DataType’
   63 |         uint32_t size = TypeTrait<type_>::size;
      |                                   ^~~~~                              ^


我尝试了很多方法,比如将type_转换为const值,如下所示:

const DataType dataType = type_;
uint32_t size = TypeTrait<dataType>::size;


然后这个问题出现了。

main.cc: In member function ‘void Test::Convert()’:
main.cc:63:39: error: the value of ‘type’ is not usable in a constant expression
   63 |         uint32_t size = TypeTrait<type>::size;
      |                                       ^
main.cc:62:24: note: ‘type’ was not initialized with a constant expression
   62 |         const DataType type = GetType();
      |                        ^~~~
main.cc:63:39: note: in template argument for type ‘DataType’
   63 |         uint32_t size = TypeTrait<type>::size;
      |


我知道如果我像这样传递枚举元素,程序将不会有问题。

uint32_t size = TypeTrait<DataType::DATA_TYPE_UINT32>::size;


我没有办法解决这个问题。所以我必须使用switch case来处理这个问题,这是违背我的意愿的。我只是想在我的代码中删除switch case。要重构的代码:

switch (dataType_) {
        case DATA_TYPE_INT8:
            byteSize = elemCnt * sizeof(int8_t);
            break;
        case DATA_TYPE_INT16:
            byteSize = elemCnt * sizeof(int16_t);
            break;
        case DATA_TYPE_FP16:
            byteSize = elemCnt * sizeof(uint16_t);
            break;
        case DATA_TYPE_UINT8:
            byteSize = elemCnt * sizeof(uint8_t);
            break;
        case DATA_TYPE_UINT16:
            byteSize = elemCnt * sizeof(uint16_t);
            break;
        case DATA_TYPE_INT32:
            byteSize = elemCnt * sizeof(int32_t);
            break;
        case DATA_TYPE_UINT32:
            byteSize = elemCnt * sizeof(uint32_t);
            break;
        case DATA_TYPE_FP32:
            byteSize = elemCnt * sizeof(float);
            break;
    }

rjjhvcjd

rjjhvcjd1#

下面是一种在运行时获取TypeTrait<type>::size的方法,而无需开关(需要C++17):

uint32_t datatypeSize(DataType type) {
    return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
        return ((static_cast<std::size_t>(type) == Is ? TypeTrait<static_cast<DataType>(Is)>::size : 0) + ...);
    }(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
}

字符串
Demo
另一个,使用std::array(来自Jarod 42的评论- C++17):

uint32_t datatypeSize(DataType type) {
    return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
        return std::array{TypeTrait<static_cast<DataType>(Is)>::size...}[type];
    }(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
}

eaf3rand

eaf3rand2#

在不使用模板专门化的情况下添加另一个答案。
我认为这种实现方式更加方便和简洁

#define DefineHelper(XX)\
  XX(DATA_TYPE_INT8, sizeof(int8_t), "DATA_TYPE_FP32")\
  XX(DATA_TYPE_INT16, sizeof(int16_t), "DATA_TYPE_FP32")\
  XX(DATA_TYPE_FP16, sizeof(uint16_t), "DATA_TYPE_FP32")\
  XX(DATA_TYPE_UINT8, sizeof(uint8_t), "DATA_TYPE_FP32")\
  XX(DATA_TYPE_UINT16, sizeof(uint16_t), "DATA_TYPE_FP32")\
  XX(DATA_TYPE_INT32, sizeof(int32_t), "DATA_TYPE_FP32")

int32_t GetDataTypeSize(DataType e){ 
#define TypeSize(e, n, _) case e: return n;

  switch(e) {
    DefineHelper(TypeSize)
    default:
      return 0;
  }
#undef TypeSize
}

const char* GetDataTypeStr(DataType e) {
#define TypeStr(e, _, s) case e: return s;
  switch(e)
  {
    DefineHelper(TypeStr)
    default:
      return "unknowntype";
  }
#undef TypeStr
}

int main(int argc, const char* argv[]) {
  DataType type = static_cast<DataType>(atoi(argv[1]));
  printf("typesize:%d desc:%s\n", 
         GetDataTypeSize(type), GetDataTypeStr(type));
  return 0;
}

字符串

相关问题