@@ -30,8 +30,7 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
3030OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3131******************************<GINKGO LICENSE>*******************************/
3232
33- template <const bool sg_kernel_all, typename BatchMatrixType_entry,
34- typename ValueType>
33+ template <typename BatchMatrixType_entry, typename ValueType>
3534__dpct_inline__ void initialize (
3635 const int num_rows, const BatchMatrixType_entry& mat_global_entry,
3736 const ValueType* const b_global_entry,
@@ -68,17 +67,12 @@ __dpct_inline__ void initialize(
6867 r_shared_entry, item_ct1);
6968 item_ct1.barrier (sycl::access::fence_space::global_and_local);
7069
71- if constexpr (sg_kernel_all) {
72- if (sg_id == 0 ) {
73- single_rhs_compute_norm2_sg (num_rows, r_shared_entry, res_norm,
74- item_ct1);
75- } else if (sg_id == 1 ) {
76- single_rhs_compute_norm2_sg (num_rows, b_global_entry, rhs_norm,
77- item_ct1);
78- }
79- } else {
80- single_rhs_compute_norm2 (num_rows, r_shared_entry, res_norm, item_ct1);
81- single_rhs_compute_norm2 (num_rows, b_global_entry, rhs_norm, item_ct1);
70+ if (sg_id == 0 ) {
71+ single_rhs_compute_norm2_sg (num_rows, r_shared_entry, res_norm,
72+ item_ct1);
73+ } else if (sg_id == 1 ) {
74+ single_rhs_compute_norm2_sg (num_rows, b_global_entry, rhs_norm,
75+ item_ct1);
8276 }
8377 item_ct1.barrier (sycl::access::fence_space::global_and_local);
8478
@@ -111,7 +105,7 @@ __dpct_inline__ void update_p(const int num_rows, const ValueType& rho_new,
111105}
112106
113107
114- template <const bool sg_kernel_all, typename ValueType>
108+ template <typename ValueType>
115109__dpct_inline__ void compute_alpha (const int num_rows, const ValueType& rho_new,
116110 const ValueType* const r_hat_shared_entry,
117111 const ValueType* const v_shared_entry,
@@ -120,23 +114,15 @@ __dpct_inline__ void compute_alpha(const int num_rows, const ValueType& rho_new,
120114 auto sg = item_ct1.get_sub_group ();
121115 const auto sg_id = sg.get_group_id ();
122116 const auto tid = item_ct1.get_local_linear_id ();
123- if constexpr (sg_kernel_all) {
124- if (sg_id == 0 ) {
125- single_rhs_compute_conj_dot_sg (num_rows, r_hat_shared_entry,
126- v_shared_entry, alpha, item_ct1);
127- }
128- item_ct1.barrier (sycl::access::fence_space::global_and_local);
129- if (tid == 0 ) {
130- alpha = rho_new / alpha;
131- }
132- item_ct1.barrier (sycl::access::fence_space::global_and_local);
133- } else {
134- single_rhs_compute_conj_dot (num_rows, r_hat_shared_entry,
135- v_shared_entry, alpha, item_ct1);
136- if (tid == 0 ) {
137- alpha = rho_new / alpha;
138- }
117+ if (sg_id == 0 ) {
118+ single_rhs_compute_conj_dot_sg (num_rows, r_hat_shared_entry,
119+ v_shared_entry, alpha, item_ct1);
139120 }
121+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
122+ if (tid == 0 ) {
123+ alpha = rho_new / alpha;
124+ }
125+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
140126}
141127
142128
@@ -155,7 +141,7 @@ __dpct_inline__ void update_s(const int num_rows,
155141}
156142
157143
158- template <const bool sg_kernel_all, typename ValueType>
144+ template <typename ValueType>
159145__dpct_inline__ void compute_omega (const int num_rows,
160146 const ValueType* const t_shared_entry,
161147 const ValueType* const s_shared_entry,
@@ -165,28 +151,18 @@ __dpct_inline__ void compute_omega(const int num_rows,
165151 auto sg = item_ct1.get_sub_group ();
166152 const auto sg_id = sg.get_group_id ();
167153 const auto tid = item_ct1.get_local_linear_id ();
168- if constexpr (sg_kernel_all) {
169- if (sg_id == 0 ) {
170- single_rhs_compute_conj_dot_sg (num_rows, t_shared_entry,
171- s_shared_entry, omega, item_ct1);
172- } else if (sg_id == 1 ) {
173- single_rhs_compute_conj_dot_sg (num_rows, t_shared_entry,
174- t_shared_entry, temp, item_ct1);
175- }
176- item_ct1.barrier (sycl::access::fence_space::global_and_local);
177- if (tid == 0 ) {
178- omega /= temp;
179- }
180- item_ct1.barrier (sycl::access::fence_space::global_and_local);
181- } else {
182- single_rhs_compute_conj_dot (num_rows, t_shared_entry, s_shared_entry,
183- omega, item_ct1);
184- single_rhs_compute_conj_dot (num_rows, t_shared_entry, t_shared_entry,
185- temp, item_ct1);
186- if (tid == 0 ) {
187- omega /= temp;
188- }
154+ if (sg_id == 0 ) {
155+ single_rhs_compute_conj_dot_sg (num_rows, t_shared_entry, s_shared_entry,
156+ omega, item_ct1);
157+ } else if (sg_id == 1 ) {
158+ single_rhs_compute_conj_dot_sg (num_rows, t_shared_entry, t_shared_entry,
159+ temp, item_ct1);
160+ }
161+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
162+ if (tid == 0 ) {
163+ omega /= temp;
189164 }
165+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
190166}
191167
192168
@@ -220,9 +196,8 @@ __dpct_inline__ void update_x_middle(const int num_rows, const ValueType& alpha,
220196}
221197
222198
223- template <typename StopType, const int n_shared_total, const bool sg_kernel_all,
224- typename PrecType, typename LogType, typename BatchMatrixType,
225- typename ValueType>
199+ template <typename StopType, const int n_shared_total, typename PrecType,
200+ typename LogType, typename BatchMatrixType, typename ValueType>
226201void apply_kernel (const gko::kernels::batch_bicgstab::storage_config sconf,
227202 const int max_iter, const gko::remove_complex<ValueType> tol,
228203 LogType logger, PrecType prec_shared,
@@ -344,10 +319,10 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
344319 // p = 0
345320 // p_hat = 0
346321 // v = 0
347- initialize<sg_kernel_all> (num_rows, mat_global_entry, b_global_entry,
348- x_global_entry, rho_old_sh[0 ], omega_sh[0 ],
349- alpha_sh [0 ], x_sh, r_sh, r_hat_sh, p_sh, p_hat_sh ,
350- v_sh, norms_rhs_sh[ 0 ], norms_res_sh[ 0 ], item_ct1);
322+ initialize (num_rows, mat_global_entry, b_global_entry, x_global_entry ,
323+ rho_old_sh[0 ], omega_sh[0 ], alpha_sh[ 0 ], x_sh, r_sh, r_hat_sh ,
324+ p_sh, p_hat_sh, v_sh, norms_rhs_sh [0 ], norms_res_sh[ 0 ] ,
325+ item_ct1);
351326 item_ct1.barrier (sycl::access::fence_space::global_and_local);
352327
353328 // stopping criterion object
@@ -361,16 +336,11 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
361336 }
362337
363338 // rho_new = < r_hat , r > = (r_hat)' * (r)
364- if constexpr (sg_kernel_all) {
365- if (sg_id == 0 ) {
366- single_rhs_compute_conj_dot_sg (num_rows, r_hat_sh, r_sh,
367- rho_new_sh[0 ], item_ct1);
368- }
369- item_ct1.barrier (sycl::access::fence_space::global_and_local);
370- } else {
371- single_rhs_compute_conj_dot (num_rows, r_hat_sh, r_sh, rho_new_sh[0 ],
372- item_ct1);
339+ if (sg_id == 0 ) {
340+ single_rhs_compute_conj_dot_sg (num_rows, r_hat_sh, r_sh,
341+ rho_new_sh[0 ], item_ct1);
373342 }
343+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
374344
375345 // beta = (rho_new / rho_old)*(alpha / omega)
376346 // p = r + beta*(p - omega * v)
@@ -387,24 +357,20 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
387357 item_ct1.barrier (sycl::access::fence_space::global_and_local);
388358
389359 // alpha = rho_new / < r_hat , v>
390- compute_alpha<sg_kernel_all> (num_rows, rho_new_sh[0 ], r_hat_sh, v_sh,
391- alpha_sh[ 0 ], item_ct1);
360+ compute_alpha (num_rows, rho_new_sh[0 ], r_hat_sh, v_sh, alpha_sh[ 0 ] ,
361+ item_ct1);
392362 item_ct1.barrier (sycl::access::fence_space::global_and_local);
393363
394364 // s = r - alpha*v
395365 update_s (num_rows, r_sh, alpha_sh[0 ], v_sh, s_sh, item_ct1);
396366 item_ct1.barrier (sycl::access::fence_space::global_and_local);
397367
398368 // an estimate of residual norms
399- if constexpr (sg_kernel_all) {
400- if (sg_id == 0 ) {
401- single_rhs_compute_norm2_sg (num_rows, s_sh, norms_res_sh[0 ],
402- item_ct1);
403- }
404- item_ct1.barrier (sycl::access::fence_space::global_and_local);
405- } else {
406- single_rhs_compute_norm2 (num_rows, s_sh, norms_res_sh[0 ], item_ct1);
369+ if (sg_id == 0 ) {
370+ single_rhs_compute_norm2_sg (num_rows, s_sh, norms_res_sh[0 ],
371+ item_ct1);
407372 }
373+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
408374
409375 if (stop.check_converged (norms_res_sh)) {
410376 update_x_middle (num_rows, alpha_sh[0 ], p_hat_sh, x_sh, item_ct1);
@@ -421,8 +387,7 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
421387 item_ct1.barrier (sycl::access::fence_space::global_and_local);
422388
423389 // omega = <t,s> / <t,t>
424- compute_omega<sg_kernel_all>(num_rows, t_sh, s_sh, temp_sh[0 ],
425- omega_sh[0 ], item_ct1);
390+ compute_omega (num_rows, t_sh, s_sh, temp_sh[0 ], omega_sh[0 ], item_ct1);
426391 item_ct1.barrier (sycl::access::fence_space::global_and_local);
427392
428393 // x = x + alpha*p_hat + omega *s_hat
@@ -431,18 +396,13 @@ void apply_kernel(const gko::kernels::batch_bicgstab::storage_config sconf,
431396 s_sh, t_sh, x_sh, r_sh, item_ct1);
432397 item_ct1.barrier (sycl::access::fence_space::global_and_local);
433398
434- if constexpr (sg_kernel_all) {
435- if (sg_id == 0 )
436- single_rhs_compute_norm2_sg (num_rows, r_sh, norms_res_sh[0 ],
437- item_ct1);
438- if (tid == group_size - 1 ) {
439- rho_old_sh[0 ] = rho_new_sh[0 ];
440- }
441- item_ct1.barrier (sycl::access::fence_space::global_and_local);
442- } else {
443- single_rhs_compute_norm2 (num_rows, r_sh, norms_res_sh[0 ], item_ct1);
399+ if (sg_id == 0 )
400+ single_rhs_compute_norm2_sg (num_rows, r_sh, norms_res_sh[0 ],
401+ item_ct1);
402+ if (tid == group_size - 1 ) {
444403 rho_old_sh[0 ] = rho_new_sh[0 ];
445404 }
405+ item_ct1.barrier (sycl::access::fence_space::global_and_local);
446406 }
447407
448408 logger.log_iteration (batch_id, iter, norms_res_sh[0 ]);
0 commit comments