layernorm2d_fwd_traits.hpp Source File

layernorm2d_fwd_traits.hpp Source File#

Composable Kernel: layernorm2d_fwd_traits.hpp Source File
layernorm2d_fwd_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck_tile {
9
11{
13 // add bias before fused add
15};
16
17// clang-format off
18template<Layernorm2dXBiasEnum> struct Layernorm2dXBiasEnumName;
19template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::NO_BIAS> { static constexpr const char * name = "no"; };
20template<> struct Layernorm2dXBiasEnumName<Layernorm2dXBiasEnum::ADD_BIAS> { static constexpr const char * name = "xbias"; };
21// clang-format on
22
24{
25 NO_ADD = 0,
26 // fused add before layernorm and store result to global
28 // fused add before layernorm, but not store result
30};
31
32// clang-format off
33template<Layernorm2dFusedAddEnum> struct Layernorm2dFusedAddEnumName;
34template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::NO_ADD> { static constexpr const char * name = "no"; };
35template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD_STORE> { static constexpr const char * name = "pras"; };
36template<> struct Layernorm2dFusedAddEnumName<Layernorm2dFusedAddEnum::PRE_ADD> { static constexpr const char * name = "pra"; };
37// clang-format on
38
40{
42 SMOOTH_DYNAMIC_QUANT = 1, // smooth oulier + rowwise quant, need input x-scale and store y_scale
43 DYNAMIC_QUANT = 2, // rowwise quant, store out a y-scale
44};
45
46// clang-format off
47template<Layernorm2dFusedQuantEnum> struct Layernorm2dFusedQuantEnumName;
48template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::NO_SWEEP> { static constexpr const char * name = "no"; };
49template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::DYNAMIC_QUANT> { static constexpr const char * name = "dqt"; };
50template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT> { static constexpr const char * name = "smdqt"; };
51// clang-format on
52
53template <bool kPadN_,
54 bool kSaveMeanInvStd_,
55 bool kFastFDiv_,
56 bool kWelford_,
57 bool kTwoPass_,
59 Layernorm2dFusedAddEnum kFusedAdd_,
60 Layernorm2dFusedQuantEnum kFusedQuant_>
62{
63 static constexpr bool kPadN = kPadN_;
64 static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
65 static constexpr bool kFastFDiv = kFastFDiv_;
66 static constexpr bool kWelford = kWelford_;
67 static constexpr bool kTwoPass = kTwoPass_;
68 static constexpr Layernorm2dXBiasEnum kXbias = kXbias_;
69 static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_;
70 static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_;
71};
72
73} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
Layernorm2dFusedQuantEnum
Definition layernorm2d_fwd_traits.hpp:40
@ NO_SWEEP
Definition layernorm2d_fwd_traits.hpp:41
@ SMOOTH_DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:42
@ DYNAMIC_QUANT
Definition layernorm2d_fwd_traits.hpp:43
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
Layernorm2dXBiasEnum
Definition layernorm2d_fwd_traits.hpp:11
@ ADD_BIAS
Definition layernorm2d_fwd_traits.hpp:14
Layernorm2dFusedAddEnum
Definition layernorm2d_fwd_traits.hpp:24
@ PRE_ADD_STORE
Definition layernorm2d_fwd_traits.hpp:27
@ PRE_ADD
Definition layernorm2d_fwd_traits.hpp:29
@ NO_ADD
Definition layernorm2d_fwd_traits.hpp:25
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:34
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:36
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:35
Definition layernorm2d_fwd_traits.hpp:33
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:49
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:48
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:50
Definition layernorm2d_fwd_traits.hpp:47
Definition layernorm2d_fwd_traits.hpp:62
static constexpr Layernorm2dFusedAddEnum kFusedAdd
Definition layernorm2d_fwd_traits.hpp:69
static constexpr bool kPadN
Definition layernorm2d_fwd_traits.hpp:63
static constexpr bool kFastFDiv
Definition layernorm2d_fwd_traits.hpp:65
static constexpr Layernorm2dFusedQuantEnum kFusedQuant
Definition layernorm2d_fwd_traits.hpp:70
static constexpr Layernorm2dXBiasEnum kXbias
Definition layernorm2d_fwd_traits.hpp:68
static constexpr bool kSaveMeanInvStd
Definition layernorm2d_fwd_traits.hpp:64
static constexpr bool kWelford
Definition layernorm2d_fwd_traits.hpp:66
static constexpr bool kTwoPass
Definition layernorm2d_fwd_traits.hpp:67
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:20
static constexpr const char * name
Definition layernorm2d_fwd_traits.hpp:19
Definition layernorm2d_fwd_traits.hpp:18