gridwise_batchnorm_backward_blockwise_welford.hpp File Reference#
gridwise_batchnorm_backward_blockwise_welford.hpp File Reference
#include "ck/utility/data_type.hpp"#include "ck/utility/math_v2.hpp"#include "ck/utility/reduction_operator.hpp"#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp"#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp"#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"Go to the source code of this file.
Namespaces | |
| namespace | ck |
Functions | |
| template<typename GridwiseBatchrNormBackwardWithBlockwiseWelford_, typename XDataType, typename DyDataType, typename DxDataType, typename AccDataType, typename ScaleDataType, typename DscaleDbiasDataType, typename MeanVarDataType, typename DyElementwiseOp, typename XYGridDesc_M_K, typename ScaleBiasGridDesc_M, typename MeanVarGridDesc_M, typename GetReduceCountPerThreadFunctor> | |
| __global__ void | ck::kernel_batchnorm_backward_with_blockwise_welford (const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias) |