fix: set correct node id on incoming messages
This commit is contained in:
@@ -5,6 +5,7 @@ module Control.Actor.Network
|
||||
( spawnConnTree
|
||||
, handleNewConn
|
||||
, getOrCreateConn
|
||||
, connect
|
||||
, findByUUID
|
||||
, routeRemoteDeath
|
||||
, validateAndDispatch
|
||||
@@ -59,10 +60,10 @@ connDeathFn peer _ = do
|
||||
chClose ch
|
||||
return Stop
|
||||
|
||||
routerActorFn :: ByteString -> Actor () ()
|
||||
routerActorFn raw = do
|
||||
routerActorFn :: NodeAddr -> ByteString -> Actor () ()
|
||||
routerActorFn senderAddr raw = do
|
||||
let nm = decode raw :: NetworkMessage
|
||||
valid <- liftRuntime $ validateAndDispatch nm
|
||||
valid <- liftRuntime $ validateAndDispatch senderAddr nm
|
||||
unless valid $ liftIO $ putStrLn "router: dropping invalid message"
|
||||
return (Nothing, ())
|
||||
|
||||
@@ -74,7 +75,7 @@ spawnConnTree peer ch = do
|
||||
routerCell <- liftIO newEmptyTMVarIO
|
||||
connCell <- liftIO newEmptyTMVarIO
|
||||
supervise' OneForAll
|
||||
[ childWithRef routerActorFn stopOnDeath () routerCell
|
||||
[ childWithRef (routerActorFn peer) stopOnDeath () routerCell
|
||||
, ChildSpec
|
||||
{ csRun = \target -> do
|
||||
ref <- spawnActor connActorFn (connDeathFn peer) ch
|
||||
@@ -157,28 +158,33 @@ findByUUID uuid actors =
|
||||
[] -> Nothing
|
||||
(_, (_, r)):_ -> Just r
|
||||
|
||||
routeRemoteDeath :: ActorId -> RemoteExitReason -> RuntimeM ()
|
||||
routeRemoteDeath deadId reason = do
|
||||
routeRemoteDeath :: NodeAddr -> ActorId -> RemoteExitReason -> RuntimeM ()
|
||||
routeRemoteDeath senderAddr (ActorId _ uuid) reason = do
|
||||
rt <- ask
|
||||
liftIO $ do
|
||||
actors <- readTVarIO (rtActors rt)
|
||||
let exitReason = case reason of
|
||||
localNodeId <- atomically $ do
|
||||
table <- readTVar (rtNodeTable rt)
|
||||
return $ Map.foldrWithKey (\k v acc -> if v == senderAddr then Just k else acc) Nothing table
|
||||
let deadId = ActorId (case localNodeId of { Just n -> n; Nothing -> 0 }) uuid
|
||||
exitReason = case reason of
|
||||
RNormal -> Normal
|
||||
RKilled -> Killed
|
||||
RException s -> Exception (error s)
|
||||
dm = DeathMessage deadId exitReason
|
||||
actors <- readTVarIO (rtActors rt)
|
||||
forM_ (Map.elems actors) $ \(_, SomeActorRef ref) ->
|
||||
case ref of
|
||||
LocalRef {arDeathQ, arState} -> do
|
||||
links <- readTVarIO (asLinks arState)
|
||||
forM_ links $ \case
|
||||
RemoteTarget rid _ | rid == deadId ->
|
||||
RemoteTarget (ActorId _ uid) peerAddr
|
||||
| uid == uuid && peerAddr == senderAddr ->
|
||||
atomically $ writeTQueue arDeathQ dm
|
||||
_else -> return ()
|
||||
RemoteRef _ -> return ()
|
||||
|
||||
validateAndDispatch :: NetworkMessage -> RuntimeM Bool
|
||||
validateAndDispatch nm = case nm of
|
||||
validateAndDispatch :: NodeAddr -> NetworkMessage -> RuntimeM Bool
|
||||
validateAndDispatch senderAddr nm = case nm of
|
||||
|
||||
NMHandshake _ -> return False
|
||||
|
||||
@@ -231,23 +237,40 @@ validateAndDispatch nm = case nm of
|
||||
return False
|
||||
|
||||
NMDeath deadId reason -> do
|
||||
routeRemoteDeath deadId reason
|
||||
routeRemoteDeath senderAddr deadId reason
|
||||
return True
|
||||
|
||||
-- System event handler
|
||||
|
||||
sysHandlerFn :: NetworkMessage -> Actor () ()
|
||||
sysHandlerFn (NMDeath deadId reason) = do
|
||||
liftRuntime $ routeRemoteDeath deadId reason
|
||||
return (Nothing, ())
|
||||
sysHandlerFn _ = return (Nothing, ())
|
||||
|
||||
-- Node connection
|
||||
|
||||
-- | Connect to a remote node and return its locally-assigned NodeId.
|
||||
-- NodeId 0 is always self; remote nodes get ids starting from 1.
|
||||
-- A suggested id is honoured if it is free (non-zero, not already in use).
|
||||
connect :: Maybe NodeId -> NodeAddr -> RuntimeM NodeId
|
||||
connect suggestedId peer = do
|
||||
rt <- ask
|
||||
nodeId <- liftIO $ atomically $ do
|
||||
table <- readTVar (rtNodeTable rt)
|
||||
nid <- readTVar (rtNextNodeId rt)
|
||||
let assigned = case suggestedId of
|
||||
Just n | n /= 0 && not (Map.member n table) -> n
|
||||
_else -> nid
|
||||
writeTVar (rtNextNodeId rt) (max (assigned + 1) nid)
|
||||
modifyTVar (rtNodeTable rt) (Map.insert assigned peer)
|
||||
return assigned
|
||||
void $ getOrCreateConn peer
|
||||
return nodeId
|
||||
|
||||
-- Runtime initialization
|
||||
|
||||
initRuntime :: NodeAddr -> IO Runtime
|
||||
initRuntime myAddr = do
|
||||
transport <- createTCPTransport myAddr
|
||||
rt0 <- newRuntime myAddr transport
|
||||
(transport, actualAddr) <- createTCPTransport myAddr
|
||||
rt0 <- newRuntime actualAddr transport
|
||||
let rt = rt0 { rtSendRemote = \addr nm -> getOrCreateConn addr >>= cast' nm }
|
||||
withRuntime rt $ void $ spawnActor sysHandlerFn stopOnDeath ()
|
||||
tListen transport $ \ch -> do
|
||||
|
||||
@@ -21,6 +21,7 @@ import Data.Map qualified as Map
|
||||
|
||||
data Runtime = Runtime
|
||||
{ rtNodeId :: NodeAddr
|
||||
, rtNextNodeId :: TVar NodeId
|
||||
, rtActors :: TVar (Map.Map ActorId (ThreadId, SomeActorRef))
|
||||
, rtPending :: TVar (Map.Map CorrelationId (MVar ByteString))
|
||||
, rtNextCorr :: TVar CorrelationId
|
||||
@@ -41,10 +42,12 @@ newRuntime myAddr transport = do
|
||||
pending <- newTVarIO Map.empty
|
||||
nextCorr <- newTVarIO (0 :: Integer)
|
||||
nodeTable <- newTVarIO Map.empty
|
||||
nextNid <- newTVarIO (1 :: NodeId)
|
||||
conns <- newTVarIO Map.empty
|
||||
promises <- newTVarIO Map.empty
|
||||
return Runtime
|
||||
{ rtNodeId = myAddr
|
||||
, rtNextNodeId = nextNid
|
||||
, rtActors = actors
|
||||
, rtPending = pending
|
||||
, rtNextCorr = nextCorr
|
||||
|
||||
@@ -28,6 +28,7 @@ import Network.Socket
|
||||
, listen
|
||||
, setSocketOption
|
||||
, socket
|
||||
, socketPort
|
||||
)
|
||||
import Network.Socket.ByteString.Lazy (recv, sendAll)
|
||||
|
||||
@@ -72,7 +73,7 @@ connectTcp (NodeAddr host port) = do
|
||||
connect sock (addrAddress a)
|
||||
return sock
|
||||
|
||||
listenTcp :: NodeAddr -> IO Socket
|
||||
listenTcp :: NodeAddr -> IO (Socket, NodeAddr)
|
||||
listenTcp (NodeAddr host port) = do
|
||||
let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream}
|
||||
addrs <- getAddrInfo (Just hints) (Just host) (Just (show port))
|
||||
@@ -83,12 +84,16 @@ listenTcp (NodeAddr host port) = do
|
||||
setSocketOption sock ReuseAddr 1
|
||||
bind sock (addrAddress a)
|
||||
listen sock 128
|
||||
return sock
|
||||
actualPort <- fromIntegral <$> socketPort sock
|
||||
return (sock, NodeAddr host actualPort)
|
||||
|
||||
createTCPTransport :: NodeAddr -> IO Transport
|
||||
-- | Create a TCP transport bound to the given address.
|
||||
-- Pass port 0 to let the OS pick a free port.
|
||||
-- Returns the transport and the address actually bound (with the real port).
|
||||
createTCPTransport :: NodeAddr -> IO (Transport, NodeAddr)
|
||||
createTCPTransport myAddr = do
|
||||
lsock <- listenTcp myAddr
|
||||
return Transport
|
||||
(lsock, actualAddr) <- listenTcp myAddr
|
||||
let transport = Transport
|
||||
{ tConnect = \peer -> do
|
||||
sock <- connectTcp peer
|
||||
return ConnHandle
|
||||
@@ -104,3 +109,4 @@ createTCPTransport myAddr = do
|
||||
, chClose = close csock
|
||||
}
|
||||
}
|
||||
return (transport, actualAddr)
|
||||
|
||||
Reference in New Issue
Block a user