Skip to content

Commit abf1747

Browse files
committed
Add multinomial Naive Bayes text classification example
1 parent 791deb4 commit abf1747

1 file changed

Lines changed: 191 additions & 0 deletions

File tree

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""
2+
Naive Bayes text classification using a multinomial event model.
3+
4+
The implementation in this module is intentionally educational and keeps the
5+
logic explicit: token counting, prior probabilities, and posterior scoring in
6+
log-space.
7+
8+
References:
9+
- https://en.wikipedia.org/wiki/Naive_Bayes_classifier
10+
- https://scikit-learn.org/stable/modules/naive_bayes.html
11+
"""
12+
13+
from __future__ import annotations
14+
15+
import re
16+
from collections import Counter, defaultdict
17+
from math import exp, log
18+
19+
20+
class NaiveBayesTextClassifier:
21+
"""
22+
Multinomial Naive Bayes classifier for short text documents.
23+
24+
Args:
25+
alpha: Additive (Laplace) smoothing parameter. Must be greater than 0.
26+
27+
>>> NaiveBayesTextClassifier(alpha=0)
28+
Traceback (most recent call last):
29+
...
30+
ValueError: alpha must be greater than 0.
31+
"""
32+
33+
def __init__(self, alpha: float = 1.0) -> None:
34+
if alpha <= 0:
35+
raise ValueError("alpha must be greater than 0.")
36+
37+
self.alpha = alpha
38+
self.classes_: list[str] = []
39+
self.vocabulary_: set[str] = set()
40+
self.class_document_counts_: Counter[str] = Counter()
41+
self.class_token_counts_: dict[str, Counter[str]] = defaultdict(Counter)
42+
self.class_total_tokens_: Counter[str] = Counter()
43+
self.class_log_prior_: dict[str, float] = {}
44+
self.is_fitted_ = False
45+
46+
@staticmethod
47+
def _tokenize(text: str) -> list[str]:
48+
"""
49+
Split text into lowercase alphanumeric tokens.
50+
51+
>>> NaiveBayesTextClassifier._tokenize("Hello, NLP world!")
52+
['hello', 'nlp', 'world']
53+
"""
54+
return re.findall(r"[a-z0-9']+", text.lower())
55+
56+
def fit(self, texts: list[str], labels: list[str]) -> None:
57+
"""
58+
Fit the classifier from labeled training texts.
59+
60+
>>> model = NaiveBayesTextClassifier()
61+
>>> model.fit(["cheap meds", "project meeting"], ["spam", "ham"])
62+
>>> sorted(model.classes_)
63+
['ham', 'spam']
64+
65+
>>> model.fit(["only one text"], ["ham", "spam"])
66+
Traceback (most recent call last):
67+
...
68+
ValueError: texts and labels must have the same length.
69+
70+
>>> model.fit([], [])
71+
Traceback (most recent call last):
72+
...
73+
ValueError: training data must not be empty.
74+
"""
75+
if len(texts) != len(labels):
76+
raise ValueError("texts and labels must have the same length.")
77+
if not texts:
78+
raise ValueError("training data must not be empty.")
79+
80+
self.classes_ = sorted(set(labels))
81+
self.vocabulary_.clear()
82+
self.class_document_counts_.clear()
83+
self.class_token_counts_ = defaultdict(Counter)
84+
self.class_total_tokens_.clear()
85+
self.class_log_prior_.clear()
86+
87+
for text, label in zip(texts, labels):
88+
if not isinstance(text, str) or not isinstance(label, str):
89+
raise TypeError("texts and labels must contain strings only.")
90+
91+
tokens = self._tokenize(text)
92+
self.class_document_counts_[label] += 1
93+
self.class_token_counts_[label].update(tokens)
94+
self.class_total_tokens_[label] += len(tokens)
95+
self.vocabulary_.update(tokens)
96+
97+
total_documents = len(texts)
98+
self.class_log_prior_ = {
99+
label: log(self.class_document_counts_[label] / total_documents)
100+
for label in self.classes_
101+
}
102+
self.is_fitted_ = True
103+
104+
def predict_proba(self, text: str) -> dict[str, float]:
105+
"""
106+
Return posterior probabilities for every class.
107+
108+
>>> train_texts, train_labels = build_toy_dataset()
109+
>>> model = NaiveBayesTextClassifier()
110+
>>> model.fit(train_texts, train_labels)
111+
>>> probs = model.predict_proba("cheap meds available now")
112+
>>> round(sum(probs.values()), 6)
113+
1.0
114+
>>> probs['spam'] > probs['ham']
115+
True
116+
117+
>>> NaiveBayesTextClassifier().predict_proba("hello")
118+
Traceback (most recent call last):
119+
...
120+
ValueError: model has not been fitted yet.
121+
"""
122+
if not self.is_fitted_:
123+
raise ValueError("model has not been fitted yet.")
124+
if not isinstance(text, str):
125+
raise TypeError("text must be a string.")
126+
127+
tokens = self._tokenize(text)
128+
vocabulary_size = len(self.vocabulary_)
129+
log_posteriors: dict[str, float] = {}
130+
131+
for label in self.classes_:
132+
log_prob = self.class_log_prior_[label]
133+
token_counts = self.class_token_counts_[label]
134+
denominator = self.class_total_tokens_[label] + self.alpha * vocabulary_size
135+
136+
for token in tokens:
137+
count = token_counts[token]
138+
log_prob += log((count + self.alpha) / denominator)
139+
140+
log_posteriors[label] = log_prob
141+
142+
max_log = max(log_posteriors.values())
143+
exp_scores = {
144+
label: exp(score - max_log)
145+
for label, score in log_posteriors.items()
146+
}
147+
normalizer = sum(exp_scores.values())
148+
return {label: score / normalizer for label, score in exp_scores.items()}
149+
150+
def predict(self, text: str) -> str:
151+
"""
152+
Predict the most likely class label for a text.
153+
154+
>>> train_texts, train_labels = build_toy_dataset()
155+
>>> model = NaiveBayesTextClassifier(alpha=1.0)
156+
>>> model.fit(train_texts, train_labels)
157+
>>> model.predict("free cheap meds")
158+
'spam'
159+
>>> model.predict("project meeting schedule")
160+
'ham'
161+
"""
162+
probabilities = self.predict_proba(text)
163+
return max(probabilities, key=probabilities.get)
164+
165+
166+
def build_toy_dataset() -> tuple[list[str], list[str]]:
167+
"""
168+
Build a tiny text dataset for examples and quick local testing.
169+
170+
>>> texts, labels = build_toy_dataset()
171+
>>> len(texts), len(labels)
172+
(6, 6)
173+
>>> sorted(set(labels))
174+
['ham', 'spam']
175+
"""
176+
texts = [
177+
"buy cheap meds now",
178+
"cheap meds available online",
179+
"win cash prizes now",
180+
"project meeting schedule attached",
181+
"let us discuss the project timeline",
182+
"team meeting moved to monday",
183+
]
184+
labels = ["spam", "spam", "spam", "ham", "ham", "ham"]
185+
return texts, labels
186+
187+
188+
if __name__ == "__main__":
189+
import doctest
190+
191+
doctest.testmod()

0 commit comments

Comments
 (0)