block_fmha_pipeline_qr_ks_vs_async.hpp Source File

block_fmha_pipeline_qr_ks_vs_async.hpp Source File#

Composable Kernel: block_fmha_pipeline_qr_ks_vs_async.hpp Source File
block_fmha_pipeline_qr_ks_vs_async.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
12
13namespace ck_tile {
14
15// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
16template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
18{
34
37 static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
38 static_assert(kQLoadOnce == Policy::QLoadOnce);
39
40 static constexpr index_t kBlockSize = Problem::kBlockSize;
41
42 static constexpr index_t kM0 = BlockFmhaShape::kM0;
43 static constexpr index_t kN0 = BlockFmhaShape::kN0;
44 static constexpr index_t kK0 = BlockFmhaShape::kK0;
45 static constexpr index_t kN1 = BlockFmhaShape::kN1;
46 static constexpr index_t kK1 = BlockFmhaShape::kK1;
47 static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
48 static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
49
50 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
51
52 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
53 // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
54 // only need special care about seq_k padding (oob need set -INF of p instead of zero)
55 static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
56 Problem::kPadHeadDimV == true);
57 static constexpr bool kPadSeqLenQ = true;
58 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
59 static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
60 static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
61 static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
62 static constexpr auto BiasEnum = Problem::BiasEnum;
63 static constexpr bool kStoreLSE = Problem::kStoreLSE;
64 static constexpr bool kHasDropout = Problem::kHasDropout;
65
66 static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
67 (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
70
71 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
72 // ... together with tensor distribution. tensor dist should able to overwrite this
73 static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
74 static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
75 static constexpr index_t kAlignmentV = []() {
76 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
77 return Policy::template GetAlignmentV<Problem>();
78 else
79 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
80 }();
81 static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
82 static constexpr index_t kAlignmentBias =
83 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
84
85#if CK_TILE_FMHA_FWD_FAST_EXP2
86 static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
87#endif
88
89 static constexpr index_t kBlockPerCu = []() {
90 if constexpr(Problem::kBlockPerCu != -1)
91 return Problem::kBlockPerCu;
92 else
93 {
94 // minimize occupancy
96 {
97 return 1;
98 }
99
100 if constexpr(kQKHeaddim <= 32)
101 {
103 FmhaMask::IsMasking)
104 return 1;
105 else
106 return 2;
107 }
108 else if constexpr(kQKHeaddim <= 64)
109 {
111 return 2;
112 else
113 return 3;
114 }
115 else if constexpr(kQKHeaddim <= 128)
116 {
118 return 1;
119 else
120 return 2;
121 }
122 else if constexpr(kQKHeaddim <= 192)
123 {
125 return 1;
126 else
127 return 2;
128 }
129 else if constexpr(kQKHeaddim <= 256)
130 {
131 return 1;
132 }
133 else
134 {
135 return 1;
136 };
137 }
138 }();
139
140 static constexpr const char* name = "qr_async";
141
142 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
143
145 {
146 return Policy::template GetSmemSize<Problem>();
147 }
148
149 template <typename QDramBlockWindowTmp,
150 typename KDramBlockWindowTmp,
151 typename VDramBlockWindowTmp,
152 typename BiasDramBlockWindowTmp,
153 typename RandValDramBlockWindowTmp,
154 typename LSEDramBlockWindowTmp,
155 typename QElementFunction,
156 typename KElementFunction,
157 typename VElementFunction,
158 typename BiasElementFunction,
159 typename LSEElementFunction,
160 typename SAccElementFunction,
161 typename PComputeElementFunction,
162 typename OAccElementFunction,
163 typename PositionEncoding,
164 typename AttentionVariantParams,
165 typename BlockIndices>
167 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
168 const QElementFunction& q_element_func,
169 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
170 const KElementFunction& /*k_element_func*/,
171 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
172 const VElementFunction& v_element_func,
173 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
174 const BiasElementFunction& bias_element_func,
175 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
176 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
177 const LSEElementFunction& lse_element_func,
178 const SAccElementFunction& s_acc_element_func,
179 const PComputeElementFunction& p_compute_element_func,
180 const OAccElementFunction& o_acc_element_func,
181 FmhaMask mask,
182 PositionEncoding position_encoding,
183 float scale_s,
184 const AttentionVariant& variant,
185 const AttentionVariantParams& variant_params,
186 const BlockIndices& block_indices,
187 void* smem_ptr,
188 DropoutType& dropout) const
189 {
190 static_assert(
191 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
192 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
193 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
194 "wrong!");
195
196 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
197 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
198 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
199 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
200 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
201 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
202 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
203 "wrong!");
204
205 constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
206
207 // K tile in LDS
208 auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
209 auto k_lds_store = generate_tuple(
210 [&](auto i_buf) {
211 return make_tile_window(
213 k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
214 Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
215 {0, 0, 0});
216 },
218
219 auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
220 k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
221
222 auto k_lds_load =
223 make_tile_window(k_lds_Load_view,
224 Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
225 {0, 0});
226
227 // V tile in LDS
229 reinterpret_cast<VDataType*>(smem_ptr),
230 Policy::template MakeVLdsBlockDescriptor<Problem>());
231 auto v_lds_window = make_tile_window(
232 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
233
234 // Block GEMM
235 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
236 constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
237
238 auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
239 q_dram_block_window_tmp.get_window_lengths(),
240 q_dram_block_window_tmp.get_window_origin(),
241 Policy::template MakeQRegTileDistribution<Problem>());
242 q_dram_window.init_raw();
243
244 // TODO: we use async Copy for K, which is inline asm
245 // a side effect is we have to use inline asm for q as well
246 auto q = decltype(load_tile(q_dram_window)){};
247 // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
248 // however, q would be cleared in the constructor of static distributed tensor
249 // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
250 load_tile_raw(q, q_dram_window);
251 __builtin_amdgcn_sched_barrier(0);
252
253 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
254 auto s_acc = SaccBlockTileType{};
255
256 // reduction function for softmax
257 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
258 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
259
260 // infer Sacc, S, P, M, L, Oacc type
261 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
262
263 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
264 SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
265
266 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
267
268 // init Oacc, M, L
269 auto o_acc = OaccBlockTileType{};
270 auto m = MLBlockTileType{};
271 auto l = MLBlockTileType{};
272
273 clear_tile(o_acc);
275 clear_tile(l);
276
277 __builtin_amdgcn_sched_barrier(0);
278 const auto q_origin = q_dram_window.get_window_origin();
279 const auto [seqlen_k_start, seqlen_k_end] =
280 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
281
282 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
283
284 // check early exit if no work to do
285 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
286 {
287 if(num_total_loop <= 0)
288 {
289 if constexpr(kStoreLSE)
290 {
291 auto lse =
292 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
293
295
296 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
297 }
298 buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
299 // otherwise will have compute error(maybe compiler bug?)
300
301 // Note: here occ are all cleard, return it
302 return o_acc;
303 }
304 __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
305 }
306
307 auto k_dram_block_window =
308 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
309 k_dram_block_window_tmp.get_window_lengths(),
310 {seqlen_k_start, 0});
311
312 auto k_dram_window = make_tile_window(
313 k_dram_block_window.get_bottom_tensor_view(),
314 k_dram_block_window.get_window_lengths(),
315 k_dram_block_window.get_window_origin(),
316 Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
317 // load
318 k_dram_window.init_raw();
319 constexpr auto k_oob_ck = bool_constant<true>{};
320 constexpr auto k_pre_np = [&]() {
321 if constexpr(kPadSeqLenK &&
324 return bool_constant<true>{};
325 else
326 return bool_constant<false>{};
327 }();
328
329 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
330 auto bias_dram_window =
331 make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
332 bias_dram_block_window_tmp.get_window_lengths(),
333 {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
334 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
335
336 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
337 randval_dram_block_window_tmp, seqlen_k_start);
338
339 auto v_dram_window =
340 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
341 v_dram_block_window_tmp.get_window_lengths(),
342 {0, seqlen_k_start}, // TODO: hdim split?
343 Policy::template MakeVDramTileDistribution<Problem>());
344
345 // prefetch K tile
347 k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
348 move_tile_window(k_dram_window, {0, kK0});
349 __builtin_amdgcn_sched_barrier(0);
350
351 buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
352 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
353 // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
354
355 index_t i_total_loops = 0;
356 constexpr index_t k0_loops = kQKHeaddim / kK0;
357 constexpr index_t k1_loops = kN0 / kK1;
358
359 static_assert(1 <= k0_loops);
360 static_assert(1 <= k1_loops);
361 // main loop
362 do
363 {
364 // STAGE 1, QK gemm
365 clear_tile(s_acc); // initialize C
366 if constexpr(k0_loops > 1)
367 {
368 static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
369 async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
370 k_dram_window,
371 number<-1>{},
372 k_oob_ck,
373 k_pre_np);
374 if constexpr(i_k0 < k0_loops - 1)
375 move_tile_window(k_dram_window, {0, kK0});
376
377 async_load_fence(k_dram_window.get_num_of_access());
378 __builtin_amdgcn_s_barrier();
379 __builtin_amdgcn_sched_barrier(0);
380 gemm_0(s_acc,
382 q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
383 get_slice_tile(k_lds_load,
384 sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
385 sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
386 });
387 }
388
389 // TODO: this to fix a bug when loop smaller than 2,
390 // the following fence/barrier will be scheduled inside 1st loop
391 if constexpr(k0_loops <= 2)
392 __builtin_amdgcn_sched_barrier(0);
393
395 __builtin_amdgcn_s_barrier();
396
397 const auto bias_tile = load_tile(bias_dram_window); // load bias tile
398 auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
399 __builtin_amdgcn_sched_barrier(0);
400 { // tail
401 gemm_0(
402 s_acc,
404 q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
405 get_slice_tile(k_lds_load,
406 sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
407 sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
408 }
409 __builtin_amdgcn_sched_barrier(1);
410
411 // STAGE 2, scale_s, add bias, mask, softmax
413 {
414 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
415 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
417 [&](auto& x, const auto& y) {
418#if !CK_TILE_FMHA_FWD_FAST_EXP2
419 x += type_convert<SaccDataType>(bias_element_func(y));
420#else
422 type_convert<SaccDataType>(bias_element_func(y));
423#endif
424 },
425 s_acc,
426 bias_tile);
427 }
428 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
429 {
430 const auto k_origin = k_dram_block_window.get_window_origin();
431 constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
432 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
433 sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
434 sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
435 const auto tile_idx = get_x_indices_from_distributed_indices(
436 s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
437
438 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
439 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
440 constexpr auto i_j_idx = make_tuple(idx0, idx1);
441
442 s_acc(i_j_idx) *= scale_s;
443 position_encoding.update(s_acc(i_j_idx), row, col);
444 });
445 });
446 }
447 else
448 {
449 s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
450 if constexpr(kHasLogitsSoftCap)
451 {
452 auto apply_logits_transform =
453 [&variant, &variant_params, &block_indices](auto& x) {
454 x = variant.LogitsTransform(variant_params,
455 variant.QueryTransform(variant_params, x),
456 block_indices.batch_idx,
457 block_indices.qo_head_idx,
458 block_indices.kv_head_idx);
459 };
460#if !CK_TILE_FMHA_FWD_FAST_EXP2
461 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
462 {
463 apply_logits_transform(s_acc.thread_buf_[i]);
464 }
465#else
466 for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
467 {
468 apply_logits_transform(s_acc.thread_buf_[i]);
469 }
470#endif
471 }
472 else
473 {
474#if !CK_TILE_FMHA_FWD_FAST_EXP2
475 tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
476#endif
477 }
478 }
479 move_tile_window(bias_dram_window, {0, kN0});
480 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
481 {
482 const auto k_origin = k_dram_block_window.get_window_origin();
483 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
484 k_origin.at(number<0>{}),
485 number<kM0>{},
486 number<kN0>{});
487
488 if(need_perpixel_check)
489 {
491 s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
492 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
493 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
494 return !variant.LogitsMask(variant_params,
495 block_indices.batch_idx,
496 row,
497 col,
498 block_indices.qo_head_idx,
499 block_indices.kv_head_idx);
500 });
501 }
502 }
503
504 const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
506 s,
507 sequence<1>{},
508 f_max,
509 -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
511
512 const auto m_old = m; // m{j-1}
514 [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
515
517 s.get_tile_distribution()); // Pcompute{j}
518
519 __builtin_amdgcn_sched_barrier(0x7F);
520 // store & prefetch next v, after the max reduction
521 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
522 {
524 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
525 shuffle_tile(v_shuffle_tmp, v_buf);
526
527 auto v_lds_window_tmp =
528 get_slice_tile(v_lds_window,
529 sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
530 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
531
533 v_lds_window_tmp,
534 tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
535 }
536 else
537 {
538 auto v_lds_window_tmp =
539 get_slice_tile(v_lds_window,
540 sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
541 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
542 store_tile(v_lds_window_tmp,
543 tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
544 }
545
546 if constexpr(k1_loops > 1)
547 {
549 v_dram_window,
550 {0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
551 v_buf = load_tile(
552 v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
553 }
554 __builtin_amdgcn_sched_barrier(0);
555
556 static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
560 FmhaMask::IsMasking)
561 {
564 : raw_m;
565 }
566 else
567 {
568 return raw_m;
569 }
570 };
571
572 constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
573 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
574 constexpr auto i_idx = make_tuple(idx0);
575#if CK_TILE_FMHA_FWD_FAST_EXP2
576 auto row_max = scale_s * get_validated_m(m[i_idx]);
577#endif
578 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
579 constexpr auto i_j_idx = make_tuple(idx0, idx1);
580#if CK_TILE_FMHA_FWD_FAST_EXP2
583 {
584 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
585 }
586 else
587 {
588 if constexpr(kHasLogitsSoftCap)
589 {
590 p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
591 }
592 else
593 {
594 p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
595 }
596 }
597#else
598 p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
599#endif
600 });
601 });
602
604 p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
605
607 // l{j}, Oacc{j}
608 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
609 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
610 constexpr auto i_idx = make_tuple(idx0);
611#if CK_TILE_FMHA_FWD_FAST_EXP2
612 const auto tmp = [&]() {
615 {
616 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
617 }
618 else
619 {
620 if constexpr(kHasLogitsSoftCap)
621 {
622 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
623 }
624 else
625 {
626 auto row_max = scale_s * get_validated_m(m[i_idx]);
627 return exp2(scale_s * m_old[i_idx] - row_max);
628 }
629 }
630 }();
631#else
632 const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
633#endif
634 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
635 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
636 constexpr auto i_j_idx = make_tuple(idx0, idx1);
637 // FIXME: this use different equation from FA v2 paper,
638 // but produce correc result.
639 // Is the equation wrong?
640 o_acc(i_j_idx) *= tmp;
641 });
642 });
643
644 if constexpr(kHasDropout)
645 {
646 auto randval_ptr =
647 reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
648 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
649 randval_ptr,
650 seqlen_k_start + i_total_loops * kN0,
651 p_compute,
652 randval_dram_window);
653 }
654
655 const auto p = [&]() {
656#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
657 // For fp32 to fp16,
658 // impl::cast_tile_pk_fp16_fp32 would cause precision issue,
659 // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
660 return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
661#else
662 if constexpr(std::is_same_v<PDataType, fp16_t>)
664 tile_elementwise_in(p_compute_element_func, p_compute));
665 else
667 tile_elementwise_in(p_compute_element_func, p_compute));
668#endif
669 }();
670
671 // STAGE 3, KV gemm
672 if constexpr(k1_loops > 1)
673 {
674 static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
675 if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
676 {
677 v_buf = load_tile(
678 v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
679 }
681 gemm_1(o_acc,
683 p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
685 v_lds_window,
686 sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
687 sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
688
689 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
690 {
692 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
693 shuffle_tile(v_shuffle_tmp, v_buf);
694 auto v_lds_window_tmp = get_slice_tile(
695 v_lds_window,
696 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
697 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
698 store_tile(v_lds_window_tmp,
699 tile_elementwise_in(v_element_func,
700 v_shuffle_tmp)); // store the prefetch
701 }
702 else
703 {
704 auto v_lds_window_tmp = get_slice_tile(
705 v_lds_window,
706 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
707 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
708 store_tile(v_lds_window_tmp,
709 tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
710 }
711 if constexpr(i_k1 < k1_loops - 1)
712 move_tile_window(v_dram_window, {0, kK1});
713 });
714 }
715 i_total_loops++;
716 if(i_total_loops < num_total_loop)
717 {
718 // move K tile windows
719 move_tile_window(k_dram_block_window, {kN0, 0});
720 k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
721
722 if constexpr(k1_loops >= 2 &&
723 LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
724 __builtin_amdgcn_s_barrier();
725 async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
726 k_dram_window,
727 number<-1>{},
728 k_oob_ck,
729 k_pre_np);
730 move_tile_window(k_dram_window, {0, kK0});
731 }
732 // tail
733 {
735 gemm_1(
736 o_acc,
737 get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
739 v_lds_window,
740 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
741 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
742 }
743 } while(i_total_loops < num_total_loop);
744
745 // store lse
746 if constexpr(kStoreLSE)
747 {
748 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
749
750 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
751 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
752 constexpr auto i_idx = make_tuple(idx0);
753#if CK_TILE_FMHA_FWD_FAST_EXP2
756 {
757 lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
758 }
759 else
760 {
761 if constexpr(kHasLogitsSoftCap)
762 {
763 lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
764 }
765 else
766 {
767 lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
768 }
769 }
770#else
771 lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
772#endif
773 });
774
775 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
776 }
777
778 // finally, O
779 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
780
781 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
782 constexpr auto i_idx = make_tuple(idx0);
783 const auto tmp = [&]() {
784 if constexpr(FmhaMask::IsMasking)
785 {
786 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
787 }
788 else
789 return 1 / l[i_idx];
790 }();
791 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
792 constexpr auto i_j_idx = make_tuple(idx0, idx1);
793 o_acc(i_j_idx) *= tmp;
794 });
795 });
796
797 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
798
799 return o_acc;
800 }
801
802 template <typename QDramBlockWindowTmp,
803 typename KDramBlockWindowTmp,
804 typename VDramBlockWindowTmp,
805 typename BiasDramBlockWindowTmp,
806 typename RandValDramBlockWindowTmp,
807 typename LSEDramBlockWindowTmp,
808 typename PositionEncoding,
809 typename AttentionVariantParams,
810 typename BlockIndices>
812 operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
813 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
814 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
815 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
816 RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
817 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
818 FmhaMask mask,
819 PositionEncoding position_encoding,
820 float scale_s,
821 const AttentionVariant& variant,
822 const AttentionVariantParams& variant_params,
823 const BlockIndices& block_indices,
824 void* smem_ptr,
825 DropoutType& dropout) const
826 {
827 return operator()(q_dram_block_window_tmp,
828 identity{},
829 k_dram_block_window_tmp,
830 identity{},
831 v_dram_block_window_tmp,
832 identity{},
833 bias_dram_block_window_tmp,
834 identity{},
835 randval_dram_block_window_tmp,
836 lse_dram_block_window_tmp,
837 identity{},
838 identity{},
839 identity{},
840 identity{},
841 mask,
842 position_encoding,
843 scale_s,
844 variant,
845 variant_params,
846 block_indices,
847 smem_ptr,
848 dropout);
849 }
850};
851
852} // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor &in_dstr_tensors)
Definition tile_elementwise.hpp:231
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE auto async_load_fence(index_t cnt=0)
Definition load_tile.hpp:145
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition shuffle_tile.hpp:154
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition bfloat16.hpp:419
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition load_tile.hpp:81
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition tile/core/arch/amd_buffer_addressing.hpp:815
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:18
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:40
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:60
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:22
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:24
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:59
static constexpr index_t kAlignmentO
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:81
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:82
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:33
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:75
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:31
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:35
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:26
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:64
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:25
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:73
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:47
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:58
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:29
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:42
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:62
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:167
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:142
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:44
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:19
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:37
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:89
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:74
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:27
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:57
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:140
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:30
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:144
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:63
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:46
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:32
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:23
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:20
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:61
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:812
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:28
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:36
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:52
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:43
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:45
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:48
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_async.hpp:21
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49