Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions lib/classifier/bayes.rb
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,13 @@ def remove_category(category)
# puts "#{progress.completed} documents processed"
# end
#
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
category = category.prepare_category_name
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
# @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void
def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &)
raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty?
raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil?

reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)

reader.each_batch do |batch|
train_batch_internal(category, batch)
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
(category && io ? { category => io } : categories).each do |(category, io)|
stream_train_category(category, io, batch_size: batch_size, &)
end
end

Expand Down Expand Up @@ -389,6 +382,25 @@ def self.load_checkpoint(storage:, checkpoint_id:)

private

# Trains from an IO stream with a single category.
# @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void
def stream_train_category(category, io, batch_size:)
category = category.prepare_category_name
raise StandardError, "No such category: #{category}" unless @categories.key?(category)
raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line)

reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)

reader.each_batch do |batch|
train_batch_internal(category, batch)
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
end
end

# Trains a batch of documents for a single category.
# @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
Expand Down
7 changes: 4 additions & 3 deletions lib/classifier/knn.rb
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,10 @@ def self.load_checkpoint(storage:, checkpoint_id:)
# puts "#{progress.completed} documents processed"
# end
#
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE, &block)
@lsi.train_from_stream(category, io, batch_size: batch_size, &block)
# @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void
def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &block)
# @type var categories: untyped
@lsi.train_from_stream(category, io, batch_size: batch_size, **categories, &block)
synchronize { @dirty = true }
end

Expand Down
54 changes: 33 additions & 21 deletions lib/classifier/logistic_regression.rb
Original file line number Diff line number Diff line change
Expand Up @@ -390,28 +390,13 @@ def self.load_checkpoint(storage:, checkpoint_id:)
# end
# classifier.fit
#
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
category = category.to_s.prepare_category_name
raise StandardError, "No such category: #{category}" unless @categories.include?(category)

reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)
# @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void
def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &)
raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty?
raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil?

reader.each_batch do |batch|
synchronize do
batch.each do |text|
features = text.word_hash(@min_word_length)
features.each_key { |word| @vocabulary[word] = true }
@training_data << { category: category, features: features }
end
@fitted = false
@dirty = true
end
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
(category && io ? { category => io } : categories).each do |(category, io)|
stream_train_category(category, io, batch_size:, &)
end
end

Expand Down Expand Up @@ -440,6 +425,33 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_

private

# Trains from an IO stream with a single category.
# @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void
def stream_train_category(category, io, batch_size:)
category = category.to_s.prepare_category_name
raise StandardError, "No such category: #{category}" unless @categories.include?(category)
raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line)

reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)

reader.each_batch do |batch|
synchronize do
batch.each do |text|
features = text.word_hash(@min_word_length)
features.each_key { |word| @vocabulary[word] = true }
@training_data << { category: category, features: features }
end
@fitted = false
@dirty = true
end
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
end
end

# Trains a batch of documents for a single category.
# @rbs (String | Symbol, Array[String], ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_batch_for_category(category, documents, batch_size: Streaming::DEFAULT_BATCH_SIZE)
Expand Down
44 changes: 27 additions & 17 deletions lib/classifier/lsi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -662,26 +662,19 @@ def self.load_checkpoint(storage:, checkpoint_id:)
# puts "#{progress.completed} documents processed"
# end
#
# @rbs (String | Symbol, IO, ?batch_size: Integer) { (Streaming::Progress) -> void } -> void
def train_from_stream(category, io, batch_size: Streaming::DEFAULT_BATCH_SIZE)
# @rbs (?(String | Symbol | nil), ?IO?, ?batch_size: Integer, **IO) { (Streaming::Progress) -> void } -> void
def train_from_stream(category = nil, io = nil, batch_size: Streaming::DEFAULT_BATCH_SIZE, **categories, &)
raise ArgumentError, 'Provide either (category, io) or keyword category: io pairs' if category.nil? && io.nil? && categories.empty?
raise ArgumentError, 'Provide both category and io, or use keyword arguments' if category.nil? ^ io.nil?

original_auto_rebuild = @auto_rebuild
@auto_rebuild = false

begin
reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)

reader.each_batch do |batch|
batch.each { |text| add_item(text, category) }
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
end
ensure
@auto_rebuild = original_auto_rebuild
build_index if original_auto_rebuild
(category && io ? { category => io } : categories).each do |(category, io)|
stream_train_category(category, io, batch_size:, &)
end
ensure
@auto_rebuild = original_auto_rebuild
build_index if original_auto_rebuild
end

# Adds items to the index in batches from an array.
Expand Down Expand Up @@ -729,6 +722,23 @@ def train_batch(category = nil, documents = nil, batch_size: Streaming::DEFAULT_

private

# Trains from an IO stream with a single category.
# @rbs (String | Symbol, IO, batch_size: Integer) { (Streaming::Progress) -> void } -> void
def stream_train_category(category, io, batch_size:)
raise StandardError, 'Stream must respond to #each_line' unless io.respond_to?(:each_line)

reader = Streaming::LineReader.new(io, batch_size: batch_size)
total = reader.estimate_line_count
progress = Streaming::Progress.new(total: total)

reader.each_batch do |batch|
batch.each { |text| add_item(text, category) }
progress.completed += batch.size
progress.current_batch += 1
yield progress if block_given?
end
end

# Restores LSI state from a JSON string (used by reload)
# @rbs (String) -> void
def restore_from_json(json)
Expand Down
4 changes: 2 additions & 2 deletions lib/classifier/streaming.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ module Streaming
# Trains the classifier from an IO stream.
# Each line in the stream is treated as a separate document.
#
# @rbs (Symbol | String, IO, ?batch_size: Integer) { (Progress) -> void } -> void
def train_from_stream(category, io, batch_size: DEFAULT_BATCH_SIZE, &block)
# @rbs (?(Symbol | String | nil), ?IO?, ?batch_size: Integer, **Hash[Symbol, IO]) { (Progress) -> void } -> void
def train_from_stream(category = nil, io = nil, batch_size: DEFAULT_BATCH_SIZE, **categories, &block)
raise NotImplementedError, "#{self.class} must implement train_from_stream"
end

Expand Down
17 changes: 17 additions & 0 deletions test/bayes/streaming_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ def test_train_from_stream_basic
assert_equal 'Spam', @classifier.classify('buy cheap free')
end

def test_train_from_stream_many_categories
classifier = Classifier::Bayes.new('Spam', 'Ham')
classifier.train_from_stream(
spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"),
ham: StringIO.new("hello friend\nmeeting tomorrow\n")
)

assert_equal 'Spam', classifier.classify('buy free')
assert_equal 'Ham', classifier.classify('hello meeting')
end

def test_train_from_stream_invalid_io_type
assert_raises(StandardError) do
@classifier.train_from_stream(spam: Object.new)
end
end

def test_train_from_stream_empty_io
io = StringIO.new('')
@classifier.train_from_stream(:spam, io)
Expand Down
29 changes: 29 additions & 0 deletions test/knn/streaming_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
require_relative '../test_helper'
require 'stringio'

class KNNStreamingTest < Minitest::Test
def test_train_from_stream_basic
knn = Classifier::KNN.new
knn.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n"))

assert_equal 'spam', knn.classify('buy cheap free')
end

def test_train_from_stream_many_categories
knn = Classifier::KNN.new
knn.train_from_stream(
spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"),
ham: StringIO.new("hello friend\nmeeting tomorrow\nhello fellow\n")
)

assert_equal 'spam', knn.classify('free offer')
assert_equal 'ham', knn.classify('hello')
end

def test_train_from_stream_invalid_io_type
knn = Classifier::KNN.new
assert_raises(StandardError) do
knn.train_from_stream(spam: Object.new)
end
end
end
31 changes: 31 additions & 0 deletions test/logistic_regression/streaming_test.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
require_relative '../test_helper'
require 'stringio'

class LogisticRegressionStreamingTest < Minitest::Test
def test_train_from_stream_basic
classifier = Classifier::LogisticRegression.new('Spam', 'Ham')
classifier.train_from_stream(:spam, StringIO.new("buy now cheap\nfree money\nlimited offer\n"))
classifier.fit

assert_equal 'Spam', classifier.classify('buy cheap free')
end

def test_train_from_stream_many_categories
classifier = Classifier::LogisticRegression.new('Spam', 'Ham')
classifier.train_from_stream(
spam: StringIO.new("buy now cheap\nfree money\nlimited offer\n"),
ham: StringIO.new("hello friend\nmeeting tomorrow\n")
)
classifier.fit

assert_equal 'Spam', classifier.classify('buy free')
assert_equal 'Ham', classifier.classify('hello meeting')
end

def test_train_from_stream_invalid_io_type
classifier = Classifier::LogisticRegression.new('Spam', 'Ham')
assert_raises(StandardError) do
classifier.train_from_stream(spam: Object.new)
end
end
end
17 changes: 17 additions & 0 deletions test/lsi/streaming_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ def test_train_from_stream_basic
assert_equal 'dog', result.to_s
end

def test_train_from_stream_many_categories
lsi = Classifier::LSI.new
lsi.train_from_stream(
dog: StringIO.new("dogs are loyal pets\npuppies are playful\ndogs bark at strangers\n"),
cat: StringIO.new("cats are independent\nkittens are curious\ncats meow softly\n")
)

assert_equal :dog, lsi.classify('loyal pet that barks')
assert_equal :cat, lsi.classify('independent curious pet')
end

def test_train_from_stream_invalid_io_type
assert_raises(StandardError) do
@lsi.train_from_stream(category: Object.new)
end
end

def test_train_from_stream_empty_io
@lsi.train_from_stream(:category, StringIO.new(''))

Expand Down
Loading