blockwise_gemm_dlops_v3.hpp Source File

blockwise_gemm_dlops_v3.hpp Source File#

Composable Kernel: blockwise_gemm_dlops_v3.hpp Source File
blockwise_gemm_dlops_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
5#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
6
7#include "common_header.hpp"
9
10namespace ck {
11
12template <index_t BlockSize,
13 typename FloatA,
14 typename FloatB,
15 typename FloatC,
16 typename ABlockDesc_E1_K1_E2,
17 typename BBlockDesc_E1_N_Ho_Wo_E2,
18 typename CThreadDesc_K_N_Ho_Wo,
19 index_t EPerThreadLoop,
20 index_t KPerThreadLoop>
22{
23 static constexpr auto I0 = Number<0>{};
24 static constexpr auto I1 = Number<1>{};
25 static constexpr auto I2 = Number<2>{};
26 static constexpr auto I3 = Number<3>{};
27 static constexpr auto I4 = Number<4>{};
28
32
33 static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0);
34 static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1);
35 static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2);
36
37 static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2);
38 static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3);
39
40 static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0);
41 static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2);
42 static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3);
43
46
47 static constexpr auto b_thread_mtx_ =
49 Number<1>{},
52 Number<E2>{}));
53
56
58 : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())},
59 a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)}
60 {
61 static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() &&
62 BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() &&
63 CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(),
64 "wrong! Desc should be known at compile-time");
65
66 static_assert(
67 ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) &&
68 ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4),
69 "wrong! E dimension not consistent\n");
70
71 static_assert(E1 % EPerThreadLoop == 0, "");
72 static_assert(KPerThread % KPerThreadLoop == 0, "");
73
74 static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 &&
76 "wrong! Cannot evenly divide work among\n");
77
78 constexpr auto KThreadCluster = KPerBlock / KPerThread;
79 constexpr auto HThreadCluster = HoPerBlock / HoPerThread;
80 constexpr auto WThreadCluster = WoPerBlock / WoPerThread;
81
82 static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster,
83 "wrong! wrong blocksize\n");
84 }
85
86 __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
87 {
89 }
90
91 __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
92 {
93 constexpr auto K0 = KPerBlock / KPerThread;
94 constexpr auto N0 = I1;
95 constexpr auto H0 = HoPerBlock / HoPerThread;
96 constexpr auto W0 = WoPerBlock / WoPerThread;
97
98 constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor =
103
104 const auto c_k_n_h_w_thread_cluster_idx =
105 c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex(
106 make_multi_index(thread_id));
107
108 return c_k_n_h_w_thread_cluster_idx;
109 }
110
111 template <typename ABlockBuffer, typename BThreadBuffer, typename CThreadBuffer>
112 __device__ void Run(const ABlockBuffer& a_block_buf,
113 const BThreadBuffer& b_thread_buf,
114 CThreadBuffer& c_thread_buf) const
115 {
116 static_assert(
120 "wrong! inconsistent type");
121
122 constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{};
123
124 // thread A buffer for GEMM
125 StaticBuffer<AddressSpaceEnum::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize(), true>
126 a_thread_buf;
127
128 constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
129 FloatB,
130 FloatC,
131 decltype(a_thread_mtx_),
132 decltype(b_thread_mtx_),
133 decltype(c_thread_mtx_)>{};
134
135 static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) {
137 a_thread_copy_.Run(a_block_mtx,
138 make_tuple(e_begin, k_begin, I0),
139 a_block_buf,
141 make_tuple(I0, I0, I0),
142 a_thread_buf);
143
144 threadwise_gemm.Run(a_thread_buf,
145 make_tuple(I0, I0, I0),
146 b_thread_buf,
147 make_tuple(e_begin, I0, I0, I0, I0),
148 c_thread_buf,
149 make_tuple(k_begin, I0, I0, I0));
150 });
151 });
152 }
153
154 template <typename ABlockSliceMoveStepIdx>
155 __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx)
156 {
157 a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx);
158 }
159
160 private:
161 using AThreadCopy =
163 FloatA,
164 ABlockDesc_E1_K1_E2,
165 decltype(a_thread_mtx_),
168 2,
169 E2,
170 E2>;
171
172 CIndex c_thread_origin_data_idx_;
173
174 AThreadCopy a_thread_copy_;
175};
176
177} // namespace ck
178#endif
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
static constexpr auto E1
Definition blockwise_gemm_dlops_v3.hpp:33
static constexpr auto b_thread_mtx_
Definition blockwise_gemm_dlops_v3.hpp:47
__device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx &a_block_slice_move_step_idx)
Definition blockwise_gemm_dlops_v3.hpp:155
static constexpr auto I4
Definition blockwise_gemm_dlops_v3.hpp:27
static constexpr auto E2
Definition blockwise_gemm_dlops_v3.hpp:35
static constexpr auto WoPerBlock
Definition blockwise_gemm_dlops_v3.hpp:38
static constexpr auto KPerBlock
Definition blockwise_gemm_dlops_v3.hpp:34
__device__ void Run(const ABlockBuffer &a_block_buf, const BThreadBuffer &b_thread_buf, CThreadBuffer &c_thread_buf) const
Definition blockwise_gemm_dlops_v3.hpp:112
static __device__ constexpr auto GetCThreadDesc_K_N_Ho_WoLengths()
Definition blockwise_gemm_dlops_v3.hpp:86
static constexpr auto I1
Definition blockwise_gemm_dlops_v3.hpp:24
static constexpr auto HoPerBlock
Definition blockwise_gemm_dlops_v3.hpp:37
MultiIndex< 3 > AIndex
Definition blockwise_gemm_dlops_v3.hpp:29
MultiIndex< 3 > BIndex
Definition blockwise_gemm_dlops_v3.hpp:30
static constexpr auto c_thread_mtx_
Definition blockwise_gemm_dlops_v3.hpp:54
static constexpr auto a_thread_mtx_
Definition blockwise_gemm_dlops_v3.hpp:44
static constexpr auto HoPerThread
Definition blockwise_gemm_dlops_v3.hpp:41
static constexpr auto I0
Definition blockwise_gemm_dlops_v3.hpp:23
static __device__ CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id)
Definition blockwise_gemm_dlops_v3.hpp:91
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
Definition blockwise_gemm_dlops_v3.hpp:57
static constexpr auto KPerThread
Definition blockwise_gemm_dlops_v3.hpp:40
static constexpr auto I3
Definition blockwise_gemm_dlops_v3.hpp:26
MultiIndex< 4 > CIndex
Definition blockwise_gemm_dlops_v3.hpp:31
static constexpr auto I2
Definition blockwise_gemm_dlops_v3.hpp:25
static constexpr auto WoPerThread
Definition blockwise_gemm_dlops_v3.hpp:42
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_gemm_dlops_v3.hpp:29
Definition threadwise_tensor_slice_transfer.hpp:1260
Definition type.hpp:177
Definition functional2.hpp:33