Skip to content

Commit da19a97

Browse files
authored
Merge Add dpcpp cooperative group, ConfigSet
This PR adds dpcpp cooperative group, ConfigSet, some helper function Summary: - dim3 and sycl_nd_range: a cuda-like usage for sycl range and nd_range with tests - helper gives default implementation macro for simple kernel cases (no explicit template parameter and 1d block) `__WG_BOUND__` gives something like `__launch_bound__` but it needs the 3d information not the product `__WG_BOUND_CONFIG__` can use ConfigSet for easy unpack - cooperative group implementation and set the test result individually - another selection for config (it allows bool, int, size_type template by roughly go through all kernel template) - update format_header such that it can handle the generated dpcpp file (the script is not yet here) - add ConfigSet and related decode/encode information Related PR: #757
2 parents 95c7652 + 7f72418 commit da19a97

16 files changed

Lines changed: 1582 additions & 7 deletions

File tree

cmake/create_test.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ function(ginkgo_create_dpcpp_test test_name)
3333
add_executable(${TEST_TARGET_NAME} ${test_name}.dp.cpp)
3434
target_compile_features("${TEST_TARGET_NAME}" PUBLIC cxx_std_17)
3535
target_compile_options("${TEST_TARGET_NAME}" PRIVATE "${GINKGO_DPCPP_FLAGS}")
36+
target_link_options("${TEST_TARGET_NAME}" PRIVATE -fsycl-device-code-split=per_kernel)
3637
if (GINKGO_DPCPP_SINGLE_MODE)
3738
target_compile_definitions("${TEST_TARGET_NAME}" PRIVATE GINKGO_DPCPP_SINGLE_MODE=1)
3839
endif()

core/base/types.hpp

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*******************************<GINKGO LICENSE>******************************
2+
Copyright (c) 2017-2021, the Ginkgo authors
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions
7+
are met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in the
14+
documentation and/or other materials provided with the distribution.
15+
16+
3. Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21+
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22+
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23+
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
******************************<GINKGO LICENSE>*******************************/
32+
33+
#ifndef GKO_CORE_BASE_TYPES_HPP_
34+
#define GKO_CORE_BASE_TYPES_HPP_
35+
36+
37+
#include <array>
38+
#include <cstdint>
39+
#include <type_traits>
40+
41+
42+
namespace gko {
43+
namespace detail {
44+
45+
46+
/**
47+
* mask gives the integer with Size activated bits in the end
48+
*
49+
* @tparam Size the number of activated bits
50+
* @tparam ValueType the type of mask, which uses std::uint32_t as default
51+
*
52+
* @return the ValueType with Size activated bits in the end
53+
*/
54+
template <int Size, typename ValueType = std::uint32_t>
55+
constexpr std::enable_if_t<(Size < sizeof(ValueType) * 8), ValueType> mask()
56+
{
57+
return (ValueType{1} << Size) - 1;
58+
}
59+
60+
/**
61+
* @copydoc mask()
62+
*
63+
* @note this is special case for the Size = the number of bits of ValueType
64+
*/
65+
template <int Size, typename ValueType = std::uint32_t>
66+
constexpr std::enable_if_t<Size == sizeof(ValueType) * 8, ValueType> mask()
67+
{
68+
return ~ValueType{};
69+
}
70+
71+
72+
/**
73+
* shift calculates the number of bits for shifting
74+
*
75+
* @tparam current_shift the current position of shifting
76+
* @tparam num_groups the number of elements in array
77+
*
78+
* @return the number of shifting bits
79+
*
80+
* @note this is the last case of nested template
81+
*/
82+
template <int current_shift, int num_groups>
83+
constexpr std::enable_if_t<(num_groups == current_shift + 1), int> shift(
84+
const std::array<unsigned char, num_groups> &bits)
85+
{
86+
return 0;
87+
}
88+
89+
/**
90+
* @copydoc shift(const std::array<char, num_groups>)
91+
*
92+
* @note this is the usual case of nested template
93+
*/
94+
template <int current_shift, int num_groups>
95+
constexpr std::enable_if_t<(num_groups > current_shift + 1), int> shift(
96+
const std::array<unsigned char, num_groups> &bits)
97+
{
98+
return bits[current_shift + 1] +
99+
shift<(current_shift + 1), num_groups>(bits);
100+
}
101+
102+
103+
} // namespace detail
104+
105+
106+
/**
107+
* ConfigSet is a way to embed several information into one integer by given
108+
* certain bits.
109+
*
110+
* The usage will be the following
111+
* Set the method with bits Cfg = ConfigSet<b_0, b_1, ..., b_k>
112+
* Encode the given infomation encoded = Cfg::encode(x_0, x_1, ..., x_k)
113+
* Decode the specific position information x_t = Cfg::decode<t>(encoded)
114+
* The encoded result will use 32 bits to record
115+
* rrrrr0..01....1...k..k, which 1/2/.../k means the bits store the information
116+
* for 1/2/.../k position and r is for rest of unused bits.
117+
*
118+
* Denote $B_t = \sum_{i = t+1}^k b_i$ and $F(X) = Cfg::encode(x_0, ..., x_k)$.
119+
* Have $F(X) = \sum_{i = 0}^k (x_i << B_i) = \sum_{i = 0}^k (x_i * 2^{B_i})$.
120+
* For all i, we have $0 <= x_i < 2^{b_i}$.
121+
* $x_i$, $2^{B_i}$ are non-negative, so
122+
* $F(X) = 0$ <=> $X = \{0\}$, $x_i = 0$ for all i.
123+
* Assume $F(X) = F(Y)$, then
124+
* $0 = |F(X) - F(Y)| = |F(X-Y)| = F(|X - Y|)$.
125+
* $|x_i - y_i|$ is still in the same range $0 <= |x_i - y_i| < 2^{b_i}$.
126+
* Thus, $F(|X - Y|) = 0$ -> $|X - Y| = \{0\}$, $x_i - y_i = 0$ -> $X = Y$.
127+
* F is one-to-one function if $0 <= x_i < 2^{b_i}$ for all i.
128+
* For any encoded result R, we can use the following to get the decoded series.
129+
* for i = k to 0;
130+
* $x_i = R % b_i$;
131+
* $R = R / bi$;
132+
* endfor;
133+
* For any R in the range $[0, 2^{B_0})$, we have X such that $F(X) = R$.
134+
* F is onto function.
135+
* Thus, F is bijection.
136+
*
137+
* @tparam num_bits... the number of bits for each position.
138+
*
139+
* @note the num_bit is required at least $ceil(log_2(maxval) + 1)$
140+
*/
141+
template <unsigned char... num_bits>
142+
class ConfigSet {
143+
public:
144+
static constexpr unsigned num_groups = sizeof...(num_bits);
145+
static constexpr std::array<unsigned char, num_groups> bits{num_bits...};
146+
147+
/**
148+
* Decodes the `position` information from encoded
149+
*
150+
* @tparam position the position of desired information
151+
*
152+
* @param encoded the encoded integer
153+
*
154+
* @return the decoded information at position
155+
*/
156+
template <int position>
157+
static constexpr std::uint32_t decode(std::uint32_t encoded)
158+
{
159+
static_assert(position < num_groups,
160+
"This position is over the bounds.");
161+
constexpr int shift = detail::shift<position, num_groups>(bits);
162+
constexpr auto mask = detail::mask<bits[position]>();
163+
return (encoded >> shift) & mask;
164+
}
165+
166+
/**
167+
* Encodes the information with given bit set to encoded integer.
168+
*
169+
* @note the last case of nested template.
170+
*/
171+
template <unsigned current_iter>
172+
static constexpr std::enable_if_t<(current_iter == num_groups),
173+
std::uint32_t>
174+
encode()
175+
{
176+
return 0;
177+
}
178+
179+
/**
180+
* Encodes the information with given bit set to encoded integer.
181+
*
182+
* @tparam current_iter the encoded place
183+
* @tparam Rest... the rest type
184+
*
185+
* @param first the current encoded information
186+
* @param rest... the rest of other information waiting for encoding
187+
*
188+
* @return the encoded integer
189+
*/
190+
template <unsigned current_iter = 0, typename... Rest>
191+
static constexpr std::enable_if_t<(current_iter < num_groups),
192+
std::uint32_t>
193+
encode(std::uint32_t first, Rest &&... rest)
194+
{
195+
constexpr int shift = detail::shift<current_iter, num_groups>(bits);
196+
if (current_iter == 0) {
197+
static_assert(
198+
bits[current_iter] + shift <= sizeof(std::uint32_t) * 8,
199+
"the total bits usage is larger than std::uint32_t bits");
200+
}
201+
return (first << shift) |
202+
encode<current_iter + 1>(std::forward<Rest>(rest)...);
203+
}
204+
};
205+
206+
207+
} // namespace gko
208+
209+
#endif // GKO_CORE_BASE_TYPES_HPP_

core/synthesizer/implementation_selection.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,38 @@ namespace syn {
7070
} \
7171
}
7272

73+
#define GKO_ENABLE_IMPLEMENTATION_CONFIG_SELECTION(_name, _callable) \
74+
template <typename Predicate, bool... BoolArgs, int... IntArgs, \
75+
gko::size_type... SizeTArgs, typename... TArgs, \
76+
typename... InferredArgs> \
77+
inline void _name(::gko::syn::value_list<std::uint32_t>, Predicate, \
78+
::gko::syn::value_list<bool, BoolArgs...>, \
79+
::gko::syn::value_list<int, IntArgs...>, \
80+
::gko::syn::value_list<gko::size_type, SizeTArgs...>, \
81+
::gko::syn::type_list<TArgs...>, InferredArgs...) \
82+
GKO_KERNEL_NOT_FOUND; \
83+
\
84+
template <std::uint32_t K, std::uint32_t... Rest, typename Predicate, \
85+
bool... BoolArgs, int... IntArgs, gko::size_type... SizeTArgs, \
86+
typename... TArgs, typename... InferredArgs> \
87+
inline void _name( \
88+
::gko::syn::value_list<std::uint32_t, K, Rest...>, \
89+
Predicate is_eligible, \
90+
::gko::syn::value_list<bool, BoolArgs...> bool_args, \
91+
::gko::syn::value_list<int, IntArgs...> int_args, \
92+
::gko::syn::value_list<gko::size_type, SizeTArgs...> size_args, \
93+
::gko::syn::type_list<TArgs...> type_args, InferredArgs... args) \
94+
{ \
95+
if (is_eligible(K)) { \
96+
_callable<BoolArgs..., IntArgs..., SizeTArgs..., TArgs..., K>( \
97+
std::forward<InferredArgs>(args)...); \
98+
} else { \
99+
_name(::gko::syn::value_list<std::uint32_t, Rest...>(), \
100+
is_eligible, bool_args, int_args, size_args, type_args, \
101+
std::forward<InferredArgs>(args)...); \
102+
} \
103+
}
104+
73105

74106
} // namespace syn
75107
} // namespace gko

core/test/base/types.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3333
#include <ginkgo/core/base/types.hpp>
3434

3535

36+
#include <array>
37+
#include <cstdint>
38+
#include <stdexcept>
39+
#include <type_traits>
40+
41+
3642
#include <gtest/gtest.h>
3743

3844

45+
#include "core/base/types.hpp"
46+
47+
3948
namespace {
4049

4150

@@ -100,4 +109,108 @@ TEST(PrecisionReduction, ComputesCommonEncoding)
100109
}
101110

102111

112+
TEST(ConfigSet, MaskCorrectly)
113+
{
114+
constexpr auto mask3_u = gko::detail::mask<3>();
115+
constexpr auto fullmask_u = gko::detail::mask<32>();
116+
constexpr auto mask3_u64 = gko::detail::mask<3, std::uint64_t>();
117+
constexpr auto fullmask_u64 = gko::detail::mask<64, std::uint64_t>();
118+
119+
ASSERT_EQ(mask3_u, 7u);
120+
ASSERT_EQ(fullmask_u, 0xffffffffu);
121+
ASSERT_TRUE((std::is_same<decltype(mask3_u), const std::uint32_t>::value));
122+
ASSERT_TRUE(
123+
(std::is_same<decltype(fullmask_u), const std::uint32_t>::value));
124+
ASSERT_EQ(mask3_u64, 7ull);
125+
ASSERT_EQ(fullmask_u64, 0xffffffffffffffffull);
126+
ASSERT_TRUE(
127+
(std::is_same<decltype(mask3_u64), const std::uint64_t>::value));
128+
ASSERT_TRUE(
129+
(std::is_same<decltype(fullmask_u64), const std::uint64_t>::value));
130+
}
131+
132+
133+
TEST(ConfigSet, ShiftCorrectly)
134+
{
135+
constexpr std::array<unsigned char, 3> bits{3, 5, 7};
136+
137+
138+
constexpr auto shift0 = gko::detail::shift<0, 3>(bits);
139+
constexpr auto shift1 = gko::detail::shift<1, 3>(bits);
140+
constexpr auto shift2 = gko::detail::shift<2, 3>(bits);
141+
142+
ASSERT_EQ(shift0, 12);
143+
ASSERT_EQ(shift1, 7);
144+
ASSERT_EQ(shift2, 0);
145+
}
146+
147+
148+
TEST(ConfigSet, ConfigSet1Correctly)
149+
{
150+
using Cfg = gko::ConfigSet<3>;
151+
152+
constexpr auto encoded = Cfg::encode(2);
153+
constexpr auto decoded = Cfg::decode<0>(encoded);
154+
155+
ASSERT_EQ(encoded, 2);
156+
ASSERT_EQ(decoded, 2);
157+
}
158+
159+
160+
TEST(ConfigSet, ConfigSet1FullCorrectly)
161+
{
162+
using Cfg = gko::ConfigSet<32>;
163+
164+
constexpr auto encoded = Cfg::encode(0xffffffff);
165+
constexpr auto decoded = Cfg::decode<0>(encoded);
166+
167+
ASSERT_EQ(encoded, 0xffffffff);
168+
ASSERT_EQ(decoded, 0xffffffff);
169+
}
170+
171+
172+
TEST(ConfigSet, ConfigSet2FullCorrectly)
173+
{
174+
using Cfg = gko::ConfigSet<1, 31>;
175+
176+
constexpr auto encoded = Cfg::encode(1, 33);
177+
178+
ASSERT_EQ(encoded, (1u << 31) + 33);
179+
}
180+
181+
182+
TEST(ConfigSet, ConfigSetSomeCorrectly)
183+
{
184+
using Cfg = gko::ConfigSet<3, 5, 7>;
185+
186+
constexpr auto encoded = Cfg::encode(2, 11, 13);
187+
constexpr auto decoded_0 = Cfg::decode<0>(encoded);
188+
constexpr auto decoded_1 = Cfg::decode<1>(encoded);
189+
constexpr auto decoded_2 = Cfg::decode<2>(encoded);
190+
191+
ASSERT_EQ(encoded, (2 << 12) + (11 << 7) + 13);
192+
ASSERT_EQ(decoded_0, 2);
193+
ASSERT_EQ(decoded_1, 11);
194+
ASSERT_EQ(decoded_2, 13);
195+
}
196+
197+
198+
TEST(ConfigSet, ConfigSetSomeFullCorrectly)
199+
{
200+
using Cfg = gko::ConfigSet<2, 6, 7, 17>;
201+
202+
constexpr auto encoded = Cfg::encode(2, 11, 13, 19);
203+
constexpr auto decoded_0 = Cfg::decode<0>(encoded);
204+
constexpr auto decoded_1 = Cfg::decode<1>(encoded);
205+
constexpr auto decoded_2 = Cfg::decode<2>(encoded);
206+
constexpr auto decoded_3 = Cfg::decode<3>(encoded);
207+
208+
ASSERT_EQ(encoded, (2 << 30) + (11 << 24) + (13 << 17) + 19);
209+
ASSERT_EQ(decoded_0, 2);
210+
ASSERT_EQ(decoded_1, 11);
211+
ASSERT_EQ(decoded_2, 13);
212+
ASSERT_EQ(decoded_3, 19);
213+
}
214+
215+
103216
} // namespace

dev_tools/scripts/config

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
- RemoveTest: "true"
3333
- "_builder\.hpp"
3434
- CoreSuffix: "_builder"
35+
- "dpcpp/test/base/dim3\.dp\.cpp"
36+
- FixInclude: "dpcpp/base/dim3.dp.hpp"
3537
- "components.*_kernels(\.hip|\.dp)?\.(cu|cpp|hpp|cuh)"
3638
- CoreSuffix: "_kernels"
3739
- RemoveTest: "true"

0 commit comments

Comments
 (0)