Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Authenticate with each node in cluster mode #224

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions src/Database/Redis/Cluster.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import Data.Typeable
import qualified Scanner
import System.IO.Unsafe(unsafeInterleaveIO)

import Database.Redis.Protocol(Reply(Error), renderRequest, reply)
import Database.Redis.Protocol(Reply(..), renderRequest, reply)
import qualified Database.Redis.Cluster.Command as CMD

-- This module implements a clustered connection whilst maintaining
Expand Down Expand Up @@ -100,8 +100,11 @@ instance Exception UnsupportedClusterCommandException
newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable)
instance Exception CrossSlotException

connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection
connect commandInfos shardMapVar timeoutOpt = do
data ClusterAuthError = ClusterAuthError Host Port Reply deriving (Show)
instance Exception ClusterAuthError

connect :: Maybe B.ByteString -> Maybe B.ByteString -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection
connect mUsername mPassword commandInfos shardMapVar timeoutOpt = do
shardMap <- readMVar shardMapVar
stateVar <- newMVar $ Pending []
pipelineVar <- newMVar $ Pipeline stateVar
Expand All @@ -113,7 +116,16 @@ connect commandInfos shardMapVar timeoutOpt = do
connectNode (Node n _ host port) = do
ctx <- CC.connect host (CC.PortNumber $ toEnum port) timeoutOpt
ref <- IOR.newIORef Nothing
return (n, NodeConnection ctx ref n)
let nodeConn = NodeConnection ctx ref n
case mPassword of
Nothing -> pure ()
Just password -> do
let reqOpts = maybe [password] (:[password]) mUsername
authReply <- requestNode1 nodeConn ( ["AUTH"] <> reqOpts )
case authReply of
SingleLine "OK" -> pure ()
_ -> throwIO $ ClusterAuthError host port authReply
return (n, nodeConn)

disconnect :: Connection -> IO ()
disconnect (Connection nodeConnMap _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where
Expand Down Expand Up @@ -370,28 +382,35 @@ allMasterNodes (Connection nodeConns _ _ _) (ShardMap shardMap) =
masterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap)

requestNode :: NodeConnection -> [[B.ByteString]] -> IO [Reply]
requestNode (NodeConnection ctx lastRecvRef _) requests = do
requestNode nodeConn@(NodeConnection ctx _ _) requests = do
mapM_ (sendNode . renderRequest) requests
_ <- CC.flush ctx
replicateM (length requests) recvNode
replicateM (length requests) $ recvNode nodeConn

where

sendNode :: B.ByteString -> IO ()
sendNode = CC.send ctx
recvNode :: IO Reply
recvNode = do
maybeLastRecv <- IOR.readIORef lastRecvRef
scanResult <- case maybeLastRecv of
Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv
Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty

case scanResult of
Scanner.Fail{} -> CC.errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
IOR.writeIORef lastRecvRef (Just rest')
return r

requestNode1 :: NodeConnection -> [B.ByteString] -> IO Reply
requestNode1 nodeConn@(NodeConnection ctx _ _) request = do
CC.send ctx $ renderRequest request
_ <- CC.flush ctx
recvNode nodeConn

recvNode :: NodeConnection -> IO Reply
recvNode (NodeConnection ctx lastRecvRef _) = do
maybeLastRecv <- IOR.readIORef lastRecvRef
scanResult <- case maybeLastRecv of
Just lastRecv -> Scanner.scanWith (CC.recv ctx) reply lastRecv
Nothing -> Scanner.scanWith (CC.recv ctx) reply B.empty

case scanResult of
Scanner.Fail{} -> CC.errConnClosed
Scanner.More{} -> error "Hedis: parseWith returned Partial"
Scanner.Done rest' r -> do
IOR.writeIORef lastRecvRef (Just rest')
return r

nodes :: ShardMap -> [Node]
nodes (ShardMap shardMap) = concatMap snd $ IntMap.toList $ fmap shardNodes shardMap where
Expand Down
4 changes: 2 additions & 2 deletions src/Database/Redis/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ connectCluster bootstrapConnInfo = do
Left e -> throwIO $ ClusterConnectError e
Right infos -> do
#if MIN_VERSION_resource_pool(0,3,0)
pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
pool <- newPool (defaultPoolConfig (Cluster.connect (connectUsername bootstrapConnInfo) (connectAuth bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
#else
pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
pool <- createPool (Cluster.connect (connectUsername bootstrapConnInfo) (connectAuth bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
#endif
return $ ClusteredConnection shardMapVar pool

Expand Down