1
+ """
2
+ This code is an implementation of the Triadic Memory and Dyadic Memory algorithms
3
+
4
+ Copyright (c) 2021-2022 Peter Overmann
5
+ Copyright (c) 2022 Cezar Totth
6
+ Copyright (c) 2023 Clément Michaud
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software
9
+ and associated documentation files (the “Software”), to deal in the Software without restriction,
10
+ including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
11
+ and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
12
+ subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all copies or substantial
15
+ portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
18
+ LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
20
+ WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE
21
+ OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22
+ """
23
+ import numpy as np
24
+ import numba
25
+
26
+ def xaddr (x ):
27
+ addr = []
28
+ for i in range (1 ,len (x )):
29
+ for j in range (i ):
30
+ yield x [i ] * (x [i ] - 1 ) // 2 + x [j ]
31
+
32
+ def store_xy (mem , x , y ):
33
+ """
34
+ Stores Y under key X
35
+ Y and X have to be sorted sparsely encoded SDRs
36
+ """
37
+ for addr in xaddr (x ):
38
+ for j in y :
39
+ mem [addr ][j ] = 1
40
+
41
+ def store_xyz (mem , x , y , z ):
42
+ """
43
+ Stores X, Y, Z triplet in mem
44
+ All X, Y, Z have to be sparse encoded SDRs
45
+ """
46
+ for ax in x :
47
+ for ay in y :
48
+ for az in z :
49
+ mem ['x' ][ay ][az ][ax ] = 1
50
+ mem ['y' ][ax ][az ][ay ] = 1
51
+ mem ['z' ][ax ][ay ][az ] = 1
52
+
53
+ def query (mem , N , P , x ):
54
+ """
55
+ Query in DiadicMemory
56
+ """
57
+ sums = np .zeros (N , dtype = np .uint32 )
58
+ for addr in xaddr (x ):
59
+ for k , v in mem [addr ].items ():
60
+ sums [k ] += v
61
+ return sums2sdr (sums , P )
62
+
63
+ def queryZ (mem , N , P , x , y ):
64
+ sums = np .zeros (N , dtype = np .uint32 )
65
+ for ax in x :
66
+ for ay in y :
67
+ # print(ax, ay, mem['z'][ax][ay])
68
+ for k , v in mem ['z' ][ax ][ay ].items ():
69
+ sums [k ] += v
70
+ return sums2sdr (sums , P )
71
+
72
+ def queryX (mem , N , P , y , z ):
73
+ sums = np .zeros (N , dtype = np .uint32 )
74
+ for ay in y :
75
+ for az in z :
76
+ for k , v in mem ['x' ][ay ][az ].items ():
77
+ sums [k ] += v
78
+ return sums2sdr (sums , P )
79
+
80
+ def queryY (mem , N , P , x , z ):
81
+ sums = np .zeros (N , dtype = np .uint32 )
82
+ for ax in x :
83
+ for az in z :
84
+ for k , v in mem ['y' ][ax ][az ].items ():
85
+ sums [k ] += v
86
+ return sums2sdr (sums ,P )
87
+
88
+ @numba .jit (nopython = True )
89
+ def sums2sdr (sums , P ):
90
+ # this does what binarize() does in C
91
+ ssums = sums .copy ()
92
+ ssums .sort ()
93
+ threshval = ssums [- P ]
94
+ if threshval == 0 :
95
+ return np .where (sums )[0 ] # All non zero values
96
+ else :
97
+ return np .where (sums >= threshval )[0 ] #
98
+
99
+ class TriadicMemory :
100
+ def __init__ (self , N , P ):
101
+ self .mem = {
102
+ # We store the data 3 times to be able to query any of the variable but
103
+ # if one knows which variable is to be queried, we can get rid of 2/3 of
104
+ # the memory used.
105
+ 'x' : defaultdict (lambda : defaultdict (lambda : defaultdict (lambda : 0 ))),
106
+ 'y' : defaultdict (lambda : defaultdict (lambda : defaultdict (lambda : 0 ))),
107
+ 'z' : defaultdict (lambda : defaultdict (lambda : defaultdict (lambda : 0 ))),
108
+ }
109
+ self .P = P
110
+ self .N = N
111
+
112
+ def store (self , x , y , z ):
113
+ store_xyz (self .mem , x , y , z )
114
+
115
+ def query (self , x , y , z = None ):
116
+ # query for either x, y or z.
117
+ # The queried member must be provided as None
118
+ # the other two members have to be encoded as sorted sparse SDRs
119
+ if z is None :
120
+ return queryZ (self .mem , self .N , self .P , x , y )
121
+ elif x is None :
122
+ return queryX (self .mem , self .N , self .P , y , z )
123
+ elif y is None :
124
+ return queryY (self .mem , self .N , self .P , x , z )
125
+
126
+ def query_X (self , y , z ):
127
+ return queryX (self .mem , self .N , self .P , y , z )
128
+
129
+ def query_Y (self , x , z ):
130
+ return queryY (self .mem , self .N , self .P , x , z )
131
+
132
+ def query_Z (self , x , y ):
133
+ return queryZ (self .mem , self .N , self .P , x , y )
134
+
135
+ def query_x_with_P (self , y , z , P ):
136
+ return queryX (self .mem , self .N , P , y , z )
137
+
138
+ def size (self ):
139
+ # TODO: computing the size everytime might not be very efficient,
140
+ # instead we could probably keep track of how many bits are stored.
141
+ size = 0
142
+ for x in self .mem ['x' ].values ():
143
+ for v1 in x .values ():
144
+ size += len (v1 )
145
+ return size * 3
146
+
147
+ from collections import defaultdict
148
+
149
+ class DiadicMemory :
150
+ """
151
+ this is a convenient object front end for SDM functions
152
+ """
153
+ def __init__ (self , N , P ):
154
+ """
155
+ N is SDR vector size, e.g. 1000
156
+ P is the count of solid bits e.g. 10
157
+ """
158
+ self .mem = defaultdict (lambda : defaultdict (lambda : 0 ))
159
+ self .N = N
160
+ self .P = P
161
+
162
+ def store (self , x , y ):
163
+ store_xy (self .mem , x , y )
164
+
165
+ def query (self , x ):
166
+ return query (self .mem , self .N , self .P , x )
167
+
168
+ def size (self ):
169
+ size = 0
170
+ for v in self .mem .values ():
171
+ size += len (v )
172
+ return size
0 commit comments