Skip to content

Commit f052a55

Browse files
committed
cuda with CC<70 and hip do not support 16 bit atomic. throw error for idr
1 parent 3903625 commit f052a55

1 file changed

Lines changed: 25 additions & 9 deletions

File tree

common/cuda_hip/solver/idr_kernels.cpp

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,19 @@ void update_g_and_u(std::shared_ptr<const DefaultExecutor> exec,
454454
if (nrhs > 1 || is_complex<ValueType>()) {
455455
components::fill_array(exec, alpha->get_values(), nrhs,
456456
zero<ValueType>());
457-
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
458-
size, nrhs, as_device_type(p_i),
459-
as_device_type(g_k->get_values()), g_k->get_stride(),
460-
as_device_type(alpha->get_values()),
461-
stop_status->get_const_data());
457+
// not support 16 bit atomic
458+
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
459+
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
460+
GKO_NOT_SUPPORTED(alpha);
461+
} else
462+
#endif
463+
{
464+
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
465+
size, nrhs, as_device_type(p_i),
466+
as_device_type(g_k->get_values()), g_k->get_stride(),
467+
as_device_type(alpha->get_values()),
468+
stop_status->get_const_data());
469+
}
462470
} else {
463471
blas::dot(exec->get_blas_handle(), size, p_i, 1, g_k->get_values(),
464472
g_k->get_stride(), alpha->get_values());
@@ -505,10 +513,18 @@ void update_m(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
505513
auto m_i = m->get_values() + i * m_stride + k * nrhs;
506514
if (nrhs > 1 || is_complex<ValueType>()) {
507515
components::fill_array(exec, m_i, nrhs, zero<ValueType>());
508-
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
509-
size, nrhs, as_device_type(p_i),
510-
as_device_type(g_k->get_const_values()), g_k->get_stride(),
511-
as_device_type(m_i), stop_status->get_const_data());
516+
// not support 16 bit atomic
517+
#if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
518+
if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
519+
GKO_NOT_SUPPORTED(m_i);
520+
} else
521+
#endif
522+
{
523+
multidot_kernel<<<grid_dim, block_dim, 0, exec->get_stream()>>>(
524+
size, nrhs, as_device_type(p_i),
525+
as_device_type(g_k->get_const_values()), g_k->get_stride(),
526+
as_device_type(m_i), stop_status->get_const_data());
527+
}
512528
} else {
513529
blas::dot(exec->get_blas_handle(), size, p_i, 1,
514530
g_k->get_const_values(), g_k->get_stride(), m_i);

0 commit comments

Comments
 (0)