gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp Source File

gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21template <typename FloatAB,
22 typename FloatGemmAcc,
23 typename FloatCShuffle,
24 typename FloatC,
25 typename D0sDataType,
26 typename AElementwiseOperation,
27 typename BElementwiseOperation,
28 typename C0DEElementwiseOperation,
29 typename B1ElementwiseOperation,
30 typename C1DEElementwiseOperation,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 typename AGridDesc_AK0_M_AK1,
33 typename BGridDesc_BK0_N_BK1,
34 typename B1GridDesc_BK0_N_BK1,
35 typename C1GridDesc_M_N,
36 typename D0sGridDesc_M_N,
37 index_t NumGemmKPrefetchStage,
38 index_t BlockSize,
39 index_t MPerBlock,
40 index_t NPerBlock,
41 index_t KPerBlock,
42 index_t Gemm1NPerBlock,
43 index_t Gemm1KPerBlock,
44 index_t AK1Value,
45 index_t BK1Value,
46 index_t B1K1Value,
47 index_t MPerXdl,
48 index_t NPerXdl,
49 index_t MXdlPerWave,
50 index_t NXdlPerWave,
51 index_t Gemm1NXdlPerWave,
52 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
53 typename ABlockTransferThreadClusterArrangeOrder,
54 typename ABlockTransferSrcAccessOrder,
55 index_t ABlockTransferSrcVectorDim,
56 index_t ABlockTransferSrcScalarPerVector,
57 index_t ABlockTransferDstScalarPerVector_AK1,
58 bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
59 index_t ABlockLdsExtraM,
60 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
61 typename BBlockTransferThreadClusterArrangeOrder,
62 typename BBlockTransferSrcAccessOrder,
63 index_t BBlockTransferSrcVectorDim,
64 index_t BBlockTransferSrcScalarPerVector,
65 index_t BBlockTransferDstScalarPerVector_BK1,
66 bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
67 index_t BBlockLdsExtraN,
68 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
69 typename B1BlockTransferThreadClusterArrangeOrder,
70 typename B1BlockTransferSrcAccessOrder,
71 index_t B1BlockTransferSrcVectorDim,
72 index_t B1BlockTransferSrcScalarPerVector,
73 index_t B1BlockTransferDstScalarPerVector_BK1,
74 bool B1ThreadTransferSrcResetCoordinateAfterRun,
75 index_t B1BlockLdsExtraN,
76 index_t CShuffleMXdlPerWavePerShuffle,
77 index_t CShuffleNXdlPerWavePerShuffle,
78 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
79 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
80 LoopScheduler LoopSched,
81 bool PadN,
83 int D0sTransferSrcScalarPerVector = 4,
86{
87 static_assert(LoopSched == LoopScheduler::Default,
88 "Non-default loop scheduler is currently not supported");
89
90 static constexpr index_t NumD0Tensor = D0sDataType::Size();
91
92 static constexpr auto I0 = Number<0>{};
93 static constexpr auto I1 = Number<1>{};
94 static constexpr auto I2 = Number<2>{};
95 static constexpr auto I3 = Number<3>{};
96 static constexpr auto I4 = Number<4>{};
97 static constexpr auto I5 = Number<5>{};
98 static constexpr auto I6 = Number<6>{};
99 static constexpr auto I7 = Number<7>{};
100
101 // K1 should be Number<...>
102 // Gemm0
103 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
104 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
105 static constexpr auto AK1 = Number<AK1Value>{};
106 static constexpr auto BK1 = Number<BK1Value>{};
107
108 static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
109 static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
110
111 // Gemm1
112 static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
113 static constexpr auto B1K1 = Number<B1K1Value>{};
114
116
119
120 template <typename ABlockDesc_AK0_M_AK1>
121 __host__ __device__ static constexpr auto
122 MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
123 {
124 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
125
126 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
127 ABlockDesc_AK0_M_AK1{});
128 }
129
130 template <typename BBlockDesc_BK0_N_BK1>
131 __host__ __device__ static constexpr auto
132 MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
133 {
134 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
135
136 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
137 BBlockDesc_BK0_N_BK1{});
138 }
139
140 template <typename ABlockDesc_AK0_M_AK1>
141 __host__ __device__ static constexpr auto
142 MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
143 {
144 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
145 }
146
147 template <typename BBlockDesc_BK0_N_BK1>
148 __host__ __device__ static constexpr auto
149 MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
150 {
151 constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
152 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
153 BBlockDesc_BK0_N_BK1{});
154 }
155
156 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
157 {
158 // A matrix in LDS memory, dst of blockwise copy
162 }
163
164 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
165 {
166 // B matrix in LDS memory, dst of blockwise copy
170 }
171
172 __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
173 {
174 // B1 matrix in LDS memory, dst of blockwise copy
178 }
179
180 __host__ __device__ static constexpr auto
182 {
183 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
184 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
185
186 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
190 I1,
192
193 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
194 }
195
196 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
197 {
198 const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
200 sizeof(FloatAB);
201 const index_t gemm1_bytes_end =
203 sizeof(FloatAB);
204 const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
206 sizeof(FloatGemmAcc);
207 const index_t c_block_bytes_end =
208 SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
209
210 return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
211 }
212
213 template <
214 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
215 __device__ static bool constexpr IsValidCompilationParameter()
216 {
217 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
218 BlockSize,
219 MPerBlock,
220 NPerBlock,
221 MPerXdl,
222 NPerXdl,
223 MXdlPerWave,
224 NXdlPerWave,
225 FloatC,
226 CGlobalMemoryDataOperation>();
227 }
228
229 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
230 template <typename Block2CTileMap>
231 __host__ __device__ static constexpr bool
232 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
233 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
234 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
235 const C1GridDesc_M_N& c1_grid_desc_m_n,
236 const Block2CTileMap& block_2_ctile_map)
237 {
238 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
239 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
240 "Invalid tuning param!");
241
242 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
243 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
244 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
245 const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
246
247 if(!(M == c1_grid_desc_m_n.GetLength(I0) && Gemm1N == c1_grid_desc_m_n.GetLength(I1)))
248 {
249 return false;
250 }
251
252 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
253 Gemm1N % Gemm1NPerBlock == 0))
254 {
255 return false;
256 }
257
258 // check gemm0 gridwise gemm pipeline
259 const auto num_gemm0_k_loop = K / KPerBlock;
260 if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
261 {
262 return false;
263 }
264
265 // check gemm1 gridwise gemm pipeline
266 if(!(NPerBlock % Gemm1KPerBlock == 0))
267 {
268 return false;
269 }
270
271 const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
272 if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
273 {
274 return false;
275 }
276
277 if(!block_2_ctile_map.CheckValidity(c1_grid_desc_m_n))
278 {
279 return false;
280 }
281
282 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
283 return true;
284 }
285
286 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
287 {
288 const index_t num_loop = K / KPerBlock;
289
290 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
291 }
292
293 __host__ __device__ static constexpr auto
294 MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N& c1_grid_desc_m_n)
295 {
296 const auto M = c1_grid_desc_m_n.GetLength(I0);
297 const auto N = c1_grid_desc_m_n.GetLength(I1);
298
299 const auto MBlock = M / MPerBlock;
300 const auto NBlock = N / Gemm1NPerBlock;
301
302 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
303 c1_grid_desc_m_n,
308
309 return c_grid_desc_mblock_mperblock_nblock_nperblock;
310 }
311
312 // return block_id to C matrix tile idx (m0, n0) mapping
313 __host__ __device__ static constexpr auto
314 MakeDefaultBlock2CTileMap(const C1GridDesc_M_N& c1_grid_desc_m_n)
315 {
317 c1_grid_desc_m_n);
318 }
319
320 __device__ static auto GetGemm0WaveIdx()
321 {
322 const index_t thread_id = get_thread_local_1d_id();
323 constexpr auto WaveSize = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.wave_size;
324
325 constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
329
330 return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
331 }
332
333 __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
334 {
335 constexpr auto WaveSize = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.wave_size;
336 constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
337 make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))),
340
341 return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
342 }
343
344 static constexpr auto MakeD0sGridPointer()
345 {
346 return generate_tuple(
347 [&](auto i) {
348 using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
349
350 return static_cast<const D0DataType*>(nullptr);
351 },
353 }
354 // D0 desc for source in blockwise copy
355 template <typename D0GridDesc_M_N>
356 __host__ __device__ static constexpr auto
357 MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n)
358 {
359 const auto M = d0_grid_desc_m_n.GetLength(I0);
360 const auto N = d0_grid_desc_m_n.GetLength(I1);
361
362 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
363 constexpr bool is_single_rate_mfma =
365 lcm_AK1_BK1 <= 4) ||
366 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
368 lcm_AK1_BK1 < 32))
369 ? true
370 : false;
371 constexpr auto is_scale_mfma = false;
372 constexpr auto mfma =
374 selected_mfma;
375 constexpr auto N3 = mfma.num_groups_per_blk;
376 constexpr auto N4 = mfma.num_input_blks;
377 constexpr auto N5 = mfma.group_size;
379 d0_grid_desc_m_n,
381 make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)),
383 make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))),
386 }
387
388 // D0s desc for source in blockwise copy
389 __host__ __device__ static constexpr auto
390 MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n)
391 {
392 return generate_tuple(
393 [&](auto i) {
395 },
397 }
398
402 D0sGridDesc_M_N{}))>;
403
406 C1GridDesc_M_N{}))>;
407
409 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))>;
410
412 {
413 // LDS allocation for A and B: be careful of alignment
414 static constexpr auto a_block_desc_ak0_m_ak1 =
416 static constexpr auto b_block_desc_bk0_n_bk1 =
418 static constexpr auto b1_block_desc_bk0_n_bk1 =
420
421 static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
422
424 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
426 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
428 b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
429
430 static constexpr auto a_block_space_offset = 0;
432 static constexpr auto b1_block_space_offset = 0;
433
434 // LDS allocation for reduction
437
438 static constexpr auto reduction_space_offset = 0;
439
440 // LDS allocation for C shuffle in LDS
443 static constexpr auto c_block_space_size =
445 };
446
447 template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
448 __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
449 const FloatAB* __restrict__ p_b_grid,
450 const FloatAB* __restrict__ p_b1_grid,
451 FloatC* __restrict__ p_c_grid,
452 D0sGridPointer p_d0s_grid,
453 void* __restrict__ p_shared,
454 const AElementwiseOperation& a_element_op,
455 const BElementwiseOperation& b_element_op,
456 const C0DEElementwiseOperation& c0de_element_op,
457 const B1ElementwiseOperation& b1_element_op,
458 const C1DEElementwiseOperation& c1de_element_op,
459 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
460 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
461 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
463 c_grid_desc_mblock_mperblock_nblock_nperblock,
464 const D0sGridDesc_M_N& d0s_griddesc_m_n,
465 const Block2CTileMap& block_2_ctile_map,
466 const C0MatrixMask& c0_matrix_mask)
467 {
468 const auto d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
470 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
471 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
472 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
473 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
474 const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
475 p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
477 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
478 const auto d0s_grid_buf = generate_tuple(
479 [&](auto i) {
481 p_d0s_grid[i],
482 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize());
483 },
485
486 // divide block work by [M, N]
487 const auto block_work_idx =
488 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
489
490 if(!block_2_ctile_map.ValidCTileIndex(
491 block_work_idx,
492 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
493 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
494 {
495 return;
496 }
497
498 // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
499 const index_t m_block_data_idx_on_grid =
500 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
501
502 const index_t gemm1_n_block_data_idx_on_grid =
503 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
504
505 // A matrix in LDS memory, dst of blockwise copy
506 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
507
508 // B matrix in LDS memory, dst of blockwise copy
509 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
510
511 //
512 // set up Gemm0
513 //
514
515 // A matrix blockwise copy
516 auto a_blockwise_copy =
518 AElementwiseOperation,
522 ABlockTransferThreadClusterLengths_AK0_M_AK1,
523 ABlockTransferThreadClusterArrangeOrder,
524 FloatAB,
525 FloatAB,
526 decltype(a_grid_desc_ak0_m_ak1),
527 decltype(a_block_desc_ak0_m_ak1),
528 ABlockTransferSrcAccessOrder,
530 ABlockTransferSrcVectorDim,
531 2,
532 ABlockTransferSrcScalarPerVector,
533 ABlockTransferDstScalarPerVector_AK1,
534 1,
535 1,
536 true, // SrcResetCoord
537 true, // DstResetCoord
538 NumGemmKPrefetchStage>(
539 a_grid_desc_ak0_m_ak1,
540 make_multi_index(0, m_block_data_idx_on_grid, 0),
541 a_element_op,
542 a_block_desc_ak0_m_ak1,
543 make_multi_index(0, 0, 0),
545
546 // B matrix blockwise copy
547 auto b_blockwise_copy =
549 BElementwiseOperation,
553 BBlockTransferThreadClusterLengths_BK0_N_BK1,
554 BBlockTransferThreadClusterArrangeOrder,
555 FloatAB,
556 FloatAB,
557 decltype(b_grid_desc_bk0_n_bk1),
558 decltype(b_block_desc_bk0_n_bk1),
559 BBlockTransferSrcAccessOrder,
561 BBlockTransferSrcVectorDim,
562 2,
563 BBlockTransferSrcScalarPerVector,
564 BBlockTransferDstScalarPerVector_BK1,
565 1,
566 1,
567 true, // SrcResetCoord
568 true, // DstResetCoord
569 NumGemmKPrefetchStage>(
570 b_grid_desc_bk0_n_bk1,
571 make_multi_index(0, 0, 0), // will loop over GemmN dimension
572 b_element_op,
573 b_block_desc_bk0_n_bk1,
574 make_multi_index(0, 0, 0),
576
577 // Fused Gemm+Gemm pipeline
578 // for n in N0:
579 // for k in K0:
580 // acc[m][n] += A[m][k] * B0[k][n]
581 // acc1[m][o] += acc[m][n] * B1[n][o]
582
583 // sanity check
584 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
585 constexpr bool is_single_rate_mfma =
587 lcm_AK1_BK1 <= 4) ||
588 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
590 lcm_AK1_BK1 < 32))
591 ? true
592 : false;
593 constexpr auto is_scale_mfma = false;
594 constexpr index_t KPack = math::max(
595 lcm_AK1_BK1,
597 selected_mfma.k_per_blk);
598
599 auto blockwise_gemm = BlockwiseGemmXdlops_v2<
600 BlockSize,
601 FloatAB,
602 FloatGemmAcc,
603 decltype(a_block_desc_ak0_m_ak1),
604 decltype(b_block_desc_bk0_n_bk1),
605 decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
606 decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
607 MPerBlock,
608 NPerBlock,
609 KPerBlock,
610 MPerXdl,
611 NPerXdl,
612 MXdlPerWave,
613 NXdlPerWave,
614 KPack,
615 true>{}; // TransposeC
616
617 auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
618
619 // LDS allocation for A and B: be careful of alignment
621 static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
622 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
623
625 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
626 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
627
628 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
629 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
630 const auto a_block_reset_copy_step =
631 make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
632 const auto b_block_reset_copy_step =
633 make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
634
635 // gridwise GEMM pipeline
636 // Only supports LoopScheduler::Default
637 const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
638 NumGemmKPrefetchStage,
640
641 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
642 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
643 KPerBlock);
644
645 //
646 // set up Gemm1
647 //
648
649 // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
650 constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
651 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
652
653 constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
654 constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
655 constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
656 constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
657 constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
658 constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
659 constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
660 constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
661
662 constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
663
664 // d0 matrix threadwise copy
665 constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
667 I1, // NBlockID
668 m0, // MRepeat
669 n0, // NRepeat
670 m1, // MWaveId
671 n1, // NWaveId
672 m2, // MPerXdl
673 n2, // NGroupNum
674 n3, // NInputNum
675 n4)); // registerNum
676
677 auto d0s_thread_buf = generate_tuple(
678 [&](auto i) {
679 using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
680 return StaticBuffer<
682 D0DataType,
683 d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(),
684 true>{};
685 },
687
688 const auto wave_id = GetGemm0WaveIdx();
689 const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
690
691 auto d0s_threadwise_copy = generate_tuple(
692 [&](auto i) {
693 using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
695 D0DataType,
696 D0DataType,
697 decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
698 decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
699 Sequence<I1, // MBlockId
700 I1, // NBlockID
701 m0, // MRepeat
702 n0, // NRepeat
703 m1, // MWaveId
704 n1, // NWaveId
705 m2, // MPerXdl
706 n2, // NGroupNum
707 n3, // NInputNum
708 n4>,
710 9,
711 D0sTransferSrcScalarPerVector,
712 1,
713 false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
714 make_multi_index(block_work_idx[I0], // MBlockId
715 0, // NBlockId
716 0, // mrepeat
717 0, // nrepeat
718 wave_id[I0], // MWaveId
719 wave_id[I1], // NWaveId
720 wave_m_n_id[I1], // MPerXdl
721 0, // group
722 wave_m_n_id[I0], // NInputIndex
723 0)); // register number
724 },
726 // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
727 // n0_n1_n2_n3 -> k0
728 // m0_m1_m2 -> m
729 // n4 -> k1
730 // NOTE: had to use merge_v3 or will spit out compilation errors
731 constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
732 acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
738
739 // A1 matrix in AccVGPR
740 // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
741 constexpr auto AccN3 =
742 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
743
744 constexpr auto A1ThreadSlice_K0_M_K1 =
746
747 constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
748 constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
749 constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
750#if defined(__gfx11__)
751 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
752 make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number<A1ThreadSliceK1 * 2>{}));
754 FloatGemmAcc,
755 FloatAB,
756 decltype(acc_thread_desc_k0_m_k1),
757 decltype(a1_thread_desc_k0_m_k1),
761 2,
762 n4,
763 0x76543210,
764 0xfedcba98,
765 false>{make_tuple(0, 0, 0)};
766#else
767 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
768 A1ThreadSlice_K0_M_K1,
769 make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
770
771 // A1 matrix blockwise copy
772 auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
773 FloatGemmAcc,
774 FloatAB,
775 decltype(acc_thread_desc_k0_m_k1),
776 decltype(a1_thread_desc_k0_m_k1),
780 2,
782#endif
783 // B1 matrix in LDS memory, dst of blockwise copy
784 constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
785 // B1 matrix blockwise copy
786 auto b1_blockwise_copy =
788 BElementwiseOperation,
792 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
793 B1BlockTransferThreadClusterArrangeOrder,
794 FloatAB,
795 FloatAB,
796 decltype(b1_grid_desc_bk0_n_bk1),
797 decltype(b1_block_desc_bk0_n_bk1),
798 B1BlockTransferSrcAccessOrder,
800 B1BlockTransferSrcVectorDim,
801 2,
802 B1BlockTransferSrcScalarPerVector,
803 B1BlockTransferDstScalarPerVector_BK1,
804 1,
805 1,
806 B1ThreadTransferSrcResetCoordinateAfterRun,
807 true, // DstResetCoord
808 NumGemmKPrefetchStage>(
809 b1_grid_desc_bk0_n_bk1,
810 make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
811 b1_element_op,
812 b1_block_desc_bk0_n_bk1,
813 make_multi_index(0, 0, 0),
815
817 a1_thread_desc_k0_m_k1.GetElementSpaceSize());
818
819 // reuse LDS space for gemm0's b_block_buf
821 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
822 b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
823
824 // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
825 // selected_mfma.k_per_blk <= Gemm1KPack
826 //
827 // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
828 // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
829 // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
830 // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
831 // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
832 // therefore we may just as well assign Gemm1KPack = group_size
833#if defined(__gfx11__)
834 constexpr index_t Gemm1KPack =
836#else
837 constexpr index_t Gemm1KPack =
839#endif
840
841 auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
842 BlockSize,
843 FloatAB,
844 FloatGemmAcc,
845 decltype(a1_thread_desc_k0_m_k1),
846 decltype(b1_block_desc_bk0_n_bk1),
847 decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
848 decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
849 MPerBlock,
850 Gemm1NPerBlock,
851 Gemm1KPerBlock,
852 MPerXdl,
853 NPerXdl,
854 MXdlPerWave,
855 Gemm1NXdlPerWave,
856 Gemm1KPack,
857 true, // TransposeC
858 Gemm1KPack, // AMmaKStride
859 Gemm1KPack *
861 // BMmaKStride
862 make_tuple(0, 0, 0, 0)}; // A_origin
863
864 auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
865
866 //
867 // Blockwise softmax
868 //
870 static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
872
873 // get acc0 8D thread cluster
874 constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
875 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
876 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
877 constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
878 constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
879 constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
880 constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
881 constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
882 constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
883 constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
884 constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
885
886 // get acc0 thread map
887 constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
892 constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
894 make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
897 const auto threadid_to_m_n_thread_cluster_adaptor =
898 chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
899
900 // get acc0 2D thread cluster & 2D thread slice
901 constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
902 make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
903 constexpr auto thread_slice_desc_m_n =
904 make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
905
906 auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
907 FloatGemmAcc,
908 decltype(threadid_to_m_n_thread_cluster_adaptor),
909 decltype(thread_cluster_desc_m_n),
910 decltype(thread_slice_desc_m_n)>{};
911
912 const index_t num_gemm1_k_block_outer_loop =
913 b_grid_desc_bk0_n_bk1.GetLength(I1) / (NPerBlock / Gemm0NWaves);
914 constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm0NWaves / Gemm1KPerBlock;
915
916 // Initialize C
917 StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true>
918 c_thread_buf;
919 c_thread_buf.Clear();
920
921 // Initialize running sum and max of exponentiating row vectors
922 using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
923 SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
924 running_sum = 0;
925 running_sum_new = 0;
927 running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
928
929 // gemm1 K loop
930 index_t gemm1_k_block_outer_index = 0;
931 do
932 {
933 auto n_block_data_idx_on_grid =
934 __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
935 if(c0_matrix_mask.IsTileSkippable(
936 m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
937 {
938 continue;
939 }
940 // gemm0
941 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
942 a_block_desc_ak0_m_ak1,
943 a_blockwise_copy,
944 a_grid_buf,
945 a_block_buf,
946 a_block_slice_copy_step,
947 b_grid_desc_bk0_n_bk1,
948 b_block_desc_bk0_n_bk1,
949 b_blockwise_copy,
950 b_grid_buf,
951 b_block_buf,
952 b_block_slice_copy_step,
953 blockwise_gemm,
954 acc_thread_buf,
955 num_k_block_main_loop);
956 // multiple d
957 if constexpr(NumD0Tensor)
958 {
959 static_assert(NXdlPerWave == n0);
960 static_assert(MXdlPerWave == m0);
961
962 static_for<0, NumD0Tensor, 1>{}([&](auto i) {
963 d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
964 d0s_grid_buf[i],
965 d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
966 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
967 d0s_thread_buf(i));
968 });
970 // get reference to src data
971 const auto src_data_refs = generate_tie(
972 // return type should be lvalue
973 [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
975
976 // get reference to dst data
977 auto dst_data_refs = generate_tie(
978 // return type should be lvalue
979 [&](auto) -> auto& { return acc_thread_buf(i); },
980 Number<2>{});
981
982 unpack2(c0de_element_op, dst_data_refs, src_data_refs);
983 });
984 static_for<0, NumD0Tensor, 1>{}([&](auto i) {
985 d0s_threadwise_copy(i).MoveSrcSliceWindow(
986 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
987 make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
988 });
989 }
990 else
991 {
992 static_for<0, acc_thread_buf.Size(), 1>{}(
993 [&](auto i) { c0de_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
994 }
995
996 // do MNK padding or upper triangular masking
997 if constexpr(MaskOutUpperTriangle || PadN)
998 {
999 // 8d thread_desc in thread scope
1000 constexpr auto c_thread_lengths =
1001 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
1002
1003 // 8d block_desc in block scope
1004 constexpr auto c_block_lengths =
1005 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
1006
1007 constexpr auto M0 = c_block_lengths[I0];
1008 constexpr auto N0 = c_block_lengths[I1];
1009 constexpr auto M1 = c_block_lengths[I2];
1010 constexpr auto N1 = c_block_lengths[I3];
1011 constexpr auto M2 = c_block_lengths[I4];
1012 constexpr auto N2 = c_block_lengths[I5];
1013 constexpr auto N3 = c_block_lengths[I6];
1014 constexpr auto N4 = c_block_lengths[I7];
1015
1016 // works like multi-dimension static_for (static_ford), but provides both the linear
1017 // index as well as n-d index
1018 using Acc0TileIterator = SpaceFillingCurve<
1019 decltype(c_thread_lengths),
1020 typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
1021 typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
1022 false>; // SnakeCurved
1023
1024 auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
1026
1027 constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
1029 make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
1032
1033 static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
1034 auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
1035 auto m_local =
1036 block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
1037 auto n_local =
1038 block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
1039 auto m_global = m_local + m_block_data_idx_on_grid;
1040 auto n_global = n_local + n_block_data_idx_on_grid;
1041 if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
1042 {
1043 acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
1044 }
1045 });
1046 }
1047
1048 block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
1049
1050 // softmax
1051 SoftmaxBuf& max = blockwise_softmax.max_value_buf;
1052 SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
1053
1054 blockwise_softmax.Run(acc_thread_buf, workspace_buf);
1055
1056 // TODO: may convert to log domain
1057 running_max_new = mathext::max(max, running_max);
1058 running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
1059 mathext::exp(max - running_max_new) * sum;
1060
1061 // gemm1
1062 {
1063 // TODO: explore using dynamic buffer for a1 thread buffer
1064 // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
1065 // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
1066 // the A1 source buffer is static buffer holding the output of first GEMM and
1067 // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
1068 // explicitly in Run() below.
1069
1070 // Initialize acc1
1071 acc1_thread_buf.Clear();
1072
1073 // preload data into LDS
1074 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
1075
1076 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
1077 b1_block_slice_copy_step);
1078
1079 block_sync_lds(); // wait for reduction LDS read
1080
1081 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
1082
1083 // main body
1084 if constexpr(num_gemm1_k_block_inner_loop > 1)
1085 {
1086
1087 static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
1088 a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
1090 acc_thread_buf,
1091 a1_thread_desc_k0_m_k1,
1092 make_tuple(I0, I0, I0),
1093 a1_thread_buf);
1094
1095 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
1096
1098
1099 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
1100
1102
1103 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
1104 b1_block_slice_copy_step);
1105
1106 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
1107 });
1108 }
1109 // tail
1110 {
1111 a1_blockwise_copy.Run(
1112 acc_thread_desc_k0_m_k1,
1113 make_tuple(
1114 Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
1115 acc_thread_buf,
1116 a1_thread_desc_k0_m_k1,
1117 make_tuple(I0, I0, I0),
1118 a1_thread_buf);
1119
1121
1122 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
1123 }
1124 } // end gemm1
1125
1126 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1127 gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1128 constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
1129 constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
1130 constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
1131 constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
1132 constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
1133 constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
1134 constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
1135 constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
1136 constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
1137 make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
1138 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
1139 constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
1140
1143 auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
1144 FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
1145 FloatGemmAcc c = c_thread_buf[I]; // O
1146 FloatGemmAcc c_new =
1147 (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
1148 math::exp(max[iM] - running_max_new[iM]) * acc1) /
1149 running_sum_new[iM]; // Formula by Dao et al.,
1150 // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
1151
1152 c_thread_buf(I) = c_new; // O_new
1153 });
1154 });
1155
1156 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
1157 a_block_reset_copy_step); // rewind K
1158 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
1159 b_block_reset_copy_step); // rewind K and step N
1160
1161 // update before next j iteration
1162 running_max = running_max_new;
1163 running_sum = running_sum_new;
1164
1165 block_sync_lds(); // wait for gemm1 LDS read
1166 } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
1167
1168 // shuffle C and write out
1169 {
1170 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1171 Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1172 "wrong!");
1173
1174 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1175 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
1176
1177 // TODO: hacky, fix it!
1178 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1179 gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1180
1181 // TODO: hacky, fix it!
1182 // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
1183 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1184 gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1185
1186 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1187 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1188 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1189 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1190 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1191 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1192 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1193 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1194
1195 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1197
1198 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1199 static_cast<FloatCShuffle*>(p_shared),
1200 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1201
1202 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1203 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1204 make_tuple(
1207 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1208 M1, // M1 = MWave
1209 M2)), // M2 = MPerXdl
1212 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1213 N1, // N1 = NWave
1214 N2, // N2 * N3 * N4 = NPerXdl
1215 N3,
1216 N4))),
1218 make_tuple(
1220
1221 // calculate origin of thread output tensor on global memory
1222 // blockwise GEMM c matrix starting index
1223 const auto c_thread_mtx_on_block =
1224 gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1225
1226 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1227 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1228
1229 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1234
1235 const auto m_thread_data_on_block_idx =
1236 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1237 make_multi_index(m_thread_data_on_block));
1238
1239 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1241 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1244
1245 const auto n_thread_data_on_block_idx =
1246 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1247 make_multi_index(n_thread_data_on_block));
1248
1249 // shuffle: threadwise copy C from VGPR to LDS
1250 auto c_thread_copy_vgpr_to_lds =
1252 FloatCShuffle,
1253 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1254 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1256 Sequence<CShuffleMXdlPerWavePerShuffle,
1257 CShuffleNXdlPerWavePerShuffle,
1258 I1,
1259 I1,
1260 I1,
1261 N2,
1262 I1,
1263 N4>,
1265 7,
1266 1,
1268 1,
1269 true>{
1270 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1272 0,
1273 m_thread_data_on_block_idx[I1],
1274 n_thread_data_on_block_idx[I1],
1275 m_thread_data_on_block_idx[I2],
1276 n_thread_data_on_block_idx[I2],
1277 n_thread_data_on_block_idx[I3],
1278 n_thread_data_on_block_idx[I4]),
1280
1281 // shuffle: blockwise copy C from LDS to global
1282 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1283 ThisThreadBlock, // ThreadGroup
1284 C1DEElementwiseOperation, // ElementwiseOperation,
1285 CGlobalMemoryDataOperation, // DstInMemOp,
1286 Sequence<1,
1287 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1288 1,
1289 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1290 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1291 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1292 FloatCShuffle, // typename SrcData,
1293 FloatC, // typename DstData,
1294 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1295 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1296 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1297 3, // index_t VectorDim,
1298 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1299 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1300 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1301 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1302 make_multi_index(0, 0, 0, 0),
1303 c_grid_desc_mblock_mperblock_nblock_nperblock,
1304 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1305 c1de_element_op};
1306
1307 // space filling curve for threadwise C in VGPR
1308 constexpr auto sfc_c_vgpr =
1311 Sequence<CShuffleMXdlPerWavePerShuffle,
1312 CShuffleNXdlPerWavePerShuffle,
1313 1,
1314 1,
1315 1,
1316 N2,
1317 1,
1318 N4>>{};
1319
1320 // space filling curve for shuffled blockwise C in global mem
1321 constexpr auto sfc_c_global =
1324 Sequence<1,
1325 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1326 1,
1327 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1328
1329 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1330
1331 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1332
1333 static_for<0, num_access, 1>{}([&](auto access_id) {
1334 // make sure it's safe to write to LDS
1336
1337 // each thread write its data from VGPR to LDS
1338 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1339 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1340 c_thread_buf,
1341 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1342 c_shuffle_block_buf);
1343
1344 // make sure it's safe to read from LDS
1346
1347 // each block copy its data from LDS to global
1348 c_shuffle_block_copy_lds_to_global.Run(
1349 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1350 c_shuffle_block_buf,
1351 c_grid_desc_mblock_mperblock_nblock_nperblock,
1352 c_grid_buf);
1353
1354 if constexpr(access_id < num_access - 1)
1355 {
1356 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1357
1358 // move on C
1359 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1360 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1361 }
1362 });
1363 }
1364 }
1365};
1366
1367} // namespace ck
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
__host__ __device__ constexpr auto exp(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:124
__host__ __device__ constexpr auto max(const Tuple< Xs... > &x, const Y &y)
Definition statically_indexed_array_multi_index.hpp:134
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto unpack2(F &&f, X &&x, Y &&y)
Definition functional4.hpp:55
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
Definition block_to_ctile_map.hpp:261
Blockwise gemm.
Definition blockwise_gemm_xdlops.hpp:690
Blockwise softmax.
Definition blockwise_softmax.hpp:32
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:412
static constexpr auto b_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:416
static constexpr auto b1_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:418
static constexpr auto max_lds_align
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:421
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:441
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::c_block_space_size
static constexpr auto c_block_space_size
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:443
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::a_block_space_size_aligned
static constexpr auto a_block_space_size_aligned
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:423
static constexpr auto b_block_space_offset
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:431
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b1_block_space_size_aligned
static constexpr auto b1_block_space_size_aligned
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:427
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b_block_space_size_aligned
static constexpr auto b_block_space_size_aligned
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:425
static constexpr auto a_block_desc_ak0_m_ak1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:414
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::reduction_space_offset
static constexpr auto reduction_space_offset
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:438
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b1_block_space_offset
static constexpr auto b1_block_space_offset
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:432
static constexpr auto a_block_space_offset
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:430
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::reduction_space_size_aligned
static constexpr index_t reduction_space_size_aligned
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:435
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:86
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetSharedMemoryNumberOfByte
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:196
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
remove_cvref_t< decltype(MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(D0sGridDesc_M_N{}))> D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:400
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
__host__ static __device__ constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:172
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
__host__ static __device__ constexpr auto MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:142
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I0
static constexpr auto I0
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:92
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::AK1
static constexpr auto AK1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:105
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
__host__ static __device__ constexpr auto MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:149
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::BK0
static constexpr auto BK0
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:104
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
__host__ static __device__ constexpr auto MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:390
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:181
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetGemm0WaveMNIdx
static __device__ auto GetGemm0WaveMNIdx(const index_t thread_id)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:333
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::IsValidCompilationParameter
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:215
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::GridwiseGemmPipe
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVersion::v1, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:117
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::Gemm0MWaves
static constexpr auto Gemm0MWaves
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:108
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
__host__ static __device__ constexpr auto MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N &d0_grid_desc_m_n)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:357
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::D0sGridPointer
decltype(MakeD0sGridPointer()) D0sGridPointer
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:399
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I1
static constexpr auto I1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:93
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
__host__ static __device__ constexpr auto MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:122
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:156
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:404
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetGemm0WaveIdx
static __device__ auto GetGemm0WaveIdx()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:320
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I2
static constexpr auto I2
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:94
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::AK0
static constexpr auto AK0
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:103
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:286
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I7
static constexpr auto I7
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:99
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:164
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::B1K1
static constexpr auto B1K1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:113
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::CheckValidity
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const C1GridDesc_M_N &c1_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:232
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(C1GridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:408
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeD0sGridPointer
static constexpr auto MakeD0sGridPointer()
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:344
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::B1K0
static constexpr auto B1K0
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:112
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::BK1
static constexpr auto BK1
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:106
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I6
static constexpr auto I6
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:98
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::NumD0Tensor
static constexpr index_t NumD0Tensor
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:90
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
__host__ static __device__ constexpr auto MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:132
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::Gemm0NWaves
static constexpr auto Gemm0NWaves
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:109
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I5
static constexpr auto I5
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:97
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I3
static constexpr auto I3
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:95
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::ThisThreadBlock
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:115
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const C1GridDesc_M_N &c1_grid_desc_m_n)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:314
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >< math::max(MXdlPerWave64, 1)>::I4
static constexpr auto I4
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:96
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::Run
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, const ADataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, D0sGridPointer p_d0s_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const C0DEElementwiseOperation &c0de_element_op, const B1ElementwiseOperation &b1_element_op, const C1DEElementwiseOperation &c1de_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDesc_M_N &d0s_griddesc_m_n, const Block2CTileMap &block_2_ctile_map, const C0MatrixMask &c0_matrix_mask)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:448
ck::GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, D0sDataType, AElementwiseOperation, BElementwiseOperation, C0DEElementwiseOperation, B1ElementwiseOperation, C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, C1GridDesc_M_N, D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, Transform::matrix_padder.PadN, MaskingSpec==MaskingSpecialization::MaskOutUpperTriangle, D0sTransferSrcScalarPerVector >::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N &c1_grid_desc_m_n)
Definition gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp:294
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
__host__ static __device__ constexpr T Lowest()
Definition numeric_limits.hpp:312
__host__ static __device__ constexpr T Infinity()
Definition numeric_limits.hpp:317
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:1877
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
Definition utility/sequence.hpp:256
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/sequence.hpp:289