From 006ff3b9d93e5536d9a76509201acc89f1eabeb8 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 24 Aug 2023 12:28:30 +0200 Subject: [PATCH] Fix (**) and use property tests The default implementation of (**) relied on log and incorrectly handled some inputs. The fix is making an explicit implementation using the pow function. To test the changes, tests for functions from the Floating typeclass are re-written using property tests. There are a few helper functions to make writing the actual properties easy. More tests can be converted to properties, but this is left for another PR. --- arrayfire.cabal | 1 + src/ArrayFire/Orphans.hs | 4 + test/ArrayFire/ArithSpec.hs | 165 +++++++++++++++++++++++++----------- 3 files changed, 122 insertions(+), 48 deletions(-) diff --git a/arrayfire.cabal b/arrayfire.cabal index 349bb61..d2ab64c 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -151,6 +151,7 @@ test-suite test base < 5, directory, hspec, + HUnit, QuickCheck, quickcheck-classes, vector diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index 690a89d..e4d1c8a 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -50,8 +50,12 @@ instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where pi = A.scalar @a 3.14159 exp = A.exp @a log = A.log @a + sqrt = A.sqrt @a + (**) = A.pow @a sin = A.sin @a cos = A.cos @a + tan = A.tan @a + tanh = A.tanh @a asin = A.asin @a acos = A.acos @a atan = A.atan @a diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index 7560e22..623726f 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -1,10 +1,75 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} + module ArrayFire.ArithSpec where -import ArrayFire hiding (acos) -import Prelude hiding (abs, sqrt, div, and, or, not, isNaN) -import Test.Hspec +import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector) +import qualified ArrayFire +import Control.Exception (throwIO) +import Control.Monad (unless, when) import Foreign.C +import GHC.Exts (IsList (..)) +import GHC.Stack +import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..)) +import Test.Hspec +import Test.Hspec.QuickCheck +import Prelude hiding (div) + +compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation +compareWith comparator result expected = + unless (comparator result expected) $ do + throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg) + where + expectedMsg = show expected + actualMsg = show result + location = case reverse (toList callStack) of + (_, loc) : _ -> Just loc + [] -> Nothing + +class (Num a) => HasEpsilon a where + eps :: a + +instance HasEpsilon Float where + eps = 1.1920929e-7 + +instance HasEpsilon Double where + eps = 2.220446049250313e-16 + +approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool +approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b)) + +approx :: (Ord a, HasEpsilon a) => a -> a -> Bool +approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b + +shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation +shouldBeApprox = compareWith approx + +evalf :: (AFType a) => Array a -> a +evalf = ArrayFire.getScalar + +shouldMatchBuiltin :: + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => + (Array a -> Array a) -> + (a -> a) -> + a -> + Expectation +shouldMatchBuiltin f f' x + | isInfinite y && isInfinite y' = pure () + | Prelude.isNaN y && Prelude.isNaN y' = pure () + | otherwise = y `shouldBeApprox` y' + where + y = evalf (f (scalar x)) + y' = f' x + +shouldMatchBuiltin2 :: + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => + (Array a -> Array a -> Array a) -> + (a -> a -> a) -> + a -> + a -> + Expectation +shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a) spec :: Spec spec = @@ -12,7 +77,7 @@ spec = it "Should negate scalar value" $ do negate (scalar @Int 1) `shouldBe` (-1) it "Should negate a vector" $ do - negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2] + negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2] it "Should add two scalar arrays" $ do scalar @Int 1 + 2 `shouldBe` 3 it "Should add two scalar bool arrays" $ do @@ -20,80 +85,84 @@ spec = it "Should subtract two scalar arrays" $ do scalar @Int 4 - 2 `shouldBe` 2 it "Should multiply two scalar arrays" $ do - scalar @Double 4 `mul` 2 `shouldBe` 8 + scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8 it "Should divide two scalar arrays" $ do - div @Double 8 2 `shouldBe` 4 + ArrayFire.div @Double 8 2 `shouldBe` 4 it "Should add two matrices" $ do - matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]] - `shouldBe` - matrix @Int (2,2) [[2,2],[2,2]] - -- Exact comparisons of Double don't make sense here, so we just check that the result is - -- accurate up to some epsilon. - it "Should take cubed root" $ do - allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0) - it "Should take square root" $ do - allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0) + matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]] + `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]] + prop "Should take cubed root" $ \(x :: Double) -> + evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x it "Should lte Array" $ do - 2 `le` (3 :: Array Double) `shouldBe` 1 + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 it "Should gte Array" $ do - 2 `ge` (3 :: Array Double) `shouldBe` 0 + 2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0 it "Should gt Array" $ do - 2 `gt` (3 :: Array Double) `shouldBe` 0 + 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0 it "Should lt Array" $ do - 2 `le` (3 :: Array Double) `shouldBe` 1 + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 it "Should eq Array" $ do 3 == (3 :: Array Double) `shouldBe` True it "Should and Array" $ do - (mkArray @CBool [1] [0] `and` mkArray [1] [1]) - `shouldBe` mkArray [1] [0] + (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1]) + `shouldBe` mkArray [1] [0] it "Should and Array" $ do - (mkArray @CBool [2] [0,0] `and` mkArray [2] [1,0]) - `shouldBe` mkArray [2] [0, 0] + (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0]) + `shouldBe` mkArray [2] [0, 0] it "Should or Array" $ do - (mkArray @CBool [2] [0,0] `or` mkArray [2] [1,0]) - `shouldBe` mkArray [2] [1, 0] + (mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0]) + `shouldBe` mkArray [2] [1, 0] it "Should not Array" $ do - not (mkArray @CBool [2] [1,0]) `shouldBe` mkArray [2] [0,1] + ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1] it "Should bitwise and array" $ do - bitAnd (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 0 + ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0) + `shouldBe` 0 it "Should bitwise or array" $ do - bitOr (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 1 + ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0) + `shouldBe` 1 it "Should bitwise xor array" $ do - bitXor (scalar @Int 1) (scalar @Int 1) - `shouldBe` - 0 + ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1) + `shouldBe` 0 it "Should bitwise shift left an array" $ do - bitShiftL (scalar @Int 1) (scalar @Int 3) - `shouldBe` - 8 + ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3) + `shouldBe` 8 it "Should cast an array" $ do getType (cast (scalar @Int 1) :: Array Double) - `shouldBe` - F64 + `shouldBe` ArrayFire.F64 it "Should find the minimum of two arrays" $ do minOf (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 0 + `shouldBe` 0 it "Should find the max of two arrays" $ do maxOf (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 1 + `shouldBe` 1 it "Should take the clamp of 3 arrays" $ do clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) - `shouldBe` - 2 + `shouldBe` 2 it "Should check if an array has positive or negative infinities" $ do isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1 isInf (scalar @Double 10) `shouldBe` scalar @Double 0 it "Should check if an array has any NaN values" $ do - isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 - isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 + ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 + ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 it "Should check if an array has any Zero values" $ do isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0 isZero (scalar @Double 0) `shouldBe` scalar @Double 1 isZero (scalar @Double 1) `shouldBe` scalar @Double 0 + + prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x + prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x + prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x + prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y + prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x + prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x + prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x + prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x + prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x + prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x + prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x + prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x + prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x + prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x + prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x + prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x