Skip to content

Commit 5694fbf

Browse files
committed
Address review feedback for Naive Bayes text Classifier
1 parent 048dea1 commit 5694fbf

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

machine_learning/naive_bayes_text_classification.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def fit(self, texts: list[str], labels: list[str]) -> None:
7272
...
7373
ValueError: training data must not be empty.
7474
"""
75-
if len(texts) != len(labels):
76-
raise ValueError("texts and labels must have the same length.")
7775
if not texts:
7876
raise ValueError("training data must not be empty.")
77+
if len(texts) != len(labels):
78+
raise ValueError("texts and labels must have the same length.")
7979

8080
self.classes_ = sorted(set(labels))
8181
self.vocabulary_.clear()
@@ -114,6 +114,11 @@ def predict_proba(self, text: str) -> dict[str, float]:
114114
>>> probs['spam'] > probs['ham']
115115
True
116116
117+
An empty input text has no tokens, so predictions fall back to class priors.
118+
>>> empty_probs = model.predict_proba("")
119+
>>> round(empty_probs['spam'], 3), round(empty_probs['ham'], 3)
120+
(0.5, 0.5)
121+
117122
>>> NaiveBayesTextClassifier().predict_proba("hello")
118123
Traceback (most recent call last):
119124
...
@@ -159,7 +164,7 @@ def predict(self, text: str) -> str:
159164
'ham'
160165
"""
161166
probabilities = self.predict_proba(text)
162-
return max(probabilities, key=probabilities.get)
167+
return max(probabilities, key=lambda label: probabilities[label])
163168

164169

165170
def build_toy_dataset() -> tuple[list[str], list[str]]:
@@ -188,3 +193,10 @@ def build_toy_dataset() -> tuple[list[str], list[str]]:
188193
import doctest
189194

190195
doctest.testmod()
196+
197+
sample_texts, sample_labels = build_toy_dataset()
198+
classifier = NaiveBayesTextClassifier(alpha=1.0)
199+
classifier.fit(sample_texts, sample_labels)
200+
201+
print("Prediction:",classifier.predict("cheap prizes available now"))
202+
print("Prediction:",classifier.predict("team meeting about project timeline"))

0 commit comments

Comments
 (0)