Skip to content

Commit aa33cc2

Browse files
Add visualization support for linear regression
1 parent 791deb4 commit aa33cc2

1 file changed

Lines changed: 57 additions & 4 deletions

File tree

machine_learning/linear_regression.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
# dependencies = [
1414
# "httpx",
1515
# "numpy",
16+
# "matplotlib",
1617
# ]
1718
# ///
1819

1920
import httpx
2021
import numpy as np
21-
22+
import matplotlib.pyplot as plt
2223

2324
def collect_dataset():
2425
"""Collect dataset of CSGO
@@ -102,12 +103,17 @@ def run_linear_regression(data_x, data_y):
102103

103104
theta = np.zeros((1, no_features))
104105

106+
err = []
107+
105108
for i in range(iterations):
106109
theta = run_steep_gradient_descent(data_x, data_y, len_data, alpha, theta)
107110
error = sum_of_square_error(data_x, data_y, len_data, theta)
108-
print(f"At Iteration {i + 1} - Error is {error:.5f}")
111+
err.append(error)
112+
113+
if i % 1000 == 0:
114+
print(f"At Iteration {i + 1} - Error is {error:.5f}")
109115

110-
return theta
116+
return theta, err
111117

112118

113119
def mean_absolute_error(predicted_y, original_y):
@@ -125,6 +131,45 @@ def mean_absolute_error(predicted_y, original_y):
125131
return total / len(original_y)
126132

127133

134+
135+
# visulization
136+
def plot_regression(data_x, data_y, theta):
137+
"""
138+
Plot regression line with dataset points
139+
"""
140+
141+
x = np.array(data_x[:, 1]).flatten()
142+
y = np.array(data_y).flatten()
143+
144+
predictions = theta[0, 0] + theta[0, 1] * x
145+
146+
plt.scatter(x, y)
147+
148+
plt.plot(x, predictions)
149+
150+
plt.xlabel("ADR")
151+
plt.ylabel("Rating")
152+
153+
plt.title("Linear Regression Best Fit")
154+
155+
plt.show()
156+
157+
158+
def plot_loss(err):
159+
"""
160+
Plot training loss curve
161+
"""
162+
163+
plt.plot(err)
164+
165+
plt.xlabel("Iterations")
166+
plt.ylabel("Loss")
167+
168+
plt.title("Training Loss Curve")
169+
170+
plt.show()
171+
172+
128173
def main():
129174
"""Driver function"""
130175
data = collect_dataset()
@@ -133,7 +178,11 @@ def main():
133178
data_x = np.c_[np.ones(len_data), data[:, :-1]].astype(float)
134179
data_y = data[:, -1].astype(float)
135180

136-
theta = run_linear_regression(data_x, data_y)
181+
theta,err = run_linear_regression(data_x, data_y)
182+
183+
plot_regression(data_x, data_y, theta)
184+
plt_loss(err)
185+
137186
len_result = theta.shape[1]
138187
print("Resultant Feature vector : ")
139188
for i in range(len_result):
@@ -145,3 +194,7 @@ def main():
145194

146195
doctest.testmod()
147196
main()
197+
198+
199+
200+

0 commit comments

Comments
 (0)