-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathusing-linear-regression.py
52 lines (39 loc) · 1.45 KB
/
using-linear-regression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""
Linear regression is a tool for modeling the dependence between two sets of
data so that we can eventually use this model to make predictions. The name
comes from the fact that we form a linear model (straight line) of one set of
data based on a second. In the literature, the variable that we wish to model
is frequently called the response variable, and the variable that we are
using in this model is the predictor variable.
This module illustrates how to use the statsmodels package to perform a simple
linear regression to model the relationship between two sets of data.
"""
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
from numpy.random import default_rng
rng = default_rng(12345)
x = np.linspace(0, 5, 25)
rng.shuffle(x)
trend = 2.0
shift = 5.0
y1 = trend * x + shift + rng.normal(0, 0.5, size=25)
y2 = trend * x + shift + rng.normal(0, 5, size=25)
fig, ax = plt.subplots()
ax.scatter(x, y1, c="b", label="Good correlation")
ax.scatter(x, y2, c="r", label="Bad correlation")
ax.legend()
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_title("Scatter plot of data with best fit lines")
pred_x = sm.add_constant(x)
model1 = sm.OLS(y1, pred_x).fit()
model2 = sm.OLS(y2, pred_x).fit()
print(model1.summary())
print(model2.summary())
model_x = sm.add_constant(np.linspace(0, 5))
model_y1 = model1.predict(model_x)
model_y2 = model2.predict(model_x)
ax.plot(model_x[:, 1], model_y1, 'b')
ax.plot(model_x[:, 1], model_y2, 'r')
plt.show()