blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 2
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 1
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeTypeA,
21 typename ComputeTypeB,
22 typename AccDataType,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerWmma,
31 index_t NPerWmma,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack,
35 bool TransposeC = false>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeTypeA,
44 typename ComputeTypeB,
45 typename AccDataType,
46 typename AWmmaTileDesc,
47 typename BWmmaTileDesc,
48 index_t ABlockTransferSrcScalarPerVector,
49 index_t BBlockTransferSrcScalarPerVector,
50 index_t MPerBlock,
51 index_t NPerBlock,
52 index_t KPerBlock,
53 index_t MPerWmma,
54 index_t NPerWmma,
55 index_t MRepeat,
56 index_t NRepeat,
57 index_t KPack,
58 bool TransposeC>
60 BlockSize,
61 ADataType,
62 BDataType,
63 ComputeTypeA,
64 ComputeTypeB,
65 AccDataType,
66 AWmmaTileDesc,
67 BWmmaTileDesc,
68 ABlockTransferSrcScalarPerVector,
69 BBlockTransferSrcScalarPerVector,
70 MPerBlock,
71 NPerBlock,
72 KPerBlock,
73 MPerWmma,
74 NPerWmma,
75 MRepeat,
76 NRepeat,
77 KPack,
78 TransposeC>
80 ADataType,
81 BDataType,
82 ComputeTypeA,
83 ComputeTypeB,
84 AccDataType,
85 AWmmaTileDesc,
86 BWmmaTileDesc,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
89 MPerBlock,
90 NPerBlock,
91 KPerBlock,
92 MPerWmma,
93 NPerWmma,
94 MRepeat,
95 NRepeat,
96 KPack,
97 TransposeC>
98{
100 ADataType,
101 BDataType,
102 ComputeTypeA,
103 ComputeTypeB,
104 AccDataType,
105 AWmmaTileDesc,
106 BWmmaTileDesc,
107 ABlockTransferSrcScalarPerVector,
108 BBlockTransferSrcScalarPerVector,
109 MPerBlock,
110 NPerBlock,
111 KPerBlock,
112 MPerWmma,
113 NPerWmma,
114 MRepeat,
115 NRepeat,
116 KPack,
117 TransposeC>;
118 using Base::I0;
119
120 using Base::A_K1;
121 using Base::A_KRow;
122 using Base::B_K1;
123 using Base::B_KRow;
124 using Base::KRepeat;
125 using Base::WmmaK;
126
127 using Base::wmma_gemm;
128 using typename Base::HotLoopInstList;
129
131 using Base::
132 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
134 using Base::
135 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
136 using Base::
137 GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
138
141
142 using typename Base::Empty;
143
144 static constexpr index_t PrefetchStages = 2;
145 static constexpr index_t PrefillStages = 1;
146 static constexpr index_t GlobalBufferNum = 1;
147
148 __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
149 {
150 return num_loop > PrefetchStages;
151 }
152
153 __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
154 {
155 if(BlockHasHotloop(num_loop))
156 {
157 return TailNumber::Full;
158 }
159 else
160 {
161 if(num_loop == 1)
162 {
163 return TailNumber::Odd;
164 }
165 else
166 {
167 return TailNumber::Even;
168 }
169 }
170 }
171
172 __device__ static constexpr auto HotLoopScheduler()
173 {
174 // TODO: Calculation of the number of instructions may require changes for WMMA
175 /*
176 // A/B split schedule
177 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
178 constexpr auto num_ds_read_inst_a =
179 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
180 ? HotLoopInstList::A_LDS_Read_Inst_Num
181 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
182 constexpr auto num_ds_read_inst_b =
183 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
184 ? HotLoopInstList::B_LDS_Read_Inst_Num
185 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
186
187 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
188 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
189
190 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
191 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
192
193 constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
194
195 constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
196 constexpr auto ds_read_a_issue_cycle =
197 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
198 constexpr auto ds_read_b_issue_cycle =
199 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
200 constexpr auto ds_read_a_wmma_rate =
201 (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
202 constexpr auto ds_read_b_wmma_rate =
203 (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
204
205 constexpr auto num_dsread_a_wmma =
206 (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
207 constexpr auto num_dsread_b_wmma =
208 (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
209
210 // stage 1
211 // Separate this part?
212 // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
213 // sizeof(ComputeDataType) / sizeof(BDataType)
214 // ? sizeof(ComputeDataType) / sizeof(ADataType)
215 // : sizeof(ComputeDataType) / sizeof(BDataType);
216 constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
217 constexpr auto num_wmma_per_issue =
218 num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
219 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
220 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
221
222 static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
223 ignore = i;
224 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
225 ignore = idswrite;
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
228 });
229 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
230 __builtin_amdgcn_sched_group_barrier(
231 0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
232 });
233 static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
234 ignore = i;
235 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
236 ignore = idswrite;
237 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
238 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
239 });
240 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
241 __builtin_amdgcn_sched_group_barrier(
242 0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
243 });
244
245 // stage 2
246 static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
247 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
248 ds_read_a_wmma_rate)
249 {
250 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
251 }
252 else
253 {
254 __builtin_amdgcn_sched_group_barrier(0x100,
255 num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
256 ds_read_a_wmma_rate,
257 0); // DS read
258 }
259 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
260 });
261
262 static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
263 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
264 ds_read_b_wmma_rate)
265 {
266 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
267 }
268 else
269 {
270 __builtin_amdgcn_sched_group_barrier(0x100,
271 num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
272 ds_read_b_wmma_rate,
273 0); // DS read
274 }
275 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
276 });
277 */
278 }
279
280 template <typename ABlockBuffer,
281 typename AThreadBuffer,
282 typename BBlockBuffer,
283 typename BThreadBuffer,
284 typename BScaleStruct>
285 __device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
286 AThreadBuffer& a_thread_buf,
287 BBlockBuffer& b_block_buf,
288 BThreadBuffer& b_thread_buf,
289 BScaleStruct& b_scale_struct) const
290 {
291 static_for<0, KRepeat, 1>{}([&](auto k0) {
292 static_for<0, MRepeat, 1>{}([&](auto m0) {
293 a_thread_copy_.Run(
296 a_block_buf,
298 make_tuple(I0, m0, k0, I0, I0, I0),
299 a_thread_buf);
300 });
301
303 {
304 static_for<0, NRepeat, 1>{}([&](auto n0) {
305 b_thread_copy_.Run(
308 b_block_buf,
310 make_tuple(I0, n0, k0, I0, I0, I0),
311 b_thread_buf);
312 });
313 }
314 else
315 {
316 static_for<0, NRepeat, 1>{}([&](auto n0) {
317 b_thread_copy_.Run(
320 b_block_buf,
321 b_scale_struct.b_scale_thread_bufs(
322 I0)[Number<n0 * BScaleStruct::num_scale_k_block +
323 k0 / BScaleStruct::num_scale_krepeat>{}],
325 make_tuple(I0, n0, k0, I0, I0, I0),
326 b_thread_buf);
327 });
328 }
329 });
330 }
331
332 template <bool HasMainLoop,
333 TailNumber TailNum,
334 typename AGridDesc,
335 typename ABlockDesc,
336 typename ABlockTransfer,
337 typename AGridBuffer,
338 typename ABlockBuffer,
339 typename ABlockTransferStep,
340 typename BGridDesc,
341 typename BBlockDesc,
342 typename BBlockTransfer,
343 typename BGridBuffer,
344 typename BBlockBuffer,
345 typename BBlockTransferStep,
346 typename CThreadBuffer,
347 typename BScaleStruct>
348 __device__ void Run(const AGridDesc& a_grid_desc,
349 const ABlockDesc& a_block_desc,
350 ABlockTransfer& a_blockwise_copy,
351 const AGridBuffer& a_grid_buf,
352 ABlockBuffer& a_block_buf,
353 const ABlockTransferStep& a_block_copy_step,
354 const BGridDesc& b_grid_desc,
355 const BBlockDesc& b_block_desc,
356 BBlockTransfer& b_blockwise_copy,
357 const BGridBuffer& b_grid_buf,
358 BBlockBuffer& b_block_buf,
359 const BBlockTransferStep& b_block_copy_step,
360 CThreadBuffer& c_thread_buf,
361 // BScaleThreadCopy
362 BScaleStruct& b_scale_struct,
363 index_t num_loop,
364 index_t num_loop_per_scale) const
365 {
366 __builtin_amdgcn_sched_barrier(0);
368 a_thread_desc_.GetElementSpaceSize());
370 b_thread_desc_.GetElementSpaceSize());
371
372 // Global prefetch 1
373 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
374 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
375
376 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
377 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
378
379 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
380
381 // Local prefill 1
382 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
383 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
384
385 // Global prefetch 2, perform when at least 2 loops exist.
386 if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
387 {
388 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
389 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
390
391 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
392 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
393 }
394
395 // Initialize C
396 c_thread_buf.Clear();
397
398 // Local prefetch 1
400
401 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
402
403 __builtin_amdgcn_sched_barrier(0);
404
405 // Main body, perform when at least 3 loops exist.
406 if constexpr(HasMainLoop)
407 {
408 index_t i = 0;
409 do
410 {
412
413 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
414 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
415
416 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
417 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
418
419 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
420 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
421
422 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
423
424 static_for<0, KRepeat, 1>{}([&](auto k0) {
425 static_for<0, MRepeat, 1>{}([&](auto m0) {
426 static_for<0, NRepeat, 1>{}([&](auto n0) {
427 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
428 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
429
430 static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
431 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
432 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
434 m0,
435 k0,
436 I0,
437 I0,
438 Number<ik % A_K1>{}))>{}];
439 });
440 static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
441 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
442 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
444 n0,
445 k0,
446 I0,
447 I0,
448 Number<ik % B_K1>{}))>{}];
449 });
450
451 using wmma_input_type_a =
452 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
453 using wmma_input_type_b =
454 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
455
456 constexpr index_t c_offset =
457 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
458
459 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
460 b_thread_vec.template AsType<wmma_input_type_b>(),
461 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
462 });
463 });
464 });
465
467
468 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
469
471 __builtin_amdgcn_sched_barrier(0);
472
473 i += 1;
474 } while(i < (num_loop - 2));
475 }
476
477 // Pre-tail, perform when at least 2 loops exist.
478 if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
479 {
481
482 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
483 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
484
485 // No RunRead or MoveSrcSliceWindow here, already finished them all!
486
487 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
488
489 static_for<0, KRepeat, 1>{}([&](auto k0) {
490 static_for<0, MRepeat, 1>{}([&](auto m0) {
491 static_for<0, NRepeat, 1>{}([&](auto n0) {
492 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
493 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
494
495 static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
496 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
497 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
498 Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
499 });
500 static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
501 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
502 b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
503 Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
504 });
505
506 using wmma_input_type_a =
507 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
508 using wmma_input_type_b =
509 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
510
511 constexpr index_t c_offset =
512 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
513
514 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
515 b_thread_vec.template AsType<wmma_input_type_b>(),
516 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
517 });
518 });
519 });
520
522
523 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
524
526 __builtin_amdgcn_sched_barrier(0);
527 }
528
529 // Tail, always perform.
530 {
531 static_for<0, KRepeat, 1>{}([&](auto k0) {
532 static_for<0, MRepeat, 1>{}([&](auto m0) {
533 static_for<0, NRepeat, 1>{}([&](auto n0) {
534 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
535 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
536
537 static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
538 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
539 a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
540 Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
541 });
542 static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
543 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
544 b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
545 Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
546 });
547
548 using wmma_input_type_a =
549 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
550 using wmma_input_type_b =
551 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
552
553 constexpr index_t c_offset =
554 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
555
556 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
557 b_thread_vec.template AsType<wmma_input_type_b>(),
558 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
559 });
560 });
561 });
562 // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
563 // latency
564 // __builtin_amdgcn_sched_barrier(0);
565 }
566 }
567
568 protected:
569 using Base::a_thread_copy_;
570 using Base::a_thread_desc_;
571 using Base::b_thread_copy_;
572 using Base::b_thread_desc_;
573 using Base::c_thread_desc_;
574};
575
576} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerWmma, NPerWmma, wmma_gemm.wmma_instr.k_per_wmma > HotLoopInstList
Definition blockwise_gemm_pipeline_wmmaops_base.hpp:70
BlockwiseGemmWmmaops_pipeline_base< BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC > Base
Definition blockwise_gemm_pipeline_wmmaops_v3.hpp:99
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_wmmaops_v3.hpp:348
Definition blockwise_gemm_pipeline_wmmaops_v3.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10