-- |
-- Module      : Network.TLS.Sending
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Sending module contains calls related to marshalling packets according
-- to the TLS state
--
module Network.TLS.Sending (
    encodePacket
  , encodeRecordM
  , updateHandshake
  ) where

import Network.TLS.Cap
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Types (Role(..))
import Network.TLS.Util

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef

-- | encodePacket transform a packet into marshalled data related to current state
-- and updating state on the go
encodePacket :: Context -> Packet -> IO (Either TLSError ByteString)
encodePacket :: Context -> Packet -> IO (Either TLSError ByteString)
encodePacket Context
ctx Packet
pkt = do
    (Version
ver, Bool
_) <- Context -> IO (Version, Bool)
decideRecordVersion Context
ctx
    let pt :: ProtocolType
pt = Packet -> ProtocolType
packetType Packet
pkt
        mkRecord :: ByteString -> Record Plaintext
mkRecord ByteString
bs = ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
ver (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
    [Record Plaintext]
records <- (ByteString -> Record Plaintext)
-> [ByteString] -> [Record Plaintext]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord ([ByteString] -> [Record Plaintext])
-> IO [ByteString] -> IO [Record Plaintext]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Int -> Packet -> IO [ByteString]
packetToFragments Context
ctx Int
16384 Packet
pkt
    Either TLSError ByteString
bs <- ([ByteString] -> ByteString)
-> Either TLSError [ByteString] -> Either TLSError ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [ByteString] -> ByteString
B.concat (Either TLSError [ByteString] -> Either TLSError ByteString)
-> IO (Either TLSError [ByteString])
-> IO (Either TLSError ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Record Plaintext]
-> (Record Plaintext -> IO (Either TLSError ByteString))
-> IO (Either TLSError [ByteString])
forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (Context -> Record Plaintext -> IO (Either TLSError ByteString)
encodeRecord Context
ctx)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Packet
pkt Packet -> Packet -> Bool
forall a. Eq a => a -> a -> Bool
== Packet
ChangeCipherSpec) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> IO ()
switchTxEncryption Context
ctx
    Either TLSError ByteString -> IO (Either TLSError ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError ByteString
bs

-- Decompose handshake packets into fragments of the specified length.  AppData
-- packets are not fragmented here but by callers of sendPacket, so that the
-- empty-packet countermeasure may be applied to each fragment independently.
packetToFragments :: Context -> Int -> Packet -> IO [ByteString]
packetToFragments :: Context -> Int -> Packet -> IO [ByteString]
packetToFragments Context
ctx Int
len (Handshake [Handshake]
hss)  =
    Int -> ByteString -> [ByteString]
getChunks Int
len (ByteString -> [ByteString])
-> ([ByteString] -> ByteString) -> [ByteString] -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> [ByteString])
-> IO [ByteString] -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Handshake -> IO ByteString) -> [Handshake] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
ClientRole) [Handshake]
hss
packetToFragments Context
_   Int
_   (Alert [(AlertLevel, AlertDescription)]
a)        = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments Context
_   Int
_   Packet
ChangeCipherSpec = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]
packetToFragments Context
_   Int
_   (AppData ByteString
x)      = [ByteString] -> IO [ByteString]
forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]

-- before TLS 1.1, the block cipher IV is made of the residual of the previous block,
-- so we use cstIV as is, however in other case we generate an explicit IV
prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
prepareRecord :: Context -> RecordM a -> IO (Either TLSError a)
prepareRecord Context
ctx RecordM a
f = do
    Version
ver     <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (Version -> TLSSt Version
getVersionWithDefault (Version -> TLSSt Version) -> Version -> TLSSt Version
forall a b. (a -> b) -> a -> b
$ [Version] -> Version
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Version] -> Version) -> [Version] -> Version
forall a b. (a -> b) -> a -> b
$ Supported -> [Version]
supportedVersions (Supported -> [Version]) -> Supported -> [Version]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx)
    RecordState
txState <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (MVar RecordState -> IO RecordState)
-> MVar RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ Context -> MVar RecordState
ctxTxState Context
ctx
    let sz :: Int
sz = case RecordState -> Maybe Cipher
stCipher RecordState
txState of
                  Maybe Cipher
Nothing     -> Int
0
                  Just Cipher
cipher -> if BulkFunctions -> Bool
hasRecordIV (BulkFunctions -> Bool) -> BulkFunctions -> Bool
forall a b. (a -> b) -> a -> b
$ Bulk -> BulkFunctions
bulkF (Bulk -> BulkFunctions) -> Bulk -> BulkFunctions
forall a b. (a -> b) -> a -> b
$ Cipher -> Bulk
cipherBulk Cipher
cipher
                                    then Bulk -> Int
bulkIVSize (Bulk -> Int) -> Bulk -> Int
forall a b. (a -> b) -> a -> b
$ Cipher -> Bulk
cipherBulk Cipher
cipher
                                    else Int
0 -- to not generate IV
    if Version -> Bool
hasExplicitBlockIV Version
ver Bool -> Bool -> Bool
&& Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
        then do ByteString
newIV <- Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
sz
                Context -> RecordM a -> IO (Either TLSError a)
forall a. Context -> RecordM a -> IO (Either TLSError a)
runTxState Context
ctx ((RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (ByteString -> RecordState -> RecordState
setRecordIV ByteString
newIV) RecordM () -> RecordM a -> RecordM a
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> RecordM a
f)
        else Context -> RecordM a -> IO (Either TLSError a)
forall a. Context -> RecordM a -> IO (Either TLSError a)
runTxState Context
ctx RecordM a
f

encodeRecord :: Context -> Record Plaintext -> IO (Either TLSError ByteString)
encodeRecord :: Context -> Record Plaintext -> IO (Either TLSError ByteString)
encodeRecord Context
ctx = Context -> RecordM ByteString -> IO (Either TLSError ByteString)
forall a. Context -> RecordM a -> IO (Either TLSError a)
prepareRecord Context
ctx (RecordM ByteString -> IO (Either TLSError ByteString))
-> (Record Plaintext -> RecordM ByteString)
-> Record Plaintext
-> IO (Either TLSError ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Record Plaintext -> RecordM ByteString
encodeRecordM

encodeRecordM :: Record Plaintext -> RecordM ByteString
encodeRecordM :: Record Plaintext -> RecordM ByteString
encodeRecordM Record Plaintext
record = do
    Record Ciphertext
erecord <- Record Plaintext -> RecordM (Record Ciphertext)
engageRecord Record Plaintext
record
    let (Header
hdr, ByteString
content) = Record Ciphertext -> (Header, ByteString)
forall a. Record a -> (Header, ByteString)
recordToRaw Record Ciphertext
erecord
    ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> RecordM ByteString)
-> ByteString -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
B.concat [ Header -> ByteString
encodeHeader Header
hdr, ByteString
content ]

switchTxEncryption :: Context -> IO ()
switchTxEncryption :: Context -> IO ()
switchTxEncryption Context
ctx = do
    RecordState
tx  <- Context -> HandshakeM RecordState -> IO RecordState
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (String -> Maybe RecordState -> RecordState
forall a. String -> Maybe a -> a
fromJust String
"tx-state" (Maybe RecordState -> RecordState)
-> HandshakeM (Maybe RecordState) -> HandshakeM RecordState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingTxState)
    (Version
ver, Role
cc) <- Context -> TLSSt (Version, Role) -> IO (Version, Role)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt (Version, Role) -> IO (Version, Role))
-> TLSSt (Version, Role) -> IO (Version, Role)
forall a b. (a -> b) -> a -> b
$ do Version
v <- TLSSt Version
getVersion
                                      Role
c <- TLSSt Role
isClientContext
                                      (Version, Role) -> TLSSt (Version, Role)
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
v, Role
c)
    IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxTxState Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
tx)
    -- set empty packet counter measure if condition are met
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
<= Version
TLS10 Bool -> Bool -> Bool
&& Role
cc Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
ClientRole Bool -> Bool -> Bool
&& RecordState -> Bool
isCBC RecordState
tx Bool -> Bool -> Bool
&& Supported -> Bool
supportedEmptyPacket (Context -> Supported
ctxSupported Context
ctx)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxNeedEmptyPacket Context
ctx) Bool
True
  where isCBC :: RecordState -> Bool
isCBC RecordState
tx = Bool -> (Cipher -> Bool) -> Maybe Cipher -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (\Cipher
c -> Bulk -> Int
bulkBlockSize (Cipher -> Bulk
cipherBulk Cipher
c) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (RecordState -> Maybe Cipher
stCipher RecordState
tx)

updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
role Handshake
hs = do
    case Handshake
hs of
        Finished ByteString
fdata -> Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Role -> ByteString -> TLSSt ()
updateVerifiedData Role
role ByteString
fdata
        Handshake
_              -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Bool -> HandshakeM () -> HandshakeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
certVerifyHandshakeMaterial Handshake
hs) (HandshakeM () -> HandshakeM ()) -> HandshakeM () -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
        Bool -> HandshakeM () -> HandshakeM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (HandshakeType -> Bool
finishHandshakeTypeMaterial (HandshakeType -> Bool) -> HandshakeType -> Bool
forall a b. (a -> b) -> a -> b
$ Handshake -> HandshakeType
typeOfHandshake Handshake
hs) (HandshakeM () -> HandshakeM ()) -> HandshakeM () -> HandshakeM ()
forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
encoded
    ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
  where
    encoded :: ByteString
encoded = Handshake -> ByteString
encodeHandshake Handshake
hs