gridwise_multiblock_batchnorm_forward.hpp Source File

gridwise_multiblock_batchnorm_forward.hpp Source File#

Composable Kernel: gridwise_multiblock_batchnorm_forward.hpp Source File
gridwise_multiblock_batchnorm_forward.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
14
15namespace ck {
16
17template <typename GridwiseMultiblockBatchNormForward_,
18 typename XDataType,
19 typename YDataType,
20 typename AccDataType,
21 typename ScaleDataType,
22 typename BiasDataType,
23 typename MeanVarDataType,
24 typename YElementwiseOp,
25 typename XYGridDesc_M_K,
26 typename MeanVarCountGridDesc_M_G,
27 typename MeanVarCountGridDesc_M_K,
28 typename ScaleBiasGridDesc_M,
29 typename MeanVarGridDesc_M,
30 typename GetReduceCountPerThreadFunctor>
32 const XYGridDesc_M_K x_grid_desc_m_k,
33 const XYGridDesc_M_K y_grid_desc_m_k,
34 const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
35 const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
36 const ScaleBiasGridDesc_M scale_grid_desc_m,
37 const ScaleBiasGridDesc_M bias_grid_desc_m,
38 const MeanVarGridDesc_M mean_var_grid_desc_m,
39 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
40 index_t num_k_block_tile_iteration,
41 AccDataType epsilon,
42 const XDataType* const __restrict__ p_x,
43 MeanVarDataType* const __restrict__ p_welford_mean,
44 MeanVarDataType* const __restrict__ p_welford_variance,
45 int32_t* const __restrict__ p_welford_count,
46 int32_t* const __restrict__ p_control,
47 const ScaleDataType* const __restrict__ p_scale,
48 const BiasDataType* const __restrict__ p_bias,
49 const YElementwiseOp y_elementwise_op,
50 YDataType* const __restrict__ p_y,
51 bool updateMovingAverage,
52 AccDataType averageFactor,
53 MeanVarDataType* const __restrict__ resultRunningMean,
54 MeanVarDataType* const __restrict__ resultRunningVariance,
55 bool saveMeanInvVariance,
56 MeanVarDataType* const __restrict__ resultSaveMean,
57 MeanVarDataType* const __restrict__ resultSaveInvVariance)
58{
59 GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
60 y_grid_desc_m_k,
61 mean_var_count_grid_desc_m_g,
62 mean_var_count_grid_desc_m_k,
63 scale_grid_desc_m,
64 bias_grid_desc_m,
65 mean_var_grid_desc_m,
66 get_reduce_count_per_thread,
67 num_k_block_tile_iteration,
68 epsilon,
69 p_x,
70 p_welford_mean,
71 p_welford_variance,
72 p_welford_count,
73 p_control,
74 p_scale,
75 p_bias,
76 y_elementwise_op,
77 p_y,
78 updateMovingAverage,
79 averageFactor,
80 resultRunningMean,
81 resultRunningVariance,
82 saveMeanInvVariance,
83 resultSaveMean,
84 resultSaveInvVariance);
85};
86
87template <typename XDataType,
88 typename YDataType,
89 typename AccDataType,
90 typename ScaleDataType,
91 typename BiasDataType,
92 typename MeanVarDataType,
93 typename YElementwiseOp,
94 typename XYGridDesc_M_K,
95 typename MeanVarCountGridDesc_M_G,
96 typename MeanVarCountGridDesc_M_K,
97 typename ScaleBiasGridDesc_M,
98 typename MeanVarGridDesc_M,
99 typename GetReduceCountPerThreadFunctor,
100 index_t BlockSize,
101 index_t MThreadClusterSize,
102 index_t KThreadClusterSize,
103 index_t MThreadSliceSize,
104 index_t KThreadSliceSize,
105 index_t XSrcYDstVectorDim,
106 index_t XSrcVectorSize,
107 index_t YDstVectorSize,
108 index_t ScaleSrcVectorSize,
109 index_t BiasSrcVectorSize,
110 index_t MeanVarSrcDstVectorSize>
112{
113 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
114 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
115 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
116
117 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
118 (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
119 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
120
121 static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
122
124
127
130
131 static constexpr auto thread_cluster_desc =
133
138
141
144
147
149 BlockSize,
152 false>;
153
155 BlockSize,
158 true>;
159
161
162 static constexpr auto I0 = Number<0>{};
163 static constexpr auto I1 = Number<1>{};
164
165 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
166 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
167
168 __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
169 const XYGridDesc_M_K& y_grid_desc_m_k,
170 const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
171 const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
172 const ScaleBiasGridDesc_M& scale_grid_desc_m,
173 const ScaleBiasGridDesc_M& bias_grid_desc_m,
174 const MeanVarGridDesc_M& mean_var_grid_desc_m,
175 const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
176 index_t num_k_block_tile_iteration,
177 AccDataType epsilon,
178 const XDataType* const __restrict__ p_x,
179 MeanVarDataType* const __restrict__ p_welford_mean,
180 MeanVarDataType* const __restrict__ p_welford_variance,
181 int32_t* const __restrict__ p_welford_count,
182 int32_t* const __restrict__ p_control,
183 const ScaleDataType* const __restrict__ p_scale,
184 const BiasDataType* const __restrict__ p_bias,
185 const YElementwiseOp y_elementwise_op,
186 YDataType* const __restrict__ p_y,
187 bool updateMovingAverage,
188 AccDataType averageFactor,
189 MeanVarDataType* const __restrict__ resultRunningMean,
190 MeanVarDataType* const __restrict__ resultRunningVariance,
191 bool saveMeanInvVariance,
192 MeanVarDataType* const __restrict__ resultSaveMean,
193 MeanVarDataType* const __restrict__ resultSaveInvVariance)
194 {
195 using ck::math::sqrt;
196
197 const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
198
199 const index_t thread_local_id = get_thread_local_1d_id();
200 const index_t block_global_id = get_block_1d_id();
201 const index_t blkgroup_id = block_global_id / blkgroup_size;
202 const index_t block_local_id = block_global_id % blkgroup_size;
203
204 if(block_local_id == 0)
205 gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
206
207 const auto thread_cluster_idx =
208 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
209
210 const auto thread_m_cluster_id = thread_cluster_idx[I0];
211 const auto thread_k_cluster_id = thread_cluster_idx[I1];
212
213 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
214 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
215 using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
216
217 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
219 constexpr auto thread_buffer_desc_m =
221 constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
223
225 x_thread_buf;
226
230
232 tmp_mean_thread_buf;
234 tmp_var_thread_buf;
236
237 const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
238
239 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
240 AccDataType,
241 XYGridDesc_M_K,
242 decltype(thread_buffer_desc_m_k),
243 ThreadBufferLengths_M_K,
245 XSrcYDstVectorDim,
246 XSrcVectorSize,
247 1,
248 true>(
249 x_grid_desc_m_k,
250 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
251 block_local_id * reduceSizePerBlock +
252 thread_k_cluster_id * KThreadSliceSize));
253
254 constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
255
256 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
257 p_x, x_grid_desc_m_k.GetElementSpaceSize());
258
259 // Step 1: each workgroup does local welford reduction
260
261 auto threadwise_welford_1 = ThreadwiseWelford1();
262 threadwise_welford_1.max_count_ =
263 get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
264
266 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
267 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
268 });
269
270 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
271 {
272 threadwise_x_load.Run(x_grid_desc_m_k,
273 x_global_val_buf,
274 thread_buffer_desc_m_k,
275 make_tuple(I0, I0),
276 x_thread_buf);
277
278 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
279 threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
280 }
281
283 if constexpr(I > 0)
285
286 count_thread_buf(I) = threadwise_welford_1.cur_count_;
287 BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
288 });
289
290 // Step 2: each workgroup writes its local welford result to workspace memory
291
292 auto mean_global_val_buf =
294 p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
295
296 auto var_global_val_buf =
298 p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
299
300 auto count_global_val_buf =
302 p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
303
304 auto threadwise_mean_var_store_m_g =
306 MeanVarDataType,
307 decltype(thread_buffer_desc_m_1),
308 MeanVarCountGridDesc_M_G,
310 ThreadBufferLengths_M_1,
312 0,
313 1,
315 1,
316 true>(
317 mean_var_count_grid_desc_m_g,
318 make_multi_index(blkgroup_id * M_BlockTileSize +
319 thread_m_cluster_id * MThreadSliceSize,
320 block_local_id),
321 PassThroughOp{});
322
323 auto threadwise_count_store_m_g =
325 int32_t,
326 decltype(thread_buffer_desc_m_1),
327 MeanVarCountGridDesc_M_G,
329 ThreadBufferLengths_M_1,
331 0,
332 1,
334 1,
335 true>(
336 mean_var_count_grid_desc_m_g,
337 make_multi_index(blkgroup_id * M_BlockTileSize +
338 thread_m_cluster_id * MThreadSliceSize,
339 block_local_id),
340 PassThroughOp{});
341
342 if(thread_k_cluster_id == 0)
343 {
344 threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
345 make_tuple(I0, I0),
346 mean_thread_buf,
347 mean_var_count_grid_desc_m_g,
348 mean_global_val_buf);
349
350 threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
351 make_tuple(I0, I0),
352 var_thread_buf,
353 mean_var_count_grid_desc_m_g,
354 var_global_val_buf);
355
356 threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
357 make_tuple(I0, I0),
358 count_thread_buf,
359 mean_var_count_grid_desc_m_g,
360 count_global_val_buf);
361 };
362
363 gms_barrier(&p_control[blkgroup_id * 2]);
364
365 if(block_local_id == 0)
366 gms_reset(&p_control[blkgroup_id * 2]);
367
368 // Step 3: each workgroup reads welford results from workspace memory and does final welford
369 // reduction
370
371 auto threadwise_mean_var_load_m_k =
372 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
373 AccDataType,
374 MeanVarCountGridDesc_M_K,
375 decltype(thread_buffer_desc_m_1),
376 ThreadBufferLengths_M_1,
378 0,
379 1,
380 1,
381 true>(
382 mean_var_count_grid_desc_m_k,
383 make_multi_index(blkgroup_id * M_BlockTileSize +
384 thread_m_cluster_id * MThreadSliceSize,
385 thread_k_cluster_id * 1));
386
387 auto threadwise_count_load_m_k =
389 int32_t,
390 MeanVarCountGridDesc_M_K,
391 decltype(thread_buffer_desc_m_1),
392 ThreadBufferLengths_M_1,
394 0,
395 1,
396 1,
397 true>(
398 mean_var_count_grid_desc_m_k,
399 make_multi_index(blkgroup_id * M_BlockTileSize +
400 thread_m_cluster_id * MThreadSliceSize,
401 thread_k_cluster_id * 1));
402
404 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
405 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
406 count_thread_buf(I) = 0;
407 });
408
409 constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
410
411 int32_t reducedSize = 0;
412 while(reducedSize < blkgroup_size)
413 {
414 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
415 mean_global_val_buf,
416 thread_buffer_desc_m_1,
417 make_tuple(I0, I0),
418 tmp_mean_thread_buf);
419
420 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
421 var_global_val_buf,
422 thread_buffer_desc_m_1,
423 make_tuple(I0, I0),
424 tmp_var_thread_buf);
425
426 threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
427 count_global_val_buf,
428 thread_buffer_desc_m_1,
429 make_tuple(I0, I0),
430 tmp_count_thread_buf);
431
432 ThreadwiseWelford2::Run(tmp_mean_thread_buf,
433 tmp_var_thread_buf,
434 tmp_count_thread_buf,
435 mean_thread_buf,
436 var_thread_buf,
437 count_thread_buf);
438
439 reducedSize += KThreadClusterSize;
440
441 threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
442 mean_var_count_read_fwd_step_m_k);
443 threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
444 mean_var_count_read_fwd_step_m_k);
445 };
446
448 if constexpr(I > 0)
450
451 BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
452 });
453
454 // Step 4: do normalization using the mean/variance
455
457
459
461 y_thread_buf;
462
463 auto threadwise_y_store =
465 YDataType,
466 decltype(thread_buffer_desc_m_k),
467 XYGridDesc_M_K,
468 YElementwiseOp,
469 ThreadBufferLengths_M_K,
471 XSrcYDstVectorDim,
472 YDstVectorSize,
474 1,
475 true>(
476 y_grid_desc_m_k,
478 blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
479 block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
480 y_elementwise_op);
481
482 auto threadwise_scale_load =
484 AccDataType,
485 ScaleBiasGridDesc_M,
486 decltype(thread_buffer_desc_m),
487 ThreadBufferLengths_M,
489 0,
490 ScaleSrcVectorSize,
491 1,
492 true>(
493 scale_grid_desc_m,
494 make_multi_index(blkgroup_id * M_BlockTileSize +
495 thread_m_cluster_id * MThreadSliceSize));
496
497 auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
498 AccDataType,
499 ScaleBiasGridDesc_M,
500 decltype(thread_buffer_desc_m),
501 ThreadBufferLengths_M,
503 0,
504 BiasSrcVectorSize,
505 1,
506 true>(
507 bias_grid_desc_m,
508 make_multi_index(blkgroup_id * M_BlockTileSize +
509 thread_m_cluster_id * MThreadSliceSize));
510
511 const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
512 p_scale, scale_grid_desc_m.GetElementSpaceSize());
513
514 const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
515 p_bias, bias_grid_desc_m.GetElementSpaceSize());
516
517 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
518 p_y, y_grid_desc_m_k.GetElementSpaceSize());
519
520 threadwise_scale_load.Run(scale_grid_desc_m,
521 scale_global_val_buf,
522 thread_buffer_desc_m,
523 make_tuple(I0),
524 scale_thread_buf);
525
526 threadwise_bias_load.Run(bias_grid_desc_m,
527 bias_global_val_buf,
528 thread_buffer_desc_m,
529 make_tuple(I0),
530 bias_thread_buf);
531
532 threadwise_x_load.SetSrcSliceOrigin(
533 x_grid_desc_m_k,
534 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
535 block_local_id * reduceSizePerBlock +
536 thread_k_cluster_id * KThreadSliceSize));
537
538 for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
539 {
540 threadwise_x_load.Run(x_grid_desc_m_k,
541 x_global_val_buf,
542 thread_buffer_desc_m_k,
543 make_tuple(I0, I0),
544 x_thread_buf);
545
547 AccDataType multiplier =
548 scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
549
550 AccDataType fused_mean_bias =
551 bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
552
554 constexpr auto offset =
555 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
556
557 // normalize
558 y_thread_buf(Number<offset>{}) =
559 x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
560 });
561 });
562
563 threadwise_y_store.Run(thread_buffer_desc_m_k,
564 make_tuple(I0, I0),
565 y_thread_buf,
566 y_grid_desc_m_k,
567 y_global_val_buf);
568
569 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
570 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
571 }
572
573 // Step 5: update the moving average of mean and variance (optional)
574
575 if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
576 {
578 running_mean_thread_buf;
580 running_var_thread_buf;
581
582 auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
583 resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
584
585 auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
586 resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
587
588 auto threadwise_mean_var_load =
589 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
590 AccDataType,
591 MeanVarGridDesc_M,
592 decltype(thread_buffer_desc_m),
593 ThreadBufferLengths_M,
595 0,
596 MeanVarSrcDstVectorSize,
597 1,
598 true>(
599 mean_var_grid_desc_m,
600 make_multi_index(blkgroup_id * M_BlockTileSize +
601 thread_m_cluster_id * MThreadSliceSize));
602
603 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
604 running_mean_global_buf,
605 thread_buffer_desc_m,
606 make_tuple(I0),
607 running_mean_thread_buf);
608
609 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
610 running_var_global_buf,
611 thread_buffer_desc_m,
612 make_tuple(I0),
613 running_var_thread_buf);
614
615 AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
616
618 running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
619 mean_thread_buf[I] * averageFactor;
620 running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
621 var_thread_buf[I] * averageFactor;
622 });
623
624 auto threadwise_mean_var_store =
626 MeanVarDataType,
627 decltype(thread_buffer_desc_m),
628 MeanVarGridDesc_M,
630 ThreadBufferLengths_M,
632 0,
633 MeanVarSrcDstVectorSize,
635 1,
636 true>(
637 mean_var_grid_desc_m,
638 make_multi_index(blkgroup_id * M_BlockTileSize +
639 thread_m_cluster_id * MThreadSliceSize),
640 PassThroughOp{});
641
642 threadwise_mean_var_store.Run(thread_buffer_desc_m,
643 make_tuple(I0),
644 running_mean_thread_buf,
645 mean_var_grid_desc_m,
646 running_mean_global_buf);
647
648 threadwise_mean_var_store.Run(thread_buffer_desc_m,
649 make_tuple(I0),
650 running_var_thread_buf,
651 mean_var_grid_desc_m,
652 running_var_global_buf);
653 };
654
655 // Step 6: save mean and inv-variance (optional)
656
657 if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
658 {
659 auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
660 resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
661
662 auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
663 resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
664
665 // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
667 var_thread_buf(I) =
668 type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
669 });
670
671 auto threadwise_mean_inv_var_store =
673 MeanVarDataType,
674 decltype(thread_buffer_desc_m),
675 MeanVarGridDesc_M,
677 ThreadBufferLengths_M,
679 0,
680 MeanVarSrcDstVectorSize,
682 1,
683 true>(
684 mean_var_grid_desc_m,
685 make_multi_index(blkgroup_id * M_BlockTileSize +
686 thread_m_cluster_id * MThreadSliceSize),
687 PassThroughOp{});
688
689 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
690 make_tuple(I0),
691 mean_thread_buf,
692 mean_var_grid_desc_m,
693 result_mean_global_buf);
694
695 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
696 make_tuple(I0),
697 var_thread_buf,
698 mean_var_grid_desc_m,
699 result_inv_var_global_buf);
700 };
701 }
702}; // namespace ck
703
704} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_batchnorm_forward.hpp:31
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
Definition blockwise_welford.hpp:25
static __device__ void Run(AccDataType &mean_value, AccDataType &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_multiblock_batchnorm_forward.hpp:112
ThreadwiseWelfordMerge< AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M > ThreadwiseWelford2
Definition gridwise_multiblock_batchnorm_forward.hpp:145
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford1
Definition gridwise_multiblock_batchnorm_forward.hpp:142
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition gridwise_multiblock_batchnorm_forward.hpp:134
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, true > BlockwiseWelford2
Definition gridwise_multiblock_batchnorm_forward.hpp:154
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_multiblock_batchnorm_forward.hpp:136
static constexpr auto I0
Definition gridwise_multiblock_batchnorm_forward.hpp:162
static constexpr bool reorder_thread_cluster
Definition gridwise_multiblock_batchnorm_forward.hpp:121
static constexpr index_t K_BlockTileSize
Definition gridwise_multiblock_batchnorm_forward.hpp:166
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_batchnorm_forward.hpp:123
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_batchnorm_forward.hpp:128
static constexpr auto I1
Definition gridwise_multiblock_batchnorm_forward.hpp:163
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_multiblock_batchnorm_forward.hpp:160
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_multiblock_batchnorm_forward.hpp:125
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition gridwise_multiblock_batchnorm_forward.hpp:139
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder, false > BlockwiseWelford1
Definition gridwise_multiblock_batchnorm_forward.hpp:148
static constexpr auto thread_cluster_desc
Definition gridwise_multiblock_batchnorm_forward.hpp:131
static constexpr index_t M_BlockTileSize
Definition gridwise_multiblock_batchnorm_forward.hpp:165
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const MeanVarCountGridDesc_M_G &mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_batchnorm_forward.hpp:168
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition threadwise_welford.hpp:18
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340