device_avgpool2d_bwd_nhwc_nhwc.hpp Source File

device_avgpool2d_bwd_nhwc_nhwc.hpp Source File#

Composable Kernel: device_avgpool2d_bwd_nhwc_nhwc.hpp Source File
device_avgpool2d_bwd_nhwc_nhwc.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
17
18namespace ck {
19namespace tensor_operation {
20namespace device {
21
22// In and Din = [N, C, Hi, Wi]
23// Out and Dout = [N, C, Ho, Wo]
24// Out = AvgPool2dFwd(In)
25// Din = AvgPool2dBwd(Dout)
26// Pooling dimension = H, W
27template <typename DOutDataType,
28 typename DInDataType,
29 typename ComputeDataType,
30 ck::index_t BlockSize,
31 ck::index_t MThreadClusterSize,
32 ck::index_t KThreadClusterSize,
33 ck::index_t MThreadSliceSize,
34 ck::index_t KThreadSliceSize,
35 ck::index_t InSrcOutDstVectorSize>
37 DOutDataType,
38 DInDataType,
39 tensor_layout::convolution::NHWC,
40 tensor_layout::convolution::NHWC>
41{
42
43 static constexpr ck::index_t NDimSpatial = 2;
44
45 static constexpr auto I0 = Number<0>{};
46 static constexpr auto I1 = Number<1>{};
47
48 static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
49 static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
50
51 static auto
52 Make2DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
53 const std::vector<ck::index_t>& din_n_c_wos_length,
54 const std::vector<ck::index_t>& dout_n_c_wos_strides,
55 const std::vector<ck::index_t>& din_n_c_wos_strides,
56 const std::vector<ck::index_t>& window_lengths,
57 const std::vector<ck::index_t>& window_strides,
58 const std::vector<ck::index_t>& window_dilations,
59 const std::vector<ck::index_t>& input_left_pads,
60 const std::vector<ck::index_t>& input_right_pads,
61 const std::vector<ck::index_t>& tildes)
62 {
63 index_t i_ytilde = tildes[0];
64 index_t i_xtilde = tildes[1];
65
66 const index_t N = dout_n_c_wos_lengths[0];
67 const index_t C = dout_n_c_wos_lengths[1];
68 const index_t Ho = dout_n_c_wos_lengths[2];
69 const index_t Wo = dout_n_c_wos_lengths[3];
70
71 const index_t Hi = din_n_c_wos_length[2];
72 const index_t Wi = din_n_c_wos_length[3];
73
74 const index_t Y = window_lengths[0];
75 const index_t X = window_lengths[1];
76
77 const index_t InLeftPadH = input_left_pads[0];
78 const index_t InLeftPadW = input_left_pads[1];
79
80 const index_t InRightPadH = input_right_pads[0];
81 const index_t InRightPadW = input_right_pads[1];
82
83 const index_t ConvStrideH = window_strides[0];
84 const index_t ConvStrideW = window_strides[1];
85
86 const index_t ConvDilationH = window_dilations[0];
87 const index_t ConvDilationW = window_dilations[1];
88
89 const index_t Ni_stride = dout_n_c_wos_strides[0];
90 const index_t Ci_stride = dout_n_c_wos_strides[1];
91 const index_t Ho_stride = dout_n_c_wos_strides[2];
92 const index_t Wo_stride = dout_n_c_wos_strides[3];
93
94 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
95 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
96
97 const auto YTilde = ConvStrideH / GcdStrideDilationH;
98 const auto XTilde = ConvStrideW / GcdStrideDilationW;
99
100 const auto YDot = math::integer_divide_ceil(Y, YTilde);
101 const auto XDot = math::integer_divide_ceil(X, XTilde);
102
103 const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
104 const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
105
106 // only work on Tildes that contribute to non-padding area of input tensor
107 const auto IHTildeSliceBegin = math::integer_divide_floor(
108 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
109 const auto IWTildeSliceBegin = math::integer_divide_floor(
110 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
111
112 const auto IHTildeSliceEnd =
113 math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
114 const auto IWTildeSliceEnd =
115 math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
116
117 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
118 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
119
120 // ReduceK is different for each Reduce
121 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
122 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
123
124 // Problem size of reduction kernel
125 const index_t MRaw = N * HTildeSlice * WTildeSlice * C;
126 const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
127
128 const index_t KRaw = YDotSlice * XDotSlice;
129 const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
130
131 const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor(
132 make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride));
133
134 // Out[ReduceM, ReduceK]
135 const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor(
136 out_n_ho_wo_c_grid_desc,
143
144 const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor(
145 out_n_hop_wop_c_grid_desc,
147 make_embed_transform(make_tuple(YDot, HTilde),
148 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
149 make_embed_transform(make_tuple(XDot, WTilde),
150 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
154
155 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
157 out_n_ydot_htilde_xdot_wtilde_c_grid_desc,
159 make_slice_transform(YDot, I0, YDotSlice),
160 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
161 make_slice_transform(XDot, I0, XDotSlice),
162 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
165 Sequence<1>{},
166 Sequence<2>{},
167 Sequence<3>{},
168 Sequence<4>{},
169 Sequence<5>{}),
171 Sequence<1>{},
172 Sequence<2>{},
173 Sequence<3>{},
174 Sequence<4>{},
175 Sequence<5>{}));
176
177 const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
178 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
179 make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)),
180 make_merge_transform(make_tuple(YDotSlice, XDotSlice))),
183
184 const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
185 out_grid_desc_reducemraw_reducekraw,
189
190 // In[ReduceM]
191 const auto in_n_hi_wi_c_grid_desc =
193 make_tuple(din_n_c_wos_strides[0],
194 din_n_c_wos_strides[2],
195 din_n_c_wos_strides[3],
196 din_n_c_wos_strides[1]));
197
198 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
199 in_n_hi_wi_c_grid_desc,
201 make_pad_transform(Hi, InLeftPadH, InRightPadH),
202 make_pad_transform(Wi, InLeftPadW, InRightPadW),
206
207 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
208 in_n_hip_wip_c_grid_desc,
210 make_embed_transform(make_tuple(YTilde, HTilde),
211 make_tuple(ConvDilationH, ConvStrideH)),
212 make_embed_transform(make_tuple(XTilde, WTilde),
213 make_tuple(ConvDilationW, ConvStrideW)),
217
218 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
219 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
221 make_freeze_transform(i_ytilde),
222 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
223 make_freeze_transform(i_xtilde),
224 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
227 Sequence<1>{},
228 Sequence<2>{},
229 Sequence<3>{},
230 Sequence<4>{},
231 Sequence<5>{}),
233 Sequence<>{},
234 Sequence<1>{},
235 Sequence<>{},
236 Sequence<2>{},
237 Sequence<3>{}));
238
239 const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
240 in_n_htildeslice_wtildeslice_c_grid_desc,
241 make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))),
244
245 const auto in_grid_desc_reducem =
246 transform_tensor_descriptor(in_grid_desc_reducemraw,
250
251 return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
252 }
253
255 {0, 0, 0, 0},
256 {0, 0, 0, 0},
257 {0, 0, 0, 0},
258 {0, 0},
259 {0, 0},
260 {0, 0},
261 {0, 0},
262 {0, 0},
263 {0, 0}));
264
267
268 // FIXME
269 // for NHWC, the dim C is the fastest dimension, and is not reduced.
270 // Hence, it is in M dimension for reduction kernel.
271 static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
272
275
277 DInDataType,
278 ComputeDataType,
279 int,
284 Div,
286 false, // propagate_nan
287 BlockSize,
288 MThreadSliceSize,
289 KThreadSliceSize,
291 InSrcOutDstVectorSize,
292 InSrcOutDstVectorSize>;
293
294 struct Argument : public BaseArgument
295 {
296 Argument(const DOutDataType* p_dout,
297 DInDataType* p_din,
298 std::vector<ck::index_t> dout_n_c_wos_lengths,
299 std::vector<ck::index_t> din_n_c_wos_length,
300 std::vector<ck::index_t> dout_n_c_wos_strides,
301 std::vector<ck::index_t> din_n_c_wos_strides,
302 std::vector<ck::index_t> window_lengths,
303 std::vector<ck::index_t> window_strides,
304 std::vector<ck::index_t> window_dilations,
305 std::vector<ck::index_t> input_left_pads,
306 std::vector<ck::index_t> input_right_pads)
307 : p_dout_grid_{p_dout},
308 p_din_grid_{p_din},
309 dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
310 din_n_c_wos_length_{din_n_c_wos_length},
311 dout_n_c_wos_strides_{dout_n_c_wos_strides},
312 din_n_c_wos_strides_{din_n_c_wos_strides},
313 num_reduce_{1},
314 div_element_op_{window_lengths[0] * window_lengths[1]}
315 {
316 std::vector<ck::index_t> Tildes(NDimSpatial);
317 for(int i = 0; i < NDimSpatial; ++i)
318 {
319 int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
320 Tildes[i] = window_strides[i] / GcdStrideDilation;
321 num_reduce_ *= Tildes[i];
322 }
323
324 for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde)
325 {
326 for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde)
327 {
328 const auto YDotSlice =
329 math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]);
330 const auto XDotSlice =
331 math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]);
332
333 if(YDotSlice * XDotSlice <= 0)
334 {
335 continue;
336 }
337
338 const auto dout_din_grid_desc =
339 Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
340 din_n_c_wos_length,
341 dout_n_c_wos_strides,
342 din_n_c_wos_strides,
343 window_lengths,
344 window_strides,
345 window_dilations,
346 input_left_pads,
347 input_right_pads,
348 {i_ytilde, i_xtilde});
349
350 dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
351 din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
352 }
353 }
354 }
355
356 const DOutDataType* p_dout_grid_;
357 DInDataType* p_din_grid_;
358 std::vector<ck::index_t> dout_n_c_wos_lengths_;
359 std::vector<ck::index_t> din_n_c_wos_length_;
360 std::vector<ck::index_t> dout_n_c_wos_strides_;
361 std::vector<ck::index_t> din_n_c_wos_strides_;
362
364 std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
365 std::vector<DinGridDesc_M> din_grid_desc_m_container_;
366
368 };
369
370 struct Invoker : public BaseInvoker
371 {
372 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
373 {
374 float ave_time = 0;
375
376 for(index_t i = 0; i < arg.num_reduce_; i++)
377 {
378 const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
379 false,
380 false,
381 false, // don't have index input
382 DOutDataType,
383 DInDataType,
384 ComputeDataType,
385 int,
389 Div>;
390
391 ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
392 const index_t grid_size = (M / M_BlockTileSize);
393
394 ave_time += launch_and_time_kernel(stream_config,
395 kernel,
396 dim3(grid_size),
397 dim3(BlockSize),
398 0,
401 PassThrough{},
402 arg.div_element_op_,
403 float(1),
404 arg.p_dout_grid_,
405 nullptr,
406 float(0),
407 arg.p_din_grid_,
408 nullptr);
409 }
410
411 return ave_time;
412 }
413
414 float Run(const BaseArgument* p_arg,
415 const StreamConfig& stream_config = StreamConfig{}) override
416 {
417 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
418 }
419 };
420
421 static bool IsSupportedArgument(const Argument& arg)
422 {
423 constexpr index_t Rank = NDimSpatial + 2;
424 int doutFastestDim = -1;
425 int dinFastestDim = -1;
426
427 for(int i = 0; i < Rank; ++i)
428 {
429 if(arg.dout_n_c_wos_strides_[i] == 1)
430 doutFastestDim = i;
431 if(arg.din_n_c_wos_strides_[i] == 1)
432 dinFastestDim = i;
433 }
434 if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1))
435 {
436 return false;
437 }
438 if(doutFastestDim == -1 || dinFastestDim == -1)
439 {
440 if constexpr(InSrcOutDstVectorSize != 1)
441 return false;
442 }
443 else
444 {
445 if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
446 return false;
447 if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
448 return false;
449 }
450 return true;
451 }
452
453 bool IsSupportedArgument(const BaseArgument* p_arg) override
454 {
455 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
456 }
457
458 std::unique_ptr<BaseArgument>
459 MakeArgumentPointer(const void* p_dout,
460 void* p_din,
461 std::vector<ck::index_t> dout_n_c_wos_lengths,
462 std::vector<ck::index_t> din_n_c_wos_length,
463 std::vector<ck::index_t> dout_n_c_wos_strides,
464 std::vector<ck::index_t> din_n_c_wos_strides,
465 std::vector<ck::index_t> window_lengths,
466 std::vector<ck::index_t> window_strides,
467 std::vector<ck::index_t> window_dilations,
468 std::vector<ck::index_t> input_left_pads,
469 std::vector<ck::index_t> input_right_pads) override
470 {
471 constexpr index_t Rank = NDimSpatial + 2;
472
473 if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
474 dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
475 {
476 throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or "
477 "[dout|din]_n_c_wos_lengths is not equal to Rank");
478 }
479
480 if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
481 window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
482 input_right_pads.size() != NDimSpatial)
483 {
484 throw std::runtime_error(
485 "dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
486 "input_right_pads] is not equal to Rank");
487 }
488 return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
489 static_cast<DInDataType*>(p_din),
490 dout_n_c_wos_lengths,
491 din_n_c_wos_length,
492 dout_n_c_wos_strides,
493 din_n_c_wos_strides,
494 window_lengths,
495 window_strides,
496 window_dilations,
497 input_left_pads,
498 input_right_pads);
499 }
500
501 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
502 {
503 return std::make_unique<Invoker>(Invoker{});
504 }
505
506 std::string GetTypeString() const override
507 {
508 auto str = std::stringstream();
509
510 // clang-format off
511 str << "DeviceAvgPool2dBwd<" << BlockSize << ",";
512 str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
513 str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
514 str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
515 // clang-format on
516
517 return str.str();
518 }
519};
520
521} // namespace device
522} // namespace tensor_operation
523} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise.hpp:84
Definition utility/sequence.hpp:43
Definition reduction_operator.hpp:37
Definition device_base.hpp:197
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:295
int num_reduce_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:363
const DOutDataType * p_dout_grid_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:356
std::vector< ck::index_t > dout_n_c_wos_strides_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:360
std::vector< ck::index_t > dout_n_c_wos_lengths_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:358
Argument(const DOutDataType *p_dout, DInDataType *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:296
DInDataType * p_din_grid_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:357
std::vector< DinGridDesc_M > din_grid_desc_m_container_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:365
std::vector< DoutGridDesc_M_K > dout_grid_desc_m_k_container_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:364
std::vector< ck::index_t > din_n_c_wos_strides_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:361
Div div_element_op_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:367
std::vector< ck::index_t > din_n_c_wos_length_
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:359
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:371
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:372
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:414
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:41
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:501
tensor_operation::element_wise::UnaryDivide Div
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:274
remove_cvref_t< tuple_element_t< 0, DoutDinGridDesc > > DoutGridDesc_M_K
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:265
static auto Make2DGridDescriptor_Out_M_K_In_M(const std::vector< ck::index_t > &dout_n_c_wos_lengths, const std::vector< ck::index_t > &din_n_c_wos_length, const std::vector< ck::index_t > &dout_n_c_wos_strides, const std::vector< ck::index_t > &din_n_c_wos_strides, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations, const std::vector< ck::index_t > &input_left_pads, const std::vector< ck::index_t > &input_right_pads, const std::vector< ck::index_t > &tildes)
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:52
static constexpr auto I1
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:46
remove_cvref_t< tuple_element_t< 1, DoutDinGridDesc > > DinGridDesc_M
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:266
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, void *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads) override
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:459
tensor_operation::element_wise::PassThrough PassThrough
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:273
static constexpr ck::index_t NDimSpatial
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:43
static constexpr ck::index_t M_BlockTileSize
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:48
static constexpr index_t OutSrcInDstVectorDim
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:271
static constexpr ck::index_t K_BlockTileSize
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:49
std::string GetTypeString() const override
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:506
decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0})) DoutDinGridDesc
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:254
static constexpr auto I0
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:45
GridwiseReduction_mk_to_m_threadwise< DOutDataType, DInDataType, ComputeDataType, int, DoutGridDesc_M_K, DinGridDesc_M, reduce::Add, PassThrough, Div, InMemoryDataOperationEnum::Set, false, BlockSize, MThreadSliceSize, KThreadSliceSize, OutSrcInDstVectorDim, InSrcOutDstVectorSize, InSrcOutDstVectorSize > gridwise_reduce
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:276
static bool IsSupportedArgument(const Argument &arg)
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:421
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_avgpool2d_bwd_nhwc_nhwc.hpp:453
Definition device_avgpool_bwd.hpp:20
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:701