Skip to content

Commit 6ab97cd

Browse files
yhmtsaiSlaedrupsj
committed
switch shift order, add ConfigSet check, review update
Co-authored-by: Aditya Kashi <aditya.kashi@kit.edu> Co-authored-by: Tobias Ribizel <ribizel@kit.edu>
1 parent ee3b630 commit 6ab97cd

3 files changed

Lines changed: 76 additions & 32 deletions

File tree

core/test/base/types.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3535

3636
#include <array>
3737
#include <cstdint>
38+
#include <stdexcept>
3839
#include <type_traits>
3940

4041

@@ -114,27 +115,26 @@ TEST(ConfigSet, MaskCorrectly)
114115

115116
ASSERT_EQ(mask3_u, 7u);
116117
ASSERT_EQ(fullmask_u, 0xffffffffu);
117-
ASSERT_EQ((std::is_same<decltype(mask3_u), const unsigned int>::value),
118-
true);
119-
ASSERT_EQ((std::is_same<decltype(fullmask_u), const unsigned int>::value),
120-
true);
118+
ASSERT_TRUE((std::is_same<decltype(mask3_u), const unsigned int>::value));
119+
ASSERT_TRUE(
120+
(std::is_same<decltype(fullmask_u), const unsigned int>::value));
121121
ASSERT_EQ(mask3_u64, 7ull);
122122
ASSERT_EQ(fullmask_u64, 0xffffffffffffffffull);
123-
ASSERT_EQ((std::is_same<decltype(mask3_u64), const std::uint64_t>::value),
124-
true);
125-
ASSERT_EQ(
126-
(std::is_same<decltype(fullmask_u64), const std::uint64_t>::value),
127-
true);
123+
ASSERT_TRUE(
124+
(std::is_same<decltype(mask3_u64), const std::uint64_t>::value));
125+
ASSERT_TRUE(
126+
(std::is_same<decltype(fullmask_u64), const std::uint64_t>::value));
128127
}
129128

130129

131130
TEST(ConfigSet, ShiftCorrectly)
132131
{
133-
constexpr std::array<char, 3> bits{3, 5, 7};
132+
constexpr std::array<unsigned char, 3> bits{3, 5, 7};
134133

135-
constexpr auto shift0 = gko::detail::shift<3, 0>(bits);
136-
constexpr auto shift1 = gko::detail::shift<3, 1>(bits);
137-
constexpr auto shift2 = gko::detail::shift<3, 2>(bits);
134+
135+
constexpr auto shift0 = gko::detail::shift<0, 3>(bits);
136+
constexpr auto shift1 = gko::detail::shift<1, 3>(bits);
137+
constexpr auto shift2 = gko::detail::shift<2, 3>(bits);
138138

139139
ASSERT_EQ(shift0, 12);
140140
ASSERT_EQ(shift1, 7);
@@ -210,4 +210,14 @@ TEST(ConfigSet, ConfigSetSomeFullCorrectly)
210210
}
211211

212212

213+
TEST(ConfigSet, ThrowOutOfBoundWhenExceedRepresentation)
214+
{
215+
using Cfg = gko::ConfigSet<3, 2, 1>;
216+
217+
ASSERT_THROW(auto a = Cfg::encode(0, 0, 2), std::out_of_range);
218+
ASSERT_THROW(auto a = Cfg::encode(0, 4, 0), std::out_of_range);
219+
ASSERT_THROW(auto a = Cfg::encode(8, 0, 0), std::out_of_range);
220+
}
221+
222+
213223
} // namespace

dpcpp/test/components/cooperative_groups_kernels.dp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace {
5959

6060

6161
using namespace gko::kernels::dpcpp;
62-
using KCfg = gko::ConfigSet<12, 7>;
62+
using KCfg = gko::ConfigSet<11, 7>;
6363
constexpr auto default_config_list =
6464
::gko::syn::value_list<::gko::ConfigSetType, KCfg::encode(64, 64),
6565
KCfg::encode(32, 32), KCfg::encode(16, 16),

include/ginkgo/core/base/types.hpp

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4141
#include <cstddef>
4242
#include <cstdint>
4343
#include <limits>
44+
#include <stdexcept>
45+
#include <string>
4446
#include <type_traits>
4547

4648

@@ -226,16 +228,16 @@ constexpr std::enable_if_t<Size == sizeof(ValueType) * 8, ValueType> mask()
226228
/**
227229
* shift calculates the number of bits for shifting
228230
*
229-
* @tparam num_groups the number of elements in array
230231
* @tparam current_shift the current position of shifting
232+
* @tparam num_groups the number of elements in array
231233
*
232234
* @return the number of shifting bits
233235
*
234236
* @note this is the last case of nested template
235237
*/
236-
template <int num_groups, int current_shift>
238+
template <int current_shift, int num_groups>
237239
constexpr std::enable_if_t<(num_groups == current_shift + 1), int> shift(
238-
const std::array<char, num_groups> &bits)
240+
const std::array<unsigned char, num_groups> &bits)
239241
{
240242
return 0;
241243
}
@@ -245,12 +247,12 @@ constexpr std::enable_if_t<(num_groups == current_shift + 1), int> shift(
245247
*
246248
* @note this is the usual case of nested template
247249
*/
248-
template <int num_groups, int current_shift>
250+
template <int current_shift, int num_groups>
249251
constexpr std::enable_if_t<(num_groups > current_shift + 1), int> shift(
250-
const std::array<char, num_groups> &bits)
252+
const std::array<unsigned char, num_groups> &bits)
251253
{
252254
return bits[current_shift + 1] +
253-
shift<num_groups, (current_shift + 1)>(bits);
255+
shift<(current_shift + 1), num_groups>(bits);
254256
}
255257

256258

@@ -262,23 +264,42 @@ using ConfigSetType = unsigned int;
262264

263265
/**
264266
* ConfigSet is a way to embed several information into one integer by given
265-
* certain bits. The usage will be the following
266-
* Set the method with bits Cfg = ConfigSet<nb_1, nb_2, ..., nb_k>
267-
* Encode the given infomation encoded = Cfg::encode(i_1, i_2, ..., i_k)
268-
* Decode the specific position information i_t = Cfg::decode<t>(encoded)
267+
* certain bits.
268+
*
269+
* The usage will be the following
270+
* Set the method with bits Cfg = ConfigSet<b_0, b_1, ..., b_k>
271+
* Encode the given infomation encoded = Cfg::encode(x_0, x_1, ..., x_k)
272+
* Decode the specific position information x_t = Cfg::decode<t>(encoded)
269273
* The encoded result will use 32 bits to record
270-
* rrrrr1..12....2...k..k, which 1/2/k means the bits store the information for
271-
* 1/2/k position and r is for rest of unused bits.
274+
* rrrrr0..01....1...k..k, which 1/2/.../k means the bits store the information
275+
* for 1/2/.../k position and r is for rest of unused bits.
276+
*
277+
* Denote B_t = sum_(t+1)^(k) b_i and F(X) = Cfg::encode(x_0, ..., x_k)
278+
* We can write F(X) = sum_0^k (x_i << B_i)
279+
* for all i, we have 0 <= x_i < 2^(b_i)
280+
* x_i, B_i are non-negative, so the F(X) = 0 <=> X = {0}, x_i = 0 for all i
281+
* Assume F(X) = F(Y), then
282+
* 0 = |F(X) - F(Y)| = |F(X-Y)| = F(|X - Y|)
283+
* |x_i - y_i| is still in the same range 0 <= |x_i - y_i| < 2^(b_i)
284+
* Thus, F(|X - Y|) = 0 -> |X - Y| = {0}, x_i - y_i = 0 -> X = Y
285+
* F is one-to-one function if 0 <= x_i < 2^(b_i) for all i
286+
* For any encoded result R, we can use the following to get the decoded series.
287+
* for i = k to 0
288+
* x_i = R % b_i
289+
* R = R / bi
290+
* Thus, any R in the range [0, 2^(B_0)) we have a series X such that F(X) = R
291+
* F is onto function.
292+
* Thus, F is bijection
272293
*
273294
* @tparam num_bits... the number of bits for each position.
274295
*
275-
* @note the num_bit is required at least $log_2(maxval) + 1$
296+
* @note the num_bit is required at least $ceil(log_2(maxval) + 1)$
276297
*/
277-
template <int... num_bits>
298+
template <unsigned char... num_bits>
278299
class ConfigSet {
279300
public:
280301
static constexpr size_type num_groups = sizeof...(num_bits);
281-
static constexpr std::array<char, num_groups> bits{num_bits...};
302+
static constexpr std::array<unsigned char, num_groups> bits{num_bits...};
282303

283304
/**
284305
* Decodes the `position` information from encoded
@@ -294,7 +315,7 @@ class ConfigSet {
294315
{
295316
static_assert(position < num_groups,
296317
"This position is over the bounds.");
297-
constexpr int shift = detail::shift<num_groups, position>(bits);
318+
constexpr int shift = detail::shift<position, num_groups>(bits);
298319
constexpr auto mask = detail::mask<bits[position]>();
299320
return (encoded >> shift) & mask;
300321
}
@@ -319,7 +340,7 @@ class ConfigSet {
319340
* @tparam Rest... the rest type
320341
*
321342
* @param first the current encoded information
322-
* @param rest... the rest of others information waits for encoding
343+
* @param rest... the rest of other information waiting for encoding
323344
*
324345
* @return the encoded integer
325346
*/
@@ -328,7 +349,20 @@ class ConfigSet {
328349
ConfigSetType>
329350
encode(ConfigSetType first, Rest &&... rest)
330351
{
331-
constexpr int shift = detail::shift<num_groups, current_iter>(bits);
352+
constexpr auto bound = detail::mask<bits[current_iter]>();
353+
if (first > bound) {
354+
throw std::out_of_range(
355+
std::to_string(first) + " at " + std::to_string(current_iter) +
356+
" postion is out of range of " +
357+
std::to_string(detail::mask<bits[current_iter]>()) +
358+
" representation");
359+
}
360+
constexpr int shift = detail::shift<current_iter, num_groups>(bits);
361+
if (current_iter == 0) {
362+
static_assert(
363+
bits[current_iter] + shift <= sizeof(ConfigSetType) * 8,
364+
"the total bits usage is larger than ConfigSetType bits");
365+
}
332366
return (first << shift) |
333367
encode<current_iter + 1>(std::forward<Rest>(rest)...);
334368
}

0 commit comments

Comments
 (0)