[websockets] migrate to StmContainers.Set

parent 5bcb8731
Pipeline #6170 failed with stages
in 67 minutes and 50 seconds
...@@ -520,6 +520,7 @@ library ...@@ -520,6 +520,7 @@ library
, cryptohash ^>= 0.11.9 , cryptohash ^>= 0.11.9
, data-time-segment ^>= 0.1.0.0 , data-time-segment ^>= 0.1.0.0
, deepseq ^>= 1.4.4.0 , deepseq ^>= 1.4.4.0
, deferred-folds >= 0.9.18 && < 0.10
, directory ^>= 1.3.6.0 , directory ^>= 1.3.6.0
, discrimination >= 0.5 , discrimination >= 0.5
, ekg-core ^>= 0.1.1.7 , ekg-core ^>= 0.1.1.7
......
...@@ -26,6 +26,7 @@ import Data.Aeson.Types (prependFailure, typeMismatch) ...@@ -26,6 +26,7 @@ import Data.Aeson.Types (prependFailure, typeMismatch)
import Data.ByteString.Char8 qualified as C import Data.ByteString.Char8 qualified as C
import Data.ByteString.Lazy qualified as BSL import Data.ByteString.Lazy qualified as BSL
import Data.List (nubBy) import Data.List (nubBy)
import DeferredFolds.UnfoldlM qualified as UnfoldlM
import Data.UUID.V4 as UUID import Data.UUID.V4 as UUID
import Gargantext.API.Admin.Auth.Types (AuthenticatedUser(_auth_user_id)) import Gargantext.API.Admin.Auth.Types (AuthenticatedUser(_auth_user_id))
import Gargantext.API.Admin.Types (jwtSettings, Settings, jwtSettings) import Gargantext.API.Admin.Types (jwtSettings, Settings, jwtSettings)
...@@ -35,9 +36,11 @@ import Gargantext.Prelude ...@@ -35,9 +36,11 @@ import Gargantext.Prelude
import GHC.Conc (TVar, newTVarIO, readTVar, writeTVar) import GHC.Conc (TVar, newTVarIO, readTVar, writeTVar)
import Nanomsg import Nanomsg
import Network.WebSockets qualified as WS import Network.WebSockets qualified as WS
import Protolude.Base (Show(showsPrec))
import Servant import Servant
import Servant.API.WebSocket qualified as WS import Servant.API.WebSocket qualified as WS
import Servant.Auth.Server (verifyJWT) import Servant.Auth.Server (verifyJWT)
import StmContainers.Set as SSet
{- {-
...@@ -56,7 +59,8 @@ data Topic = ...@@ -56,7 +59,8 @@ data Topic =
-- children (e.g. list is automatically created in a corpus) -- children (e.g. list is automatically created in a corpus)
UpdateTree NodeId UpdateTree NodeId
deriving (Eq, Show) deriving (Eq, Show)
instance Hashable Topic where
hashWithSalt salt (UpdateTree nodeId) = hashWithSalt salt ("update-tree" :: Text, nodeId)
instance FromJSON Topic where instance FromJSON Topic where
parseJSON = Aeson.withObject "Topic" $ \o -> do parseJSON = Aeson.withObject "Topic" $ \o -> do
type_ <- o .: "type" type_ <- o .: "type"
...@@ -65,7 +69,6 @@ instance FromJSON Topic where ...@@ -65,7 +69,6 @@ instance FromJSON Topic where
node_id <- o .: "node_id" node_id <- o .: "node_id"
pure $ UpdateTree node_id pure $ UpdateTree node_id
s -> prependFailure "parsing type failed, " (typeMismatch "type" s) s -> prependFailure "parsing type failed, " (typeMismatch "type" s)
instance ToJSON Topic where instance ToJSON Topic where
toJSON (UpdateTree node_id) = Aeson.object [ toJSON (UpdateTree node_id) = Aeson.object [
"type" .= toJSON ("update_tree" :: Text) "type" .= toJSON ("update_tree" :: Text)
...@@ -76,10 +79,17 @@ data ConnectedUser = ...@@ -76,10 +79,17 @@ data ConnectedUser =
CUUser UserId CUUser UserId
| CUPublic | CUPublic
deriving (Eq, Show) deriving (Eq, Show)
instance Hashable ConnectedUser where
hashWithSalt salt (CUUser userId) = hashWithSalt salt ("cuuser" :: Text, userId)
hashWithSalt salt CUPublic = hashWithSalt salt ("cupublic" :: Text)
newtype WSKeyConnection = WSKeyConnection (ByteString, WS.Connection) newtype WSKeyConnection = WSKeyConnection (ByteString, WS.Connection)
eqWSKeyConnection :: WSKeyConnection -> WSKeyConnection -> Bool instance Hashable WSKeyConnection where
eqWSKeyConnection ws1 ws2 = wsKey ws1 == wsKey ws2 hashWithSalt salt (WSKeyConnection (key, _conn)) = hashWithSalt salt key
instance Eq WSKeyConnection where
(==) (WSKeyConnection (key1, _conn1)) (WSKeyConnection (key2, _conn2)) = key1 == key2
instance Show WSKeyConnection where
showsPrec d (WSKeyConnection (key, _conn)) = showsPrec d $ "WSKeyConnection " <> key
showWSKeyConnection :: WSKeyConnection -> Text showWSKeyConnection :: WSKeyConnection -> Text
showWSKeyConnection ws = "WSKeyConnection " <> show (wsKey ws) showWSKeyConnection ws = "WSKeyConnection " <> show (wsKey ws)
wsKey :: WSKeyConnection -> ByteString wsKey :: WSKeyConnection -> ByteString
...@@ -92,16 +102,10 @@ data Subscription = ...@@ -92,16 +102,10 @@ data Subscription =
s_connected_user :: ConnectedUser s_connected_user :: ConnectedUser
, s_ws_key_connection :: WSKeyConnection , s_ws_key_connection :: WSKeyConnection
, s_topic :: Topic } , s_topic :: Topic }
eqSub :: Subscription -> Subscription -> Bool deriving (Eq, Show)
eqSub sub1 sub2 = instance Hashable Subscription where
s_connected_user sub1 == s_connected_user sub2 && hashWithSalt salt (Subscription { .. }) =
s_ws_key_connection sub2 `eqWSKeyConnection` s_ws_key_connection sub2 && hashWithSalt salt ( s_connected_user, s_ws_key_connection, s_topic )
s_topic sub1 == s_topic sub2
showSub :: Subscription -> Text
showSub sub =
"Subscription " <> show (s_connected_user sub) <>
" " <> showWSKeyConnection (s_ws_key_connection sub) <>
" " <> show (s_topic sub)
subKey :: Subscription -> ByteString subKey :: Subscription -> ByteString
subKey sub = wsKey $ s_ws_key_connection sub subKey sub = wsKey $ s_ws_key_connection sub
...@@ -142,7 +146,7 @@ instance FromJSON WSRequest where ...@@ -142,7 +146,7 @@ instance FromJSON WSRequest where
s -> prependFailure "parsing request type failed, " (typeMismatch "request" s) s -> prependFailure "parsing request type failed, " (typeMismatch "request" s)
data Dispatcher = data Dispatcher =
Dispatcher { d_subscriptions :: TVar [Subscription] Dispatcher { d_subscriptions :: SSet.Set Subscription
, d_ws_server :: Server WSAPI , d_ws_server :: Server WSAPI
, d_ce_listener :: ThreadId , d_ce_listener :: ThreadId
} }
...@@ -150,7 +154,7 @@ data Dispatcher = ...@@ -150,7 +154,7 @@ data Dispatcher =
dispatcher :: Settings -> IO Dispatcher dispatcher :: Settings -> IO Dispatcher
dispatcher authSettings = do dispatcher authSettings = do
subscriptions <- newTVarIO ([] :: [Subscription]) subscriptions <- SSet.newIO
let server = wsServer authSettings subscriptions let server = wsServer authSettings subscriptions
...@@ -164,33 +168,37 @@ dispatcher authSettings = do ...@@ -164,33 +168,37 @@ dispatcher authSettings = do
-- | TODO Allow only 1 topic subscription per connection. It doesn't -- | TODO Allow only 1 topic subscription per connection. It doesn't
-- | make sense to send multiple notifications of the same type to the -- | make sense to send multiple notifications of the same type to the
-- | same connection. -- | same connection.
insertSubscription :: TVar [Subscription] -> Subscription -> IO [Subscription] insertSubscription :: SSet.Set Subscription -> Subscription -> IO ()
insertSubscription subscriptions sub = insertSubscription subscriptions sub =
atomically $ do atomically $ SSet.insert sub subscriptions
s <- readTVar subscriptions -- s <- readTVar subscriptions
let ss = nubBy eqSub $ s <> [sub] -- let ss = nubBy eqSub $ s <> [sub]
writeTVar subscriptions ss -- writeTVar subscriptions ss
pure ss -- -- pure ss
-- pure ()
removeSubscription :: TVar [Subscription] -> Subscription -> IO [Subscription] removeSubscription :: SSet.Set Subscription -> Subscription -> IO ()
removeSubscription subscriptions sub = removeSubscription subscriptions sub =
atomically $ do atomically $ SSet.delete sub subscriptions
s <- readTVar subscriptions -- s <- readTVar subscriptions
let ss = filter (\sub' -> not $ sub `eqSub` sub') s -- let ss = filter (\sub' -> not $ sub `eqSub` sub') s
writeTVar subscriptions ss -- writeTVar subscriptions ss
pure ss -- pure ss
removeSubscriptionsForWSKey :: TVar [Subscription] -> WSKeyConnection -> IO [Subscription] removeSubscriptionsForWSKey :: SSet.Set Subscription -> WSKeyConnection -> IO ()
removeSubscriptionsForWSKey subscriptions ws = removeSubscriptionsForWSKey subscriptions ws =
atomically $ do atomically $ do
s <- readTVar subscriptions let toDelete = UnfoldlM.filter (\sub -> return $ subKey sub == wsKey ws) $ SSet.unfoldlM subscriptions
let ss = filter (\sub -> subKey sub /= wsKey ws) s UnfoldlM.mapM_ (\sub -> SSet.delete sub subscriptions) toDelete
writeTVar subscriptions ss -- atomically $ do
pure ss -- s <- readTVar subscriptions
-- let ss = filter (\sub -> subKey sub /= wsKey ws) s
-- writeTVar subscriptions ss
-- pure ss
type WSAPI = "ws" :> WS.WebSocketPending type WSAPI = "ws" :> WS.WebSocketPending
wsServer :: Settings -> TVar [Subscription] -> Server WSAPI wsServer :: Settings -> SSet.Set Subscription -> Server WSAPI
wsServer authSettings subscriptions = streamData wsServer authSettings subscriptions = streamData
where where
streamData :: MonadIO m => WS.PendingConnection -> m () streamData :: MonadIO m => WS.PendingConnection -> m ()
...@@ -242,16 +250,15 @@ wsServer authSettings subscriptions = streamData ...@@ -242,16 +250,15 @@ wsServer authSettings subscriptions = streamData
let sub = Subscription { s_connected_user = user let sub = Subscription { s_connected_user = user
, s_ws_key_connection = ws , s_ws_key_connection = ws
, s_topic = topic } , s_topic = topic }
ss <- insertSubscription subscriptions sub _ss <- insertSubscription subscriptions sub
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss) -- putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
return user return user
Just (WSUnsubscribe topic) -> do Just (WSUnsubscribe topic) -> do
-- TODO Fix s_connected_user based on header let sub = Subscription { s_connected_user = user
let sub = Subscription { s_connected_user = CUPublic , s_ws_key_connection = ws
, s_ws_key_connection = ws , s_topic = topic }
, s_topic = topic } _ss <- removeSubscription subscriptions sub
ss <- removeSubscription subscriptions sub -- putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
return user return user
Just (WSAuthorize token) -> do Just (WSAuthorize token) -> do
let jwtS = authSettings ^. jwtSettings let jwtS = authSettings ^. jwtSettings
...@@ -273,8 +280,9 @@ wsServer authSettings subscriptions = streamData ...@@ -273,8 +280,9 @@ wsServer authSettings subscriptions = streamData
disconnect = do disconnect = do
putText "[wsLoop] disconnecting..." putText "[wsLoop] disconnecting..."
ss <- removeSubscriptionsForWSKey subscriptions ws _ss <- removeSubscriptionsForWSKey subscriptions ws
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss) -- putText $ "[wsLoop] subscriptions: " <> show (show <$> ss)
return ()
data Notification = data Notification =
...@@ -287,7 +295,7 @@ instance ToJSON Notification where ...@@ -287,7 +295,7 @@ instance ToJSON Notification where
] ]
ce_listener :: TVar [Subscription] -> IO () ce_listener :: SSet.Set Subscription -> IO ()
ce_listener subscriptions = do ce_listener subscriptions = do
withSocket Pull $ \s -> do withSocket Pull $ \s -> do
_ <- bind s "tcp://*:5561" _ <- bind s "tcp://*:5561"
...@@ -298,7 +306,10 @@ ce_listener subscriptions = do ...@@ -298,7 +306,10 @@ ce_listener subscriptions = do
case Aeson.decode (BSL.fromStrict r) of case Aeson.decode (BSL.fromStrict r) of
Nothing -> putText "[ce_listener] unknown message from central exchange" Nothing -> putText "[ce_listener] unknown message from central exchange"
Just ceMessage -> do Just ceMessage -> do
subs <- atomically $ readTVar subscriptions -- subs <- atomically $ readTVar subscriptions
filteredSubs <- atomically $ do
let subs' = UnfoldlM.filter (pure . ceMessageSubPred ceMessage) $ SSet.unfoldlM subscriptions
UnfoldlM.foldlM' (\acc sub -> pure $ acc <> [sub]) [] subs'
-- NOTE This isn't safe: we atomically fetch subscriptions, -- NOTE This isn't safe: we atomically fetch subscriptions,
-- then send notifications one by one. In the meantime, a -- then send notifications one by one. In the meantime, a
-- subscription could end or new ones could appear (but is -- subscription could end or new ones could appear (but is
...@@ -306,7 +317,7 @@ ce_listener subscriptions = do ...@@ -306,7 +317,7 @@ ce_listener subscriptions = do
-- probably they already fetch new tree anyways, and if old -- probably they already fetch new tree anyways, and if old
-- one drops in the meantime, it won't listen to what we -- one drops in the meantime, it won't listen to what we
-- send...) -- send...)
let filteredSubs = filterCEMessageSubs ceMessage subs -- let filteredSubs = filterCEMessageSubs ceMessage subs
mapM_ (sendNotification ceMessage) filteredSubs mapM_ (sendNotification ceMessage) filteredSubs
where where
sendNotification :: CETypes.CEMessage -> Subscription -> IO () sendNotification :: CETypes.CEMessage -> Subscription -> IO ()
...@@ -321,5 +332,8 @@ ce_listener subscriptions = do ...@@ -321,5 +332,8 @@ ce_listener subscriptions = do
-- For example, we can add CEMessage.Broadcast to propagate a -- For example, we can add CEMessage.Broadcast to propagate a
-- notification to all connections. -- notification to all connections.
filterCEMessageSubs :: CETypes.CEMessage -> [Subscription] -> [Subscription] filterCEMessageSubs :: CETypes.CEMessage -> [Subscription] -> [Subscription]
filterCEMessageSubs (CETypes.UpdateTreeFirstLevel node_id) subscriptions = filterCEMessageSubs ceMessage subscriptions = filter (ceMessageSubPred ceMessage) subscriptions
filter (\sub -> s_topic sub == UpdateTree node_id) subscriptions
ceMessageSubPred :: CETypes.CEMessage -> Subscription -> Bool
ceMessageSubPred (CETypes.UpdateTreeFirstLevel node_id) (Subscription { s_topic }) =
s_topic == UpdateTree node_id
...@@ -63,7 +63,7 @@ class ResourceId a where ...@@ -63,7 +63,7 @@ class ResourceId a where
-- whereas this one tracks only users. -- whereas this one tracks only users.
newtype UserId = UnsafeMkUserId { _UserId :: Int } newtype UserId = UnsafeMkUserId { _UserId :: Int }
deriving stock (Show, Eq, Ord, Generic) deriving stock (Show, Eq, Ord, Generic)
deriving newtype (ToSchema, ToJSON, FromJSON, FromField, ToField) deriving newtype (ToSchema, ToJSON, FromJSON, FromField, ToField, Hashable)
-- The 'UserId' is isomprohic to an 'Int'. -- The 'UserId' is isomprohic to an 'Int'.
instance GQLType UserId where instance GQLType UserId where
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment