thread_buffer.hpp Source File

thread_buffer.hpp Source File#

Composable Kernel: thread_buffer.hpp Source File
thread_buffer.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
9
10namespace ck_tile {
11
12#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE
13template <typename T, index_t N>
15
16template <typename... Ts>
17CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
18{
19 return make_tuple(ts...);
20}
21#else
22
23#if 0
24template <typename T, index_t N>
25using thread_buffer = array<T, N>;
26
27template <typename... Ts>
28CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
29{
30 return make_array(ts...);
31}
32
33#endif
34
35// clang-format off
36template<typename T_, index_t N_>
37struct thread_buffer {
38 using value_type = remove_cvref_t<T_>;
39 static constexpr index_t N = N_;
40
41 value_type data[N];
42
43 // TODO: this ctor can't ignore
44 CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
45 CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{} {
46 static_for<0, N, 1>{}(
47 [&](auto i) { data[i] = o; }
48 );
49 }
50
51 CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
52 CK_TILE_HOST_DEVICE auto & get() {return data; }
53 CK_TILE_HOST_DEVICE const auto & get() const {return data; }
54 CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
55 CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
56 CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
57 CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
58 CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
59 CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
60 CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
61 template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
62 template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
63 template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
64 template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
65
66 template <typename X_,
67 typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
68 CK_TILE_HOST_DEVICE constexpr auto _get_as() const
69 {
70 using X = remove_cvref_t<X_>;
71
72 constexpr index_t kSPerX = vector_traits<X>::vector_size;
73 static_assert(N % kSPerX == 0);
74
75 union {
76 thread_buffer<X_, N / kSPerX> data {};
77 // tuple_array<value_type, kSPerX> sub_data;
78 value_type sub_data[N];
79 } vx;
80 static_for<0, N, 1>{}(
81 [&](auto j) { vx.sub_data[j] = data[j]; });
82 return vx.data;
83 }
84
85 template <typename X_,
86 index_t Is,
87 typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
88 CK_TILE_HOST_DEVICE const constexpr remove_reference_t<X_> _get_as(number<Is> is) const
89 {
90 using X = remove_cvref_t<X_>;
91
92 constexpr index_t kSPerX = vector_traits<X>::vector_size;
93
94 union {
95 X_ data {};
97 } vx;
98 static_for<0, kSPerX, 1>{}(
99 [&](auto j) { vx.sub_data(j) = operator[]((is * number<sizeof(X_)/sizeof(value_type)>{}) + j); });
100 return vx.data;
101 }
102
103#if 0
104 template <typename X_,
105 index_t Is,
106 typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
107 CK_TILE_HOST_DEVICE constexpr void _set_as(number<Is> is, X_ x)
108 {
109 using X = remove_cvref_t<X_>;
110
111 constexpr index_t kSPerX = vector_traits<X>::vector_size;
112
113 union {
114 X_ data;
116 } vx {x};
117
118 static_for<0, kSPerX, 1>{}(
119 [&](auto j) { operator()((is * number<sizeof(X_)/sizeof(value_type)>{}) + j) = vx.sub_data[j]; });
120 }
121#endif
122
123
124#define TB_COMMON_AS() \
125 static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
126 constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
127
128 template<typename Tx>
129 CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
130 return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
131 template<typename Tx>
132 CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS();
133 if constexpr(sizeof(value_type) <= 1 )
134 return _get_as<Tx>(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
135 else
136 return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
137 template<typename Tx, index_t I>
138 CK_TILE_HOST_DEVICE auto & get_as(number<I>) {TB_COMMON_AS();
139 return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(number<I>{});}
140 template<typename Tx, index_t I>
141 CK_TILE_HOST_DEVICE constexpr auto get_as(number<I>) const {TB_COMMON_AS();
142 if constexpr(sizeof(value_type) <= 1 )
143 return _get_as<Tx>(number<I>{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future
144 else
145 return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(number<I>{});}
146
147 template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
148 { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(i) = x; }
149 template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
150 { TB_COMMON_AS(); reinterpret_cast<thread_buffer<Tx, vx>&>(data).at(number<I>{}) = x; }
151
152#undef TB_COMMON_AS
153};
154// clang-format on
155
156template <typename, typename>
157struct vector_traits;
158
159// specialization for array
160template <typename T, index_t N>
161struct vector_traits<thread_buffer<T, N>, std::enable_if_t<!std::is_class_v<T>>>
162{
163 using scalar_type = T;
164 static constexpr index_t vector_size = N;
165};
166
167template <typename T, index_t N>
168struct vector_traits<thread_buffer<T, N>, std::enable_if_t<std::is_class_v<T>>>
169{
170 using scalar_type = typename T::type;
171 static constexpr index_t vector_size = N;
172};
173
174#endif
175
176} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts &&... ts)
Definition thread_buffer.hpp:17
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
typename std::remove_reference< T >::type remove_reference_t
Definition type_traits.hpp:15
tuple_array< T, N > thread_buffer
Definition thread_buffer.hpp:14
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
typename impl::tuple_array_impl< T, N >::type tuple_array
Definition tile/core/container/tuple.hpp:28
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Definition tile/core/utility/debug.hpp:67
Definition vector_type.hpp:90
static constexpr index_t vector_size
Definition vector_type.hpp:98
std::conditional_t< std::is_same_v< remove_cvref_t< T >, pk_int4_t >, int8_t, std::conditional_t< std::is_same_v< remove_cvref_t< T >, pk_fp4_t >|| std::is_same_v< remove_cvref_t< T >, e8m0_t >, uint8_t, remove_cvref_t< T > > > scalar_type
Definition vector_type.hpp:91