Skip to content

Commit b653d3b

Browse files
pratikvnyhmtsaiMarcelKoch
committed
Review updates
Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com> Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
1 parent 6459e2f commit b653d3b

14 files changed

Lines changed: 177 additions & 183 deletions

common/cuda_hip/base/batch_multi_vector_kernels.hpp.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ __global__ __launch_bounds__(
104104

105105

106106
template <typename Group, typename ValueType>
107-
__device__ __forceinline__ void single_rhs_compute_dot(Group subgroup,
107+
__device__ __forceinline__ void single_rhs_compute_conj_dot(Group subgroup,
108108
const int num_rows,
109109
const ValueType* x,
110110
const ValueType* y,

common/cuda_hip/preconditioner/batch_identity.hpp.inc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ public:
4747

4848
__device__ __forceinline__ void generate(
4949
size_type,
50-
const gko::batch::matrix::ell::batch_item<const ValueType, gko::int32>&,
50+
const gko::batch::matrix::ell::batch_item<const ValueType,
51+
const gko::int32>&,
5152
ValueType*)
5253
{}
5354

common/cuda_hip/solver/batch_bicgstab_kernels.hpp.inc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ __device__ __forceinline__ void initialize(
3838
const ValueType* const x_global_entry, ValueType& rho_old, ValueType& omega,
3939
ValueType& alpha, ValueType* const x_shared_entry,
4040
ValueType* const r_shared_entry, ValueType* const r_hat_shared_entry,
41-
ValueType* const p_shared_entry, ValueType* const v_shared_entry,
41+
ValueType* const p_shared_entry, ValueType* const p_hat_shared_entry,
42+
ValueType* const v_shared_entry,
4243
typename gko::remove_complex<ValueType>& rhs_norm,
4344
typename gko::remove_complex<ValueType>& res_norm)
4445
{
@@ -70,6 +71,7 @@ __device__ __forceinline__ void initialize(
7071
for (int iz = threadIdx.x; iz < num_rows; iz += blockDim.x) {
7172
r_hat_shared_entry[iz] = r_shared_entry[iz];
7273
p_shared_entry[iz] = zero<ValueType>();
74+
p_hat_shared_entry[iz] = zero<ValueType>();
7375
v_shared_entry[iz] = zero<ValueType>();
7476
}
7577
}
@@ -82,8 +84,8 @@ __device__ __forceinline__ void update_p(
8284
const ValueType* const r_shared_entry,
8385
const ValueType* const v_shared_entry, ValueType* const p_shared_entry)
8486
{
87+
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
8588
for (int r = threadIdx.x; r < num_rows; r += blockDim.x) {
86-
const ValueType beta = (rho_new / rho_old) * (alpha / omega);
8789
p_shared_entry[r] =
8890
r_shared_entry[r] +
8991
beta * (p_shared_entry[r] - omega * v_shared_entry[r]);
@@ -97,8 +99,8 @@ __device__ __forceinline__ void compute_alpha(
9799
const ValueType* const v_shared_entry, ValueType& alpha)
98100
{
99101
if (threadIdx.x / config::warp_size == 0) {
100-
single_rhs_compute_dot(subgroup, num_rows, r_hat_shared_entry,
101-
v_shared_entry, alpha);
102+
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_shared_entry,
103+
v_shared_entry, alpha);
102104
}
103105
__syncthreads();
104106
if (threadIdx.x == 0) {
@@ -126,11 +128,11 @@ __device__ __forceinline__ void compute_omega(
126128
const ValueType* const s_shared_entry, ValueType& temp, ValueType& omega)
127129
{
128130
if (threadIdx.x / config::warp_size == 0) {
129-
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
130-
s_shared_entry, omega);
131+
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
132+
s_shared_entry, omega);
131133
} else if (threadIdx.x / config::warp_size == 1) {
132-
single_rhs_compute_dot(subgroup, num_rows, t_shared_entry,
133-
t_shared_entry, temp);
134+
single_rhs_compute_conj_dot(subgroup, num_rows, t_shared_entry,
135+
t_shared_entry, temp);
134136
}
135137

136138
__syncthreads();
@@ -278,10 +280,12 @@ __global__ void apply_kernel(
278280
// compute residual norms
279281
// r_hat = r
280282
// p = 0
283+
// p_hat = 0
281284
// v = 0
282285
initialize(subgroup, num_rows, mat_entry, b_entry_ptr, x_gl_entry_ptr,
283286
rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh,
284-
r_hat_sh, p_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0]);
287+
r_hat_sh, p_sh, p_hat_sh, v_sh, norms_rhs_sh[0],
288+
norms_res_sh[0]);
285289
__syncthreads();
286290

287291
// stopping criterion object
@@ -296,8 +300,8 @@ __global__ void apply_kernel(
296300

297301
// rho_new = < r_hat , r > = (r_hat)' * (r)
298302
if (threadIdx.x / config::warp_size == 0) {
299-
single_rhs_compute_dot(subgroup, num_rows, r_hat_sh, r_sh,
300-
rho_new_sh[0]);
303+
single_rhs_compute_conj_dot(subgroup, num_rows, r_hat_sh, r_sh,
304+
rho_new_sh[0]);
301305
}
302306
__syncthreads();
303307

core/solver/batch_bicgstab_kernels.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
115115
}
116116
// align global memory chunks
117117
sconf.gmem_stride_bytes =
118-
gmem_stride > 0 ? ((gmem_stride - 1) / align_bytes + 1) * align_bytes
119-
: 0;
118+
gmem_stride > 0 ? ceildiv(gmem_stride, align_bytes) * align_bytes : 0;
120119
}
121120

122121

@@ -143,8 +142,8 @@ void set_gmem_stride_bytes(storage_config& sconf,
143142
* - rhs_norms
144143
* - res_norms
145144
*
146-
* @param shared_mem_per_blk The amount of shared memory per block to use for
147-
* keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
145+
* @param available_shared_mem The amount of shared memory per block to use
146+
* for keeping intermediate vectors. In case keeping the matrix in L1 cache etc.
148147
* should be prioritized, the cache configuration must be updated separately
149148
* and the needed space should be subtracted before passing to this
150149
* function.
@@ -154,7 +153,7 @@ void set_gmem_stride_bytes(storage_config& sconf,
154153
* @return A struct containing allocation information specific to Bicgstab.
155154
*/
156155
template <typename Prectype, typename ValueType, int align_bytes = 32>
157-
storage_config compute_shared_storage(const int shared_mem_per_blk,
156+
storage_config compute_shared_storage(const int available_shared_mem,
158157
const int num_rows, const int num_nz,
159158
const int num_rhs)
160159
{
@@ -163,10 +162,11 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
163162
const int num_main_vecs = 9;
164163
const int prec_storage =
165164
Prectype::dynamic_work_size(num_rows, num_nz) * sizeof(ValueType);
166-
int rem_shared = shared_mem_per_blk;
167-
// Set default values. All vecs are in global.
165+
int rem_shared = available_shared_mem;
166+
// Set default values. Initially all vecs are in global memory.
167+
// {prec_shared, n_shared, n_global, gmem_stride_bytes, padded_vec_len}
168168
storage_config sconf{false, 0, num_main_vecs, 0, num_rows};
169-
// If available shared mem, is zero, set all vecs to global.
169+
// If available shared mem is zero, set all vecs to global.
170170
if (rem_shared <= 0) {
171171
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
172172
return sconf;
@@ -177,13 +177,13 @@ storage_config compute_shared_storage(const int shared_mem_per_blk,
177177
const int num_vecs_shared = min(initial_vecs_available, num_main_vecs);
178178
sconf.n_shared += num_vecs_shared;
179179
sconf.n_global -= num_vecs_shared;
180+
rem_shared -= num_vecs_shared * vec_size;
180181
// Set the storage configuration with preconditioner workspace in global if
181182
// there are any vectors in global memory.
182183
if (sconf.n_global > 0) {
183184
set_gmem_stride_bytes<align_bytes>(sconf, vec_size, prec_storage);
184185
return sconf;
185186
}
186-
rem_shared -= num_vecs_shared * vec_size;
187187
// If more shared memory space is available and preconditioner workspace is
188188
// needed, enable preconditioner workspace to use shared memory.
189189
if (rem_shared >= prec_storage && prec_storage > 0) {

cuda/matrix/batch_struct.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ get_batch_struct(batch::matrix::Dense<ValueType>* const op)
9292
* Generates an immutable uniform batch struct from a batch of ell matrices.
9393
*/
9494
template <typename ValueType, typename IndexType>
95-
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>, IndexType>
95+
inline batch::matrix::ell::uniform_batch<const cuda_type<ValueType>,
96+
const IndexType>
9697
get_batch_struct(const batch::matrix::Ell<ValueType, IndexType>* const op)
9798
{
9899
return {as_cuda_type(op->get_const_values()),

cuda/solver/batch_bicgstab_kernels.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,7 @@ int get_num_threads_per_block(std::shared_ptr<const DefaultExecutor> exec,
101101
cudaDeviceGetAttribute(&max_regs_blk, cudaDevAttrMaxRegistersPerBlock,
102102
exec->get_device_id());
103103
const int max_threads_regs =
104-
((max_regs_blk /
105-
static_cast<int>((static_cast<double>(num_regs_used)))) /
106-
warp_sz) *
107-
warp_sz;
104+
((max_regs_blk / static_cast<int>(num_regs_used)) / warp_sz) * warp_sz;
108105
int max_threads = std::min(max_threads_regs, device_max_threads);
109106
max_threads = max_threads <= 1024 ? max_threads : 1024;
110107
return std::min(num_warps * warp_sz, max_threads);

dpcpp/base/batch_multi_vector_kernels.dp.cpp

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
8787
long max_group_size =
8888
device.get_info<sycl::info::device::max_work_group_size>();
8989
int group_size =
90-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
90+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
9191
max_group_size);
9292

9393
const dim3 block(group_size);
@@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
141141
long max_group_size =
142142
device.get_info<sycl::info::device::max_work_group_size>();
143143
int group_size =
144-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
144+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
145145
max_group_size);
146146

147147
const dim3 block(group_size);
@@ -202,49 +202,45 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
202202
long max_group_size =
203203
device.get_info<sycl::info::device::max_work_group_size>();
204204
int group_size =
205-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
205+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
206206
max_group_size);
207207

208208
const dim3 block(group_size);
209209
const dim3 grid(num_batches);
210210
if (x->get_common_size()[1] == 1) {
211211
exec->get_queue()->submit([&](sycl::handler& cgh) {
212212
cgh.parallel_for(
213-
sycl_nd_range(grid, block),
214-
[=](sycl::nd_item<3> item_ct1)
215-
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
216-
auto group = item_ct1.get_group();
217-
auto group_id = group.get_group_linear_id();
218-
const auto x_b =
219-
batch::extract_batch_item(x_ub, group_id);
220-
const auto y_b =
221-
batch::extract_batch_item(y_ub, group_id);
222-
const auto res_b =
223-
batch::extract_batch_item(res_ub, group_id);
224-
single_rhs_compute_dot_sg(x_b.num_rows, x_b.values,
225-
y_b.values, res_b.values[0],
226-
item_ct1);
227-
});
213+
sycl_nd_range(grid, block), [=
214+
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
215+
max_subgroup_size)]] {
216+
auto group = item_ct1.get_group();
217+
auto group_id = group.get_group_linear_id();
218+
const auto x_b = batch::extract_batch_item(x_ub, group_id);
219+
const auto y_b = batch::extract_batch_item(y_ub, group_id);
220+
const auto res_b =
221+
batch::extract_batch_item(res_ub, group_id);
222+
single_rhs_compute_conj_dot_sg(x_b.num_rows, x_b.values,
223+
y_b.values, res_b.values[0],
224+
item_ct1);
225+
});
228226
});
229227
} else {
230228
// TODO: Remove reqd_sub_group size and use sycl::reduce_over_group
231229
exec->get_queue()->submit([&](sycl::handler& cgh) {
232230
cgh.parallel_for(
233-
sycl_nd_range(grid, block),
234-
[=](sycl::nd_item<3> item_ct1)
235-
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
236-
auto group = item_ct1.get_group();
237-
auto group_id = group.get_group_linear_id();
238-
const auto x_b =
239-
batch::extract_batch_item(x_ub, group_id);
240-
const auto y_b =
241-
batch::extract_batch_item(y_ub, group_id);
242-
const auto res_b =
243-
batch::extract_batch_item(res_ub, group_id);
244-
compute_gen_dot_product_kernel(
245-
x_b, y_b, res_b, item_ct1,
246-
[](auto val) { return val; });
247-
});
231+
sycl_nd_range(grid, block), [=
232+
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
233+
max_subgroup_size)]] {
234+
auto group = item_ct1.get_group();
235+
auto group_id = group.get_group_linear_id();
236+
const auto x_b = batch::extract_batch_item(x_ub, group_id);
237+
const auto y_b = batch::extract_batch_item(y_ub, group_id);
238+
const auto res_b =
239+
batch::extract_batch_item(res_ub, group_id);
240+
compute_gen_dot_product_kernel(
241+
x_b, y_b, res_b, item_ct1,
242+
[](auto val) { return val; });
243+
});
248244
});
249245
}
250246
}
@@ -270,27 +266,26 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
270266
long max_group_size =
271267
device.get_info<sycl::info::device::max_work_group_size>();
272268
int group_size =
273-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
269+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
274270
max_group_size);
275271

276272
const dim3 block(group_size);
277273
const dim3 grid(num_batches);
278274

279275
exec->get_queue()->submit([&](sycl::handler& cgh) {
280276
cgh.parallel_for(
281-
sycl_nd_range(grid, block),
282-
[=](sycl::nd_item<3> item_ct1)
283-
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
284-
auto group = item_ct1.get_group();
285-
auto group_id = group.get_group_linear_id();
286-
const auto x_b = batch::extract_batch_item(x_ub, group_id);
287-
const auto y_b = batch::extract_batch_item(y_ub, group_id);
288-
const auto res_b =
289-
batch::extract_batch_item(res_ub, group_id);
290-
compute_gen_dot_product_kernel(
291-
x_b, y_b, res_b, item_ct1,
292-
[](auto val) { return conj(val); });
293-
});
277+
sycl_nd_range(grid, block), [=
278+
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
279+
max_subgroup_size)]] {
280+
auto group = item_ct1.get_group();
281+
auto group_id = group.get_group_linear_id();
282+
const auto x_b = batch::extract_batch_item(x_ub, group_id);
283+
const auto y_b = batch::extract_batch_item(y_ub, group_id);
284+
const auto res_b = batch::extract_batch_item(res_ub, group_id);
285+
compute_gen_dot_product_kernel(
286+
x_b, y_b, res_b, item_ct1,
287+
[](auto val) { return conj(val); });
288+
});
294289
});
295290
}
296291

@@ -314,41 +309,39 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
314309
long max_group_size =
315310
device.get_info<sycl::info::device::max_work_group_size>();
316311
int group_size =
317-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
312+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
318313
max_group_size);
319314

320315
const dim3 block(group_size);
321316
const dim3 grid(num_batches);
322317
if (x->get_common_size()[1] == 1) {
323318
exec->get_queue()->submit([&](sycl::handler& cgh) {
324319
cgh.parallel_for(
325-
sycl_nd_range(grid, block),
326-
[=](sycl::nd_item<3> item_ct1)
327-
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
328-
auto group = item_ct1.get_group();
329-
auto group_id = group.get_group_linear_id();
330-
const auto x_b =
331-
batch::extract_batch_item(x_ub, group_id);
332-
const auto res_b =
333-
batch::extract_batch_item(res_ub, group_id);
334-
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
335-
res_b.values[0], item_ct1);
336-
});
320+
sycl_nd_range(grid, block), [=
321+
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
322+
max_subgroup_size)]] {
323+
auto group = item_ct1.get_group();
324+
auto group_id = group.get_group_linear_id();
325+
const auto x_b = batch::extract_batch_item(x_ub, group_id);
326+
const auto res_b =
327+
batch::extract_batch_item(res_ub, group_id);
328+
single_rhs_compute_norm2_sg(x_b.num_rows, x_b.values,
329+
res_b.values[0], item_ct1);
330+
});
337331
});
338332
} else {
339333
exec->get_queue()->submit([&](sycl::handler& cgh) {
340334
cgh.parallel_for(
341-
sycl_nd_range(grid, block),
342-
[=](sycl::nd_item<3> item_ct1)
343-
[[sycl::reqd_sub_group_size(max_subgroup_size)]] {
344-
auto group = item_ct1.get_group();
345-
auto group_id = group.get_group_linear_id();
346-
const auto x_b =
347-
batch::extract_batch_item(x_ub, group_id);
348-
const auto res_b =
349-
batch::extract_batch_item(res_ub, group_id);
350-
compute_norm2_kernel(x_b, res_b, item_ct1);
351-
});
335+
sycl_nd_range(grid, block), [=
336+
](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
337+
max_subgroup_size)]] {
338+
auto group = item_ct1.get_group();
339+
auto group_id = group.get_group_linear_id();
340+
const auto x_b = batch::extract_batch_item(x_ub, group_id);
341+
const auto res_b =
342+
batch::extract_batch_item(res_ub, group_id);
343+
compute_norm2_kernel(x_b, res_b, item_ct1);
344+
});
352345
});
353346
}
354347
}
@@ -372,7 +365,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
372365
long max_group_size =
373366
device.get_info<sycl::info::device::max_work_group_size>();
374367
int group_size =
375-
std::max(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
368+
std::min(ceildiv(num_rows, max_subgroup_size) * max_subgroup_size,
376369
max_group_size);
377370

378371
const dim3 block(group_size);

0 commit comments

Comments
 (0)