dpp_gemm.hpp Source File

dpp_gemm.hpp Source File#

Composable Kernel: dpp_gemm.hpp Source File
dpp_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
8#include "ck/utility/math.hpp"
9
10namespace ck {
11
24
47template <DppInstr instr>
48struct dpp_type;
49
50template <>
52{
53 static constexpr index_t wave_size = 32;
54 static constexpr index_t lanegroup_size = 8;
55 static constexpr index_t m_per_wave = 32;
56 static constexpr index_t n_per_wave = 8;
57 static constexpr index_t m_per_lanegroup = 8;
58 static constexpr index_t n_per_lanegroup = 8;
59 static constexpr index_t m_per_thread = 8;
60 static constexpr index_t n_per_thread = 1;
61 static constexpr index_t k_per_dpp = 2;
62 static constexpr bool share_a = true;
64
65 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
66 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
67 {
72 ADataType,
73 BDataType,
74 CDataType,
75 share_a>{}
76 .Run(a, b, reg_c);
77 }
78};
79
80template <>
82{
83 static constexpr index_t wave_size = 32;
84 static constexpr index_t lanegroup_size = 8;
85 static constexpr index_t m_per_wave = 8;
86 static constexpr index_t n_per_wave = 32;
87 static constexpr index_t m_per_lanegroup = 8;
88 static constexpr index_t n_per_lanegroup = 8;
89 static constexpr index_t m_per_thread = 8;
90 static constexpr index_t n_per_thread = 1;
91 static constexpr index_t k_per_dpp = 2;
92 static constexpr bool share_a = true;
94
95 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
96 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
97 {
100 k_per_dpp,
101 BaseType,
102 ADataType,
103 BDataType,
104 CDataType,
105 share_a>{}
106 .Run(a, b, reg_c);
107 }
108};
109
110template <>
112{
113 static constexpr index_t wave_size = 32;
114 static constexpr index_t lanegroup_size = 8;
115 static constexpr index_t m_per_wave = 8;
116 static constexpr index_t n_per_wave = 16;
117 static constexpr index_t m_per_lanegroup = 4;
118 static constexpr index_t n_per_lanegroup = 8;
119 static constexpr index_t m_per_thread = 4;
120 static constexpr index_t n_per_thread = 1;
121 static constexpr index_t k_per_dpp = 2;
122 static constexpr bool share_a = true;
124
125 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
126 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
127 {
130 k_per_dpp,
131 BaseType,
132 ADataType,
133 BDataType,
134 CDataType,
135 share_a>{}
136 .Run(a, b, reg_c);
137 }
138};
139
140template <>
142{
143 static constexpr index_t wave_size = 32;
144 static constexpr index_t lanegroup_size = 8;
145 static constexpr index_t m_per_wave = 16;
146 static constexpr index_t n_per_wave = 16;
147 static constexpr index_t m_per_lanegroup = 8;
148 static constexpr index_t n_per_lanegroup = 8;
149 static constexpr index_t m_per_thread = 8;
150 static constexpr index_t n_per_thread = 1;
151 static constexpr index_t k_per_dpp = 2;
152 static constexpr bool share_a = true;
154
155 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
156 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
157 {
160 k_per_dpp,
161 BaseType,
162 ADataType,
163 BDataType,
164 CDataType,
165 share_a>{}
166 .Run(a, b, reg_c);
167 }
168};
169
170template <>
172{
173 static constexpr index_t wave_size = 32;
174 static constexpr index_t lanegroup_size = 8;
175 static constexpr index_t m_per_wave = 4;
176 static constexpr index_t n_per_wave = 32;
177 static constexpr index_t m_per_lanegroup = 4;
178 static constexpr index_t n_per_lanegroup = 8;
179 static constexpr index_t m_per_thread = 4;
180 static constexpr index_t n_per_thread = 1;
181 static constexpr index_t k_per_dpp = 2;
182 static constexpr bool share_a = true;
184
185 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
186 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
187 {
190 k_per_dpp,
191 BaseType,
192 ADataType,
193 BDataType,
194 CDataType,
195 share_a>{}
196 .Run(a, b, reg_c);
197 }
198};
199
200template <>
202{
203 static constexpr index_t wave_size = 32;
204 static constexpr index_t lanegroup_size = 8;
205 static constexpr index_t m_per_wave = 4;
206 static constexpr index_t n_per_wave = 16;
207 static constexpr index_t m_per_lanegroup = 2;
208 static constexpr index_t n_per_lanegroup = 8;
209 static constexpr index_t m_per_thread = 2;
210 static constexpr index_t n_per_thread = 1;
211 static constexpr index_t k_per_dpp = 2;
212 static constexpr bool share_a = true;
214
215 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
216 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
217 {
220 k_per_dpp,
221 BaseType,
222 ADataType,
223 BDataType,
224 CDataType,
225 share_a>{}
226 .Run(a, b, reg_c);
227 }
228};
229
230template <>
232{
233 static constexpr index_t wave_size = 32;
234 static constexpr index_t lanegroup_size = 8;
235 static constexpr index_t m_per_wave = 1;
236 static constexpr index_t n_per_wave = 32;
237 static constexpr index_t m_per_lanegroup = 1;
238 static constexpr index_t n_per_lanegroup = 8;
239 static constexpr index_t m_per_thread = 1;
240 static constexpr index_t n_per_thread = 1;
241 static constexpr index_t k_per_dpp = 2;
242 static constexpr bool share_a = true;
244
245 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
246 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
247 {
250 k_per_dpp,
251 BaseType,
252 ADataType,
253 BDataType,
254 CDataType,
255 share_a>{}
256 .Run(a, b, reg_c);
257 }
258};
259
260template <>
262{
263 static constexpr index_t wave_size = 32;
264 static constexpr index_t lanegroup_size = 8;
265 static constexpr index_t m_per_wave = 2;
266 static constexpr index_t n_per_wave = 32;
267 static constexpr index_t m_per_lanegroup = 2;
268 static constexpr index_t n_per_lanegroup = 8;
269 static constexpr index_t m_per_thread = 2;
270 static constexpr index_t n_per_thread = 1;
271 static constexpr index_t k_per_dpp = 2;
272 static constexpr bool share_a = true;
274
275 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
276 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
277 {
280 k_per_dpp,
281 BaseType,
282 ADataType,
283 BDataType,
284 CDataType,
285 share_a>{}
286 .Run(a, b, reg_c);
287 }
288};
289
290template <>
292{
293 static constexpr index_t wave_size = 32;
294 static constexpr index_t lanegroup_size = 8;
295 static constexpr index_t m_per_wave = 2;
296 static constexpr index_t n_per_wave = 16;
297 static constexpr index_t m_per_lanegroup = 1;
298 static constexpr index_t n_per_lanegroup = 8;
299 static constexpr index_t m_per_thread = 1;
300 static constexpr index_t n_per_thread = 1;
301 static constexpr index_t k_per_dpp = 2;
302 static constexpr bool share_a = true;
304
305 template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
306 __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
307 {
310 k_per_dpp,
311 BaseType,
312 ADataType,
313 BDataType,
314 CDataType,
315 share_a>{}
316 .Run(a, b, reg_c);
317 }
318};
319
320template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
322{
323 template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
324 static constexpr auto GetDpp();
325
326 template <>
327 constexpr auto GetDpp<half_t, 8, 32>()
328 {
330 }
331
332 template <>
333 constexpr auto GetDpp<half_t, 8, 16>()
334 {
336 }
337
338 template <>
339 constexpr auto GetDpp<half_t, 16, 16>()
340 {
342 }
343
344 template <>
345 constexpr auto GetDpp<half_t, 32, 8>()
346 {
348 }
349
350 template <>
351 constexpr auto GetDpp<half_t, 1, 32>()
352 {
354 }
355
356 template <>
357 constexpr auto GetDpp<half_t, 2, 32>()
358 {
360 }
361
362 template <>
363 constexpr auto GetDpp<half_t, 2, 16>()
364 {
366 }
367
368 template <>
369 constexpr auto GetDpp<half_t, 4, 16>()
370 {
372 }
373
374 template <>
375 constexpr auto GetDpp<half_t, 4, 32>()
376 {
378 }
379
381
382 __host__ __device__ constexpr DppSelector()
383 {
384 static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
385 static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
386
387 static_assert(selected_dpp.k_per_dpp % 2 == 0);
388
389 static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
390 constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
391 constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
392 constexpr index_t num_dpp_c_elems =
393 selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
394 static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
395 static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
396
397 if constexpr(selected_dpp.share_a)
398 {
399 static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
400 static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
401 static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
402 selected_dpp.lanegroup_size);
403 }
404 else
405 {
406 static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
407 static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
408 selected_dpp.lanegroup_size);
409 static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
410 }
411
412 // Below checks come from the restrictions of the current implementation, could be removed
413 // in the future when the implementation is more generalized.
414 static_assert(selected_dpp.share_a);
415 static_assert(selected_dpp.n_per_thread == 1);
416 static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
417 static_assert(selected_dpp.n_per_lanegroup ==
418 selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
419 }
420
421 static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
422};
423
424template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
426{
427 static constexpr auto I0 = Number<0>{};
428 static constexpr auto I1 = Number<1>{};
429 static constexpr auto I2 = Number<2>{};
430 static constexpr auto I3 = Number<3>{};
431 static constexpr auto I4 = Number<4>{};
432 static constexpr auto I5 = Number<5>{};
433
436
437 __host__ __device__ constexpr DppGemm()
438 {
439 static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
440 }
441
442 __device__ static constexpr index_t GetRegSizePerDpp()
443 {
444 return MPerDpp * NPerDpp / dpp_instr.wave_size;
445 }
446
447 template <class ADataType, class BDataType, class CDataType>
448 __device__ void
449 Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
450 {
454 "base BaseType must be double, float, half, bfloat16, and int8_t!");
455
456 static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
457 dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
458 });
459 }
460
461 __device__ static auto GetLaneIdInWave()
462 {
463 return get_thread_local_1d_id() % dpp_instr.wave_size;
464 }
465
466 __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
467
468 __device__ static auto GetLaneIdInLaneGroup()
469 {
470 return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
471 }
472
473 __device__ static auto GetLaneGroupIdInWave()
474 {
475 return GetLaneIdInWave() / dpp_instr.lanegroup_size;
476 }
477
478 __device__ static auto GetDppOpIdx()
479 {
480 const auto lanegroupId = GetLaneGroupIdInWave();
481
482 constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
484 make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
485 dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
488
489 const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
490 make_multi_index(lanegroupId));
491
492 const auto m_dpp_idx = dpp_idx[I0];
493 const auto n_dpp_idx = dpp_idx[I1];
494
495 return make_tuple(m_dpp_idx, n_dpp_idx);
496 }
497
498 __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
499 {
500 const auto laneId = get_thread_local_1d_id();
501 const auto wave_row = laneId / dpp_instr.n_per_wave;
502 auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
503 return make_tuple(0, m_idx % dpp_instr.m_per_wave);
504 }
505
506 __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
507 {
508 const auto laneId = get_thread_local_1d_id();
509 return make_tuple(0, laneId % dpp_instr.n_per_wave);
510 }
511
512 __device__ static CIndex GetBeginOfThreadBlk()
513 {
514 const auto dpp_op_idx = GetDppOpIdx();
515
516 const auto m_dpp_op_idx = dpp_op_idx[I0];
517 const auto n_dpp_op_idx = dpp_op_idx[I1];
518
519 index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
520 index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
521
522 return CIndex{m_offset, n_offset};
523 }
524
526
527 static constexpr auto dpp_instr = dpp.selected_dpp;
528
529 static constexpr auto K0PerDpp = 1;
530 static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
531
532 __host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
533 {
535 }
536};
537
538} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
DppInstr
Definition dpp_gemm.hpp:13
@ dpp8_f16_2x16x2
Definition dpp_gemm.hpp:15
@ dpp8_f16_4x16x2
Definition dpp_gemm.hpp:17
@ dpp8_f16_1x32x2
Definition dpp_gemm.hpp:14
@ dpp8_f16_8x32x2
Definition dpp_gemm.hpp:20
@ dpp8_f16_2x32x2
Definition dpp_gemm.hpp:16
@ dpp8_f16_8x16x2
Definition dpp_gemm.hpp:19
@ dpp8_f16_32x8x2
Definition dpp_gemm.hpp:22
@ dpp8_f16_4x32x2
Definition dpp_gemm.hpp:18
@ dpp8_f16_16x16x2
Definition dpp_gemm.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
__host__ static __device__ auto CalculateBThreadOriginDataIndex_K_N()
Definition dpp_gemm.hpp:506
MultiIndex< 2 > CIndex
Definition dpp_gemm.hpp:434
__device__ void Run(const ADataType &p_a_wave, const BDataType &p_b_wave, CDataType &p_c_thread) const
Definition dpp_gemm.hpp:449
static constexpr auto dpp_instr
Definition dpp_gemm.hpp:527
__host__ static __device__ constexpr auto GetCMNThreadBlkLengths()
Definition dpp_gemm.hpp:532
static __device__ constexpr index_t GetRegSizePerDpp()
Definition dpp_gemm.hpp:442
static constexpr auto I3
Definition dpp_gemm.hpp:430
__host__ __device__ constexpr DppGemm()
Definition dpp_gemm.hpp:437
static constexpr auto I1
Definition dpp_gemm.hpp:428
static __device__ auto GetWaveId()
Definition dpp_gemm.hpp:466
static constexpr auto I5
Definition dpp_gemm.hpp:432
static __device__ auto GetLaneGroupIdInWave()
Definition dpp_gemm.hpp:473
static __device__ CIndex GetBeginOfThreadBlk()
Definition dpp_gemm.hpp:512
static constexpr auto I4
Definition dpp_gemm.hpp:431
static constexpr auto I2
Definition dpp_gemm.hpp:429
static __device__ auto GetLaneIdInLaneGroup()
Definition dpp_gemm.hpp:468
static constexpr auto K1PerDpp
Definition dpp_gemm.hpp:530
static constexpr auto dpp
Definition dpp_gemm.hpp:525
__host__ static __device__ auto CalculateAThreadOriginDataIndex_K_M()
Definition dpp_gemm.hpp:498
static constexpr auto I0
Definition dpp_gemm.hpp:427
MultiIndex< 4 > CIndex4D
Definition dpp_gemm.hpp:435
static __device__ auto GetDppOpIdx()
Definition dpp_gemm.hpp:478
static __device__ auto GetLaneIdInWave()
Definition dpp_gemm.hpp:461
static constexpr auto K0PerDpp
Definition dpp_gemm.hpp:529
Definition dpp_gemm.hpp:322
static constexpr index_t GetK1PerDpp()
Definition dpp_gemm.hpp:421
static constexpr auto selected_dpp
Definition dpp_gemm.hpp:380
static constexpr auto GetDpp()
__host__ __device__ constexpr DppSelector()
Definition dpp_gemm.hpp:382
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition amd_gemm_dpp.hpp:37
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:156
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:149
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:146
static constexpr index_t wave_size
Definition dpp_gemm.hpp:143
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:144
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:147
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:151
static constexpr bool share_a
Definition dpp_gemm.hpp:152
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:145
half_t BaseType
Definition dpp_gemm.hpp:153
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:150
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:148
static constexpr bool share_a
Definition dpp_gemm.hpp:242
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:238
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:240
static constexpr index_t wave_size
Definition dpp_gemm.hpp:233
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:246
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:241
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:235
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:237
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:234
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:236
half_t BaseType
Definition dpp_gemm.hpp:243
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:239
static constexpr bool share_a
Definition dpp_gemm.hpp:302
half_t BaseType
Definition dpp_gemm.hpp:303
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:297
static constexpr index_t wave_size
Definition dpp_gemm.hpp:293
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:301
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:295
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:300
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:299
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:298
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:306
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:294
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:296
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:268
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:267
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:270
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:266
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:264
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:269
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:276
half_t BaseType
Definition dpp_gemm.hpp:273
static constexpr index_t wave_size
Definition dpp_gemm.hpp:263
static constexpr bool share_a
Definition dpp_gemm.hpp:272
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:265
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:271
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:54
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:57
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:56
static constexpr index_t wave_size
Definition dpp_gemm.hpp:53
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:60
static constexpr bool share_a
Definition dpp_gemm.hpp:62
half_t BaseType
Definition dpp_gemm.hpp:63
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:58
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:61
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:55
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:66
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:59
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:208
static constexpr index_t wave_size
Definition dpp_gemm.hpp:203
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:205
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:216
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:210
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:209
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:206
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:204
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:211
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:207
static constexpr bool share_a
Definition dpp_gemm.hpp:212
half_t BaseType
Definition dpp_gemm.hpp:213
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:177
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:174
static constexpr index_t wave_size
Definition dpp_gemm.hpp:173
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:180
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:181
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:179
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:175
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:178
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:176
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:186
static constexpr bool share_a
Definition dpp_gemm.hpp:182
half_t BaseType
Definition dpp_gemm.hpp:183
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:116
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:118
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:114
static constexpr index_t wave_size
Definition dpp_gemm.hpp:113
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:121
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:126
half_t BaseType
Definition dpp_gemm.hpp:123
static constexpr bool share_a
Definition dpp_gemm.hpp:122
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:119
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:115
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:117
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:120
static constexpr index_t n_per_lanegroup
Definition dpp_gemm.hpp:88
half_t BaseType
Definition dpp_gemm.hpp:93
static constexpr bool share_a
Definition dpp_gemm.hpp:92
static constexpr index_t m_per_lanegroup
Definition dpp_gemm.hpp:87
static constexpr index_t n_per_thread
Definition dpp_gemm.hpp:90
static constexpr index_t wave_size
Definition dpp_gemm.hpp:83
static constexpr index_t m_per_thread
Definition dpp_gemm.hpp:89
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition dpp_gemm.hpp:96
static constexpr index_t k_per_dpp
Definition dpp_gemm.hpp:91
static constexpr index_t lanegroup_size
Definition dpp_gemm.hpp:84
static constexpr index_t n_per_wave
Definition dpp_gemm.hpp:86
static constexpr index_t m_per_wave
Definition dpp_gemm.hpp:85
Definition dpp_gemm.hpp:48
Definition functional2.hpp:33