{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE CPP #-}
module Data.SecureMem
( SecureMem
, secureMemGetSize
, secureMemCopy
, ToSecureMem(..)
, allocateSecureMem
, createSecureMem
, unsafeCreateSecureMem
, finalizeSecureMem
, withSecureMemPtr
, withSecureMemPtrSz
, withSecureMemCopy
, secureMemFromByteString
, secureMemFromByteable
) where
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr
import Data.Word (Word8)
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup
import Data.Foldable (toList)
#else
import Data.Monoid
#endif
import Control.Applicative
import Data.Byteable
import Data.ByteString (ByteString)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as B
import qualified Data.Memory.PtrMethods as B (memSet)
import qualified Data.ByteString.Internal as BS
#if MIN_VERSION_base(4,4,0)
import System.IO.Unsafe (unsafeDupablePerformIO)
#else
import System.IO.Unsafe (unsafePerformIO)
#endif
pureIO :: IO a -> a
#if MIN_VERSION_base(4,4,0)
pureIO :: IO a -> a
pureIO = IO a -> a
forall a. IO a -> a
unsafeDupablePerformIO
#else
pureIO = unsafePerformIO
#endif
newtype SecureMem = SecureMem ScrubbedBytes
secureMemGetSize :: SecureMem -> Int
secureMemGetSize :: SecureMem -> Int
secureMemGetSize (SecureMem ScrubbedBytes
scrubbedBytes) = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
scrubbedBytes
secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq :: SecureMem -> SecureMem -> Bool
secureMemEq (SecureMem ScrubbedBytes
sm1) (SecureMem ScrubbedBytes
sm2) = ScrubbedBytes
sm1 ScrubbedBytes -> ScrubbedBytes -> Bool
forall a. Eq a => a -> a -> Bool
== ScrubbedBytes
sm2
secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend :: SecureMem -> SecureMem -> SecureMem
secureMemAppend (SecureMem ScrubbedBytes
s1) (SecureMem ScrubbedBytes
s2) = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes
s1 ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
forall a. Monoid a => a -> a -> a
`mappend` ScrubbedBytes
s2)
secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat :: [SecureMem] -> SecureMem
secureMemConcat = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem)
-> ([SecureMem] -> ScrubbedBytes) -> [SecureMem] -> SecureMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ScrubbedBytes] -> ScrubbedBytes
forall a. Monoid a => [a] -> a
mconcat ([ScrubbedBytes] -> ScrubbedBytes)
-> ([SecureMem] -> [ScrubbedBytes]) -> [SecureMem] -> ScrubbedBytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SecureMem -> ScrubbedBytes) -> [SecureMem] -> [ScrubbedBytes]
forall a b. (a -> b) -> [a] -> [b]
map SecureMem -> ScrubbedBytes
unSecureMem
where unSecureMem :: SecureMem -> ScrubbedBytes
unSecureMem (SecureMem ScrubbedBytes
sb) = ScrubbedBytes
sb
secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy :: SecureMem -> IO SecureMem
secureMemCopy (SecureMem ScrubbedBytes
src) =
ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ScrubbedBytes -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
src (\Ptr Any
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy :: SecureMem -> (Ptr Word8 -> IO ()) -> IO SecureMem
withSecureMemCopy (SecureMem ScrubbedBytes
src) Ptr Word8 -> IO ()
f = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` ScrubbedBytes -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> IO bs2
B.copy ScrubbedBytes
src Ptr Word8 -> IO ()
f
instance Show SecureMem where
show :: SecureMem -> String
show SecureMem
_ = String
"<secure-mem>"
instance Byteable SecureMem where
toBytes :: SecureMem -> ByteString
toBytes = SecureMem -> ByteString
secureMemToByteString
byteableLength :: SecureMem -> Int
byteableLength = SecureMem -> Int
secureMemGetSize
withBytePtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withBytePtr = SecureMem -> (Ptr Word8 -> IO b) -> IO b
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr
instance Eq SecureMem where
== :: SecureMem -> SecureMem -> Bool
(==) = SecureMem -> SecureMem -> Bool
secureMemEq
#if MIN_VERSION_base(4,9,0)
instance Semigroup SecureMem where
<> :: SecureMem -> SecureMem -> SecureMem
(<>) = SecureMem -> SecureMem -> SecureMem
secureMemAppend
sconcat :: NonEmpty SecureMem -> SecureMem
sconcat = [SecureMem] -> SecureMem
secureMemConcat ([SecureMem] -> SecureMem)
-> (NonEmpty SecureMem -> [SecureMem])
-> NonEmpty SecureMem
-> SecureMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty SecureMem -> [SecureMem]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
#endif
instance Monoid SecureMem where
mempty :: SecureMem
mempty = Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem Int
0 (\Ptr Word8
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
#if !(MIN_VERSION_base(4,11,0))
mappend = secureMemAppend
mconcat = secureMemConcat
#endif
class ToSecureMem a where
toSecureMem :: a -> SecureMem
instance ToSecureMem SecureMem where
toSecureMem :: SecureMem -> SecureMem
toSecureMem SecureMem
a = SecureMem
a
instance ToSecureMem ByteString where
toSecureMem :: ByteString -> SecureMem
toSecureMem ByteString
bs = ByteString -> SecureMem
secureMemFromByteString ByteString
bs
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem :: Int -> IO SecureMem
allocateSecureMem Int
sz = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> (Ptr Any -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.create Int
sz (\Ptr Any
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem :: Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem Int
sz Ptr Word8 -> IO ()
f = ScrubbedBytes -> SecureMem
SecureMem (ScrubbedBytes -> SecureMem) -> IO ScrubbedBytes -> IO SecureMem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> (Ptr Word8 -> IO ()) -> IO ScrubbedBytes
forall ba p. ByteArray ba => Int -> (Ptr p -> IO ()) -> IO ba
B.create Int
sz Ptr Word8 -> IO ()
f
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem :: Int -> (Ptr Word8 -> IO ()) -> SecureMem
unsafeCreateSecureMem Int
sz Ptr Word8 -> IO ()
f = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (Int -> (Ptr Word8 -> IO ()) -> IO SecureMem
createSecureMem Int
sz Ptr Word8 -> IO ()
f)
{-# NOINLINE unsafeCreateSecureMem #-}
withSecureMemPtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr :: SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr (SecureMem ScrubbedBytes
sm) Ptr Word8 -> IO b
f = ScrubbedBytes -> (Ptr Word8 -> IO b) -> IO b
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sm Ptr Word8 -> IO b
f
withSecureMemPtrSz :: SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz :: SecureMem -> (Int -> Ptr Word8 -> IO b) -> IO b
withSecureMemPtrSz (SecureMem ScrubbedBytes
sm) Int -> Ptr Word8 -> IO b
f = ScrubbedBytes -> (Ptr Word8 -> IO b) -> IO b
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sm (Int -> Ptr Word8 -> IO b
f (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sm))
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem :: SecureMem -> IO ()
finalizeSecureMem (SecureMem ScrubbedBytes
sb) = ScrubbedBytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ScrubbedBytes
sb ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
Ptr Word8 -> Word8 -> Int -> IO ()
B.memSet Ptr Word8
p Word8
0 (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
sb)
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString :: SecureMem -> ByteString
secureMemToByteString SecureMem
sm =
Int -> (Ptr Word8 -> IO ()) -> ByteString
BS.unsafeCreate Int
sz ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst ->
SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
sm ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src ->
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
sz)
where !sz :: Int
sz = SecureMem -> Int
secureMemGetSize SecureMem
sm
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString :: ByteString -> SecureMem
secureMemFromByteString ByteString
b = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (IO SecureMem -> SecureMem) -> IO SecureMem -> SecureMem
forall a b. (a -> b) -> a -> b
$ do
SecureMem
sm <- Int -> IO SecureMem
allocateSecureMem Int
len
SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
sm ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> (Ptr Word8 -> IO ()) -> IO ()
forall b b. (Ptr b -> IO b) -> IO b
withBytestringPtr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
SecureMem -> IO SecureMem
forall (m :: * -> *) a. Monad m => a -> m a
return SecureMem
sm
where (ForeignPtr Word8
fp, Int
off, !Int
len) = ByteString -> (ForeignPtr Word8, Int, Int)
BS.toForeignPtr ByteString
b
withBytestringPtr :: (Ptr b -> IO b) -> IO b
withBytestringPtr Ptr b -> IO b
f = ForeignPtr Word8 -> (Ptr Word8 -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO b) -> IO b) -> (Ptr Word8 -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Ptr b -> IO b
f (Ptr Word8
p Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off)
{-# NOINLINE secureMemFromByteString #-}
secureMemFromByteable :: Byteable b => b -> SecureMem
secureMemFromByteable :: b -> SecureMem
secureMemFromByteable b
bs = IO SecureMem -> SecureMem
forall a. IO a -> a
pureIO (IO SecureMem -> SecureMem) -> IO SecureMem -> SecureMem
forall a b. (a -> b) -> a -> b
$ do
SecureMem
sm <- Int -> IO SecureMem
allocateSecureMem Int
len
SecureMem -> (Ptr Word8 -> IO ()) -> IO ()
forall b. SecureMem -> (Ptr Word8 -> IO b) -> IO b
withSecureMemPtr SecureMem
sm ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dst -> b -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. Byteable a => a -> (Ptr Word8 -> IO b) -> IO b
withBytePtr b
bs ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
src -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
BS.memcpy Ptr Word8
dst Ptr Word8
src (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
SecureMem -> IO SecureMem
forall (m :: * -> *) a. Monad m => a -> m a
return SecureMem
sm
where len :: Int
len = b -> Int
forall a. Byteable a => a -> Int
byteableLength b
bs
{-# NOINLINE secureMemFromByteable #-}