Skip to content

Commit c9449a6

Browse files
committed
update
1 parent f9687b7 commit c9449a6

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

recommenders/rbm_tf_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def build(self, D, M, K):
107107
logits = dot2(H, self.W) + self.b
108108
cdist = tf.distributions.Categorical(logits=logits)
109109
X_sample = cdist.sample() # shape is (N, D)
110-
X_sample = tf.one_hot(X_sample, depth=self.K) # turn it into (N, D, K)
110+
X_sample = tf.one_hot(X_sample, depth=K) # turn it into (N, D, K)
111111
X_sample = X_sample * self.mask # missing ratings shouldn't contribute to objective
112112

113113

recommenders/rbm_tf_k_faster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def build(self, D, M, K):
6565
logits = dot2(H, self.W) + self.b
6666
cdist = tf.distributions.Categorical(logits=logits)
6767
X_sample = cdist.sample() # shape is (N, D)
68-
X_sample = tf.one_hot(X_sample, depth=self.K) # turn it into (N, D, K)
68+
X_sample = tf.one_hot(X_sample, depth=K) # turn it into (N, D, K)
6969

7070
# mask X_sample to remove missing ratings
7171
mask2d = tf.cast(self.X_in > 0, tf.float32)
@@ -94,7 +94,7 @@ def build(self, D, M, K):
9494

9595

9696
# for calculating SSE
97-
self.one_to_ten = tf.constant(one_to_ten.astype(np.float32) / 2)
97+
self.one_to_ten = tf.constant((np.arange(10) + 1).astype(np.float32) / 2)
9898
self.pred = tf.tensordot(self.output_visible, self.one_to_ten, axes=[[2], [0]])
9999
mask = tf.cast(self.X_in > 0, tf.float32)
100100
se = mask * (self.X_in - self.pred) * (self.X_in - self.pred)

0 commit comments

Comments
 (0)