Skip to content

Commit 0f3427a

Browse files
committed
add ELL mixed-precision support
1 parent a248953 commit 0f3427a

10 files changed

Lines changed: 274 additions & 104 deletions

File tree

benchmark/utils/formats.hpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ namespace formats {
5959

6060

6161
std::string available_format =
62-
"coo, csr, ell, sellp, hybrid, hybrid0, hybrid25, hybrid33, hybrid40, "
62+
"coo, csr, ell, fell, sellp, hybrid, hybrid0, hybrid25, hybrid33, "
63+
"hybrid40, "
6364
"hybrid60, hybrid80, hybridlimit0, hybridlimit25, hybridlimit33, "
6465
"hybridminstorage"
6566
#ifdef HAS_CUDA
@@ -90,6 +91,9 @@ std::string format_description =
9091
"csrm: Ginkgo's CSR implementation with merge_path strategy.\n"
9192
"ell: Ellpack format according to Bell and Garland: Efficient Sparse "
9293
"Matrix-Vector Multiplication on CUDA.\n"
94+
"fell: float Ellpack format according to Bell and Garland: Efficient "
95+
"Sparse "
96+
"Matrix-Vector Multiplication on CUDA.\n"
9397
"sellp: Sliced Ellpack uses a default block size of 32.\n"
9498
"hybrid: Hybrid uses ell and coo to represent the matrix.\n"
9599
"hybrid0, hybrid25, hybrid33, hybrid40, hybrid60, hybrid80: Hybrid uses "
@@ -204,6 +208,23 @@ const std::map<std::string, std::function<std::unique_ptr<gko::LinOp>(
204208
{"csrc", READ_MATRIX(csr, std::make_shared<csr::classical>())},
205209
{"coo", read_matrix_from_data<gko::matrix::Coo<etype>>},
206210
{"ell", read_matrix_from_data<gko::matrix::Ell<etype>>},
211+
{"fell",
212+
[](std::shared_ptr<const gko::Executor> exec,
213+
const gko::matrix_data<> &data) {
214+
gko::matrix_data<float> conv_data;
215+
conv_data.size = data.size;
216+
conv_data.nonzeros.resize(data.nonzeros.size());
217+
auto it = conv_data.nonzeros.begin();
218+
for (auto &el : data.nonzeros) {
219+
it->row = el.row;
220+
it->column = el.column;
221+
it->value = el.value;
222+
++it;
223+
}
224+
auto mat = gko::matrix::Ell<float>::create(std::move(exec));
225+
mat->read(conv_data);
226+
return mat;
227+
}},
207228
#ifdef HAS_CUDA
208229
#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000)
209230
{"cusp_csr", read_matrix_from_data<cusp_csr>},
@@ -212,8 +233,8 @@ const std::map<std::string, std::function<std::unique_ptr<gko::LinOp>(
212233
{"cusp_hybrid", read_matrix_from_data<cusp_hybrid>},
213234
{"cusp_coo", read_matrix_from_data<cusp_coo>},
214235
{"cusp_ell", read_matrix_from_data<cusp_ell>},
215-
#else // CUDA_VERSION >= 11000
216-
// cusp_csr, cusp_coo use the generic ones from CUDA 11
236+
#else // CUDA_VERSION >= 11000
237+
// cusp_csr, cusp_coo use the generic ones from CUDA 11
217238
{"cusp_csr", read_matrix_from_data<cusp_gcsr>},
218239
{"cusp_coo", read_matrix_from_data<cusp_gcoo>},
219240
#endif
@@ -260,7 +281,8 @@ const std::map<std::string, std::function<std::unique_ptr<gko::LinOp>(
260281
{"hybridminstorage",
261282
READ_MATRIX(hybrid,
262283
std::make_shared<hybrid::minimal_storage_limit>())},
263-
{"sellp", read_matrix_from_data<gko::matrix::Sellp<etype>>}};
284+
{"sellp", read_matrix_from_data<gko::matrix::Sellp<etype>>}
285+
};
264286
// clang-format on
265287

266288

common/matrix/ell_kernels.hpp.inc

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,25 @@ namespace kernel {
3434
namespace {
3535

3636

37-
template <int num_thread_per_worker, bool atomic, typename ValueType,
37+
template <int num_thread_per_worker, bool atomic, typename InputValueType,
38+
typename MatrixValueType, typename OutputValueType,
3839
typename IndexType, typename Closure>
3940
__device__ void spmv_kernel(
4041
const size_type num_rows, const int num_worker_per_row,
41-
const ValueType *__restrict__ val, const IndexType *__restrict__ col,
42+
const MatrixValueType *__restrict__ val, const IndexType *__restrict__ col,
4243
const size_type stride, const size_type num_stored_elements_per_row,
43-
const ValueType *__restrict__ b, const size_type b_stride,
44-
ValueType *__restrict__ c, const size_type c_stride, Closure op)
44+
const InputValueType *__restrict__ b, const size_type b_stride,
45+
OutputValueType *__restrict__ c, const size_type c_stride, Closure op)
4546
{
4647
const auto tidx = thread::get_thread_id_flat();
4748
const auto column_id = blockIdx.y;
49+
using compute_type =
50+
decltype(InputValueType{} + MatrixValueType{} + OutputValueType{});
4851
if (num_thread_per_worker == 1) {
4952
// Specialize the num_thread_per_worker = 1. It doesn't need the shared
5053
// memory, __syncthreads, and atomic_add
5154
if (tidx < num_rows) {
52-
ValueType temp = zero<ValueType>();
55+
auto temp = zero<compute_type>();
5356
for (size_type idx = 0; idx < num_stored_elements_per_row; idx++) {
5457
const auto ind = tidx + idx * stride;
5558
const auto col_idx = col[ind];
@@ -68,14 +71,14 @@ __device__ void spmv_kernel(
6871
const auto x = tidx % num_rows;
6972
const auto worker_id = tidx / num_rows;
7073
const auto step_size = num_worker_per_row * num_thread_per_worker;
71-
__shared__ UninitializedArray<ValueType, default_block_size /
72-
num_thread_per_worker>
74+
__shared__ UninitializedArray<
75+
compute_type, default_block_size / num_thread_per_worker>
7376
storage;
7477
if (idx_in_worker == 0) {
7578
storage[threadIdx.x] = 0;
7679
}
7780
__syncthreads();
78-
ValueType temp = zero<ValueType>();
81+
auto temp = zero<compute_type>();
7982
for (size_type idx =
8083
worker_id * num_thread_per_worker + idx_in_worker;
8184
idx < num_stored_elements_per_row; idx += step_size) {
@@ -102,35 +105,41 @@ __device__ void spmv_kernel(
102105
}
103106

104107

105-
template <int num_thread_per_worker, bool atomic = false, typename ValueType,
106-
typename IndexType>
108+
template <int num_thread_per_worker, bool atomic = false,
109+
typename InputValueType, typename MatrixValueType,
110+
typename OutputValueType, typename IndexType>
107111
__global__ __launch_bounds__(default_block_size) void spmv(
108112
const size_type num_rows, const int num_worker_per_row,
109-
const ValueType *__restrict__ val, const IndexType *__restrict__ col,
113+
const MatrixValueType *__restrict__ val, const IndexType *__restrict__ col,
110114
const size_type stride, const size_type num_stored_elements_per_row,
111-
const ValueType *__restrict__ b, const size_type b_stride,
112-
ValueType *__restrict__ c, const size_type c_stride)
115+
const InputValueType *__restrict__ b, const size_type b_stride,
116+
OutputValueType *__restrict__ c, const size_type c_stride)
113117
{
118+
using compute_type =
119+
decltype(InputValueType{} + MatrixValueType{} + OutputValueType{});
114120
spmv_kernel<num_thread_per_worker, atomic>(
115121
num_rows, num_worker_per_row, val, col, stride,
116122
num_stored_elements_per_row, b, b_stride, c, c_stride,
117-
[](const ValueType &x, const ValueType &y) { return x; });
123+
[](const compute_type &x, const OutputValueType &y) { return x; });
118124
}
119125

120126

121-
template <int num_thread_per_worker, bool atomic = false, typename ValueType,
122-
typename IndexType>
127+
template <int num_thread_per_worker, bool atomic = false,
128+
typename InputValueType, typename MatrixValueType,
129+
typename OutputValueType, typename IndexType>
123130
__global__ __launch_bounds__(default_block_size) void spmv(
124131
const size_type num_rows, const int num_worker_per_row,
125-
const ValueType *__restrict__ alpha, const ValueType *__restrict__ val,
126-
const IndexType *__restrict__ col, const size_type stride,
127-
const size_type num_stored_elements_per_row,
128-
const ValueType *__restrict__ b, const size_type b_stride,
129-
const ValueType *__restrict__ beta, ValueType *__restrict__ c,
132+
const MatrixValueType *__restrict__ alpha,
133+
const MatrixValueType *__restrict__ val, const IndexType *__restrict__ col,
134+
const size_type stride, const size_type num_stored_elements_per_row,
135+
const InputValueType *__restrict__ b, const size_type b_stride,
136+
const OutputValueType *__restrict__ beta, OutputValueType *__restrict__ c,
130137
const size_type c_stride)
131138
{
132-
const ValueType alpha_val = alpha[0];
133-
const ValueType beta_val = beta[0];
139+
using compute_type =
140+
decltype(InputValueType{} + MatrixValueType{} + OutputValueType{});
141+
const compute_type alpha_val = alpha[0];
142+
const compute_type beta_val = beta[0];
134143
// Because the atomic operation changes the values of c during computation,
135144
// it can not do the right alpha * a * b + beta * c operation.
136145
// Thus, the cuda kernel only computes alpha * a * b when it uses atomic
@@ -139,15 +148,16 @@ __global__ __launch_bounds__(default_block_size) void spmv(
139148
spmv_kernel<num_thread_per_worker, atomic>(
140149
num_rows, num_worker_per_row, val, col, stride,
141150
num_stored_elements_per_row, b, b_stride, c, c_stride,
142-
[&alpha_val](const ValueType &x, const ValueType &y) {
151+
[&alpha_val](const compute_type &x, const OutputValueType &y) {
143152
return alpha_val * x;
144153
});
145154
} else {
146155
spmv_kernel<num_thread_per_worker, atomic>(
147156
num_rows, num_worker_per_row, val, col, stride,
148157
num_stored_elements_per_row, b, b_stride, c, c_stride,
149-
[&alpha_val, &beta_val](const ValueType &x, const ValueType &y) {
150-
return alpha_val * x + beta_val * y;
158+
[&alpha_val, &beta_val](const compute_type &x,
159+
const OutputValueType &y) {
160+
return alpha_val * x + beta_val * compute_type{y};
151161
});
152162
}
153163
}

core/base/precision_dispatch.hpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,57 @@ void precision_dispatch_spmv(Function fn, const LinOp *alpha, const LinOp *in,
147147
}
148148
}
149149

150+
151+
template <typename ValueType, typename Function>
152+
void mixed_precision_dispatch(Function fn, const LinOp *in, LinOp *out)
153+
{
154+
if (auto dense_in = dynamic_cast<const matrix::Dense<ValueType> *>(in)) {
155+
if (auto dense_out = dynamic_cast<matrix::Dense<ValueType> *>(out)) {
156+
fn(dense_in, dense_out);
157+
} else if (auto dense_out =
158+
dynamic_cast<matrix::Dense<next_precision<ValueType>> *>(
159+
out)) {
160+
fn(dense_in, dense_out);
161+
} else {
162+
GKO_NOT_SUPPORTED(out);
163+
}
164+
} else if (auto dense_in = dynamic_cast<
165+
const matrix::Dense<next_precision<ValueType>> *>(in)) {
166+
if (auto dense_out = dynamic_cast<matrix::Dense<ValueType> *>(out)) {
167+
fn(dense_in, dense_out);
168+
} else if (auto dense_out =
169+
dynamic_cast<matrix::Dense<next_precision<ValueType>> *>(
170+
out)) {
171+
fn(dense_in, dense_out);
172+
} else {
173+
GKO_NOT_SUPPORTED(out);
174+
}
175+
} else {
176+
GKO_NOT_SUPPORTED(in);
177+
}
178+
}
179+
180+
template <typename ValueType, typename Function>
181+
void mixed_precision_dispatch_spmv(Function fn, const LinOp *in, LinOp *out)
182+
{
183+
// do we need to convert complex Dense to real Dense?
184+
auto complex_to_real =
185+
!(is_complex<ValueType>() ||
186+
dynamic_cast<const ConvertibleTo<matrix::Dense<>> *>(in));
187+
if (complex_to_real) {
188+
auto dense_in = make_temporary_conversion<to_complex<ValueType>>(in);
189+
auto dense_out = make_temporary_conversion<to_complex<ValueType>>(out);
190+
using Dense = matrix::Dense<ValueType>;
191+
// These dynamic_casts are only needed to make the code compile
192+
// If ValueType is complex, this branch will never be taken
193+
// If ValueType is real, the cast is a no-op
194+
fn(dynamic_cast<const Dense *>(dense_in->create_real_view().get()),
195+
dynamic_cast<Dense *>(dense_out->create_real_view().get()));
196+
} else {
197+
mixed_precision_dispatch<ValueType>(fn, in, out);
198+
}
199+
}
200+
150201
} // namespace gko
151202

152203

core/device_hooks/common_kernels.inc.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -849,15 +849,20 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
849849
namespace ell {
850850

851851

852-
template <typename ValueType, typename IndexType>
853-
GKO_DECLARE_ELL_SPMV_KERNEL(ValueType, IndexType)
852+
template <typename InputValueType, typename MatrixValueType,
853+
typename OutputValueType, typename IndexType>
854+
GKO_DECLARE_ELL_SPMV_KERNEL(InputValueType, MatrixValueType, OutputValueType,
855+
IndexType)
854856
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
855-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_ELL_SPMV_KERNEL);
857+
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(
858+
GKO_DECLARE_ELL_SPMV_KERNEL);
856859

857-
template <typename ValueType, typename IndexType>
858-
GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(ValueType, IndexType)
860+
template <typename InputValueType, typename MatrixValueType,
861+
typename OutputValueType, typename IndexType>
862+
GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(InputValueType, MatrixValueType,
863+
OutputValueType, IndexType)
859864
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
860-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
865+
GKO_INSTANTIATE_FOR_EACH_MIXED_VALUE_AND_INDEX_TYPE(
861866
GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL);
862867

863868
template <typename ValueType, typename IndexType>

core/matrix/ell.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ size_type calculate_max_nnz_per_row(
102102
template <typename ValueType, typename IndexType>
103103
void Ell<ValueType, IndexType>::apply_impl(const LinOp *b, LinOp *x) const
104104
{
105-
precision_dispatch_spmv<ValueType>(
105+
mixed_precision_dispatch_spmv<ValueType>(
106106
[&](auto dense_b, auto dense_x) {
107107
this->get_executor()->run(ell::make_spmv(this, dense_b, dense_x));
108108
},
@@ -114,12 +114,17 @@ template <typename ValueType, typename IndexType>
114114
void Ell<ValueType, IndexType>::apply_impl(const LinOp *alpha, const LinOp *b,
115115
const LinOp *beta, LinOp *x) const
116116
{
117-
precision_dispatch_spmv<ValueType>(
118-
[&](auto dense_alpha, auto dense_b, auto dense_beta, auto dense_x) {
119-
this->get_executor()->run(ell::make_advanced_spmv(
120-
dense_alpha, this, dense_b, dense_beta, dense_x));
117+
mixed_precision_dispatch_spmv<ValueType>(
118+
[&](auto dense_b, auto dense_x) {
119+
auto converted_alpha = make_temporary_conversion<ValueType>(alpha);
120+
auto converted_beta =
121+
make_temporary_conversion<typename std::remove_reference_t<
122+
decltype(*dense_x)>::value_type>(beta);
123+
this->get_executor()->run(
124+
ell::make_advanced_spmv(converted_alpha.get(), this, dense_b,
125+
converted_beta.get(), dense_x));
121126
},
122-
alpha, b, beta, x);
127+
b, x);
123128
}
124129

125130

core/matrix/ell_kernels.hpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,21 @@ namespace gko {
4646
namespace kernels {
4747

4848

49-
#define GKO_DECLARE_ELL_SPMV_KERNEL(ValueType, IndexType) \
50-
void spmv(std::shared_ptr<const DefaultExecutor> exec, \
51-
const matrix::Ell<ValueType, IndexType> *a, \
52-
const matrix::Dense<ValueType> *b, matrix::Dense<ValueType> *c)
53-
54-
#define GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(ValueType, IndexType) \
55-
void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec, \
56-
const matrix::Dense<ValueType> *alpha, \
57-
const matrix::Ell<ValueType, IndexType> *a, \
58-
const matrix::Dense<ValueType> *b, \
59-
const matrix::Dense<ValueType> *beta, \
60-
matrix::Dense<ValueType> *c)
49+
#define GKO_DECLARE_ELL_SPMV_KERNEL(InputValueType, MatrixValueType, \
50+
OutputValueType, IndexType) \
51+
void spmv(std::shared_ptr<const DefaultExecutor> exec, \
52+
const matrix::Ell<MatrixValueType, IndexType> *a, \
53+
const matrix::Dense<InputValueType> *b, \
54+
matrix::Dense<OutputValueType> *c)
55+
56+
#define GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(InputValueType, MatrixValueType, \
57+
OutputValueType, IndexType) \
58+
void advanced_spmv(std::shared_ptr<const DefaultExecutor> exec, \
59+
const matrix::Dense<MatrixValueType> *alpha, \
60+
const matrix::Ell<MatrixValueType, IndexType> *a, \
61+
const matrix::Dense<InputValueType> *b, \
62+
const matrix::Dense<OutputValueType> *beta, \
63+
matrix::Dense<OutputValueType> *c)
6164

6265
#define GKO_DECLARE_ELL_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType) \
6366
void convert_to_dense(std::shared_ptr<const DefaultExecutor> exec, \
@@ -87,10 +90,14 @@ namespace kernels {
8790
matrix::Diagonal<ValueType> *diag)
8891

8992
#define GKO_DECLARE_ALL_AS_TEMPLATES \
90-
template <typename ValueType, typename IndexType> \
91-
GKO_DECLARE_ELL_SPMV_KERNEL(ValueType, IndexType); \
92-
template <typename ValueType, typename IndexType> \
93-
GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(ValueType, IndexType); \
93+
template <typename InputValueType, typename MatrixValueType, \
94+
typename OutputValueType, typename IndexType> \
95+
GKO_DECLARE_ELL_SPMV_KERNEL(InputValueType, MatrixValueType, \
96+
OutputValueType, IndexType); \
97+
template <typename InputValueType, typename MatrixValueType, \
98+
typename OutputValueType, typename IndexType> \
99+
GKO_DECLARE_ELL_ADVANCED_SPMV_KERNEL(InputValueType, MatrixValueType, \
100+
OutputValueType, IndexType); \
94101
template <typename ValueType, typename IndexType> \
95102
GKO_DECLARE_ELL_CONVERT_TO_DENSE_KERNEL(ValueType, IndexType); \
96103
template <typename ValueType, typename IndexType> \

0 commit comments

Comments
 (0)