Commit 1ae27347 authored by Mathieu Rodic's avatar Mathieu Rodic

[FEATURE] implemented operations for WeightedMatrix

https://forge.iscpif.fr/issues/1507
parent 25584756
from collections import defaultdict from collections import defaultdict
from math import sqrt
from gargantext_web.db import session, NodeNgram, NodeNgramNgram, bulk_insert from gargantext_web.db import session, NodeNgram, NodeNgramNgram, bulk_insert
...@@ -132,12 +133,22 @@ class WeightedMatrix(BaseClass): ...@@ -132,12 +133,22 @@ class WeightedMatrix(BaseClass):
def __init__(self, other=None): def __init__(self, other=None):
if other is None: if other is None:
self.items = defaultdict(lambda: defaultdict(float)) self.items = defaultdict(lambda: defaultdict(float))
elif isinstance(other, int):
query = (session
.query(NodeNgramNgram.ngramx_id, NodeNgramNgram.ngramy_id, NodeNgramNgram.score)
.filter(NodeNgramNgram.node_id == other)
)
self.items = defaultdict(lambda: defaultdict(float))
for key1, key2, value in self.items.items():
self.items[key1][key2] = value
elif isinstance(other, WeightedMatrix): elif isinstance(other, WeightedMatrix):
self.items = other.items.copy() self.items = defaultdict(lambda: defaultdict(float))
for key1, key2, value in other:
self.items[key1][key2] = value
elif hasattr(other, '__iter__'): elif hasattr(other, '__iter__'):
self.items = defaultdict(lambda: defaultdict(float)) self.items = defaultdict(lambda: defaultdict(float))
for row in other: for row in other:
self.items[other[0]][other[1]] = [other[2]] self.items[row[0]][row[1]] = row[2]
else: else:
raise TypeError raise TypeError
...@@ -146,60 +157,84 @@ class WeightedMatrix(BaseClass): ...@@ -146,60 +157,84 @@ class WeightedMatrix(BaseClass):
for key2, value in key2_value.items(): for key2, value in key2_value.items():
yield key1, key2, value yield key1, key2, value
def __add__(self, other): def save(self, node_id):
# delete previous data
session.query(NodeNgramNgram).filter(NodeNgramNgram.node_id == node_id).delete()
session.commit()
# insert new data
bulk_insert(
NodeNgramNgram,
('node_id', 'ngramx_id', 'ngramy_id', 'score'),
((node_id, key1, key2, value) for key1, key2, value in self)
)
def __radd__(self, other):
result = NotImplemented result = NotImplemented
if isinstance(other, WeightedMatrix): if isinstance(other, WeightedMatrix):
result = WeightedMatrix(self) result = WeightedMatrix()
for key1, key2_value in other.items.items(): for key1, key2, value in self:
for key2, value in key2_value.items(): value = value + other.items[key1][key2]
result.items[key1][key2] += value if value != 0.0:
result.items[key1][key2] = value
return result return result
def __and__(self, other): def __rsub__(self, other):
result = NotImplemented result = NotImplemented
if isinstance(other, (UnweightedList, WeightedList)): if isinstance(other, (UnweightedList, WeightedList)):
result = WeightedMatrix() result = WeightedMatrix()
for key1, key2_value in self.items.items(): for key1, key2, value in self:
if key1 not in other.items: if key1 in other.items or key2 in other.items:
continue continue
for key2, value in key2_value.items(): result.items[key1][key2] = value
if key2 not in other.items: elif isinstance(other, WeightedMatrix):
continue result = WeightedMatrix()
for key1, key2, value in self:
value = value - other.items[key1][key2]
if value != 0.0:
result.items[key1][key2] = value result.items[key1][key2] = value
return result return result
def __rsub__(self, other): def __rand__(self, other):
"""Remove elements of the other list from the current one
Can only be substracted to another list of coocurrences.
"""
result = NotImplemented result = NotImplemented
if isinstance(other, (UnweightedList, WeightedList)): if isinstance(other, (UnweightedList, WeightedList)):
result = WeightedMatrix() result = WeightedMatrix()
for key1, key2_value in self.items.items(): for key1, key2, value in self:
if key1 in other.items: if key1 not in other.items or key2 not in other.items:
continue continue
for key2, value in key2_value.items(): result.items[key1][key2] = value
if key2 in other.items:
continue
result.items[key1][key2] = value
return result return result
def __mul__(self, other): def __rmul__(self, other):
result = NotImplemented result = NotImplemented
if isinstance(other, Translations): if isinstance(other, Translations):
result = WeightedList() result = WeightedMatrix()
for key1, key2_value in self.items.items(): for key1, key2_value in self.items.items():
key1 = other.items.get(key1, key1) key1 = other.items.get(key1, key1)
for key2, value in key2_value.items(): for key2, value in key2_value.items():
result.items[key1][ result.items[key1][
other.items.get(key2, key2) other.items.get(key2, key2)
] += value ] += value
elif isinstance(other, UnweightedList):
result = self.__rand__(other)
# elif isinstance(other, WeightedMatrix):
# result = WeightedMatrix()
elif isinstance(other, WeightedList):
result = WeightedMatrix()
for key1, key2, value in self:
if key1 not in other.items or key2 not in other.items:
continue
result.items[key1][key2] = value * sqrt(other.items[key1] * other.items[key2])
return result return result
def __iter__(self): def __rdiv__(self, other):
for key1, key2_value in self.items.items(): result = NotImplemented
for key2, value in key2_value.items(): if isinstance(other, WeightedList):
yield key1, key2, value result = WeightedMatrix()
for key1, key2, value in self:
if key1 not in other.items or key2 not in other.items:
continue
result.items[key1][key2] = value / sqrt(other.items[key1] * other.items[key2])
return result
class UnweightedList(BaseClass): class UnweightedList(BaseClass):
...@@ -414,20 +449,21 @@ def test(): ...@@ -414,20 +449,21 @@ def test():
from collections import OrderedDict from collections import OrderedDict
# define operands # define operands
operands = OrderedDict() operands = OrderedDict()
operands['ul1'] = UnweightedList((1, 2, 3, 4, 5)) operands['wm'] = WeightedMatrix(((1, 2, .5), (1, 3, .75), (2, 3, .6), (3, 3, 1), ))
operands['ul2'] = UnweightedList((1, 2, 3, 6)) operands['ul'] = UnweightedList((1, 2, 3, 4, 5))
# operands['ul2'] = UnweightedList((1, 2, 3, 6))
# operands['ul2'].save(5) # operands['ul2'].save(5)
# operands['ul3'] = UnweightedList(5) # operands['ul3'] = UnweightedList(5)
operands['wl1'] = WeightedList({1:.7, 2:.8, 7: 1.1}) operands['wl'] = WeightedList({1:.7, 2:.8, 7: 1.1})
# operands['wl1'].save(5) # operands['wl1'].save(5)
# operands['wl2'] = WeightedList(5) # operands['wl2'] = WeightedList(5)
operands['t1'] = Translations({1:2, 4:5}) # operands['t1'] = Translations({1:2, 4:5})
operands['t2'] = Translations({3:2, 4:5}) operands['t'] = Translations({3:2, 4:5})
# operands['t2'].save(5) # operands['t2'].save(5)
# operands['t3'] = Translations(5) # operands['t3'] = Translations(5)
# define operators # define operators
operators = OrderedDict() operators = OrderedDict()
# operators['+'] = '__add__' operators['+'] = '__add__'
operators['-'] = '__sub__' operators['-'] = '__sub__'
operators['*'] = '__mul__' operators['*'] = '__mul__'
operators['|'] = '__or__' operators['|'] = '__or__'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment