Skip to content

Add CUDA, HIP and DPCPP batch bicgstab kernels#1443

Merged
pratikvn merged 28 commits into
developfrom
batch-bicgstab-device
Nov 5, 2023
Merged

Add CUDA, HIP and DPCPP batch bicgstab kernels#1443
pratikvn merged 28 commits into
developfrom
batch-bicgstab-device

Conversation

@pratikvn

@pratikvn pratikvn commented Oct 26, 2023

Copy link
Copy Markdown
Member

This PR adds the batch bicgstab solver kernels for CUDA, HIP and DPCPP backends. Some additional single rhs vector kernels are also added into the batch multivector kernels.

TODO

  • Add DPCPP kernels

@pratikvn pratikvn added 1:ST:WIP This PR is a work in progress. Not ready for review. type:batched-functionality This is related to the batched functionality in Ginkgo labels Oct 26, 2023
@pratikvn pratikvn added this to the Release 1.7.0 milestone Oct 26, 2023
@pratikvn pratikvn self-assigned this Oct 26, 2023
@ginkgo-bot ginkgo-bot added reg:build This is related to the build system. reg:testing This is related to testing. mod:core This is related to the core module. mod:cuda This is related to the CUDA module. type:solver This is related to the solvers mod:hip This is related to the HIP module. labels Oct 26, 2023

@MarcelKoch MarcelKoch left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the kernels look good so far. I have mostly comments outside of those.

Here are some things to be tackled later:

  • use dispatch instead of manual switch
  • make reductions work with more than 1 warp

Comment thread cuda/solver/batch_bicgstab_kernels.cu Outdated
// Compute norms of rhs
single_rhs_compute_norm2(subgroup, num_rows, b_global_entry, rhs_norm);
}
__syncthreads();

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? The above code writes only to the norm.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Diverging paths between subwarps. To ensure consistency, I think it is good to synchronize them.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, they diverge, but I don't see how that would affect the following code. But I'm no expert on this, so I won't push anything here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not requesting any changes, but I wanted to elaborate on this a bit. I agree here, I think we could take a page from CUB's book, where they ensure synchronization always happens inside functions that require it (i.e. SpMVs and reductions) and are entirely absent from the code otherwise.
To make this work, you need a "default" work assignment (like the default for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) loop) and every time you read from values outside your own assigned set, you have a threadsync before, and if you write to values outside your set (also computing reductions), you have a threadsync after. This may even allow you to keep all values in registers most of the time, as long as you don't have huge blocks. But that is an optional detail.

Outside of this, there is also some potential for "kernel fusion" (i.e. removing the __syncthreads and computing directly on values in registers) by computing the dot product on the result of the SpMV, but I don't have a clear idea how large the runtime impact of that would be.

}
__syncthreads();

for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in the other kernels you are using r as index variable.

Comment thread core/solver/batch_bicgstab_kernels.hpp
Comment thread core/solver/batch_bicgstab_kernels.hpp
Comment thread dpcpp/solver/batch_bicgstab_kernels.dp.cpp Outdated

// template
// launch_apply_kernel<StopType, SIMDLEN, n_shared_total, sg_kernel_all>
if (num_rows <= 32 && n_shared_total == 10)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuda/hip uses 9 vectors in shmem. Why does this check for 10? Also the kernel only checks until n_shared_total == 9

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the strategy is slightly different. Here the count includes the prec_shared vector. The number of shared vectors is always 9, so you can only check until 9. If it is greater than 9, then you know that the prec is also in shared memory.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but isn't that what storage_config::prec_shared is there for?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is a bit easier with looking at n_shared as 10 vectors. Otherwise, prec_shared will need to be a template parameter as well. But I understand your point that it makes the cuda/dpcpp kernels more confusing to compare.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer the additional template parameter then. But that might also be done later.

Comment thread dpcpp/solver/batch_bicgstab_kernels.dp.cpp Outdated
Comment thread core/solver/batch_bicgstab_kernels.hpp
Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc Outdated
@pratikvn

Copy link
Copy Markdown
Member Author

format!

Comment thread common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc Outdated
Comment thread core/test/utils/batch_helpers.hpp Outdated
Comment thread cuda/base/kernel_config.cuh Outdated
Comment on lines +53 to +57
if (sizeof(ValueType) == 4) {
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeFourByte);
} else if (sizeof(ValueType) % 8 == 0) {
cudaDeviceSetSharedMemConfig(cudaSharedMemBankSizeEightByte);
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do they have TwoByte? Otherwise, it may introduce some troubles when adding half

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I dont think that is necessary. Only a value of 8 is recommended for double to avoid bank conflicts. You can just set it to 4 for half I think .

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of problematic - it configures the entire device, but we only run on a single stream. At the very least, we need to revert it after the kernel finished, otherwise we interfere with other applications' performance

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a scope guard similar to the one for the device id could work here.

}
}
x.values[tidx * x.stride] = temp;
x[tidx] = temp;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete stride?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just use the plain pointers as arguments here. I guess technically we should have another stride parameter to the function, but I think that is unnecessary for now and we can add that when we support stride later.

Comment thread dpcpp/preconditioner/batch_identity.hpp.inc Outdated
Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc Outdated
Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc Outdated
Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc
Comment on lines +264 to +272
ValueType values[5];
real_type reals[2];
rho_old_sh = &values[0];
rho_new_sh = &values[1];
alpha_sh = &values[2];
omega_sh = &values[3];
temp_sh = &values[4];
norms_rhs_sh = &reals[0];
norms_res_sh = &reals[1];

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

segfault.
values and reals will be destroies after else.

Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc Outdated
Comment thread core/solver/batch_bicgstab_kernels.hpp Outdated
{
using real_type = gko::remove_complex<value_type>;
const size_type num_batch_items = mat.num_batch_items;
constexpr int align_multiple = 8;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, that alignment is only relevant if the vectors are stored in global memory, right?

@yhmtsai yhmtsai left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

except for the shared_memory in dpcpp and storage computation (not reviewed yet), others LGTM

Comment thread common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Comment thread common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc Outdated
Comment thread dpcpp/base/batch_multi_vector_kernels.dp.cpp
Comment thread dpcpp/solver/batch_bicgstab_kernels.dp.cpp Outdated
Comment on lines +35 to +44
__dpct_inline__ void initialize(
const int num_rows, const BatchMatrixType_entry& mat_global_entry,
const ValueType* const b_global_entry,
const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega,
ValueType& alpha, ValueType* const x_shared_entry,
ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry,
ValueType* const p_shared_entry, ValueType* const v_shared_entry,
typename gko::remove_complex<ValueType>& rhs_norm,
typename gko::remove_complex<ValueType>& res_norm,
sycl::nd_item<3> item_ct1)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think from CUDA, it will use __ldg() automatically if it is const __restrict__*. That's why we do not need to use __ldg

Comment thread dpcpp/solver/batch_bicgstab_kernels.hpp.inc Outdated
Comment thread test/solver/batch_bicgstab_kernels.cpp Outdated
Comment thread core/solver/batch_bicgstab_kernels.hpp Outdated
@pratikvn pratikvn force-pushed the batch-bicgstab-device branch from b8def5b to b653d3b Compare October 30, 2023 13:38
Comment thread common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc
Comment thread cuda/solver/batch_bicgstab_kernels.cu
Comment thread cuda/solver/batch_bicgstab_kernels.cu Outdated
Comment thread dpcpp/base/batch_multi_vector_kernels.hpp.inc Outdated
Comment thread dpcpp/solver/batch_bicgstab_kernels.dp.cpp Outdated
Comment thread hip/solver/batch_bicgstab_kernels.hip.cpp Outdated
Comment thread hip/solver/batch_bicgstab_kernels.hip.cpp
Comment thread hip/solver/batch_bicgstab_kernels.hip.cpp Outdated
Comment on lines +95 to +96
inline batch::matrix::ell::uniform_batch<const hip_type<ValueType>,
const IndexType>

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the to_const usually face this issue.
Could you check the other const version also correct?
If all related to this issue are not in public interface, it are not urgent before release

Comment thread test/solver/batch_bicgstab_kernels.cpp Outdated
@pratikvn pratikvn force-pushed the batch-bicgstab-device branch from b653d3b to fb50eaf Compare October 30, 2023 21:37
@pratikvn pratikvn added 1:ST:ready-for-review This PR is ready for review and removed 1:ST:WIP This PR is a work in progress. Not ready for review. labels Oct 31, 2023
@pratikvn pratikvn force-pushed the batch-bicgstab branch 3 times, most recently from 8982811 to 28560a5 Compare October 31, 2023 14:04
Comment thread hip/base/exception.hip.hpp Outdated
Base automatically changed from batch-bicgstab to develop November 1, 2023 09:06
Comment thread common/cuda_hip/stop/batch_criteria.hpp.inc Outdated
Comment thread core/base/batch_utilities.hpp Outdated
Comment thread core/device_hooks/common_kernels.inc.cpp
@pratikvn pratikvn force-pushed the batch-bicgstab-device branch from fb50eaf to d21d5fd Compare November 1, 2023 10:57
@pratikvn

pratikvn commented Nov 1, 2023

Copy link
Copy Markdown
Member Author

format!

@pratikvn pratikvn force-pushed the batch-bicgstab-device branch from f48179b to f600023 Compare November 5, 2023 16:11
Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com>
@pratikvn pratikvn force-pushed the batch-bicgstab-device branch from f600023 to 79e68b3 Compare November 5, 2023 16:17
@pratikvn

pratikvn commented Nov 5, 2023

Copy link
Copy Markdown
Member Author

format!

Co-authored-by: Pratik Nayak <pratikvn@pm.me>
@pratikvn

pratikvn commented Nov 5, 2023

Copy link
Copy Markdown
Member Author

Turns out the no-circular-deps job is terribly slow. I verified (with the same config and flags as the job, inside the same image with a docker container) that it builds successfully with GINKGO_CHECK_CIRCULAR_DEPS=on, so I will go ahead and merge this.

@pratikvn pratikvn merged commit 47b3267 into develop Nov 5, 2023
@pratikvn pratikvn deleted the batch-bicgstab-device branch November 5, 2023 23:44
@sonarqubecloud

sonarqubecloud Bot commented Nov 6, 2023

Copy link
Copy Markdown

Kudos, SonarCloud Quality Gate passed!    Quality Gate passed

Bug A 0 Bugs
Vulnerability A 0 Vulnerabilities
Security Hotspot A 0 Security Hotspots
Code Smell A 7 Code Smells

98.6% 98.6% Coverage
17.7% 17.7% Duplication

warning The version of Java (11.0.3) you have used to run this analysis is deprecated and we will stop accepting it soon. Please update to at least Java 17.
Read more here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

1:ST:ready-for-review This PR is ready for review 1:ST:run-full-test mod:core This is related to the core module. mod:cuda This is related to the CUDA module. mod:hip This is related to the HIP module. reg:build This is related to the build system. reg:testing This is related to testing. type:batched-functionality This is related to the batched functionality in Ginkgo type:solver This is related to the solvers