Skip to content

Commit 265cb47

Browse files
committed
type map
1 parent 6c1ae5b commit 265cb47

4 files changed

Lines changed: 86 additions & 12 deletions

File tree

accessor/cuda_helper.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
#include "utils.hpp"
1818

1919

20+
struct __half;
21+
22+
2023
namespace gko {
24+
25+
26+
class half;
27+
28+
2129
namespace acc {
2230
namespace detail {
2331

@@ -27,6 +35,11 @@ struct cuda_type {
2735
using type = T;
2836
};
2937

38+
template <>
39+
struct cuda_type<gko::half> {
40+
using type = __half;
41+
};
42+
3043
// Unpack cv and reference / pointer qualifiers
3144
template <typename T>
3245
struct cuda_type<const T> {
@@ -57,7 +70,7 @@ struct cuda_type<T&&> {
5770
// Transform std::complex to thrust::complex
5871
template <typename T>
5972
struct cuda_type<std::complex<T>> {
60-
using type = thrust::complex<T>;
73+
using type = thrust::complex<typename cuda_type<T>::type>;
6174
};
6275

6376

accessor/hip_helper.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,15 @@
1717
#include "utils.hpp"
1818

1919

20+
struct __half;
21+
22+
2023
namespace gko {
24+
25+
26+
class half;
27+
28+
2129
namespace acc {
2230
namespace detail {
2331

@@ -53,11 +61,15 @@ struct hip_type<T&&> {
5361
using type = typename hip_type<T>::type&&;
5462
};
5563

64+
template <>
65+
struct hip_type<gko::half> {
66+
using type = __half;
67+
};
5668

5769
// Transform std::complex to thrust::complex
5870
template <typename T>
5971
struct hip_type<std::complex<T>> {
60-
using type = thrust::complex<T>;
72+
using type = thrust::complex<typename hip_type<T>::type>;
6173
};
6274

6375

cuda/base/types.hpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,17 @@
1414
#include <cusparse.h>
1515
#include <thrust/complex.h>
1616

17+
#include <ginkgo/core/base/half.hpp>
1718
#include <ginkgo/core/base/matrix_data.hpp>
1819
#include <ginkgo/core/base/types.hpp>
1920

2021

2122
namespace gko {
2223

23-
2424
namespace kernels {
2525
namespace cuda {
26-
27-
2826
namespace detail {
2927

30-
3128
/**
3229
* @internal
3330
*
@@ -124,6 +121,17 @@ struct culibs_type_impl<std::complex<double>> {
124121
using type = cuDoubleComplex;
125122
};
126123

124+
125+
template <>
126+
struct culibs_type_impl<half> {
127+
using type = __half;
128+
};
129+
130+
template <>
131+
struct culibs_type_impl<std::complex<half>> {
132+
using type = __half2;
133+
};
134+
127135
template <typename T>
128136
struct culibs_type_impl<thrust::complex<T>> {
129137
using type = typename culibs_type_impl<std::complex<T>>::type;
@@ -154,9 +162,14 @@ struct cuda_type_impl<volatile T> {
154162
using type = volatile typename cuda_type_impl<T>::type;
155163
};
156164

165+
template <>
166+
struct cuda_type_impl<half> {
167+
using type = __half;
168+
};
169+
157170
template <typename T>
158171
struct cuda_type_impl<std::complex<T>> {
159-
using type = thrust::complex<T>;
172+
using type = thrust::complex<typename cuda_type_impl<T>::type>;
160173
};
161174

162175
template <>
@@ -169,14 +182,24 @@ struct cuda_type_impl<cuComplex> {
169182
using type = thrust::complex<float>;
170183
};
171184

185+
template <>
186+
struct cuda_type_impl<__half2> {
187+
using type = thrust::complex<__half>;
188+
};
189+
172190
template <typename T>
173191
struct cuda_struct_member_type_impl {
174192
using type = T;
175193
};
176194

177195
template <typename T>
178196
struct cuda_struct_member_type_impl<std::complex<T>> {
179-
using type = fake_complex<T>;
197+
using type = fake_complex<typename cuda_struct_member_type_impl<T>::type>;
198+
};
199+
200+
template <>
201+
struct cuda_struct_member_type_impl<gko::half> {
202+
using type = __half;
180203
};
181204

182205
template <typename ValueType, typename IndexType>
@@ -200,6 +223,7 @@ GKO_CUDA_DATA_TYPE(float, CUDA_R_32F);
200223
GKO_CUDA_DATA_TYPE(double, CUDA_R_64F);
201224
GKO_CUDA_DATA_TYPE(std::complex<float>, CUDA_C_32F);
202225
GKO_CUDA_DATA_TYPE(std::complex<double>, CUDA_C_64F);
226+
GKO_CUDA_DATA_TYPE(std::complex<float16>, CUDA_C_16F);
203227
GKO_CUDA_DATA_TYPE(int32, CUDA_R_32I);
204228
GKO_CUDA_DATA_TYPE(int8, CUDA_R_8I);
205229

hip/base/types.hip.hpp

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
#endif
2222
#include <thrust/complex.h>
2323

24+
#include <ginkgo/core/base/half.hpp>
2425
#include <ginkgo/core/base/matrix_data.hpp>
2526

2627
#include "common/cuda_hip/base/runtime.hpp"
2728

2829

2930
namespace gko {
30-
31-
3231
namespace kernels {
3332
namespace hip {
3433
namespace detail {
@@ -130,6 +129,17 @@ struct hiplibs_type_impl<std::complex<double>> {
130129
using type = hipDoubleComplex;
131130
};
132131

132+
template <>
133+
struct hiplibs_type_impl<half> {
134+
using type = __half;
135+
};
136+
137+
template <>
138+
struct hiplibs_type_impl<std::complex<half>> {
139+
using type = __half2;
140+
};
141+
142+
133143
template <typename T>
134144
struct hiplibs_type_impl<thrust::complex<T>> {
135145
using type = typename hiplibs_type_impl<std::complex<T>>::type;
@@ -202,9 +212,14 @@ struct hip_type_impl<volatile T> {
202212
using type = volatile typename hip_type_impl<T>::type;
203213
};
204214

215+
template <>
216+
struct hip_type_impl<gko::half> {
217+
using type = __half;
218+
};
219+
205220
template <typename T>
206221
struct hip_type_impl<std::complex<T>> {
207-
using type = thrust::complex<T>;
222+
using type = thrust::complex<typename hip_type_impl<T>::type>;
208223
};
209224

210225
template <>
@@ -217,14 +232,24 @@ struct hip_type_impl<hipComplex> {
217232
using type = thrust::complex<float>;
218233
};
219234

235+
template <>
236+
struct hip_type_impl<__half2> {
237+
using type = thrust::complex<__half>;
238+
};
239+
220240
template <typename T>
221241
struct hip_struct_member_type_impl {
222242
using type = T;
223243
};
224244

225245
template <typename T>
226246
struct hip_struct_member_type_impl<std::complex<T>> {
227-
using type = fake_complex<T>;
247+
using type = fake_complex<typename hip_struct_member_type_impl<T>::type>;
248+
};
249+
250+
template <>
251+
struct hip_struct_member_type_impl<gko::half> {
252+
using type = __half;
228253
};
229254

230255
template <typename ValueType, typename IndexType>

0 commit comments

Comments
 (0)