@@ -342,6 +342,59 @@ LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::TestConfig(LongVector::Bin
342
342
}
343
343
}
344
344
345
+ template <typename DataTypeT, typename LongVectorOpTypeT>
346
+ LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::TestConfig(LongVector::TrigonometricOpType OpType)
347
+ : OpTypeTraits(OpType) {
348
+ IntrinsicString = " " ;
349
+ BasicOpType = LongVector::BasicOpType_Unary;
350
+
351
+ // All trigonometric ops are floating point types.
352
+ // These trig functions are defined to have a max absolute error of 0.0008
353
+ // as per the D3D functional specs. An example with this spec for sin and
354
+ // cos is available here:
355
+ // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#22.10.20
356
+ ValidationType = LongVector::ValidationType_Epsilon;
357
+ if (std::is_same_v<DataTypeT, HLSLHalf_t>)
358
+ Tolerance = 0 .0010f ;
359
+ else if (std::is_same_v<DataTypeT, float >)
360
+ Tolerance = 0 .0008f ;
361
+ else
362
+ VERIFY_FAIL (
363
+ " Invalid type for trigonometric op. Expecting half or float." );
364
+
365
+ switch (OpType) {
366
+ case LongVector::TrigonometricOpType_Acos:
367
+ IntrinsicString = " acos" ;
368
+ break ;
369
+ case LongVector::TrigonometricOpType_Asin:
370
+ IntrinsicString = " asin" ;
371
+ break ;
372
+ case LongVector::TrigonometricOpType_Atan:
373
+ IntrinsicString = " atan" ;
374
+ break ;
375
+ case LongVector::TrigonometricOpType_Cos:
376
+ IntrinsicString = " cos" ;
377
+ break ;
378
+ case LongVector::TrigonometricOpType_Cosh:
379
+ IntrinsicString = " cosh" ;
380
+ break ;
381
+ case LongVector::TrigonometricOpType_Sin:
382
+ IntrinsicString = " sin" ;
383
+ break ;
384
+ case LongVector::TrigonometricOpType_Sinh:
385
+ IntrinsicString = " sinh" ;
386
+ break ;
387
+ case LongVector::TrigonometricOpType_Tan:
388
+ IntrinsicString = " tan" ;
389
+ break ;
390
+ case LongVector::TrigonometricOpType_Tanh:
391
+ IntrinsicString = " tanh" ;
392
+ break ;
393
+ default :
394
+ VERIFY_FAIL (" Invalid TrigonometricOpType" );
395
+ }
396
+ }
397
+
345
398
template <typename DataTypeT, typename LongVectorOpTypeT>
346
399
bool LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::hasFunctionDefinition() const {
347
400
if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::UnaryOpType>) {
@@ -463,6 +516,13 @@ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedV
463
516
template <typename DataTypeT, typename LongVectorOpTypeT>
464
517
DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedValue(const DataTypeT &A) const {
465
518
519
+ if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::TrigonometricOpType>) {
520
+ const auto OpType = static_cast <LongVector::TrigonometricOpType>(OpTypeTraits.OpType );
521
+ // HLSLHalf_t is a struct. We need to call the constructor to get the
522
+ // expected value.
523
+ return computeExpectedValue (A, OpType);
524
+ }
525
+
466
526
if constexpr (std::is_same_v<LongVectorOpTypeT, LongVector::UnaryOpType>) {
467
527
const auto OpType = static_cast <LongVector::UnaryOpType>(OpTypeTraits.OpType );
468
528
// HLSLHalf_t is a struct. We need to call the constructor to get the
@@ -477,6 +537,67 @@ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedV
477
537
return DataTypeT ();
478
538
}
479
539
540
+ template <typename DataTypeT, typename LongVectorOpTypeT>
541
+ DataTypeT LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::computeExpectedValue(const DataTypeT &A,
542
+ LongVector::TrigonometricOpType OpType) const {
543
+ // The trig functions are only valid on floating point types. The constexpr in
544
+ // this case is a relatively easy and clean way to prevent the compiler from
545
+ // erroring out trying to resolve these for the non floating point types. We
546
+ // won't use them in the first place.
547
+ if constexpr (isFloatingPointType<DataTypeT>()) {
548
+ switch (OpType) {
549
+ case LongVector::TrigonometricOpType_Acos:
550
+ return std::acos (A);
551
+ case LongVector::TrigonometricOpType_Asin:
552
+ return std::asin (A);
553
+ case LongVector::TrigonometricOpType_Atan:
554
+ return std::atan (A);
555
+ case LongVector::TrigonometricOpType_Cos:
556
+ return std::cos (A);
557
+ case LongVector::TrigonometricOpType_Cosh:
558
+ return std::cosh (A);
559
+ case LongVector::TrigonometricOpType_Sin:
560
+ return std::sin (A);
561
+ case LongVector::TrigonometricOpType_Sinh:
562
+ return std::sinh (A);
563
+ case LongVector::TrigonometricOpType_Tan:
564
+ return std::tan (A);
565
+ case LongVector::TrigonometricOpType_Tanh:
566
+ return std::tanh (A);
567
+ default :
568
+ LOG_ERROR_FMT_THROW (L" Unknown TrigonometricOpType: %d" ,
569
+ OpTypeTraits.OpType );
570
+ return DataTypeT ();
571
+ }
572
+ }
573
+
574
+ LOG_ERROR_FMT_THROW (L" ComputeExpectedValue(const DataTypeT &A, "
575
+ L" LongVectorOpTypeT OpType) called on a "
576
+ L" non-float type: %d" ,
577
+ OpType);
578
+
579
+ return DataTypeT ();
580
+ }
581
+
582
+ template <typename DataTypeT, typename LongVectorOpTypeT>
583
+ std::vector<DataTypeT> LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::getInputArgsArray() const {
584
+
585
+ std::vector<DataTypeT> InputArgs;
586
+
587
+ std::wstring InputArgsArrayName = this ->InputArgsArrayName ;
588
+
589
+ if (InputArgsArrayName.empty ())
590
+ VERIFY_FAIL (" No args array name set." );
591
+
592
+ if (std::is_same_v<DataTypeT, HLSLBool_t> && isClampOp ())
593
+ VERIFY_FAIL (" Clamp is not supported for bools." );
594
+ else
595
+ return getInputValueSetByKey<DataTypeT>(InputArgsArrayName, false );
596
+
597
+ VERIFY_FAIL (" Invalid type for args array." );
598
+ return std::vector<DataTypeT>();
599
+ }
600
+
480
601
template <typename DataTypeT, typename LongVectorOpTypeT>
481
602
std::string LongVector::TestConfig<DataTypeT, LongVectorOpTypeT>::getCompilerOptionsString(size_t VectorSize) const {
482
603
std::stringstream CompilerOptions (" " );
0 commit comments