Skip to content

Commit 8b0698f

Browse files
committed
Add classify_with_confidence method to LSI classifier
- Implement classify_with_confidence method in LSI class - Update README with example usage of new method - Add tests for classify_with_confidence functionality - Include Vector extension in classifier.rb
1 parent 0cf2c3b commit 8b0698f

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

README.markdown

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ theoretically simulates human learning.
7777
lsi.classify "This text is also about dogs!"
7878
# returns => :dog
7979

80+
lsi.classify_with_confidence "This text is also about dogs!"
81+
# returns => [:dog, 1.0]
82+
8083
Please see the Classifier::LSI documentation for more information. It is possible to index, search and classify
8184
with more than just simple strings.
8285

lib/classifier.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626

2727
require 'rubygems'
2828
require 'classifier/extensions/string'
29+
require 'classifier/extensions/vector'
2930
require 'classifier/bayes'
3031
require 'classifier/lsi'

lib/classifier/lsi.rb

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ def find_related(doc, max_nearest = 3, &block)
247247
# what category the document is in. This may not always make sense.
248248
#
249249
def classify(doc, cutoff = 0.30, &block)
250+
votes = vote(doc, cutoff, &block)
251+
252+
ranking = votes.keys.sort_by { |x| votes[x] }
253+
ranking[-1]
254+
end
255+
256+
def vote(doc, cutoff = 0.30, &block)
250257
icutoff = (@items.size * cutoff).round
251258
carry = proximity_array_for_content(doc, &block)
252259
carry = carry[0..icutoff - 1]
@@ -258,9 +265,30 @@ def classify(doc, cutoff = 0.30, &block)
258265
votes[category] += pair[1]
259266
end
260267
end
268+
votes
269+
end
270+
271+
# Returns the same category as classify() but also returns
272+
# a confidence value derived from the vote share that the
273+
# winning category got.
274+
#
275+
# e.g.
276+
# category,confidence = classify_with_confidence(doc)
277+
# if confidence < 0.3
278+
# category = nil
279+
# end
280+
#
281+
#
282+
# See classify() for argument docs
283+
def classify_with_confidence(doc, cutoff = 0.30, &block)
284+
votes = vote(doc, cutoff, &block)
285+
votes_sum = votes.values.inject(0.0) { |sum, v| sum + v }
286+
return [nil, nil] if votes_sum.zero?
261287

262288
ranking = votes.keys.sort_by { |x| votes[x] }
263-
ranking[-1]
289+
winner = ranking[-1]
290+
vote_share = votes[winner] / votes_sum.to_f
291+
[winner, vote_share]
264292
end
265293

266294
# Prototype, only works on indexed documents.

test/lsi/lsi_test.rb

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,31 @@ def test_recategorize_interface
8383
assert_equal 'Cow', lsi.classify(tricky_case)
8484
end
8585

86+
def test_classify_with_confidence
87+
lsi = Classifier::LSI.new
88+
lsi.add_item @str2, 'Dog'
89+
lsi.add_item @str3, 'Cat'
90+
lsi.add_item @str4, 'Cat'
91+
lsi.add_item @str5, 'Bird'
92+
93+
category, confidence = lsi.classify_with_confidence(@str1)
94+
assert_equal 'Dog', category
95+
assert confidence > 0.5, "Confidence should be greater than 0.5, but was #{confidence}"
96+
97+
category, confidence = lsi.classify_with_confidence(@str3)
98+
assert_equal 'Cat', category
99+
assert confidence > 0.5, "Confidence should be greater than 0.5, but was #{confidence}"
100+
101+
category, confidence = lsi.classify_with_confidence(@str5)
102+
assert_equal 'Bird', category
103+
assert confidence > 0.5, "Confidence should be greater than 0.5, but was #{confidence}"
104+
105+
tricky_case = 'This text revolves around dogs.'
106+
category, confidence = lsi.classify_with_confidence(tricky_case)
107+
assert_equal 'Dog', category
108+
assert confidence > 0.3, "Confidence should be greater than 0.3, but was #{confidence}"
109+
end
110+
86111
def test_search
87112
lsi = Classifier::LSI.new
88113
[@str1, @str2, @str3, @str4, @str5].each { |x| lsi << x }

0 commit comments

Comments
 (0)