gemm_bquant_pipeline_ag_bg_cr_v3.hpp Source File

gemm_bquant_pipeline_ag_bg_cr_v3.hpp Source File#

Composable Kernel: gemm_bquant_pipeline_ag_bg_cr_v3.hpp Source File
gemm_bquant_pipeline_ag_bg_cr_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7#include <sstream>
8
9#include "ck_tile/core.hpp"
14
15namespace ck_tile {
16
17// Compute optimized pipeline
18// GlobalPrefetchStages: 2
19// LocalPreFillStages: 1
20// LocalPreFetchStages: 1
21// LocalSharedMemoryBuffer: 1
22
23template <typename Problem>
25{
26 template <typename RunFunction>
27 CK_TILE_HOST_DEVICE static auto
28 TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
29 {
30 if(has_hot_loop)
31 {
32 if(tail_number == ck_tile::TailNumber::Full)
33 {
34 return run_func(
37 }
38 else if(tail_number == ck_tile::TailNumber::Odd)
39 {
40 return run_func(
43 }
44 else if(tail_number == ck_tile::TailNumber::Even)
45 {
46 return run_func(
49 }
50 else
51 {
52 throw std::runtime_error("Unsupported tail number for this operation !!!");
53 }
54 }
55 else
56 {
57 if(tail_number == ck_tile::TailNumber::Full)
58 {
59 return run_func(
62 }
63 else if(tail_number == ck_tile::TailNumber::Odd)
64 {
65 return run_func(
68 }
69 else if(tail_number == ck_tile::TailNumber::Even)
70 {
71 return run_func(
74 }
75 else
76 {
77 throw std::runtime_error("Unsupported tail number for this operation !!!");
78 }
79 }
80 }
81};
82
83template <typename Problem, typename Policy = GemmBQuantPipelineAgBgCrDefaultPolicy>
85{
88
95
96 static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
97 using I0 = number<0>;
98 using I1 = number<1>;
99 using I2 = number<2>;
100
101 static constexpr index_t APackedSize =
103 static constexpr index_t BPackedSize =
105
106 static constexpr index_t BQPackedSize =
108
113
115
116 static constexpr index_t BlockSize = Problem::kBlockSize;
117 static constexpr index_t MPerBlock = BlockGemmShape::kM;
118 static constexpr index_t NPerBlock = BlockGemmShape::kN;
119 static constexpr index_t KPerBlock = BlockGemmShape::kK;
120
121 static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
122 static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
123
124 static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
125 static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
126 static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
127 static constexpr index_t GetVectorSizeBQ()
128 {
129 return Policy::template GetVectorSizeBQ<Problem>();
130 }
131
132 static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
133 static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
134
135 static constexpr bool kPadM = Problem::kPadM;
136 static constexpr bool kPadN = Problem::kPadN;
137 static constexpr bool kPadK = Problem::kPadK;
138
139 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
140
141 static constexpr bool HasHotLoop = Problem::HasHotLoop;
142 static constexpr auto TailNum = Problem::TailNum;
143 static constexpr auto Scheduler = Problem::Scheduler;
144
146
147 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
148 {
149 // clang-format off
150 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
151 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
152 return concat('_', "bquant_pipeline_AgBgCrCompV3",
154 BlockSize,
155 concat('x', WaveNumM, WaveNumN),
156 concat('x', BlockGemm::WarpGemm::kM, BlockGemm::WarpGemm::kN, BlockGemm::WarpGemm::kK),
157 concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
158 // clang-format on
159 }
160
162 {
163 return Policy::template GetSmemSize<Problem>();
164 }
165
166 CK_TILE_HOST static std::string Print()
167 {
168 constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
169 constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
170 constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
171
172 constexpr index_t WaveSize = 64;
173 constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
174 constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
175
176 constexpr index_t A_LDS_Read_Width = GetSmemPackA();
177 constexpr index_t B_LDS_Read_Width = GetSmemPackB();
178
179 constexpr index_t A_LDS_Write_Width = GetSmemPackA();
180 constexpr index_t B_LDS_Write_Width = GetSmemPackB();
181
182 constexpr index_t A_Buffer_Load_Inst_Num =
184 constexpr index_t B_Buffer_Load_Inst_Num =
186 constexpr index_t BQ_Buffer_Load_Inst_Num =
188
189 constexpr index_t A_LDS_Write_Inst_Num =
190 MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
191 constexpr index_t B_LDS_Write_Inst_Num =
192 NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
193
194 constexpr index_t A_LDS_Read_Inst_Num =
195 WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
196 constexpr index_t B_LDS_Read_Inst_Num =
197 WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
198
199 constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
200 (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
201
202 auto str = std::stringstream{};
203
204 str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
205 << "BQ vector size: " << GetVectorSizeBQ() << "\n"
206 << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
207 << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
208 << ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
209 << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
210 << "\n"
211 << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
212 << "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
213 << "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
214 << "KPack: " << BlockGemm::Traits::KPack << "\n"
215 << "PrefetchStages: " << PrefetchStages << "\n";
216 return str.str();
217 }
218
219 template <GemmPipelineScheduler Scheduler>
221 {
222 };
223
224 template <>
226 {
228
229 template <bool HasHotLoop,
231 typename ADramBlockWindowTmp,
232 typename BDramBlockWindowTmp,
233 typename BQDramBlockWindowTmp,
234 typename AElementFunction,
235 typename BElementFunction>
236 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
237 const AElementFunction& a_element_func,
238 const BDramBlockWindowTmp& b_dram_block_window_tmp,
239 const BElementFunction& b_element_func,
240 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
241 index_t num_loop,
242 void* p_smem) const
243 {
244 static_assert(
245 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
246 std::is_same_v<BDataType,
248 std::is_same_v<BQDataType,
250 "A/B/BQ Dram block window should have the same data type as appropriate "
251 "([A|B|BQ]DataType) defined in Problem definition!");
252
253 constexpr bool is_a_col_major =
254 std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
255 constexpr bool is_bq_col_major =
256 std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
257 constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
258
259 static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
260 static_assert(KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
261 NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
262 "Bq block window has incorrect lengths for defined BqLayout!");
263
264 static_assert(is_a_col_major
265 ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
266 MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
267 : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
268 KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
269 "A block window has incorrect lengths for defined ALayout!");
270 static_assert(is_b_row_major
271 ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
272 NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
273 : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
274 KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
275 "B block window has incorrect lengths for defined BLayout!");
276
277 using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
278 using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
279 using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
280
281 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
282
283 constexpr auto a_lds_load_tile_distr =
284 make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
285 constexpr auto b_lds_load_tile_distr =
286 make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
287
288 auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
289 Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
290 auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
291 Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
292 auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
293
294 using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
295 using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
296 using BQBlockTileDistr = decltype(bq_copy_dram_window.get_tile_distribution());
297
298 using ABlockTile =
299 decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
300 using BBlockTile =
301 decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
302 using BQBlockTile =
303 decltype(make_static_distributed_tensor<BQDataType>(BQBlockTileDistr{}));
304
305 auto block_gemm = BlockGemm();
306
307 ABlockTile a_block_tile;
308 BBlockTile b_block_tile;
309 BQBlockTile bq_block_tile[2];
310 int currIdx = 0;
311
312 auto c_block_tile = block_gemm.MakeCBlockTile();
313
314 constexpr ADramTileWindowStep a_dram_tile_window_step =
315 is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
316 constexpr BDramTileWindowStep b_dram_tile_window_step =
317 is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
318 constexpr BQDramTileWindowStep bq_dram_tile_window_step =
319 is_bq_col_major ? make_array(KPerBlockBQ, 0) : make_array(0, KPerBlockBQ);
320
321 // DRAM prefetch (global read 0)
322 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
323 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
325 bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step);
326
327 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
328
329 if constexpr(is_a_col_major)
330 {
332 Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
333 transpose_tile2d(a_shuffle_tmp, a_block_tile);
334 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
335 }
336 else
337 {
338 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
339 }
340
341 if constexpr(is_b_row_major)
342 {
344 Policy::template make_shuffled_2d_static_tile_distribution<Problem>());
345 transpose_tile2d(b_shuffle_tmp, b_block_tile);
346 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
347 }
348 else
349 {
350 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
351 }
352
353 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
354 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
355
357
358 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
359
360 __builtin_amdgcn_sched_barrier(0);
361
362 if constexpr(HasHotLoop)
363 {
364 constexpr index_t tail_count =
365 ((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) ? 1 : 2;
366 index_t i = 0;
367 do
368 {
370
371 if constexpr(is_a_col_major)
372 {
374 Policy::template MakeShuffledARegTileDistribution<Problem>());
375 transpose_tile2d(a_shuffle_tmp, a_block_tile);
376 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
377 }
378 else
379 {
380 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
381 }
382 if constexpr(is_b_row_major)
383 {
385 Policy::template MakeShuffledBRegTileDistribution<Problem>());
386 transpose_tile2d(b_shuffle_tmp, b_block_tile);
387 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
388 }
389 else
390 {
391 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
392 }
393
394 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
395 Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step);
396 Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
397 bq_copy_dram_window,
398 bq_dram_tile_window_step);
399
400 block_gemm(
401 c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
402
403 currIdx = (currIdx + 1) % 2;
404
406
407 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
408 __builtin_amdgcn_sched_barrier(0);
409
410 i += 1;
411 } while(i < (num_loop - tail_count));
412 }
413 // tail
414 if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
415 {
416 block_gemm(
417 c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
418 }
419 else
420 {
421 Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2],
422 bq_copy_dram_window,
423 bq_dram_tile_window_step);
424 block_gemm(
425 c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
427
428 currIdx = (currIdx + 1) % 2;
429
430 if constexpr(is_a_col_major)
431 {
433 Policy::template MakeShuffledARegTileDistribution<Problem>());
434 transpose_tile2d(a_shuffle_tmp, a_block_tile);
435 Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
436 }
437 else
438 {
439 Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
440 }
441 if constexpr(is_b_row_major)
442 {
444 Policy::template MakeShuffledBRegTileDistribution<Problem>());
445 transpose_tile2d(b_shuffle_tmp, b_block_tile);
446 Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
447 }
448 else
449 {
450 Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
451 }
453 block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
454 block_gemm(
455 c_block_tile, bq_block_tile[currIdx], a_lds_gemm_window, b_lds_gemm_window);
456 }
457 return c_block_tile;
458 }
459 };
460 template <typename ADramBlockWindowTmp,
461 typename BDramBlockWindowTmp,
462 typename BQDramBlockWindowTmp>
463 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
464 const BDramBlockWindowTmp& b_dram_block_window_tmp,
465 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
466 index_t num_loop,
467 void* p_smem) const
468 {
469 return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
470 a_dram_block_window_tmp,
471 [](const ADataType& a) { return a; },
472 b_dram_block_window_tmp,
473 [](const BDataType& b) { return b; },
474 bq_dram_block_window_tmp,
475 num_loop,
476 p_smem);
477 }
478
496 template <typename ADramBlockWindowTmp,
497 typename BDramBlockWindowTmp,
498 typename BQDramBlockWindowTmp>
499 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
500 const BDramBlockWindowTmp& b_dram_block_window_tmp,
501 const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
502 index_t num_loop,
503 bool has_hot_loop,
504 TailNumber tail_number,
505 void* p_smem) const
506 {
507 const auto RunPipeline = [&](auto has_hot_loop_, auto tail_number_) {
508 constexpr bool hot_loop = has_hot_loop_.value;
509 constexpr auto tail_num = tail_number_.value;
510 return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
511 a_dram_block_window_tmp,
512 [](const ADataType& a) { return a; },
513 b_dram_block_window_tmp,
514 [](const BDataType& b) { return b; },
515 bq_dram_block_window_tmp,
516 num_loop,
517 p_smem);
518 };
519 return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
520 }
521};
522
523} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
@ Full
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:39
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_DEVICE void transpose_tile2d(OutTensor &out, const InTensor &in)
Definition transpose_tile.hpp:195
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
CK_TILE_HOST_DEVICE constexpr details::return_type< D, Ts... > make_array(Ts &&... ts)
Definition tile/core/container/array.hpp:242
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BElementFunction &b_element_func, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:236
PipelineImplBase Base
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:227
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:221
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:85
static constexpr auto TailNum
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:142
static constexpr index_t GetSmemPackB()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:133
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:93
static constexpr index_t GetVectorSizeC()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:126
remove_cvref_t< typename Problem::BQLayout > BQLayout
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:110
remove_cvref_t< typename Problem::CDataType > CDataType
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:92
remove_cvref_t< typename Problem::ALayout > ALayout
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:109
number< 1 > I1
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:98
static constexpr bool kPadM
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:135
remove_cvref_t< typename Problem::BQDataType > BQDataType
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:91
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:161
static constexpr index_t GetVectorSizeBQ()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:127
static CK_TILE_HOST const std::string GetName()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:147
static constexpr index_t GetSmemPackA()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:132
remove_cvref_t< typename Problem::QuantGroupSize > QuantGroupSize
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:94
static constexpr index_t GetVectorSizeB()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:125
GemmBQuantPipelineAgBgCrImplBase< Problem, Policy > PipelineImplBase
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:87
remove_cvref_t< typename Problem::BLayout > BLayout
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:111
static constexpr index_t BQPackedSize
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:106
remove_cvref_t< typename Problem::BDataType > BDataType
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:90
BaseGemmPipelineAgBgCrCompV3< Problem > Base
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:86
number< 0 > I0
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:97
remove_cvref_t< decltype(Policy::template GetBlockGemm< Problem >())> BlockGemm
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:114
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, void *p_smem) const
Runtime pipeline dispatch operator for grouped GEMM kernels.
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:499
static constexpr bool kPadN
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:136
number< 2 > I2
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:99
static CK_TILE_HOST std::string Print()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:166
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BDramBlockWindowTmp &b_dram_block_window_tmp, const BQDramBlockWindowTmp &bq_dram_block_window_tmp, index_t num_loop, void *p_smem) const
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:463
static constexpr bool kPadK
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:137
remove_cvref_t< typename Problem::ADataType > ADataType
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:89
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
static constexpr index_t BPackedSize
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:103
static constexpr bool HasHotLoop
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:141
static constexpr index_t NPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:118
remove_cvref_t< typename Problem::CLayout > CLayout
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:112
static constexpr auto Scheduler
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:143
static constexpr bool DoubleSmemBuffer
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:139
static constexpr index_t KPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:122
static constexpr index_t NPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:121
static constexpr index_t GetVectorSizeA()
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:124
static constexpr index_t BlockSize
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:116
static constexpr index_t KPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:119
static constexpr index_t APackedSize
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:101
static constexpr index_t MPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:117
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:25
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_bquant_pipeline_ag_bg_cr_v3.hpp:28
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:18
static constexpr index_t PrefetchStages
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:19
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool has_hot_loop, TailNumber tail_number)
Definition gemm_pipeline_ag_bg_cr_comp_v3.hpp:50
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:14
static constexpr index_t KPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:27
static constexpr index_t KPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:30
static constexpr index_t NPerBlockBQ
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:29
static constexpr index_t NPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:26
CK_TILE_DEVICE constexpr auto GetBQDramLoadWindow(const BQDramBlockWindowTmp &bq_dram_block_window_tmp) const
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:43
typename Base::BDataType BDataType
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:18
static constexpr index_t MPerBlock
Definition gemm_bquant_pipeline_ag_bg_cr_base.hpp:25
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp &b_dram_block_window_tmp, const BLdsTensorView &b_lds_block_view, const BLdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:225
CK_TILE_DEVICE auto GetABLdsTensorViews(void *p_smem) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:83
CK_TILE_DEVICE void LocalPrefill(DstTileWindow &lds_tile_window, const SrcBlockTile &src_block_tile, const ElementFunction &element_func) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:57
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp &a_dram_block_window_tmp, const ALdsTensorView &a_lds_block_view, const ALdsLoadTileDistr &, const array< index_t, 2 > &offset={0, 0}) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:190
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile &dst_block_tile, SrcTileWindow &dram_tile_window, const DramTileWindowStep &dram_tile_window_step) const
Definition gemm_pipeline_ag_bg_cr_base.hpp:39
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/numeric/numeric.hpp:81