device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp Source File

device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp Source File
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10#include "ck/utility/env.hpp"
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
25template <typename InDataType,
26 typename WeiDataType,
27 typename OutDataType,
28 typename AccDataType,
29 typename InElementwiseOperation,
30 typename WeiElementwiseOperation,
31 typename OutElementwiseOperation,
32 ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
33 ck::index_t BlockSize,
34 ck::index_t MPerBlock,
35 ck::index_t NPerBlock,
36 ck::index_t K0PerBlock,
37 ck::index_t K1,
38 ck::index_t MPerXDL,
39 ck::index_t NPerXDL,
40 ck::index_t MXdlPerWave,
41 ck::index_t NXdlPerWave,
42 typename ABlockTransferThreadClusterLengths_K0_M_K1,
43 typename ABlockTransferThreadClusterArrangeOrder,
44 typename ABlockTransferSrcAccessOrder,
45 ck::index_t ABlockTransferSrcVectorDim,
46 ck::index_t ABlockTransferSrcScalarPerVector,
47 ck::index_t ABlockTransferDstScalarPerVector_K1,
48 bool ABlockLdsAddExtraM,
49 typename BBlockTransferThreadClusterLengths_K0_N_K1,
50 typename BBlockTransferThreadClusterArrangeOrder,
51 typename BBlockTransferSrcAccessOrder,
52 ck::index_t BBlockTransferSrcVectorDim,
53 ck::index_t BBlockTransferSrcScalarPerVector,
54 ck::index_t BBlockTransferDstScalarPerVector_K1,
55 bool BBlockLdsAddExtraN,
56 ck::index_t CThreadTransferSrcDstVectorDim,
57 ck::index_t CThreadTransferDstScalarPerVector>
59 : public DeviceConvBwdData<2,
60 ck::tensor_layout::convolution::NHWC,
61 ck::tensor_layout::convolution::KYXC,
62 ck::tensor_layout::convolution::NHWK,
63 InDataType,
64 WeiDataType,
65 OutDataType,
66 InElementwiseOperation,
67 WeiElementwiseOperation,
68 OutElementwiseOperation>
69{
71
73 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
74 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
75
76 using ADataType = OutDataType;
77 using BDataType = WeiDataType;
78 using CDataType = InDataType;
79
80 // TODO make A/B datatype different
81 using ABDataType = InDataType;
82
83 static constexpr index_t NDimSpatial = 2;
84
85 static constexpr auto I0 = Number<0>{};
86 static constexpr auto I1 = Number<1>{};
87 static constexpr auto I2 = Number<2>{};
88 static constexpr auto I3 = Number<3>{};
89 static constexpr auto I4 = Number<4>{};
90 static constexpr auto I5 = Number<5>{};
91
92 static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) %
93 ABlockTransferSrcScalarPerVector ==
94 0);
95 static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) %
96 BBlockTransferSrcScalarPerVector ==
97 0);
98
99 static constexpr auto K1Number = Number<K1>{};
100 static constexpr auto GemmK1Number = K1Number;
101
102 static auto
104 ck::index_t K,
105 ck::index_t C,
106 std::vector<ck::index_t> input_spatial_lengths,
107 std::vector<ck::index_t> filter_spatial_lengths,
108 std::vector<ck::index_t> output_spatial_lengths,
109 std::vector<ck::index_t> conv_filter_strides,
110 std::vector<ck::index_t> conv_filter_dilations,
111 std::vector<ck::index_t> input_left_pads,
112 std::vector<ck::index_t> input_right_pads,
113 index_t i_ytilde,
114 index_t i_xtilde)
115 {
116 using namespace ck;
117
118 const index_t Hi = input_spatial_lengths[0];
119 const index_t Wi = input_spatial_lengths[1];
120
121 const index_t Ho = output_spatial_lengths[0];
122 const index_t Wo = output_spatial_lengths[1];
123
124 const index_t Y = filter_spatial_lengths[0];
125 const index_t X = filter_spatial_lengths[1];
126
127 const index_t InLeftPadH = input_left_pads[0];
128 const index_t InLeftPadW = input_left_pads[1];
129
130 const index_t InRightPadH = input_right_pads[0];
131 const index_t InRightPadW = input_right_pads[1];
132
133 const index_t ConvStrideH = conv_filter_strides[0];
134 const index_t ConvStrideW = conv_filter_strides[1];
135
136 const index_t ConvDilationH = conv_filter_dilations[0];
137 const index_t ConvDilationW = conv_filter_dilations[1];
138
139 const auto K0 = K / K1;
140
141 const auto out_n_ho_wo_k_grid_desc =
143 const auto wei_k_y_x_c_grid_desc =
145 const auto in_n_hi_wi_c_grid_desc =
147
148 if constexpr(ConvBackwardDataSpecialization ==
150 {
151 // A: output tensor
152 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
158
159 // B: weight tensor
160 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
166
167 // C: input tensor
168 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
169 in_n_hi_wi_c_grid_desc,
171 make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
172 make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
176
177 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
178 in_n_y_ho_x_wo_c_grid_desc,
185
186 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
187 wei_gemmk0_gemmn_gemmk1_grid_desc,
188 in_gemmm_gemmn_grid_desc);
189 }
190 else
191 {
192 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
193 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
194
195 const auto YTilde = ConvStrideH / GcdStrideDilationH;
196 const auto XTilde = ConvStrideW / GcdStrideDilationW;
197
198 const auto YDot = math::integer_divide_ceil(Y, YTilde);
199 const auto XDot = math::integer_divide_ceil(X, XTilde);
200
201 const auto HTilde =
202 Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
203 const auto WTilde =
204 Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
205
206 // only work on HTilde and WTilde that contribute to non-padding area of input tensor
207 const auto IHTildeSliceBegin = math::integer_divide_floor(
208 math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
209 const auto IWTildeSliceBegin = math::integer_divide_floor(
210 math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
211
212 const auto IHTildeSliceEnd = math::min(
213 HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
214 const auto IWTildeSliceEnd = math::min(
215 WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
216
217 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
218 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
219
220 // GemmK is different for each GEMM
221 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
222 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
223
224 // A: output tensor
225 const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
226 out_n_ho_wo_k_grid_desc,
233
234 const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
235 out_n_hop_wop_k_grid_desc,
238 make_embed_transform(make_tuple(YDot, HTilde),
239 make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
240 make_embed_transform(make_tuple(XDot, WTilde),
241 make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
245
246 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
248 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
250 make_slice_transform(YDot, I0, YDotSlice),
251 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
252 make_slice_transform(XDot, I0, XDotSlice),
253 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
256 Sequence<1>{},
257 Sequence<2>{},
258 Sequence<3>{},
259 Sequence<4>{},
260 Sequence<5>{}),
262 Sequence<1>{},
263 Sequence<2>{},
264 Sequence<3>{},
265 Sequence<4>{},
266 Sequence<5, 6>{}));
267
268 const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
269 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
270 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
271 make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
275
276 // B weight tensor
277 const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
278 wei_k_y_x_c_grid_desc,
280 make_embed_transform(make_tuple(YDot, YTilde),
281 make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
282 make_embed_transform(make_tuple(XDot, XTilde),
283 make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
287
288 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
289 transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
291 make_slice_transform(YDot, I0, YDotSlice),
292 make_slice_transform(XDot, I0, XDotSlice),
293 make_freeze_transform(i_ytilde),
294 make_freeze_transform(i_xtilde),
297 Sequence<1>{},
298 Sequence<3>{},
299 Sequence<2>{},
300 Sequence<4>{},
301 Sequence<5>{}),
303 Sequence<2>{},
304 Sequence<3>{},
305 Sequence<>{},
306 Sequence<>{},
307 Sequence<4>{}));
308
309 const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
310 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
311 make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
316
317 // C: input tensor
318 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
319 in_n_hi_wi_c_grid_desc,
321 make_pad_transform(Hi, InLeftPadH, InRightPadH),
322 make_pad_transform(Wi, InLeftPadW, InRightPadW),
326
327 const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
328 in_n_hip_wip_c_grid_desc,
330 make_embed_transform(make_tuple(YTilde, HTilde),
331 make_tuple(ConvDilationH, ConvStrideH)),
332 make_embed_transform(make_tuple(XTilde, WTilde),
333 make_tuple(ConvDilationW, ConvStrideW)),
337
338 const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
339 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
341 make_freeze_transform(i_ytilde),
342 make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
343 make_freeze_transform(i_xtilde),
344 make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
347 Sequence<1>{},
348 Sequence<2>{},
349 Sequence<3>{},
350 Sequence<4>{},
351 Sequence<5>{}),
353 Sequence<>{},
354 Sequence<1>{},
355 Sequence<>{},
356 Sequence<2>{},
357 Sequence<3>{}));
358
359 const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
360 in_n_htildeslice_wtildeslice_c_grid_desc,
361 make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
365
366 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
367 wei_gemmk0_gemmn_gemmk1_grid_desc,
368 in_gemmm_gemmn_grid_desc);
369 }
370
371 } // function end
372
374 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0));
375
379
380 // GridwiseGemm
381 template <index_t NXdlPerWave_>
383 BlockSize,
384 ABDataType, // TODO: distinguish A/B datatype
385 AccDataType,
386 CDataType,
388 InElementwiseOperation,
389 WeiElementwiseOperation,
390 OutElementwiseOperation,
391 MPerBlock,
392 NPerBlock,
393 K0PerBlock,
394 MPerXDL,
395 NPerXDL,
396 K1,
397 MXdlPerWave,
398 NXdlPerWave_,
399 ABlockTransferThreadClusterLengths_K0_M_K1,
400 ABlockTransferThreadClusterArrangeOrder,
401 ABlockTransferSrcAccessOrder,
402 ABlockTransferSrcVectorDim,
403 ABlockTransferSrcScalarPerVector,
404 ABlockTransferDstScalarPerVector_K1,
405 false, // AThreadTransferSrcResetCoordinateAfterRun,
406 ABlockLdsAddExtraM,
407 BBlockTransferThreadClusterLengths_K0_N_K1,
408 BBlockTransferThreadClusterArrangeOrder,
409 BBlockTransferSrcAccessOrder,
410 BBlockTransferSrcVectorDim,
411 BBlockTransferSrcScalarPerVector,
412 BBlockTransferDstScalarPerVector_K1,
413 false, // BThreadTransferSrcResetCoordinateAfterRun,
414 BBlockLdsAddExtraN,
415 Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
416 7, // CThreadTransferSrcDstVectorDim,
417 CThreadTransferDstScalarPerVector>;
420
421 // Argument
422 struct Argument : public BaseArgument
423 {
424 Argument(InDataType* p_in_grid,
425 const WeiDataType* p_wei_grid,
426 const OutDataType* p_out_grid,
427 ck::index_t N,
428 ck::index_t K,
429 ck::index_t C,
430 std::vector<ck::index_t> input_spatial_lengths,
431 std::vector<ck::index_t> filter_spatial_lengths,
432 std::vector<ck::index_t> output_spatial_lengths,
433 std::vector<ck::index_t> conv_filter_strides,
434 std::vector<ck::index_t> conv_filter_dilations,
435 std::vector<ck::index_t> input_left_pads,
436 std::vector<ck::index_t> input_right_pads)
437 : p_a_grid_{p_out_grid},
438 p_b_grid_{p_wei_grid},
439 p_c_grid_{p_in_grid},
440 Conv_N_{N},
441 Conv_K_{K},
442 Conv_C_{C},
443 input_spatial_lengths_{input_spatial_lengths},
444 filter_spatial_lengths_{filter_spatial_lengths},
445 output_spatial_lengths_{output_spatial_lengths},
446 conv_filter_strides_{conv_filter_strides},
447 conv_filter_dilations_{conv_filter_dilations},
448 input_left_pads_{input_left_pads},
449 input_right_pads_{input_right_pads}
450 {
451 const index_t ConvStrideH = conv_filter_strides[0];
452 const index_t ConvStrideW = conv_filter_strides[1];
453
454 const index_t ConvDilationH = conv_filter_dilations[0];
455 const index_t ConvDilationW = conv_filter_dilations[1];
456
457 const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
458 const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
459
460 const auto YTilde = ConvStrideH / GcdStrideDilationH;
461 const auto XTilde = ConvStrideW / GcdStrideDilationW;
462
463 for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
464 {
465 for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
466 {
467 // check slice is valid
468 const index_t Y = filter_spatial_lengths_[0];
469 const index_t X = filter_spatial_lengths_[1];
470 const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
471 const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
472 if(YDotSlice * XDotSlice <= 0)
473 {
474 continue;
475 }
476
478 N,
479 K,
480 C,
481 input_spatial_lengths,
482 filter_spatial_lengths,
483 output_spatial_lengths,
484 conv_filter_strides,
485 conv_filter_dilations,
486 input_left_pads,
487 input_right_pads,
488 i_ytilde,
489 i_xtilde);
490 a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
491 b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
492 c_grid_desc_m_n_container_.push_back(descs[I2]);
493 }
494 }
495 }
496
500 std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
501 std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
502 std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
503 // for checking IsSupportedArgument()
507
508 std::vector<ck::index_t> input_spatial_lengths_;
509 std::vector<ck::index_t> filter_spatial_lengths_;
510 std::vector<ck::index_t> output_spatial_lengths_;
511 std::vector<ck::index_t> conv_filter_strides_;
512 std::vector<ck::index_t> conv_filter_dilations_;
513 std::vector<ck::index_t> input_left_pads_;
514 std::vector<ck::index_t> input_right_pads_;
515 };
516
517 // Invoker
518 struct Invoker : public BaseInvoker
519 {
521
522 template <typename GridwiseGemm>
523 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
524 {
525 float ave_time = 0;
526 for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
527 {
528 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
529 {
530 {
531 std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
532 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
533 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
534 << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
535 << std::endl;
536
537 std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
538 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
539 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
540 << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
541 << std::endl;
542
543 std::cout << "arg.c_grid_desc_m_n_container_{ "
544 << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
545 << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
546 << std::endl;
547 }
548 }
549
550 if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
553 {
554 throw std::runtime_error(
555 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
556 }
557
558 const auto [gdx, gdy, gdz] =
559 GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]);
560
561 const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
562 arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
563
564 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
565 {
566 const auto kernel =
567 kernel_gemm_xdlops_v2r3<GridwiseGemm,
568 ADataType, // TODO: distiguish A/B datatype
569 CDataType,
573 true>;
574
575 ave_time += launch_and_time_kernel(stream_config,
576 kernel,
577 dim3(gdx, gdy, gdz),
578 dim3(BlockSize),
579 0,
580 arg.p_a_grid_,
581 arg.p_b_grid_,
582 arg.p_c_grid_,
586 }
587 else
588 {
589 const auto kernel =
590 kernel_gemm_xdlops_v2r3<GridwiseGemm,
591 ADataType, // TODO: distiguish A/B datatype
592 CDataType,
596 false>;
597
598 ave_time += launch_and_time_kernel(stream_config,
599 kernel,
600 dim3(gdx, gdy, gdz),
601 dim3(BlockSize),
602 0,
603 arg.p_a_grid_,
604 arg.p_b_grid_,
605 arg.p_c_grid_,
609 }
610 }
611 return ave_time;
612 }
613
615
616 float Run(const BaseArgument* p_arg,
617 const StreamConfig& stream_config = StreamConfig{}) override
618 {
619 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
620 }
621 };
622
623 static constexpr bool IsValidCompilationParameter()
624 {
625 // TODO: properly implement this check
626 return true;
627 }
628
629 static bool IsSupportedArgument(const Argument& arg)
630 {
632 {
633 return false;
634 }
635 if constexpr(ConvBackwardDataSpecialization ==
637 {
638 // check if it's 1x1, stride=1 pad = 0 conv
639 if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
640 arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
641 arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
642 arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
643 {
644 return false;
645 }
646 }
647
648 // vector load A/B matrix from global memory
649 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 &&
650 arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
651 arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
652 {
653 return false;
654 }
655
656 // vector store C matrix into global memory
657 if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
658 {
659 return false;
660 }
661
662 // Gridwise GEMM size
663 bool isWave64 = get_warp_size() == 64;
664 for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
665 {
666 bool valid = false;
667 if(isWave64)
668 {
669 if constexpr(NXdlPerWave64 > 0)
670 {
674 }
675 }
676 else
677 {
678 if constexpr(NXdlPerWave32 > 0)
679 {
683 }
684 }
685 if(!valid)
686 return false;
687 }
688 return true;
689 }
690
691 bool IsSupportedArgument(const BaseArgument* p_arg) override
692 {
693 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
694 }
695
696 static auto MakeArgument(InDataType* p_in_grid,
697 const WeiDataType* p_wei_grid,
698 const OutDataType* p_out_grid,
699 ck::index_t N,
700 ck::index_t K,
701 ck::index_t C,
702 std::vector<ck::index_t> input_spatial_lengths,
703 std::vector<ck::index_t> filter_spatial_lengths,
704 std::vector<ck::index_t> output_spatial_lengths,
705 std::vector<ck::index_t> conv_filter_strides,
706 std::vector<ck::index_t> conv_filter_dilations,
707 std::vector<ck::index_t> input_left_pads,
708 std::vector<ck::index_t> input_right_pads)
709 {
710 return Argument{p_in_grid,
711 p_wei_grid,
712 p_out_grid,
713 N,
714 K,
715 C,
716 input_spatial_lengths,
717 filter_spatial_lengths,
718 output_spatial_lengths,
719 conv_filter_strides,
720 conv_filter_dilations,
721 input_left_pads,
722 input_right_pads};
723 }
724
725 static auto MakeInvoker() { return Invoker{}; }
726
727 std::unique_ptr<BaseArgument>
728 MakeArgumentPointer(void* p_in_grid,
729 const void* p_wei_grid,
730 const void* p_out_grid,
731 ck::index_t N,
732 ck::index_t K,
733 ck::index_t C,
734 std::vector<ck::index_t> input_spatial_lengths,
735 std::vector<ck::index_t> filter_spatial_lengths,
736 std::vector<ck::index_t> output_spatial_lengths,
737 std::vector<ck::index_t> conv_filter_strides,
738 std::vector<ck::index_t> conv_filter_dilations,
739 std::vector<ck::index_t> input_left_pads,
740 std::vector<ck::index_t> input_right_pads,
741 InElementwiseOperation,
742 WeiElementwiseOperation,
743 OutElementwiseOperation) override
744 {
745 return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
746 static_cast<const WeiDataType*>(p_wei_grid),
747 static_cast<const OutDataType*>(p_out_grid),
748 N,
749 K,
750 C,
751 input_spatial_lengths,
752 filter_spatial_lengths,
753 output_spatial_lengths,
754 conv_filter_strides,
755 conv_filter_dilations,
756 input_left_pads,
757 input_right_pads);
758 }
759
760 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
761 {
762 return std::make_unique<Invoker>(Invoker{});
763 }
764
765 std::string GetTypeString() const override
766 {
767 auto str = std::stringstream();
768
769 // clang-format off
770 str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
771 << "<"
772 << BlockSize << ", "
773 << MPerBlock << ", "
774 << NPerBlock << ", "
775 << K0PerBlock << ", "
776 << K1 << ", "
777 << MXdlPerWave << ", "
778 << NXdlPerWave << ", "
779 << ABlockTransferSrcScalarPerVector << ", "
780 << ABlockTransferDstScalarPerVector_K1 << ", "
781 << BBlockTransferSrcScalarPerVector << ", "
782 << BBlockTransferDstScalarPerVector_K1
783 << ">";
784 // clang-format on
785
786 return str.str();
787 }
788};
789
790} // namespace device
791} // namespace tensor_operation
792} // namespace ck
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
const ADataType * p_a_grid_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:497
std::vector< ck::index_t > input_right_pads_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:514
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:500
CDataType * p_c_grid_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:499
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:502
std::vector< ck::index_t > output_spatial_lengths_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:510
std::vector< ck::index_t > conv_filter_dilations_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:512
std::vector< ck::index_t > filter_spatial_lengths_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:509
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:501
std::vector< ck::index_t > conv_filter_strides_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:511
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:424
const BDataType * p_b_grid_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:498
std::vector< ck::index_t > input_left_pads_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:513
std::vector< ck::index_t > input_spatial_lengths_
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:508
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:382
static constexpr index_t NDimSpatial
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:83
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:376
static constexpr auto I2
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:87
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation) override
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:728
static constexpr auto K1Number
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:99
static auto MakeInvoker()
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:725
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, index_t i_ytilde, index_t i_xtilde)
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:103
decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)) ABCGridDescs
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:373
OutDataType ADataType
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:76
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:378
InDataType CDataType
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:78
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:760
static constexpr auto I3
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:88
std::string GetTypeString() const override
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:765
static constexpr auto I4
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:89
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:73
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:696
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:419
static constexpr auto I0
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:85
static bool IsSupportedArgument(const Argument &arg)
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:629
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:377
static constexpr auto NXdlPerWave32
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:74
InDataType ABDataType
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:81
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K DeviceOp
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:70
WeiDataType BDataType
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:77
static constexpr auto I5
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:90
static constexpr bool IsValidCompilationParameter()
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:623
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:418
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:691
static constexpr auto GemmK1Number
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:100
static constexpr auto I1
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:86
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:616
DeviceOp::Argument Argument
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:520
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp:523
Definition device_conv_bwd_data.hpp:25
#define CK_ENV(name)
Definition utility/env.hpp:129