Skip to content

Commit ae9b61c

Browse files
Merge pull request #20 from clems4ever/memory-tradeoff
add an implementation trading computation for memory
2 parents 54e475e + 70a58b6 commit ae9b61c

File tree

7 files changed

+336
-136
lines changed

7 files changed

+336
-136
lines changed

Python/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

Python/diadic_test.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,71 @@
22
"""
33
The xcount variable below sets how many records would be stored & queried within memory
44
"""
5-
xcount = 450000
6-
SDR_SIZE = 1000 # Size of SDRs in bits
7-
SDR_BITS = 10 # Default number of ON bits (aka solidity)
5+
import argparse
86

97
from time import time
108
import numpy as np
11-
from sdrsdm import DiadicMemory, randomSDR
12-
13-
print(f"Testing DiadicMemory for {xcount} entries,\nsdr size is {SDR_SIZE} with {SDR_BITS} of ON bits")
14-
# The number of records to write and query
15-
16-
t = time()
17-
X = randomSDR(xcount+1, N = SDR_SIZE, P = SDR_BITS)
18-
t = time() - t
19-
print(f"{xcount+1} random SDRs generated in {int(t*1000)}ms")
20-
21-
22-
mem = DiadicMemory(SDR_SIZE,SDR_BITS) ######################################## Initialize the memory
23-
24-
t = time()
25-
for i,x in enumerate(X):
26-
if i == xcount:
27-
break
28-
y = X[i+1]
29-
mem.store(x,y) ######################################################## Store operation
30-
t = time() - t
31-
print(f"{xcount} writes in {int(t*1000)} ms")
32-
33-
print("Testing queries")
34-
35-
size_errors = {}
36-
found = np.zeros((xcount,mem.P),dtype = np.uint16)
37-
t = time()
38-
39-
for i in range(xcount):
40-
qresult = mem.query(X[i]) ########################################## Query operations
41-
if qresult.size == mem.P:
42-
found[i] = qresult
43-
else:
44-
found[i] = X[i+1] #
45-
size_errors[i] = qresult
46-
47-
t = time() - t
48-
49-
print(f"{xcount} queries done in {int(t*1000)}ms")
50-
51-
print("Comparing results with expectations")
52-
if len(size_errors):
53-
print(f"{len(size_errors)} size errors, check size_errors dictionary")
54-
diff = (X[1:] != found)
55-
56-
print(f"{(diff.sum(axis=1) > 0).sum()} differences check diff array")
9+
from mem_sdrsdm import DiadicMemory as MemDiadicMemory
10+
from sdrsdm import DiadicMemory
11+
from sdr_util import random_sdrs
12+
13+
14+
SDR_SIZE = 1000 # Size of SDRs in bits
15+
SDR_BITS = 10 # Default number of ON bits (aka solidity)
16+
17+
def test_diadic(mem, xcount):
18+
print(f"Testing DiadicMemory for {xcount} entries,\nsdr size is {SDR_SIZE} with {SDR_BITS} of ON bits")
19+
# The number of records to write and query
20+
21+
t = time()
22+
X = random_sdrs(xcount+1, sdr_size = SDR_SIZE, on_bits = SDR_BITS)
23+
t = time() - t
24+
print(f"{xcount+1} random SDRs generated in {int(t*1000)}ms")
25+
26+
t = time()
27+
for i,x in enumerate(X):
28+
if i == xcount:
29+
break
30+
y = X[i+1]
31+
mem.store(x,y) ######################################################## Store operation
32+
t = time() - t
33+
print(f"{xcount} writes in {int(t*1000)} ms")
34+
35+
print("Testing queries")
36+
37+
size_errors = {}
38+
found = np.zeros((xcount,mem.P),dtype = np.uint16)
39+
t = time()
40+
41+
for i in range(xcount):
42+
qresult = mem.query(X[i]) ########################################## Query operations
43+
if qresult.size == mem.P:
44+
found[i] = qresult
45+
else:
46+
found[i] = X[i+1] #
47+
size_errors[i] = qresult
48+
49+
t = time() - t
50+
51+
print(f"{xcount} queries done in {int(t*1000)}ms")
52+
53+
print("Comparing results with expectations")
54+
if len(size_errors):
55+
print(f"{len(size_errors)} size errors, check size_errors dictionary")
56+
diff = (X[1:] != found)
57+
58+
print(f"{(diff.sum(axis=1) > 0).sum()} differences check diff array")
59+
60+
parser = argparse.ArgumentParser(
61+
prog='diadic tester',
62+
description='test the diadic memory')
63+
parser.add_argument('-c', '--count', type=int, default=45000)
64+
parser.add_argument('-m', '--mem', action='store_true')
65+
args = parser.parse_args()
66+
67+
if args.mem:
68+
mem = MemDiadicMemory(SDR_SIZE, SDR_BITS) ######################################## Initialize the memory
69+
else:
70+
mem = DiadicMemory(SDR_SIZE, SDR_BITS)
71+
72+
test_diadic(mem, args.count)

Python/mem_sdrsdm.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
This code is an implementation of the Triadic Memory and Dyadic Memory algorithms
3+
4+
Copyright (c) 2021-2022 Peter Overmann
5+
Copyright (c) 2022 Cezar Totth
6+
Copyright (c) 2023 Clément Michaud
7+
8+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software
9+
and associated documentation files (the “Software”), to deal in the Software without restriction,
10+
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
11+
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
12+
subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in all copies or substantial
15+
portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
18+
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19+
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
20+
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
21+
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22+
"""
23+
import numpy as np
24+
import numba
25+
26+
def xaddr(x):
27+
addr = []
28+
for i in range(1,len(x)):
29+
for j in range(i):
30+
yield x[i] * (x[i] - 1) // 2 + x[j]
31+
32+
def store_xy(mem, x, y):
33+
"""
34+
Stores Y under key X
35+
Y and X have to be sorted sparsely encoded SDRs
36+
"""
37+
for addr in xaddr(x):
38+
for j in y:
39+
mem[addr][j] = 1
40+
41+
def store_xyz(mem, x, y, z):
42+
"""
43+
Stores X, Y, Z triplet in mem
44+
All X, Y, Z have to be sparse encoded SDRs
45+
"""
46+
for ax in x:
47+
for ay in y:
48+
for az in z:
49+
mem['x'][ay][az][ax] = 1
50+
mem['y'][ax][az][ay] = 1
51+
mem['z'][ax][ay][az] = 1
52+
53+
def query(mem, N, P, x):
54+
"""
55+
Query in DiadicMemory
56+
"""
57+
sums = np.zeros(N, dtype=np.uint32)
58+
for addr in xaddr(x):
59+
for k, v in mem[addr].items():
60+
sums[k] += v
61+
return sums2sdr(sums, P)
62+
63+
def queryZ(mem, N, P, x, y):
64+
sums = np.zeros(N, dtype=np.uint32)
65+
for ax in x:
66+
for ay in y:
67+
# print(ax, ay, mem['z'][ax][ay])
68+
for k, v in mem['z'][ax][ay].items():
69+
sums[k] += v
70+
return sums2sdr(sums, P)
71+
72+
def queryX(mem, N, P, y, z):
73+
sums = np.zeros(N, dtype=np.uint32)
74+
for ay in y:
75+
for az in z:
76+
for k, v in mem['x'][ay][az].items():
77+
sums[k] += v
78+
return sums2sdr(sums, P)
79+
80+
def queryY(mem, N, P, x, z):
81+
sums = np.zeros(N, dtype=np.uint32)
82+
for ax in x:
83+
for az in z:
84+
for k, v in mem['y'][ax][az].items():
85+
sums[k] += v
86+
return sums2sdr(sums,P)
87+
88+
@numba.jit(nopython=True)
89+
def sums2sdr(sums, P):
90+
# this does what binarize() does in C
91+
ssums = sums.copy()
92+
ssums.sort()
93+
threshval = ssums[-P]
94+
if threshval == 0:
95+
return np.where(sums)[0] # All non zero values
96+
else:
97+
return np.where(sums >= threshval)[0] #
98+
99+
class TriadicMemory:
100+
def __init__(self, N, P):
101+
self.mem = {
102+
# We store the data 3 times to be able to query any of the variable but
103+
# if one knows which variable is to be queried, we can get rid of 2/3 of
104+
# the memory used.
105+
'x': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
106+
'y': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
107+
'z': defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: 0))),
108+
}
109+
self.P = P
110+
self.N = N
111+
112+
def store(self, x, y, z):
113+
store_xyz(self.mem, x, y, z)
114+
115+
def query(self, x, y, z = None):
116+
# query for either x, y or z.
117+
# The queried member must be provided as None
118+
# the other two members have to be encoded as sorted sparse SDRs
119+
if z is None:
120+
return queryZ(self.mem, self.N, self.P, x, y)
121+
elif x is None:
122+
return queryX(self.mem, self.N, self.P, y, z)
123+
elif y is None:
124+
return queryY(self.mem, self.N, self.P, x, z)
125+
126+
def query_X(self, y, z):
127+
return queryX(self.mem, self.N, self.P, y, z)
128+
129+
def query_Y(self, x, z):
130+
return queryY(self.mem, self.N, self.P, x, z)
131+
132+
def query_Z(self, x, y):
133+
return queryZ(self.mem, self.N, self.P, x, y)
134+
135+
def query_x_with_P(self, y, z, P):
136+
return queryX(self.mem, self.N, P, y, z)
137+
138+
def size(self):
139+
# TODO: computing the size everytime might not be very efficient,
140+
# instead we could probably keep track of how many bits are stored.
141+
size = 0
142+
for x in self.mem['x'].values():
143+
for v1 in x.values():
144+
size += len(v1)
145+
return size * 3
146+
147+
from collections import defaultdict
148+
149+
class DiadicMemory:
150+
"""
151+
this is a convenient object front end for SDM functions
152+
"""
153+
def __init__(self, N, P):
154+
"""
155+
N is SDR vector size, e.g. 1000
156+
P is the count of solid bits e.g. 10
157+
"""
158+
self.mem = defaultdict(lambda: defaultdict(lambda: 0))
159+
self.N = N
160+
self.P = P
161+
162+
def store(self, x, y):
163+
store_xy(self.mem, x, y)
164+
165+
def query(self, x):
166+
return query(self.mem, self.N, self.P, x)
167+
168+
def size(self):
169+
size = 0
170+
for v in self.mem.values():
171+
size += len(v)
172+
return size

Python/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
numba 0.55.0
2-
numpy 1.19.4
1+
numba==0.58.0
2+
numpy==1.25.2

Python/sdr_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def sdr_distance(n1, n2):
102102
"""
103103
return 1.0 - 2.0 * sdr_overlap(n1, n2) / (len(n1) + len(n2))
104104

105-
@numba.jit
105+
@numba.jit(nopython=True)
106106
def random_sdr(sdr_size, sdr_len):
107107
out = np.zeros(sdr_len, dtype = np.uint32)
108108
r = np.random.randint(0,sdr_size)
@@ -113,7 +113,7 @@ def random_sdr(sdr_size, sdr_len):
113113
out.sort()
114114
return out
115115

116-
@numba.jit
116+
@numba.jit(nopython=True)
117117
def near_sdr(sdr, sdr_size, switch = 3):
118118
"""
119119
returns a sdr close to input sdr by switching switch bits

0 commit comments

Comments
 (0)