Commit d6aa1e6e authored by Mathieu Rodic's avatar Mathieu Rodic

[OPTI] improved SQLAlchemy sessions management

(using `scoped_session`, removed unnecessary use of Aldjemy)
parent 4a9f94fb
...@@ -22,10 +22,8 @@ class User(Base): ...@@ -22,10 +22,8 @@ class User(Base):
is_active = Column(Boolean()) is_active = Column(Boolean())
date_joined = DateTime(timezone=False) date_joined = DateTime(timezone=False)
def get_contacts(self, session=None): def get_contacts(self):
"""get all contacts in relation with the user""" """get all contacts in relation with the user"""
if session is None:
session = Session()
Friend = aliased(User) Friend = aliased(User)
query = (session query = (session
.query(Friend) .query(Friend)
...@@ -34,11 +32,9 @@ class User(Base): ...@@ -34,11 +32,9 @@ class User(Base):
) )
return query.all() return query.all()
def get_nodes(self, session=None, nodetype=None): def get_nodes(self, nodetype=None):
"""get all nodes belonging to the user""" """get all nodes belonging to the user"""
from .nodes import Node from .nodes import Node
if session is None:
session = Session()
query = (session query = (session
.query(Node) .query(Node)
.filter(Node.user_id == self.id) .filter(Node.user_id == self.id)
......
from aldjemy.core import get_engine
from gargantext import settings from gargantext import settings
# get engine, session, etc. # get engine, session, etc.
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
def get_engine(): def get_engine():
...@@ -18,7 +17,7 @@ engine = get_engine() ...@@ -18,7 +17,7 @@ engine = get_engine()
Base = declarative_base() Base = declarative_base()
Session = sessionmaker(bind=engine) session = scoped_session(sessionmaker(bind=engine))
# tools to build models # tools to build models
......
...@@ -12,7 +12,6 @@ from gargantext import models ...@@ -12,7 +12,6 @@ from gargantext import models
class ModelCache(dict): class ModelCache(dict):
def __init__(self, model, preload=False): def __init__(self, model, preload=False):
self._session = Session()
self._model = model self._model = model
self._columns = [column for column in model.__table__.columns if column.unique or column.primary_key] self._columns = [column for column in model.__table__.columns if column.unique or column.primary_key]
self._columns_names = [column.name for column in self._columns] self._columns_names = [column.name for column in self._columns]
...@@ -20,7 +19,7 @@ class ModelCache(dict): ...@@ -20,7 +19,7 @@ class ModelCache(dict):
self.preload() self.preload()
def __del__(self): def __del__(self):
self._session.close() session.close()
def __missing__(self, key): def __missing__(self, key):
formatted_key = None formatted_key = None
...@@ -34,7 +33,7 @@ class ModelCache(dict): ...@@ -34,7 +33,7 @@ class ModelCache(dict):
if formatted_key in self: if formatted_key in self:
self[key] = self[formatted_key] self[key] = self[formatted_key]
else: else:
element = self._session.query(self._model).filter(or_(*conditions)).first() element = session.query(self._model).filter(or_(*conditions)).first()
if element is None: if element is None:
raise KeyError raise KeyError
self[key] = element self[key] = element
...@@ -42,7 +41,7 @@ class ModelCache(dict): ...@@ -42,7 +41,7 @@ class ModelCache(dict):
def preload(self): def preload(self):
self.clear() self.clear()
for element in self._session.query(self._model).all(): for element in session.query(self._model).all():
for column_name in self._columns_names: for column_name in self._columns_names:
key = getattr(element, column_name) key = getattr(element, column_name)
self[key] = element self[key] = element
......
...@@ -13,7 +13,6 @@ def overview(request): ...@@ -13,7 +13,6 @@ def overview(request):
To each project, we can link a resource that can be an image. To each project, we can link a resource that can be an image.
''' '''
session = Session()
user = cache.User[request.user.username] user = cache.User[request.user.username]
project_type = cache.NodeType['Project'] project_type = cache.NodeType['Project']
...@@ -30,10 +29,10 @@ def overview(request): ...@@ -30,10 +29,10 @@ def overview(request):
session.commit() session.commit()
# list of projects created by the logged user # list of projects created by the logged user
user_projects = user.get_nodes(session=session, nodetype=project_type) user_projects = user.get_nodes(nodetype=project_type)
# list of contacts of the logged user # list of contacts of the logged user
contacts = user.get_contacts(session=session) contacts = user.get_contacts()
contacts_projects = [] contacts_projects = []
for contact in contacts: for contact in contacts:
contact_projects = (session contact_projects = (session
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment