diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 8286357d..05a5c073 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -23,7 +23,7 @@ import Data.Char(toLower) import qualified Data.ByteString.Char8 as Char8 import qualified Data.IORef as IOR import Data.Maybe(mapMaybe, fromMaybe) -import Data.List(nub, sortBy, find, findIndex) +import Data.List(nub, sortBy, find) import Data.Map(fromListWith, assocs) import Data.Function(on) import Control.Exception(Exception, SomeException, throwIO, BlockedIndefinitelyOnMVar(..), catches, Handler(..), try) @@ -216,21 +216,22 @@ evaluatePipeline shardMapVar refreshShardmapAction conn requests = do -- merge the current responses with new responses. eresps <- mapM (try . uncurry executeRequests) requestsByNode -- take a random connection where there are no exceptions. - let (nc, _) = (requestsByNode !!) $ fromMaybe 0 $ findIndex isRight eresps -- PERF_CONCERN: Since usually we send only one request at time, this won't be -- heavy perf issue. but still should be evaluated and figured out with complete rewrite. - resps <- concat <$> mapM (\(resp, (_, r)) -> case resp of + resps <- concat <$> mapM (\(resp, (cc, r)) -> case resp of Right v -> return v - Left (_ :: SomeException) -> executeRequests nc r + Left (_ :: SomeException) -> executeRequests (getRandomConnection cc) r ) (zip eresps requestsByNode) -- check for any moved in both responses and continue the flow. when (any (moved . rawResponse) resps) (refreshShardMapVar "locked refreshing due to moved responses") retriedResps <- mapM (retry 0) resps return $ map rawResponse $ sortBy (on compare responseIndex) retriedResps where - isRight :: Either a b -> Bool - isRight (Right _) = True - isRight _ = False + getRandomConnection :: NodeConnection -> NodeConnection + getRandomConnection nc = + let (Connection hmn _ _ _ _) = conn + conns = HM.elems hmn + in fromMaybe (head conns) $ find (nc /= ) conns getRequestsByNode :: ShardMap -> IO [(NodeConnection, [PendingRequest])] getRequestsByNode shardMap = do commandsWithNodes <- zipWithM (requestWithNode shardMap) (reverse [0..(length requests - 1)]) requests