block_gemm_areg_bsmem_creg_one_warp_v1.hpp Source File

block_gemm_areg_bsmem_creg_one_warp_v1.hpp Source File#

Composable Kernel: block_gemm_areg_bsmem_creg_one_warp_v1.hpp Source File
block_gemm_areg_bsmem_creg_one_warp_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block distributed tensor
12// B is block window on shared memory
13// C is block distributed tensor
14template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
16{
23
24 static constexpr index_t kBlockSize = Problem::kBlockSize;
25 static_assert(kBlockSize == get_warp_size(), "Check failed!");
26
27 // C += A * B
28 template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
29 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
30 const ABlockTensorTmp& a_block_tensor_tmp,
31 const BBlockWindowTmp& b_block_window_tmp) const
32 {
33 static_assert(
34 std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
35 std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
36 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
37 "wrong!");
38
39 // constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
40 // constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
41 // constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
42 constexpr index_t MPerBlock = BlockGemmShape::kM;
43 constexpr index_t NPerBlock = BlockGemmShape::kN;
44 constexpr index_t KPerBlock = BlockGemmShape::kK;
45
46 // static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
47 // KPerBlock == BlockGemmShape::kK,
48 // "wrong!");
49
50 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
51
52 using WG = remove_cvref_t<decltype(config.template at<0>())>;
53
54 constexpr index_t MWarp = config.template at<1>();
55 constexpr index_t NWarp = config.template at<2>();
56
57 static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
58
59 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
60 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
61 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
62
63 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
64 constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
65
66 const index_t iNWarp = 0;
67
68 constexpr auto c_block_outer_dstr_encoding =
71 tuple<>,
72 tuple<>,
75
76 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
77 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
78
79 // constrcut from A-block-tensor from A-Block-tensor-tmp
80 // FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
81 // distribution
84
85 a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
86
87 // construct B-warp-window
88 auto b_warp_window_tmp = make_tile_window(
89 b_block_window_tmp.get_bottom_tensor_view(),
91 b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
92 make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
93
94#if 0 // FIXME: using array will cause register spill
95 array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
96 {b_warp_window_tmp}};
97
98 for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
99 {
100 for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
101 {
102 move_tile_window(b_warp_windows(nIter)(kIter),
103 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
104 }
105 }
106#else
108 statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
109 NIterPerWarp>
110 b_warp_windows;
111
112 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
113 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
114 b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
115
116 move_tile_window(b_warp_windows(nIter)(kIter),
117 {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
118 });
119 });
120#endif
121
122 // check C-block-distribution
123 static_assert(
124 std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
125 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
126 .get_static_tile_distribution_encoding())>>,
127 "wrong!");
128
129 using AWarpDstr = typename WG::AWarpDstr;
130 using CWarpDstr = typename WG::CWarpDstr;
131
132 using AWarpTensor = typename WG::AWarpTensor;
133 using CWarpTensor = typename WG::CWarpTensor;
134
135 constexpr auto a_warp_y_lengths =
136 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
137 constexpr auto c_warp_y_lengths =
138 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
139
140 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
141 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
142
143 // hot loop:
144 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
145 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
146 // read A warp tensor from A block tensor
147 AWarpTensor a_warp_tensor;
148
149 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
150 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
151 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
152
153 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
154 // read B warp tensor from B Block window
155 const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
156
157 // read C warp tensor from C block tensor
158 CWarpTensor c_warp_tensor;
159
160 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
161 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
162 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
163
164 // warp GEMM
165 WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
166
167 // write C warp tensor into C block tensor
168 c_block_tensor.set_y_sliced_thread_data(
169 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
170 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
171 c_warp_tensor.get_thread_buffer());
172 });
173 });
174 });
175 }
176
177 template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
179 {
180 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
181
182 using WG = remove_cvref_t<decltype(config.template at<0>())>;
183
184 constexpr index_t MWarp = config.template at<1>();
185 constexpr index_t NWarp = config.template at<2>();
186
187 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
188 constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
189
190 constexpr auto a_block_outer_dstr_encoding =
197
198 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
199 a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
200
201 return make_static_tile_distribution(a_block_dstr_encode);
202 }
203
204 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
205 {
206 constexpr index_t MPerBlock = BlockGemmShape::kM;
207 constexpr index_t NPerBlock = BlockGemmShape::kN;
208
209 constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
210
211 using WG = remove_cvref_t<decltype(config.template at<0>())>;
212
213 constexpr index_t MWarp = config.template at<1>();
214 constexpr index_t NWarp = config.template at<2>();
215
216 static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
217
218 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
219 constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
220 // constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
221
222 constexpr auto c_block_outer_dstr_encoding =
225 tuple<>,
226 tuple<>,
229
230 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
231 c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
232
233 static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!");
234
235 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
236 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
237 return c_block_tensor;
238 }
239
240 // C = A * B
241 template <typename ABlockTensorTmp, typename BBlockWindowTmp>
242 CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
243 const BBlockWindowTmp& b_block_window_tmp) const
244 {
245 auto c_block_tensor = MakeCBlockTile();
246 operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
247 return c_block_tensor;
248 }
249};
250
251} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
array< index_t, N > multi_index
Definition tile/core/container/multi_index.hpp:17
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:16
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:17
static constexpr index_t kBlockSize
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:24
remove_cvref_t< typename Problem::ADataType > ADataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:19
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:204
remove_cvref_t< typename Problem::BDataType > BDataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:20
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:29
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:18
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:22
remove_cvref_t< typename Problem::CDataType > CDataType
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:21
static CK_TILE_DEVICE constexpr auto MakeABlockTileDistribution()
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:178
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp &a_block_tensor_tmp, const BBlockWindowTmp &b_block_window_tmp) const
Definition block_gemm_areg_bsmem_creg_one_warp_v1.hpp:242
A fixed-size array container similar to std::array with additional utilities.
Definition tile/core/container/array.hpp:43
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192