gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp Source File

gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp Source File#

Composable Kernel: gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp Source File
gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
16
18
19#define DEBUG_LOG 0
20
21namespace ck {
22
23// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
24// kernel function Blockers:
25// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
26// two lds chunks.
27// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
28// buffer when we declare __shared__ inside blkgemmpipe
29template <typename GridwiseGemm,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
38 // __attribute__((amdgpu_waves_per_eu(1, 1)))
40 typename GridwiseGemm::Argument karg)
41{
42#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
48 karg.p_a_grid,
49 karg.p_b_grid,
50 karg.p_ds_grid,
51 karg.p_c_grid,
52 karg.p_a_scale_grid,
53 karg.p_b_scale_grid,
54 p_shared,
55 karg,
56 karg.a_element_op,
57 karg.b_element_op,
58 karg.c_element_op);
59 }
60#else
61 ignore = karg;
62#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
63}
64
65template <typename GridwiseGemm,
66 bool HasMainKBlockLoop,
67 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
68 index_t MinimumOccupancy = 1,
70__global__ void
71#if CK_USE_LAUNCH_BOUNDS
72__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
73#endif
74 // __attribute__((amdgpu_waves_per_eu(1, 1)))
76 typename GridwiseGemm::Argument karg)
77{
78#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
79 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
80 {
81 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
82 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
83
84 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
85 karg.p_a_grid,
86 karg.p_b_grid,
87 karg.p_ds_grid,
88 karg.p_c_grid,
89 karg.p_a_scale_grid,
90 karg.p_b_scale_grid,
91 p_shared,
92 p_shared1,
93 karg,
94 karg.a_element_op,
95 karg.b_element_op,
96 karg.c_element_op);
97 }
98#else
99 ignore = karg;
100#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
101}
102
103template <typename ALayout,
104 typename BLayout,
105 typename DsLayout,
106 typename CLayout,
107 typename ADataType,
108 typename BDataType,
109 typename AccDataType,
110 typename CShuffleDataType,
111 typename DsDataType,
112 typename CDataType,
113 typename AElementwiseOperation,
114 typename BElementwiseOperation,
115 typename CElementwiseOperation,
117 index_t BlockSize,
118 index_t ScaleBlockM,
119 index_t ScaleBlockN,
120 index_t ScaleBlockK,
121 index_t MPerBlock,
122 index_t NPerBlock,
123 index_t KPerBlock,
124 index_t AK1Value,
125 index_t BK1Value,
126 index_t MPerXdl,
127 index_t NPerXdl,
128 index_t MXdlPerWave,
129 index_t NXdlPerWave,
130 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
131 typename ABlockTransferThreadClusterArrangeOrder,
132 typename ABlockTransferSrcAccessOrder,
133 index_t ABlockTransferSrcVectorDim,
134 index_t ABlockTransferSrcScalarPerVector,
135 index_t ABlockTransferDstScalarPerVector_AK1,
136 bool AThreadTransferSrcResetCoordinateAfterRun,
137 index_t ABlockLdsExtraM,
138 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
139 typename BBlockTransferThreadClusterArrangeOrder,
140 typename BBlockTransferSrcAccessOrder,
141 index_t BBlockTransferSrcVectorDim,
142 index_t BBlockTransferSrcScalarPerVector,
143 index_t BBlockTransferDstScalarPerVector_BK1,
144 bool BThreadTransferSrcResetCoordinateAfterRun,
145 index_t BBlockLdsExtraN,
146 index_t CShuffleMXdlPerWavePerShuffle,
147 index_t CShuffleNXdlPerWavePerShuffle,
148 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
149 typename CDEShuffleBlockTransferScalarPerVectors,
152 typename ComputeTypeA = CDataType,
153 typename ComputeTypeB = ComputeTypeA,
154 typename LDSTypeA = ADataType,
155 typename LDSTypeB = BDataType>
157{
158 using AScaleType = float;
159 using BScaleType = float;
160
161 static constexpr auto I0 = Number<0>{};
162 static constexpr auto I1 = Number<1>{};
163 static constexpr auto I2 = Number<2>{};
164 static constexpr auto I3 = Number<3>{};
165 static constexpr auto I4 = Number<4>{};
166 static constexpr auto I5 = Number<5>{};
167 static constexpr auto I6 = Number<6>{};
168 static constexpr auto I7 = Number<7>{};
169
171 CDEShuffleBlockTransferScalarPerVectors{}[I0];
172 // K1 should be Number<...>
173 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
174 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
175 static constexpr auto AK1Number = Number<AK1Value>{};
176 static constexpr auto BK1Number = Number<BK1Value>{};
177 static constexpr auto BlockSizeNumber = Number<BlockSize>{};
178
179 static constexpr index_t NumDTensor = DsDataType::Size();
180
182 static constexpr index_t KPack =
184 static constexpr index_t KGroup = []() {
186 // On gfx950, we have a mfma that required 32 f8 elements as input,
187 // splited into 2 groups of 16 f8 elements.
188 // the 2 groups is not contiguous in the B preshuffed layout.
189 // and we do not want it to be contiguous in the B preshuffled layout
190 // because a memory instruction can only read 16 f8 elements at a time.
191 return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
192 else
193 return 1;
194 }();
195 static constexpr index_t KLane =
197 static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup);
198 static constexpr index_t NLane = NPerXdl;
199 static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
200
201 static constexpr auto MakeDsGridPointer()
202 {
203 return generate_tuple(
204 [&](auto i) {
205 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
206
207 return static_cast<const DDataType*>(nullptr);
208 },
210 }
211
212 using DsGridPointer = decltype(MakeDsGridPointer());
213
215
216 __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
217 {
218 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
219 }
220
221 __host__ __device__ static auto CalculateMPadded(index_t M)
222 {
223 return math::integer_least_multiple(M, MPerBlock);
224 }
225
226 __host__ __device__ static auto CalculateNPadded(index_t N)
227 {
228 return math::integer_least_multiple(N, NPerBlock);
229 }
230
231 __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
232 {
234 }
235 __host__ __device__ static auto CalculateBK0Shuffled(index_t K)
236 {
238 }
239
240 __host__ __device__ static auto CalculateKPadded(index_t K)
241 {
242 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
243 }
244
245 __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
246 {
247 auto K_t = K_Batch * KPerBlock;
248 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
249 }
250
251 __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
252 {
253 auto K_t = K_Batch * KPerBlock;
254 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
255 }
256
257 __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
258 {
259 auto K_t = K_Batch * KPerBlock;
260 return (K + K_t - 1) / K_t * KPerBlock;
261 }
262
263 __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
264 {
265 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
266 auto K_t = K_Batch * KReadVec;
267 return (K + K_t - 1) / K_t * KReadVec;
268 }
269
270 __host__ __device__ static auto CalculateMBlock(index_t M)
271 {
272 return math::integer_divide_ceil(M, MPerBlock);
273 }
274
275 __host__ __device__ static auto CalculateNBlock(index_t N)
276 {
277 return math::integer_divide_ceil(N, NPerBlock);
278 }
279
280 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
281 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
282 {
283 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
284 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
285
287 TileDesc_K0_MN_K1{},
293 }
294
295 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
296 index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
297 {
298 const auto a_grid_desc_mraw_kraw = [&]() {
300 {
301 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
302 }
304 {
305 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
306 }
307 }();
308
309 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
310
311 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
312 GemmSpec == GemmSpecialization::MNKPadding)
313 {
314 // pad both M and K
315 const auto a_grid_desc_m_k =
316 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
318 make_right_pad_transform(K, KPad - K)),
321
322 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
323 a_grid_desc_m_k,
328
329 return a_grid_desc_ak0_m_ak1;
330 }
331 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
332 GemmSpec == GemmSpecialization::MNPadding)
333 {
334 // pad M, but not K
335 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
336 a_grid_desc_mraw_kraw,
338 make_right_pad_transform(M, MPad - M)),
341
342 return a_grid_desc_ak0_m_ak1;
343 }
344 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
345 GemmSpec == GemmSpecialization::NKPadding)
346 {
347 // pad K, but not M
348 const auto a_grid_desc_m_k = transform_tensor_descriptor(
349 a_grid_desc_mraw_kraw,
353
354 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
355 a_grid_desc_m_k,
360
361 return a_grid_desc_ak0_m_ak1;
362 }
363 else
364 {
365 // not pad M or K
366 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
367 a_grid_desc_mraw_kraw,
372
373 return a_grid_desc_ak0_m_ak1;
374 }
375 }
376
377 __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
378 {
379 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
380 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
381 constexpr index_t NkSwizzleNumber = Number<WaveSize * KPack / KGroup>{};
383 make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber),
384 make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1));
385 }
386
387 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
388 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
389 {
390 const auto b_grid_desc_nraw_kraw = [&]() {
392 {
393 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
394 }
396 {
397 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
398 }
399 }();
400
401 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
402
403 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
404 GemmSpec == GemmSpecialization::MNKPadding)
405 {
406 // pad both N and K
407 const auto b_grid_desc_n_k =
408 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
410 make_right_pad_transform(K, KPad - K)),
413
414 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
415 b_grid_desc_n_k,
420
421 return b_grid_desc_bk0_n_bk1;
422 }
423 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
424 GemmSpec == GemmSpecialization::MNPadding)
425 {
426 // pad N, but not K
427 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
428 b_grid_desc_nraw_kraw,
430 make_right_pad_transform(N, NPad - N)),
433
434 return b_grid_desc_bk0_n_bk1;
435 }
436 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
437 GemmSpec == GemmSpecialization::MKPadding)
438 {
439 // pad K, but not N
440 const auto b_grid_desc_n_k = transform_tensor_descriptor(
441 b_grid_desc_nraw_kraw,
445
446 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
447 b_grid_desc_n_k,
452
453 return b_grid_desc_bk0_n_bk1;
454 }
455 else
456 {
457 // not pad N or K
458 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
459 b_grid_desc_nraw_kraw,
464
465 return b_grid_desc_bk0_n_bk1;
466 }
467 }
468
469 template <typename ABlockDesc_AK0_M_AK1>
470 __host__ __device__ static constexpr auto
471 MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
472 {
473 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
474
475 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
476 }
477
478 template <typename BBlockDesc_BK0_N_BK1>
479 __host__ __device__ static constexpr auto
480 MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
481 {
482 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
483 }
484
485 template <typename ELayout>
486 __host__ __device__ static auto
488 {
489 const auto c_grid_desc_mraw_nraw = [&]() {
491 {
492 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
493 }
495 {
496 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
497 }
498 }();
499
500 // pad M and N
501 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
503 make_right_pad_transform(N, NPad - N)),
506 }
507
508 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
509 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
510 {
511 return generate_tuple(
512 [&](auto i) {
513 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
514 return MakeCGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
515 },
517 }
518
519 template <typename DsGridDesc>
521 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
522 {
523 return generate_tuple(
524 [&](auto i) {
526 ds_grid_desc_m_n[i], MBlock, NBlock);
527 },
529 }
530
531 using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))>;
532
533 struct Problem
534 {
535 __host__ __device__ Problem(index_t M_,
536 index_t N_,
537 index_t K_,
538 index_t StrideA_,
539 index_t StrideB_,
540 std::array<index_t, NumDTensor> StrideDs_,
541 index_t StrideC_,
542 index_t KBatch_)
543 : M{M_},
544 N{N_},
545 K{K_},
546 StrideA{StrideA_},
547 StrideB{StrideB_},
548 StrideDs{StrideDs_},
549 StrideC{StrideC_},
550 KBatch{KBatch_},
553 KRead{CalculateKRead(K_, KBatch_)},
554 KPadded{CalculateKPadded(K_, KBatch_)},
555 AK0{CalculateAK0Padded(K_, KBatch_)},
556 BK0{CalculateBK0Padded(K_, KBatch_)},
559 {
560 }
561
562 __host__ void Print() const
563 {
564 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
565 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
566 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
567 << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
568 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
569 << "NBlock: " << NBlock << "}" << std::endl;
570 }
571
577 std::array<index_t, NumDTensor> StrideDs;
588 };
589
590 // Argument
592 {
593 __host__ Argument(const ADataType* p_a_grid_,
594 const BDataType* p_b_grid_,
595 std::array<const void*, NumDTensor> p_ds_grid_,
596 CDataType* p_c_grid_,
597 index_t M_,
598 index_t N_,
599 index_t K_,
600 index_t StrideA_,
601 index_t StrideB_,
602 std::array<index_t, NumDTensor> StrideDs_,
603 index_t StrideC_,
604 const AScaleType* p_a_scale_grid_,
605 const BScaleType* p_b_scale_grid_,
606 index_t k_batch_,
607 AElementwiseOperation a_element_op_,
608 BElementwiseOperation b_element_op_,
609 CElementwiseOperation c_element_op_)
610 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
611 p_a_grid{p_a_grid_},
612 p_b_grid{p_b_grid_},
613 p_ds_grid{},
614 p_c_grid{p_c_grid_},
615 p_a_scale_grid{p_a_scale_grid_},
616 p_b_scale_grid{p_b_scale_grid_},
617 a_element_op{a_element_op_},
618 b_element_op{b_element_op_},
619 c_element_op{c_element_op_}
620 {
621
622 // populate pointer, desc for Ds
623 static_for<0, NumDTensor, 1>{}([&](auto i) {
624 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
625
626 // D pointer
627 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
628 });
629 }
630
631 const ADataType* p_a_grid;
632 const BDataType* p_b_grid;
634 CDataType* p_c_grid;
635
638
639 const AElementwiseOperation a_element_op;
640 const BElementwiseOperation b_element_op;
641 const CElementwiseOperation c_element_op;
642 };
643
645 {
646 __device__ SplitKBatchOffset(Argument& karg)
647 {
649 {
650 a_k_split_offset = blockIdx.z * karg.KRead;
651 }
653 {
654 a_k_split_offset = blockIdx.z * karg.KRead * karg.M;
655 }
656
658 {
659 b_k_split_offset = blockIdx.z * karg.KRead * karg.N;
660 }
662 {
663 b_k_split_offset = blockIdx.z * karg.KRead;
664 }
665
666 if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
667 {
668 karg.K = karg.KRead;
669 }
670 else
671 {
672 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
673 }
674 }
675
678 };
679
680 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
681 {
682 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
683 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
684 // A matrix in LDS memory, dst of blockwise copy
685 if constexpr(ABlockLdsExtraM)
686 {
690 }
691 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
692 // in some cases.
694 {
695 constexpr auto a_lds_block_desc =
698
699 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
700 a_lds_block_desc,
706
707 return a_lds_block_desc_permuted;
708 }
709 else // ColumnMajor A
710 {
711 // kfold and mpair dimension is not always required.
712 // more dimension in merge_transform increase the difficulty of generating immarg offset
713 // for compiler.
714 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
715 constexpr auto M1 = MPerBlock / M0;
716
717 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
718 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
719 constexpr auto KThreadRead = WaveSize / MPerXdl;
720 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
721
722 constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128)
723 ? 1
724 : 128 / (AK1Number * M0 * sizeof(LDSTypeA));
725 constexpr auto KThreadReadPerm =
726 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
727 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
728 : KThreadRead;
729
730 // 1<=mpair<=n0
731 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128)
732 ? 1
733 : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0
734 ? M0
735 : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA)));
736
737 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
741 Number<kfold * M0 / mpair>{},
743 AK1Number));
744
745 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
746 a_lds_block_desc,
751 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
758
759 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
760 a_lds_block_desc_permuted,
769 Sequence<1>{},
770 Sequence<2>{},
771 Sequence<3>{},
772 Sequence<4>{},
773 Sequence<5>{}),
775 Sequence<2>{},
778 Sequence<6>{},
779 Sequence<7>{}));
780
781 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
782 a_lds_block_desc_unmerged,
785 Number<KThreadWrite / kfold / KThreadReadPerm>{},
793
794 return a_lds_block_desc_ak0_m_ak1;
795 }
796 }
797
798 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
799 {
800 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
803 }
804
806 {
807 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
808
809 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
813 I1,
815
816 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
817 }
818
821 BlkGemmPipelineVer,
822 BlkGemmPipeSched,
823 BlockSize,
824 LDSTypeA,
825 LDSTypeB,
826 ComputeTypeA,
827 AccDataType,
834 ABlockTransferSrcScalarPerVector,
835 BBlockTransferSrcScalarPerVector,
836 MPerBlock,
837 NPerBlock,
838 KPerBlock,
839 ScaleBlockM,
840 ScaleBlockN,
841 ScaleBlockK,
842 MPerXdl,
843 NPerXdl,
844 MXdlPerWave,
845 NXdlPerWave,
846 KPack>())>;
847
848 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
849 {
850 // LDS allocation for A and B: be careful of alignment
851 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
852 // lds max alignment
853 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
854
855 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
856 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
857
858 // LDS allocation for C shuffle in LDS
859 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
861
862 constexpr auto c_block_size =
863 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
864
865 return math::max(a_block_space_size_aligned * sizeof(LDSTypeA),
866 c_block_size * sizeof(CShuffleDataType));
867 }
868
870
871 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
872 __host__ static constexpr bool CheckValidity(const Argument& karg)
873 {
874 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
875 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
876 "Invalid tuning param!");
877
883 {
884 if(!(karg.M % MPerBlock == 0))
885 {
886#if DEBUG_LOG
887 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
888 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
889 << std::endl;
890
891#endif // DEBUG_LOG
892 return false;
893 }
894 }
895
901 {
902 if(!(karg.N % NPerBlock == 0))
903 {
904#if DEBUG_LOG
905 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
906 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
907 << std::endl;
908
909#endif // DEBUG_LOG
910 return false;
911 }
912 }
913
918 {
919
920 auto K_t = karg.KBatch * KPerBlock;
921 if(!(karg.K % K_t == 0))
922 {
923#if DEBUG_LOG
924 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
925 << karg.K << " " << __FILE__ << ":" << __LINE__
926 << ", in function: " << __func__ << std::endl;
927
928#endif // DEBUG_LOG
929 return false;
930 }
931 }
932 else
933 {
934 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
935 auto K_t = karg.KBatch * KReadVec;
936 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
937 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
938 {
939 return false;
940 }
941 }
942
944 {
945 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
946 {
947#if DEBUG_LOG
948 std::cout << "Arg K (" << karg.K
949 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
950 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
951 << __LINE__ << ", in function: " << __func__ << std::endl;
952
953#endif // DEBUG_LOG
954 return false;
955 }
956 }
957 else
958 {
959 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
960 {
961#if DEBUG_LOG
962 std::cout << "Arg M (" << karg.M
963 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
964 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
965 << __LINE__ << ", in function: " << __func__ << std::endl;
966
967#endif // DEBUG_LOG
968 return false;
969 }
970 }
971
973 {
974 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
975 {
976#if DEBUG_LOG
977 std::cout << "Arg N (" << karg.N
978 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
979 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
980 << __LINE__ << ", in function: " << __func__ << std::endl;
981
982#endif // DEBUG_LOG
983 return false;
984 }
985 }
986 else
987 {
988 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
989 {
990#if DEBUG_LOG
991 std::cout << "Arg K (" << karg.K
992 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
993 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
994 << __LINE__ << ", in function: " << __func__ << std::endl;
995
996#endif // DEBUG_LOG
997 return false;
998 }
999 }
1000
1002 {
1004 {
1005#if DEBUG_LOG
1006 std::cout << "Arg N (" << karg.N
1007 << ") value is not a multiple of "
1008 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1009 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1010 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1011
1012#endif // DEBUG_LOG
1013 return false;
1014 }
1015 }
1016 else
1017 {
1019 {
1020#if DEBUG_LOG
1021 std::cout << "Arg M (" << karg.M
1022 << ") value is not a multiple of "
1023 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1024 << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__
1025 << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1026
1027#endif // DEBUG_LOG
1028 return false;
1029 }
1030 }
1031
1032 // check gridwise gemm pipeline
1033 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1034
1035 if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1036 {
1037 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1038 {
1039 return false;
1040 }
1041 }
1042
1043 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1044 return true;
1045 }
1046
1047 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1048 {
1049 const index_t num_loop = K / KPerBlock;
1050
1051 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1052 }
1053
1054 __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1055 {
1056 const index_t num_loop = K / KPerBlock;
1057
1058 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1059 }
1060
1061 template <typename CGridDesc>
1063 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1064 {
1065 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1066 c_grid_desc_m_n,
1071
1072 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1073 }
1074
1075 // return block_id to C matrix tile idx (m0, n0) mapping
1076 // if arch = gfx942
1078
1079 template <bool HasMainKBlockLoop,
1080 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1081 TailNumber TailNum = TailNumber::Odd>
1082 __device__ static void Run(const ADataType* p_a_grid,
1083 const BDataType* p_b_grid,
1084 DsGridPointer& p_ds_grid,
1085 CDataType* p_c_grid,
1086 const AScaleType* p_a_scale_grid,
1087 const BScaleType* p_b_scale_grid,
1088 void* p_shared,
1089 const Problem& problem,
1090 AElementwiseOperation a_element_op,
1091 BElementwiseOperation b_element_op,
1092 CElementwiseOperation c_element_op)
1093 {
1094 ignore = b_element_op;
1095 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1096 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1097 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1098 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1099
1100 const auto b_grid_desc_bpreshuffled =
1101 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1102 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1103 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1104
1105 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1106 make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
1107 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1108 make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM)));
1109 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1110 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1111 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1112 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1113
1114 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1116 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1117
1118 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1119 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1120 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1121 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1123 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1124
1125 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1126 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1127
1128 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1129 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1130
1131 // divide block work by [M, N]
1132 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1133
1134 const auto block_work_idx =
1135 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1136
1137 if(!block_2_ctile_map.ValidCTileIndex(
1138 block_work_idx,
1139 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1140 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1141 {
1142 return;
1143 }
1144
1145 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1146 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1147
1148 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1149 const index_t m_block_data_idx_on_grid =
1150 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1151
1152 const index_t n_block_data_idx_on_grid =
1153 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1154
1155 // A matrix in LDS memory, dst of blockwise copy
1156 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1157
1158 // B matrix in LDS memory, dst of blockwise copy
1159 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1160
1161 // A matrix blockwise copy
1162 auto a_blockwise_copy =
1164 AElementwiseOperation,
1168 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1169 ABlockTransferThreadClusterArrangeOrder,
1170 ADataType,
1171 LDSTypeA,
1172 decltype(a_grid_desc_ak0_m_ak1),
1173 decltype(a_block_desc_ak0_m_ak1),
1174 ABlockTransferSrcAccessOrder,
1176 ABlockTransferSrcVectorDim,
1177 2,
1178 ABlockTransferSrcScalarPerVector,
1179 ABlockTransferDstScalarPerVector_AK1,
1180 1,
1181 1,
1182 AThreadTransferSrcResetCoordinateAfterRun,
1183 true,
1184 BlockwiseGemmPipe::GlobalBufferNum>(
1185 a_grid_desc_ak0_m_ak1,
1186 make_multi_index(0, m_block_data_idx_on_grid, 0),
1187 a_element_op,
1188 a_block_desc_ak0_m_ak1,
1189 make_multi_index(0, 0, 0),
1191
1192 // Thread-wise copy
1193 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1195 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1196
1197 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1198 BDataType,
1199 BDataType,
1200 decltype(b_grid_desc_bpreshuffled),
1201 decltype(b_block_desc_bk0_n_bk1),
1204 3,
1205 BBlockTransferSrcScalarPerVector,
1206 BThreadTransferSrcResetCoordinateAfterRun,
1207 true>(b_grid_desc_bpreshuffled,
1208 make_multi_index(n_block_data_idx_on_grid,
1210 0,
1211 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1212
1213 // LDS allocation for A and B: be careful of alignment
1214 // Cast after lds
1216 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1217
1218 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1219 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1220
1221 // Blockwise GEMM pipeline
1222 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1223 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1224 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1225
1226 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1227 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1228 KPerBlock);
1229
1230 constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1231 constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1232 constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1233
1234 // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1235 // ScaleSliceSizeK is first dimension in C scale for packed math
1236 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1238
1239 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1240 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1241 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1242
1243 auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1244 (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1245
1246 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1248
1249 constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1251
1252 auto a_scale_thread_copy =
1254 AScaleType,
1255 decltype(a_scale_grid_desc_am_ak),
1256 decltype(a_scale_thread_desc),
1259 0,
1260 1,
1261 1,
1262 true>(
1263 a_scale_grid_desc_am_ak,
1264 make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
1265
1266 auto b_scale_thread_copy =
1268 BScaleType,
1269 decltype(b_scale_grid_desc_bn_ak),
1270 decltype(b_scale_thread_desc),
1273 1,
1274 ScaleSliceSizeK,
1275 1,
1276 true>(
1277 b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1278
1279 // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1280 constexpr auto a_scale_thread_slice_copy_step =
1281 make_tuple(make_multi_index(MWaves * MPerXdl, 0),
1282 make_multi_index(-MPerBlock, 0),
1283 make_multi_index(-MPerBlock, ScaleSliceSizeK));
1284 constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1285
1286 constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1287
1288 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1289 a_grid_desc_ak0_m_ak1,
1290 a_block_desc_ak0_m_ak1,
1291 a_blockwise_copy,
1292 a_grid_buf,
1293 a_block_buf,
1294 a_block_slice_copy_step,
1295 b_grid_desc_bpreshuffled,
1296 b_block_desc_bk0_n_bk1,
1297 b_blockwise_copy,
1298 b_grid_buf,
1299 b_block_buf,
1300 b_block_slice_copy_step,
1301
1302 c_scale_thread_desc,
1303 c_thread_buf,
1304
1305 a_scale_grid_desc_am_ak,
1306 a_scale_thread_desc,
1307 a_scale_thread_copy,
1308 a_scale_grid_buf,
1309 a_scale_thread_slice_copy_step,
1310
1311 b_scale_grid_desc_bn_ak,
1312 b_scale_thread_desc,
1313 b_scale_thread_copy,
1314 b_scale_grid_buf,
1315 b_scale_thread_slice_copy_step,
1316
1317 num_k_block_main_loop);
1318
1319 // shuffle C and write out
1320 {
1321 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1322 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1323 "wrong!");
1324
1325 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1326
1327 // transposed XDL
1328 // // TODO: hacky, fix it!
1329 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1330 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1331
1332 // // TODO: hacky, fix it!
1333 // only used to get lengths
1334 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1335 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1336
1337 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1338 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1339 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1340 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1341 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1342 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1343 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1344 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1345
1346 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1348
1349 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1350 static_cast<CShuffleDataType*>(p_shared),
1351 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1352
1353 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1354 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1355 make_tuple(
1358 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1359 M1, // M1 = MWave
1360 M2)), // M2 = MPerXdl
1363 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1364 N1, // N1 = NWave
1365 N2, // N2 * N3 * N4 = NPerXdl
1366 N3,
1367 N4))),
1369 make_tuple(
1371
1372 // calculate origin of thread output tensor on global memory
1373 // blockwise GEMM c matrix starting index
1374 const auto c_thread_mtx_on_block =
1375 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1376
1377 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1378 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1379
1380 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1385
1386 const auto m_thread_data_on_block_idx =
1387 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1388 make_multi_index(m_thread_data_on_block));
1389
1390 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1392 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1395
1396 const auto n_thread_data_on_block_idx =
1397 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1398 make_multi_index(n_thread_data_on_block));
1399
1400 // shuffle: threadwise copy C from VGPR to LDS
1401 auto c_thread_copy_vgpr_to_lds =
1403 CShuffleDataType,
1404 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1405 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1407 Sequence<CShuffleMXdlPerWavePerShuffle,
1408 CShuffleNXdlPerWavePerShuffle,
1409 I1,
1410 I1,
1411 I1,
1412 N2,
1413 I1,
1414 N4>,
1416 7,
1417 1,
1419 1,
1420 true>{
1421 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1423 0,
1424 m_thread_data_on_block_idx[I1],
1425 n_thread_data_on_block_idx[I1],
1426 m_thread_data_on_block_idx[I2],
1427 n_thread_data_on_block_idx[I2],
1428 n_thread_data_on_block_idx[I3],
1429 n_thread_data_on_block_idx[I4]),
1431
1432 using EDataType = CDataType;
1433
1434 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1435 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1436
1437 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1439 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1440
1441 const auto ds_grid_buf = generate_tuple(
1442 [&](auto i) {
1444 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1445 },
1447
1448 // tuple of reference to C/Ds tensor descriptors
1449 const auto c_ds_desc_refs = concat_tuple_of_reference(
1450 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1451 generate_tie([&](auto i) -> const auto& // return type should be reference
1452 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1454
1455 // tuple of reference to C/Ds tensor descriptors
1456 const auto c_ds_buf_refs = concat_tuple_of_reference(
1457 tie(c_shuffle_block_buf),
1458 generate_tie([&](auto i) -> const auto& // return type should be reference
1459 { return ds_grid_buf[i]; },
1461
1462 // tuple of starting index of C/Ds blockwise copy
1463 const auto idx_c_ds_block_begin = container_concat(
1464 make_tuple(make_multi_index(0, 0, 0, 0)),
1466 [&](auto) {
1467 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1468 },
1470
1471 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1472 c_grid_desc_mblock_mperblock_nblock_nperblock;
1473
1474 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1475 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1476 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1477
1478 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1480 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1482 decltype(c_ds_desc_refs),
1483 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1484 CElementwiseOperation,
1485 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1486 // support arbitray type
1487 Sequence<1,
1488 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1489 1,
1490 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1491 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1492 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1493 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1494 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1495 3, // index_t SrcVectorDim,
1496 3, // index_t DstVectorDim,
1497 CDEShuffleBlockTransferScalarPerVectors,
1502 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
1503 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
1504 {c_ds_desc_refs,
1505 idx_c_ds_block_begin,
1506 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1507 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
1508 c_element_op};
1509
1510 constexpr auto sfc_c_vgpr =
1513 Sequence<CShuffleMXdlPerWavePerShuffle,
1514 CShuffleNXdlPerWavePerShuffle,
1515 1,
1516 1,
1517 1,
1518 N2,
1519 1,
1520 N4>>{};
1521
1522 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1523
1524 // space filling curve for shuffled blockwise C/D/E
1525 constexpr auto sfc_cde_block =
1528 Sequence<1,
1529 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1530 1,
1531 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1532
1533 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
1534
1535 static_for<0, num_access, 1>{}([&](auto access_id) {
1536 // make sure it's safe to write to LDS
1538
1539 // each thread write its data from VGPR to LDS
1540 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1541 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1542 c_thread_buf,
1543 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1544 c_shuffle_block_buf);
1545
1546 // make sure it's safe to read from LDS
1548
1549 // each block copy its data from LDS to global
1550 cde_block_copy_lds_and_global.Run(
1551 c_ds_desc_refs,
1552 c_ds_buf_refs,
1553 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1554 tie(c_grid_buf));
1555
1556 if constexpr(access_id < num_access - 1)
1557 {
1558 constexpr auto cde_lds_and_global_step =
1559 sfc_cde_block.GetForwardStep(access_id);
1560
1561 // move on Ds
1562 static_for<0, NumDTensor, 1>{}([&](auto i) {
1563 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1564 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
1565 });
1566
1567 // move on E
1568 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1569 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1570 I0,
1571 cde_lds_and_global_step);
1572 }
1573 });
1574 }
1575 }
1576
1577 template <bool HasMainKBlockLoop,
1578 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1579 TailNumber TailNum = TailNumber::Odd>
1580 __device__ static void Run_2Lds(const ADataType* p_a_grid,
1581 const BDataType* p_b_grid,
1582 DsGridPointer& p_ds_grid,
1583 CDataType* p_c_grid,
1584 const AScaleType* p_a_scale_grid,
1585 const BScaleType* p_b_scale_grid,
1586 void* p_shared,
1587 void* p_shared1,
1588 const Problem& problem,
1589 AElementwiseOperation a_element_op,
1590 BElementwiseOperation b_element_op,
1591 CElementwiseOperation c_element_op)
1592 {
1593 ignore = b_element_op;
1594 index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
1595 index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
1596 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1597 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1598 const auto b_grid_desc_bpreshuffled =
1599 MakeBGridDescriptor_Preshuffled(BN0Shuffled, BK0Shuffled);
1600 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1601 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1602
1603 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
1604 make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
1605 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1606 make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM)));
1607 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
1608 make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
1609 math::integer_divide_ceil(problem.K, ScaleBlockK)),
1610 make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
1611
1612 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1614 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1615
1616 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1617 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1618 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1619 p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1621 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1622
1623 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1624 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1625
1626 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1627 p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1628
1629 // divide block work by [M, N]
1630 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1631
1632 const auto block_work_idx =
1633 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1634
1635 if(!block_2_ctile_map.ValidCTileIndex(
1636 block_work_idx,
1637 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1638 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1639 {
1640 return;
1641 }
1642
1643 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1644 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1645
1646 // HACK: this force m/n_block_data_idx_on_grid into SGPR
1647 const index_t m_block_data_idx_on_grid =
1648 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1649
1650 const index_t n_block_data_idx_on_grid =
1651 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1652
1653 // A matrix in LDS memory, dst of blockwise copy
1654 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1655
1656 // B matrix in LDS memory, dst of blockwise copy
1657 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1658
1659 // A matrix blockwise copy
1660 auto a_blockwise_copy =
1662 AElementwiseOperation,
1666 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1667 ABlockTransferThreadClusterArrangeOrder,
1668 ADataType,
1669 LDSTypeA,
1670 decltype(a_grid_desc_ak0_m_ak1),
1671 decltype(a_block_desc_ak0_m_ak1),
1672 ABlockTransferSrcAccessOrder,
1674 ABlockTransferSrcVectorDim,
1675 2,
1676 ABlockTransferSrcScalarPerVector,
1677 ABlockTransferDstScalarPerVector_AK1,
1678 1,
1679 1,
1680 AThreadTransferSrcResetCoordinateAfterRun,
1681 true,
1682 BlockwiseGemmPipe::GlobalBufferNum>(
1683 a_grid_desc_ak0_m_ak1,
1684 make_multi_index(0, m_block_data_idx_on_grid, 0),
1685 a_element_op,
1686 a_block_desc_ak0_m_ak1,
1687 make_multi_index(0, 0, 0),
1689
1690 // Thread-wise copy
1691 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
1693 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1695 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1696 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1697
1698 auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
1699 BDataType,
1700 BDataType,
1701 decltype(b_grid_desc_bpreshuffled),
1702 decltype(b_block_desc_bk0_n_bk1),
1705 3,
1706 BBlockTransferSrcScalarPerVector,
1707 BThreadTransferSrcResetCoordinateAfterRun,
1708 true>(b_grid_desc_bpreshuffled,
1709 make_multi_index(n_block_data_idx_on_grid,
1711 0,
1712 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
1713
1714 // LDS allocation for A and B: be careful of alignment
1715 // Cast after lds
1716 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1717 static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1718 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1719 static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1720 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1721
1722 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1723 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
1724
1725 // Blockwise GEMM pipeline
1726 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1727 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1728 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1729
1730 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1731 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1732 KPerBlock);
1733
1734 constexpr index_t ScaleSliceSizeM = MXdlPerWave;
1735 constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
1736 constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
1737
1738 // ScaleSliceSizeK is last dimension in A/B scale for vector memory access
1739 // ScaleSliceSizeK is first dimension in C scale for packed math
1740 constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
1742
1743 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1744 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1745 constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
1746 auto a_thread_offset = get_thread_local_1d_id() % MPerXdl +
1747 (get_thread_local_1d_id() / WaveSize) / NWaves * MPerXdl;
1748
1749 constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
1751
1752 constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
1754
1755 auto a_scale_thread_copy =
1757 AScaleType,
1758 decltype(a_scale_grid_desc_am_ak),
1759 decltype(a_scale_thread_desc),
1762 0,
1763 1,
1764 1,
1765 true>(
1766 a_scale_grid_desc_am_ak,
1767 make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0));
1768
1769 auto b_scale_thread_copy =
1771 BScaleType,
1772 decltype(b_scale_grid_desc_bn_ak),
1773 decltype(b_scale_thread_desc),
1776 1,
1777 ScaleSliceSizeK,
1778 1,
1779 true>(
1780 b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1781
1782 // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1);
1783 constexpr auto a_scale_thread_slice_copy_step =
1784 make_tuple(make_multi_index(MWaves * MPerXdl, 0),
1785 make_multi_index(-MPerBlock, 0),
1786 make_multi_index(-MPerBlock, ScaleSliceSizeK));
1787 constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
1788
1789 constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
1790
1791 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1792 a_grid_desc_ak0_m_ak1,
1793 a_block_desc_ak0_m_ak1,
1794 a_blockwise_copy,
1795 a_grid_buf,
1796 a_block_bufs,
1797 a_block_slice_copy_step,
1798 b_grid_desc_bpreshuffled,
1799 b_block_desc_bk0_n_bk1,
1800 b_blockwise_copy,
1801 b_grid_buf,
1802 b_block_bufs,
1803 b_block_slice_copy_step,
1804
1805 c_scale_thread_desc,
1806 c_thread_buf,
1807
1808 a_scale_grid_desc_am_ak,
1809 a_scale_thread_desc,
1810 a_scale_thread_copy,
1811 a_scale_grid_buf,
1812 a_scale_thread_slice_copy_step,
1813
1814 b_scale_grid_desc_bn_ak,
1815 b_scale_thread_desc,
1816 b_scale_thread_copy,
1817 b_scale_grid_buf,
1818 b_scale_thread_slice_copy_step,
1819
1820 num_k_block_main_loop);
1821
1822 // shuffle C and write out
1823 {
1824 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1825 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1826 "wrong!");
1827
1828 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1829
1830 // transposed XDL
1831 // // TODO: hacky, fix it!
1832 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1833 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1834
1835 // // TODO: hacky, fix it!
1836 // only used to get lengths
1837 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1838 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1839
1840 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1841 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1842 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1843 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1844 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1845 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1846 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1847 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1848
1849 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1851
1852 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1853 static_cast<CShuffleDataType*>(p_shared),
1854 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1855
1856 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1857 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1858 make_tuple(
1861 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1862 M1, // M1 = MWave
1863 M2)), // M2 = MPerXdl
1866 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1867 N1, // N1 = NWave
1868 N2, // N2 * N3 * N4 = NPerXdl
1869 N3,
1870 N4))),
1872 make_tuple(
1874
1875 // calculate origin of thread output tensor on global memory
1876 // blockwise GEMM c matrix starting index
1877 const auto c_thread_mtx_on_block =
1878 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1879
1880 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1881 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1882
1883 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1888
1889 const auto m_thread_data_on_block_idx =
1890 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1891 make_multi_index(m_thread_data_on_block));
1892
1893 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1895 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1898
1899 const auto n_thread_data_on_block_idx =
1900 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1901 make_multi_index(n_thread_data_on_block));
1902
1903 // shuffle: threadwise copy C from VGPR to LDS
1904 auto c_thread_copy_vgpr_to_lds =
1906 CShuffleDataType,
1907 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1908 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1910 Sequence<CShuffleMXdlPerWavePerShuffle,
1911 CShuffleNXdlPerWavePerShuffle,
1912 I1,
1913 I1,
1914 I1,
1915 N2,
1916 I1,
1917 N4>,
1919 7,
1920 1,
1922 1,
1923 true>{
1924 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1926 0,
1927 m_thread_data_on_block_idx[I1],
1928 n_thread_data_on_block_idx[I1],
1929 m_thread_data_on_block_idx[I2],
1930 n_thread_data_on_block_idx[I2],
1931 n_thread_data_on_block_idx[I3],
1932 n_thread_data_on_block_idx[I4]),
1934
1935 using EDataType = CDataType;
1936
1937 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1938 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1939
1940 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1942 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1943
1944 const auto ds_grid_buf = generate_tuple(
1945 [&](auto i) {
1947 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1948 },
1950
1951 // tuple of reference to C/Ds tensor descriptors
1952 const auto c_ds_desc_refs = concat_tuple_of_reference(
1953 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1954 generate_tie([&](auto i) -> const auto& // return type should be reference
1955 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1957
1958 // tuple of reference to C/Ds tensor descriptors
1959 const auto c_ds_buf_refs = concat_tuple_of_reference(
1960 tie(c_shuffle_block_buf),
1961 generate_tie([&](auto i) -> const auto& // return type should be reference
1962 { return ds_grid_buf[i]; },
1964
1965 // tuple of starting index of C/Ds blockwise copy
1966 const auto idx_c_ds_block_begin = container_concat(
1967 make_tuple(make_multi_index(0, 0, 0, 0)),
1969 [&](auto) {
1970 return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
1971 },
1973
1974 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1975 c_grid_desc_mblock_mperblock_nblock_nperblock;
1976
1977 using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock =
1978 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1979 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1980
1981 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3<
1983 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1985 decltype(c_ds_desc_refs),
1986 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1987 CElementwiseOperation,
1988 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
1989 // support arbitray type
1990 Sequence<1,
1991 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1992 1,
1993 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1994 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1995 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1996 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
1997 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
1998 3, // index_t SrcVectorDim,
1999 3, // index_t DstVectorDim,
2000 CDEShuffleBlockTransferScalarPerVectors,
2005 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2006 Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
2007 {c_ds_desc_refs,
2008 idx_c_ds_block_begin,
2009 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2010 make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)),
2011 c_element_op};
2012
2013 constexpr auto sfc_c_vgpr =
2016 Sequence<CShuffleMXdlPerWavePerShuffle,
2017 CShuffleNXdlPerWavePerShuffle,
2018 1,
2019 1,
2020 1,
2021 N2,
2022 1,
2023 N4>>{};
2024
2025 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2026
2027 // space filling curve for shuffled blockwise C/D/E
2028 constexpr auto sfc_cde_block =
2031 Sequence<1,
2032 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2033 1,
2034 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2035
2036 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2037
2038 static_for<0, num_access, 1>{}([&](auto access_id) {
2039 // make sure it's safe to write to LDS
2041
2042 // each thread write its data from VGPR to LDS
2043 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2044 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2045 c_thread_buf,
2046 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2047 c_shuffle_block_buf);
2048
2049 // make sure it's safe to read from LDS
2051
2052 // each block copy its data from LDS to global
2053 cde_block_copy_lds_and_global.Run(
2054 c_ds_desc_refs,
2055 c_ds_buf_refs,
2056 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2057 tie(c_grid_buf));
2058
2059 if constexpr(access_id < num_access - 1)
2060 {
2061 constexpr auto cde_lds_and_global_step =
2062 sfc_cde_block.GetForwardStep(access_id);
2063
2064 // move on Ds
2065 static_for<0, NumDTensor, 1>{}([&](auto i) {
2066 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2067 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2068 });
2069
2070 // move on E
2071 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2072 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2073 I0,
2074 cde_lds_and_global_step);
2075 }
2076 });
2077 }
2078 }
2079};
2080
2081} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:75
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__global__ void kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:39
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp:34
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
unsigned int uint32_t
Definition stdint.h:126
Definition block_to_ctile_map.hpp:271
__host__ static __device__ constexpr index_t CalculateGridSize(index_t M, index_t N)
Definition block_to_ctile_map.hpp:283
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:592
const CElementwiseOperation c_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:641
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:593
const ADataType * p_a_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:631
const BScaleType * p_b_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:637
const BElementwiseOperation b_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:640
const BDataType * p_b_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:632
CDataType * p_c_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:634
const AScaleType * p_a_scale_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:636
DsGridPointer p_ds_grid
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:633
const AElementwiseOperation a_element_op
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:639
index_t BK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:585
index_t N
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:573
index_t NPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:581
index_t AK0
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:584
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:577
index_t M
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:572
__host__ void Print() const
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:562
index_t StrideB
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:576
index_t StrideC
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:578
index_t MBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:586
index_t K
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:574
index_t MPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:580
index_t StrideA
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:575
index_t KBatch
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:579
index_t KRead
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:582
index_t NBlock
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:587
__host__ __device__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:535
index_t KPadded
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:583
index_t b_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:677
index_t a_k_split_offset
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:676
__device__ SplitKBatchOffset(Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:646
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:157
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:1580
remove_cvref_t< decltype(BlockGemmBlockScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, LDSTypeA, LDSTypeB, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:819
static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:1062
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp:1082
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr index_t GetK1PerXdlops()
Definition xdlops_gemm.hpp:1810
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
static constexpr index_t GetKPerXdlops()
Definition xdlops_gemm.hpp:1804
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3.hpp:48
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340