1414#include < cusparse.h>
1515#include < thrust/complex.h>
1616
17+ #include < ginkgo/core/base/half.hpp>
1718#include < ginkgo/core/base/matrix_data.hpp>
1819#include < ginkgo/core/base/types.hpp>
1920
2021
2122namespace gko {
2223
23-
2424namespace kernels {
2525namespace cuda {
26-
27-
2826namespace detail {
2927
30-
3128/* *
3229 * @internal
3330 *
@@ -124,6 +121,17 @@ struct culibs_type_impl<std::complex<double>> {
124121 using type = cuDoubleComplex;
125122};
126123
124+
125+ template <>
126+ struct culibs_type_impl <half> {
127+ using type = __half;
128+ };
129+
130+ template <>
131+ struct culibs_type_impl <std::complex <half>> {
132+ using type = __half2;
133+ };
134+
127135template <typename T>
128136struct culibs_type_impl <thrust::complex <T>> {
129137 using type = typename culibs_type_impl<std::complex <T>>::type;
@@ -154,9 +162,14 @@ struct cuda_type_impl<volatile T> {
154162 using type = volatile typename cuda_type_impl<T>::type;
155163};
156164
165+ template <>
166+ struct cuda_type_impl <half> {
167+ using type = __half;
168+ };
169+
157170template <typename T>
158171struct cuda_type_impl <std::complex <T>> {
159- using type = thrust::complex <T >;
172+ using type = thrust::complex <typename cuda_type_impl<T>::type >;
160173};
161174
162175template <>
@@ -169,14 +182,24 @@ struct cuda_type_impl<cuComplex> {
169182 using type = thrust::complex <float >;
170183};
171184
185+ template <>
186+ struct cuda_type_impl <__half2> {
187+ using type = thrust::complex <__half>;
188+ };
189+
172190template <typename T>
173191struct cuda_struct_member_type_impl {
174192 using type = T;
175193};
176194
177195template <typename T>
178196struct cuda_struct_member_type_impl <std::complex <T>> {
179- using type = fake_complex<T>;
197+ using type = fake_complex<typename cuda_struct_member_type_impl<T>::type>;
198+ };
199+
200+ template <>
201+ struct cuda_struct_member_type_impl <gko::half> {
202+ using type = __half;
180203};
181204
182205template <typename ValueType, typename IndexType>
@@ -200,6 +223,7 @@ GKO_CUDA_DATA_TYPE(float, CUDA_R_32F);
200223GKO_CUDA_DATA_TYPE (double , CUDA_R_64F );
201224GKO_CUDA_DATA_TYPE (std::complex <float >, CUDA_C_32F );
202225GKO_CUDA_DATA_TYPE (std::complex <double >, CUDA_C_64F );
226+ GKO_CUDA_DATA_TYPE (std::complex <float16>, CUDA_C_16F );
203227GKO_CUDA_DATA_TYPE (int32, CUDA_R_32I );
204228GKO_CUDA_DATA_TYPE (int8, CUDA_R_8I );
205229
0 commit comments