gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp Source File

gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp Source File#

Composable Kernel: gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp Source File
gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8
9namespace ck_tile {
10
12{
13 template <typename Problem>
14 CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
15 {
17 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
18 constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
19 constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
20 constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
21
22 return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
23 }
24
25 template <typename Problem>
30
31 template <typename Problem>
33 {
34 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
35 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
36
37 using BTypeToUse =
38 std::conditional_t<std::is_same_v<typename Problem::BDataType, ck_tile::pk_int4_t>,
39 typename Problem::ADataType,
40 typename Problem::BDataType>;
41
42 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
43 BTypeToUse,
44 typename Problem::CDataType,
45 WarpTile::at(I0),
46 WarpTile::at(I1),
47 WarpTile::at(I2),
48 Problem::TransposeC>;
49
50 // TODO : Use a custom block policy for AsBrCr
51 using BlockGemmPolicy =
52 BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
53 typename Problem::BDataType,
54 typename Problem::CDataType,
55 BlockWarps,
56 WarpGemm>;
58 }
59};
60
61} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
int32_t index_t
Definition integer.hpp:9
Definition block_universal_gemm_ar_flatbr_bquant_cr.hpp:18
Definition block_wp_asmem_bsmem_creg_v1_custom_policy.hpp:18
static CK_TILE_HOST_DEVICE constexpr auto MakeBQDramTileDistribution()
Definition gemm_bquant_pipeline_ag_bg_cr_policy.hpp:33
Definition gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp:12
static CK_TILE_HOST_DEVICE constexpr auto MakeBQDramTileDistribution()
Definition gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp:26
static CK_TILE_HOST_DEVICE constexpr auto GetVectorSizeBQ()
Definition gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp:14
static CK_TILE_HOST_DEVICE constexpr auto GetBlockWeightPreshuffleBQuant()
Definition gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp:32
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49
Definition wp_pipeline_agmem_bgmem_creg_base_policy.hpp:13