@NHZlX ,
Is it really necessary to reallocate the tmp buffer and copy these pointers to the GPU every time you enqueue? If pointers keep the same, following code - emb_eltwise_layernorm_plugin.cu#L173, should be run only once.
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(...)
{
// allocate GPU buffer
framework::Tensor in_ptr_tensor, emb_ptr_tensor;
in_ptr_tensor.Resize({input_num});
emb_ptr_tensor.Resize({input_num});
int64_t *in_ptr_gpu_d =
in_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
int64_t *emb_ptr_gpu_d =
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
// allocate CPU buffer and assign GPU pointers to it
std::vector<int64_t> in_ptr, emb_ptr;
for (int i = 0; i < input_num; i++) {
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
}
// copy CPU buffer to GPU buffer
cudaMemcpyAsync(in_ptr_gpu_d, in_ptr.data(), ...);
cudaMemcpyAsync(emb_ptr_gpu_d, emb_ptr.data(), ...);
...
}
For example, embs_gpu_
is a data member, once created, it should not change during the entire life cycle, so it can be handled at EmbEltwiseLayernormPluginDynamic<T>::initialize
.
On the other hand, if the pointers of inputs
don't change, the code can be modified like this:
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(...)
{
// in_ptr_tensor_ should be a data member
if (in_ptr_tensor_.size() != input_num) { // only run once
in_ptr_tensor_.Resize({input_num});
int64_t *in_ptr_gpu_d =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(in_ptr_gpu_d, inputs, input_num*sizeof(size_t), ...);
}
...
}
Or if only amount of pointers doesn't change:
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(...)
{
// in_ptr_tensor_ should be a data member
if (in_ptr_tensor_.size() != input_num) { // only run once
in_ptr_tensor_.Resize({input_num});
}
int64_t *in_ptr_gpu_d =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(in_ptr_gpu_d, inputs, input_num*sizeof(size_t), ...);
...
}
1条答案
按热度按时间mrphzbgm1#
The patch for this issue is PR #25205