Skip to content

Commit 70a58b6

Browse files
committed
add an implementation trading computation for memory
Realizing that the weights are also sparse, it might be interesting to exploit this fact to allocate only what is necessary in memory. This has the downside of being slower in terms query but I think this is definitely optimizable. This was just a proof of concept.
1 parent 54e475e commit 70a58b6

File tree

7 files changed

+336
-136
lines changed

7 files changed

+336
-136
lines changed

Python/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

Python/diadic_test.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,71 @@
22
"""
33
The xcount variable below sets how many records would be stored & queried within memory
44
"""
5-
xcount = 450000
6-
SDR_SIZE = 1000 # Size of SDRs in bits
7-
SDR_BITS = 10 # Default number of ON bits (aka solidity)
5+
import argparse
86

97
from time import time
108
import numpy as np
11-
from sdrsdm import DiadicMemory, randomSDR
12-
13-
print(f"Testing DiadicMemory for {xcount} entries,\nsdr size is {SDR_SIZE} with {SDR_BITS} of ON bits")
14-
# The number of records to write and query
15-
16-
t = time()
17-
X = randomSDR(xcount+1, N = SDR_SIZE, P = SDR_BITS)
18-
t = time() - t
19-
print(f"{xcount+1} random SDRs generated in {int(t*1000)}ms")
20-
21-
22-
mem = DiadicMemory(SDR_SIZE,SDR_BITS) ######################################## Initialize the memory
23-
24-
t = time()
25-
for i,x in enumerate(X):
26-
if i == xcount:
27-
break
28-
y = X[i+1]
29-
mem.store(x,y) ######################################################## Store operation
30-
t = time() - t
31-
print(f"{xcount} writes in {int(t*1000)} ms")
32-
33-
print("Testing queries")
34-
35-
size_errors = {}
36-
found = np.zeros((xcount,mem.P),dtype = np.uint16)
37-
t = time()
38-
39-
for i in range(xcount):
40-
qresult = mem.query(X[i]) ########################################## Query operations
41-
if qresult.size == mem.P:
42-
found[i] = qresult
43-
else:
44-
found[i] = X[i+1] #
45-
size_errors[i] = qresult
46-
47-
t = time() - t
48-
49-
print(f"{xcount} queries done in {int(t*1000)}ms")
50-
51-
print("Comparing results with expectations")
52-
if len(size_errors):
53-
print(f"{len(size_errors)} size errors, check size_errors dictionary")
54-
diff = (X[1:] != found)
55-
56-
print(f"{(diff.sum(axis=1) > 0).sum()} differences check diff array")
9+
from mem_sdrsdm import DiadicMemory as MemDiadicMemory
10+
from sdrsdm import DiadicMemory
11+
from sdr_util import random_sdrs
12+
13+
14+
SDR_SIZE = 1000 # Size of SDRs in bits
15+
SDR_BITS = 10 # Default number of ON bits (aka solidity)
16+
17+
def test_diadic(mem, xcount):
18+
print(f"Testing DiadicMemory for {xcount} entries,\nsdr size is {SDR_SIZE} with {SDR_BITS} of ON bits")
19+
# The number of records to write and query
20+
21+
t = time()
22+
X = random_sdrs(xcount+1, sdr_size = SDR_SIZE, on_bits = SDR_BITS)
23+
t = time() - t
24+
print(f"{xcount+1} random SDRs generated in {int(t*1000)}ms")
25+
26+
t = time()
27+
for i,x in enumerate(X):
28+
if i == xcount:
29+
break
30+
y = X[i+1]
31+
mem.store(x,y) ######################################################## Store operation
32+
t = time() - t
33+
print(f"{xcount} writes in {int(t*1000)} ms")
34+
35+
print("Testing queries")
36+
37+
size_errors = {}
38+
found = np.zeros((xcount,mem.P),dtype = np.uint16)
39+
t = time()
40+
41+
for i in range(xcount):
42+
qresult = mem.query(X[i]) ########################################## Query operations
43+
if qresult.size == mem.P:
44+
found[i] = qresult
45+
else:
46+
found[i] = X[i+1] #
47+
size_errors[i] = qresult
48+
49+
t = time() - t
50+
51+
print(f"{xcount} queries done in {int(t*1000)}ms")
52+
53+
print("Comparing results with expectations")
54+
if len(size_errors):
55+
print(f"{len(size_errors)} size errors, check size_errors dictionary")
56+
diff = (X[1:] != found)
57+
58+
print(f"{(diff.sum(axis=1) > 0).sum()} differences check diff array")
59+
60+
parser = argparse.ArgumentParser(
61+
prog='diadic tester',
62+
description='test the diadic memory')
63+
parser.add_argument('-c', '--count', type=int, default=45000)
64+
parser.add_argument('-m', '--mem', action='store_true')
65+
args = parser.parse_args()
66+
67+
if args.mem:
68+
mem = MemDiadicMemory(SDR_SIZE, SDR_BITS) ######################################## Initialize the memory
69+
else:
70+
mem = DiadicMemory(SDR_SIZE, SDR_BITS)
71+
72+
test_diadic(mem, args.count)

Python/mem_sdrsdm.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
This code is an implementation of the Triadic Memory and Dyadic Memory algorithms
3+
4+
Copyright (c) 2021-2022 Peter Overmann
5+
Copyright (c) 2022 Cezar Totth
6+
Copyright (c) 2023 Clément Michaud
7+
8+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
9+
and associated documentation files (the “Software”), to deal in the Software without restriction,
10+
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
11+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
12+
subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in all copies or substantial
15+
portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
18+
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
20+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
21+
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22+
"""
23+
import numpy as np
24+
import numba
25+
26+
def xaddr(x):
27+
addr = []
28+
for i in range(1,len(x)):
29+
for j in range(i):
30+
yield x[i] * (x[i] - 1) // 2 + x[j]
31+
32+
def store_xy(mem, x, y):
33+
"""
34+
Stores Y under key X
35+
Y and X have to be sorted sparsely encoded SDRs
36+
"""
37+
for addr in xaddr(x):
38+
for j in y:
39+
mem[addr][j] = 1
40+
41+
def store_xyz(mem, x, y, z):
42+
"""
43+
Stores X, Y, Z triplet in mem
44+
All X, Y, Z have to be sparse encoded SDRs
45+
"""
46+
for ax in x:
47+
for ay in y:
48+
for az in z:
49+
mem['x'][ay][az][ax] = 1
50+
mem['y'][ax][az][ay] = 1
51+
mem['z'][ax][ay][az] = 1
52+
53+
def query(mem, N, P, x):
54+
"""
55+
Query in DiadicMemory
56+
"""
57+
sums = np.zeros(N, dtype=np.uint32)
58+
for addr in xaddr(x):
59+
for k, v in mem[addr].items():
60+
sums[k] += v
61+
return sums2sdr(sums, P)
62+
63+
def queryZ(mem, N, P, x, y):
64+
sums = np.zeros(N, dtype=np.uint32)
65+
for ax in x:
66+
for ay in y:
67+
# print(ax, ay, mem['z'][ax][ay])
68+
for k, v in mem['z'][ax][ay].items():
69+
sums[k] += v
70+
return sums2sdr(sums, P)
71+
72+
def queryX(mem, N, P, y, z):
73+
sums = np.zeros(N, dtype=np.uint32)
74+
for ay in y:
75+
for az in z:
76+
for k, v in mem['x'][ay][az].items():
77+
sums[k] += v
78+
return sums2sdr(sums, P)
79+
80+
def queryY(mem, N, P, x, z):
81+
sums = np.zeros(N, dtype=np.uint32)
82+
for ax in x:
83+
for az in z:
84+
for k, v in mem['y'][ax][az].items():
85+
sums[k] += v
86+
return sums2sdr(sums,P)
87+
88+
@numba.jit(nopython=True)
89+
def sums2sdr(sums, P):
90+
# this does what binarize() does in C
91+
ssums = sums.copy()
92+
ssums.sort()
93+
threshval = ssums[-P]
94+
if threshval == 0:
95+
return np.where(sums)[0] # All non zero values
96+
else:
97+
return np.where(sums >= threshval)[0] #
98+
99+
class TriadicMemory:
100+
def __init__(self, N, P):
101+
self.mem = {
102+
# We store the data 3 times to be able to query any of the variable but
103+
# if one knows which variable is to be queried, we can get rid of 2/3 of
104+
# the memory used.
105+
'x': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
106+
'y': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
107+
'z': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
108+
}
109+
self.P = P
110+
self.N = N
111+
112+
def store(self, x, y, z):
113+
store_xyz(self.mem, x, y, z)
114+
115+
def query(self, x, y, z = None):
116+
# query for either x, y or z.
117+
# The queried member must be provided as None
118+
# the other two members have to be encoded as sorted sparse SDRs
119+
if z is None:
120+
return queryZ(self.mem, self.N, self.P, x, y)
121+
elif x is None:
122+
return queryX(self.mem, self.N, self.P, y, z)
123+
elif y is None:
124+
return queryY(self.mem, self.N, self.P, x, z)
125+
126+
def query_X(self, y, z):
127+
return queryX(self.mem, self.N, self.P, y, z)
128+
129+
def query_Y(self, x, z):
130+
return queryY(self.mem, self.N, self.P, x, z)
131+
132+
def query_Z(self, x, y):
133+
return queryZ(self.mem, self.N, self.P, x, y)
134+
135+
def query_x_with_P(self, y, z, P):
136+
return queryX(self.mem, self.N, P, y, z)
137+
138+
def size(self):
139+
# TODO: computing the size everytime might not be very efficient,
140+
# instead we could probably keep track of how many bits are stored.
141+
size = 0
142+
for x in self.mem['x'].values():
143+
for v1 in x.values():
144+
size += len(v1)
145+
return size * 3
146+
147+
from collections import defaultdict
148+
149+
class DiadicMemory:
150+
"""
151+
this is a convenient object front end for SDM functions
152+
"""
153+
def __init__(self, N, P):
154+
"""
155+
N is SDR vector size, e.g. 1000
156+
P is the count of solid bits e.g. 10
157+
"""
158+
self.mem = defaultdict(lambda: defaultdict(lambda: 0))
159+
self.N = N
160+
self.P = P
161+
162+
def store(self, x, y):
163+
store_xy(self.mem, x, y)
164+
165+
def query(self, x):
166+
return query(self.mem, self.N, self.P, x)
167+
168+
def size(self):
169+
size = 0
170+
for v in self.mem.values():
171+
size += len(v)
172+
return size

Python/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
numba 0.55.0
2-
numpy 1.19.4
1+
numba==0.58.0
2+
numpy==1.25.2

Python/sdr_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def sdr_distance(n1, n2):
102102
"""
103103
return 1.0 - 2.0 * sdr_overlap(n1, n2) / (len(n1) + len(n2))
104104

105-
@numba.jit
105+
@numba.jit(nopython=True)
106106
def random_sdr(sdr_size, sdr_len):
107107
out = np.zeros(sdr_len, dtype = np.uint32)
108108
r = np.random.randint(0,sdr_size)
@@ -113,7 +113,7 @@ def random_sdr(sdr_size, sdr_len):
113113
out.sort()
114114
return out
115115

116-
@numba.jit
116+
@numba.jit(nopython=True)
117117
def near_sdr(sdr, sdr_size, switch = 3):
118118
"""
119119
returns a sdr close to input sdr by switching switch bits

0 commit comments

Comments
 (0)