@@ -34,22 +34,25 @@ namespace kernel {
3434namespace {
3535
3636
37- template < int num_thread_per_worker, bool atomic, typename ValueType,
37+ template < int num_thread_per_worker, bool atomic, typename InputValueType,
38+ typename MatrixValueType, typename OutputValueType,
3839 typename IndexType, typename Closure>
3940__device__ void spmv_kernel (
4041 const size_type num_rows, const int num_worker_per_row,
41- const ValueType * __restrict__ val, const IndexType * __restrict__ col,
42+ const MatrixValueType * __restrict__ val, const IndexType * __restrict__ col,
4243 const size_type stride, const size_type num_stored_elements_per_row,
43- const ValueType * __restrict__ b, const size_type b_stride,
44- ValueType * __restrict__ c, const size_type c_stride, Closure op)
44+ const InputValueType * __restrict__ b, const size_type b_stride,
45+ OutputValueType * __restrict__ c, const size_type c_stride, Closure op)
4546{
4647 const auto tidx = thread ::get_thread_id_flat ();
4748 const auto column_id = blockIdx.y;
49+ using compute_type =
50+ decltype (InputValueType{} + MatrixValueType{} + OutputValueType{});
4851 if (num_thread_per_worker == 1 ) {
4952 // Specialize the num_thread_per_worker = 1. It doesn't need the shared
5053 // memory, __syncthreads, and atomic_add
5154 if (tidx < num_rows) {
52- ValueType temp = zero< ValueType > ();
55+ auto temp = zero< compute_type > ();
5356 for (size_type idx = 0 ; idx < num_stored_elements_per_row; idx++ ) {
5457 const auto ind = tidx + idx * stride;
5558 const auto col_idx = col[ind];
@@ -68,14 +71,14 @@ __device__ void spmv_kernel(
6871 const auto x = tidx % num_rows;
6972 const auto worker_id = tidx / num_rows;
7073 const auto step_size = num_worker_per_row * num_thread_per_worker;
71- __shared__ UninitializedArray< ValueType, default_block_size /
72- num_thread_per_worker>
74+ __shared__ UninitializedArray<
75+ compute_type, default_block_size / num_thread_per_worker>
7376 storage;
7477 if (idx_in_worker == 0 ) {
7578 storage[threadIdx.x] = 0 ;
7679 }
7780 __syncthreads ();
78- ValueType temp = zero< ValueType > ();
81+ auto temp = zero< compute_type > ();
7982 for (size_type idx =
8083 worker_id * num_thread_per_worker + idx_in_worker;
8184 idx < num_stored_elements_per_row; idx += step_size) {
@@ -102,35 +105,41 @@ __device__ void spmv_kernel(
102105}
103106
104107
105- template < int num_thread_per_worker, bool atomic = false , typename ValueType,
106- typename IndexType>
108+ template < int num_thread_per_worker, bool atomic = false ,
109+ typename InputValueType, typename MatrixValueType,
110+ typename OutputValueType, typename IndexType>
107111__global__ __launch_bounds__ (default_block_size) void spmv (
108112 const size_type num_rows, const int num_worker_per_row,
109- const ValueType * __restrict__ val, const IndexType * __restrict__ col,
113+ const MatrixValueType * __restrict__ val, const IndexType * __restrict__ col,
110114 const size_type stride, const size_type num_stored_elements_per_row,
111- const ValueType * __restrict__ b, const size_type b_stride,
112- ValueType * __restrict__ c, const size_type c_stride)
115+ const InputValueType * __restrict__ b, const size_type b_stride,
116+ OutputValueType * __restrict__ c, const size_type c_stride)
113117{
118+ using compute_type =
119+ decltype (InputValueType{} + MatrixValueType{} + OutputValueType{});
114120 spmv_kernel< num_thread_per_worker, atomic> (
115121 num_rows, num_worker_per_row, val, col, stride,
116122 num_stored_elements_per_row, b, b_stride, c, c_stride,
117- [](const ValueType & x, const ValueType & y) { return x; });
123+ [](const compute_type & x, const OutputValueType & y) { return x; });
118124}
119125
120126
121- template < int num_thread_per_worker, bool atomic = false , typename ValueType,
122- typename IndexType>
127+ template < int num_thread_per_worker, bool atomic = false ,
128+ typename InputValueType, typename MatrixValueType,
129+ typename OutputValueType, typename IndexType>
123130__global__ __launch_bounds__ (default_block_size) void spmv (
124131 const size_type num_rows, const int num_worker_per_row,
125- const ValueType * __restrict__ alpha, const ValueType * __restrict__ val ,
126- const IndexType * __restrict__ col , const size_type stride ,
127- const size_type num_stored_elements_per_row,
128- const ValueType * __restrict__ b, const size_type b_stride,
129- const ValueType * __restrict__ beta, ValueType * __restrict__ c,
132+ const MatrixValueType * __restrict__ alpha,
133+ const MatrixValueType * __restrict__ val , const IndexType * __restrict__ col ,
134+ const size_type stride, const size_type num_stored_elements_per_row,
135+ const InputValueType * __restrict__ b, const size_type b_stride,
136+ const OutputValueType * __restrict__ beta, OutputValueType * __restrict__ c,
130137 const size_type c_stride)
131138{
132- const ValueType alpha_val = alpha[0 ];
133- const ValueType beta_val = beta[0 ];
139+ using compute_type =
140+ decltype (InputValueType{} + MatrixValueType{} + OutputValueType{});
141+ const compute_type alpha_val = alpha[0 ];
142+ const compute_type beta_val = beta[0 ];
134143 // Because the atomic operation changes the values of c during computation,
135144 // it can not do the right alpha * a * b + beta * c operation.
136145 // Thus, the cuda kernel only computes alpha * a * b when it uses atomic
@@ -139,15 +148,16 @@ __global__ __launch_bounds__(default_block_size) void spmv(
139148 spmv_kernel< num_thread_per_worker, atomic> (
140149 num_rows, num_worker_per_row, val, col, stride,
141150 num_stored_elements_per_row, b, b_stride, c, c_stride,
142- [& alpha_val](const ValueType & x, const ValueType & y) {
151+ [& alpha_val](const compute_type & x, const OutputValueType & y) {
143152 return alpha_val * x;
144153 });
145154 } else {
146155 spmv_kernel< num_thread_per_worker, atomic> (
147156 num_rows, num_worker_per_row, val, col, stride,
148157 num_stored_elements_per_row, b, b_stride, c, c_stride,
149- [& alpha_val, & beta_val](const ValueType & x, const ValueType & y) {
150- return alpha_val * x + beta_val * y;
158+ [& alpha_val, & beta_val](const compute_type & x,
159+ const OutputValueType & y) {
160+ return alpha_val * x + beta_val * compute_type{y};
151161 });
152162 }
153163}
0 commit comments