c++ 如何将枚举变量传递给具有枚举模板专门化的结构

doinxwow  于 2024-01-09  发布在  其他
关注(0)|答案(2)|浏览(163)

我为enum设计了一个带有模板专门化的结构体,像这样:

  1. template<DataType type>
  2. struct TypeTrait;
  3. template<>
  4. struct TypeTrait<DATA_TYPE_INT8> {
  5. static constexpr uint32_t size = sizeof(int8_t);
  6. };
  7. template<>
  8. struct TypeTrait<DATA_TYPE_INT16> {
  9. static constexpr uint32_t size = sizeof(int16_t);
  10. };
  11. template<>
  12. struct TypeTrait<DATA_TYPE_FP16> {
  13. static constexpr uint32_t size = sizeof(uint16_t);
  14. };
  15. template<>
  16. struct TypeTrait<DATA_TYPE_UINT8> {
  17. static constexpr uint32_t size = sizeof(uint8_t);
  18. };
  19. template<>
  20. struct TypeTrait<DATA_TYPE_UINT16> {
  21. static constexpr uint32_t size = sizeof(uint16_t);
  22. };
  23. template<>
  24. struct TypeTrait<DATA_TYPE_INT32> {
  25. static constexpr uint32_t size = sizeof(int32_t);
  26. };
  27. template<>
  28. struct TypeTrait<DATA_TYPE_UINT32> {
  29. static constexpr uint32_t size = sizeof(uint32_t);
  30. };
  31. template<>
  32. struct TypeTrait<DATA_TYPE_FP32> {
  33. static constexpr uint32_t size = sizeof(float);
  34. };

字符串
enum DataType是这样定义的:

  1. enum DataType {
  2. DATA_TYPE_INT8 = 0,
  3. DATA_TYPE_INT16 = 1,
  4. DATA_TYPE_FP16 = 2,
  5. DATA_TYPE_UINT8 = 3,
  6. DATA_TYPE_UINT16 = 4,
  7. DATA_TYPE_INT32 = 5,
  8. DATA_TYPE_UINT32 = 6,
  9. DATA_TYPE_FP32 = 7,
  10. DATA_TYPE_UNKOWN
  11. };


我想传递一个DataType变量给结构体TypeTrait,像这样:

  1. class Test {
  2. public:
  3. ...
  4. void Convert() {
  5. ...
  6. uint32_t size = TypeTrait<type_>::size;
  7. ...
  8. }
  9. private:
  10. DataType type_;
  11. };


当我这样做时,编译程序时会出现问题:

  1. main.cc: In member function void Test::Convert()’:
  2. main.cc:63:35: error: use of this in a constant expression
  3. 63 | uint32_t size = TypeTrait<type_>::size;
  4. | ^~~~~
  5. main.cc:63:40: error: use of this in a constant expression
  6. 63 | uint32_t size = TypeTrait<type_>::size;
  7. | ^
  8. main.cc:63:35: note: in template argument for type DataType
  9. 63 | uint32_t size = TypeTrait<type_>::size;
  10. | ^~~~~ ^


我尝试了很多方法,比如将type_转换为const值,如下所示:

  1. const DataType dataType = type_;
  2. uint32_t size = TypeTrait<dataType>::size;


然后这个问题出现了。

  1. main.cc: In member function void Test::Convert()’:
  2. main.cc:63:39: error: the value of type is not usable in a constant expression
  3. 63 | uint32_t size = TypeTrait<type>::size;
  4. | ^
  5. main.cc:62:24: note: type was not initialized with a constant expression
  6. 62 | const DataType type = GetType();
  7. | ^~~~
  8. main.cc:63:39: note: in template argument for type DataType
  9. 63 | uint32_t size = TypeTrait<type>::size;
  10. |


我知道如果我像这样传递枚举元素,程序将不会有问题。

  1. uint32_t size = TypeTrait<DataType::DATA_TYPE_UINT32>::size;


我没有办法解决这个问题。所以我必须使用switch case来处理这个问题,这是违背我的意愿的。我只是想在我的代码中删除switch case。要重构的代码:

  1. switch (dataType_) {
  2. case DATA_TYPE_INT8:
  3. byteSize = elemCnt * sizeof(int8_t);
  4. break;
  5. case DATA_TYPE_INT16:
  6. byteSize = elemCnt * sizeof(int16_t);
  7. break;
  8. case DATA_TYPE_FP16:
  9. byteSize = elemCnt * sizeof(uint16_t);
  10. break;
  11. case DATA_TYPE_UINT8:
  12. byteSize = elemCnt * sizeof(uint8_t);
  13. break;
  14. case DATA_TYPE_UINT16:
  15. byteSize = elemCnt * sizeof(uint16_t);
  16. break;
  17. case DATA_TYPE_INT32:
  18. byteSize = elemCnt * sizeof(int32_t);
  19. break;
  20. case DATA_TYPE_UINT32:
  21. byteSize = elemCnt * sizeof(uint32_t);
  22. break;
  23. case DATA_TYPE_FP32:
  24. byteSize = elemCnt * sizeof(float);
  25. break;
  26. }

rjjhvcjd

rjjhvcjd1#

下面是一种在运行时获取TypeTrait<type>::size的方法,而无需开关(需要C++17):

  1. uint32_t datatypeSize(DataType type) {
  2. return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
  3. return ((static_cast<std::size_t>(type) == Is ? TypeTrait<static_cast<DataType>(Is)>::size : 0) + ...);
  4. }(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
  5. }

字符串
Demo
另一个,使用std::array(来自Jarod 42的评论- C++17):

  1. uint32_t datatypeSize(DataType type) {
  2. return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
  3. return std::array{TypeTrait<static_cast<DataType>(Is)>::size...}[type];
  4. }(std::make_index_sequence<DATA_TYPE_UNKOWN>{});
  5. }

展开查看全部
eaf3rand

eaf3rand2#

在不使用模板专门化的情况下添加另一个答案。
我认为这种实现方式更加方便和简洁

  1. #define DefineHelper(XX)\
  2. XX(DATA_TYPE_INT8, sizeof(int8_t), "DATA_TYPE_FP32")\
  3. XX(DATA_TYPE_INT16, sizeof(int16_t), "DATA_TYPE_FP32")\
  4. XX(DATA_TYPE_FP16, sizeof(uint16_t), "DATA_TYPE_FP32")\
  5. XX(DATA_TYPE_UINT8, sizeof(uint8_t), "DATA_TYPE_FP32")\
  6. XX(DATA_TYPE_UINT16, sizeof(uint16_t), "DATA_TYPE_FP32")\
  7. XX(DATA_TYPE_INT32, sizeof(int32_t), "DATA_TYPE_FP32")
  8. int32_t GetDataTypeSize(DataType e){
  9. #define TypeSize(e, n, _) case e: return n;
  10. switch(e) {
  11. DefineHelper(TypeSize)
  12. default:
  13. return 0;
  14. }
  15. #undef TypeSize
  16. }
  17. const char* GetDataTypeStr(DataType e) {
  18. #define TypeStr(e, _, s) case e: return s;
  19. switch(e)
  20. {
  21. DefineHelper(TypeStr)
  22. default:
  23. return "unknowntype";
  24. }
  25. #undef TypeStr
  26. }
  27. int main(int argc, const char* argv[]) {
  28. DataType type = static_cast<DataType>(atoi(argv[1]));
  29. printf("typesize:%d desc:%s\n",
  30. GetDataTypeSize(type), GetDataTypeStr(type));
  31. return 0;
  32. }

字符串

展开查看全部

相关问题