conv_common.hpp Source File

conv_common.hpp Source File#

Composable Kernel: conv_common.hpp Source File
conv_common.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8template <typename... InDesc,
9 typename... WeiDesc,
10 typename ConvStrides,
11 typename ConvDilations,
12 typename LeftPads,
13 typename RightPads>
17 const ConvStrides& conv_strides,
18 const ConvDilations conv_dilations,
19 const LeftPads& left_pads,
20 const RightPads& right_pads)
21{
22 using namespace ck;
23
24 constexpr auto I0 = Number<0>{};
25 constexpr auto I1 = Number<1>{};
26 constexpr auto I2 = Number<2>{};
27 constexpr auto I3 = Number<3>{};
28
29 assert(in_desc.GetNumOfDimension() == 4);
30 assert(wei_desc.GetNumOfDimension() == 4);
31 assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1));
32
33 const auto N = in_desc.GetLength(I0);
34 const auto Hi = in_desc.GetLength(I2);
35 const auto Wi = in_desc.GetLength(I3);
36
37 const auto K = wei_desc.GetLength(I0);
38 const auto Y = wei_desc.GetLength(I2);
39 const auto X = wei_desc.GetLength(I3);
40
41 const auto LeftPadH = left_pads[I0];
42 const auto LeftPadW = left_pads[I1];
43
44 const auto RightPadH = right_pads[I0];
45 const auto RightPadW = right_pads[I1];
46
47 const auto YEff = (Y - I1) * conv_dilations[I0] + I1;
48 const auto XEff = (X - I1) * conv_dilations[I1] + I1;
49
50 const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1;
51 const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1;
52
54}
55
56template <class InDesc, class WeiDesc, class OutDesc>
57constexpr std::size_t
58calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc)
59{
60 using namespace ck;
61
62 constexpr auto I0 = Number<0>{};
63 constexpr auto I1 = Number<1>{};
64 constexpr auto I2 = Number<2>{};
65 constexpr auto I3 = Number<3>{};
66
67 const index_t N = out_desc.GetLength(I0);
68 const index_t K = out_desc.GetLength(I1);
69 const index_t Ho = out_desc.GetLength(I2);
70 const index_t Wo = out_desc.GetLength(I3);
71
72 const index_t C = wei_desc.GetLength(I1);
73 const index_t Y = wei_desc.GetLength(I2);
74 const index_t X = wei_desc.GetLength(I3);
75
76 return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
77}
constexpr auto get_convolution_output_default_4d_tensor_descriptor(const ck::TensorDescriptor< InDesc... > &in_desc, const ck::TensorDescriptor< WeiDesc... > &wei_desc, const ConvStrides &conv_strides, const ConvDilations conv_dilations, const LeftPads &left_pads, const RightPads &right_pads)
Definition conv_common.hpp:14
constexpr std::size_t calculate_convolution_flops(const InDesc &, const WeiDesc &wei_desc, const OutDesc &out_desc)
Definition conv_common.hpp:58
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
Definition tensor_description/tensor_descriptor.hpp:28
__host__ static __device__ constexpr index_t GetNumOfDimension()
Definition tensor_description/tensor_descriptor.hpp:141
__host__ __device__ constexpr auto GetLength(Number< IDim >) const
Definition tensor_description/tensor_descriptor.hpp:147