blockwise_gemm_pipeline_xdlops_v5.hpp Source File

blockwise_gemm_pipeline_xdlops_v5.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_v5.hpp Source File
blockwise_gemm_pipeline_xdlops_v5.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Compute optimized pipeline
11// GlobalPrefetchStages: 3
12// LocalPreFillStages: 1
13// LocalPreFetchStages: 1
14// LocalSharedMemoryBuffer: 2
15
16template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17 index_t BlockSize,
18 typename ADataType,
19 typename BDataType,
20 typename ComputeDataType,
21 typename AccDataType,
22 typename ATileDesc,
23 typename BTileDesc,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
28 index_t MPerBlock,
29 index_t NPerBlock,
30 index_t KPerBlock,
31 index_t MPerXDL,
32 index_t NPerXDL,
33 index_t MRepeat,
34 index_t NRepeat,
35 index_t KPacks>
39
40template <index_t BlockSize,
41 typename ADataType,
42 typename BDataType,
43 typename ComputeDataType,
44 typename AccDataType,
45 typename ATileDesc,
46 typename BTileDesc,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
51 index_t MPerBlock,
52 index_t NPerBlock,
53 index_t KPerBlock,
54 index_t MPerXDL,
55 index_t NPerXDL,
56 index_t MRepeat,
57 index_t NRepeat,
58 index_t KPack
59 // ,bool TransposeC //disable transposec right now...
60 >
62 BlockSize,
63 ADataType,
64 BDataType,
65 ComputeDataType,
66 AccDataType,
67 ATileDesc,
68 BTileDesc,
69 AMmaTileDesc,
70 BMmaTileDesc,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
73 MPerBlock,
74 NPerBlock,
75 KPerBlock,
76 MPerXDL,
77 NPerXDL,
78 MRepeat,
79 NRepeat,
80 KPack>
82 ADataType,
83 BDataType,
84 ComputeDataType,
85 AccDataType,
86 ATileDesc,
87 BTileDesc,
88 AMmaTileDesc,
89 BMmaTileDesc,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
92 MPerBlock,
93 NPerBlock,
94 KPerBlock,
95 MPerXDL,
96 NPerXDL,
97 MRepeat,
98 NRepeat,
99 KPack>
100
101{
103 ADataType,
104 BDataType,
105 ComputeDataType,
106 AccDataType,
107 ATileDesc,
108 BTileDesc,
109 AMmaTileDesc,
110 BMmaTileDesc,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
113 MPerBlock,
114 NPerBlock,
115 KPerBlock,
116 MPerXDL,
117 NPerXDL,
118 MRepeat,
119 NRepeat,
120 KPack>;
121 using Base::A_K1;
122 using Base::B_K1;
123 using Base::I0;
124 using Base::I1;
125 using Base::KRepeat;
126 using Base::xdlops_gemm;
127 using typename Base::HotLoopInstList;
128
140
143
144 using Base::AMmaKStride;
145 using Base::BMmaKStride;
146
148
149 static constexpr index_t PrefetchStages = 3;
150 static constexpr index_t PrefillStages = 1;
151 static constexpr index_t GlobalBufferNum = 2;
152 static constexpr index_t HotloopUnroll = 2;
153
154 __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
155 {
156 return num_loop > PrefetchStages;
157 }
158
159 __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
160 {
161 if(num_loop % HotloopUnroll == 1)
162 {
163 return TailNumber::Odd;
164 }
165 else
166 {
167 return TailNumber::Even;
168 }
169 }
170
171 __device__ static constexpr auto HotLoopScheduler()
172 {
173 // TODO: Take data type into consideration as pipe ver 3
174 // A/B split schedule
175 // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
176 constexpr auto num_ds_read_inst_a =
177 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
180 constexpr auto num_ds_read_inst_b =
181 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
184
185 constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
186 constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
187
188 constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
189 constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
190
191 constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
192
193 constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
194 constexpr auto ds_read_a_issue_cycle =
195 HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
196 constexpr auto ds_read_b_issue_cycle =
197 HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
198 constexpr auto ds_read_a_mfma_rate =
199 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
200 constexpr auto ds_read_b_mfma_rate =
201 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
202
203 constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
204 constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
205 constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
206 constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
207
208 constexpr auto num_dsread_stage1_a_mfma =
209 (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
210 constexpr auto num_dsread_stage1_b_mfma =
211 (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
212 constexpr auto num_dsread_stage3_a_mfma =
213 (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
214 constexpr auto num_dsread_stage3_b_mfma =
215 (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
216
217 constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
218 num_ds_read_inst_b / ds_read_b_mfma_rate;
219 constexpr auto num_mfma_per_issue =
220 num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
221 constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
222 constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
223
224 // stage 1
226 ignore = i;
227 if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
228 ds_read_a_mfma_rate)
229 {
230 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
231 }
232 else
233 {
234 __builtin_amdgcn_sched_group_barrier(
235 0x100,
236 num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
237 0); // DS read
238 }
239 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
240 });
242 ignore = i;
243 if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
244 ds_read_b_mfma_rate)
245 {
246 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
247 }
248 else
249 {
250 __builtin_amdgcn_sched_group_barrier(
251 0x100,
252 num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
253 0); // DS read
254 }
255 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
256 });
257
258 // stage 2
260 ignore = i;
261 static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
262 ignore = idswrite;
263 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
265 });
266 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
267 __builtin_amdgcn_sched_group_barrier(
268 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
269 });
271 ignore = i;
272 static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
273 ignore = idswrite;
274 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
275 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
276 });
277 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
278 __builtin_amdgcn_sched_group_barrier(
279 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
280 });
281
282 // stage 3
284 ignore = i;
285 if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
286 ds_read_a_mfma_rate)
287 {
288 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
289 }
290 else
291 {
292 __builtin_amdgcn_sched_group_barrier(
293 0x100,
294 num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
295 0); // DS read
296 }
297 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
298 });
300 ignore = i;
301 if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
302 ds_read_b_mfma_rate)
303 {
304 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
305 }
306 else
307 {
308 __builtin_amdgcn_sched_group_barrier(
309 0x100,
310 num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
311 0); // DS read
312 }
313 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
314 });
315
316 // IGLP COMPILER BUG:
317 // If comment out following scheduler barrier would cause sanity fail.
318 __builtin_amdgcn_sched_barrier(0);
319 }
320
321 template <bool HasMainLoop,
322 TailNumber TailNum,
323 typename AGridDesc,
324 typename ABlockDesc,
325 typename ABlockTransfer,
326 typename AGridBuffer,
327 typename ABlockBuffer,
328 typename ABlockTransferStep,
329 typename BGridDesc,
330 typename BBlockDesc,
331 typename BBlockTransfer,
332 typename BGridBuffer,
333 typename BBlockBuffer,
334 typename BBlockTransferStep,
335 typename CThreadBuffer>
336 __device__ void Run(const AGridDesc& a_grid_desc,
337 const ABlockDesc& a_block_desc,
338 ABlockTransfer& a_blockwise_copy,
339 const AGridBuffer& a_grid_buf,
340 ABlockBuffer& a_block_buf,
341 const ABlockTransferStep& a_block_copy_step,
342 const BGridDesc& b_grid_desc,
343 const BBlockDesc& b_block_desc,
344 BBlockTransfer& b_blockwise_copy,
345 const BGridBuffer& b_grid_buf,
346 BBlockBuffer& b_block_buf,
347 const BBlockTransferStep& b_block_copy_step,
348 CThreadBuffer& c_thread_buf,
349 index_t num_loop) const
350 {
352 a_thread_desc_.GetElementSpaceSize());
354 b_thread_desc_.GetElementSpaceSize());
355
356 // Global prefetch 1
357 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
358 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
359
360 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
361 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
362
363 // Local prefill 1
364 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
365 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
366
367 // Global prefetch 2
368 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
369 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
370
371 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
372 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
373
374 // Global prefetch 3
375 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
376 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
377
378 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
379 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
380
381 // Initialize C
382 c_thread_buf.Clear();
383
384 // Local prefetch 1
386 static_for<0, MRepeat, 1>{}([&](auto m0) {
388 make_tuple(m0, I0, I0, I0),
389 a_block_buf,
391 make_tuple(m0, I0, I0, I0),
392 a_thread_buf);
393 });
394 static_for<0, NRepeat, 1>{}([&](auto n0) {
396 make_tuple(n0, I0, I0, I0),
397 b_block_buf,
399 make_tuple(n0, I0, I0, I0),
400 b_thread_buf);
401 });
402
403 // main body
404 if constexpr(HasMainLoop)
405 {
406 index_t i = 0;
407 do
408 {
409 auto LoopFunc = [&](auto vmem_buf) {
412
413 static_for<0, KRepeat, 1>{}([&](auto k0) {
414 if constexpr(k0 == (KRepeat - 1))
415 {
417
418 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
419 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
420
421 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
422 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
423
424 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
425 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
426
428 }
429 static_for<0, MRepeat, 1>{}([&](auto m0) {
430 static_for<0, NRepeat, 1>{}([&](auto n0) {
431 static_for<0, KPack, 1>{}([&](auto ik) {
432 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
433 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
434 make_tuple(m0, I0, I0, ik))>{}];
435 });
436 static_for<0, KPack, 1>{}([&](auto ik) {
437 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
438 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
439 make_tuple(n0, I0, I0, ik))>{}];
440 });
441
442 using mfma_input_type =
444 xdlops_gemm.K1PerXdlops>::type;
445
446 constexpr index_t c_offset =
447 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
448
449 xdlops_gemm.Run(
450 a_thread_vec.template AsType<mfma_input_type>(),
451 b_thread_vec.template AsType<mfma_input_type>(),
452 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
453 });
454
455 a_thread_copy_.Run(
457 make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
458 a_block_buf,
460 make_tuple(m0, I0, I0, I0),
461 a_thread_buf);
462 });
463
464 static_for<0, NRepeat, 1>{}([&](auto n0) {
465 b_thread_copy_.Run(
467 make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
468 b_block_buf,
470 make_tuple(n0, I0, I0, I0),
471 b_thread_buf);
472 });
473 });
474
476 };
477
478 LoopFunc(I0);
479 LoopFunc(I1);
480
481 i += HotloopUnroll;
482 } while(i < (num_loop - PrefetchStages));
483 }
484 // tail
485 auto ReadWriteCompFunc = [&](auto vmem_buf) {
488
489 static_for<0, KRepeat, 1>{}([&](auto k0) {
490 if constexpr(k0 == (KRepeat - 1))
491 {
493
494 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
495 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
496
498 }
499 static_for<0, MRepeat, 1>{}([&](auto m0) {
500 static_for<0, NRepeat, 1>{}([&](auto n0) {
501 static_for<0, KPack, 1>{}([&](auto ik) {
502 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
503 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
504 make_tuple(m0, I0, I0, ik))>{}];
505 });
506 static_for<0, KPack, 1>{}([&](auto ik) {
507 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
508 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
509 make_tuple(n0, I0, I0, ik))>{}];
510 });
511
512 using mfma_input_type =
513 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
514
515 constexpr index_t c_offset =
516 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
517
518 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
519 b_thread_vec.template AsType<mfma_input_type>(),
520 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
521 });
522 a_thread_copy_.Run(
524 make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
525 a_block_buf,
527 make_tuple(m0, I0, I0, I0),
528 a_thread_buf);
529 });
530
531 static_for<0, NRepeat, 1>{}([&](auto n0) {
532 b_thread_copy_.Run(
534 make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
535 b_block_buf,
537 make_tuple(n0, I0, I0, I0),
538 b_thread_buf);
539 });
540 });
541
543 };
544 auto ReadCompFunc = [&]() {
547
548 static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
549 static_for<0, MRepeat, 1>{}([&](auto m0) {
550 static_for<0, NRepeat, 1>{}([&](auto n0) {
551 static_for<0, KPack, 1>{}([&](auto ik) {
552 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
553 a_thread_buf[Number<a_thread_desc_.CalculateOffset(
554 make_tuple(m0, I0, I0, ik))>{}];
555 });
556 static_for<0, KPack, 1>{}([&](auto ik) {
557 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
558 b_thread_buf[Number<b_thread_desc_.CalculateOffset(
559 make_tuple(n0, I0, I0, ik))>{}];
560 });
561
562 using mfma_input_type =
563 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
564
565 constexpr index_t c_offset =
566 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
567
568 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
569 b_thread_vec.template AsType<mfma_input_type>(),
570 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
571 });
572
573 a_thread_copy_.Run(
575 make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
576 a_block_buf,
578 make_tuple(m0, I0, I0, I0),
579 a_thread_buf);
580 });
581
582 static_for<0, NRepeat, 1>{}([&](auto n0) {
583 b_thread_copy_.Run(
585 make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
586 b_block_buf,
588 make_tuple(n0, I0, I0, I0),
589 b_thread_buf);
590 });
591 });
592
593 static_for<0, MRepeat, 1>{}([&](auto m0) {
594 static_for<0, NRepeat, 1>{}([&](auto n0) {
595 static_for<0, KPack, 1>{}([&](auto ik) {
596 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
597 [Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
598 });
599 static_for<0, KPack, 1>{}([&](auto ik) {
600 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
601 [Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
602 });
603
604 using mfma_input_type =
605 typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
606
607 constexpr index_t c_offset =
608 c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
609
610 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
611 b_thread_vec.template AsType<mfma_input_type>(),
612 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
613 });
614 });
615
617 };
618
619 if constexpr(TailNum == TailNumber::Odd)
620 {
621 ReadWriteCompFunc(I0);
622 ReadWriteCompFunc(I1);
623 ReadCompFunc();
624 }
625 else if constexpr(TailNum == TailNumber::Even)
626 {
627 ReadWriteCompFunc(I0);
628 ReadCompFunc();
629 }
630 }
631
632 protected:
633 // A[MRepeat, I1, I1, KPack]
634 static constexpr auto a_thread_desc_ =
636
637 // B[NRepeat, N1, N2, KPack]
638 static constexpr auto b_thread_desc_ =
640
643 decltype(a_block_desc_m0_m1_m2_k),
644 decltype(a_thread_desc_),
647 3,
648 A_K1,
649 A_K1>;
650
653 decltype(b_block_desc_n0_n1_n2_k),
654 decltype(b_thread_desc_),
657 3,
658 B_K1,
659 B_K1>;
660
663 using Base::c_thread_desc_;
664};
665
666} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static __device__ auto CalculateBThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:147
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr auto I1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
static __device__ auto CalculateAThreadOriginDataIndex()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:125
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
static constexpr index_t B_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:51
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< BlockSize, MPerBlock, NPerBlock, KPerBlock, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, A_K1, B_K1, A_K1, B_K1, MRepeat, NRepeat, MPerXDL, NPerXDL, xdlops_gemm.KPerXdlops > HotLoopInstList
Definition blockwise_gemm_pipeline_xdlops_base.hpp:82
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t A_K1
Definition blockwise_gemm_pipeline_xdlops_base.hpp:50
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:102
ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataTypeBuf, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:641
ThreadwiseTensorSliceTransfer_v4< BDataType, ComputeDataTypeBuf, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:651
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:336
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:37
Definition utility/sequence.hpp:43
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition functional2.hpp:33
Definition dtype_vector.hpp:10