Skip to content

Commit 99bbada

Browse files
committed
update
1 parent 64c1437 commit 99bbada

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

unsupervised_class3/bayes_classifier_gaussian.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from scipy.stats import multivariate_normal as mvn
1212

1313

14+
def clamp_sample(x):
15+
x = np.minimum(x, 1)
16+
x = np.maximum(x, 0)
17+
return x
18+
19+
1420
class BayesClassifier:
1521
def fit(self, X, Y):
1622
# assume classes are numbered 0...K-1
@@ -30,11 +36,11 @@ def fit(self, X, Y):
3036

3137
def sample_given_y(self, y):
3238
g = self.gaussians[y]
33-
return mvn.rvs(mean=g['m'], cov=g['c'])
39+
return clamp_sample( mvn.rvs(mean=g['m'], cov=g['c']) )
3440

3541
def sample(self):
3642
y = np.random.choice(self.K, p=self.p_y)
37-
return self.sample_given_y(y)
43+
return clamp_sample( self.sample_given_y(y) )
3844

3945

4046
if __name__ == '__main__':

unsupervised_class3/bayes_classifier_gmm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from sklearn.mixture import BayesianGaussianMixture
1212

1313

14+
def clamp_sample(x):
15+
x = np.minimum(x, 1)
16+
x = np.maximum(x, 0)
17+
return x
18+
19+
1420
class BayesClassifier:
1521
def fit(self, X, Y):
1622
# assume classes are numbered 0...K-1
@@ -39,11 +45,11 @@ def sample_given_y(self, y):
3945
# we cheat by looking at "non-public" params in
4046
# the sklearn source code
4147
mean = gmm.means_[sample[1]]
42-
return sample[0].reshape(28, 28), mean.reshape(28, 28)
48+
return clamp_sample( sample[0].reshape(28, 28) ), mean.reshape(28, 28)
4349

4450
def sample(self):
4551
y = np.random.choice(self.K, p=self.p_y)
46-
return self.sample_given_y(y)
52+
return clamp_sample( self.sample_given_y(y) )
4753

4854

4955
if __name__ == '__main__':

unsupervised_class3/util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy.misc import imread, imsave, imresize
1515
from glob import glob
1616
from tqdm import tqdm
17+
from sklearn.utils import shuffle
1718

1819

1920
def get_mnist(limit=None):
@@ -26,10 +27,11 @@ def get_mnist(limit=None):
2627

2728
print("Reading in and transforming data...")
2829
df = pd.read_csv('../large_files/train.csv')
29-
data = df.as_matrix()
30-
np.random.shuffle(data)
30+
data = df.values
31+
# np.random.shuffle(data)
3132
X = data[:, 1:] / 255.0 # data is from 0..255
3233
Y = data[:, 0]
34+
X, Y = shuffle(X, Y)
3335
if limit is not None:
3436
X, Y = X[:limit], Y[:limit]
3537
return X, Y

0 commit comments

Comments
 (0)