blockwise_gemm_pipeline_xdlops_selector.hpp Source File

blockwise_gemm_pipeline_xdlops_selector.hpp Source File#

Composable Kernel: blockwise_gemm_pipeline_xdlops_selector.hpp Source File
blockwise_gemm_pipeline_xdlops_selector.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14template <BlockGemmPipelineVersion BlkGemmPipelineVer,
15 BlockGemmPipelineScheduler BlkGemmPipeSche,
16 index_t BlockSize,
17 typename ADataType,
18 typename BDataType,
19 typename ComputeDataType,
20 typename AccDataType,
21 typename ATileDesc,
22 typename BTileDesc,
23 typename AMmaTileDesc,
24 typename BMmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
27 index_t MPerBlock,
28 index_t NPerBlock,
29 index_t KPerBlock,
30 index_t MPerXDL,
31 index_t NPerXDL,
32 index_t MRepeat,
33 index_t NRepeat,
34 index_t KPack,
35 bool DirectLoad = false>
37{
38 if constexpr(DirectLoad)
39 {
40 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
41 {
42 return BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlkGemmPipeSche,
43 BlockSize,
44 ADataType,
45 BDataType,
46 ComputeDataType,
47 AccDataType,
48 ATileDesc,
49 BTileDesc,
50 AMmaTileDesc,
51 BMmaTileDesc,
52 ABlockTransferSrcScalarPerVector,
53 BBlockTransferSrcScalarPerVector,
54 MPerBlock,
55 NPerBlock,
56 KPerBlock,
57 MPerXDL,
58 NPerXDL,
59 MRepeat,
60 NRepeat,
61 KPack>{};
62 }
63 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
64 {
65 return BlockwiseGemmXdlopsDirectLoad_pipeline_v4<BlkGemmPipeSche,
66 BlockSize,
67 ADataType,
68 BDataType,
69 ComputeDataType,
70 AccDataType,
71 ATileDesc,
72 BTileDesc,
73 AMmaTileDesc,
74 BMmaTileDesc,
75 ABlockTransferSrcScalarPerVector,
76 BBlockTransferSrcScalarPerVector,
77 MPerBlock,
78 NPerBlock,
79 KPerBlock,
80 MPerXDL,
81 NPerXDL,
82 MRepeat,
83 NRepeat,
84 KPack>{};
85 }
86 else
87 {
88 std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
89 }
90 }
91 else
92 {
93 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
94 {
95 return BlockwiseGemmXdlops_pipeline_v1<BlkGemmPipeSche,
96 BlockSize,
97 ADataType,
98 BDataType,
99 ComputeDataType,
100 AccDataType,
101 ATileDesc,
102 BTileDesc,
103 AMmaTileDesc,
104 BMmaTileDesc,
105 ABlockTransferSrcScalarPerVector,
106 BBlockTransferSrcScalarPerVector,
107 MPerBlock,
108 NPerBlock,
109 KPerBlock,
110 MPerXDL,
111 NPerXDL,
112 MRepeat,
113 NRepeat,
114 KPack>{};
115 }
116 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
117 {
118 return BlockwiseGemmXdlops_pipeline_v2<BlkGemmPipeSche,
119 BlockSize,
120 ADataType,
121 BDataType,
122 ComputeDataType,
123 AccDataType,
124 ATileDesc,
125 BTileDesc,
126 AMmaTileDesc,
127 BMmaTileDesc,
128 ABlockTransferSrcScalarPerVector,
129 BBlockTransferSrcScalarPerVector,
130 MPerBlock,
131 NPerBlock,
132 KPerBlock,
133 MPerXDL,
134 NPerXDL,
135 MRepeat,
136 NRepeat,
137 KPack>{};
138 }
139 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
140 {
141 return BlockwiseGemmXdlops_pipeline_v3<BlkGemmPipeSche,
142 BlockSize,
143 ADataType,
144 BDataType,
145 ComputeDataType,
146 AccDataType,
147 ATileDesc,
148 BTileDesc,
149 AMmaTileDesc,
150 BMmaTileDesc,
151 ABlockTransferSrcScalarPerVector,
152 BBlockTransferSrcScalarPerVector,
153 MPerBlock,
154 NPerBlock,
155 KPerBlock,
156 MPerXDL,
157 NPerXDL,
158 MRepeat,
159 NRepeat,
160 KPack>{};
161 }
162 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
163 {
164 return BlockwiseGemmXdlops_pipeline_v4<BlkGemmPipeSche,
165 BlockSize,
166 ADataType,
167 BDataType,
168 ComputeDataType,
169 AccDataType,
170 ATileDesc,
171 BTileDesc,
172 AMmaTileDesc,
173 BMmaTileDesc,
174 ABlockTransferSrcScalarPerVector,
175 BBlockTransferSrcScalarPerVector,
176 MPerBlock,
177 NPerBlock,
178 KPerBlock,
179 MPerXDL,
180 NPerXDL,
181 MRepeat,
182 NRepeat,
183 KPack>{};
184 }
185 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5)
186 {
187 return BlockwiseGemmXdlops_pipeline_v5<BlkGemmPipeSche,
188 BlockSize,
189 ADataType,
190 BDataType,
191 ComputeDataType,
192 AccDataType,
193 ATileDesc,
194 BTileDesc,
195 AMmaTileDesc,
196 BMmaTileDesc,
197 ABlockTransferSrcScalarPerVector,
198 BBlockTransferSrcScalarPerVector,
199 MPerBlock,
200 NPerBlock,
201 KPerBlock,
202 MPerXDL,
203 NPerXDL,
204 MRepeat,
205 NRepeat,
206 KPack>{};
207 }
208 else
209 {
210 std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
211 }
212 }
213}
214
215} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
constexpr auto BlockGemmPipeline_Selector()
Definition blockwise_gemm_pipeline_wmma_selector.hpp:32
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition blockwise_gemm_pipeline_xdlops.hpp:103
Definition blockwise_gemm_pipeline_xdlops_v5.hpp:37
Definition blockwise_gemm_pipeline_xdlops_v1.hpp:763
Definition blockwise_gemm_pipeline_xdlops_v4.hpp:604