Skip to content

Commit bbf7883

Browse files
committed
add cusparse csrsort bindings
1 parent bb8b830 commit bbf7883

3 files changed

Lines changed: 178 additions & 7 deletions

File tree

cuda/base/cusparse_bindings.hpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,80 @@ GKO_BIND_CUSPARSE64_CSRSM_SOLVE(ValueType, detail::not_implemented);
968968
#endif
969969

970970

971+
template <typename IndexType>
972+
void create_identity_permutation(cusparseHandle_t handle, IndexType size,
973+
IndexType *permutation) GKO_NOT_IMPLEMENTED;
974+
975+
template <>
976+
inline void create_identity_permutation<int32>(cusparseHandle_t handle,
977+
int32 size, int32 *permutation)
978+
{
979+
GKO_ASSERT_NO_CUSPARSE_ERRORS(
980+
cusparseCreateIdentityPermutation(handle, size, permutation));
981+
}
982+
983+
984+
template <typename IndexType>
985+
void csrsort_buffer_size(cusparseHandle_t handle, IndexType m, IndexType n,
986+
IndexType nnz, const IndexType *row_ptrs,
987+
const IndexType *col_idxs,
988+
size_type &buffer_size) GKO_NOT_IMPLEMENTED;
989+
990+
template <>
991+
inline void csrsort_buffer_size<int32>(cusparseHandle_t handle, int32 m,
992+
int32 n, int32 nnz,
993+
const int32 *row_ptrs,
994+
const int32 *col_idxs,
995+
size_type &buffer_size)
996+
{
997+
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseXcsrsort_bufferSizeExt(
998+
handle, m, n, nnz, row_ptrs, col_idxs, &buffer_size));
999+
}
1000+
1001+
1002+
template <typename IndexType>
1003+
void csrsort(cusparseHandle_t handle, IndexType m, IndexType n, IndexType nnz,
1004+
const cusparseMatDescr_t descr, const IndexType *row_ptrs,
1005+
IndexType *col_idxs, IndexType *permutation,
1006+
void *buffer) GKO_NOT_IMPLEMENTED;
1007+
1008+
template <>
1009+
inline void csrsort<int32>(cusparseHandle_t handle, int32 m, int32 n, int32 nnz,
1010+
const cusparseMatDescr_t descr,
1011+
const int32 *row_ptrs, int32 *col_idxs,
1012+
int32 *permutation, void *buffer)
1013+
{
1014+
GKO_ASSERT_NO_CUSPARSE_ERRORS(cusparseXcsrsort(
1015+
handle, m, n, nnz, descr, row_ptrs, col_idxs, permutation, buffer));
1016+
}
1017+
1018+
1019+
template <typename IndexType, typename ValueType>
1020+
void gather(cusparseHandle_t handle, IndexType nnz, const ValueType *in,
1021+
ValueType *out, const IndexType *permutation) GKO_NOT_IMPLEMENTED;
1022+
1023+
#define GKO_BIND_CUSPARSE_GATHER(ValueType, CusparseName) \
1024+
template <> \
1025+
inline void gather<int32, ValueType>(cusparseHandle_t handle, int32 nnz, \
1026+
const ValueType *in, ValueType *out, \
1027+
const int32 *permutation) \
1028+
{ \
1029+
GKO_ASSERT_NO_CUSPARSE_ERRORS( \
1030+
CusparseName(handle, nnz, as_culibs_type(in), as_culibs_type(out), \
1031+
permutation, CUSPARSE_INDEX_BASE_ZERO)); \
1032+
} \
1033+
static_assert(true, \
1034+
"This assert is used to counter the false positive extra " \
1035+
"semi-colon warnings")
1036+
1037+
GKO_BIND_CUSPARSE_GATHER(float, cusparseSgthr);
1038+
GKO_BIND_CUSPARSE_GATHER(double, cusparseDgthr);
1039+
GKO_BIND_CUSPARSE_GATHER(std::complex<float>, cusparseCgthr);
1040+
GKO_BIND_CUSPARSE_GATHER(std::complex<double>, cusparseZgthr);
1041+
1042+
#undef GKO_BIND_CUSPARSE_GATHER
1043+
1044+
9711045
} // namespace cusparse
9721046
} // namespace cuda
9731047
} // namespace kernels

cuda/matrix/csr_kernels.cu

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,46 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
931931
template <typename ValueType, typename IndexType>
932932
void sort_by_column_index(std::shared_ptr<const CudaExecutor> exec,
933933
matrix::Csr<ValueType, IndexType> *to_sort)
934-
GKO_NOT_IMPLEMENTED;
934+
{
935+
if (cusparse::is_supported<ValueType, IndexType>::value) {
936+
auto handle = exec->get_cusparse_handle();
937+
auto descr = cusparse::create_mat_descr();
938+
auto m = IndexType(to_sort->get_size()[0]);
939+
auto n = IndexType(to_sort->get_size()[1]);
940+
auto nnz = IndexType(to_sort->get_num_stored_elements());
941+
auto row_ptrs = to_sort->get_const_row_ptrs();
942+
auto col_idxs = to_sort->get_col_idxs();
943+
auto vals = to_sort->get_values();
944+
945+
// copy values
946+
Array<ValueType> tmp_vals_array(exec, nnz);
947+
exec->copy_from(exec.get(), nnz, vals, tmp_vals_array.get_data());
948+
auto tmp_vals = tmp_vals_array.get_const_data();
949+
950+
// init identity permutation
951+
Array<IndexType> permutation_array(exec, nnz);
952+
auto permutation = permutation_array.get_data();
953+
cusparse::create_identity_permutation(handle, nnz, permutation);
954+
955+
// allocate buffer
956+
size_type buffer_size{};
957+
cusparse::csrsort_buffer_size(handle, m, n, nnz, row_ptrs, col_idxs,
958+
buffer_size);
959+
Array<char> buffer_array{exec, buffer_size};
960+
auto buffer = buffer_array.get_data();
961+
962+
// sort column indices
963+
cusparse::csrsort(handle, m, n, nnz, descr, row_ptrs, col_idxs,
964+
permutation, buffer);
965+
966+
// sort values
967+
cusparse::gather(handle, nnz, tmp_vals, vals, permutation);
968+
969+
cusparse::destroy(descr);
970+
} else {
971+
GKO_NOT_IMPLEMENTED;
972+
}
973+
}
935974

936975
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
937976
GKO_DECLARE_CSR_SORT_BY_COLUMN_INDEX);

cuda/test/matrix/csr_kernels.cpp

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Csr : public ::testing::Test {
6363
using ComplexVec = gko::matrix::Dense<std::complex<double>>;
6464
using ComplexMtx = gko::matrix::Csr<std::complex<double>>;
6565

66-
Csr() : rand_engine(42) {}
66+
Csr() : mtx_size(532, 231), rand_engine(42) {}
6767

6868
void SetUp()
6969
{
@@ -93,11 +93,11 @@ class Csr : public ::testing::Test {
9393
int num_vectors = 1)
9494
{
9595
mtx = Mtx::create(ref, strategy);
96-
mtx->copy_from(gen_mtx<Vec>(532, 231, 1));
96+
mtx->copy_from(gen_mtx<Vec>(mtx_size[0], mtx_size[1], 1));
9797
square_mtx = Mtx::create(ref, strategy);
98-
square_mtx->copy_from(gen_mtx<Vec>(532, 532, 1));
99-
expected = gen_mtx<Vec>(532, num_vectors, 1);
100-
y = gen_mtx<Vec>(231, num_vectors, 1);
98+
square_mtx->copy_from(gen_mtx<Vec>(mtx_size[0], mtx_size[0], 1));
99+
expected = gen_mtx<Vec>(mtx_size[0], num_vectors, 1);
100+
y = gen_mtx<Vec>(mtx_size[1], num_vectors, 1);
101101
alpha = gko::initialize<Vec>({2.0}, ref);
102102
beta = gko::initialize<Vec>({-1.0}, ref);
103103
dmtx = Mtx::create(cuda, strategy);
@@ -118,14 +118,48 @@ class Csr : public ::testing::Test {
118118
std::shared_ptr<ComplexMtx::strategy_type> strategy)
119119
{
120120
complex_mtx = ComplexMtx::create(ref, strategy);
121-
complex_mtx->copy_from(gen_mtx<ComplexVec>(532, 231, 1));
121+
complex_mtx->copy_from(
122+
gen_mtx<ComplexVec>(mtx_size[0], mtx_size[1], 1));
122123
complex_dmtx = ComplexMtx::create(cuda, strategy);
123124
complex_dmtx->copy_from(complex_mtx.get());
124125
}
125126

127+
struct matrix_pair {
128+
std::unique_ptr<Mtx> ref;
129+
std::unique_ptr<Mtx> cuda;
130+
};
131+
132+
matrix_pair gen_unsorted_mtx()
133+
{
134+
constexpr int min_nnz_per_row = 2; // Must be larger/equal than 2
135+
auto local_mtx_ref =
136+
gen_mtx<Mtx>(mtx_size[0], mtx_size[1], min_nnz_per_row);
137+
for (size_t row = 0; row < mtx_size[0]; ++row) {
138+
const auto row_ptrs = local_mtx_ref->get_const_row_ptrs();
139+
const auto start_row = row_ptrs[row];
140+
auto col_idx = local_mtx_ref->get_col_idxs() + start_row;
141+
auto vals = local_mtx_ref->get_values() + start_row;
142+
const auto nnz_in_this_row = row_ptrs[row + 1] - row_ptrs[row];
143+
auto swap_idx_dist =
144+
std::uniform_int_distribution<>(0, nnz_in_this_row - 1);
145+
// shuffle `nnz_in_this_row / 2` times
146+
for (size_t perm = 0; perm < nnz_in_this_row; perm += 2) {
147+
const auto idx1 = swap_idx_dist(rand_engine);
148+
const auto idx2 = swap_idx_dist(rand_engine);
149+
std::swap(col_idx[idx1], col_idx[idx2]);
150+
std::swap(vals[idx1], vals[idx2]);
151+
}
152+
}
153+
auto local_mtx_cuda = Mtx::create(cuda);
154+
local_mtx_cuda->copy_from(local_mtx_ref.get());
155+
156+
return {std::move(local_mtx_ref), std::move(local_mtx_cuda)};
157+
}
158+
126159
std::shared_ptr<gko::ReferenceExecutor> ref;
127160
std::shared_ptr<const gko::CudaExecutor> cuda;
128161

162+
const gko::dim<2> mtx_size;
129163
std::ranlux48 rand_engine;
130164

131165
std::unique_ptr<Mtx> mtx;
@@ -576,4 +610,28 @@ TEST_F(Csr, MoveToHybridIsEquivalentToRef)
576610
}
577611

578612

613+
TEST_F(Csr, SortSortedMatrixIsEquivalentToRef)
614+
{
615+
set_up_apply_data(std::make_shared<Mtx::automatical>());
616+
617+
mtx->sort_by_column_index();
618+
dmtx->sort_by_column_index();
619+
620+
// Values must be unchanged, therefore, tolerance is `0`
621+
GKO_ASSERT_MTX_NEAR(mtx, dmtx, 0);
622+
}
623+
624+
625+
TEST_F(Csr, SortUnsortedMatrixIsEquivalentToRef)
626+
{
627+
auto uns_mtx = gen_unsorted_mtx();
628+
629+
uns_mtx.ref->sort_by_column_index();
630+
uns_mtx.cuda->sort_by_column_index();
631+
632+
// Values must be unchanged, therefore, tolerance is `0`
633+
GKO_ASSERT_MTX_NEAR(uns_mtx.ref, uns_mtx.cuda, 0);
634+
}
635+
636+
579637
} // namespace

0 commit comments

Comments
 (0)