@@ -87,7 +87,7 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
8787 long max_group_size =
8888 device.get_info <sycl::info::device::max_work_group_size>();
8989 int group_size =
90- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
90+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
9191 max_group_size);
9292
9393 const dim3 block (group_size);
@@ -141,7 +141,7 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
141141 long max_group_size =
142142 device.get_info <sycl::info::device::max_work_group_size>();
143143 int group_size =
144- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
144+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
145145 max_group_size);
146146
147147 const dim3 block (group_size);
@@ -202,49 +202,45 @@ void compute_dot(std::shared_ptr<const DefaultExecutor> exec,
202202 long max_group_size =
203203 device.get_info <sycl::info::device::max_work_group_size>();
204204 int group_size =
205- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
205+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
206206 max_group_size);
207207
208208 const dim3 block (group_size);
209209 const dim3 grid (num_batches);
210210 if (x->get_common_size ()[1 ] == 1 ) {
211211 exec->get_queue ()->submit ([&](sycl::handler& cgh) {
212212 cgh.parallel_for (
213- sycl_nd_range (grid, block),
214- [=](sycl::nd_item<3 > item_ct1)
215- [[sycl::reqd_sub_group_size(max_subgroup_size)]] {
216- auto group = item_ct1.get_group ();
217- auto group_id = group.get_group_linear_id ();
218- const auto x_b =
219- batch::extract_batch_item (x_ub, group_id);
220- const auto y_b =
221- batch::extract_batch_item (y_ub, group_id);
222- const auto res_b =
223- batch::extract_batch_item (res_ub, group_id);
224- single_rhs_compute_dot_sg (x_b.num_rows , x_b.values ,
225- y_b.values , res_b.values [0 ],
226- item_ct1);
227- });
213+ sycl_nd_range (grid, block), [=
214+ ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (
215+ max_subgroup_size)]] {
216+ auto group = item_ct1.get_group ();
217+ auto group_id = group.get_group_linear_id ();
218+ const auto x_b = batch::extract_batch_item (x_ub, group_id);
219+ const auto y_b = batch::extract_batch_item (y_ub, group_id);
220+ const auto res_b =
221+ batch::extract_batch_item (res_ub, group_id);
222+ single_rhs_compute_conj_dot_sg (x_b.num_rows , x_b.values ,
223+ y_b.values , res_b.values [0 ],
224+ item_ct1);
225+ });
228226 });
229227 } else {
230228 // TODO: Remove reqd_sub_group size and use sycl::reduce_over_group
231229 exec->get_queue ()->submit ([&](sycl::handler& cgh) {
232230 cgh.parallel_for (
233- sycl_nd_range (grid, block),
234- [=](sycl::nd_item<3 > item_ct1)
235- [[sycl::reqd_sub_group_size(max_subgroup_size)]] {
236- auto group = item_ct1.get_group ();
237- auto group_id = group.get_group_linear_id ();
238- const auto x_b =
239- batch::extract_batch_item (x_ub, group_id);
240- const auto y_b =
241- batch::extract_batch_item (y_ub, group_id);
242- const auto res_b =
243- batch::extract_batch_item (res_ub, group_id);
244- compute_gen_dot_product_kernel (
245- x_b, y_b, res_b, item_ct1,
246- [](auto val) { return val; });
247- });
231+ sycl_nd_range (grid, block), [=
232+ ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (
233+ max_subgroup_size)]] {
234+ auto group = item_ct1.get_group ();
235+ auto group_id = group.get_group_linear_id ();
236+ const auto x_b = batch::extract_batch_item (x_ub, group_id);
237+ const auto y_b = batch::extract_batch_item (y_ub, group_id);
238+ const auto res_b =
239+ batch::extract_batch_item (res_ub, group_id);
240+ compute_gen_dot_product_kernel (
241+ x_b, y_b, res_b, item_ct1,
242+ [](auto val) { return val; });
243+ });
248244 });
249245 }
250246}
@@ -270,27 +266,26 @@ void compute_conj_dot(std::shared_ptr<const DefaultExecutor> exec,
270266 long max_group_size =
271267 device.get_info <sycl::info::device::max_work_group_size>();
272268 int group_size =
273- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
269+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
274270 max_group_size);
275271
276272 const dim3 block (group_size);
277273 const dim3 grid (num_batches);
278274
279275 exec->get_queue ()->submit ([&](sycl::handler& cgh) {
280276 cgh.parallel_for (
281- sycl_nd_range (grid, block),
282- [=](sycl::nd_item<3 > item_ct1)
283- [[sycl::reqd_sub_group_size(max_subgroup_size)]] {
284- auto group = item_ct1.get_group ();
285- auto group_id = group.get_group_linear_id ();
286- const auto x_b = batch::extract_batch_item (x_ub, group_id);
287- const auto y_b = batch::extract_batch_item (y_ub, group_id);
288- const auto res_b =
289- batch::extract_batch_item (res_ub, group_id);
290- compute_gen_dot_product_kernel (
291- x_b, y_b, res_b, item_ct1,
292- [](auto val) { return conj (val); });
293- });
277+ sycl_nd_range (grid, block), [=
278+ ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (
279+ max_subgroup_size)]] {
280+ auto group = item_ct1.get_group ();
281+ auto group_id = group.get_group_linear_id ();
282+ const auto x_b = batch::extract_batch_item (x_ub, group_id);
283+ const auto y_b = batch::extract_batch_item (y_ub, group_id);
284+ const auto res_b = batch::extract_batch_item (res_ub, group_id);
285+ compute_gen_dot_product_kernel (
286+ x_b, y_b, res_b, item_ct1,
287+ [](auto val) { return conj (val); });
288+ });
294289 });
295290}
296291
@@ -314,41 +309,39 @@ void compute_norm2(std::shared_ptr<const DefaultExecutor> exec,
314309 long max_group_size =
315310 device.get_info <sycl::info::device::max_work_group_size>();
316311 int group_size =
317- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
312+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
318313 max_group_size);
319314
320315 const dim3 block (group_size);
321316 const dim3 grid (num_batches);
322317 if (x->get_common_size ()[1 ] == 1 ) {
323318 exec->get_queue ()->submit ([&](sycl::handler& cgh) {
324319 cgh.parallel_for (
325- sycl_nd_range (grid, block),
326- [=](sycl::nd_item<3 > item_ct1)
327- [[sycl::reqd_sub_group_size(max_subgroup_size)]] {
328- auto group = item_ct1.get_group ();
329- auto group_id = group.get_group_linear_id ();
330- const auto x_b =
331- batch::extract_batch_item (x_ub, group_id);
332- const auto res_b =
333- batch::extract_batch_item (res_ub, group_id);
334- single_rhs_compute_norm2_sg (x_b.num_rows , x_b.values ,
335- res_b.values [0 ], item_ct1);
336- });
320+ sycl_nd_range (grid, block), [=
321+ ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (
322+ max_subgroup_size)]] {
323+ auto group = item_ct1.get_group ();
324+ auto group_id = group.get_group_linear_id ();
325+ const auto x_b = batch::extract_batch_item (x_ub, group_id);
326+ const auto res_b =
327+ batch::extract_batch_item (res_ub, group_id);
328+ single_rhs_compute_norm2_sg (x_b.num_rows , x_b.values ,
329+ res_b.values [0 ], item_ct1);
330+ });
337331 });
338332 } else {
339333 exec->get_queue ()->submit ([&](sycl::handler& cgh) {
340334 cgh.parallel_for (
341- sycl_nd_range (grid, block),
342- [=](sycl::nd_item<3 > item_ct1)
343- [[sycl::reqd_sub_group_size(max_subgroup_size)]] {
344- auto group = item_ct1.get_group ();
345- auto group_id = group.get_group_linear_id ();
346- const auto x_b =
347- batch::extract_batch_item (x_ub, group_id);
348- const auto res_b =
349- batch::extract_batch_item (res_ub, group_id);
350- compute_norm2_kernel (x_b, res_b, item_ct1);
351- });
335+ sycl_nd_range (grid, block), [=
336+ ](sycl::nd_item<3 > item_ct1) [[sycl::reqd_sub_group_size (
337+ max_subgroup_size)]] {
338+ auto group = item_ct1.get_group ();
339+ auto group_id = group.get_group_linear_id ();
340+ const auto x_b = batch::extract_batch_item (x_ub, group_id);
341+ const auto res_b =
342+ batch::extract_batch_item (res_ub, group_id);
343+ compute_norm2_kernel (x_b, res_b, item_ct1);
344+ });
352345 });
353346 }
354347}
@@ -372,7 +365,7 @@ void copy(std::shared_ptr<const DefaultExecutor> exec,
372365 long max_group_size =
373366 device.get_info <sycl::info::device::max_work_group_size>();
374367 int group_size =
375- std::max (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
368+ std::min (ceildiv (num_rows, max_subgroup_size) * max_subgroup_size,
376369 max_group_size);
377370
378371 const dim3 block (group_size);
0 commit comments