diff --git a/src/Control/Actor/Network.hs b/src/Control/Actor/Network.hs index e0a65f2..9b6814e 100644 --- a/src/Control/Actor/Network.hs +++ b/src/Control/Actor/Network.hs @@ -5,6 +5,7 @@ module Control.Actor.Network ( spawnConnTree , handleNewConn , getOrCreateConn + , connect , findByUUID , routeRemoteDeath , validateAndDispatch @@ -59,10 +60,10 @@ connDeathFn peer _ = do chClose ch return Stop -routerActorFn :: ByteString -> Actor () () -routerActorFn raw = do +routerActorFn :: NodeAddr -> ByteString -> Actor () () +routerActorFn senderAddr raw = do let nm = decode raw :: NetworkMessage - valid <- liftRuntime $ validateAndDispatch nm + valid <- liftRuntime $ validateAndDispatch senderAddr nm unless valid $ liftIO $ putStrLn "router: dropping invalid message" return (Nothing, ()) @@ -74,7 +75,7 @@ spawnConnTree peer ch = do routerCell <- liftIO newEmptyTMVarIO connCell <- liftIO newEmptyTMVarIO supervise' OneForAll - [ childWithRef routerActorFn stopOnDeath () routerCell + [ childWithRef (routerActorFn peer) stopOnDeath () routerCell , ChildSpec { csRun = \target -> do ref <- spawnActor connActorFn (connDeathFn peer) ch @@ -157,28 +158,33 @@ findByUUID uuid actors = [] -> Nothing (_, (_, r)):_ -> Just r -routeRemoteDeath :: ActorId -> RemoteExitReason -> RuntimeM () -routeRemoteDeath deadId reason = do +routeRemoteDeath :: NodeAddr -> ActorId -> RemoteExitReason -> RuntimeM () +routeRemoteDeath senderAddr (ActorId _ uuid) reason = do rt <- ask liftIO $ do - actors <- readTVarIO (rtActors rt) - let exitReason = case reason of + localNodeId <- atomically $ do + table <- readTVar (rtNodeTable rt) + return $ Map.foldrWithKey (\k v acc -> if v == senderAddr then Just k else acc) Nothing table + let deadId = ActorId (case localNodeId of { Just n -> n; Nothing -> 0 }) uuid + exitReason = case reason of RNormal -> Normal RKilled -> Killed RException s -> Exception (error s) dm = DeathMessage deadId exitReason + actors <- readTVarIO (rtActors rt) forM_ (Map.elems actors) $ \(_, SomeActorRef ref) -> case ref of LocalRef {arDeathQ, arState} -> do links <- readTVarIO (asLinks arState) forM_ links $ \case - RemoteTarget rid _ | rid == deadId -> - atomically $ writeTQueue arDeathQ dm + RemoteTarget (ActorId _ uid) peerAddr + | uid == uuid && peerAddr == senderAddr -> + atomically $ writeTQueue arDeathQ dm _else -> return () RemoteRef _ -> return () -validateAndDispatch :: NetworkMessage -> RuntimeM Bool -validateAndDispatch nm = case nm of +validateAndDispatch :: NodeAddr -> NetworkMessage -> RuntimeM Bool +validateAndDispatch senderAddr nm = case nm of NMHandshake _ -> return False @@ -231,23 +237,40 @@ validateAndDispatch nm = case nm of return False NMDeath deadId reason -> do - routeRemoteDeath deadId reason + routeRemoteDeath senderAddr deadId reason return True -- System event handler sysHandlerFn :: NetworkMessage -> Actor () () -sysHandlerFn (NMDeath deadId reason) = do - liftRuntime $ routeRemoteDeath deadId reason - return (Nothing, ()) sysHandlerFn _ = return (Nothing, ()) +-- Node connection + +-- | Connect to a remote node and return its locally-assigned NodeId. +-- NodeId 0 is always self; remote nodes get ids starting from 1. +-- A suggested id is honoured if it is free (non-zero, not already in use). +connect :: Maybe NodeId -> NodeAddr -> RuntimeM NodeId +connect suggestedId peer = do + rt <- ask + nodeId <- liftIO $ atomically $ do + table <- readTVar (rtNodeTable rt) + nid <- readTVar (rtNextNodeId rt) + let assigned = case suggestedId of + Just n | n /= 0 && not (Map.member n table) -> n + _else -> nid + writeTVar (rtNextNodeId rt) (max (assigned + 1) nid) + modifyTVar (rtNodeTable rt) (Map.insert assigned peer) + return assigned + void $ getOrCreateConn peer + return nodeId + -- Runtime initialization initRuntime :: NodeAddr -> IO Runtime initRuntime myAddr = do - transport <- createTCPTransport myAddr - rt0 <- newRuntime myAddr transport + (transport, actualAddr) <- createTCPTransport myAddr + rt0 <- newRuntime actualAddr transport let rt = rt0 { rtSendRemote = \addr nm -> getOrCreateConn addr >>= cast' nm } withRuntime rt $ void $ spawnActor sysHandlerFn stopOnDeath () tListen transport $ \ch -> do diff --git a/src/Control/Actor/Runtime.hs b/src/Control/Actor/Runtime.hs index ee7844a..b706495 100644 --- a/src/Control/Actor/Runtime.hs +++ b/src/Control/Actor/Runtime.hs @@ -21,6 +21,7 @@ import Data.Map qualified as Map data Runtime = Runtime { rtNodeId :: NodeAddr + , rtNextNodeId :: TVar NodeId , rtActors :: TVar (Map.Map ActorId (ThreadId, SomeActorRef)) , rtPending :: TVar (Map.Map CorrelationId (MVar ByteString)) , rtNextCorr :: TVar CorrelationId @@ -41,10 +42,12 @@ newRuntime myAddr transport = do pending <- newTVarIO Map.empty nextCorr <- newTVarIO (0 :: Integer) nodeTable <- newTVarIO Map.empty + nextNid <- newTVarIO (1 :: NodeId) conns <- newTVarIO Map.empty promises <- newTVarIO Map.empty return Runtime { rtNodeId = myAddr + , rtNextNodeId = nextNid , rtActors = actors , rtPending = pending , rtNextCorr = nextCorr diff --git a/src/Control/Actor/Transport.hs b/src/Control/Actor/Transport.hs index be98d21..8e4b4fd 100644 --- a/src/Control/Actor/Transport.hs +++ b/src/Control/Actor/Transport.hs @@ -28,6 +28,7 @@ import Network.Socket , listen , setSocketOption , socket + , socketPort ) import Network.Socket.ByteString.Lazy (recv, sendAll) @@ -72,7 +73,7 @@ connectTcp (NodeAddr host port) = do connect sock (addrAddress a) return sock -listenTcp :: NodeAddr -> IO Socket +listenTcp :: NodeAddr -> IO (Socket, NodeAddr) listenTcp (NodeAddr host port) = do let hints = defaultHints {addrFlags = [AI_PASSIVE], addrSocketType = Stream} addrs <- getAddrInfo (Just hints) (Just host) (Just (show port)) @@ -83,24 +84,29 @@ listenTcp (NodeAddr host port) = do setSocketOption sock ReuseAddr 1 bind sock (addrAddress a) listen sock 128 - return sock + actualPort <- fromIntegral <$> socketPort sock + return (sock, NodeAddr host actualPort) -createTCPTransport :: NodeAddr -> IO Transport +-- | Create a TCP transport bound to the given address. +-- Pass port 0 to let the OS pick a free port. +-- Returns the transport and the address actually bound (with the real port). +createTCPTransport :: NodeAddr -> IO (Transport, NodeAddr) createTCPTransport myAddr = do - lsock <- listenTcp myAddr - return Transport - { tConnect = \peer -> do - sock <- connectTcp peer - return ConnHandle - { chSend = sendFramed sock - , chRecv = recvFramed sock - , chClose = close sock - } - , tListen = \callback -> void $ forkIO $ forever $ do - (csock, _) <- accept lsock - void $ forkIO $ callback ConnHandle - { chSend = sendFramed csock - , chRecv = recvFramed csock - , chClose = close csock - } - } + (lsock, actualAddr) <- listenTcp myAddr + let transport = Transport + { tConnect = \peer -> do + sock <- connectTcp peer + return ConnHandle + { chSend = sendFramed sock + , chRecv = recvFramed sock + , chClose = close sock + } + , tListen = \callback -> void $ forkIO $ forever $ do + (csock, _) <- accept lsock + void $ forkIO $ callback ConnHandle + { chSend = sendFramed csock + , chRecv = recvFramed csock + , chClose = close csock + } + } + return (transport, actualAddr)