Skip to content

Commit f600023

Browse files
pratikvnyhmtsai
andcommitted
Review updates
Co-authored-by: Yu-Hsiang Tsai <yhmtsai@gmail.com>
1 parent 1bc6d83 commit f600023

3 files changed

Lines changed: 70 additions & 117 deletions

File tree

dpcpp/solver/batch_bicgstab_kernels.dp.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class KernelCaller {
9494
{}
9595

9696
template <typename StopType, const int subgroup_size,
97-
const int n_shared_total, const bool sg_kernel_all,
98-
typename PrecType, typename LogType, typename BatchMatrixType>
97+
const int n_shared_total, typename PrecType, typename LogType,
98+
typename BatchMatrixType>
9999
__dpct_inline__ void launch_apply_kernel(
100100
const gko::kernels::batch_bicgstab::storage_config& sconf,
101101
LogType& logger, PrecType& prec, const BatchMatrixType mat,
@@ -118,9 +118,10 @@ class KernelCaller {
118118
slm_values(sycl::range<1>(shared_size), cgh);
119119

120120
cgh.parallel_for(
121-
sycl_nd_range(grid, block),
122-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
123-
subgroup_size)]] [[intel::kernel_args_restrict]] {
121+
sycl_nd_range(grid, block), [=
122+
](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(
123+
subgroup_size)]] [
124+
[intel::kernel_args_restrict]] {
124125
auto batch_id = item_ct1.get_group_linear_id();
125126
const auto mat_global_entry =
126127
gko::batch::matrix::extract_batch_item(mat, batch_id);
@@ -130,7 +131,7 @@ class KernelCaller {
130131
ValueType* const x_global_entry =
131132
gko::batch::multi_vector::batch_item_ptr(
132133
x_values, 1, num_rows, batch_id);
133-
apply_kernel<StopType, n_shared_total, sg_kernel_all>(
134+
apply_kernel<StopType, n_shared_total>(
134135
sconf, max_iters, res_tol, logger, prec,
135136
mat_global_entry, b_global_entry, x_global_entry,
136137
num_rows, mat.get_single_item_num_nnz(),
@@ -197,67 +198,67 @@ class KernelCaller {
197198
// launch_apply_kernel<StopType, subgroup_size, n_shared_total,
198199
// sg_kernel_all>
199200
if (num_rows <= 32 && n_shared_total == 10) {
200-
launch_apply_kernel<StopType, 32, 10, true>(
201+
launch_apply_kernel<StopType, 32, 10>(
201202
sconf, logger, prec, mat, b.values, x.values, workspace_data,
202203
group_size, shared_size);
203204
} else if (num_rows <= 256 && n_shared_total == 10) {
204-
launch_apply_kernel<StopType, 32, 10, true>(
205+
launch_apply_kernel<StopType, 32, 10>(
205206
sconf, logger, prec, mat, b.values, x.values, workspace_data,
206207
group_size, shared_size);
207208
} else {
208209
switch (n_shared_total) {
209210
case 0:
210-
launch_apply_kernel<StopType, 32, 0, true>(
211+
launch_apply_kernel<StopType, 32, 0>(
211212
sconf, logger, prec, mat, b.values, x.values,
212213
workspace_data, group_size, shared_size);
213214
break;
214215
case 1:
215-
launch_apply_kernel<StopType, 32, 1, true>(
216+
launch_apply_kernel<StopType, 32, 1>(
216217
sconf, logger, prec, mat, b.values, x.values,
217218
workspace_data, group_size, shared_size);
218219
break;
219220
case 2:
220-
launch_apply_kernel<StopType, 32, 2, true>(
221+
launch_apply_kernel<StopType, 32, 2>(
221222
sconf, logger, prec, mat, b.values, x.values,
222223
workspace_data, group_size, shared_size);
223224
break;
224225
case 3:
225-
launch_apply_kernel<StopType, 32, 3, true>(
226+
launch_apply_kernel<StopType, 32, 3>(
226227
sconf, logger, prec, mat, b.values, x.values,
227228
workspace_data, group_size, shared_size);
228229
break;
229230
case 4:
230-
launch_apply_kernel<StopType, 32, 4, true>(
231+
launch_apply_kernel<StopType, 32, 4>(
231232
sconf, logger, prec, mat, b.values, x.values,
232233
workspace_data, group_size, shared_size);
233234
break;
234235
case 5:
235-
launch_apply_kernel<StopType, 32, 5, true>(
236+
launch_apply_kernel<StopType, 32, 5>(
236237
sconf, logger, prec, mat, b.values, x.values,
237238
workspace_data, group_size, shared_size);
238239
break;
239240
case 6:
240-
launch_apply_kernel<StopType, 32, 6, true>(
241+
launch_apply_kernel<StopType, 32, 6>(
241242
sconf, logger, prec, mat, b.values, x.values,
242243
workspace_data, group_size, shared_size);
243244
break;
244245
case 7:
245-
launch_apply_kernel<StopType, 32, 7, true>(
246+
launch_apply_kernel<StopType, 32, 7>(
246247
sconf, logger, prec, mat, b.values, x.values,
247248
workspace_data, group_size, shared_size);
248249
break;
249250
case 8:
250-
launch_apply_kernel<StopType, 32, 8, true>(
251+
launch_apply_kernel<StopType, 32, 8>(
251252
sconf, logger, prec, mat, b.values, x.values,
252253
workspace_data, group_size, shared_size);
253254
break;
254255
case 9:
255-
launch_apply_kernel<StopType, 32, 9, true>(
256+
launch_apply_kernel<StopType, 32, 9>(
256257
sconf, logger, prec, mat, b.values, x.values,
257258
workspace_data, group_size, shared_size);
258259
break;
259260
case 10:
260-
launch_apply_kernel<StopType, 32, 10, true>(
261+
launch_apply_kernel<StopType, 32, 10>(
261262
sconf, logger, prec, mat, b.values, x.values,
262263
workspace_data, group_size, shared_size);
263264
break;

dpcpp/solver/batch_bicgstab_kernels.hpp.inc

Lines changed: 50 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
3030
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3131
******************************<GINKGO LICENSE>*******************************/
3232

33-
template <const bool sg_kernel_all, typename BatchMatrixType_entry,
34-
typename ValueType>
33+
template <typename BatchMatrixType_entry, typename ValueType>
3534
__dpct_inline__ void initialize(
3635
const int num_rows, const BatchMatrixType_entry& mat_global_entry,
3736
const ValueType* const b_global_entry,
@@ -68,17 +67,12 @@ __dpct_inline__ void initialize(
6867
r_shared_entry, item_ct1);
6968
item_ct1.barrier(sycl::access::fence_space::global_and_local);
7069

71-
if constexpr (sg_kernel_all) {
72-
if (sg_id == 0) {
73-
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
74-
item_ct1);
75-
} else if (sg_id == 1) {
76-
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
77-
item_ct1);
78-
}
79-
} else {
80-
single_rhs_compute_norm2(num_rows, r_shared_entry, res_norm, item_ct1);
81-
single_rhs_compute_norm2(num_rows, b_global_entry, rhs_norm, item_ct1);
70+
if (sg_id == 0) {
71+
single_rhs_compute_norm2_sg(num_rows, r_shared_entry, res_norm,
72+
item_ct1);
73+
} else if (sg_id == 1) {
74+
single_rhs_compute_norm2_sg(num_rows, b_global_entry, rhs_norm,
75+
item_ct1);
8276
}
8377
item_ct1.barrier(sycl::access::fence_space::global_and_local);
8478

@@ -111,7 +105,7 @@ __dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new,
111105
}
112106

113107

114-
template <const bool sg_kernel_all, typename ValueType>
108+
template <typename ValueType>
115109
__dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
116110
const ValueType* const r_hat_shared_entry,
117111
const ValueType* const v_shared_entry,
@@ -120,23 +114,15 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
120114
auto sg = item_ct1.get_sub_group();
121115
const auto sg_id = sg.get_group_id();
122116
const auto tid = item_ct1.get_local_linear_id();
123-
if constexpr (sg_kernel_all) {
124-
if (sg_id == 0) {
125-
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
126-
v_shared_entry, alpha, item_ct1);
127-
}
128-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
129-
if (tid == 0) {
130-
alpha = rho_new / alpha;
131-
}
132-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
133-
} else {
134-
single_rhs_compute_conj_dot(num_rows, r_hat_shared_entry,
135-
v_shared_entry, alpha, item_ct1);
136-
if (tid == 0) {
137-
alpha = rho_new / alpha;
138-
}
117+
if (sg_id == 0) {
118+
single_rhs_compute_conj_dot_sg(num_rows, r_hat_shared_entry,
119+
v_shared_entry, alpha, item_ct1);
139120
}
121+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
122+
if (tid == 0) {
123+
alpha = rho_new / alpha;
124+
}
125+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
140126
}
141127

142128

@@ -155,7 +141,7 @@ __dpct_inline__ void update_s(const int num_rows,
155141
}
156142

157143

158-
template <const bool sg_kernel_all, typename ValueType>
144+
template <typename ValueType>
159145
__dpct_inline__ void compute_omega(const int num_rows,
160146
const ValueType* const t_shared_entry,
161147
const ValueType* const s_shared_entry,
@@ -165,28 +151,18 @@ __dpct_inline__ void compute_omega(const int num_rows,
165151
auto sg = item_ct1.get_sub_group();
166152
const auto sg_id = sg.get_group_id();
167153
const auto tid = item_ct1.get_local_linear_id();
168-
if constexpr (sg_kernel_all) {
169-
if (sg_id == 0) {
170-
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
171-
s_shared_entry, omega, item_ct1);
172-
} else if (sg_id == 1) {
173-
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry,
174-
t_shared_entry, temp, item_ct1);
175-
}
176-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
177-
if (tid == 0) {
178-
omega /= temp;
179-
}
180-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
181-
} else {
182-
single_rhs_compute_conj_dot(num_rows, t_shared_entry, s_shared_entry,
183-
omega, item_ct1);
184-
single_rhs_compute_conj_dot(num_rows, t_shared_entry, t_shared_entry,
185-
temp, item_ct1);
186-
if (tid == 0) {
187-
omega /= temp;
188-
}
154+
if (sg_id == 0) {
155+
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, s_shared_entry,
156+
omega, item_ct1);
157+
} else if (sg_id == 1) {
158+
single_rhs_compute_conj_dot_sg(num_rows, t_shared_entry, t_shared_entry,
159+
temp, item_ct1);
160+
}
161+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
162+
if (tid == 0) {
163+
omega /= temp;
189164
}
165+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
190166
}
191167

192168

@@ -220,9 +196,8 @@ __dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha,
220196
}
221197

222198

223-
template <typename StopType, const int n_shared_total, const bool sg_kernel_all,
224-
typename PrecType, typename LogType, typename BatchMatrixType,
225-
typename ValueType>
199+
template <typename StopType, const int n_shared_total, typename PrecType,
200+
typename LogType, typename BatchMatrixType, typename ValueType>
226201
void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
227202
const int max_iter, const gko::remove_complex<ValueType> tol,
228203
LogType logger, PrecType prec_shared,
@@ -344,10 +319,10 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
344319
// p = 0
345320
// p_hat = 0
346321
// v = 0
347-
initialize<sg_kernel_all>(num_rows, mat_global_entry, b_global_entry,
348-
x_global_entry, rho_old_sh[0], omega_sh[0],
349-
alpha_sh[0], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh,
350-
v_sh, norms_rhs_sh[0], norms_res_sh[0], item_ct1);
322+
initialize(num_rows, mat_global_entry, b_global_entry, x_global_entry,
323+
rho_old_sh[0], omega_sh[0], alpha_sh[0], x_sh, r_sh, r_hat_sh,
324+
p_sh, p_hat_sh, v_sh, norms_rhs_sh[0], norms_res_sh[0],
325+
item_ct1);
351326
item_ct1.barrier(sycl::access::fence_space::global_and_local);
352327

353328
// stopping criterion object
@@ -361,16 +336,11 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
361336
}
362337

363338
// rho_new = < r_hat , r > = (r_hat)' * (r)
364-
if constexpr (sg_kernel_all) {
365-
if (sg_id == 0) {
366-
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
367-
rho_new_sh[0], item_ct1);
368-
}
369-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
370-
} else {
371-
single_rhs_compute_conj_dot(num_rows, r_hat_sh, r_sh, rho_new_sh[0],
372-
item_ct1);
339+
if (sg_id == 0) {
340+
single_rhs_compute_conj_dot_sg(num_rows, r_hat_sh, r_sh,
341+
rho_new_sh[0], item_ct1);
373342
}
343+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
374344

375345
// beta = (rho_new / rho_old)*(alpha / omega)
376346
// p = r + beta*(p - omega * v)
@@ -387,24 +357,20 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
387357
item_ct1.barrier(sycl::access::fence_space::global_and_local);
388358

389359
// alpha = rho_new / < r_hat , v>
390-
compute_alpha<sg_kernel_all>(num_rows, rho_new_sh[0], r_hat_sh, v_sh,
391-
alpha_sh[0], item_ct1);
360+
compute_alpha(num_rows, rho_new_sh[0], r_hat_sh, v_sh, alpha_sh[0],
361+
item_ct1);
392362
item_ct1.barrier(sycl::access::fence_space::global_and_local);
393363

394364
// s = r - alpha*v
395365
update_s(num_rows, r_sh, alpha_sh[0], v_sh, s_sh, item_ct1);
396366
item_ct1.barrier(sycl::access::fence_space::global_and_local);
397367

398368
// an estimate of residual norms
399-
if constexpr (sg_kernel_all) {
400-
if (sg_id == 0) {
401-
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
402-
item_ct1);
403-
}
404-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
405-
} else {
406-
single_rhs_compute_norm2(num_rows, s_sh, norms_res_sh[0], item_ct1);
369+
if (sg_id == 0) {
370+
single_rhs_compute_norm2_sg(num_rows, s_sh, norms_res_sh[0],
371+
item_ct1);
407372
}
373+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
408374

409375
if (stop.check_converged(norms_res_sh)) {
410376
update_x_middle(num_rows, alpha_sh[0], p_hat_sh, x_sh, item_ct1);
@@ -421,8 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
421387
item_ct1.barrier(sycl::access::fence_space::global_and_local);
422388

423389
// omega = <t,s> / <t,t>
424-
compute_omega<sg_kernel_all>(num_rows, t_sh, s_sh, temp_sh[0],
425-
omega_sh[0], item_ct1);
390+
compute_omega(num_rows, t_sh, s_sh, temp_sh[0], omega_sh[0], item_ct1);
426391
item_ct1.barrier(sycl::access::fence_space::global_and_local);
427392

428393
// x = x + alpha*p_hat + omega *s_hat
@@ -431,18 +396,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
431396
s_sh, t_sh, x_sh, r_sh, item_ct1);
432397
item_ct1.barrier(sycl::access::fence_space::global_and_local);
433398

434-
if constexpr (sg_kernel_all) {
435-
if (sg_id == 0)
436-
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
437-
item_ct1);
438-
if (tid == group_size - 1) {
439-
rho_old_sh[0] = rho_new_sh[0];
440-
}
441-
item_ct1.barrier(sycl::access::fence_space::global_and_local);
442-
} else {
443-
single_rhs_compute_norm2(num_rows, r_sh, norms_res_sh[0], item_ct1);
399+
if (sg_id == 0)
400+
single_rhs_compute_norm2_sg(num_rows, r_sh, norms_res_sh[0],
401+
item_ct1);
402+
if (tid == group_size - 1) {
444403
rho_old_sh[0] = rho_new_sh[0];
445404
}
405+
item_ct1.barrier(sycl::access::fence_space::global_and_local);
446406
}
447407

448408
logger.log_iteration(batch_id, iter, norms_res_sh[0]);

include/ginkgo/core/solver/batch_solver_base.hpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,6 @@ class BatchSolver {
177177
};
178178

179179

180-
/**
181-
* The parameter type shared between all preconditioned iterative solvers,
182-
* excluding the parameters available in iterative_solver_factory_parameters.
183-
* @see GKO_CREATE_FACTORY_PARAMETERS
184-
*/
185-
struct preconditioned_iterative_solver_factory_parameters {};
186-
187-
188180
template <typename Parameters, typename Factory>
189181
struct enable_preconditioned_iterative_solver_factory_parameters
190182
: enable_parameters_type<Parameters, Factory> {

0 commit comments

Comments
 (0)