Skip to content

Commit 3c2c0af

Browse files
committed
add GPU permutation kernels
1 parent b854c30 commit 3c2c0af

11 files changed

Lines changed: 580 additions & 252 deletions

File tree

common/matrix/csr_kernels.hpp.inc

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,3 +946,120 @@ __global__ __launch_bounds__(default_block_size) void conjugate_kernel(
946946

947947

948948
} // namespace
949+
950+
951+
template <typename IndexType>
952+
__global__ __launch_bounds__(default_block_size) void inv_permutation_kernel(
953+
size_type size, const IndexType *__restrict__ permutation,
954+
IndexType *__restrict__ inv_permutation)
955+
{
956+
auto tid = thread::get_thread_id_flat();
957+
if (tid >= size) {
958+
return;
959+
}
960+
inv_permutation[permutation[tid]] = tid;
961+
}
962+
963+
964+
template <typename ValueType, typename IndexType>
965+
__global__ __launch_bounds__(default_block_size) void col_permute_kernel(
966+
size_type num_rows, size_type num_nonzeros,
967+
const IndexType *__restrict__ permutation,
968+
const IndexType *__restrict__ in_row_ptrs,
969+
const IndexType *__restrict__ in_cols,
970+
const ValueType *__restrict__ in_vals, IndexType *__restrict__ out_row_ptrs,
971+
IndexType *__restrict__ out_cols, ValueType *__restrict__ out_vals)
972+
{
973+
auto tid = thread::get_thread_id_flat();
974+
if (tid < num_nonzeros) {
975+
out_cols[tid] = permutation[in_cols[tid]];
976+
out_vals[tid] = in_vals[tid];
977+
}
978+
if (tid <= num_rows) {
979+
out_row_ptrs[tid] = in_row_ptrs[tid];
980+
}
981+
}
982+
983+
984+
template <typename IndexType>
985+
__global__ __launch_bounds__(default_block_size) void row_ptr_permute_kernel(
986+
size_type num_rows, const IndexType *__restrict__ permutation,
987+
const IndexType *__restrict__ in_row_ptrs, IndexType *__restrict__ out_nnz)
988+
{
989+
auto tid = thread::get_thread_id_flat();
990+
if (tid >= num_rows) {
991+
return;
992+
}
993+
auto in_row = permutation[tid];
994+
auto out_row = tid;
995+
out_nnz[out_row] = in_row_ptrs[in_row + 1] - in_row_ptrs[in_row];
996+
}
997+
998+
999+
template <typename IndexType>
1000+
__global__
1001+
__launch_bounds__(default_block_size) void inv_row_ptr_permute_kernel(
1002+
size_type num_rows, const IndexType *__restrict__ permutation,
1003+
const IndexType *__restrict__ in_row_ptrs,
1004+
IndexType *__restrict__ out_nnz)
1005+
{
1006+
auto tid = thread::get_thread_id_flat();
1007+
if (tid >= num_rows) {
1008+
return;
1009+
}
1010+
auto in_row = tid;
1011+
auto out_row = permutation[tid];
1012+
out_nnz[out_row] = in_row_ptrs[in_row + 1] - in_row_ptrs[in_row];
1013+
}
1014+
1015+
1016+
template <int subwarp_size, typename ValueType, typename IndexType>
1017+
__global__ __launch_bounds__(default_block_size) void row_permute_kernel(
1018+
size_type num_rows, const IndexType *__restrict__ permutation,
1019+
const IndexType *__restrict__ in_row_ptrs,
1020+
const IndexType *__restrict__ in_cols,
1021+
const ValueType *__restrict__ in_vals,
1022+
const IndexType *__restrict__ out_row_ptrs,
1023+
IndexType *__restrict__ out_cols, ValueType *__restrict__ out_vals)
1024+
{
1025+
auto tid = thread::get_subwarp_id_flat<subwarp_size>();
1026+
if (tid >= num_rows) {
1027+
return;
1028+
}
1029+
auto lane = threadIdx.x % subwarp_size;
1030+
auto in_row = permutation[tid];
1031+
auto out_row = tid;
1032+
auto in_begin = in_row_ptrs[in_row];
1033+
auto in_size = in_row_ptrs[in_row + 1] - in_begin;
1034+
auto out_begin = out_row_ptrs[out_row];
1035+
for (IndexType i = lane; i < in_size; i += subwarp_size) {
1036+
out_cols[out_begin + i] = in_cols[in_begin + i];
1037+
out_vals[out_begin + i] = in_vals[in_begin + i];
1038+
}
1039+
}
1040+
1041+
1042+
template <int subwarp_size, typename ValueType, typename IndexType>
1043+
__global__ __launch_bounds__(default_block_size) void inv_row_permute_kernel(
1044+
size_type num_rows, const IndexType *__restrict__ permutation,
1045+
const IndexType *__restrict__ in_row_ptrs,
1046+
const IndexType *__restrict__ in_cols,
1047+
const ValueType *__restrict__ in_vals,
1048+
const IndexType *__restrict__ out_row_ptrs,
1049+
IndexType *__restrict__ out_cols, ValueType *__restrict__ out_vals)
1050+
{
1051+
auto tid = thread::get_subwarp_id_flat<subwarp_size>();
1052+
if (tid >= num_rows) {
1053+
return;
1054+
}
1055+
auto lane = threadIdx.x % subwarp_size;
1056+
auto in_row = tid;
1057+
auto out_row = permutation[tid];
1058+
auto in_begin = in_row_ptrs[in_row];
1059+
auto in_size = in_row_ptrs[in_row + 1] - in_begin;
1060+
auto out_begin = out_row_ptrs[out_row];
1061+
for (IndexType i = lane; i < in_size; i += subwarp_size) {
1062+
out_cols[out_begin + i] = in_cols[in_begin + i];
1063+
out_vals[out_begin + i] = in_vals[in_begin + i];
1064+
}
1065+
}

core/device_hooks/common_kernels.inc.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,22 +647,21 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
647647
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL);
648648

649649
template <typename ValueType, typename IndexType>
650-
GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL(ValueType, IndexType)
650+
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL(ValueType, IndexType)
651651
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
652652
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
653-
GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL);
653+
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL);
654654

655655
template <typename ValueType, typename IndexType>
656656
GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL(ValueType, IndexType)
657657
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
658658
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
659659
GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL);
660660

661-
template <typename ValueType, typename IndexType>
662-
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL(ValueType, IndexType)
661+
template <typename IndexType>
662+
GKO_DECLARE_INVERT_PERMUTATION_KERNEL(IndexType)
663663
GKO_NOT_COMPILED(GKO_HOOK_MODULE);
664-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
665-
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL);
664+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_INVERT_PERMUTATION_KERNEL);
666665

667666
template <typename ValueType, typename IndexType>
668667
GKO_DECLARE_CSR_CALCULATE_MAX_NNZ_PER_ROW_KERNEL(ValueType, IndexType)

core/matrix/csr.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ GKO_REGISTER_OPERATION(convert_to_hybrid, csr::convert_to_hybrid);
7070
GKO_REGISTER_OPERATION(transpose, csr::transpose);
7171
GKO_REGISTER_OPERATION(conj_transpose, csr::conj_transpose);
7272
GKO_REGISTER_OPERATION(row_permute, csr::row_permute);
73-
GKO_REGISTER_OPERATION(column_permute, csr::column_permute);
7473
GKO_REGISTER_OPERATION(inverse_row_permute, csr::inverse_row_permute);
7574
GKO_REGISTER_OPERATION(inverse_column_permute, csr::inverse_column_permute);
75+
GKO_REGISTER_OPERATION(invert_permutation, csr::invert_permutation);
7676
GKO_REGISTER_OPERATION(calculate_max_nnz_per_row,
7777
csr::calculate_max_nnz_per_row);
7878
GKO_REGISTER_OPERATION(calculate_nonzeros_per_row,
@@ -394,8 +394,8 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::row_permute(
394394
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
395395
this->get_strategy());
396396

397-
exec->run(
398-
csr::make_row_permute(permutation_indices, this, permute_cpy.get()));
397+
exec->run(csr::make_row_permute(permutation_indices->get_const_data(), this,
398+
permute_cpy.get()));
399399
permute_cpy->make_srow();
400400
return std::move(permute_cpy);
401401
}
@@ -410,10 +410,15 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::column_permute(
410410
auto permute_cpy =
411411
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
412412
this->get_strategy());
413+
Array<IndexType> inv_permutation(exec, this->get_size()[1]);
413414

414-
exec->run(
415-
csr::make_column_permute(permutation_indices, this, permute_cpy.get()));
415+
exec->run(csr::make_invert_permutation(
416+
this->get_size()[1], permutation_indices->get_const_data(),
417+
inv_permutation.get_data()));
418+
exec->run(csr::make_inverse_column_permute(inv_permutation.get_const_data(),
419+
this, permute_cpy.get()));
416420
permute_cpy->make_srow();
421+
permute_cpy->sort_by_column_index();
417422
return std::move(permute_cpy);
418423
}
419424

@@ -429,8 +434,9 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::inverse_row_permute(
429434
Csr::create(exec, this->get_size(), this->get_num_stored_elements(),
430435
this->get_strategy());
431436

432-
exec->run(csr::make_inverse_row_permute(inverse_permutation_indices, this,
433-
inverse_permute_cpy.get()));
437+
exec->run(csr::make_inverse_row_permute(
438+
inverse_permutation_indices->get_const_data(), this,
439+
inverse_permute_cpy.get()));
434440
inverse_permute_cpy->make_srow();
435441
return std::move(inverse_permute_cpy);
436442
}
@@ -448,8 +454,10 @@ std::unique_ptr<LinOp> Csr<ValueType, IndexType>::inverse_column_permute(
448454
this->get_strategy());
449455

450456
exec->run(csr::make_inverse_column_permute(
451-
inverse_permutation_indices, this, inverse_permute_cpy.get()));
457+
inverse_permutation_indices->get_const_data(), this,
458+
inverse_permute_cpy.get()));
452459
inverse_permute_cpy->make_srow();
460+
inverse_permute_cpy->sort_by_column_index();
453461
return std::move(inverse_permute_cpy);
454462
}
455463

core/matrix/csr_kernels.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,29 +131,28 @@ namespace kernels {
131131

132132
#define GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL(ValueType, IndexType) \
133133
void row_permute(std::shared_ptr<const DefaultExecutor> exec, \
134-
const Array<IndexType> *permutation_indices, \
134+
const IndexType *permutation_indices, \
135135
const matrix::Csr<ValueType, IndexType> *orig, \
136136
matrix::Csr<ValueType, IndexType> *row_permuted)
137137

138-
#define GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL(ValueType, IndexType) \
139-
void column_permute(std::shared_ptr<const DefaultExecutor> exec, \
140-
const Array<IndexType> *permutation_indices, \
141-
const matrix::Csr<ValueType, IndexType> *orig, \
142-
matrix::Csr<ValueType, IndexType> *column_permuted)
143-
144138
#define GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL(ValueType, IndexType) \
145139
void inverse_row_permute(std::shared_ptr<const DefaultExecutor> exec, \
146-
const Array<IndexType> *permutation_indices, \
140+
const IndexType *permutation_indices, \
147141
const matrix::Csr<ValueType, IndexType> *orig, \
148142
matrix::Csr<ValueType, IndexType> *row_permuted)
149143

150144
#define GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL(ValueType, IndexType) \
151145
void inverse_column_permute( \
152146
std::shared_ptr<const DefaultExecutor> exec, \
153-
const Array<IndexType> *permutation_indices, \
147+
const IndexType *permutation_indices, \
154148
const matrix::Csr<ValueType, IndexType> *orig, \
155149
matrix::Csr<ValueType, IndexType> *column_permuted)
156150

151+
#define GKO_DECLARE_INVERT_PERMUTATION_KERNEL(IndexType) \
152+
void invert_permutation( \
153+
std::shared_ptr<const DefaultExecutor> exec, size_type size, \
154+
const IndexType *permutation_indices, IndexType *inv_permutation)
155+
157156
#define GKO_DECLARE_CSR_CALCULATE_MAX_NNZ_PER_ROW_KERNEL(ValueType, IndexType) \
158157
void calculate_max_nnz_per_row( \
159158
std::shared_ptr<const DefaultExecutor> exec, \
@@ -210,11 +209,11 @@ namespace kernels {
210209
template <typename ValueType, typename IndexType> \
211210
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL(ValueType, IndexType); \
212211
template <typename ValueType, typename IndexType> \
213-
GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL(ValueType, IndexType); \
214-
template <typename ValueType, typename IndexType> \
215212
GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL(ValueType, IndexType); \
216213
template <typename ValueType, typename IndexType> \
217214
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL(ValueType, IndexType); \
215+
template <typename IndexType> \
216+
GKO_DECLARE_INVERT_PERMUTATION_KERNEL(IndexType); \
218217
template <typename ValueType, typename IndexType> \
219218
GKO_DECLARE_CSR_CALCULATE_MAX_NNZ_PER_ROW_KERNEL(ValueType, IndexType); \
220219
template <typename ValueType, typename IndexType> \

cuda/matrix/csr_kernels.cu

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,45 +1143,86 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
11431143
GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL);
11441144

11451145

1146+
template <typename IndexType>
1147+
void invert_permutation(std::shared_ptr<const DefaultExecutor> exec,
1148+
size_type size, const IndexType *permutation_indices,
1149+
IndexType *inv_permutation)
1150+
{
1151+
auto num_blocks = ceildiv(size, default_block_size);
1152+
inv_permutation_kernel<<<num_blocks, default_block_size>>>(
1153+
size, permutation_indices, inv_permutation);
1154+
}
1155+
1156+
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_INVERT_PERMUTATION_KERNEL);
1157+
1158+
11461159
template <typename ValueType, typename IndexType>
11471160
void row_permute(std::shared_ptr<const CudaExecutor> exec,
1148-
const Array<IndexType> *permutation_indices,
1161+
const IndexType *perm,
11491162
const matrix::Csr<ValueType, IndexType> *orig,
11501163
matrix::Csr<ValueType, IndexType> *row_permuted)
1151-
GKO_NOT_IMPLEMENTED;
1164+
{
1165+
auto num_rows = orig->get_size()[0];
1166+
auto count_num_blocks = ceildiv(num_rows, default_block_size);
1167+
row_ptr_permute_kernel<<<count_num_blocks, default_block_size>>>(
1168+
num_rows, perm, orig->get_const_row_ptrs(),
1169+
row_permuted->get_row_ptrs());
1170+
components::prefix_sum(exec, row_permuted->get_row_ptrs(), num_rows + 1);
1171+
auto copy_num_blocks =
1172+
ceildiv(num_rows, default_block_size / config::warp_size);
1173+
row_permute_kernel<config::warp_size>
1174+
<<<copy_num_blocks, default_block_size>>>(
1175+
num_rows, perm, orig->get_const_row_ptrs(),
1176+
orig->get_const_col_idxs(), as_cuda_type(orig->get_const_values()),
1177+
row_permuted->get_row_ptrs(), row_permuted->get_col_idxs(),
1178+
as_cuda_type(row_permuted->get_values()));
1179+
}
11521180

11531181
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
11541182
GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL);
11551183

11561184

1157-
template <typename ValueType, typename IndexType>
1158-
void column_permute(std::shared_ptr<const CudaExecutor> exec,
1159-
const Array<IndexType> *permutation_indices,
1160-
const matrix::Csr<ValueType, IndexType> *orig,
1161-
matrix::Csr<ValueType, IndexType> *column_permuted)
1162-
GKO_NOT_IMPLEMENTED;
1163-
1164-
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
1165-
GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL);
1166-
1167-
11681185
template <typename ValueType, typename IndexType>
11691186
void inverse_row_permute(std::shared_ptr<const CudaExecutor> exec,
1170-
const Array<IndexType> *permutation_indices,
1187+
const IndexType *perm,
11711188
const matrix::Csr<ValueType, IndexType> *orig,
11721189
matrix::Csr<ValueType, IndexType> *row_permuted)
1173-
GKO_NOT_IMPLEMENTED;
1190+
{
1191+
auto num_rows = orig->get_size()[0];
1192+
auto count_num_blocks = ceildiv(num_rows, default_block_size);
1193+
inv_row_ptr_permute_kernel<<<count_num_blocks, default_block_size>>>(
1194+
num_rows, perm, orig->get_const_row_ptrs(),
1195+
row_permuted->get_row_ptrs());
1196+
components::prefix_sum(exec, row_permuted->get_row_ptrs(), num_rows + 1);
1197+
auto copy_num_blocks =
1198+
ceildiv(num_rows, default_block_size / config::warp_size);
1199+
inv_row_permute_kernel<config::warp_size>
1200+
<<<copy_num_blocks, default_block_size>>>(
1201+
num_rows, perm, orig->get_const_row_ptrs(),
1202+
orig->get_const_col_idxs(), as_cuda_type(orig->get_const_values()),
1203+
row_permuted->get_row_ptrs(), row_permuted->get_col_idxs(),
1204+
as_cuda_type(row_permuted->get_values()));
1205+
}
11741206

11751207
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
11761208
GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL);
11771209

11781210

11791211
template <typename ValueType, typename IndexType>
11801212
void inverse_column_permute(std::shared_ptr<const CudaExecutor> exec,
1181-
const Array<IndexType> *permutation_indices,
1213+
const IndexType *perm,
11821214
const matrix::Csr<ValueType, IndexType> *orig,
11831215
matrix::Csr<ValueType, IndexType> *column_permuted)
1184-
GKO_NOT_IMPLEMENTED;
1216+
{
1217+
auto num_rows = orig->get_size()[0];
1218+
auto nnz = orig->get_num_stored_elements();
1219+
auto num_blocks = ceildiv(std::max(num_rows, nnz), default_block_size);
1220+
col_permute_kernel<<<num_blocks, default_block_size>>>(
1221+
num_rows, nnz, perm, orig->get_const_row_ptrs(),
1222+
orig->get_const_col_idxs(), as_cuda_type(orig->get_const_values()),
1223+
column_permuted->get_row_ptrs(), column_permuted->get_col_idxs(),
1224+
as_cuda_type(column_permuted->get_values()));
1225+
}
11851226

11861227
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
11871228
GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL);

0 commit comments

Comments
 (0)