gridwise_gemm_reduce_xdl_cshuffle_v1.hpp Source File

gridwise_gemm_reduce_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
21template <typename GridwiseGemm,
22 typename FloatAB,
23 typename FloatC,
24 typename ReducePtrsGlobal,
25 typename AElementwiseOperation,
26 typename BElementwiseOperation,
27 typename CElementwiseOperation,
28 typename ReduceInElementwiseOperations,
29 typename ReduceAccElementwiseOperations,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33 typename ReduceGridDescriptor_MBlock_MPerBlock,
34 typename Block2CTileMap,
35 bool HasMainKBlockLoop>
36__global__ void
37#if CK_USE_LAUNCH_BOUNDS
39#endif
41 const FloatAB* __restrict__ p_a_grid,
42 const FloatAB* __restrict__ p_b_grid,
43 FloatC* __restrict__ p_c_grid,
44 ReducePtrsGlobal p_reduces_grid,
45 const AElementwiseOperation a_element_op,
46 const BElementwiseOperation b_element_op,
47 const CElementwiseOperation c_element_op,
48 const ReduceInElementwiseOperations reduce_in_element_ops,
49 const ReduceAccElementwiseOperations reduce_out_element_ops,
50 const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
51 const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
52 const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
53 c_grid_desc_mblock_mperblock_nblock_nperblock,
54 const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock,
55 const Block2CTileMap block_2_ctile_map)
56{
57#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
58 defined(__gfx12__)
59 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
60 {
61 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
62
63 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
64 p_b_grid,
65 p_c_grid,
66 p_reduces_grid,
67 p_shared,
68 a_element_op,
69 b_element_op,
70 c_element_op,
71 reduce_in_element_ops,
72 reduce_out_element_ops,
73 a_grid_desc_ak0_m_ak1,
74 b_grid_desc_bk0_n_bk1,
75 c_grid_desc_mblock_mperblock_nblock_nperblock,
76 reduce_grid_desc_mblock_mperblock,
77 block_2_ctile_map);
78 }
79#else
80 ignore = p_a_grid;
81 ignore = p_b_grid;
82 ignore = p_c_grid;
83 ignore = p_reduces_grid;
84 ignore = a_element_op;
85 ignore = b_element_op;
86 ignore = c_element_op;
87 ignore = reduce_in_element_ops;
88 ignore = reduce_out_element_ops;
89 ignore = a_grid_desc_ak0_m_ak1;
90 ignore = b_grid_desc_bk0_n_bk1;
91 ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
92 ignore = reduce_grid_desc_mblock_mperblock;
93 ignore = block_2_ctile_map;
94#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
95}
96
97template <typename FloatAB,
98 typename FloatGemmAcc,
99 typename FloatCShuffle,
100 typename FloatC,
101 typename FloatReduceAcc,
102 typename ReducePtrsGlobal,
103 typename AElementwiseOperation,
104 typename BElementwiseOperation,
105 typename CElementwiseOperation,
106 typename ReduceOperations,
107 typename ReduceInElementwiseOperations,
108 typename ReduceAccElementwiseOperations,
109 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
110 typename ReduceGlobalMemoryDataOperation,
111 typename AGridDesc_AK0_M_AK1,
112 typename BGridDesc_BK0_N_BK1,
113 typename CGridDesc_M_N,
114 typename ReduceGridDesc_M,
115 index_t NumGemmKPrefetchStage,
116 index_t BlockSize,
117 index_t MPerBlock,
118 index_t NPerBlock,
119 index_t KPerBlock,
120 index_t AK1Value,
121 index_t BK1Value,
122 index_t MPerXdl,
123 index_t NPerXdl,
124 index_t MXdlPerWave,
125 index_t NXdlPerWave,
126 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
127 typename ABlockTransferThreadClusterArrangeOrder,
128 typename ABlockTransferSrcAccessOrder,
129 index_t ABlockTransferSrcVectorDim,
130 index_t ABlockTransferSrcScalarPerVector,
131 index_t ABlockTransferDstScalarPerVector_AK1,
132 bool AThreadTransferSrcResetCoordinateAfterRun,
133 index_t ABlockLdsExtraM,
134 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
135 typename BBlockTransferThreadClusterArrangeOrder,
136 typename BBlockTransferSrcAccessOrder,
137 index_t BBlockTransferSrcVectorDim,
138 index_t BBlockTransferSrcScalarPerVector,
139 index_t BBlockTransferDstScalarPerVector_BK1,
140 bool BThreadTransferSrcResetCoordinateAfterRun,
141 index_t BBlockLdsExtraN,
142 index_t CShuffleMXdlPerWavePerShuffle,
143 index_t CShuffleNXdlPerWavePerShuffle,
144 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
145 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
146 typename CReduceThreadClusterLengths_MPerBlock_NPerBlock,
147 index_t CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
148 index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
149 LoopScheduler LoopSched,
152{
153 static constexpr auto I0 = Number<0>{};
154 static constexpr auto I1 = Number<1>{};
155 static constexpr auto I2 = Number<2>{};
156 static constexpr auto I3 = Number<3>{};
157 static constexpr auto I4 = Number<4>{};
158 static constexpr auto I5 = Number<5>{};
159 static constexpr auto I6 = Number<6>{};
160 static constexpr auto I7 = Number<7>{};
161
162 // K1 should be Number<...>
163 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
164 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
165 static constexpr auto AK1 = Number<AK1Value>{};
166 static constexpr auto BK1 = Number<BK1Value>{};
167
169
172
173 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
174 {
175 // A matrix in LDS memory, dst of blockwise copy
179 }
180
181 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
182 {
183 // B matrix in LDS memory, dst of blockwise copy
187 }
188
189 __host__ __device__ static constexpr auto
191 {
192 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
193 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
194
195 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
199 I1,
201
202 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
203 }
204
205 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
206 {
207 // LDS allocation for A and B: be careful of alignment
208 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
209 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
210
211 // lds max alignment
212 constexpr auto max_lds_align = math::lcm(AK1, BK1);
213
214 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
215 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
216
217 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
218 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
219
220 // LDS allocation for C shuffle in LDS
221 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
223
224 constexpr auto c_block_size =
225 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
226
227 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
228 sizeof(FloatAB),
229 c_block_size * sizeof(FloatCShuffle));
230 }
231
232 template <
233 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
234 __device__ static bool constexpr IsValidCompilationParameter()
235 {
236 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
237 BlockSize,
238 MPerBlock,
239 NPerBlock,
240 MPerXdl,
241 NPerXdl,
242 MXdlPerWave,
243 NXdlPerWave,
244 FloatC,
245 CGlobalMemoryDataOperation>();
246 }
247
248 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
249 template <typename Block2CTileMap>
250 __host__ __device__ static constexpr bool
251 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
252 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
253 const CGridDesc_M_N& c_grid_desc_m_n,
254 const Block2CTileMap& block_2_ctile_map)
255 {
256 // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
257 // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
258 // "wrong! K1 need to be known at compile-time");
259
260 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
261 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
262 "Invalid tuning param!");
263
264 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
265 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
266 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
267
268 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
269 return false;
270
271 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
272 return false;
273
274 // check gridwise gemm pipeline
275 const auto num_k_loop = K / KPerBlock;
276
277 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
278 {
279 return false;
280 }
281
282 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
283 {
284 return false;
285 }
286
287 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
288 return true;
289 }
290
291 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
292 {
293 const index_t num_loop = K / KPerBlock;
294
295 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
296 }
297
298 __host__ __device__ static constexpr auto
299 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
300 {
301 const auto M = c_grid_desc_m_n.GetLength(I0);
302 const auto N = c_grid_desc_m_n.GetLength(I1);
303
304 const auto MBlock = M / MPerBlock;
305 const auto NBlock = N / NPerBlock;
306
307 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
308 c_grid_desc_m_n,
313
314 return c_grid_desc_mblock_mperblock_nblock_nperblock;
315 }
316
317 __host__ __device__ static constexpr auto
318 MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m)
319 {
320 const auto M = d_grid_desc_m.GetLength(I0);
321 const auto MBlock = M / MPerBlock;
322
323 const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor(
324 d_grid_desc_m,
328
329 return reduce_grid_desc_mblock_mperblock;
330 }
331
332 // return block_id to C matrix tile idx (m0, n0) mapping
333 __host__ __device__ static constexpr auto
334 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
335 {
337 c_grid_desc_m_n);
338 }
339
342 CGridDesc_M_N{}))>;
343
345 remove_cvref_t<decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))>;
346
348 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
349
350 template <bool HasMainKBlockLoop, typename Block2CTileMap>
351 __device__ static void
352 Run(const FloatAB* __restrict__ p_a_grid,
353 const FloatAB* __restrict__ p_b_grid,
354 FloatC* __restrict__ p_c_grid,
355 ReducePtrsGlobal p_reduces_grid,
356 void* __restrict__ p_shared,
357 const AElementwiseOperation& a_element_op,
358 const BElementwiseOperation& b_element_op,
359 const CElementwiseOperation& c_element_op,
360 const ReduceInElementwiseOperations& reduce_in_element_ops,
361 const ReduceAccElementwiseOperations& reduce_out_element_ops,
362 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
363 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
365 c_grid_desc_mblock_mperblock_nblock_nperblock,
366 const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock,
367 const Block2CTileMap& block_2_ctile_map)
368 {
369 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
370 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
371 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
372 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
374 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
375
376 // divide block work by [M, N]
377 const auto block_work_idx =
378 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
379
380 if(!block_2_ctile_map.ValidCTileIndex(
381 block_work_idx,
382 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
383 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
384 {
385 return;
386 }
387
388 // HACK: this force m/n_block_data_idx_on_grid into SGPR
389 const index_t m_block_data_idx_on_grid =
390 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
391
392 const index_t n_block_data_idx_on_grid =
393 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
394
395 // lds max alignment
396 constexpr auto max_lds_align = math::lcm(AK1, BK1);
397
398 // A matrix in LDS memory, dst of blockwise copy
399 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
400
401 // B matrix in LDS memory, dst of blockwise copy
402 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
403
404 // A matrix blockwise copy
405 auto a_blockwise_copy =
407 AElementwiseOperation,
411 ABlockTransferThreadClusterLengths_AK0_M_AK1,
412 ABlockTransferThreadClusterArrangeOrder,
413 FloatAB,
414 FloatAB,
415 decltype(a_grid_desc_ak0_m_ak1),
416 decltype(a_block_desc_ak0_m_ak1),
417 ABlockTransferSrcAccessOrder,
419 ABlockTransferSrcVectorDim,
420 2,
421 ABlockTransferSrcScalarPerVector,
422 ABlockTransferDstScalarPerVector_AK1,
423 1,
424 1,
425 AThreadTransferSrcResetCoordinateAfterRun,
426 true,
427 NumGemmKPrefetchStage>(
428 a_grid_desc_ak0_m_ak1,
429 make_multi_index(0, m_block_data_idx_on_grid, 0),
430 a_element_op,
431 a_block_desc_ak0_m_ak1,
432 make_multi_index(0, 0, 0),
434
435 // B matrix blockwise copy
436 auto b_blockwise_copy =
438 BElementwiseOperation,
442 BBlockTransferThreadClusterLengths_BK0_N_BK1,
443 BBlockTransferThreadClusterArrangeOrder,
444 FloatAB,
445 FloatAB,
446 decltype(b_grid_desc_bk0_n_bk1),
447 decltype(b_block_desc_bk0_n_bk1),
448 BBlockTransferSrcAccessOrder,
450 BBlockTransferSrcVectorDim,
451 2,
452 BBlockTransferSrcScalarPerVector,
453 BBlockTransferDstScalarPerVector_BK1,
454 1,
455 1,
456 BThreadTransferSrcResetCoordinateAfterRun,
457 true,
458 NumGemmKPrefetchStage>(
459 b_grid_desc_bk0_n_bk1,
460 make_multi_index(0, n_block_data_idx_on_grid, 0),
461 b_element_op,
462 b_block_desc_bk0_n_bk1,
463 make_multi_index(0, 0, 0),
465
466 // GEMM definition
467 // c_mtx += transpose(a_mtx) * b_mtx
468 // a_mtx[K0PerBlock, MPerBlock] is in LDS
469 // b_mtx[K0PerBlock, NPerBlock] is in LDS
470 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
471 // register
472 // sanity check
473 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
474 constexpr bool is_single_rate_mfma =
476 lcm_AK1_BK1 <= 4) ||
477 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
479 lcm_AK1_BK1 < 32))
480 ? true
481 : false;
482 constexpr auto is_scale_mfma = false;
483 constexpr index_t KPack = math::max(
484 lcm_AK1_BK1,
486 selected_mfma.k_per_blk);
487
489 BlockSize,
490 FloatAB,
491 FloatAB,
492 FloatGemmAcc,
493 decltype(a_block_desc_ak0_m_ak1),
494 decltype(b_block_desc_bk0_n_bk1),
495 MPerXdl,
496 NPerXdl,
497 MXdlPerWave,
498 NXdlPerWave,
499 KPack,
500 LoopSched>();
501
502 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
503
504 // LDS allocation for A and B: be careful of alignment
505 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
506 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
507
509 static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
510
512 static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
513 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
514
515 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
516 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
517
518 // gridwise GEMM pipeline
519 const auto gridwise_gemm_pipeline =
521
522 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
523 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
524 KPerBlock);
525
526 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
527 a_block_desc_ak0_m_ak1,
528 a_blockwise_copy,
529 a_grid_buf,
530 a_block_buf,
531 a_block_slice_copy_step,
532 b_grid_desc_bk0_n_bk1,
533 b_block_desc_bk0_n_bk1,
534 b_blockwise_copy,
535 b_grid_buf,
536 b_block_buf,
537 b_block_slice_copy_step,
538 blockwise_gemm,
539 c_thread_buf,
540 num_k_block_main_loop);
541
542 // shuffle C + reduction + write out
543 {
544 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
545 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
546 "wrong!");
547
548 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
549 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
550
551 // TODO: hacky, fix it!
552 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
553 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
554
555 // TODO: hacky, fix it!
556 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
557 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
558 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
559
560 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
561 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
562 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
563 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
564 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
565 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
566 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
567 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
568
569 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
571
572 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
573 static_cast<FloatCShuffle*>(p_shared),
574 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
575
576 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
577 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
581 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
582 M1, // M1 = MWave
583 M2, // M2 * M3 * M4 = MPerXdl
584 M3,
585 M4)),
588 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
589 N1, // N1 = NWave
590 N2))), // N2 = NPerXdl
594
595 // calculate origin of thread output tensor on global memory
596 // blockwise GEMM c matrix starting index
597 const auto c_thread_mtx_on_block =
598 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
599
600 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
601 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
602
603 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
605 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
608
609 const auto m_thread_data_on_block_idx =
610 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
611 make_multi_index(m_thread_data_on_block));
612
613 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
618
619 const auto n_thread_data_on_block_idx =
620 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
621 make_multi_index(n_thread_data_on_block));
622
623 // shuffle: threadwise copy C from VGPR to LDS
624 auto c_thread_copy_vgpr_to_lds =
626 FloatCShuffle,
627 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
628 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
630 Sequence<CShuffleMXdlPerWavePerShuffle,
631 CShuffleNXdlPerWavePerShuffle,
632 I1,
633 I1,
634 M2,
635 I1,
636 M4,
637 I1>,
639 7,
640 1,
642 1,
643 true>{
644 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
646 0,
647 m_thread_data_on_block_idx[I1],
648 n_thread_data_on_block_idx[I1],
649 m_thread_data_on_block_idx[I2],
650 m_thread_data_on_block_idx[I3],
651 m_thread_data_on_block_idx[I4],
652 n_thread_data_on_block_idx[I2]),
654
655 // shuffle: blockwise copy C from LDS to global
656 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
657 ThisThreadBlock, // ThreadGroup
658 CElementwiseOperation, // ElementwiseOperation,
659 CGlobalMemoryDataOperation, // DstInMemOp,
660 Sequence<1,
661 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
662 1,
663 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
664 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
665 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
666 FloatCShuffle, // typename SrcData,
667 FloatC, // typename DstData,
668 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
669 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
670 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
671 3, // index_t VectorDim,
672 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
673 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
674 false> // bool ThreadTransferDstResetCoordinateAfterRun>
675 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
676 make_multi_index(0, 0, 0, 0),
677 c_grid_desc_mblock_mperblock_nblock_nperblock,
678 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
679 c_element_op};
680
681 // space filling curve for threadwise C in VGPR
682 constexpr auto sfc_c_vgpr =
685 Sequence<CShuffleMXdlPerWavePerShuffle,
686 CShuffleNXdlPerWavePerShuffle,
687 1,
688 1,
689 M2,
690 1,
691 M4,
692 1>>{};
693
694 // space filling curve for shuffled blockwise C in global mem
695 constexpr auto sfc_c_global =
698 Sequence<1,
699 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
700 1,
701 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
702
703 // TODO: this should be implemented as a blockwise reduction
704 // LDS c_reduce_block_desc_mperblock_nperblock
705 constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor(
706 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
710 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)),
713 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))),
716
717 static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) *
718 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
719 BlockSize,
720 "wrong!");
721
722 static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) %
723 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) ==
724 0 &&
725 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) %
726 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) ==
727 0,
728 "wrong!");
729
730 constexpr index_t mreduce_per_thread =
731 (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) /
732 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0);
733
734 constexpr index_t nreduce_per_thread =
735 (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) /
736 CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1);
737
738 constexpr auto c_reduce_thread_lengths_mperblock_nperblock =
740
741 // VGPR c_reduce_thread_desc_mperblock_nperblock
742 constexpr auto c_reduce_thread_desc_mperblock_nperblock =
745
746 // VGPR reduce_thread_desc_mperblock
747 constexpr auto reduce_thread_desc_mperblock =
749
750 // VGPR reduce_thread_desc_mblock_mperblock
751 constexpr auto reduce_thread_desc_mblock_mperblock =
753
755 c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
756
757 // reduce: threadwise copy from LDS to VGPR
758 constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor(
759 CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{});
760
761 const auto c_reduce_thread_cluster_idx =
762 c_reduce_thread_cluster_desc.CalculateBottomIndex(
764
765 const auto c_reduce_thread_data_idx_begin =
766 c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock;
767
768 auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2<
769 FloatCShuffle,
770 FloatReduceAcc,
771 decltype(c_reduce_block_desc_mperblock_nperblock),
772 decltype(c_reduce_thread_desc_mperblock_nperblock),
773 decltype(c_reduce_thread_lengths_mperblock_nperblock),
775 1,
776 CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock,
777 1,
778 true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin};
779
780 auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple(
781 [&](auto I) {
782 auto p_reduce_grid = p_reduces_grid[I];
783 auto reduce_acc_element_op = reduce_out_element_ops[I];
784
786 FloatReduceAcc,
787 remove_pointer_t<decltype(p_reduce_grid)>,
788 decltype(reduce_thread_desc_mblock_mperblock),
789 decltype(reduce_grid_desc_mblock_mperblock),
790 decltype(reduce_acc_element_op),
793 1,
794 CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock,
795 ReduceGlobalMemoryDataOperation::At(I),
796 1,
797 false>{reduce_grid_desc_mblock_mperblock,
798 make_multi_index(block_work_idx[I0], // mblock
799 c_reduce_thread_data_idx_begin[I0]), // mperblock
800 reduce_acc_element_op};
801 },
802 Number<p_reduces_grid.Size()>{});
803
804 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
805
806 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
807
808 static_for<0, num_access, 1>{}([&](auto access_id) {
809 // make sure it's safe to write to LDS
811
812 // each thread write its data from VGPR to LDS
813 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
814 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
815 c_thread_buf,
816 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
817 c_shuffle_block_buf);
818
819 // make sure it's safe to read from LDS
821
822 // each block copy its data from LDS to global
823 c_shuffle_block_copy_lds_to_global.Run(
824 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
825 c_shuffle_block_buf,
826 c_grid_desc_mblock_mperblock_nblock_nperblock,
827 c_grid_buf);
828
829 // TODO - extract following into reduction_blockwise
830 {
831 c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock,
832 c_shuffle_block_buf,
833 c_reduce_thread_desc_mperblock_nperblock,
834 make_tuple(I0, I0),
835 c_reduce_thread_buf);
836
837 static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) {
838 auto& p_reduce_grid = p_reduces_grid[In];
839
841 p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize());
842
843 auto reduce_thread_buf =
845 reduce_thread_desc_mperblock.GetElementSpaceSize());
846
847 auto& reduce_in_element_op = reduce_in_element_ops[In];
848
849 auto& reduce_thread_copy_vgpr_to_global =
850 reduce_tuple_thread_copy_vgpr_to_global(In);
851
852 using ReduceOperation = remove_cvref_t<decltype(ReduceOperations{}[In])>;
853 using ThreadwiseReduce =
854 ThreadwiseReduction<FloatReduceAcc,
855 decltype(c_reduce_thread_desc_mperblock_nperblock),
856 decltype(reduce_thread_desc_mperblock),
857 ReduceOperation,
858 false>;
859
860 // Global write Gemm shuffle + reduction
861 const auto reduce_identityVal =
862 ReduceOperation::template GetIdentityValue<FloatReduceAcc>();
863
865 [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
866
867 // reduce in VGPR
870 constexpr auto offset =
871 Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
872 make_tuple(im, in))>{};
873
874 reduce_in_element_op(c_reduce_thread_buf(offset),
875 c_reduce_thread_buf(offset));
876 });
877 });
878
879 ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
880
881 // copy from VGPR to Global
882 reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock,
883 make_tuple(I0, I0),
884 reduce_thread_buf,
885 reduce_grid_desc_mblock_mperblock,
886 reduce_grid_buf);
887
888 if constexpr(access_id < num_access - 1)
889 {
890 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
891 reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow(
892 reduce_grid_desc_mblock_mperblock,
893 make_tuple(c_global_step[I0], c_global_step[I1]));
894 }
895 });
896 }
897
898 if constexpr(access_id < num_access - 1)
899 {
900 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
901
902 // move on C
903 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
904 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
905 }
906 });
907
908 // Reduction
909 }
910 }
911};
912
913} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:40
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:152
__host__ static __device__ constexpr auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:318
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:234
static constexpr auto I2
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:155
static constexpr auto I7
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:160
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:347
static constexpr auto I4
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:157
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:190
static constexpr auto BK0
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:164
static constexpr auto I5
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:158
static constexpr auto I6
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:159
static constexpr auto AK0
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:163
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:168
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const ReduceInElementwiseOperations &reduce_in_element_ops, const ReduceAccElementwiseOperations &reduce_out_element_ops, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock &reduce_grid_desc_mblock_mperblock, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:352
static constexpr auto AK1
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:165
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:205
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:334
static constexpr auto I1
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:154
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:251
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:340
remove_cvref_t< decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))> ReduceGridDescriptor_MBlock_MPerBlock
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:344
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:173
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:299
static constexpr auto I3
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:156
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:181
static constexpr auto I0
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:153
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:291
static constexpr auto BK1
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:166
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVersion::v1, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:170
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition reduction_functions_threadwise.hpp:23
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340