@@ -1143,45 +1143,86 @@ GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(
11431143 GKO_DECLARE_CSR_CONJ_TRANSPOSE_KERNEL );
11441144
11451145
1146+ template <typename IndexType>
1147+ void invert_permutation (std::shared_ptr<const DefaultExecutor> exec,
1148+ size_type size, const IndexType *permutation_indices,
1149+ IndexType *inv_permutation)
1150+ {
1151+ auto num_blocks = ceildiv (size, default_block_size);
1152+ inv_permutation_kernel<<<num_blocks, default_block_size>>> (
1153+ size, permutation_indices, inv_permutation);
1154+ }
1155+
1156+ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE (GKO_DECLARE_INVERT_PERMUTATION_KERNEL );
1157+
1158+
11461159template <typename ValueType, typename IndexType>
11471160void row_permute (std::shared_ptr<const CudaExecutor> exec,
1148- const Array< IndexType> *permutation_indices ,
1161+ const IndexType *perm ,
11491162 const matrix::Csr<ValueType, IndexType> *orig,
11501163 matrix::Csr<ValueType, IndexType> *row_permuted)
1151- GKO_NOT_IMPLEMENTED;
1164+ {
1165+ auto num_rows = orig->get_size ()[0 ];
1166+ auto count_num_blocks = ceildiv (num_rows, default_block_size);
1167+ row_ptr_permute_kernel<<<count_num_blocks, default_block_size>>> (
1168+ num_rows, perm, orig->get_const_row_ptrs (),
1169+ row_permuted->get_row_ptrs ());
1170+ components::prefix_sum (exec, row_permuted->get_row_ptrs (), num_rows + 1 );
1171+ auto copy_num_blocks =
1172+ ceildiv (num_rows, default_block_size / config::warp_size);
1173+ row_permute_kernel<config::warp_size>
1174+ <<<copy_num_blocks, default_block_size>>> (
1175+ num_rows, perm, orig->get_const_row_ptrs (),
1176+ orig->get_const_col_idxs (), as_cuda_type (orig->get_const_values ()),
1177+ row_permuted->get_row_ptrs (), row_permuted->get_col_idxs (),
1178+ as_cuda_type (row_permuted->get_values ()));
1179+ }
11521180
11531181GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE (
11541182 GKO_DECLARE_CSR_ROW_PERMUTE_KERNEL );
11551183
11561184
1157- template <typename ValueType, typename IndexType>
1158- void column_permute (std::shared_ptr<const CudaExecutor> exec,
1159- const Array<IndexType> *permutation_indices,
1160- const matrix::Csr<ValueType, IndexType> *orig,
1161- matrix::Csr<ValueType, IndexType> *column_permuted)
1162- GKO_NOT_IMPLEMENTED;
1163-
1164- GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE (
1165- GKO_DECLARE_CSR_COLUMN_PERMUTE_KERNEL );
1166-
1167-
11681185template <typename ValueType, typename IndexType>
11691186void inverse_row_permute (std::shared_ptr<const CudaExecutor> exec,
1170- const Array< IndexType> *permutation_indices ,
1187+ const IndexType *perm ,
11711188 const matrix::Csr<ValueType, IndexType> *orig,
11721189 matrix::Csr<ValueType, IndexType> *row_permuted)
1173- GKO_NOT_IMPLEMENTED;
1190+ {
1191+ auto num_rows = orig->get_size ()[0 ];
1192+ auto count_num_blocks = ceildiv (num_rows, default_block_size);
1193+ inv_row_ptr_permute_kernel<<<count_num_blocks, default_block_size>>> (
1194+ num_rows, perm, orig->get_const_row_ptrs (),
1195+ row_permuted->get_row_ptrs ());
1196+ components::prefix_sum (exec, row_permuted->get_row_ptrs (), num_rows + 1 );
1197+ auto copy_num_blocks =
1198+ ceildiv (num_rows, default_block_size / config::warp_size);
1199+ inv_row_permute_kernel<config::warp_size>
1200+ <<<copy_num_blocks, default_block_size>>> (
1201+ num_rows, perm, orig->get_const_row_ptrs (),
1202+ orig->get_const_col_idxs (), as_cuda_type (orig->get_const_values ()),
1203+ row_permuted->get_row_ptrs (), row_permuted->get_col_idxs (),
1204+ as_cuda_type (row_permuted->get_values ()));
1205+ }
11741206
11751207GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE (
11761208 GKO_DECLARE_CSR_INVERSE_ROW_PERMUTE_KERNEL );
11771209
11781210
11791211template <typename ValueType, typename IndexType>
11801212void inverse_column_permute (std::shared_ptr<const CudaExecutor> exec,
1181- const Array< IndexType> *permutation_indices ,
1213+ const IndexType *perm ,
11821214 const matrix::Csr<ValueType, IndexType> *orig,
11831215 matrix::Csr<ValueType, IndexType> *column_permuted)
1184- GKO_NOT_IMPLEMENTED;
1216+ {
1217+ auto num_rows = orig->get_size ()[0 ];
1218+ auto nnz = orig->get_num_stored_elements ();
1219+ auto num_blocks = ceildiv (std::max (num_rows, nnz), default_block_size);
1220+ col_permute_kernel<<<num_blocks, default_block_size>>> (
1221+ num_rows, nnz, perm, orig->get_const_row_ptrs (),
1222+ orig->get_const_col_idxs (), as_cuda_type (orig->get_const_values ()),
1223+ column_permuted->get_row_ptrs (), column_permuted->get_col_idxs (),
1224+ as_cuda_type (column_permuted->get_values ()));
1225+ }
11851226
11861227GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE (
11871228 GKO_DECLARE_CSR_INVERSE_COLUMN_PERMUTE_KERNEL );
0 commit comments