24template <
typename ALayout,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
46 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t ABlockTransferDstScalarPerVector_AK1,
53 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 index_t BBlockTransferSrcVectorDim,
57 index_t BBlockTransferSrcScalarPerVector,
58 index_t BBlockTransferDstScalarPerVector_BK1,
60 index_t CShuffleMXdlPerWavePerShuffle,
61 index_t CShuffleNXdlPerWavePerShuffle,
62 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66 typename ComputeTypeA = CDataType,
67 typename ComputeTypeB = ComputeTypeA,
68 bool PermuteA =
false,
69 bool PermuteB =
false>
76 AElementwiseOperation,
77 BElementwiseOperation,
78 CElementwiseOperation>
85 template <index_t NXdlPerWave_>
95 AElementwiseOperation,
96 BElementwiseOperation,
97 CElementwiseOperation,
109 ABlockTransferThreadClusterLengths_AK0_M_AK1,
110 ABlockTransferThreadClusterArrangeOrder,
111 ABlockTransferSrcAccessOrder,
112 ABlockTransferSrcVectorDim,
113 ABlockTransferSrcScalarPerVector,
114 ABlockTransferDstScalarPerVector_AK1,
117 BBlockTransferThreadClusterLengths_BK0_N_BK1,
118 BBlockTransferThreadClusterArrangeOrder,
119 BBlockTransferSrcAccessOrder,
120 BBlockTransferSrcVectorDim,
121 BBlockTransferSrcScalarPerVector,
122 BBlockTransferDstScalarPerVector_BK1,
125 CShuffleMXdlPerWavePerShuffle,
126 CShuffleNXdlPerWavePerShuffle,
127 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
128 CShuffleBlockTransferScalarPerVector_NPerBlock,
138 using Argument =
typename GridwiseGemm64::Argument;
159 template <
typename Gr
idwiseGemm>
160 float RunImp(
const typename GridwiseGemm::Argument& arg,
163 if(stream_config.log_level_ > 0)
166 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
169 if(!GridwiseGemm::CheckValidity(arg))
171 throw std::runtime_error(
"wrong! GridwiseGemm has invalid setting");
175 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
179 index_t k_grain = arg.KBatch * KPerBlock;
180 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
182 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
184 const auto Run = [&](
const auto& kernel) {
185 if(stream_config.flush_cache)
189 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
190 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
191 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
192 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
194 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
196 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
200 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
201 rotating_mem.Print();
203 auto run_flush_cache = [&]() {
210 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
212 arg_.M * arg_.N *
sizeof(CDataType),
213 stream_config.stream_id_));
228 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
230 arg.M * arg.N *
sizeof(CDataType),
231 stream_config.stream_id_));
234 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
238 constexpr auto estimated_reg_a = MPerBlock * KPerBlock *
sizeof(ADataType) / BlockSize /
239 4 * (1 + GridwiseGemm::NWave);
240 constexpr auto estimated_reg_b =
241 NPerBlock * KPerBlock *
sizeof(BDataType) / BlockSize / 4 * (2);
242 constexpr auto estimated_reg_c =
243 MPerBlock * NPerBlock *
sizeof(GemmAccDataType) / BlockSize / 4;
244 constexpr auto estimated_reg_total =
245 estimated_reg_a + estimated_reg_b + estimated_reg_c;
247 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
249 if(has_main_k_block_loop)
256 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
279 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
353 throw std::runtime_error(
"Only support pipeline ver v1, v2, v3 now!");
395 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
428 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
445 reinterpret_cast<const typename GridwiseGemm32::Argument&
>(arg));
463 const BDataType* p_b,
472 AElementwiseOperation,
473 BElementwiseOperation,
474 CElementwiseOperation)
476 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
492 AElementwiseOperation,
493 BElementwiseOperation,
494 CElementwiseOperation)
override
496 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
497 static_cast<const BDataType*
>(p_b),
498 static_cast<CDataType*
>(p_c),
511 return std::make_unique<Invoker>(
Invoker{});
517 auto str = std::stringstream();
519 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
523 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
531 str <<
"DeviceGemmXdlUniversal"
534 << std::string(ALayout::name)[0]
535 << std::string(BLayout::name)[0]
536 << std::string(CLayout::name)[0]
541 << MPerBlock<<
"x"<<NPerBlock<<
"x"<<KPerBlock <<
", "
543 << MPerXDL<<
"x"<<NPerXDL <<
", "
545 << MXdlPerWave<<
"x" << NXdlPerWave<<
", "
547 << ABlockTransferSrcScalarPerVector<<
"x"<<BBlockTransferSrcScalarPerVector<<
", "
548 <<
"BlkGemmPipelineScheduler: "
549 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
550 <<
"BlkGemmPipelineVersion: "
551 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
552 <<
"BlkGemmPipelinePrefetchStages: "
553 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages <<
", "
555 << GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:75
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:36
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:157
ck::GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB >::CheckValidity static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:893
Definition data_type.hpp:187
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:158
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:392
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:160
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:79
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:136
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:479
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:509
bool GetPermuteB() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:460
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:462
bool GetPermuteA() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:459
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:138
static constexpr index_t APackedSize
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:140
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:515
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:135
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:81
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:82
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:405
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:399
GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:86
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:452
static constexpr index_t BPackedSize
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:147
index_t GetKPerBlock() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:457
int GetPreShuffleParameters() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:154
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:482
Definition device_gemm_v2.hpp:127
Definition flush_cache.hpp:299