blockwise_gemm_smfmac_xdlops.hpp Source File

blockwise_gemm_smfmac_xdlops.hpp Source File#

Composable Kernel: blockwise_gemm_smfmac_xdlops.hpp Source File
blockwise_gemm_smfmac_xdlops.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15__host__ __device__ static constexpr auto
16MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17{
18 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20
22 TileDesc_K0_MN_K1{},
28}
29
30template <index_t BlockSize,
31 typename FloatA,
32 typename FloatB,
33 typename FloatAcc,
34 typename AK0MK1BlockDesc,
35 typename BK0NK1BlockDesc,
36 index_t MPerXDL,
37 index_t NPerXDL,
38 index_t MRepeat,
39 index_t NRepeat,
40 index_t KPack,
41 typename ComputeTypeA = FloatA,
42 typename ComputeTypeB = FloatB>
44{
45 static constexpr auto I0 = Number<0>{};
46 static constexpr auto I1 = Number<1>{};
47 static constexpr auto I2 = Number<2>{};
48 static constexpr auto I3 = Number<3>{};
49
51
52 static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
53 static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
54 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
55
56 static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
57 static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
58 static constexpr index_t KPerBlock =
59 BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
60
61 static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
62 static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
63 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
64 static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
65
68
69 static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
70
72 FloatAcc,
73 MRepeat * NRepeat,
74 xdlops_gemm.GetRegSizePerXdlops(),
75 true>
77
78 __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
79
80 __device__ static auto GetWaveIdx()
81 {
82 const index_t thread_id = ThisThreadBlock::GetThreadId();
83
84 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
88
89 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
90 }
91
92 __device__ static auto CalculateAThreadOriginDataIndex()
93 {
94 const auto wave_idx = GetWaveIdx();
95 const auto waveId_m = wave_idx[I0];
96 const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
97
98 return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
99 }
100
101 __device__ static auto CalculateBThreadOriginDataIndex()
102 {
103 const auto wave_idx = GetWaveIdx();
104 const auto waveId_n = wave_idx[I1];
105 const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
106
107 return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
108 }
109
110 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
111 __device__ static auto
113 {
114 const auto wave_idx = GetWaveIdx();
115 const auto waveId_m = wave_idx[I0];
116 const auto waveId_n = wave_idx[I1];
117
118 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
119
120 constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
124
125 constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
129
130 const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
131 make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
132 const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
133 make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
134
135 return make_tuple(c_thread_m, c_thread_n);
136 }
137
138 template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
139 __device__ static auto
141 {
142 const auto wave_idx = GetWaveIdx();
143 const auto waveId_m = wave_idx[I0];
144 const auto waveId_n = wave_idx[I1];
145
146 const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
147
148 return make_tuple(Number<m0>{},
149 Number<n0>{},
150 waveId_m,
151 waveId_n,
152 blk_idx[I0],
153 blk_idx[I1],
154 blk_idx[I2],
155 blk_idx[I3]);
156 }
157
159 {
160#if defined(__HIP_DEVICE_COMPILE__)
161 static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
162 BK0NK1BlockDesc::IsKnownAtCompileTime(),
163 "wrong! Desc should be known at compile-time");
164
166 "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
167
168 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0,
169 "MPerBlock must be divisible by MPerXDL * MRepeat");
170 static_assert(NPerBlock % (NPerXDL * NRepeat) == 0,
171 "NPerBlock must be divisible by NPerXDL * NRepeat");
172
173 static_assert(
174 KPack % (16 * sizeof(ComputeTypeA)) == 0,
175 "KPack must be divisbile by number of elements processed in single smfmac instruction");
176#endif
177 }
178
179 __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
180 {
181 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
182
183 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
184 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
185 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
186 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
187
189 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
190 }
191
192 __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
193 {
194 constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
195
196 constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
197 constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
198 constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
199 constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
200
202 make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
203 }
204
205 __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
206 {
207 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
213 Number<NPerXDL>{}));
214
215 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
216 }
217
218 __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
219 {
220 constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
227 Number<NPerXDL>{}));
228
229 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
230 c_block_desc_g_m0_n0_m1_n1_m2_n2);
231 }
232
233 template <typename CGridDesc_M_N>
234 __host__ __device__ static constexpr auto
235 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
236 {
237 const auto M = c_grid_desc_m_n.GetLength(I0);
238 const auto N = c_grid_desc_m_n.GetLength(I1);
239
240 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
241 c_grid_desc_m_n,
242 make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
243 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
246
247 return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
248 }
249
250 template <typename CGridDesc_G_M_N>
251 __host__ __device__ static constexpr auto
252 MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
253 {
254 const auto G = c_grid_desc_g_m_n.GetLength(I0);
255 const auto M = c_grid_desc_g_m_n.GetLength(I1);
256 const auto N = c_grid_desc_g_m_n.GetLength(I2);
257
258 const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
259 c_grid_desc_g_m_n,
261 make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
262 make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
265
266 return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
267 c_grid_desc_g_m0_n0_m1_n1_m2_n2);
268 }
269
281
293
296
297 // Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4
298 // structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in
299 // smfmac instruction
300 template <typename AThreadBuf, typename IdxBuf, int32_t num_elems>
301 __device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf)
302 {
303 static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000};
304 static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA);
305
307 constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA));
308 constexpr int idx_reg_part = (i % 32) / processed_elems;
309
312 a_thread_vec.template AsType<ComputeTypeA>()(j) = a_thread_buf
313 [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i + j))>{}];
314 });
315
316 uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default
317 for(int j = 0; j < processed_elems; j += 4)
318 {
319 int32_t a_pos = idx_reg_part * processed_elems + j;
320 int32_t nonzero_pos = 0;
321 ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]};
322 for(int k = 0; k < 3; k += 1)
323 {
324 if(a_thread_vec[j + k] != 0.0f)
325 {
326 nonzero_elems[nonzero_pos] = a_thread_vec[j + k];
327 idx &= ~bit_clear_masks[j / 2 + nonzero_pos];
328 idx |= k << 2 * (j / 2 + nonzero_pos);
329 ++nonzero_pos;
330 }
331 }
332 a_thread_vec[j / 2] = nonzero_elems[0];
333 a_thread_vec[j / 2 + 1] = nonzero_elems[1];
334 }
335 IdxBuf[idx_reg_num].AsType<int8x4_t>()[Number<idx_reg_part>{}] = idx;
336
337 static_for<0, processed_elems / 2, 1>{}([&](auto j) {
338 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
339 make_tuple(0, 0, 0, i / 2 + j))>{}] = a_thread_vec[j];
340 });
341 });
342 }
343
344 template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
345 __device__ void Run(const ABlockBuffer& a_block_buf,
346 const BBlockBuffer& b_block_buf,
347 CThreadBuffer& c_thread_buf) const
348 {
350 a_thread_desc_.GetElementSpaceSize());
352 b_thread_desc_.GetElementSpaceSize());
353 static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA);
355 (a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx);
356
357 static_for<0, MRepeat, 1>{}([&](auto m0) {
358 // read A
360 make_tuple(m0, I0, I0, I0),
361 a_block_buf,
363 make_tuple(I0, I0, I0, I0),
364 a_thread_buf);
365
366 SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize());
367
368 static_for<0, NRepeat, 1>{}([&](auto n0) {
369 // read B
371 make_tuple(n0, I0, I0, I0),
372 b_block_buf,
374 make_tuple(I0, I0, I0, I0),
375 b_thread_buf);
376
378 // a_thread_vec is smaller because it's structurally sparse 2:4
379 vector_type<ComputeTypeA, KPack / 2> a_thread_vec;
381 vector_type<int32_t, KPack / elems_per_idx> idx_vec;
382
383 static_for<0, KPack / 2, 1>{}([&](auto i) {
384 a_thread_vec.template AsType<ComputeTypeA>()(i) =
385 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
386 make_tuple(0, 0, 0, k / 2 + i))>{}];
387 });
388
389 static_for<0, KPack, 1>{}([&](auto i) {
390 b_thread_vec.template AsType<ComputeTypeB>()(2 * i) = b_thread_buf
391 [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
392 });
393
394 static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) {
395 idx_vec.template AsType<int32_t>()(i) = idx_buf[k / elems_per_idx + i];
396 });
397
398 // A is smaller because it's structurally sparse 2:4
399 using mfma_input_type_a =
400 typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops / 2>::type;
401 using mfma_input_type_b =
402 typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
403 using mfma_input_type_idx = typename vector_type<int32_t, 1>::type;
404
405 constexpr index_t c_offset =
406 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
407
408 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
409 b_thread_vec.template AsType<mfma_input_type_b>(),
410 idx_vec.template AsType<mfma_input_type_idx>(),
411 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
412 });
413 });
414 });
415 }
416
417 protected:
418 // A[M0, M1, M2, KPerThread]
419 static constexpr auto a_thread_desc_ =
421
422 // B[N0, N1, N2, KPerThread]
423 static constexpr auto b_thread_desc_ =
425
426 // C[M, N, NumRegXdlops]
428 make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
429
431 ComputeTypeA,
432 decltype(a_block_desc_m0_m1_m2_k),
433 decltype(a_thread_desc_),
436 3,
437 A_K1,
438 A_K1>;
439
441 ComputeTypeB,
442 decltype(b_block_desc_n0_n1_n2_k),
443 decltype(b_thread_desc_),
446 3,
447 B_K1,
448 B_K1>;
449
452};
453
454} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
typename vector_type< int8_t, 4 >::type int8x4_t
Definition dtype_vector.hpp:2177
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
signed int int32_t
Definition stdint.h:123
unsigned char uint8_t
Definition stdint.h:124
__host__ static __device__ constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
Definition blockwise_gemm_smfmac_xdlops.hpp:270
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition blockwise_gemm_smfmac_xdlops.hpp:158
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_smfmac_xdlops.hpp:78
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_smfmac_xdlops.hpp:252
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_smfmac_xdlops.hpp:101
__host__ static __device__ constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition blockwise_gemm_smfmac_xdlops.hpp:282
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_smfmac_xdlops.hpp:92
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition blockwise_gemm_smfmac_xdlops.hpp:76
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_smfmac_xdlops.hpp:218
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_smfmac_xdlops.hpp:140
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_smfmac_xdlops.hpp:192
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_smfmac_xdlops.hpp:205
__device__ void SetIdxSqueezeA(AThreadBuf &a_thread_buf, IdxBuf &idx_buf)
Definition blockwise_gemm_smfmac_xdlops.hpp:301
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:440
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_smfmac_xdlops.hpp:235
ThreadwiseTensorSliceTransfer_v4< FloatA, ComputeTypeA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_smfmac_xdlops.hpp:430
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_smfmac_xdlops.hpp:345
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_smfmac_xdlops.hpp:112
static __device__ auto GetWaveIdx()
Definition blockwise_gemm_smfmac_xdlops.hpp:80
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_smfmac_xdlops.hpp:179
Definition utility/sequence.hpp:43
Definition smfmac_xdlops_gemm.hpp:215
Definition static_buffer.hpp:75
static __device__ constexpr index_t GetNumOfThread()
Definition thread_group.hpp:15
static __device__ index_t GetThreadId()
Definition thread_group.hpp:19
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10