reduction_operator.hpp Source File

reduction_operator.hpp Source File#

Composable Kernel: reduction_operator.hpp Source File
reduction_operator.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck/ck.hpp"
8#include "ck/utility/type.hpp"
10
11namespace ck {
12
13namespace reduce {
14
15// Every binary operator used in reduction is represented by a templated functor class. Each functor
16// class must provide at least
17// three members:
18// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary
19// operator, "identity element" is the unique
20// element in the algebraic space that doesn't affect the value of other elements
21// when operated against them, and the concept is similar to zero vector in
22// vector space
23// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf).
24// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this
25// operator can use the InMemoryDataOperation to finalize, or else it return false
26// 3) operator() -- the first argument of the operator must be both an input & output, and the
27// corresponding variable usually stores
28// the accumulated result of many operator() calls; the second argument is only an
29// input. For indexable binary
30// operator, the second version of operator() has third argument (which is an
31// output) to indicate whether the
32// accumulated value (the first argument) has changed, in which case the recorded
33// accumulated index also need be
34// changed.
35
36struct Add
37{
38 template <typename T>
39 __host__ __device__ static constexpr T GetIdentityValue()
40 {
41 return type_convert<T>(0.0f);
42 };
43
44 __host__ __device__ static constexpr bool
50
51 template <typename T>
52 __host__ __device__ inline constexpr void operator()(T& a, T b) const
53 {
56 "The data type is not supported by the Add accumulator!");
57
58 a = a + b;
59 }
60
61 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
62 {
63 float a_ = type_convert<float>(a);
64 float b_ = type_convert<float>(b);
65
66 a = type_convert<f8_t>(a_ + b_);
67 }
68
69 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
70 {
71 float a_ = type_convert<float>(a);
72 float b_ = type_convert<float>(b);
73
74 a = type_convert<half_t>(a_ + b_);
75 }
76
77 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
78 {
79 float a_ = type_convert<float>(a);
80 float b_ = type_convert<float>(b);
81
82 a = type_convert<bhalf_t>(a_ + b_);
83 }
84};
85
87{
88 template <class T>
89 __host__ __device__ static constexpr T GetIdentityValue()
90 {
91 return type_convert<T>(0.0f);
92 };
93
94 __host__ __device__ static constexpr bool
100
101 template <class T>
102 __host__ __device__ inline constexpr void operator()(T& a, T b) const
103 {
107 "The data type is not supported by the SquaredAdd accumulator!");
108
109 a = a + b * b;
110 }
111};
112
113struct Mul
114{
115 template <typename T>
116 __host__ __device__ static constexpr T GetIdentityValue()
117 {
118 return type_convert<T>(1.0f);
119 };
120
121 __host__ __device__ static constexpr bool
126
127 template <typename T>
128 __host__ __device__ inline constexpr void operator()(T& a, T b) const
129 {
132 "The data type is not supported by the Mul accumulator!");
133
134 a = a * b;
135 }
136
137 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
138 {
139 float a_ = type_convert<float>(a);
140 float b_ = type_convert<float>(b);
141
142 a = type_convert<f8_t>(a_ * b_);
143 }
144
145 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
146 {
147 float a_ = type_convert<float>(a);
148 float b_ = type_convert<float>(b);
149
150 a = type_convert<half_t>(a_ * b_);
151 }
152
153 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
154 {
155 float a_ = type_convert<float>(a);
156 float b_ = type_convert<float>(b);
157
158 a = type_convert<bhalf_t>(a_ * b_);
159 }
160};
161
162struct Max
163{
164 template <typename T>
165 __host__ __device__ static constexpr T GetIdentityValue()
166 {
167 if constexpr(is_same_v<T, bhalf_t>)
168 {
169 float val = NumericLimits<float>::Lowest();
170 return type_convert<bhalf_t>(val);
171 }
172 if constexpr(is_same_v<T, f8_t>)
173 {
174 float val = NumericLimits<float>::Lowest();
175 return type_convert<f8_t>(val);
176 }
177 if constexpr(is_same_v<T, half_t>)
178 {
179 float val = NumericLimits<float>::Lowest();
180 return type_convert<half_t>(val);
181 }
182 else
183 {
185 }
186 };
187
188 __host__ __device__ static constexpr bool
190 {
191 // ToChange: atomic_max to be added
192 return operation == InMemoryDataOperationEnum::Set;
193 };
194
195 template <typename T>
196 __host__ __device__ inline constexpr void operator()(T& a, T b) const
197 {
200 "The data type is not supported by the Max accumulator!");
201
202 if(a < b)
203 a = b;
204 }
205
206 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
207 {
208 float a_ = type_convert<float>(a);
209 float b_ = type_convert<float>(b);
210
211 if(a_ < b_)
212 a = b;
213 }
214
215 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
216 {
217 float a_ = type_convert<float>(a);
218 float b_ = type_convert<float>(b);
219
220 if(a_ < b_)
221 a = b;
222 }
223
224 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
225 {
226 float a_ = type_convert<float>(a);
227 float b_ = type_convert<float>(b);
228
229 if(a_ < b_)
230 a = b;
231 }
232
233 template <typename T>
234 __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
235 {
238 "The data type is not supported by the Max accumulator!");
239
240 if(a < b)
241 {
242 a = b;
243 changed = true;
244 }
245 }
246
247 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
248 {
249 float a_ = type_convert<float>(a);
250 float b_ = type_convert<float>(b);
251
252 if(a_ < b_)
253 {
254 a = b;
255 changed = true;
256 }
257 }
258
259 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
260 {
261 float a_ = type_convert<float>(a);
262 float b_ = type_convert<float>(b);
263
264 if(a_ < b_)
265 {
266 a = b;
267 changed = true;
268 }
269 }
270
271 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
272 {
273 float a_ = type_convert<float>(a);
274 float b_ = type_convert<float>(b);
275
276 if(a_ < b_)
277 {
278 a = b;
279 changed = true;
280 }
281 }
282};
283
284struct Min
285{
286 template <typename T>
287 __host__ __device__ static constexpr T GetIdentityValue()
288 {
289 if constexpr(is_same_v<T, bhalf_t>)
290 {
291 float val = NumericLimits<float>::Max();
292 return type_convert<bhalf_t>(val);
293 }
294 else if constexpr(is_same_v<T, half_t>)
295 {
296 float val = NumericLimits<float>::Max();
297 return type_convert<half_t>(val);
298 }
299 else if constexpr(is_same_v<T, f8_t>)
300 {
301 float val = NumericLimits<float>::Max();
302 return type_convert<f8_t>(val);
303 }
304 else
305 {
306 return NumericLimits<T>::Max();
307 }
308 return NumericLimits<T>::Max();
309 };
310
311 __host__ __device__ static constexpr bool
313 {
314 // ToChange: atomic_min to be added
315 return operation == InMemoryDataOperationEnum::Set;
316 };
317
318 template <typename T>
319 __host__ __device__ inline constexpr void operator()(T& a, T b) const
320 {
323 "The data type is not supported by the Min accumulator!");
324
325 if(a > b)
326 a = b;
327 }
328
329 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
330 {
331 float a_ = type_convert<float>(a);
332 float b_ = type_convert<float>(b);
333
334 if(a_ > b_)
335 a = b;
336 }
337
338 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b) const
339 {
340 float a_ = type_convert<float>(a);
341 float b_ = type_convert<float>(b);
342
343 if(a_ > b_)
344 a = b;
345 }
346
347 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
348 {
349 float a_ = type_convert<float>(a);
350 float b_ = type_convert<float>(b);
351
352 if(a_ > b_)
353 a = b;
354 }
355
356 template <typename T>
357 __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
358 {
362 "The data type is not supported by the Min accumulator!");
363
364 if(a > b)
365 {
366 a = b;
367 changed = true;
368 }
369 }
370
371 __host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
372 {
373 float a_ = type_convert<float>(a);
374 float b_ = type_convert<float>(b);
375
376 if(a_ > b_)
377 {
378 a = b;
379 changed = true;
380 }
381 }
382
383 __host__ __device__ inline constexpr void operator()(half_t& a, half_t b, bool& changed) const
384 {
385 float a_ = type_convert<float>(a);
386 float b_ = type_convert<float>(b);
387
388 if(a_ > b_)
389 {
390 a = b;
391 changed = true;
392 }
393 }
394
395 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
396 {
397 float a_ = type_convert<float>(a);
398 float b_ = type_convert<float>(b);
399
400 if(a_ > b_)
401 {
402 a = b;
403 changed = true;
404 }
405 }
406};
407
408struct AMax
409{
410 template <typename T>
411 __host__ __device__ static constexpr T GetIdentityValue()
412 {
413 return type_convert<T>(0.0f);
414 };
415
416 __host__ __device__ static constexpr bool
418 {
419 // ToChange: atomic_max to be added
420 return operation == InMemoryDataOperationEnum::Set;
421 };
422
423 template <typename T>
424 __host__ __device__ inline constexpr void operator()(T& a, T b) const
425 {
429 "The data type is not supported by the AMax accumulator!");
430
431 if(a < b)
432 a = b;
433 }
434
435 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b) const
436 {
437 float a_ = type_convert<float>(a);
438 float b_ = type_convert<float>(b);
439
440 if(a_ < b_)
441 a = b;
442 }
443
444 template <typename T>
445 __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
446 {
450 "The data type is not supported by the AMax accumulator!");
451
452 if(a < b)
453 {
454 a = b;
455 changed = true;
456 }
457 }
458
459 __host__ __device__ inline constexpr void operator()(f8_t& a, f8_t b, bool& changed) const
460 {
461 float a_ = type_convert<float>(a);
462 float b_ = type_convert<float>(b);
463
464 if(a_ < b_)
465 {
466 a = b;
467 changed = true;
468 }
469 }
470};
471
472template <typename T>
474{
475 T result = ck::type_convert<T>(0.0f);
476
479
480 return (result);
481};
482
483template <InMemoryDataOperationEnum Operation, typename DataType>
485{
486 static constexpr bool value = false;
487};
488
489template <typename DataType>
495
496template <typename DataType>
502
503template <typename DataType>
512
513template <typename DataType>
521
522} // namespace reduce
523} // namespace ck
Definition reduction_operator.hpp:13
constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:473
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicMax
Definition ck.hpp:280
@ AtomicAdd
Definition ck.hpp:279
_Float16 half_t
Definition data_type.hpp:31
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
constexpr bool is_same_v
Definition type.hpp:283
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
__host__ static __device__ constexpr T Max()
Definition numeric_limits.hpp:311
__host__ static __device__ constexpr T Lowest()
Definition numeric_limits.hpp:312
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition reduction_operator.hpp:409
__host__ __device__ constexpr void operator()(T &a, T b, bool &changed) const
Definition reduction_operator.hpp:445
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:424
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:417
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b) const
Definition reduction_operator.hpp:435
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b, bool &changed) const
Definition reduction_operator.hpp:459
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:411
Definition reduction_operator.hpp:37
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:45
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b) const
Definition reduction_operator.hpp:77
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b) const
Definition reduction_operator.hpp:61
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:39
__host__ __device__ constexpr void operator()(half_t &a, half_t b) const
Definition reduction_operator.hpp:69
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:52
Definition reduction_operator.hpp:485
static constexpr bool value
Definition reduction_operator.hpp:486
Definition reduction_operator.hpp:163
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:165
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:189
__host__ __device__ constexpr void operator()(half_t &a, half_t b, bool &changed) const
Definition reduction_operator.hpp:259
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b, bool &changed) const
Definition reduction_operator.hpp:271
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:196
__host__ __device__ constexpr void operator()(half_t &a, half_t b) const
Definition reduction_operator.hpp:215
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b) const
Definition reduction_operator.hpp:224
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b) const
Definition reduction_operator.hpp:206
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b, bool &changed) const
Definition reduction_operator.hpp:247
__host__ __device__ constexpr void operator()(T &a, T b, bool &changed) const
Definition reduction_operator.hpp:234
Definition reduction_operator.hpp:285
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b) const
Definition reduction_operator.hpp:347
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b, bool &changed) const
Definition reduction_operator.hpp:371
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:287
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:312
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:319
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b, bool &changed) const
Definition reduction_operator.hpp:395
__host__ __device__ constexpr void operator()(half_t &a, half_t b) const
Definition reduction_operator.hpp:338
__host__ __device__ constexpr void operator()(half_t &a, half_t b, bool &changed) const
Definition reduction_operator.hpp:383
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b) const
Definition reduction_operator.hpp:329
__host__ __device__ constexpr void operator()(T &a, T b, bool &changed) const
Definition reduction_operator.hpp:357
Definition reduction_operator.hpp:114
__host__ __device__ constexpr void operator()(f8_t &a, f8_t b) const
Definition reduction_operator.hpp:137
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:128
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:116
__host__ __device__ constexpr void operator()(half_t &a, half_t b) const
Definition reduction_operator.hpp:145
__host__ __device__ constexpr void operator()(bhalf_t &a, bhalf_t b) const
Definition reduction_operator.hpp:153
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:122
Definition reduction_operator.hpp:87
__host__ static __device__ constexpr bool IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
Definition reduction_operator.hpp:95
__host__ static __device__ constexpr T GetIdentityValue()
Definition reduction_operator.hpp:89
__host__ __device__ constexpr void operator()(T &a, T b) const
Definition reduction_operator.hpp:102