26template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
43 "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
44 double compute_error = 0;
60 "Warning: Unhandled OutDataType for setting up the relative threshold!");
61 double output_error = 0;
71 double midway_error = std::max(compute_error, output_error);
78 "Warning: Unhandled AccDataType for setting up the relative threshold!");
89 return std::max(acc_error, midway_error);
92template <
typename ComputeDataType,
typename OutDataType,
typename AccDataType = ComputeDataType>
109 "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
110 auto expo = std::log2(std::abs(max_possible_num));
111 double compute_error = 0;
127 "Warning: Unhandled OutDataType for setting up the absolute threshold!");
128 double output_error = 0;
138 double midway_error = std::max(compute_error, output_error);
145 "Warning: Unhandled AccDataType for setting up the absolute threshold!");
146 double acc_error = 0;
157 return std::max(acc_error, midway_error);
160template <
typename Range,
163typename std::enable_if<
165 std::is_same_v<ranges::range_value_t<Range>,
float> &&
166 std::is_same_v<ComputeDataType, ck::tf32_t>,
170 const std::string& msg =
"Error: Incorrect results!",
174 if(out.size() != ref.size())
176 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
184 double max_err = std::numeric_limits<double>::min();
185 for(std::size_t i = 0; i < ref.size(); ++i)
187 const double o = *std::next(std::begin(out), i);
188 const double r = *std::next(std::begin(ref), i);
189 err = std::abs(o - r);
190 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
192 max_err = err > max_err ? err : max_err;
195 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
196 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
204 const float error_percent =
205 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
206 std::cerr <<
"max err: " << max_err;
207 std::cerr <<
", number of errors: " << err_count;
208 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
213template <
typename Range,
216typename std::enable_if<
218 std::is_floating_point_v<ranges::range_value_t<Range>> &&
219 !std::is_same_v<ranges::range_value_t<Range>,
half_t> &&
220 !std::is_same_v<ComputeDataType, ck::tf32_t>,
224 const std::string& msg =
"Error: Incorrect results!",
228 if(out.size() != ref.size())
230 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
238 double max_err = std::numeric_limits<double>::min();
239 for(std::size_t i = 0; i < ref.size(); ++i)
241 const double o = *std::next(std::begin(out), i);
242 const double r = *std::next(std::begin(ref), i);
243 err = std::abs(o - r);
244 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
246 max_err = err > max_err ? err : max_err;
249 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
250 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
258 const float error_percent =
259 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
260 std::cerr <<
"max err: " << max_err;
261 std::cerr <<
", number of errors: " << err_count;
262 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
267template <
typename Range,
270typename std::enable_if<
272 std::is_same_v<ranges::range_value_t<Range>,
bhalf_t>,
276 const std::string& msg =
"Error: Incorrect results!",
280 if(out.size() != ref.size())
282 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
291 double max_err = std::numeric_limits<float>::min();
292 for(std::size_t i = 0; i < ref.size(); ++i)
296 err = std::abs(o - r);
297 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
299 max_err = err > max_err ? err : max_err;
303 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
304 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
311 const float error_percent =
312 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
313 std::cerr <<
"max err: " << max_err;
314 std::cerr <<
", number of errors: " << err_count;
315 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
320template <
typename Range,
323typename std::enable_if<
325 std::is_same_v<ranges::range_value_t<Range>,
half_t>,
329 const std::string& msg =
"Error: Incorrect results!",
333 if(out.size() != ref.size())
335 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
344 for(std::size_t i = 0; i < ref.size(); ++i)
348 err = std::abs(o - r);
349 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
351 max_err = err > max_err ? err : max_err;
355 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
356 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
363 const float error_percent =
364 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
365 std::cerr <<
"max err: " << max_err;
366 std::cerr <<
", number of errors: " << err_count;
367 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
372template <
typename Range,
376 std::is_integral_v<ranges::range_value_t<Range>> &&
377 !std::is_same_v<ranges::range_value_t<Range>,
bhalf_t> &&
378 !std::is_same_v<ranges::range_value_t<Range>,
f8_t> &&
379 !std::is_same_v<ranges::range_value_t<Range>,
bf8_t>)
380#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
387 const std::string& msg =
"Error: Incorrect results!",
391 if(out.size() != ref.size())
393 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
401 int64_t max_err = std::numeric_limits<int64_t>::min();
402 for(std::size_t i = 0; i < ref.size(); ++i)
404 const int64_t o = *std::next(std::begin(out), i);
405 const int64_t r = *std::next(std::begin(ref), i);
406 err = std::abs(o - r);
410 max_err = err > max_err ? err : max_err;
414 std::cerr << msg <<
" out[" << i <<
"] != ref[" << i <<
"]: " << o <<
" != " << r
422 const float error_percent =
423 static_cast<float>(err_count) /
static_cast<float>(out.size()) * 100.f;
424 std::cerr <<
"max err: " << max_err;
425 std::cerr <<
", number of errors: " << err_count;
426 std::cerr <<
", " << error_percent <<
"% wrong values" << std::endl;
431template <
typename Range,
435 std::is_same_v<ranges::range_value_t<Range>,
f8_t>),
439 const std::string& msg =
"Error: Incorrect results!",
443 if(out.size() != ref.size())
445 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
453 double max_err = std::numeric_limits<float>::min();
455 for(std::size_t i = 0; i < ref.size(); ++i)
459 err = std::abs(o - r);
461 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
463 max_err = err > max_err ? err : max_err;
467 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
468 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
476 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
477 <<
" number of errors: " << err_count << std::endl;
482template <
typename Range,
486 std::is_same_v<ranges::range_value_t<Range>,
bf8_t>),
490 const std::string& msg =
"Error: Incorrect results!",
494 if(out.size() != ref.size())
496 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
504 double max_err = std::numeric_limits<float>::min();
505 for(std::size_t i = 0; i < ref.size(); ++i)
509 err = std::abs(o - r);
510 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
512 max_err = err > max_err ? err : max_err;
516 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
517 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
524 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err << std::endl;
529template <
typename Range,
533 std::is_same_v<ranges::range_value_t<Range>,
f4_t>),
537 const std::string& msg =
"Error: Incorrect results!",
541 if(out.size() != ref.size())
543 std::cerr << msg <<
" out.size() != ref.size(), :" << out.size() <<
" != " << ref.size()
551 double max_err = std::numeric_limits<float>::min();
553 for(std::size_t i = 0; i < ref.size(); ++i)
557 err = std::abs(o - r);
559 if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
561 max_err = err > max_err ? err : max_err;
565 std::cerr << msg << std::setw(12) << std::setprecision(7) <<
" out[" << i
566 <<
"] != ref[" << i <<
"]: " << o <<
" != " << r << std::endl;
574 std::cerr << std::setw(12) << std::setprecision(7) <<
"max err: " << max_err
575 <<
" number of errors: " << err_count << std::endl;
std::enable_if< std::is_same_v< ranges::range_value_t< Range >, ranges::range_value_t< RefRange > > &&std::is_same_v< ranges::range_value_t< Range >, float > &&std::is_same_v< ComputeDataType, ck::tf32_t >, bool >::type check_err(const Range &out, const RefRange &ref, const std::string &msg="Error: Incorrect results!", double rtol=1e-5, double atol=3e-5)
Definition library/utility/check_err.hpp:168