c++ CUBLAS_STATUS_INVALID_VALUE

zsbz8rwp  于 2023-06-25  发布在  其他
关注(0)|答案(1)|浏览(136)

我正在使用cublasSgemmStridedBatched API进行所谓的“Tensor收缩”。我有形状为60000*20*9的TensorA和形状为9*32的TensorB,它们都是行主的。根据定义,C = A * B应该给予形状为60000*20*32的结果TensorC。我写的代码如下:

  1. int batch_count = 60000;
  2. int M = 20;
  3. int K = 9;
  4. int N = 32;
  5. cublasHandle_t handle;
  6. cublasCreate(&handle);
  7. float alpha = 1.0;
  8. float beta = 0.0;
  9. int strideA = 20 * 9;
  10. int strideB = 0;
  11. int strideC = 20 * 32;
  12. // A(60000 * 20 * 9) * B(9 * 32) = C(60000 * 20 * 32)
  13. cublasStatus_t ret = cublasSgemmStridedBatched(
  14. handle,
  15. CUBLAS_OP_N, //transposed, since in row-major
  16. CUBLAS_OP_N, //transposed, since in row-major
  17. N,
  18. M,
  19. K,
  20. &alpha,
  21. B.data<float>(), //already in GPU
  22. N, // lda, transposed
  23. strideB,
  24. A.data<float>(), //already in GPU
  25. K, // ldb, transposed
  26. strideA,
  27. &beta,
  28. C.data<float>(),//already in GPU
  29. N, // ldc
  30. strideC,
  31. batchCount);
  32. cublasDestroy(handle);
  33. if(ret != CUBLAS_STATUS_SUCCESS){
  34. printf("cublasSgemmStridedBatched failed %d line (%d)\n", ret, __LINE__);
  35. }

上面的代码无法完成工作,并一直显示cublasSgemmStridedBatched failed 7,根据manualCUBLAS_STATUS_INVALID_VALUE代表CUBLAS_STATUS_INVALID_VALUE。任何帮助或建议是赞赏!

2ul0zpep

2ul0zpep1#

下面是一个最小的版本,它可以工作并测试结果:

  1. #include <cuda_runtime.h>
  2. #include <cublas_v2.h>
  3. #include <cstdio>
  4. #include <Eigen/Dense>
  5. int main()
  6. {
  7. cublasHandle_t cubl;
  8. cublasCreate(&cubl);
  9. int batch_count = 60000;
  10. int M = 20;
  11. int K = 9;
  12. int N = 32;
  13. float* A, *B, *C;
  14. cudaMallocManaged(&A, sizeof(float) * batch_count * M * K);
  15. cudaMallocManaged(&B, sizeof(float) * K * N);
  16. cudaMallocManaged(&C, sizeof(float) * batch_count * M * N);
  17. for(int b = 0; b < batch_count; ++b)
  18. for(int m = 0; m < M; ++m)
  19. for(int k = 0; k < K; ++k)
  20. A[((b * M) + m) * K + k] = (float)(b + 1) * (m + 2) * (k + 3) / (M*N*K);
  21. for(int k = 0; k < K; ++k)
  22. for(int n = 0; n < N; ++n)
  23. B[k * N + n] = (float) (k + 1) * (n + 2) / (N*K);
  24. const float alpha = 1.f, beta = 0.f;
  25. const int strideA = K * M, strideB = 0, strideC = M * N;
  26. cublasStatus_t ret = cublasSgemmStridedBatched(
  27. cubl, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha,
  28. B, N /*lda*/, strideB,
  29. A, K /*ldb*/, strideA, &beta,
  30. C, N /*ldc*/, strideC, batch_count);
  31. if(ret != CUBLAS_STATUS_SUCCESS)
  32. std::printf("cublasSgemmStridedBatched failed %d line (%d)\n",
  33. ret, __LINE__);
  34. cudaError_t curet = cudaDeviceSynchronize();
  35. std::printf("Device sync: %d\n", ret);
  36. Eigen::ArrayXXf reference, error = Eigen::ArrayXXf::Zero(N, M);
  37. const auto B_map = Eigen::MatrixXf::Map(B, N, K);
  38. for(int b = 0; b < batch_count; ++b) {
  39. const auto A_map = Eigen::MatrixXf::Map(A + strideA * b, K, M);
  40. reference.matrix().noalias() = B_map * A_map;
  41. const auto C_map = Eigen::ArrayXXf::Map(C + strideC * b, N, M);
  42. const auto rel_error = (C_map - reference).abs() /
  43. C_map.abs().max(reference.abs());
  44. error = error.max(rel_error);
  45. }
  46. std::printf("Max relative error %g\n", error.maxCoeff());
  47. }

报告最大相对误差为2.5e-7

展开查看全部

相关问题