From 5e6925485a58f8db0a588815ad6147172efa21d4 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Sun, 26 Mar 2023 20:59:36 +0000 Subject: [PATCH 1/6] Fix tests for ArrayFire 3.8.3 (#51) * Do not rely on == comparison for cbrt and sqrt tests * Do not be too strict about the version constraint; upstream ArrayFire changes a lot between minor releases, so we check for that, but don't assert a specific patch version --- test/ArrayFire/ArithSpec.hs | 8 +++++--- test/ArrayFire/UtilSpec.hs | 6 ++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index ae03a54..7560e22 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -2,7 +2,7 @@ module ArrayFire.ArithSpec where import ArrayFire hiding (acos) -import Prelude hiding (sqrt, div, and, or, not, isNaN) +import Prelude hiding (abs, sqrt, div, and, or, not, isNaN) import Test.Hspec import Foreign.C @@ -27,10 +27,12 @@ spec = matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]] `shouldBe` matrix @Int (2,2) [[2,2],[2,2]] + -- Exact comparisons of Double don't make sense here, so we just check that the result is + -- accurate up to some epsilon. it "Should take cubed root" $ do - 3 `shouldBe` cbrt @Double 27 + allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0) it "Should take square root" $ do - 2 `shouldBe` sqrt @Double 4 + allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0) it "Should lte Array" $ do 2 `le` (3 :: Array Double) `shouldBe` 1 diff --git a/test/ArrayFire/UtilSpec.hs b/test/ArrayFire/UtilSpec.hs index 9ff01dc..5d95dd1 100644 --- a/test/ArrayFire/UtilSpec.hs +++ b/test/ArrayFire/UtilSpec.hs @@ -30,8 +30,10 @@ spec = A.getSizeOf (Proxy @(Complex Float)) `shouldBe` 8 A.getSizeOf (Proxy @(Complex Double)) `shouldBe` 16 it "Should get version" $ do - x <- A.getVersion - x `shouldBe` (3,8,2) + (major, minor, patch) <- A.getVersion + major `shouldBe` 3 + minor `shouldBe` 8 + patch `shouldSatisfy` (>= 0) it "Should get revision" $ do x <- A.getRevision x `shouldSatisfy` (not . null) From 1e4f9091220b59cf3e9e5c26aa6d0f1ca86fa5d9 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Thu, 24 Aug 2023 03:06:51 +0200 Subject: [PATCH 2/6] Switch to Nix flakes; make the tests pass with newest ArrayFire (#55) * Switch to Nix flakes; make the tests pass with nix build & nix develop for the latest version of ArrayFire * update flake.lock --- arrayfire.cabal | 15 ++++-- cabal.project | 5 ++ flake.lock | 101 +++++++++++++++++++++++++++++++++++++ flake.nix | 90 +++++++++++++++++++++++++++++++++ test/ArrayFire/UtilSpec.hs | 2 +- 5 files changed, 208 insertions(+), 5 deletions(-) create mode 100644 cabal.project create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/arrayfire.cabal b/arrayfire.cabal index 22d2fd4..349bb61 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -21,6 +21,11 @@ flag disable-default-paths default: False manual: True +flag disable-build-tool-depends + description: When enabled, don't add build-tool-depends fields to the Cabal file. Needed for working inside @nix develop@. + default: False + manual: True + custom-setup setup-depends: base <5, @@ -75,8 +80,9 @@ library ArrayFire.Internal.Types ArrayFire.Internal.Util ArrayFire.Internal.Vision - build-tool-depends: - hsc2hs:hsc2hs + if !flag(disable-build-tool-depends) + build-tool-depends: + hsc2hs:hsc2hs extra-libraries: af c-sources: @@ -148,8 +154,9 @@ test-suite test QuickCheck, quickcheck-classes, vector - build-tool-depends: - hspec-discover:hspec-discover + if !flag(disable-build-tool-depends) + build-tool-depends: + hspec-discover:hspec-discover default-language: Haskell2010 other-modules: diff --git a/cabal.project b/cabal.project new file mode 100644 index 0000000..7f529ad --- /dev/null +++ b/cabal.project @@ -0,0 +1,5 @@ +ignore-project: False +write-ghc-environment-files: always +tests: True +test-options: "--color" +test-show-details: streaming diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..27da0e5 --- /dev/null +++ b/flake.lock @@ -0,0 +1,101 @@ +{ + "nodes": { + "arrayfire-nix": { + "inputs": { + "flake-utils": [ + "flake-utils" + ], + "nixpkgs": [ + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1692793973, + "narHash": "sha256-6dG41ile3T+6dfRazlcPBdKBarGesswsBpb40Lcf35U=", + "owner": "twesterhout", + "repo": "arrayfire-nix", + "rev": "4236770612b80a3f29adbd8d670f6cea2bc098ba", + "type": "github" + }, + "original": { + "owner": "twesterhout", + "repo": "arrayfire-nix", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1692792214, + "narHash": "sha256-voZDQOvqHsaReipVd3zTKSBwN7LZcUwi3/ThMxRZToU=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1721b3e7c882f75f2301b00d48a2884af8c448ae", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nix-filter": { + "locked": { + "lastModified": 1687178632, + "narHash": "sha256-HS7YR5erss0JCaUijPeyg2XrisEb959FIct3n2TMGbE=", + "owner": "numtide", + "repo": "nix-filter", + "rev": "d90c75e8319d0dd9be67d933d8eb9d0894ec9174", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "nix-filter", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1692638711, + "narHash": "sha256-J0LgSFgJVGCC1+j5R2QndadWI1oumusg6hCtYAzLID4=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "91a22f76cd1716f9d0149e8a5c68424bb691de15", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "arrayfire-nix": "arrayfire-nix", + "flake-utils": "flake-utils", + "nix-filter": "nix-filter", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..d7f0af7 --- /dev/null +++ b/flake.nix @@ -0,0 +1,90 @@ +{ + description = "arrayfire/arrayfire-haskell: ArrayFire Haskell bindings"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + nix-filter.url = "github:numtide/nix-filter"; + arrayfire-nix = { + url = "github:twesterhout/arrayfire-nix"; + inputs.flake-utils.follows = "flake-utils"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + }; + + outputs = inputs: + let + src = inputs.nix-filter.lib { + root = ./.; + include = [ + "cbits" + "exe" + "gen" + "include" + "src" + "test" + "arrayfire.cabal" + "README.md" + "CHANGELOG.md" + "LICENSE" + ]; + }; + + # An overlay that lets us test arrayfire-haskell with different GHC versions + arrayfire-haskell-overlay = self: super: { + haskell = super.haskell // { + packageOverrides = inputs.nixpkgs.lib.composeExtensions super.haskell.packageOverrides + (hself: hsuper: { + arrayfire = self.haskell.lib.appendConfigureFlags + (hself.callCabal2nix "arrayfire" src { af = self.arrayfire; }) + [ "-f disable-default-paths" ]; + }); + }; + }; + + devShell-for = pkgs: + let + ps = pkgs.haskellPackages; + in + ps.shellFor { + packages = ps: with ps; [ arrayfire ]; + withHoogle = true; + buildInputs = with pkgs; [ ocl-icd ]; + nativeBuildInputs = with pkgs; with ps; [ + # Building and testing + cabal-install + doctest + hsc2hs + hspec-discover + # Language servers + haskell-language-server + nil + # Formatters + nixpkgs-fmt + ]; + shellHook = '' + ''; + }; + + pkgs-for = system: import inputs.nixpkgs { + inherit system; + overlays = [ + inputs.arrayfire-nix.overlays.default + arrayfire-haskell-overlay + ]; + }; + in + { + packages = inputs.flake-utils.lib.eachDefaultSystemMap (system: + with (pkgs-for system); { + default = haskellPackages.arrayfire; + haskell = haskell.packages; + }); + + devShells = inputs.flake-utils.lib.eachDefaultSystemMap (system: { + default = devShell-for (pkgs-for system); + }); + + overlays.default = arrayfire-haskell-overlay; + }; +} diff --git a/test/ArrayFire/UtilSpec.hs b/test/ArrayFire/UtilSpec.hs index 5d95dd1..3539bc2 100644 --- a/test/ArrayFire/UtilSpec.hs +++ b/test/ArrayFire/UtilSpec.hs @@ -32,7 +32,7 @@ spec = it "Should get version" $ do (major, minor, patch) <- A.getVersion major `shouldBe` 3 - minor `shouldBe` 8 + minor `shouldSatisfy` (>= 8) patch `shouldSatisfy` (>= 0) it "Should get revision" $ do x <- A.getRevision From 90812e02bfeb160d8ced614b3db90aef8ba727da Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 25 Aug 2023 15:35:27 +0200 Subject: [PATCH 3/6] Fix joinMany (#56) Instead of allocating an array of pointers, joinMany was allocating memory for just one pointer. This was making ArrayFire read out of bounds and fail with various errors. This commit fixes this issue by adding a helper withManyForeignPtr function that acts like withForeignPtr (not unsafeWithForeignPtr!), but for a list of ForeignPtrs. --- cabal.project | 1 + src/ArrayFire/Data.hs | 25 +++++++++++++------------ test/ArrayFire/DataSpec.hs | 5 +++++ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/cabal.project b/cabal.project index 7f529ad..e5c6ff0 100644 --- a/cabal.project +++ b/cabal.project @@ -1,3 +1,4 @@ +packages: . ignore-project: False write-ghc-environment-files: always tests: True diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index ab3b8e6..c4fab4b 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -30,7 +30,6 @@ module ArrayFire.Data where import Control.Exception -import Control.Monad import Data.Complex import Data.Int import Data.Proxy @@ -38,6 +37,7 @@ import Data.Word import Foreign.C.Types import Foreign.ForeignPtr import Foreign.Marshal hiding (void) +import Foreign.Ptr (Ptr) import Foreign.Storable import System.IO.Unsafe import Unsafe.Coerce @@ -357,20 +357,21 @@ joinMany :: Int -> [Array a] -> Array a -joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do - fptrs <- forM arrays $ \(Array fptr) -> pure fptr - newPtr <- - alloca $ \fPtrsPtr -> do - forM_ fptrs $ \fptr -> - withForeignPtr fptr (poke fPtrsPtr) - alloca $ \aPtr -> do - zeroOutArray aPtr - throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr - peek aPtr +joinMany (fromIntegral -> n) (fmap (\(Array fp) -> fp) -> arrays) = unsafePerformIO . mask_ $ do + newPtr <- alloca $ \aPtr -> do + zeroOutArray aPtr + (throwAFError =<<) $ + withManyForeignPtr arrays $ \(fromIntegral -> nArrays) fPtrsPtr -> + af_join_many aPtr n nArrays fPtrsPtr + peek aPtr Array <$> newForeignPtr af_release_array_finalizer newPtr + +withManyForeignPtr :: [ForeignPtr a] -> (Int -> Ptr (Ptr a) -> IO b) -> IO b +withManyForeignPtr fptrs action = go [] fptrs where - nArrays = fromIntegral (length arrays) + go ptrs [] = withArrayLen (reverse ptrs) action + go ptrs (fptr:others) = withForeignPtr fptr $ \ptr -> go (ptr : ptrs) others -- | Tiles an Array according to specified dimensions -- diff --git a/test/ArrayFire/DataSpec.hs b/test/ArrayFire/DataSpec.hs index ab22e69..fcbd53f 100644 --- a/test/ArrayFire/DataSpec.hs +++ b/test/ArrayFire/DataSpec.hs @@ -32,3 +32,8 @@ spec = constant @(Complex Float) [1] (1.0 :+ 1.0) `shouldBe` constant @(Complex Float) [1] (1.0 :+ 1.0) + it "Should join Arrays along the specified dimension" $ do + join 0 (constant @Int [1, 3] 1) (constant @Int [1, 3] 2) `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] + join 1 (constant @Int [1, 2] 1) (constant @Int [1, 2] 2) `shouldBe` mkArray @Int [1, 4] [1, 1, 2, 2] + joinMany 0 [constant @Int [1, 3] 1, constant @Int [1, 3] 2] `shouldBe` mkArray @Int [2, 3] [1, 2, 1, 2, 1, 2] + joinMany 1 [constant @Int [1, 2] 1, constant @Int [1, 1] 2, constant @Int [1, 3] 3] `shouldBe` mkArray @Int [1, 6] [1, 1, 2, 3, 3, 3] From 971bae05a4f544c1375e60fa402ed4df18165ff9 Mon Sep 17 00:00:00 2001 From: Tom Westerhout <14264576+twesterhout@users.noreply.github.com> Date: Fri, 25 Aug 2023 16:54:16 +0200 Subject: [PATCH 4/6] Fix (**) and use property tests (#57) The default implementation of (**) relied on log and incorrectly handled some inputs. The fix is making an explicit implementation using the pow function. To test the changes, tests for functions from the Floating typeclass are re-written using property tests. There are a few helper functions to make writing the actual properties easy. More tests can be converted to properties, but this is left for another PR. --- arrayfire.cabal | 1 + src/ArrayFire/Orphans.hs | 4 + test/ArrayFire/ArithSpec.hs | 165 +++++++++++++++++++++++++----------- 3 files changed, 122 insertions(+), 48 deletions(-) diff --git a/arrayfire.cabal b/arrayfire.cabal index 349bb61..d2ab64c 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -151,6 +151,7 @@ test-suite test base < 5, directory, hspec, + HUnit, QuickCheck, quickcheck-classes, vector diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index 690a89d..e4d1c8a 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -50,8 +50,12 @@ instance forall a . (Ord a, AFType a, Fractional a) => Floating (Array a) where pi = A.scalar @a 3.14159 exp = A.exp @a log = A.log @a + sqrt = A.sqrt @a + (**) = A.pow @a sin = A.sin @a cos = A.cos @a + tan = A.tan @a + tanh = A.tanh @a asin = A.asin @a acos = A.acos @a atan = A.atan @a diff --git a/test/ArrayFire/ArithSpec.hs b/test/ArrayFire/ArithSpec.hs index 7560e22..623726f 100644 --- a/test/ArrayFire/ArithSpec.hs +++ b/test/ArrayFire/ArithSpec.hs @@ -1,10 +1,75 @@ +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} + module ArrayFire.ArithSpec where -import ArrayFire hiding (acos) -import Prelude hiding (abs, sqrt, div, and, or, not, isNaN) -import Test.Hspec +import ArrayFire (AFType, Array, cast, clamp, getType, isInf, isZero, matrix, maxOf, minOf, mkArray, scalar, vector) +import qualified ArrayFire +import Control.Exception (throwIO) +import Control.Monad (unless, when) import Foreign.C +import GHC.Exts (IsList (..)) +import GHC.Stack +import Test.HUnit.Lang (FailureReason (..), HUnitFailure (..)) +import Test.Hspec +import Test.Hspec.QuickCheck +import Prelude hiding (div) + +compareWith :: (HasCallStack, Show a) => (a -> a -> Bool) -> a -> a -> Expectation +compareWith comparator result expected = + unless (comparator result expected) $ do + throwIO (HUnitFailure location $ ExpectedButGot Nothing expectedMsg actualMsg) + where + expectedMsg = show expected + actualMsg = show result + location = case reverse (toList callStack) of + (_, loc) : _ -> Just loc + [] -> Nothing + +class (Num a) => HasEpsilon a where + eps :: a + +instance HasEpsilon Float where + eps = 1.1920929e-7 + +instance HasEpsilon Double where + eps = 2.220446049250313e-16 + +approxWith :: (Ord a, Num a) => a -> a -> a -> a -> Bool +approxWith rtol atol a b = abs (a - b) <= Prelude.max atol (rtol * Prelude.max (abs a) (abs b)) + +approx :: (Ord a, HasEpsilon a) => a -> a -> Bool +approx a b = approxWith (2 * eps * Prelude.max (abs a) (abs b)) (4 * eps) a b + +shouldBeApprox :: (Ord a, HasEpsilon a, Show a) => a -> a -> Expectation +shouldBeApprox = compareWith approx + +evalf :: (AFType a) => Array a -> a +evalf = ArrayFire.getScalar + +shouldMatchBuiltin :: + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => + (Array a -> Array a) -> + (a -> a) -> + a -> + Expectation +shouldMatchBuiltin f f' x + | isInfinite y && isInfinite y' = pure () + | Prelude.isNaN y && Prelude.isNaN y' = pure () + | otherwise = y `shouldBeApprox` y' + where + y = evalf (f (scalar x)) + y' = f' x + +shouldMatchBuiltin2 :: + (AFType a, Ord a, RealFloat a, HasEpsilon a, Show a) => + (Array a -> Array a -> Array a) -> + (a -> a -> a) -> + a -> + a -> + Expectation +shouldMatchBuiltin2 f f' a = shouldMatchBuiltin (f (scalar a)) (f' a) spec :: Spec spec = @@ -12,7 +77,7 @@ spec = it "Should negate scalar value" $ do negate (scalar @Int 1) `shouldBe` (-1) it "Should negate a vector" $ do - negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2] + negate (vector @Int 3 [2, 2, 2]) `shouldBe` vector @Int 3 [-2, -2, -2] it "Should add two scalar arrays" $ do scalar @Int 1 + 2 `shouldBe` 3 it "Should add two scalar bool arrays" $ do @@ -20,80 +85,84 @@ spec = it "Should subtract two scalar arrays" $ do scalar @Int 4 - 2 `shouldBe` 2 it "Should multiply two scalar arrays" $ do - scalar @Double 4 `mul` 2 `shouldBe` 8 + scalar @Double 4 `ArrayFire.mul` 2 `shouldBe` 8 it "Should divide two scalar arrays" $ do - div @Double 8 2 `shouldBe` 4 + ArrayFire.div @Double 8 2 `shouldBe` 4 it "Should add two matrices" $ do - matrix @Int (2,2) [[1,1],[1,1]] + matrix @Int (2,2) [[1,1],[1,1]] - `shouldBe` - matrix @Int (2,2) [[2,2],[2,2]] - -- Exact comparisons of Double don't make sense here, so we just check that the result is - -- accurate up to some epsilon. - it "Should take cubed root" $ do - allTrueAll ((abs (3 - cbrt @Double 27)) `lt` 1.0e-14) `shouldBe` (1, 0) - it "Should take square root" $ do - allTrueAll ((abs (2 - sqrt @Double 4)) `lt` 1.0e-14) `shouldBe` (1, 0) + matrix @Int (2, 2) [[1, 1], [1, 1]] + matrix @Int (2, 2) [[1, 1], [1, 1]] + `shouldBe` matrix @Int (2, 2) [[2, 2], [2, 2]] + prop "Should take cubed root" $ \(x :: Double) -> + evalf (ArrayFire.cbrt (scalar (x * x * x))) `shouldBeApprox` x it "Should lte Array" $ do - 2 `le` (3 :: Array Double) `shouldBe` 1 + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 it "Should gte Array" $ do - 2 `ge` (3 :: Array Double) `shouldBe` 0 + 2 `ArrayFire.ge` (3 :: Array Double) `shouldBe` 0 it "Should gt Array" $ do - 2 `gt` (3 :: Array Double) `shouldBe` 0 + 2 `ArrayFire.gt` (3 :: Array Double) `shouldBe` 0 it "Should lt Array" $ do - 2 `le` (3 :: Array Double) `shouldBe` 1 + 2 `ArrayFire.le` (3 :: Array Double) `shouldBe` 1 it "Should eq Array" $ do 3 == (3 :: Array Double) `shouldBe` True it "Should and Array" $ do - (mkArray @CBool [1] [0] `and` mkArray [1] [1]) - `shouldBe` mkArray [1] [0] + (mkArray @CBool [1] [0] `ArrayFire.and` mkArray [1] [1]) + `shouldBe` mkArray [1] [0] it "Should and Array" $ do - (mkArray @CBool [2] [0,0] `and` mkArray [2] [1,0]) - `shouldBe` mkArray [2] [0, 0] + (mkArray @CBool [2] [0, 0] `ArrayFire.and` mkArray [2] [1, 0]) + `shouldBe` mkArray [2] [0, 0] it "Should or Array" $ do - (mkArray @CBool [2] [0,0] `or` mkArray [2] [1,0]) - `shouldBe` mkArray [2] [1, 0] + (mkArray @CBool [2] [0, 0] `ArrayFire.or` mkArray [2] [1, 0]) + `shouldBe` mkArray [2] [1, 0] it "Should not Array" $ do - not (mkArray @CBool [2] [1,0]) `shouldBe` mkArray [2] [0,1] + ArrayFire.not (mkArray @CBool [2] [1, 0]) `shouldBe` mkArray [2] [0, 1] it "Should bitwise and array" $ do - bitAnd (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 0 + ArrayFire.bitAnd (scalar @Int 1) (scalar @Int 0) + `shouldBe` 0 it "Should bitwise or array" $ do - bitOr (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 1 + ArrayFire.bitOr (scalar @Int 1) (scalar @Int 0) + `shouldBe` 1 it "Should bitwise xor array" $ do - bitXor (scalar @Int 1) (scalar @Int 1) - `shouldBe` - 0 + ArrayFire.bitXor (scalar @Int 1) (scalar @Int 1) + `shouldBe` 0 it "Should bitwise shift left an array" $ do - bitShiftL (scalar @Int 1) (scalar @Int 3) - `shouldBe` - 8 + ArrayFire.bitShiftL (scalar @Int 1) (scalar @Int 3) + `shouldBe` 8 it "Should cast an array" $ do getType (cast (scalar @Int 1) :: Array Double) - `shouldBe` - F64 + `shouldBe` ArrayFire.F64 it "Should find the minimum of two arrays" $ do minOf (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 0 + `shouldBe` 0 it "Should find the max of two arrays" $ do maxOf (scalar @Int 1) (scalar @Int 0) - `shouldBe` - 1 + `shouldBe` 1 it "Should take the clamp of 3 arrays" $ do clamp (scalar @Int 2) (scalar @Int 1) (scalar @Int 3) - `shouldBe` - 2 + `shouldBe` 2 it "Should check if an array has positive or negative infinities" $ do isInf (scalar @Double (1 / 0)) `shouldBe` scalar @Double 1 isInf (scalar @Double 10) `shouldBe` scalar @Double 0 it "Should check if an array has any NaN values" $ do - isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 - isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 + ArrayFire.isNaN (scalar @Double (acos 2)) `shouldBe` scalar @Double 1 + ArrayFire.isNaN (scalar @Double 10) `shouldBe` scalar @Double 0 it "Should check if an array has any Zero values" $ do isZero (scalar @Double (acos 2)) `shouldBe` scalar @Double 0 isZero (scalar @Double 0) `shouldBe` scalar @Double 1 isZero (scalar @Double 1) `shouldBe` scalar @Double 0 + + prop "Floating @Float (exp)" $ \(x :: Float) -> exp `shouldMatchBuiltin` exp $ x + prop "Floating @Float (log)" $ \(x :: Float) -> log `shouldMatchBuiltin` log $ x + prop "Floating @Float (sqrt)" $ \(x :: Float) -> sqrt `shouldMatchBuiltin` sqrt $ x + prop "Floating @Float (**)" $ \(x :: Float) (y :: Float) -> ((**) `shouldMatchBuiltin2` (**)) x y + prop "Floating @Float (sin)" $ \(x :: Float) -> sin `shouldMatchBuiltin` sin $ x + prop "Floating @Float (cos)" $ \(x :: Float) -> cos `shouldMatchBuiltin` cos $ x + prop "Floating @Float (tan)" $ \(x :: Float) -> tan `shouldMatchBuiltin` tan $ x + prop "Floating @Float (asin)" $ \(x :: Float) -> asin `shouldMatchBuiltin` asin $ x + prop "Floating @Float (acos)" $ \(x :: Float) -> acos `shouldMatchBuiltin` acos $ x + prop "Floating @Float (atan)" $ \(x :: Float) -> atan `shouldMatchBuiltin` atan $ x + prop "Floating @Float (sinh)" $ \(x :: Float) -> sinh `shouldMatchBuiltin` sinh $ x + prop "Floating @Float (cosh)" $ \(x :: Float) -> cosh `shouldMatchBuiltin` cosh $ x + prop "Floating @Float (tanh)" $ \(x :: Float) -> tanh `shouldMatchBuiltin` tanh $ x + prop "Floating @Float (asinh)" $ \(x :: Float) -> asinh `shouldMatchBuiltin` asinh $ x + prop "Floating @Float (acosh)" $ \(x :: Float) -> acosh `shouldMatchBuiltin` acosh $ x + prop "Floating @Float (atanh)" $ \(x :: Float) -> atanh `shouldMatchBuiltin` atanh $ x From f268fc9f0ab0284f79520d585f1b692f744b4afb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Justus=20Sagem=C3=BCller?= Date: Sat, 4 May 2024 23:46:30 +0200 Subject: [PATCH 5/6] Relax some tests that are satisfied only up to some floating point error. (#59) --- arrayfire.cabal | 5 ++++- test/ArrayFire/LAPACKSpec.hs | 5 +++-- test/ArrayFire/StatisticsSpec.hs | 3 ++- test/Test/Hspec/ApproxExpect.hs | 19 +++++++++++++++++++ 4 files changed, 28 insertions(+), 4 deletions(-) create mode 100644 test/Test/Hspec/ApproxExpect.hs diff --git a/arrayfire.cabal b/arrayfire.cabal index d2ab64c..b347c8b 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -144,6 +144,8 @@ test-suite test exitcode-stdio-1.0 main-is: Main.hs + other-modules: + Test.Hspec.ApproxExpect hs-source-dirs: test build-depends: @@ -154,7 +156,8 @@ test-suite test HUnit, QuickCheck, quickcheck-classes, - vector + vector, + call-stack >=0.4 && <0.5 if !flag(disable-build-tool-depends) build-tool-depends: hspec-discover:hspec-discover diff --git a/test/ArrayFire/LAPACKSpec.hs b/test/ArrayFire/LAPACKSpec.hs index 2c9f554..5c225c7 100644 --- a/test/ArrayFire/LAPACKSpec.hs +++ b/test/ArrayFire/LAPACKSpec.hs @@ -4,6 +4,7 @@ module ArrayFire.LAPACKSpec where import qualified ArrayFire as A import Prelude import Test.Hspec +import Test.Hspec.ApproxExpect spec :: Spec spec = @@ -33,9 +34,9 @@ spec = it "Should get determinant of Double" $ do let eles = [[3 A.:+ 1, 8 A.:+ 1], [4 A.:+ 1, 6 A.:+ 1]] (x,y) = A.det (A.matrix @(A.Complex Double) (2,2) eles) - x `shouldBe` (-14) + x `shouldBeApprox` (-14) let (x,y) = A.det $ A.matrix @Double (2,2) [[3,8],[4,6]] - x `shouldBe` (-14) + x `shouldBeApprox` (-14) -- it "Should calculate inverse" $ do -- let x = flip A.inverse A.None $ A.matrix @Double (2,2) [[4.0,7.0],[2.0,6.0]] -- x `shouldBe` A.matrix (2,2) [[0.6,-0.7],[-0.2,0.4]] diff --git a/test/ArrayFire/StatisticsSpec.hs b/test/ArrayFire/StatisticsSpec.hs index 392f617..c8c6314 100644 --- a/test/ArrayFire/StatisticsSpec.hs +++ b/test/ArrayFire/StatisticsSpec.hs @@ -5,6 +5,7 @@ import ArrayFire hiding (not) import Data.Complex import Test.Hspec +import Test.Hspec.ApproxExpect spec :: Spec spec = @@ -15,7 +16,7 @@ spec = 5.5 it "Should find the weighted-mean" $ do meanWeighted (vector @Double 10 [1..]) (vector @Double 10 [1..]) 0 - `shouldBe` + `shouldBeApprox` 7.0 it "Should find the variance" $ do var (vector @Double 8 [1..8]) False 0 diff --git a/test/Test/Hspec/ApproxExpect.hs b/test/Test/Hspec/ApproxExpect.hs new file mode 100644 index 0000000..3e9d66b --- /dev/null +++ b/test/Test/Hspec/ApproxExpect.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE ScopedTypeVariables #-} +module Test.Hspec.ApproxExpect where + +import Data.CallStack (HasCallStack) + +import Test.Hspec (shouldSatisfy, Expectation) + +infix 1 `shouldBeApprox` + +shouldBeApprox :: (HasCallStack, Show a, Fractional a, Eq a) + => a -> a -> Expectation +shouldBeApprox actual tgt + -- This is a hackish way of checking, without requiring a specific + -- type or an 'Ord' instance, whether two floating-point values + -- are only some epsilons apart: when the difference is small enough + -- so scaling it down some more makes it a no-op for addition. + = actual `shouldSatisfy` \x -> (x-tgt) * 1e-4 + tgt == tgt + From fa74dd8c5664656c860694dae5532c01502bcc9b Mon Sep 17 00:00:00 2001 From: dmjio Date: Fri, 13 Dec 2024 18:02:18 -0600 Subject: [PATCH 6/6] update copyright year, maintainer --- arrayfire.cabal | 4 ++-- src/ArrayFire.hs | 2 +- src/ArrayFire/Algorithm.hs | 2 +- src/ArrayFire/Arith.hs | 2 +- src/ArrayFire/Array.hs | 2 +- src/ArrayFire/BLAS.hs | 2 +- src/ArrayFire/Backend.hs | 2 +- src/ArrayFire/Data.hs | 2 +- src/ArrayFire/Device.hs | 2 +- src/ArrayFire/Exception.hs | 2 +- src/ArrayFire/FFI.hs | 2 +- src/ArrayFire/Features.hs | 2 +- src/ArrayFire/Graphics.hs | 2 +- src/ArrayFire/Image.hs | 2 +- src/ArrayFire/Index.hs | 2 +- src/ArrayFire/LAPACK.hs | 2 +- src/ArrayFire/Orphans.hs | 2 +- src/ArrayFire/Random.hs | 2 +- src/ArrayFire/Signal.hs | 2 +- src/ArrayFire/Sparse.hs | 2 +- src/ArrayFire/Statistics.hs | 2 +- src/ArrayFire/Types.hs | 2 +- src/ArrayFire/Util.hs | 2 +- src/ArrayFire/Vision.hs | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/arrayfire.cabal b/arrayfire.cabal index b347c8b..df41c2f 100644 --- a/arrayfire.cabal +++ b/arrayfire.cabal @@ -6,8 +6,8 @@ homepage: https://github.com/arrayfire/arrayfire-haskell license: BSD-3-Clause license-file: LICENSE author: David Johnson -maintainer: djohnson.m@gmail.com -copyright: David Johnson (c) 2018-2023 +maintainer: code@dmj.io +copyright: David Johnson (c) 2018-2025 category: Math build-type: Custom extra-source-files: CHANGELOG.md diff --git a/src/ArrayFire.hs b/src/ArrayFire.hs index a711ff0..f5cf814 100644 --- a/src/ArrayFire.hs +++ b/src/ArrayFire.hs @@ -3,7 +3,7 @@ -- Module : ArrayFire -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Algorithm.hs b/src/ArrayFire/Algorithm.hs index 67e57ef..b7fccba 100644 --- a/src/ArrayFire/Algorithm.hs +++ b/src/ArrayFire/Algorithm.hs @@ -6,7 +6,7 @@ -- Module : ArrayFire.Algorithm -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Arith.hs b/src/ArrayFire/Arith.hs index f74548a..ec2cc25 100644 --- a/src/ArrayFire/Arith.hs +++ b/src/ArrayFire/Arith.hs @@ -7,7 +7,7 @@ -- Module : ArrayFire.Arith -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Array.hs b/src/ArrayFire/Array.hs index 7c36359..b0abc01 100644 --- a/src/ArrayFire/Array.hs +++ b/src/ArrayFire/Array.hs @@ -12,7 +12,7 @@ -- Module : ArrayFire.Array -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/BLAS.hs b/src/ArrayFire/BLAS.hs index 0f734a1..321980a 100644 --- a/src/ArrayFire/BLAS.hs +++ b/src/ArrayFire/BLAS.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.BLAS -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Backend.hs b/src/ArrayFire/Backend.hs index 6379788..7b9b14f 100644 --- a/src/ArrayFire/Backend.hs +++ b/src/ArrayFire/Backend.hs @@ -3,7 +3,7 @@ -- Module : ArrayFire.Backend -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index c4fab4b..8bcfe54 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -12,7 +12,7 @@ -- Module : ArrayFire.Data -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Device.hs b/src/ArrayFire/Device.hs index c9c4482..29a9e63 100644 --- a/src/ArrayFire/Device.hs +++ b/src/ArrayFire/Device.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.Device -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Exception.hs b/src/ArrayFire/Exception.hs index d39007d..bc8a12d 100644 --- a/src/ArrayFire/Exception.hs +++ b/src/ArrayFire/Exception.hs @@ -5,7 +5,7 @@ -- Module : ArrayFire.Exception -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/FFI.hs b/src/ArrayFire/FFI.hs index e56d1f9..483bc0f 100644 --- a/src/ArrayFire/FFI.hs +++ b/src/ArrayFire/FFI.hs @@ -6,7 +6,7 @@ -- Module : ArrayFire.FFI -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Features.hs b/src/ArrayFire/Features.hs index d286e0b..a84f58d 100644 --- a/src/ArrayFire/Features.hs +++ b/src/ArrayFire/Features.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.Features -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Graphics.hs b/src/ArrayFire/Graphics.hs index 35b71ae..e657625 100644 --- a/src/ArrayFire/Graphics.hs +++ b/src/ArrayFire/Graphics.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.Graphics -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Image.hs b/src/ArrayFire/Image.hs index e8b55fb..9ae11d8 100644 --- a/src/ArrayFire/Image.hs +++ b/src/ArrayFire/Image.hs @@ -6,7 +6,7 @@ -- Module : ArrayFire.Image -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Index.hs b/src/ArrayFire/Index.hs index c31eaee..c0a34f5 100644 --- a/src/ArrayFire/Index.hs +++ b/src/ArrayFire/Index.hs @@ -3,7 +3,7 @@ -- Module : ArrayFire.Index -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/LAPACK.hs b/src/ArrayFire/LAPACK.hs index 70b7966..d30e98f 100644 --- a/src/ArrayFire/LAPACK.hs +++ b/src/ArrayFire/LAPACK.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.LAPACK -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Orphans.hs b/src/ArrayFire/Orphans.hs index e4d1c8a..fc09c6b 100644 --- a/src/ArrayFire/Orphans.hs +++ b/src/ArrayFire/Orphans.hs @@ -8,7 +8,7 @@ -- Module : ArrayFire.Orphans -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Random.hs b/src/ArrayFire/Random.hs index a89ebcd..0f0c31f 100644 --- a/src/ArrayFire/Random.hs +++ b/src/ArrayFire/Random.hs @@ -13,7 +13,7 @@ -- Module : ArrayFire.Random -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Signal.hs b/src/ArrayFire/Signal.hs index 4a2f7aa..4ddae65 100644 --- a/src/ArrayFire/Signal.hs +++ b/src/ArrayFire/Signal.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.Signal -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Sparse.hs b/src/ArrayFire/Sparse.hs index f1969a5..1b35026 100644 --- a/src/ArrayFire/Sparse.hs +++ b/src/ArrayFire/Sparse.hs @@ -4,7 +4,7 @@ -- Module : ArrayFire.Sparse -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Statistics.hs b/src/ArrayFire/Statistics.hs index 96095b1..8a3db79 100644 --- a/src/ArrayFire/Statistics.hs +++ b/src/ArrayFire/Statistics.hs @@ -5,7 +5,7 @@ -- Module : ArrayFire.Statistics -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Types.hs b/src/ArrayFire/Types.hs index afd9988..e63f6c9 100644 --- a/src/ArrayFire/Types.hs +++ b/src/ArrayFire/Types.hs @@ -16,7 +16,7 @@ -- Module : ArrayFire.Types -- Copyright : David Johnson (c) 2019-2020 -- License : BSD3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Util.hs b/src/ArrayFire/Util.hs index 2175627..d8ba69b 100644 --- a/src/ArrayFire/Util.hs +++ b/src/ArrayFire/Util.hs @@ -6,7 +6,7 @@ -- Module : ArrayFire.Util -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC -- diff --git a/src/ArrayFire/Vision.hs b/src/ArrayFire/Vision.hs index 2587477..71f3bd7 100644 --- a/src/ArrayFire/Vision.hs +++ b/src/ArrayFire/Vision.hs @@ -6,7 +6,7 @@ -- Module : ArrayFire.Vision -- Copyright : David Johnson (c) 2019-2020 -- License : BSD 3 --- Maintainer : David Johnson +-- Maintainer : David Johnson -- Stability : Experimental -- Portability : GHC --