Skip to content

Commit b29bf2b

Browse files
Use heapq.nsmallest in NearestNeighborLearner.
1 parent dd9fe7a commit b29bf2b

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

learning.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Learn to estimate functions from examples. (Chapters 18-20)"""
22

33
from utils import *
4-
import random
4+
import heapq, random
55

66
#______________________________________________________________________________
77

@@ -202,26 +202,11 @@ def N(targetval, attr, attrval):
202202

203203
def NearestNeighborLearner(dataset, k=1):
204204
"k-NearestNeighbor: the k nearest neighbors vote."
205-
if k == 1:
206-
def predict(example):
207-
"Predict according to the point closest to example."
208-
neighbor = argmin(dataset.examples,
209-
lambda e: dataset.distance(e, example))
210-
return neighbor[dataset.target]
211-
else:
212-
def predict(example):
213-
"Find the k closest, and have them vote for the best."
214-
## Maintain a sorted list of (distance, example) pairs.
215-
## For very large k, a PriorityQueue would be better
216-
best = []
217-
for e in dataset.examples:
218-
d = dataset.distance(e, example)
219-
if len(best) < k:
220-
best.append((d, e))
221-
elif d < best[-1][0]:
222-
best[-1] = (d, e)
223-
best.sort()
224-
return mode([e[dataset.target] for (d, e) in best])
205+
def predict(example):
206+
"Find the k closest, and have them vote for the best."
207+
best = heapq.nsmallest(k, ((dataset.distance(e, example), e)
208+
for e in dataset.examples))
209+
return mode([e[dataset.target] for (d, e) in best])
225210
return predict
226211

227212
#______________________________________________________________________________

0 commit comments

Comments
 (0)