diff --git a/arrayfire.cabal b/arrayfire.cabal index d2ab64c..b347c8b 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -144,6 +144,8 @@ test-suite test exitcode-stdio-1.0 main-is: Main.hs + other-modules: + Test.Hspec.ApproxExpect hs-source-dirs: test build-depends: @@ -154,7 +156,8 @@ test-suite test HUnit, QuickCheck, quickcheck-classes, - vector + vector, + call-stack >=0.4 && <0.5 if !flag(disable-build-tool-depends) build-tool-depends: hspec-discover:hspec-discover diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 2c9f554..5c225c7 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -4,6 +4,7 @@ module ArrayFire.LAPACKSpec where import qualified ArrayFire as A import Prelude import Test.Hspec +import Test.Hspec.ApproxExpect spec :: Spec spec = @@ -33,9 +34,9 @@ spec = it "Should get determinant of Double" $ do let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) - x `shouldBe` (-14) + x `shouldBeApprox` (-14) let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] - x `shouldBe` (-14) + x `shouldBeApprox` (-14) -- it "Should calculate inverse" $ do -- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] -- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] diff --git a/test/ArrayFire/StatisticsSpec.hs b/test/ArrayFire/StatisticsSpec.hs index 392f617..c8c6314 100644 --- a/test/ArrayFire/StatisticsSpec.hs +++ b/test/ArrayFire/StatisticsSpec.hs @@ -5,6 +5,7 @@ import ArrayFire hiding (not) import Data.Complex import Test.Hspec +import Test.Hspec.ApproxExpect spec :: Spec spec = @@ -15,7 +16,7 @@ spec = 5.5 it "Should find the weighted-mean" $ do meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 - `shouldBe` + `shouldBeApprox` 7.0 it "Should find the variance" $ do var (vector @Double 8 [1..8]) False 0 diff --git a/test/Test/Hspec/ApproxExpect.hs b/test/Test/Hspec/ApproxExpect.hs new file mode 100644 index 0000000..3e9d66b --- /dev/null +++ b/test/Test/Hspec/ApproxExpect.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +module Test.Hspec.ApproxExpect where + +import Data.CallStack (HasCallStack) + +import Test.Hspec (shouldSatisfy, Expectation) + +infix 1 `shouldBeApprox` + +shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) + => a -> a -> Expectation +shouldBeApprox actual tgt + -- This is a hackish way of checking, without requiring a specific + -- type or an 'Ord' instance, whether two floating-point values + -- are only some epsilons apart: when the difference is small enough + -- so scaling it down some more makes it a no-op for addition. + = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt +