device_normalization_bwd_data_impl.hpp Source File

device_normalization_bwd_data_impl.hpp Source File#

Composable Kernel: device_normalization_bwd_data_impl.hpp Source File
device_normalization_bwd_data_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
16
17// M is Invariant dimension, K is reduced dimension
18namespace ck {
19namespace tensor_operation {
20namespace device {
21template <typename GridwiseNormalizationBwd,
22 typename DYDataType,
23 typename XDataType,
24 typename GammaDataType,
25 typename MeanInvStdDataType,
26 typename DXDataType,
27 typename GridDesc_M_K>
28__global__ void
29kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k,
30 const GridDesc_M_K x_grid_desc_m_k,
31 const GridDesc_M_K gamma_grid_desc_m_k,
32 const GridDesc_M_K mean_grid_desc_m_k,
33 const GridDesc_M_K inv_std_grid_desc_m_k,
34 const GridDesc_M_K dx_grid_desc_m_k,
35 index_t num_k_block_tile_iteration,
36 const DYDataType* const __restrict__ p_dy_global,
37 const XDataType* const __restrict__ p_x_global,
38 const GammaDataType* const __restrict__ p_gamma_global,
39 const MeanInvStdDataType* const __restrict__ p_mean_global,
40 const MeanInvStdDataType* const __restrict__ p_inv_std_global,
41 DXDataType* const __restrict__ p_dx_global)
42{
43 GridwiseNormalizationBwd::Run(dy_grid_desc_m_k,
44 x_grid_desc_m_k,
45 gamma_grid_desc_m_k,
46 mean_grid_desc_m_k,
47 inv_std_grid_desc_m_k,
48 dx_grid_desc_m_k,
49 num_k_block_tile_iteration,
50 p_dy_global,
51 p_x_global,
52 p_gamma_global,
53 p_mean_global,
54 p_inv_std_global,
55 p_dx_global);
56};
57
58template <typename DYDataType,
59 typename XDataType,
60 typename GammaDataType,
61 typename MeanInvStdDataType,
62 typename ComputeDataType,
63 typename DXDataType,
64 index_t Rank,
65 index_t NumReduceDim,
66 index_t BlockSize,
67 index_t MThreadClusterSize,
68 index_t KThreadClusterSize,
69 index_t MThreadSliceSize,
70 index_t KThreadSliceSize,
71 bool IsDYFastestDimReduced,
72 index_t DYSrcVectorSize,
73 bool IsXFastestDimReduced,
74 index_t XSrcVectorSize,
75 bool IsGammaFastestDimReduced,
76 index_t GammaSrcVectorSize,
77 bool IsMeanInvStdFastestDimReduced,
78 index_t MeanInvStdSrcVectorSize,
79 bool IsDxFastestDimReduced,
80 index_t DXDstVectorSize>
82 XDataType,
83 GammaDataType,
84 MeanInvStdDataType,
85 DXDataType,
86 Rank,
87 NumReduceDim>
88{
89 static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
90 static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
91 static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0;
92 static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
93 static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0;
94
95 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
96
97 static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
98 (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
99 "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
100
101 static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
102 (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
103 "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
104
105 static_assert(
106 ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
107 (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
108 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
109
110 static_assert(
111 (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
112 (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
113 "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
114 "check!");
115
116 static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) ||
117 (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)),
118 "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
119
120 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
121 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
122 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
123
124 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
125 static_assert(!reduceAllDim);
126
127 static auto Make2dDescriptor(const std::vector<index_t>& lengths,
128 const std::vector<index_t>& strides,
129 int numBlockTileIteration)
130 {
131 const auto tupleLengths = make_tuple_from_array(lengths, Number<Rank>{});
132 const auto tupleStrides = make_tuple_from_array(strides, Number<Rank>{});
133
134 const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
135
136 const auto grid_desc_m_k = [&]() {
137 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
139
140 const auto reduceDimLengths =
141 make_tuple_from_array_and_index_seq(lengths, ReduceDims{});
142 const auto invariantDimLengths =
143 make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
144
145 return transform_tensor_descriptor(desc,
146 make_tuple(make_merge_transform(invariantDimLengths),
147 make_merge_transform(reduceDimLengths)),
148 make_tuple(InvariantDims{}, ReduceDims{}),
150 }();
151
152 const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
153 const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
154
155 const auto pad_M =
156 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
157 const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
158
159 auto grid_desc_m_k_padded =
160 transform_tensor_descriptor(grid_desc_m_k,
161 make_tuple(make_right_pad_transform(invariantLength, pad_M),
162 make_right_pad_transform(reduceLength, pad_K)),
165
166 return grid_desc_m_k_padded;
167 }
168
169 using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1));
170
173 XDataType,
174 GammaDataType,
175 MeanInvStdDataType,
176 ComputeDataType,
177 DXDataType,
179 BlockSize,
180 MThreadClusterSize,
181 KThreadClusterSize,
182 MThreadSliceSize,
183 KThreadSliceSize,
185 DYSrcVectorSize,
187 XSrcVectorSize,
189 GammaSrcVectorSize,
191 MeanInvStdSrcVectorSize,
193 DXDstVectorSize,
194 false>;
195
198 XDataType,
199 GammaDataType,
200 MeanInvStdDataType,
201 ComputeDataType,
202 DXDataType,
204 BlockSize,
205 MThreadClusterSize,
206 KThreadClusterSize,
207 MThreadSliceSize,
208 KThreadSliceSize,
210 DYSrcVectorSize,
212 XSrcVectorSize,
214 GammaSrcVectorSize,
216 MeanInvStdSrcVectorSize,
218 DXDstVectorSize,
219 true>;
220
221 struct Argument : public BaseArgument
222 {
223 Argument(const std::vector<index_t> lengths,
224 const std::vector<index_t> dyStrides,
225 const std::vector<index_t> xStrides,
226 const std::vector<index_t> gammaStrides,
227 const std::vector<index_t> meanStrides,
228 const std::vector<index_t> invStdStrides,
229 const std::vector<index_t> dxStrides,
230 const std::vector<index_t> reduceDims,
231 const DYDataType* p_dy,
232 const XDataType* p_x,
233 const GammaDataType* p_gamma,
234 const MeanInvStdDataType* p_mean,
235 const MeanInvStdDataType* p_invStd,
236 DXDataType* p_dx)
237 : p_dy_(p_dy),
238 p_x_(p_x),
239 p_gamma_(p_gamma),
240 p_mean_(p_mean),
241 p_invStd_(p_invStd),
242 p_dx_(p_dx)
243 {
250 shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
252
254
256
258
267
269 }
270
271 const DYDataType* p_dy_;
272 const XDataType* p_x_;
273 const GammaDataType* p_gamma_;
274 const MeanInvStdDataType* p_mean_;
275 const MeanInvStdDataType* p_invStd_;
276 DXDataType* p_dx_;
277
278 std::vector<index_t> lengths_;
279 std::vector<index_t> dyStrides_;
280 std::vector<index_t> xStrides_;
281 std::vector<index_t> gammaStrides_;
282 std::vector<index_t> meanStrides_;
283 std::vector<index_t> invStdStrides_;
284 std::vector<index_t> dxStrides_;
285
287 size_t gridSize_;
288
289 // tensor descriptor
296
298 index_t MRaw_; // Invariant length
299 index_t KRaw_; // reduce length
300 };
301
302 struct Invoker : public BaseInvoker
303 {
304 auto KernelSelector(bool isSweepOnce)
305 {
306 return isSweepOnce
308 DYDataType,
309 XDataType,
310 GammaDataType,
311 MeanInvStdDataType,
312 DXDataType,
315 DYDataType,
316 XDataType,
317 GammaDataType,
318 MeanInvStdDataType,
319 DXDataType,
321 }
322
323 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
324 {
325 const auto kernel_main = KernelSelector(arg.isSweeponce_);
326
327 return launch_and_time_kernel(stream_config,
328 kernel_main,
329 dim3(arg.gridSize_),
330 dim3(BlockSize),
331 0,
339 arg.p_dy_,
340 arg.p_x_,
341 arg.p_gamma_,
342 arg.p_mean_,
343 arg.p_invStd_,
344 arg.p_dx_);
345 }
346
347 float Run(const BaseArgument* p_arg,
348 const StreamConfig& stream_config = StreamConfig{}) override
349 {
350 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
351 }
352 };
353
354 template <index_t SrcVectorDim, index_t SrcVectorSize>
355 bool IsVectorDimSizeValid(const std::vector<index_t>& lengths,
356 const std::vector<index_t>& strides)
357 {
358 if constexpr(SrcVectorSize == 1)
359 return true;
360
361 // Fastest dimension is not reduced
362 if constexpr(SrcVectorDim == 0)
363 {
364 if constexpr(NumInvariantDim == 0)
365 return false;
366
367 if(strides[NumInvariantDim - 1] != 1)
368 return false;
369
370 if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
371 return false;
372 }
373 else // Fastest dimension is reduced
374 {
375 if(strides[Rank - 1] != 1)
376 return false;
377
378 if(lengths[Rank - 1] % SrcVectorSize != 0)
379 return false;
380 };
381
382 return true;
383 }
384
385 bool IsSupportedArgument(const BaseArgument* p_arg) override
386 {
387 const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
388
389 bool pass = true;
391 p_arg_->dyStrides_);
393 p_arg_->xStrides_);
395 p_arg_->gammaStrides_);
397 p_arg_->lengths_, p_arg_->meanStrides_);
399 p_arg_->lengths_, p_arg_->invStdStrides_);
400
402 p_arg_->dxStrides_);
403 return pass;
404 }
405
406 std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
407 const std::vector<index_t> dyStrides,
408 const std::vector<index_t> xStrides,
409 const std::vector<index_t> gammaStrides,
410 const std::vector<index_t> meanStrides,
411 const std::vector<index_t> invStdStrides,
412 const std::vector<index_t> dxStrides,
413 const std::vector<index_t> reduceDims,
414 const void* p_dy,
415 const void* p_x,
416 const void* p_gamma,
417 const void* p_mean,
418 const void* p_invStd,
419 void* p_dx) override
420 {
421 if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
422 gammaStrides.size() != Rank || meanStrides.size() != Rank ||
423 invStdStrides.size() != Rank || dxStrides.size() != Rank)
424 throw std::runtime_error("dimension is incorrect");
425
426 return std::make_unique<Argument>(lengths,
427 dyStrides,
428 xStrides,
429 gammaStrides,
430 meanStrides,
431 invStdStrides,
432 dxStrides,
433 reduceDims,
434 static_cast<const DYDataType*>(p_dy),
435 static_cast<const XDataType*>(p_x),
436 static_cast<const GammaDataType*>(p_gamma),
437 static_cast<const MeanInvStdDataType*>(p_mean),
438 static_cast<const MeanInvStdDataType*>(p_invStd),
439 static_cast<DXDataType*>(p_dx));
440 }
441
442 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
443 {
444 return std::make_unique<Invoker>();
445 }
446
447 std::string GetTypeString() const override
448 {
449 auto str = std::stringstream();
450
451 // clang-format off
452 str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ",";
453 str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
454 str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
455 str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize;
456 str << ">";
457 // clang-format on
458
459 return str.str();
460 }
461};
462
463} // namespace device
464} // namespace tensor_operation
465} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
__global__ void kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M_K dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition device_normalization_bwd_data_impl.hpp:29
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_bwd_data.hpp:49
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_normalization_bwd_data.hpp:22
Definition device_normalization_bwd_data_impl.hpp:222
DXDataType * p_dx_
Definition device_normalization_bwd_data_impl.hpp:276
GridDesc_M_K mean_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:293
const XDataType * p_x_
Definition device_normalization_bwd_data_impl.hpp:272
index_t MRaw_
Definition device_normalization_bwd_data_impl.hpp:298
const MeanInvStdDataType * p_mean_
Definition device_normalization_bwd_data_impl.hpp:274
GridDesc_M_K gamma_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:292
const DYDataType * p_dy_
Definition device_normalization_bwd_data_impl.hpp:271
GridDesc_M_K inv_std_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:294
Argument(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const DYDataType *p_dy, const XDataType *p_x, const GammaDataType *p_gamma, const MeanInvStdDataType *p_mean, const MeanInvStdDataType *p_invStd, DXDataType *p_dx)
Definition device_normalization_bwd_data_impl.hpp:223
int numBlockTileIteration_
Definition device_normalization_bwd_data_impl.hpp:286
GridDesc_M_K dy_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:290
const MeanInvStdDataType * p_invStd_
Definition device_normalization_bwd_data_impl.hpp:275
std::vector< index_t > gammaStrides_
Definition device_normalization_bwd_data_impl.hpp:281
bool isSweeponce_
Definition device_normalization_bwd_data_impl.hpp:297
std::vector< index_t > lengths_
Definition device_normalization_bwd_data_impl.hpp:278
std::vector< index_t > invStdStrides_
Definition device_normalization_bwd_data_impl.hpp:283
size_t gridSize_
Definition device_normalization_bwd_data_impl.hpp:287
GridDesc_M_K dx_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:295
std::vector< index_t > dxStrides_
Definition device_normalization_bwd_data_impl.hpp:284
std::vector< index_t > xStrides_
Definition device_normalization_bwd_data_impl.hpp:280
const GammaDataType * p_gamma_
Definition device_normalization_bwd_data_impl.hpp:273
index_t KRaw_
Definition device_normalization_bwd_data_impl.hpp:299
std::vector< index_t > meanStrides_
Definition device_normalization_bwd_data_impl.hpp:282
std::vector< index_t > dyStrides_
Definition device_normalization_bwd_data_impl.hpp:279
GridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_bwd_data_impl.hpp:291
Definition device_normalization_bwd_data_impl.hpp:303
auto KernelSelector(bool isSweepOnce)
Definition device_normalization_bwd_data_impl.hpp:304
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_bwd_data_impl.hpp:323
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_bwd_data_impl.hpp:347
Definition device_normalization_bwd_data_impl.hpp:88
static constexpr index_t MeanInvStdSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:92
static constexpr index_t K_BlockTileSize
Definition device_normalization_bwd_data_impl.hpp:122
static constexpr index_t NumInvariantDim
Definition device_normalization_bwd_data_impl.hpp:120
GridwiseNormalizationBwdData_mk_to_mk< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DXDstVectorDim, DXDstVectorSize, true > GridwiseNormalizationBwdDataSweepOnce
Definition device_normalization_bwd_data_impl.hpp:196
static constexpr index_t XSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:90
static constexpr index_t DXDstVectorDim
Definition device_normalization_bwd_data_impl.hpp:93
static constexpr index_t M_BlockTileSize
Definition device_normalization_bwd_data_impl.hpp:121
std::string GetTypeString() const override
Definition device_normalization_bwd_data_impl.hpp:447
static constexpr index_t DYSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:89
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_bwd_data_impl.hpp:442
static constexpr bool reduceAllDim
Definition device_normalization_bwd_data_impl.hpp:124
static auto Make2dDescriptor(const std::vector< index_t > &lengths, const std::vector< index_t > &strides, int numBlockTileIteration)
Definition device_normalization_bwd_data_impl.hpp:127
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > dyStrides, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > meanStrides, const std::vector< index_t > invStdStrides, const std::vector< index_t > dxStrides, const std::vector< index_t > reduceDims, const void *p_dy, const void *p_x, const void *p_gamma, const void *p_mean, const void *p_invStd, void *p_dx) override
Definition device_normalization_bwd_data_impl.hpp:406
decltype(Make2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition device_normalization_bwd_data_impl.hpp:169
GridwiseNormalizationBwdData_mk_to_mk< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, ComputeDataType, DXDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, DYSrcVectorDim, DYSrcVectorSize, XSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize, DXDstVectorDim, DXDstVectorSize, false > GridwiseNormalizationBwdDataGeneric
Definition device_normalization_bwd_data_impl.hpp:171
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_bwd_data_impl.hpp:385
static constexpr index_t GammaSrcVectorDim
Definition device_normalization_bwd_data_impl.hpp:91
bool IsVectorDimSizeValid(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_bwd_data_impl.hpp:355