26template <
typename GridwiseGemm,
28 typename AElementwiseOperation,
29 typename BElementwiseOperation,
30 typename CDEElementwiseOperation,
31 bool HasMainKBlockLoop>
33#if CK_USE_LAUNCH_BOUNDS
38 const AElementwiseOperation a_element_op,
39 const BElementwiseOperation b_element_op,
40 const CDEElementwiseOperation c_element_op)
42#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
45 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
49 const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*
>(
55 while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
56 block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
59 if(block_id < gemm_desc_ptr[group_id].BlockStart_)
67 group_id =
index_t((left + right) / 2);
70 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
71 gemm_desc_ptr[group_id].a_ptr_,
72 gemm_desc_ptr[group_id].b_ptr_,
73 gemm_desc_ptr[group_id].ds_ptr_,
74 gemm_desc_ptr[group_id].e_ptr_,
79 gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
80 gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
81 gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
82 gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
83 gemm_desc_ptr[group_id].block_2_etile_map_);
94template <
typename ALayout,
100 typename AccDataType,
101 typename CShuffleDataType,
104 typename AElementwiseOperation,
105 typename BElementwiseOperation,
106 typename CDEElementwiseOperation,
119 typename ABlockTransferThreadClusterLengths_K0_M_K1,
120 typename ABlockTransferThreadClusterArrangeOrder,
121 typename ABlockTransferSrcAccessOrder,
125 bool ABlockLdsExtraM,
126 typename BBlockTransferThreadClusterLengths_K0_N_K1,
127 typename BBlockTransferThreadClusterArrangeOrder,
128 typename BBlockTransferSrcAccessOrder,
132 bool BBlockLdsExtraN,
133 index_t CShuffleMXdlPerWavePerShuffle,
134 index_t CShuffleNXdlPerWavePerShuffle,
135 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
136 index_t CDEBlockTransferScalarPerVector_NPerBlock,
146 AElementwiseOperation,
147 BElementwiseOperation,
148 CDEElementwiseOperation>
150 using DeviceOp = DeviceGroupedGemm_Xdl;
152 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
153 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
154 static constexpr index_t NumDTensor = DsDataType::Size();
156 static constexpr auto I0 = Number<0>{};
157 static constexpr auto I1 = Number<1>{};
158 static constexpr auto I2 = Number<2>{};
160 static constexpr auto matrix_padder =
161 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
163 static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
165 const auto a_grid_desc_mraw_kraw = [&]() {
166 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
171 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
178 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
181 static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
183 const auto b_grid_desc_nraw_kraw = [&]() {
184 if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
189 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
196 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
199 template <
typename ELay>
200 static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
202 const auto e_grid_desc_mraw_nraw = [&]() {
203 if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
208 else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
215 return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
218 static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NumDTensor>& MRaws,
219 const std::array<index_t, NumDTensor>& NRaws,
220 const std::array<index_t, NumDTensor>& DsStride)
226 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
228 Number<NumDTensor>{});
231 using AGridDesc_M_K =
decltype(MakeAGridDescriptor_M_K(1, 1, 1));
232 using BGridDesc_N_K =
decltype(MakeBGridDescriptor_N_K(1, 1, 1));
233 using DsGridDesc_M_N =
remove_cvref_t<
decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
234 using EGridDesc_M_N =
decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
236 using ComputeDataType = ADataType;
239 template <index_t NXdlPerWave_>
240 using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
248 AElementwiseOperation,
249 BElementwiseOperation,
250 CDEElementwiseOperation,
262 ABlockTransferThreadClusterLengths_K0_M_K1,
263 ABlockTransferThreadClusterArrangeOrder,
264 ABlockTransferSrcAccessOrder,
265 ABlockTransferSrcVectorDim,
266 ABlockTransferSrcScalarPerVector,
267 ABlockTransferDstScalarPerVector_K1,
270 BBlockTransferThreadClusterLengths_K0_N_K1,
271 BBlockTransferThreadClusterArrangeOrder,
272 BBlockTransferSrcAccessOrder,
273 BBlockTransferSrcVectorDim,
274 BBlockTransferSrcScalarPerVector,
275 BBlockTransferDstScalarPerVector_K1,
278 CShuffleMXdlPerWavePerShuffle,
279 CShuffleNXdlPerWavePerShuffle,
280 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
281 CDEBlockTransferScalarPerVector_NPerBlock,
283 using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
284 using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
286 using AGridDesc_AK0_M_AK1 =
287 remove_cvref_t<
decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(
289 using BGridDesc_BK0_N_BK1 =
290 remove_cvref_t<
decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(
292 using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<
293 decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
295 using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<
296 decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
299 struct GroupedGemmBlock2ETileMap
301 using Block2ETileMap =
302 remove_cvref_t<
decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
304 GroupedGemmBlock2ETileMap()
306 block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{});
310 GroupedGemmBlock2ETileMap(
const EGridDesc_M_N& e_grid_desc_m_n,
ck::index_t BlockStart)
312 block_2_etile_map_ = GridwiseGemm64::MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
313 BlockStart_ = BlockStart;
316 template <
typename TopIdx>
317 __host__ __device__
constexpr auto CalculateBottomIndex(
const TopIdx& idx_top)
const
319 return block_2_etile_map_.CalculateBottomIndex(
324 template <
typename CTileIdx,
typename CTileDim>
325 __host__ __device__
bool ValidCTileIndex(
const CTileIdx& c_tile_idx,
326 const CTileDim& c_tile_dim)
const
328 return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim);
331 __host__
bool CheckValidity(
const EGridDesc_M_N& e_grid_desc_m_n)
const
333 return block_2_etile_map_.CheckValidity(e_grid_desc_m_n);
336 Block2ETileMap block_2_etile_map_;
340 struct GemmBiasTransKernelArg
343 const ADataType* a_ptr_;
344 const BDataType* b_ptr_;
345 typename GridwiseGemm64::DsGridPointer ds_ptr_;
349 AGridDesc_M_K a_grid_desc_m_k_;
350 BGridDesc_N_K b_grid_desc_n_k_;
351 DsGridDesc_M_N ds_grid_desc_m_n_;
352 EGridDesc_M_N e_grid_desc_m_n_;
355 AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
356 BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
357 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
358 ds_grid_desc_mblock_mperblock_nblock_nperblock_;
359 EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
362 GroupedGemmBlock2ETileMap block_2_etile_map_;
367 struct Argument :
public BaseArgument
369 template <
typename Gr
idwiseGemm,
typename DsPo
inter,
typename Block2ETileMap>
370 void init_gridwise_gemm_desc(
const ADataType* a_ptr,
371 const BDataType* b_ptr,
374 const AGridDesc_M_K& a_grid_desc_m_k,
375 const BGridDesc_N_K& b_grid_desc_n_k,
376 const DsGridDesc_M_N& ds_grid_desc_m_n,
377 const EGridDesc_M_N& e_grid_desc_m_n,
378 const Block2ETileMap& block_2_etile_map,
383 const auto a_grid_desc_ak0_m_ak1 =
384 GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
386 const auto b_grid_desc_bk0_n_bk1 =
387 GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
389 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
396 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
397 ds_grid_desc_mblock_mperblock_nblock_nperblock;
399 static_for<0, NumDTensor, 1>{}([&](
auto j) {
400 ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
401 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
402 ds_grid_desc_m_n[j]);
405 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
406 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
409 gemm_desc_kernel_arg_.push_back(
410 GemmBiasTransKernelArg{a_ptr,
418 a_grid_desc_ak0_m_ak1,
419 b_grid_desc_bk0_n_bk1,
420 ds_grid_desc_mblock_mperblock_nblock_nperblock,
421 e_grid_desc_mblock_mperblock_nblock_nperblock,
427 Argument(std::vector<const void*>& p_As,
428 std::vector<const void*>& p_Bs,
429 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
430 std::vector<void*>& p_Es,
431 std::vector<GemmDesc>& gemm_descs,
432 AElementwiseOperation a_element_op,
433 BElementwiseOperation b_element_op,
434 CDEElementwiseOperation c_element_op)
435 : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}
445 throw std::runtime_error(
"wrong! group_count_ != p_As/b/c.size");
448 gemm_desc_kernel_arg_.reserve(group_count_);
450 skipped_group_count_ = 0;
452 for(std::size_t i = 0; i < gemm_descs.size(); i++)
454 const index_t M = gemm_descs[i].M_;
455 const index_t N = gemm_descs[i].N_;
456 const index_t K = gemm_descs[i].K_;
458 a_mtx_mraw_kraw_.emplace_back(M, K);
459 b_mtx_nraw_kraw_.emplace_back(N, K);
463 skipped_group_count_++;
467 const index_t StrideA = gemm_descs[i].stride_A_;
468 const index_t StrideB = gemm_descs[i].stride_B_;
469 const index_t StrideC = gemm_descs[i].stride_C_;
472 typename GridwiseGemm64::DsGridPointer p_ds_grid{};
474 static_for<0, NumDTensor, 1>{}([&](
auto j) {
477 p_ds_grid(j) =
static_cast<const DDataType*
>(p_Ds[i][j]);
481 const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA);
482 const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB);
484 DsGridDesc_M_N ds_grid_desc_m_n;
486 static_for<0, NumDTensor, 1>{}([&](
auto j) {
489 ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
490 M, N, gemm_descs[i].stride_Ds_[j]);
493 const auto e_grid_desc_m_n =
494 DeviceOp::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideC);
497 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0)
498 .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n);
500 const index_t BlockStart = grid_size_;
501 const index_t BlockEnd = grid_size_ + grid_size_grp;
503 grid_size_ += grid_size_grp;
506 const auto block_2_etile_map =
507 GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart);
511 if constexpr(NXdlPerWave64 > 0)
513 init_gridwise_gemm_desc<GridwiseGemm64>(
514 static_cast<const ADataType*
>(p_As[i]),
515 static_cast<const BDataType*
>(p_Bs[i]),
517 static_cast<EDataType*
>(p_Es[i]),
529 if constexpr(NXdlPerWave32 > 0)
531 init_gridwise_gemm_desc<GridwiseGemm32>(
532 static_cast<const ADataType*
>(p_As[i]),
533 static_cast<const BDataType*
>(p_Bs[i]),
535 static_cast<EDataType*
>(p_Es[i]),
552 AElementwiseOperation a_element_op_;
553 BElementwiseOperation b_element_op_;
554 CDEElementwiseOperation c_element_op_;
556 std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
557 std::vector<Tuple<index_t, index_t>> a_mtx_mraw_kraw_;
558 std::vector<Tuple<index_t, index_t>> b_mtx_nraw_kraw_;
561 void* gemm_kernel_host_args_;
565 struct Invoker :
public BaseInvoker
567 using Argument = DeviceOp::Argument;
569 template <
typename Gr
idwiseGemm>
570 float RunImp(
const Argument& arg,
571 const StreamConfig& stream_config = StreamConfig{},
572 hipStream_t cpy_stream =
nullptr,
573 hipEvent_t cpy_event =
nullptr)
575 bool has_main_k_block_loop =
true;
577 for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
581 std::cout <<
"group: " << i <<
" arg.a_grid_desc_ak0_m_ak1_{"
582 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0)
584 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1)
586 << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2)
589 std::cout <<
", arg.b_grid_desc_bk0_n_bk1_{"
590 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0)
592 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1)
594 << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2)
597 std::cout <<
", arg.e_grid_desc_m_n_{ "
598 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) <<
", "
599 << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) <<
"}"
603 if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_,
604 arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_,
605 arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_,
606 arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_,
607 arg.gemm_desc_kernel_arg_[i].block_2_etile_map_))
609 throw std::runtime_error(
610 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
613 const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
614 arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2);
616 if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop)
618 throw std::runtime_error(
"wrong! not all gemm has_main_k_block_loop");
625 if(cpy_stream && cpy_event)
627 if(arg.gemm_kernel_host_args_ ==
nullptr)
629 std::ostringstream err;
630 err <<
"No memory has been allocated for gemm kernel host args "
631 <<
"when providing the copy stream and copy event! In " << __FILE__ <<
":"
632 << __LINE__ <<
", in function: " << __func__;
633 throw std::runtime_error(err.str());
635 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
636 arg.gemm_kernel_host_args_,
637 arg.group_count_ *
sizeof(GemmBiasTransKernelArg),
638 hipMemcpyHostToDevice,
640 hipGetErrorString(hipEventRecord(cpy_event, cpy_stream));
641 hipGetErrorString(hipEventSynchronize(cpy_event));
645 hipGetErrorString(hipMemcpyAsync(arg.p_workspace_,
646 arg.gemm_desc_kernel_arg_.data(),
647 arg.gemm_desc_kernel_arg_.size() *
648 sizeof(GemmBiasTransKernelArg),
649 hipMemcpyHostToDevice,
650 stream_config.stream_id_));
656 const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
657 GemmBiasTransKernelArg,
658 AElementwiseOperation,
659 BElementwiseOperation,
660 CDEElementwiseOperation,
661 has_main_k_block_loop_>;
666 dim3(arg.grid_size_),
670 arg.gemm_desc_kernel_arg_.size(),
676 if(has_main_k_block_loop)
688 float Run(
const Argument& arg,
689 const StreamConfig& stream_config = StreamConfig{},
690 hipStream_t cpy_stream =
nullptr,
691 hipEvent_t cpy_event =
nullptr)
695 if constexpr(NXdlPerWave64 > 0)
697 return RunImp<GridwiseGemm64>(arg, stream_config, cpy_stream, cpy_event);
702 if constexpr(NXdlPerWave32 > 0)
704 return RunImp<GridwiseGemm32>(arg, stream_config, cpy_stream, cpy_event);
711 float Run(
const BaseArgument* p_arg,
712 const StreamConfig& stream_config = StreamConfig{})
override
714 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
718 static bool IsSupportedArgument(
const Argument& arg)
725 arg.skipped_group_count_) != arg.group_count_)
730 bool supported =
true;
734 if constexpr(GemmSpec != GemmSpecialization::Default)
738 const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
739 const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
741 for(index_t i = 0; i < arg.group_count_; ++i)
743 const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number<a_raw_vector_dim>{});
744 const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number<b_raw_vector_dim>{});
746 supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0);
747 supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0);
755 bool IsSupportedArgument(
const BaseArgument* p_arg)
override
757 return IsSupportedArgument(*
dynamic_cast<const Argument*
>(p_arg));
760 static auto MakeArgument(std::vector<const void*>& p_As,
761 std::vector<const void*>& p_Bs,
762 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
763 std::vector<void*>& p_Es,
764 std::vector<GemmDesc> gemm_descs,
765 AElementwiseOperation a_element_op,
766 BElementwiseOperation b_element_op,
767 CDEElementwiseOperation c_element_op)
770 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op};
773 static auto MakeInvoker() {
return Invoker{}; }
776 std::unique_ptr<BaseArgument>
777 MakeArgumentPointer(std::vector<const void*>& p_As,
778 std::vector<const void*>& p_Bs,
779 std::vector<std::array<const void*, NumDTensor>>& p_Ds,
780 std::vector<void*>& p_Es,
781 std::vector<GemmDesc>& gemm_descs,
782 AElementwiseOperation a_element_op,
783 BElementwiseOperation b_element_op,
784 CDEElementwiseOperation c_element_op)
override
786 return std::make_unique<Argument>(
787 p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op);
791 std::unique_ptr<BaseInvoker> MakeInvokerPointer()
override
793 return std::make_unique<Invoker>(Invoker{});
797 std::string GetTypeString()
const override
799 auto str = std::stringstream();
802 str <<
"DeviceGroupedGemm_Xdl"
812 << MXdlPerWave <<
", "
813 << NXdlPerWave <<
", "
814 << ABlockTransferSrcScalarPerVector <<
", "
815 << BBlockTransferSrcScalarPerVector <<
", "
816 << CShuffleMXdlPerWavePerShuffle <<
", "
817 << CShuffleNXdlPerWavePerShuffle <<
", "
825 size_t GetWorkSpaceSize(
const BaseArgument* p_arg)
const override
827 auto p_arg_ =
dynamic_cast<const Argument*
>(p_arg);
830 return p_arg_->group_count_ *
sizeof(GemmBiasTransKernelArg);
833 throw std::runtime_error(
"The argument pointer is not an object of "
834 "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!");
837 size_t GetDeviceKernelArgSize(
const BaseArgument* p_arg)
const override
839 return GetWorkSpaceSize(p_arg);
842 void SetDeviceKernelArgs(BaseArgument* p_arg,
void* p_dev_kernel_args)
const override
844 return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
847 size_t GetHostKernelArgSize(
const BaseArgument* p_arg)
const {
return GetWorkSpaceSize(p_arg); }
859 void SetHostKernelArgsPointer(BaseArgument* p_arg,
void* p_host_kernel_args)
const
861 Argument* pArg_ =
dynamic_cast<Argument*
>(p_arg);
864 throw std::runtime_error(
"Failed to cast argument pointer!");
867 pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
868 std::copy(pArg_->gemm_desc_kernel_arg_.begin(),
869 pArg_->gemm_desc_kernel_arg_.end(),
870 static_cast<GemmBiasTransKernelArg*
>(pArg_->gemm_kernel_host_args_));
#define CK_CONSTANT_ADDRESS_SPACE
Definition ck.hpp:23
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition amd_address_space.hpp:35
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition device_grouped_gemm.hpp:99
Definition device_grouped_gemm.hpp:80
#define CK_ENV(name)
Definition utility/env.hpp:129