@@ -454,11 +454,19 @@ void update_g_and_u(std::shared_ptr<const DefaultExecutor> exec,
454454 if (nrhs > 1 || is_complex<ValueType>()) {
455455 components::fill_array (exec, alpha->get_values (), nrhs,
456456 zero<ValueType>());
457- multidot_kernel<<<grid_dim, block_dim, 0 , exec->get_stream ()>>>(
458- size, nrhs, as_device_type (p_i),
459- as_device_type (g_k->get_values ()), g_k->get_stride (),
460- as_device_type (alpha->get_values ()),
461- stop_status->get_const_data ());
457+ // not support 16 bit atomic
458+ #if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
459+ if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
460+ GKO_NOT_SUPPORTED (alpha);
461+ } else
462+ #endif
463+ {
464+ multidot_kernel<<<grid_dim, block_dim, 0 , exec->get_stream ()>>>(
465+ size, nrhs, as_device_type (p_i),
466+ as_device_type (g_k->get_values ()), g_k->get_stride (),
467+ as_device_type (alpha->get_values ()),
468+ stop_status->get_const_data ());
469+ }
462470 } else {
463471 blas::dot (exec->get_blas_handle (), size, p_i, 1 , g_k->get_values (),
464472 g_k->get_stride (), alpha->get_values ());
@@ -505,10 +513,18 @@ void update_m(std::shared_ptr<const DefaultExecutor> exec, const size_type nrhs,
505513 auto m_i = m->get_values () + i * m_stride + k * nrhs;
506514 if (nrhs > 1 || is_complex<ValueType>()) {
507515 components::fill_array (exec, m_i, nrhs, zero<ValueType>());
508- multidot_kernel<<<grid_dim, block_dim, 0 , exec->get_stream ()>>>(
509- size, nrhs, as_device_type (p_i),
510- as_device_type (g_k->get_const_values ()), g_k->get_stride (),
511- as_device_type (m_i), stop_status->get_const_data ());
516+ // not support 16 bit atomic
517+ #if !(defined(CUDA_VERSION) && (__CUDA_ARCH__ >= 700))
518+ if constexpr (std::is_same_v<remove_complex<ValueType>, half>) {
519+ GKO_NOT_SUPPORTED (m_i);
520+ } else
521+ #endif
522+ {
523+ multidot_kernel<<<grid_dim, block_dim, 0 , exec->get_stream ()>>>(
524+ size, nrhs, as_device_type (p_i),
525+ as_device_type (g_k->get_const_values ()), g_k->get_stride (),
526+ as_device_type (m_i), stop_status->get_const_data ());
527+ }
512528 } else {
513529 blas::dot (exec->get_blas_handle (), size, p_i, 1 ,
514530 g_k->get_const_values (), g_k->get_stride (), m_i);
0 commit comments