Commit 8eb55509 authored by Alfredo Di Napoli's avatar Alfredo Di Napoli

Clean-up policy API

parent 3057490b
Pipeline #4642 failed with stages
in 3 minutes and 39 seconds
......@@ -66,7 +66,7 @@ import Gargantext.Database.Schema.Node (NodePoly(_node_id))
import Gargantext.Prelude hiding (reverse)
import Gargantext.Prelude.Crypto.Pass.User (gargPass)
import Gargantext.Utils.Jobs (serveJobsAPI, MonadJobStatus(..))
import Protolude hiding (to)
import Protolude hiding (Handler, to)
import Servant
import Servant.Auth.Server
import qualified Data.Text as Text
......@@ -161,23 +161,23 @@ withAccess p _ ur id = hoistServer p f
f :: forall a. m a -> m a
f = withAccessM ur id
withPolicy :: forall env m api. (GargServerC env GargError m, HasServer api '[])
-- | Given the 'AuthenticatedUser', a policy check and a function that returns an @a@,
-- it runs the underlying policy check to ensure that the resource is returned only to
-- who is entitled to see it.
withPolicy :: GargServerC env GargError m
=> AuthenticatedUser
-> BoolExpr AccessCheck
-> Proxy api
-> Proxy m
-> ServerT api m
-> m a
-> AccessPolicyManager
-> ServerT api m
withPolicy ur checks p _ m0 mgr = hoistServer p f m0
where
f :: forall a. m a -> m a
f m = case mgr of
AccessPolicyManager{runAccessPolicy} -> do
res <- runAccessPolicy ur checks
case res of
Allow -> m
Deny err -> throwError $ GargServerError err
-> m a
withPolicy ur checks h mgr = do
a <- h
case mgr of
AccessPolicyManager{runAccessPolicy} -> do
res <- runAccessPolicy ur checks
case res of
Allow -> pure a
Deny err -> throwError $ GargServerError $ err
{- | Collaborative Schema
User at his root can create Teams Folder
......
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
......@@ -11,9 +10,12 @@ module Gargantext.API.Auth.PolicyCheck (
, PolicyChecked
, BoolExpr(..)
-- * Smart constructors
, nodeOwner
-- * Smart constructors for access checks
, nodeDescendant
, nodeSuper
, nodeUser
, nodeChecks
, alwaysAllow
) where
import Control.Lens
......@@ -34,9 +36,20 @@ import Data.BoolExpr
import Control.Monad
import Gargantext.API.Prelude
import Servant.Auth.Server.Internal.AddSetCookie
import Gargantext.Database.Query.Tree
-------------------------------------------------------------------------------
-- Types
-------------------------------------------------------------------------------
-- | Phantom type that allows us to embellish a Servant route with a policy check.
data PolicyChecked a
-- | The result of an access check.
data AccessResult
= Allow
= -- | Grants access.
Allow
-- | Denies access with the given 'ServerError'.
| Deny ServerError
instance Semigroup AccessResult where
......@@ -48,64 +61,97 @@ instance Semigroup AccessResult where
instance Monoid AccessResult where
mempty = Allow
enforce :: Applicative m => ServerError -> Bool -> m AccessResult
enforce errStatus p = pure $ if p then Allow else Deny errStatus
-- | An access policy manager for gargantext that governs how resources are accessed
-- and who is entitled to see what.
data AccessPolicyManager = AccessPolicyManager
{ runAccessPolicy :: AuthenticatedUser -> BoolExpr AccessCheck -> DBCmd GargError AccessResult }
-- | A type representing all the possible access checks we might want to perform on a resource,
-- typically a 'Node'.
data AccessCheck
= AC_node_owner NodeId
| AC_master_user NodeId
nodeOwner :: NodeId -> BoolExpr AccessCheck
nodeOwner = BConst . Positive . AC_node_owner
= -- | Grants access if the input 'NodeId' is a descendant of the
-- one for the logged-in user.
AC_node_descendant NodeId
-- | Grants access if the input 'NodeId' /is/ the logged-in user.
| AC_user_node NodeId
-- | Grants access if the logged-in user is the master user.
| AC_master_user NodeId
-- | Always grant access, effectively a public route.
| AC_always_allow
deriving (Show, Eq)
-------------------------------------------------------------------------------
-- Running access checks
-------------------------------------------------------------------------------
-- | The static access manager returned as part of a 'Servant' handler every time
-- we use the 'PolicyChecked' combinator.
accessPolicyManager :: AccessPolicyManager
accessPolicyManager = AccessPolicyManager (\ur ac -> interpretPolicy ur ac)
where
interpretPolicy :: AuthenticatedUser -> BoolExpr AccessCheck -> DBCmd GargError AccessResult
interpretPolicy ur chk = case chk of
BAnd b1 b2
-> liftM2 (<>) (interpretPolicy ur b1) (interpretPolicy ur b2)
BOr b1 b2
-> do
c1 <- interpretPolicy ur b1
case c1 of
Allow -> pure Allow
Deny{} -> interpretPolicy ur b2
BNot b1
-> do
res <- interpretPolicy ur b1
case res of
Allow -> pure $ Deny err403
Deny _ -> pure Allow
BTrue
-> pure Allow
BFalse
-> pure $ Deny err403
BConst (Positive b)
-> check ur b
BConst (Negative b)
-> check ur b
nodeSuper :: NodeId -> BoolExpr AccessCheck
nodeSuper = BConst . Positive . AC_master_user
check :: HasNodeError err => AuthenticatedUser -> AccessCheck -> DBCmd err AccessResult
check (AuthenticatedUser nodeId) = \case
AC_node_owner requestedNodeId
-> enforce err403 $ nodeId == requestedNodeId
check (AuthenticatedUser loggedUserNodeId) = \case
AC_always_allow
-> pure Allow
AC_user_node requestedNodeId
-> enforce err403 $ loggedUserNodeId == requestedNodeId
AC_master_user _requestedNodeId
-> do
masterUsername <- _gc_masteruser <$> view hasConfig
masterNodeId <- getUserId (UserName masterUsername)
enforce err403 $ (NodeId masterNodeId) == nodeId
enforce err403 $ (NodeId masterNodeId) == loggedUserNodeId
AC_node_descendant nodeId
-> enforce err403 =<< nodeId `isDescendantOf` loggedUserNodeId
accessPolicyManager :: AccessPolicyManager
accessPolicyManager = AccessPolicyManager (\ur ac -> interpretPolicy ur ac)
-------------------------------------------------------------------------------
-- Smart constructors of access checks
-------------------------------------------------------------------------------
interpretPolicy :: AuthenticatedUser -> BoolExpr AccessCheck -> DBCmd GargError AccessResult
interpretPolicy ur = \case
BAnd b1 b2
-> liftM2 (<>) (interpretPolicy ur b1) (interpretPolicy ur b2)
BOr b1 b2
-> do
c1 <- interpretPolicy ur b1
case c1 of
Allow -> pure Allow
Deny{} -> interpretPolicy ur b2
BNot b1
-> do
res <- interpretPolicy ur b1
case res of
Allow -> pure $ Deny err403
Deny _ -> pure Allow
BTrue
-> pure Allow
BFalse
-> pure $ Deny err403
BConst (Positive b)
-> check ur b
BConst (Negative b)
-> check ur b
nodeUser :: NodeId -> BoolExpr AccessCheck
nodeUser = BConst . Positive . AC_user_node
data PolicyChecked a
nodeSuper :: NodeId -> BoolExpr AccessCheck
nodeSuper = BConst . Positive . AC_master_user
nodeDescendant :: NodeId -> BoolExpr AccessCheck
nodeDescendant = BConst . Positive . AC_node_descendant
nodeChecks :: NodeId -> BoolExpr AccessCheck
nodeChecks nid =
nodeUser nid `BOr` nodeSuper nid `BOr` nodeDescendant nid
alwaysAllow :: BoolExpr AccessCheck
alwaysAllow = BConst . Positive $ AC_always_allow
-------------------------------------------------------------------------------
-- Instances
-------------------------------------------------------------------------------
instance (HasServer subApi ctx) => HasServer (PolicyChecked subApi) ctx where
type ServerT (PolicyChecked subApi) m = AccessPolicyManager -> ServerT subApi m
......@@ -126,3 +172,12 @@ instance Swagger.HasSwagger sub => Swagger.HasSwagger (PolicyChecked sub) where
instance HasEndpoint sub => HasEndpoint (PolicyChecked sub) where
getEndpoint _ = getEndpoint (Proxy :: Proxy sub)
enumerateEndpoints _ = enumerateEndpoints (Proxy :: Proxy sub)
-------------------------------------------------------------------------------
-- Utility functions
-------------------------------------------------------------------------------
-- | If the given predicate holds then grant access, otherwise denies access
-- with the given 'ServerError'.
enforce :: Applicative m => ServerError -> Bool -> m AccessResult
enforce errStatus p = pure $ if p then Allow else Deny errStatus
......@@ -119,7 +119,7 @@ roots = getNodesWithParentId Nothing
-- CanFavorite
-- CanMoveToTrash
type NodeAPI a = PolicyChecked (Get '[JSON] (Node a))
type NodeAPI a = PolicyChecked (NodeNodeAPI a)
:<|> "rename" :> RenameApi
:<|> PostNodeApi -- TODO move to children POST
:<|> PostNodeAsync
......@@ -193,7 +193,7 @@ nodeNodeAPI p uId cId nId = withAccess (Proxy :: Proxy (NodeNodeAPI a)) Proxy uI
------------------------------------------------------------------------
-- TODO: make the NodeId type indexed by `a`, then we no longer need the proxy.
nodeAPI :: forall proxy a.
( HyperdataC a
( HyperdataC a, Show a
) => proxy a
-> AuthenticatedUser
-> NodeId
......@@ -201,14 +201,8 @@ nodeAPI :: forall proxy a.
nodeAPI p authenticatedUser@(AuthenticatedUser (NodeId uId)) id' = withAccess (Proxy :: Proxy (NodeAPI a)) Proxy authenticatedUser (PathNode id') nodeAPI'
where
api :: Proxy (NodeNodeAPI a)
api = Proxy
m :: Proxy (GargM Env GargError)
m = Proxy
nodeAPI' :: ServerT (NodeAPI a) (GargM Env GargError)
nodeAPI' = withPolicy authenticatedUser (nodeOwner id' `BOr` nodeSuper id') api m (getNodeWith id' p)
nodeAPI' = withPolicy authenticatedUser (nodeChecks id') (getNodeWith id' p)
:<|> rename id'
:<|> postNode uId id'
:<|> postNodeAsyncAPI uId id'
......
......@@ -210,7 +210,7 @@ pgContextId :: ContextId -> O.Column O.SqlInt4
pgContextId = pgNodeId
------------------------------------------------------------------------
newtype NodeId = NodeId Int
newtype NodeId = NodeId { _NodeId :: Int }
deriving (Read, Generic, Num, Eq, Ord, Enum, ToJSONKey, FromJSONKey, ToJSON, FromJSON, Hashable, Csv.ToField)
instance GQLType NodeId
instance Show NodeId where
......
......@@ -59,7 +59,7 @@ import Gargantext.Core.Types.Main (NodeTree(..), Tree(..))
import Gargantext.Database.Admin.Config (fromNodeTypeId, nodeTypeId, fromNodeTypeId)
import Gargantext.Database.Admin.Types.Hyperdata.Any (HyperdataAny)
import Gargantext.Database.Admin.Types.Node
import Gargantext.Database.Prelude (Cmd, runPGSQuery)
import Gargantext.Database.Prelude (Cmd, runPGSQuery, DBCmd)
import Gargantext.Database.Query.Table.Node (getNodeWith)
import Gargantext.Database.Query.Table.Node.Error (HasNodeError)
import Gargantext.Database.Query.Table.NodeNode (getNodeNode)
......@@ -356,7 +356,7 @@ dbTree rootId nodeTypes = map (\(nId, tId, pId, n) -> DbTreeNode nId tId pId n)
[] -> allNodeTypes
_ -> nodeTypes
isDescendantOf :: NodeId -> RootId -> Cmd err Bool
isDescendantOf :: NodeId -> RootId -> DBCmd err Bool
isDescendantOf childId rootId = (== [Only True])
<$> runPGSQuery [sql|
BEGIN ;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment