[websockets] migrate to StmContainers.Set

parent 5bcb8731
Pipeline #6170 failed with stages
in 67 minutes and 50 seconds
......@@ -520,6 +520,7 @@ library
, cryptohash ^>= 0.11.9
, data-time-segment ^>= 0.1.0.0
, deepseq ^>= 1.4.4.0
, deferred-folds >= 0.9.18 && < 0.10
, directory ^>= 1.3.6.0
, discrimination >= 0.5
, ekg-core ^>= 0.1.1.7
......
......@@ -26,6 +26,7 @@ import Data.Aeson.Types (prependFailure, typeMismatch)
import Data.ByteString.Char8 qualified as C
import Data.ByteString.Lazy qualified as BSL
import Data.List (nubBy)
import DeferredFolds.UnfoldlM qualified as UnfoldlM
import Data.UUID.V4 as UUID
import Gargantext.API.Admin.Auth.Types (AuthenticatedUser(_auth_user_id))
import Gargantext.API.Admin.Types (jwtSettings, Settings, jwtSettings)
......@@ -35,9 +36,11 @@ import Gargantext.Prelude
import GHC.Conc (TVar, newTVarIO, readTVar, writeTVar)
import Nanomsg
import Network.WebSockets qualified as WS
import Protolude.Base (Show(showsPrec))
import Servant
import Servant.API.WebSocket qualified as WS
import Servant.Auth.Server (verifyJWT)
import StmContainers.Set as SSet
{-
......@@ -56,7 +59,8 @@ data Topic =
-- children (e.g. list is automatically created in a corpus)
UpdateTree NodeId
deriving (Eq, Show)
instance Hashable Topic where
hashWithSalt salt (UpdateTree nodeId) = hashWithSalt salt ("update-tree" :: Text, nodeId)
instance FromJSON Topic where
parseJSON = Aeson.withObject "Topic" $ \o -> do
type_ <- o .: "type"
......@@ -65,7 +69,6 @@ instance FromJSON Topic where
node_id <- o .: "node_id"
pure $ UpdateTree node_id
s -> prependFailure "parsing type failed, " (typeMismatch "type" s)
instance ToJSON Topic where
toJSON (UpdateTree node_id) = Aeson.object [
"type" .= toJSON ("update_tree" :: Text)
......@@ -76,10 +79,17 @@ data ConnectedUser =
CUUser UserId
| CUPublic
deriving (Eq, Show)
instance Hashable ConnectedUser where
hashWithSalt salt (CUUser userId) = hashWithSalt salt ("cuuser" :: Text, userId)
hashWithSalt salt CUPublic = hashWithSalt salt ("cupublic" :: Text)
newtype WSKeyConnection = WSKeyConnection (ByteString, WS.Connection)
eqWSKeyConnection :: WSKeyConnection -> WSKeyConnection -> Bool
eqWSKeyConnection ws1 ws2 = wsKey ws1 == wsKey ws2
instance Hashable WSKeyConnection where
hashWithSalt salt (WSKeyConnection (key, _conn)) = hashWithSalt salt key
instance Eq WSKeyConnection where
(==) (WSKeyConnection (key1, _conn1)) (WSKeyConnection (key2, _conn2)) = key1 == key2
instance Show WSKeyConnection where
showsPrec d (WSKeyConnection (key, _conn)) = showsPrec d $ "WSKeyConnection " <> key
showWSKeyConnection :: WSKeyConnection -> Text
showWSKeyConnection ws = "WSKeyConnection " <> show (wsKey ws)
wsKey :: WSKeyConnection -> ByteString
......@@ -92,16 +102,10 @@ data Subscription =
s_connected_user :: ConnectedUser
, s_ws_key_connection :: WSKeyConnection
, s_topic :: Topic }
eqSub :: Subscription -> Subscription -> Bool
eqSub sub1 sub2 =
s_connected_user sub1 == s_connected_user sub2 &&
s_ws_key_connection sub2 `eqWSKeyConnection` s_ws_key_connection sub2 &&
s_topic sub1 == s_topic sub2
showSub :: Subscription -> Text
showSub sub =
"Subscription " <> show (s_connected_user sub) <>
" " <> showWSKeyConnection (s_ws_key_connection sub) <>
" " <> show (s_topic sub)
deriving (Eq, Show)
instance Hashable Subscription where
hashWithSalt salt (Subscription { .. }) =
hashWithSalt salt ( s_connected_user, s_ws_key_connection, s_topic )
subKey :: Subscription -> ByteString
subKey sub = wsKey $ s_ws_key_connection sub
......@@ -142,7 +146,7 @@ instance FromJSON WSRequest where
s -> prependFailure "parsing request type failed, " (typeMismatch "request" s)
data Dispatcher =
Dispatcher { d_subscriptions :: TVar [Subscription]
Dispatcher { d_subscriptions :: SSet.Set Subscription
, d_ws_server :: Server WSAPI
, d_ce_listener :: ThreadId
}
......@@ -150,7 +154,7 @@ data Dispatcher =
dispatcher :: Settings -> IO Dispatcher
dispatcher authSettings = do
subscriptions <- newTVarIO ([] :: [Subscription])
subscriptions <- SSet.newIO
let server = wsServer authSettings subscriptions
......@@ -164,33 +168,37 @@ dispatcher authSettings = do
-- | TODO Allow only 1 topic subscription per connection. It doesn't
-- | make sense to send multiple notifications of the same type to the
-- | same connection.
insertSubscription :: TVar [Subscription] -> Subscription -> IO [Subscription]
insertSubscription :: SSet.Set Subscription -> Subscription -> IO ()
insertSubscription subscriptions sub =
atomically $ do
s <- readTVar subscriptions
let ss = nubBy eqSub $ s <> [sub]
writeTVar subscriptions ss
pure ss
atomically $ SSet.insert sub subscriptions
-- s <- readTVar subscriptions
-- let ss = nubBy eqSub $ s <> [sub]
-- writeTVar subscriptions ss
-- -- pure ss
-- pure ()
removeSubscription :: TVar [Subscription] -> Subscription -> IO [Subscription]
removeSubscription :: SSet.Set Subscription -> Subscription -> IO ()
removeSubscription subscriptions sub =
atomically $ do
s <- readTVar subscriptions
let ss = filter (\sub' -> not $ sub `eqSub` sub') s
writeTVar subscriptions ss
pure ss
atomically $ SSet.delete sub subscriptions
-- s <- readTVar subscriptions
-- let ss = filter (\sub' -> not $ sub `eqSub` sub') s
-- writeTVar subscriptions ss
-- pure ss
removeSubscriptionsForWSKey :: TVar [Subscription] -> WSKeyConnection -> IO [Subscription]
removeSubscriptionsForWSKey :: SSet.Set Subscription -> WSKeyConnection -> IO ()
removeSubscriptionsForWSKey subscriptions ws =
atomically $ do
s <- readTVar subscriptions
let ss = filter (\sub -> subKey sub /= wsKey ws) s
writeTVar subscriptions ss
pure ss
let toDelete = UnfoldlM.filter (\sub -> return $ subKey sub == wsKey ws) $ SSet.unfoldlM subscriptions
UnfoldlM.mapM_ (\sub -> SSet.delete sub subscriptions) toDelete
-- atomically $ do
-- s <- readTVar subscriptions
-- let ss = filter (\sub -> subKey sub /= wsKey ws) s
-- writeTVar subscriptions ss
-- pure ss
type WSAPI = "ws" :> WS.WebSocketPending
wsServer :: Settings -> TVar [Subscription] -> Server WSAPI
wsServer :: Settings -> SSet.Set Subscription -> Server WSAPI
wsServer authSettings subscriptions = streamData
where
streamData :: MonadIO m => WS.PendingConnection -> m ()
......@@ -242,16 +250,15 @@ wsServer authSettings subscriptions = streamData
let sub = Subscription { s_connected_user = user
, s_ws_key_connection = ws
, s_topic = topic }
ss <- insertSubscription subscriptions sub
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
_ss <- insertSubscription subscriptions sub
-- putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
return user
Just (WSUnsubscribe topic) -> do
-- TODO Fix s_connected_user based on header
let sub = Subscription { s_connected_user = CUPublic
, s_ws_key_connection = ws
, s_topic = topic }
ss <- removeSubscription subscriptions sub
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
let sub = Subscription { s_connected_user = user
, s_ws_key_connection = ws
, s_topic = topic }
_ss <- removeSubscription subscriptions sub
-- putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
return user
Just (WSAuthorize token) -> do
let jwtS = authSettings ^. jwtSettings
......@@ -273,8 +280,9 @@ wsServer authSettings subscriptions = streamData
disconnect = do
putText "[wsLoop] disconnecting..."
ss <- removeSubscriptionsForWSKey subscriptions ws
putText $ "[wsLoop] subscriptions: " <> show (showSub <$> ss)
_ss <- removeSubscriptionsForWSKey subscriptions ws
-- putText $ "[wsLoop] subscriptions: " <> show (show <$> ss)
return ()
data Notification =
......@@ -287,7 +295,7 @@ instance ToJSON Notification where
]
ce_listener :: TVar [Subscription] -> IO ()
ce_listener :: SSet.Set Subscription -> IO ()
ce_listener subscriptions = do
withSocket Pull $ \s -> do
_ <- bind s "tcp://*:5561"
......@@ -298,7 +306,10 @@ ce_listener subscriptions = do
case Aeson.decode (BSL.fromStrict r) of
Nothing -> putText "[ce_listener] unknown message from central exchange"
Just ceMessage -> do
subs <- atomically $ readTVar subscriptions
-- subs <- atomically $ readTVar subscriptions
filteredSubs <- atomically $ do
let subs' = UnfoldlM.filter (pure . ceMessageSubPred ceMessage) $ SSet.unfoldlM subscriptions
UnfoldlM.foldlM' (\acc sub -> pure $ acc <> [sub]) [] subs'
-- NOTE This isn't safe: we atomically fetch subscriptions,
-- then send notifications one by one. In the meantime, a
-- subscription could end or new ones could appear (but is
......@@ -306,7 +317,7 @@ ce_listener subscriptions = do
-- probably they already fetch new tree anyways, and if old
-- one drops in the meantime, it won't listen to what we
-- send...)
let filteredSubs = filterCEMessageSubs ceMessage subs
-- let filteredSubs = filterCEMessageSubs ceMessage subs
mapM_ (sendNotification ceMessage) filteredSubs
where
sendNotification :: CETypes.CEMessage -> Subscription -> IO ()
......@@ -321,5 +332,8 @@ ce_listener subscriptions = do
-- For example, we can add CEMessage.Broadcast to propagate a
-- notification to all connections.
filterCEMessageSubs :: CETypes.CEMessage -> [Subscription] -> [Subscription]
filterCEMessageSubs (CETypes.UpdateTreeFirstLevel node_id) subscriptions =
filter (\sub -> s_topic sub == UpdateTree node_id) subscriptions
filterCEMessageSubs ceMessage subscriptions = filter (ceMessageSubPred ceMessage) subscriptions
ceMessageSubPred :: CETypes.CEMessage -> Subscription -> Bool
ceMessageSubPred (CETypes.UpdateTreeFirstLevel node_id) (Subscription { s_topic }) =
s_topic == UpdateTree node_id
......@@ -63,7 +63,7 @@ class ResourceId a where
-- whereas this one tracks only users.
newtype UserId = UnsafeMkUserId { _UserId :: Int }
deriving stock (Show, Eq, Ord, Generic)
deriving newtype (ToSchema, ToJSON, FromJSON, FromField, ToField)
deriving newtype (ToSchema, ToJSON, FromJSON, FromField, ToField, Hashable)
-- The 'UserId' is isomprohic to an 'Int'.
instance GQLType UserId where
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment