device_splitk_contraction_multiple_d_xdl_cshuffle.hpp Source File

device_splitk_contraction_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: device_splitk_contraction_multiple_d_xdl_cshuffle.hpp Source File
device_splitk_contraction_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemm,
24 typename FloatAB,
25 typename FloatDsPointer,
26 typename FloatE,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CDEElementwiseOperation,
30 typename AGridDesc_AKB_AK0_M_AK1,
31 typename BGridDesc_BKB_BK0_N_BK1,
32 typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34 typename ComputePtrOffsetOfBatch,
35 typename Block2ETileMap,
36 bool HasMainKBlockLoop>
37__global__ void
38#if CK_USE_LAUNCH_BOUNDS
40#endif
42 const FloatAB* __restrict__ p_a_grid,
43 const FloatAB* __restrict__ p_b_grid,
44 FloatDsPointer p_ds_grid,
45 FloatE* __restrict__ p_e_grid,
46 const index_t batch_count,
47 const AElementwiseOperation a_element_op,
48 const BElementwiseOperation b_element_op,
49 const CDEElementwiseOperation cde_element_op,
50 const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1,
51 const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1,
52 const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 ds_grid_desc_mblock_mperblock_nblock_nperblock,
54 const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
55 e_grid_desc_mblock_mperblock_nblock_nperblock,
56 const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
57 const Block2ETileMap block_2_etile_map)
58{
59#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
60 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
61 {
62 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
63
64 const index_t num_blocks_per_batch =
65 __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
66 const index_t g_idx =
67 __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
68
69 const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
70 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
71 const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
72 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
73 const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
74 static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
75
76 const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
77
78 FloatDsPointer p_ds_grid_grp;
79
80 static constexpr index_t NumDTensor =
81 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
82
84 [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
85
86 GridwiseGemm::template Run<HasMainKBlockLoop>(
87 p_a_grid + a_batch_offset,
88 p_b_grid + b_batch_offset,
89 p_ds_grid_grp,
90 p_e_grid + e_batch_offset,
91 p_shared,
92 a_element_op,
93 b_element_op,
94 cde_element_op,
95 a_grid_desc_akb_ak0_m_ak1,
96 b_grid_desc_bkb_bk0_n_bk1,
97 ds_grid_desc_mblock_mperblock_nblock_nperblock,
98 e_grid_desc_mblock_mperblock_nblock_nperblock,
99 block_2_etile_map);
100 }
101#else
102 ignore = p_a_grid;
103 ignore = p_b_grid;
104 ignore = p_ds_grid;
105 ignore = p_e_grid;
106 ignore = batch_count;
107 ignore = a_element_op;
108 ignore = b_element_op;
109 ignore = cde_element_op;
110 ignore = a_grid_desc_akb_ak0_m_ak1;
111 ignore = b_grid_desc_bkb_bk0_n_bk1;
112 ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
113 ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
114 ignore = block_2_etile_map;
115 ignore = compute_ptr_offset_of_batch;
116#endif
117}
118
119} // namespace ck
120
121namespace ck {
122namespace tensor_operation {
123namespace device {
124
125// Tensor Contraction:
126// input : A
127// input : B
128// input : D0, D1, ...
129// output : E
130// C = a_op(A) * b_op(B)
131// E = cde_op(C, D0, D1, ...)
132// Assume:
133// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
134// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
135// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
136// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
137template <index_t NumDimG,
138 index_t NumDimM,
139 index_t NumDimN,
140 index_t NumDimK,
141 typename ADataType,
142 typename BDataType,
143 typename AccDataType,
144 typename CShuffleDataType,
145 typename DsDataType,
146 typename EDataType,
147 typename AElementwiseOperation,
148 typename BElementwiseOperation,
149 typename CDEElementwiseOperation,
150 GemmSpecialization GemmSpec,
154 index_t NumGemmKPrefetchStage,
155 index_t BlockSize,
156 index_t MPerBlock,
157 index_t NPerBlock,
158 index_t KPerBlock,
159 index_t AK1,
160 index_t BK1,
161 index_t MPerXDL,
162 index_t NPerXDL,
163 index_t MXdlPerWave,
164 index_t NXdlPerWave,
165 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
166 typename ABlockTransferThreadClusterArrangeOrder,
167 typename ABlockTransferSrcAccessOrder,
168 index_t ABlockTransferSrcVectorDim,
169 index_t ABlockTransferSrcScalarPerVector,
170 index_t ABlockTransferDstScalarPerVector_AK1,
171 bool ABlockLdsExtraM,
172 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
173 typename BBlockTransferThreadClusterArrangeOrder,
174 typename BBlockTransferSrcAccessOrder,
175 index_t BBlockTransferSrcVectorDim,
176 index_t BBlockTransferSrcScalarPerVector,
177 index_t BBlockTransferDstScalarPerVector_BK1,
178 bool BBlockLdsExtraN,
179 index_t CShuffleMXdlPerWavePerShuffle,
180 index_t CShuffleNXdlPerWavePerShuffle,
181 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
182 index_t CDEBlockTransferScalarPerVector_NPerBlock,
185 : public DeviceSplitKContractionMultipleD<NumDimG,
186 NumDimM,
187 NumDimN,
188 NumDimK,
189 ADataType,
190 BDataType,
191 DsDataType,
192 EDataType,
193 AElementwiseOperation,
194 BElementwiseOperation,
195 CDEElementwiseOperation>
196{
199 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
200 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
201 static constexpr index_t NumDTensor = DsDataType::Size();
202
203 static constexpr auto I0 = Number<0>{};
204 static constexpr auto I1 = Number<1>{};
205 static constexpr auto I2 = Number<2>{};
206 static constexpr auto I3 = Number<3>{};
207
208 static constexpr auto matrix_padder =
209 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
210
211 // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
212 static auto MakeAGridDescriptor_M_K(const std::vector<index_t>& a_gs_ms_ks_lengths_vec,
213 const std::vector<index_t>& a_gs_ms_ks_strides_vec)
214 {
215 assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK &&
216 a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK);
217
218 const auto to_tuple = [&](auto& vec, auto start, auto end) {
219 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
220 };
221
222 const auto a_ms_ks_lengths = to_tuple(
223 a_gs_ms_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
224 const auto a_ms_ks_strides = to_tuple(
225 a_gs_ms_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimK>{});
226
227 // dimension Ids for M0, M1, ...
228 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
229
230 // dimension Ids for K0, K1, ...
231 constexpr auto kDimIds =
233
234 // lengths for M0, M1, ...
235 const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
236
237 // lengths for K0, K1, ...
238 const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
239
240 if constexpr(ASpec == TensorSpecialization::Packed)
241 {
242 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
243 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
244 const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor(
245 make_tuple(M, K),
246 make_tuple(a_ms_ks_strides[Number<NumDimM - 1>{}],
247 a_ms_ks_strides[Number<NumDimM + NumDimK - 1>{}]));
248 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
249 }
250 else
251 {
252 // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
253 const auto a_grid_desc_ms_ks =
254 make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
255
256 // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
257 const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
258 a_grid_desc_ms_ks,
260 make_tuple(mDimIds, kDimIds),
262
263 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
264 }
265 }
266
267 // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
268 static auto MakeBGridDescriptor_N_K(const std::vector<index_t>& b_gs_ns_ks_lengths_vec,
269 const std::vector<index_t>& b_gs_ns_ks_strides_vec)
270 {
271 assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK &&
272 b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK);
273
274 const auto to_tuple = [&](auto& vec, auto start, auto end) {
275 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
276 };
277
278 const auto b_ns_ks_lengths = to_tuple(
279 b_gs_ns_ks_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
280 const auto b_ns_ks_strides = to_tuple(
281 b_gs_ns_ks_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimN + NumDimK>{});
282
283 // dimension Ids for N0, N1, ...
284 constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
285
286 // dimension Ids for K0, K1, ...
287 constexpr auto kDimIds =
289
290 // lengths for K0, K1, ...
291 const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds);
292
293 // lengths for N0, N1, ...
294 const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds);
295
296 if constexpr(BSpec == TensorSpecialization::Packed)
297 {
298 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
299 auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{});
300 const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor(
301 make_tuple(N, K),
302 make_tuple(b_ns_ks_strides[Number<NumDimN - 1>{}],
303 b_ns_ks_strides[Number<NumDimN + NumDimK - 1>{}]));
304 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
305 }
306 else
307 {
308 // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
309 const auto b_grid_desc_ns_ks =
310 make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides);
311
312 // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
313 const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor(
314 b_grid_desc_ns_ks,
316 make_tuple(nDimIds, kDimIds),
318
319 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
320 }
321 }
322
323 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
324 static auto MakeEGridDescriptor_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
325 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
326 {
327 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
328 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
329
330 const auto to_tuple = [&](auto& vec, auto start, auto end) {
331 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
332 };
333
334 const auto e_ms_ns_lengths = to_tuple(
335 e_gs_ms_ns_lengths_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
336 const auto e_ms_ns_strides = to_tuple(
337 e_gs_ms_ns_strides_vec, Number<NumDimG>{}, Number<NumDimG + NumDimM + NumDimN>{});
338
339 // dimension Ids for M0, M1, ...
340 constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
341
342 // dimension Ids for N0, N1, ...
343 constexpr auto nDimIds =
345
346 // lengths for M0, M1, ...
347 const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds);
348
349 // lengths for K0, K1, ...
350 const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds);
351
352 if constexpr(DESpec == TensorSpecialization::Packed)
353 {
354 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
355 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
356 const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor(
357 make_tuple(M, N),
358 make_tuple(e_ms_ns_strides[Number<NumDimM - 1>{}],
359 e_ms_ns_strides[Number<NumDimM + NumDimN - 1>{}]));
360 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
361 }
362 else
363 {
364 // naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
365 const auto e_grid_desc_ms_ns =
366 make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides);
367
368 // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
369 const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor(
370 e_grid_desc_ms_ns,
372 make_tuple(mDimIds, nDimIds),
374
375 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
376 }
377 }
378
379 // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
380 static auto MakeEGridDescriptor_G_M_N(const std::vector<index_t>& e_gs_ms_ns_lengths_vec,
381 const std::vector<index_t>& e_gs_ms_ns_strides_vec)
382 {
383 assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN &&
384 e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN);
385
386 const auto to_tuple = [&](auto& vec, auto start, auto end) {
387 return generate_tuple([&](auto i) { return vec[start + i]; }, Number<end - start>{});
388 };
389
390 const auto e_gs_ms_ns_lengths =
391 to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
392 const auto e_gs_ms_ns_strides =
393 to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number<NumDimG + NumDimM + NumDimN>{});
394
395 // dimension Ids for G0, G1, ...
396 constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{};
397
398 // dimension Ids for M0, M1, ...
399 constexpr auto mDimIds =
401
402 // dimension Ids for N0, N1, ...
403 constexpr auto nDimIds = typename arithmetic_sequence_gen<NumDimG + NumDimM,
404 NumDimG + NumDimM + NumDimN,
405 1>::type{};
406
407 // lengths for G0, G1, ...
408 const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds);
409
410 // lengths for M0, M1, ...
411 const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds);
412
413 // lengths for K0, K1, ...
414 const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds);
415
416 if constexpr(DESpec == TensorSpecialization::Packed)
417 {
418 auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{});
419 auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{});
420 auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{});
421 const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor(
422 make_tuple(G, M, N),
423 make_tuple(e_gs_ms_ns_strides[Number<NumDimG - 1>{}],
424 e_gs_ms_ns_strides[Number<NumDimG + NumDimM - 1>{}],
425 e_gs_ms_ns_strides[Number<NumDimG + NumDimM + NumDimN - 1>{}]));
426 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
427 return e_grid_desc_g_mraw_nraw;
428 }
429 else
430 {
431 // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
432 const auto e_grid_desc_gs_ms_ns =
433 make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides);
434
435 // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
436 // N2 * ...]
437 const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor(
438 e_grid_desc_gs_ms_ns,
440 make_merge_transform(mLengths),
441 make_merge_transform(nLengths)),
442 make_tuple(gDimIds, mDimIds, nDimIds),
444
445 // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
446 return e_grid_desc_g_mraw_nraw;
447 }
448 }
449
451 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
452 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
453 {
454 return generate_tuple(
455 [&](auto i) {
456 return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i],
457 ds_gs_ms_ns_strides_vec[i]);
458 },
460 }
461
463 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths_vec,
464 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides_vec)
465 {
466 return generate_tuple(
467 [&](auto i) {
468 return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i],
469 ds_gs_ms_ns_strides_vec[i]);
470 },
472 }
473
474 using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {}));
475 using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {}));
477 using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {}));
478
481
483 {
485 index_t batch_stride_B,
486 DsGridDesc_G_M_N ds_grid_desc_g_m_n,
487 EGridDesc_G_M_N e_grid_desc_g_m_n)
488 : batch_stride_A_(batch_stride_A),
489 batch_stride_B_(batch_stride_B),
490 ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n),
491 e_grid_desc_g_m_n_(e_grid_desc_g_m_n)
492 {
493 }
494
495 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
496 {
497 return g_idx * static_cast<long_index_t>(batch_stride_A_);
498 }
499
500 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
501 {
502 return g_idx * static_cast<long_index_t>(batch_stride_B_);
503 }
504
505 __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
506 {
507 std::array<long_index_t, NumDTensor> ds_offset;
508
509 static_for<0, NumDTensor, 1>{}([&](auto i) {
510 ds_offset[i] = static_cast<long_index_t>(g_idx) *
511 ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0));
512 });
513
514 return ds_offset;
515 }
516
517 __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
518 {
519 return static_cast<long_index_t>(g_idx) *
520 e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0));
521 }
522
523 private:
524 index_t batch_stride_A_;
525 index_t batch_stride_B_;
526 DsGridDesc_G_M_N ds_grid_desc_g_m_n_;
527 EGridDesc_G_M_N e_grid_desc_g_m_n_;
528 };
529
530 // GridwiseGemm
531 template <index_t NXdlPerWave_>
533 ADataType, // TODO: distinguish A/B datatype
534 AccDataType,
535 CShuffleDataType,
536 DsDataType,
537 EDataType,
538 AElementwiseOperation,
539 BElementwiseOperation,
540 CDEElementwiseOperation,
546 NumGemmKPrefetchStage,
547 BlockSize,
548 MPerBlock,
549 NPerBlock,
550 KPerBlock,
551 AK1,
552 BK1,
553 MPerXDL,
554 NPerXDL,
555 MXdlPerWave,
556 NXdlPerWave_,
557 ABlockTransferThreadClusterLengths_AK0_M_AK1,
558 ABlockTransferThreadClusterArrangeOrder,
559 ABlockTransferSrcAccessOrder,
560 ABlockTransferSrcVectorDim,
561 ABlockTransferSrcScalarPerVector,
562 ABlockTransferDstScalarPerVector_AK1,
563 false,
564 ABlockLdsExtraM,
565 BBlockTransferThreadClusterLengths_BK0_N_BK1,
566 BBlockTransferThreadClusterArrangeOrder,
567 BBlockTransferSrcAccessOrder,
568 BBlockTransferSrcVectorDim,
569 BBlockTransferSrcScalarPerVector,
570 BBlockTransferDstScalarPerVector_BK1,
571 false,
572 BBlockLdsExtraN,
573 CShuffleMXdlPerWavePerShuffle,
574 CShuffleNXdlPerWavePerShuffle,
575 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
576 CDEBlockTransferScalarPerVector_NPerBlock,
577 LoopSched>;
580
581 // GridwiseGemm
582 template <index_t NXdlPerWave_>
584 ADataType, // TODO: distinguish A/B datatype
585 AccDataType,
586 CShuffleDataType,
587 DsDataType,
588 EDataType,
589 AElementwiseOperation,
590 BElementwiseOperation,
591 CDEElementwiseOperation,
597 NumGemmKPrefetchStage,
598 BlockSize,
599 MPerBlock,
600 NPerBlock,
601 KPerBlock,
602 AK1,
603 BK1,
604 MPerXDL,
605 NPerXDL,
606 MXdlPerWave,
607 NXdlPerWave_,
608 ABlockTransferThreadClusterLengths_AK0_M_AK1,
609 ABlockTransferThreadClusterArrangeOrder,
610 ABlockTransferSrcAccessOrder,
611 ABlockTransferSrcVectorDim,
612 ABlockTransferSrcScalarPerVector,
613 ABlockTransferDstScalarPerVector_AK1,
614 false,
615 ABlockLdsExtraM,
616 BBlockTransferThreadClusterLengths_BK0_N_BK1,
617 BBlockTransferThreadClusterArrangeOrder,
618 BBlockTransferSrcAccessOrder,
619 BBlockTransferSrcVectorDim,
620 BBlockTransferSrcScalarPerVector,
621 BBlockTransferDstScalarPerVector_BK1,
622 false,
623 BBlockLdsExtraN,
624 CShuffleMXdlPerWavePerShuffle,
625 CShuffleNXdlPerWavePerShuffle,
626 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
627 CDEBlockTransferScalarPerVector_NPerBlock,
628 LoopSched>;
631
634 AGridDesc_M_K{}, 1))>;
637 BGridDesc_N_K{}, 1))>;
638
640
641 // Argument
642 struct Argument : public BaseArgument
643 {
644 template <typename GridwiseGemm>
646 {
647 if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_,
652 {
654 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
656
658 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
660 }
661 }
662 Argument(const void* p_a_grid,
663 const void* p_b_grid,
664 std::array<const void*, NumDTensor> p_ds_grid,
665 void* p_e_grid,
666 const std::vector<index_t>& a_gs_ms_ns_lengths,
667 const std::vector<index_t>& a_gs_ms_ks_strides,
668 const std::vector<index_t>& b_gs_ns_ks_lengths,
669 const std::vector<index_t>& b_gs_ns_ks_strides,
670 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
671 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
672 const std::vector<index_t>& e_gs_ms_ns_lengths,
673 const std::vector<index_t>& e_gs_ms_ns_strides,
674 AElementwiseOperation a_element_op,
675 BElementwiseOperation b_element_op,
676 CDEElementwiseOperation cde_element_op,
677 index_t split_k)
678 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
679 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
680 p_ds_grid_{},
681 p_e_grid_{static_cast<EDataType*>(p_e_grid)},
683 DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)},
685 DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)},
688 DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
690 DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)},
692 DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)},
693 a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm64::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(
694 a_grid_desc_m_k_, split_k)},
695 b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm64::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(
696 b_grid_desc_n_k_, split_k)},
700 GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)},
701 a_element_op_{a_element_op},
702 b_element_op_{b_element_op},
703 cde_element_op_{cde_element_op},
704 a_mz_stride_{},
705 a_kz_stride_{},
706 b_nz_stride_{},
707 b_kz_stride_{},
709 e_nz_stride_{},
710 a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]},
711 b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]},
714 split_k_{split_k}
715 {
716 static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, "");
717
718 // populate pointer, batch stride, desc for Ds
719 static_for<0, NumDTensor, 1>{}([&](auto i) {
720 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
721
722 // D pointer
723 p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
724
725 // D desc
726 ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i],
727 ds_gs_ms_ns_strides[i]);
728 });
729
730 // populate desc for Ds/E
731 if(get_warp_size() == 64)
732 {
733 if constexpr(NXdlPerWave64 > 0)
734 {
736 }
737 }
738 else
739 {
740 if constexpr(NXdlPerWave32 > 0)
741 {
743 }
744 }
745
746 // for sanity check of vector memory access
747 a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1];
748 a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1];
749 b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1];
750 b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1];
751
752 for(index_t i = 0; i < NumDTensor; ++i)
753 {
754 ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1];
755 }
756
757 e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1];
758
759 Print();
760 }
761
762 void Print() const
763 {
764 std::cout << "A[M, K]: " << a_grid_desc_m_k_.GetLength(I0) << ", "
765 << a_grid_desc_m_k_.GetLength(I1) << std::endl;
766 std::cout << "B[N, K]: " << b_grid_desc_n_k_.GetLength(I0) << ", "
767 << b_grid_desc_n_k_.GetLength(I1) << std::endl;
768
769 std::cout << "A[akb, ak0, m, ak1]: " << a_grid_desc_akb_ak0_m_ak1_.GetLength(I0) << ", "
770 << a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) << ", "
771 << a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) << ", "
772 << a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) << std::endl;
773 std::cout << "B[bkb, bk0, n, bk1]: " << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I0) << ", "
774 << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I1) << ", "
775 << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) << ", "
776 << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) << std::endl;
777 static_for<0, NumDTensor, 1>{}([&](auto i) {
778 std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i].GetLength(I0) << ", "
779 << ds_grid_desc_m_n_[i].GetLength(I1) << std::endl;
780 });
781 std::cout << "E[M, N]: " << e_grid_desc_m_n_.GetLength(I0) << ", "
782 << e_grid_desc_m_n_.GetLength(I1) << std::endl;
783 }
784
785 // private:
786 // pointers
787 const ADataType* p_a_grid_;
788 const BDataType* p_b_grid_;
790 EDataType* p_e_grid_;
791
792 // tensor descriptors for problem definiton
797
800
801 // tensor descriptors for block/thread-wise copy
808
809 // block-to-e-tile map
811
812 // element-wise op
813 AElementwiseOperation a_element_op_;
814 BElementwiseOperation b_element_op_;
815 CDEElementwiseOperation cde_element_op_;
816
817 // Strides for the last M/N/K dimensions of A/B/Ds/E
818 // for sanity check of vector load/store
823 std::array<index_t, NumDTensor> ds_nz_stride_;
826
829
831
833 };
834
835 // Invoker
836 struct Invoker : public BaseInvoker
837 {
839
840 template <typename GridwiseGemm>
841 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
842 {
843 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_,
848 {
849 throw std::runtime_error(
850 "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting");
851 }
852
853 using GridwiseGemmAtomicAdd =
854 std::conditional_t<std::is_same_v<GridwiseGemm, GridwiseGemm64>,
857
858 const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0);
859
860 const index_t grid_size =
861 arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G;
862
863 const auto K = arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) *
864 arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3);
865
866 auto launch_kernel = [&](auto has_main_k_block_loop) {
867 constexpr bool has_main_loop = has_main_k_block_loop.value;
868
870 GridwiseGemm,
871 ADataType, // TODO: distiguish A/B datatype
872 typename GridwiseGemm::DsGridPointer,
873 EDataType,
874 AElementwiseOperation,
875 BElementwiseOperation,
876 CDEElementwiseOperation,
879 typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
880 typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
881 ComputePtrOffsetOfStridedBatch,
882 typename GridwiseGemm::DefaultBlock2ETileMap,
883 has_main_loop>;
884
885 return launch_and_time_kernel(stream_config,
886 kernel,
887 dim3(grid_size),
888 dim3(BlockSize),
889 0,
890 arg.p_a_grid_,
891 arg.p_b_grid_,
892 arg.p_ds_grid_,
893 arg.p_e_grid_,
894 G,
895 arg.a_element_op_,
896 arg.b_element_op_,
897 arg.cde_element_op_,
904 };
905
906 auto launch_kernel_atomic_add = [&](auto has_main_k_block_loop) {
907 constexpr bool has_main_loop = has_main_k_block_loop.value;
908
910 GridwiseGemmAtomicAdd,
911 ADataType, // TODO: distiguish A/B datatype
912 typename GridwiseGemmAtomicAdd::DsGridPointer,
913 EDataType,
914 AElementwiseOperation,
915 BElementwiseOperation,
916 CDEElementwiseOperation,
919 typename GridwiseGemmAtomicAdd::
920 DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
921 typename GridwiseGemmAtomicAdd::
922 EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
923 ComputePtrOffsetOfStridedBatch,
924 typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap,
925 has_main_loop>;
926
927 hipGetErrorString(hipMemsetAsync(
928 arg.p_e_grid_,
929 0,
930 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
931 sizeof(EDataType),
932 stream_config.stream_id_));
933
934 return launch_and_time_kernel(stream_config,
935 kernel,
936 dim3(grid_size),
937 dim3(BlockSize),
938 0,
939 arg.p_a_grid_,
940 arg.p_b_grid_,
941 arg.p_ds_grid_,
942 arg.p_e_grid_,
943 G,
944 arg.a_element_op_,
945 arg.b_element_op_,
946 arg.cde_element_op_,
953 };
954
955 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
956 {
957 if(arg.split_k_ <= 1)
958 return launch_kernel(integral_constant<bool, true>{});
959 else
960 return launch_kernel_atomic_add(integral_constant<bool, true>{});
961 }
962 else
963 {
964 if(arg.split_k_ <= 1)
965 return launch_kernel(integral_constant<bool, false>{});
966 else
967 return launch_kernel_atomic_add(integral_constant<bool, false>{});
968 }
969 }
970
972
973 // polymorphic
974 float Run(const BaseArgument* p_arg,
975 const StreamConfig& stream_config = StreamConfig{}) override
976 {
977 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
978 }
979 };
980
981 static bool IsSupportedArgument(const Argument& arg)
982 {
984 {
985 return false;
986 }
987 bool valid = false;
988 if(get_warp_size() == 64)
989 {
990 if constexpr(NXdlPerWave64 > 0)
991 {
997 }
998 }
999 else
1000 {
1001 if constexpr(NXdlPerWave32 > 0)
1002 {
1006 arg.e_grid_desc_m_n_,
1007 arg.block_2_etile_map_);
1008 }
1009 }
1010 if(!valid)
1011 return false;
1012
1013 // check vector access
1014 static_assert((ABlockTransferSrcVectorDim == 2 || ABlockTransferSrcVectorDim == 3) &&
1015 (BBlockTransferSrcVectorDim == 2 || BBlockTransferSrcVectorDim == 3),
1016 "wrong!");
1017
1018 // vector memory access of A: could be on M or AK1 dimension
1019 if constexpr(ABlockTransferSrcVectorDim == 2)
1020 {
1021 if(!(arg.a_mz_stride_ == 1 &&
1022 arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector ==
1023 0))
1024 {
1025 return false;
1026 }
1027 }
1028 else
1029 {
1030 if(!(arg.a_kz_stride_ == 1 &&
1031 arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) % ABlockTransferSrcScalarPerVector ==
1032 0))
1033 {
1034 return false;
1035 }
1036 }
1037
1038 // vector memory access of B: could be on N or BK1 dimension
1039 if constexpr(BBlockTransferSrcVectorDim == 2)
1040 {
1041 if(!(arg.b_nz_stride_ == 1 &&
1042 arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector ==
1043 0))
1044 {
1045 return false;
1046 }
1047 }
1048 else
1049 {
1050 if(!(arg.b_kz_stride_ == 1 &&
1051 arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) % BBlockTransferSrcScalarPerVector ==
1052 0))
1053 {
1054 return false;
1055 }
1056 }
1057
1058 // vector memory access of Ds: always on NPerBlock dimension
1059 bool valid_d_access = true;
1060
1061 static_for<0, NumDTensor, 1>{}([&](auto i) {
1062 if(!(arg.ds_nz_stride_[i] == 1 &&
1064 CDEBlockTransferScalarPerVector_NPerBlock ==
1065 0))
1066 {
1067 valid_d_access = false;
1068 }
1069 });
1070
1071 if(valid_d_access == false)
1072 {
1073 return false;
1074 }
1075
1076 // vector memory access of E: always on NPerBlock dimension
1077 if(!((arg.e_nz_stride_ == 1 &&
1079 CDEBlockTransferScalarPerVector_NPerBlock ==
1080 0) ||
1081 CDEBlockTransferScalarPerVector_NPerBlock == 1))
1082 {
1083 return false;
1084 }
1085
1086 return true;
1087 }
1088
1089 // polymorphic
1090 bool IsSupportedArgument(const BaseArgument* p_arg) override
1091 {
1092 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1093 }
1094
1095 static auto
1096 MakeArgument(const void* p_a,
1097 const void* p_b,
1098 std::array<const void*, NumDTensor> p_ds,
1099 void* p_e,
1100 const std::vector<index_t>& a_gs_ms_ns_lengths,
1101 const std::vector<index_t>& a_gs_ms_ks_strides,
1102 const std::vector<index_t>& b_gs_ns_ks_lengths,
1103 const std::vector<index_t>& b_gs_ns_ks_strides,
1104 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
1105 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
1106 const std::vector<index_t>& e_gs_ms_ns_lengths,
1107 const std::vector<index_t>& e_gs_ms_ns_strides,
1108 AElementwiseOperation a_element_op,
1109 BElementwiseOperation b_element_op,
1110 CDEElementwiseOperation cde_element_op,
1111 index_t split_k)
1112 {
1113 return Argument{p_a,
1114 p_b,
1115 p_ds,
1116 p_e,
1117 a_gs_ms_ns_lengths,
1118 a_gs_ms_ks_strides,
1119 b_gs_ns_ks_lengths,
1120 b_gs_ns_ks_strides,
1121 ds_gs_ms_ns_lengths,
1122 ds_gs_ms_ns_strides,
1123 e_gs_ms_ns_lengths,
1124 e_gs_ms_ns_strides,
1125 a_element_op,
1126 b_element_op,
1127 cde_element_op,
1128 split_k};
1129 }
1130
1131 static auto MakeInvoker() { return Invoker{}; }
1132
1133 // polymorphic
1134 std::unique_ptr<BaseArgument>
1135 MakeArgumentPointer(const void* p_a,
1136 const void* p_b,
1137 std::array<const void*, NumDTensor> p_ds,
1138 void* p_e,
1139 const std::vector<index_t>& a_gs_ms_ns_lengths,
1140 const std::vector<index_t>& a_gs_ms_ks_strides,
1141 const std::vector<index_t>& b_gs_ns_ks_lengths,
1142 const std::vector<index_t>& b_gs_ns_ks_strides,
1143 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
1144 const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
1145 const std::vector<index_t>& e_gs_ms_ns_lengths,
1146 const std::vector<index_t>& e_gs_ms_ns_strides,
1147 AElementwiseOperation a_element_op,
1148 BElementwiseOperation b_element_op,
1149 CDEElementwiseOperation cde_element_op,
1150 index_t split_k) override
1151 {
1152 return std::make_unique<Argument>(p_a,
1153 p_b,
1154 p_ds,
1155 p_e,
1156 a_gs_ms_ns_lengths,
1157 a_gs_ms_ks_strides,
1158 b_gs_ns_ks_lengths,
1159 b_gs_ns_ks_strides,
1160 ds_gs_ms_ns_lengths,
1161 ds_gs_ms_ns_strides,
1162 e_gs_ms_ns_lengths,
1163 e_gs_ms_ns_strides,
1164 a_element_op,
1165 b_element_op,
1166 cde_element_op,
1167 split_k);
1168 }
1169
1170 // polymorphic
1171 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1172 {
1173 return std::make_unique<Invoker>(Invoker{});
1174 }
1175
1176 // polymorphic
1177 std::string GetTypeString() const override
1178 {
1179 auto str = std::stringstream();
1180
1181 // clang-format off
1182 str << "DeviceSplitKContractionMultipleD_Xdl_CShuffle"
1183 << "<"
1184 << NumDimG << ", "
1185 << NumDimM << ", "
1186 << NumDimN << ", "
1187 << NumDimK << ", "
1188 << BlockSize << ", "
1189 << MPerBlock << ", "
1190 << NPerBlock << ", "
1191 << KPerBlock << ", "
1192 << AK1 << ", "
1193 << BK1 << ", "
1194 << ABlockTransferSrcVectorDim << ", "
1195 << BBlockTransferSrcVectorDim
1196 << ">";
1197 // clang-format on
1198
1199 return str.str();
1200 }
1201};
1202
1203} // namespace device
1204} // namespace tensor_operation
1205} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
TensorSpecialization
Definition tensor_specialization.hpp:11
@ Packed
Definition tensor_specialization.hpp:13
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition get_id.hpp:49
__global__ void kernel_contraction_multiple_d_xdl_cshuffle(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatDsPointer p_ds_grid, FloatE *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition device_batched_contraction_multiple_d_xdl_cshuffle.hpp:41
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
__host__ __device__ constexpr auto get_container_subset(const Array< T, N > &arr, Sequence< Is... >)
Definition utility/container_helper.hpp:346
__host__ __device__ constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition utility/container_helper.hpp:111
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp:76
Definition utility/sequence.hpp:43
Definition utility/sequence.hpp:256
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition utility/integral_constant.hpp:20
Definition utility/math.hpp:34
Definition functional2.hpp:33
Definition device_base.hpp:197
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:517
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:495
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:500
ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, index_t batch_stride_B, DsGridDesc_G_M_N ds_grid_desc_g_m_n, EGridDesc_G_M_N e_grid_desc_g_m_n)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:484
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:505
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:643
const BDataType * p_b_grid_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:788
BGridDesc_N_K b_grid_desc_n_k_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:794
std::array< index_t, NumDTensor > ds_nz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:823
void Print() const
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:762
EDataType * p_e_grid_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:790
index_t b_nz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:821
BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:803
index_t b_kz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:822
index_t e_mz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:824
AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:802
Block2ETileMap block_2_etile_map_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:810
GridwiseGemm64::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:807
AGridDesc_M_K a_grid_desc_m_k_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:793
index_t a_batch_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:827
EGridDesc_M_N e_grid_desc_m_n_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:796
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:830
GridwiseGemm64::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:805
void init_ds_e_grid_desc()
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:645
index_t split_k_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:832
EGridDesc_G_M_N e_grid_desc_g_m_n_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:799
AElementwiseOperation a_element_op_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:813
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:789
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t split_k)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:662
index_t b_batch_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:828
DsGridDesc_G_M_N ds_grid_desc_g_m_n_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:798
const ADataType * p_a_grid_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:787
index_t e_nz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:825
index_t a_mz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:819
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:795
CDEElementwiseOperation cde_element_op_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:815
BElementwiseOperation b_element_op_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:814
index_t a_kz_stride_
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:820
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:837
DeviceOp::Argument Argument
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:838
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:841
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:974
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:196
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t split_k) override
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1135
static auto MakeDsGridDescriptor_G_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:462
static constexpr auto I3
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:206
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:199
static auto MakeEGridDescriptor_G_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:380
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( AGridDesc_M_K{}, 1))> AGridDesc_AKB_AK0_M_AK1
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:632
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, index_t split_k)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1096
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({{}}, {{}}))> DsGridDesc_M_N
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:476
static auto MakeAGridDescriptor_M_K(const std::vector< index_t > &a_gs_ms_ks_lengths_vec, const std::vector< index_t > &a_gs_ms_ks_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:212
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1171
GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmBase
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:532
decltype(MakeAGridDescriptor_M_K({}, {})) AGridDesc_M_K
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:474
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:579
static bool IsSupportedArgument(const Argument &arg)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:981
static auto MakeInvoker()
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1131
static auto MakeDsGridDescriptor_M_N(const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths_vec, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:450
static constexpr auto NXdlPerWave32
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:200
GridwiseGemmAtomicAddBase< NXdlPerWave32 > GridwiseGemmAtomicAdd32
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:630
GridwiseGemmAtomicAddBase< math::max(NXdlPerWave64, 1)> GridwiseGemmAtomicAdd64
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:629
static auto MakeEGridDescriptor_M_N(const std::vector< index_t > &e_gs_ms_ns_lengths_vec, const std::vector< index_t > &e_gs_ms_ns_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:324
static constexpr index_t NumDTensor
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:201
static auto MakeBGridDescriptor_N_K(const std::vector< index_t > &b_gs_ns_ks_lengths_vec, const std::vector< index_t > &b_gs_ns_ks_strides_vec)
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:268
decltype(MakeEGridDescriptor_G_M_N({}, {})) EGridDesc_G_M_N
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:480
static constexpr auto matrix_padder
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:208
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( BGridDesc_N_K{}, 1))> BGridDesc_BKB_BK0_N_BK1
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:635
static constexpr auto I1
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:204
decltype(MakeBGridDescriptor_N_K({}, {})) BGridDesc_N_K
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:475
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:578
typename GridwiseGemm64::DefaultBlock2ETileMap Block2ETileMap
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:639
DeviceSplitKContractionMultipleD_Xdl_CShuffle DeviceOp
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:197
static constexpr auto I2
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:205
remove_cvref_t< decltype(MakeDsGridDescriptor_G_M_N({}, {}))> DsGridDesc_G_M_N
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:479
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1090
GridwiseGemmSplitKMultipleD_xdl_cshuffle< ADataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::AtomicAdd, AGridDesc_M_K, BGridDesc_N_K, DsGridDesc_M_N, EGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemmAtomicAddBase
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:583
static constexpr auto I0
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:203
std::string GetTypeString() const override
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:1177
decltype(MakeEGridDescriptor_M_N({}, {})) EGridDesc_M_N
Definition device_splitk_contraction_multiple_d_xdl_cshuffle.hpp:477
Definition device_splitk_contraction_multiple_d.hpp:39
Definition matrix_padder.hpp:180