Commit b7ccc435 authored by Mathieu Rodic's avatar Mathieu Rodic

[CODE] increased flexibility of complex queries on nodes' children

[FEAT] advanced charts now start with the right data
https://forge.iscpif.fr/issues/1510
parent 1ae27347
......@@ -7,17 +7,29 @@ from sqlalchemy import text, distinct, or_
from sqlalchemy.sql import func
from sqlalchemy.orm import aliased
import datetime
import copy
from gargantext_web.views import move_to_trash
from gargantext_web.db import *
from gargantext_web.validation import validate, ValidationException
from node import models
def DebugHttpResponse(data):
return HttpResponse('<html><body style="background:#000;color:#FFF"><pre>%s</pre></body></html>' % (str(data), ))
import json
class JSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()[:19] + 'Z'
else:
return super(self.__class__, self).default(obj)
json_encoder = JSONEncoder(indent=4)
def JsonHttpResponse(data, status=200):
return HttpResponse(
content = json.dumps(data, indent=4),
content = json_encoder.encode(data),
content_type = 'application/json; charset=utf-8',
status = status
)
......@@ -54,7 +66,7 @@ class APIException(_APIException):
self.detail = message
_operators = {
_operators_dict = {
"=": lambda field, value: (field == value),
"!=": lambda field, value: (field != value),
"<": lambda field, value: (field < value),
......@@ -65,6 +77,14 @@ _operators = {
"contains": lambda field, value: (field.contains(value)),
"startswith": lambda field, value: (field.startswith(value)),
}
_hyperdata_list = [
hyperdata
for hyperdata in session.query(Hyperdata).order_by(Hyperdata.name)
]
_hyperdata_dict = {
hyperdata.name: hyperdata
for hyperdata in _hyperdata_list
}
from rest_framework.decorators import api_view
......@@ -75,6 +95,7 @@ def Root(request, format=None):
'snippets': reverse('snippet-list', request=request, format=format)
})
class NodesChildrenNgrams(APIView):
def get(self, request, node_id):
......@@ -117,6 +138,7 @@ class NodesChildrenNgrams(APIView):
],
})
class NodesChildrenDuplicates(APIView):
def _fetch_duplicates(self, request, node_id, extra_columns=None, min_count=1):
......@@ -205,6 +227,7 @@ class NodesChildrenDuplicates(APIView):
'deleted': count
})
class NodesChildrenMetatadata(APIView):
def get(self, request, node_id):
......@@ -264,49 +287,98 @@ class NodesChildrenMetatadata(APIView):
'data': collection,
})
class NodesChildrenQueries(APIView):
def _parse_filter(self, filter):
# validate filter keys
filter_keys = {'field', 'operator', 'value'}
if set(filter) != filter_keys:
raise APIException('Every filter should have exactly %d keys: "%s"'% (len(filter_keys), '", "'.join(filter_keys)), 400)
field, operator, value = filter['field'], filter['operator'], filter['value']
# validate operator
if operator not in _operators:
raise APIException('Invalid operator: "%s"'% (operator, ), 400)
# validate value, depending on the operator
if operator == 'in':
if not isinstance(value, list):
raise APIException('Parameter "value" should be an array when using operator "%s"'% (operator, ), 400)
for v in value:
if not isinstance(v, (int, float, str)):
raise APIException('Parameter "value" should be an array of numbers or strings when using operator "%s"'% (operator, ), 400)
def _sql(self, input, node_id):
fields = dict()
tables = set('nodes')
hyperdata_aliases = dict()
# retrieve all unique fields names
fields_names = input['retrieve']['fields'].copy()
fields_names += [filter['field'] for filter in input['filters']]
fields_names += input['sort']
fields_names = set(fields_names)
# relate fields to their respective ORM counterparts
for field_name in fields_names:
field_name_parts = field_name.split('.')
field = None
if len(field_name_parts) == 1:
field = getattr(Node, field_name)
elif field_name_parts[0] == 'ngrams':
field = getattr(Ngram, field_name)
tables.add('ngrams')
elif field_name_parts[1] == 'count':
if field_name_parts[0] == 'nodes':
field = func.count(Node.id)
elif field_name_parts[0] == 'ngrams':
field = func.count(Ngram.id)
tables.add('ngrams')
elif field_name_parts[0] == 'hyperdata':
hyperdata = _hyperdata_dict[field_name_parts[1]]
if hyperdata not in hyperdata_aliases:
hyperdata_aliases[hyperdata] = aliased(Node_Hyperdata)
hyperdata_alias = hyperdata_aliases[hyperdata]
field = getattr(hyperdata_alias, 'value_%s' % hyperdata.type)
if len(field_name_parts) == 3:
field = func.date_trunc(field_name_parts[2], field)
fields[field_name] = field
# build query: selected fields
query = (session
.query(*(fields[field_name] for field_name in input['retrieve']['fields']))
)
# build query: selected tables
query = query.select_from(Node)
if 'ngrams' in tables:
query = (query
.join(Node_Ngram, Node_Ngram.node_id == Node.id)
.join(Ngram, Ngram.id == Node_Ngram.ngram_id)
)
for hyperdata, hyperdata_alias in hyperdata_aliases.items():
query = (query
.join(hyperdata_alias, hyperdata_alias.node_id == Node.id)
.filter(hyperdata_alias.hyperdata_id == hyperdata.id)
)
# build query: filtering
query = (query
.filter(Node.parent_id == node_id)
)
for filter in input['filters']:
query = (query
.filter(_operators_dict[filter['operator']](
fields[filter['field']],
filter['value']
))
)
# build query: aggregations
if input['retrieve']['aggregate']:
for field_name in input['retrieve']['fields']:
if not field_name.endswith('.count'):
query = query.group_by(fields[field_name])
# build query: sorting
for field_name in input['sort']:
last = field_name[-1:]
if last in ('+', '-'):
field_name = field_name[:-1]
if last == '-':
query = query.order_by(fields[field_name].desc())
else:
if not isinstance(value, (int, float, str)):
raise APIException('Parameter "value" should be a number or string when using operator "%s"'% (operator, ), 400)
query = query.order_by(fields[field_name])
# build and return result
output = copy.deepcopy(input)
output['pagination']['total'] = query.count()
output['results'] = list(
query[input['pagination']['offset']:input['pagination']['offset']+input['pagination']['limit']]
if input['pagination']['limit']
else query[input['pagination']['offset']:]
)
return output
# parse field
field_objects = {
'hyperdata': None,
'ngrams': ['terms', 'n'],
}
field = field.split('.')
if len(field) < 2 or field[0] not in field_objects:
raise APIException('Parameter "field" should be a in the form "object.key", where "object" takes one of the following values: "%s". "%s" was found instead' % ('", "'.join(field_objects), '.'.join(field)), 400)
if field_objects[field[0]] is not None and field[1] not in field_objects[field[0]]:
raise APIException('Invalid key for "%s" in parameter "field", should be one of the following values: "%s". "%s" was found instead' % (field[0], '", "'.join(field_objects[field[0]]), field[1]), 400)
# return value
return field, _operators[operator], value
def _count_documents(self, query):
return {
'fields': []
}
def _haskell(self, input, node_id):
output = copy.deepcopy(input)
output['pagination']['total'] = 0
output['results'] = list()
return output
def post(self, request, node_id):
""" Query the children of the given node.
......@@ -348,199 +420,53 @@ class NodesChildrenQueries(APIView):
}
"""
hyperdata_aliases = {}
# validate query
query_fields = {'pagination', 'retrieve', 'sort', 'filters'}
for key in request.DATA:
if key not in query_fields:
raise APIException('Unrecognized field "%s" in query object. Accepted fields are: "%s"' % (key, '", "'.join(query_fields)), 400)
# selecting info
if 'retrieve' not in request.DATA:
raise APIException('The query should have a "retrieve" parameter.', 400)
retrieve = request.DATA['retrieve']
retrieve_types = {'fields', 'aggregates'}
if 'type' not in retrieve:
raise APIException('In the query\'s "retrieve" parameter, a "type" should be specified. Possible values are: "%s".' % ('", "'.join(retrieve_types), ), 400)
if 'list' not in retrieve or not isinstance(retrieve['list'], list):
raise APIException('In the query\'s "retrieve" parameter, a "list" should be provided as an array', 400)
if retrieve['type'] not in retrieve_types:
raise APIException('Unrecognized "type": "%s" in the query\'s "retrieve" parameter. Possible values are: "%s".' % (retrieve['type'], '", "'.join(retrieve_types), ), 400)
if retrieve['type'] == 'fields':
fields_names = ['id'] + retrieve['list'] if 'id' not in retrieve['list'] else retrieve['list']
elif retrieve['type'] == 'aggregates':
fields_names = list(retrieve['list'])
fields_list = []
for field_name in fields_names:
split_field_name = field_name.split('.')
if split_field_name[0] == 'hyperdata':
hyperdata = session.query(Hyperdata).filter(Hyperdata.name == split_field_name[1]).first()
if hyperdata is None:
hyperdata_query = session.query(Hyperdata.name).order_by(Hyperdata.name)
hyperdata_names = [hyperdata.name for hyperdata in hyperdata_query.all()]
raise APIException('Invalid key for "%s" in parameter "field", should be one of the following values: "%s". "%s" was found instead' % (field[0], '", "'.join(hyperdata_names), field[1]), 400)
# check or create Node_Hyperdata alias; join if necessary
if hyperdata.id in hyperdata_aliases:
hyperdata_alias = hyperdata_aliases[hyperdata.id]
else:
hyperdata_alias = hyperdata_aliases[hyperdata.id] = aliased(Node_Hyperdata)
field = getattr(hyperdata_alias, 'value_' + hyperdata.type)
# operation on field
if len(split_field_name) > 2:
# datetime truncation
# authorized field names
sql_fields = set({
'id', 'name',
'nodes.count', 'ngrams.count',
'ngrams.terms', 'ngrams.n',
})
for hyperdata in _hyperdata_list:
sql_fields.add('hyperdata.' + hyperdata.name)
if hyperdata.type == 'datetime':
datepart = split_field_name[2]
accepted_dateparts = ['year', 'month', 'day', 'hour', 'minute']
if datepart not in accepted_dateparts:
raise APIException('Invalid date truncation for "%s": "%s". Accepted values are: "%s".' % (split_field_name[1], split_field_name[2], '", "'.join(accepted_dateparts), ), 400)
# field = extract(datepart, field)
field = func.date_trunc(datepart, field)
# field = func.date_trunc(text('"%s"'% (datepart,)), field)
else:
authorized_field_names = {'id', 'name', }
authorized_aggregates = {
'nodes.count': func.count(Node.id),
'ngrams.count': func.count(Ngram.id),
}
if retrieve['type'] == 'aggregates' and field_name in authorized_aggregates:
field = authorized_aggregates[field_name]
elif field_name in authorized_field_names:
field = getattr(Node, field_name)
else:
raise APIException('Unrecognized "field": "%s" in the query\'s "retrieve" parameter. Possible values are: "%s".' % (field_name, '", "'.join(authorized_field_names), ))
fields_list.append(
field.label(
field_name if '.' in field_name else 'node.' + field_name
)
)
for part in ['year', 'month', 'day', 'hour', 'minute']:
sql_fields.add('hyperdata.' + hyperdata.name + '.' + part)
# starting the query!
document_type_id = cache.NodeType['Document'].id ##session.query(NodeType.id).filter(NodeType.name == 'Document').scalar()
query = (session
.query(*fields_list)
.select_from(Node)
.filter(Node.type_id == document_type_id)
.filter(Node.parent_id == node_id)
)
# join ngrams if necessary
if 'ngrams.count' in fields_names:
query = (query
.join(Node_Ngram, Node_Ngram.node_id == Node.id)
.join(Ngram, Ngram.id == Node_Ngram.ngram_id)
)
# authorized field names: Haskell
haskell_fields = set({
'haskell.test',
})
# join hyperdata aliases
for hyperdata_id, hyperdata_alias in hyperdata_aliases.items():
query = (query
.join(hyperdata_alias, hyperdata_alias.node_id == Node.id)
.filter(hyperdata_alias.hyperdata_id == hyperdata_id)
)
# authorized field names: all of them
authorized_fields = sql_fields | haskell_fields
# filtering
for filter in request.DATA.get('filters', []):
# parameters extraction & validation
field, operator, value = self._parse_filter(filter)
#
if field[0] == 'hyperdata':
# which hyperdata?
hyperdata = session.query(Hyperdata).filter(Hyperdata.name == field[1]).first()
if hyperdata is None:
hyperdata_query = session.query(Hyperdata.name).order_by(Hyperdata.name)
hyperdata_names = [hyperdata.name for hyperdata in hyperdata_query.all()]
raise APIException('Invalid key for "%s" in parameter "field", should be one of the following values: "%s". "%s" was found instead' % (field[0], '", "'.join(hyperdata_names), field[1]), 400)
# check or create Node_Hyperdata alias; join if necessary
if hyperdata.id in hyperdata_aliases:
hyperdata_alias = hyperdata_aliases[hyperdata.id]
# input validation
input = validate(request.DATA, {'type': dict, 'items': {
'pagination': {'type': dict, 'items': {
'limit': {'type': int, 'default': 0},
'offset': {'type': int, 'default': 0},
}, 'default': {'limit': 0, 'offset': 0}},
'filters': {'type': list, 'items': {'type': dict, 'items': {
'field': {'type': str, 'required': True, 'range': authorized_fields},
'operator': {'type': str, 'required': True, 'range': list(_operators_dict.keys())},
'value': {'required': True},
}}, 'default': list()},
'retrieve': {'type': dict, 'required': True, 'items': {
'aggregate': {'type': bool, 'default': False},
'fields': {'type': list, 'items': {'type': str, 'range': authorized_fields}, 'range': (1, )},
}},
'sort': {'type': list, 'items': {'type': str}, 'default': list()},
}})
# return result, depending on the queried fields
if set(input['retrieve']['fields']) <= sql_fields:
method = self._sql
elif set(input['retrieve']['fields']) <= haskell_fields:
method = self._haskell
else:
hyperdata_alias = hyperdata_aliases[hyperdata.id] = aliased(Node_Hyperdata)
query = (query
.join(hyperdata_alias, hyperdata_alias.node_id == Node.id)
.filter(hyperdata_alias.hyperdata_id == hyperdata.id)
)
# adjust date
if hyperdata.type == 'datetime':
value = value + '2000-01-01T00:00:00Z'[len(value):]
# filter query
query = query.filter(operator(
getattr(hyperdata_alias, 'value_' + hyperdata.type),
value
))
elif field[0] == 'ngrams':
query = query.filter(
Node.id.in_(session
.query(Node_Ngram.node_id)
.join(Ngram, Ngram.id == Node_Ngram.ngram_id)
.filter(operator(
getattr(Ngram, field[1]),
map(lambda x: x.replace('-', ' '), value)
))
)
)
# TODO: date_trunc (psql) -> index also
raise ValidationException('queried fields are mixing incompatible types of fields')
return JsonHttpResponse(method(input, node_id), 201)
# groupping
for field_name in fields_names:
if field_name not in authorized_aggregates:
# query = query.group_by(text(field_name))
query = query.group_by('"%s"' % (
field_name if '.' in field_name else 'node.' + field_name
, ))
# sorting
sort_fields_names = request.DATA.get('sort', ['id'])
if not isinstance(sort_fields_names, list):
raise APIException('The query\'s "sort" parameter should be an array', 400)
sort_fields_list = []
for sort_field_name in sort_fields_names:
try:
desc = sort_field_name[0] == '-'
if sort_field_name[0] in {'-', '+'}:
sort_field_name = sort_field_name[1:]
field = fields_list[fields_names.index(sort_field_name)]
if desc:
field = field.desc()
sort_fields_list.append(field)
except:
raise APIException('Unrecognized field "%s" in the query\'s "sort" parameter. Accepted values are: "%s"' % (sort_field_name, '", "'.join(fields_names)), 400)
query = query.order_by(*sort_fields_list)
# pagination
pagination = request.DATA.get('pagination', {})
for key, value in pagination.items():
if key not in {'limit', 'offset'}:
raise APIException('Unrecognized parameter in "pagination": "%s"' % (key, ), 400)
if not isinstance(value, int):
raise APIException('In "pagination", "%s" should be an integer.' % (key, ), 400)
if 'offset' not in pagination:
pagination['offset'] = 0
if 'limit' not in pagination:
pagination['limit'] = 0
# respond to client!
# return DebugHttpResponse(str(query))
# return DebugHttpResponse(literalquery(query))
results = [
list(row)
# dict(zip(fields_names, row))
for row in (
query[pagination["offset"]:pagination["offset"]+pagination["limit"]]
if pagination['limit']
else query[pagination["offset"]:]
)
]
pagination["total"] = query.count()
return Response({
"pagination": pagination,
"retrieve": fields_names,
"sorted": sort_fields_names,
"results": results,
}, 201)
class NodesList(APIView):
authentication_classes = (SessionAuthentication, BasicAuthentication)
......@@ -598,6 +524,7 @@ class Nodes(APIView):
except Exception as error:
msgres ="error deleting : " + node_id + str(error)
class CorpusController:
@classmethod
......
from rest_framework.exceptions import APIException
from datetime import datetime
__all__ = ['validate']
_types_names = {
bool: 'boolean',
int: 'integer',
float: 'float',
str: 'string',
dict: 'object',
list: 'array',
datetime: 'datetime',
}
class ValidationException(APIException):
status_code = 400
default_detail = 'Bad request!'
def validate(value, expected, path='input'):
# Is the expected type respected?
if 'type' in expected:
expected_type = expected['type']
if not isinstance(value, expected_type):
if expected_type in (bool, int, float, str, datetime, ):
try:
if expected_type == bool:
value = value not in {0, 0.0, '', '0', 'false'}
elif expected_type == datetime:
value = value + '2000-01-01T00:00:00Z'[len(value):]
value = datetime.strptime(value, '%Y-%m-%dT%H:%M:%SZ')
else:
value = expected_type(value)
except ValueError:
raise ValidationException('%s should be a JSON %s, but could not be parsed as such' % (path, _types_names[expected_type], ))
else:
raise ValidationException('%s should be a JSON %s' % (path, _types_names[expected_type], ))
else:
expected_type = type(value)
# Is the value in the expected range?
if 'range' in expected:
expected_range = expected['range']
if isinstance(expected_range, tuple):
if expected_type in (int, float):
tested_value = value
tested_name = 'value'
elif expected_type in (str, list):
tested_value = len(value)
tested_name = 'length'
if tested_value < expected_range[0]:
raise ValidationException('%s should have a minimum %s of %d' % (path, tested_name, expected_range[0], ))
if len(expected_range) > 1 and tested_value > expected_range[1]:
raise ValidationException('%s should have a maximum %s of %d' % (path, tested_name, expected_range[1], ))
elif isinstance(expected_range, (list, set, dict, )) and value not in expected_range:
expected_values = expected_range if isinstance(expected_range, list) else expected_range.keys()
expected_values = [str(value) for value in expected_values if isinstance(value, expected_type)]
if len(expected_values) < 16:
expected_values_str = '", "'.join(expected_values)
expected_values_str = '"' + expected_values_str + '"'
else:
expected_values_str = '", "'.join(expected_values[:16])
expected_values_str = '"' + expected_values_str + '"...'
raise ValidationException('%s should take one of the following values: %s' % (path, expected_values_str, ))
# Do we have to translate through a dictionary?
if 'translate' in expected:
translate = expected['translate']
if callable(translate):
value = translate(value)
if value is None and expected.get('required', False):
raise ValidationException('%s has been given an invalid value' % (path, ))
return value
try:
value = expected['translate'][value]
except KeyError:
if expected.get('translate_fallback_keep', False):
return value
if expected.get('required', False):
raise ValidationException('%s has been given an invalid value' % (path, ))
else:
return expected.get('default', value)
# Are we handling an iterable?
if expected_type in (list, dict):
if 'items' in expected:
expected_items = expected['items']
if expected_type == list:
for i, element in enumerate(value):
value[i] = validate(element, expected_items, '%s[%d]' % (path, i, ))
elif expected_type == dict:
if expected_items:
for key in value:
if key not in expected_items:
raise ValidationException('%s should not have a "%s" key.' % (path, key, ))
for expected_key, expected_value in expected_items.items():
if expected_key in value:
value[expected_key] = validate(value[expected_key], expected_value, '%s["%s"]' % (path, expected_key, ))
elif 'required' in expected_value and expected_value['required']:
raise ValidationException('%s should have a "%s" key.' % (path, expected_key, ))
elif 'default' in expected_value:
value[expected_key] = expected_value['default']
# Let's return the proper value!
return value
......@@ -259,11 +259,21 @@ gargantext.controller("DatasetController", function($scope, $http) {
$scope.corpora = [];
$http.get('/api/nodes?type=Project', {cache: true}).success(function(response){
$scope.projects = response.data;
// Initially set to what is indicated in the URL
if (/^\/project\/\d+\/corpus\/\d+/.test(location.pathname)) {
$scope.projectId = parseInt(location.pathname.split('/')[2]);
$scope.updateCorpora();
}
});
// update corpora according to the select parent project
$scope.updateCorpora = function() {
$http.get('/api/nodes?type=Corpus&parent=' + $scope.projectId, {cache: true}).success(function(response){
$scope.corpora = response.data;
// Initially set to what is indicated in the URL
if (/^\/project\/\d+\/corpus\/\d+/.test(location.pathname)) {
$scope.corpusId = parseInt(location.pathname.split('/')[4]);
$scope.updateEntities();
}
});
};
// update entities depending on the selected corpus
......@@ -522,8 +532,8 @@ gargantext.controller("GraphController", function($scope, $http, $element) {
filters: query.filters,
sort: ['hyperdata.publication_date.day'],
retrieve: {
type: 'aggregates',
list: ['hyperdata.publication_date.day', query.mesured]
aggregate: true,
fields: ['hyperdata.publication_date.day', query.mesured]
}
};
// request to the server
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment