Skip to content

Commit 490c565

Browse files
committed
Expand function documentation and add algorithm explanation
1 parent 36593a0 commit 490c565

File tree

1 file changed

+77
-16
lines changed

1 file changed

+77
-16
lines changed

machine_learning/local_weighted_learning/local_weighted_learning.py

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,54 @@
1+
"""
2+
Locally weighted linear regression, also called local regression, is a type of
3+
non-parametric linear regression that prioritizes data closest to a given
4+
prediction point. The algorithm estimates the vector of model coefficients β
5+
using weighted least squares regression:
6+
7+
β = (XᵀWX)⁻¹(XᵀWy),
8+
9+
where X is the design matrix, y is the response vector, and W is the diagonal
10+
weight matrix.
11+
12+
This implementation calculates wᵢ, the weight of the ith training sample, using
13+
the Gaussian weight:
14+
15+
wᵢ = exp(-‖xᵢ - x‖²/(2τ²)),
16+
17+
where xᵢ is the ith training sample, x is the prediction point, τ is the
18+
"bandwidth", and ‖x‖ is the Euclidean norm (also called the 2-norm or the L²
19+
norm). The bandwidth τ controls how quickly the weight of a training sample
20+
decreases as its distance from the prediction point increases. One can think of
21+
the Gaussian weight as a bell curve centered around the prediction point: a
22+
training sample is weighted lower if it's farther from the center, and τ
23+
controls the spread of the bell curve.
24+
25+
Other types of locally weighted regression such as locally estimated scatterplot
26+
smoothing (LOESS) typically use different weight functions.
27+
28+
References:
29+
- https://en.wikipedia.org/wiki/Local_regression
30+
- https://en.wikipedia.org/wiki/Weighted_least_squares
31+
- https://cs229.stanford.edu/notes2022fall/main_notes.pdf
32+
"""
33+
134
import matplotlib.pyplot as plt
235
import numpy as np
336

437

538
def weight_matrix(point: np.ndarray, x_train: np.ndarray, tau: float) -> np.ndarray:
639
"""
7-
Calculate the weight for every point in the data set.
8-
point --> the x value at which we want to make predictions
40+
Calculate the weight of every point in the training data around a given
41+
prediction point
42+
43+
Args:
44+
point: x-value at which the prediction is being made
45+
x_train: ndarray of x-values for training
46+
tau: bandwidth value, controls how quickly the weight of training values
47+
decreases as the distance from the prediction point increases
48+
49+
Returns:
50+
n x n weight matrix around the prediction point, where n is the size of
51+
the training set
952
>>> weight_matrix(
1053
... np.array([1., 1.]),
1154
... np.array([[16.99, 10.34], [21.01,23.68], [24.59,25.69]]),
@@ -15,22 +58,30 @@ def weight_matrix(point: np.ndarray, x_train: np.ndarray, tau: float) -> np.ndar
1558
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000],
1659
[0.00000000e+000, 0.00000000e+000, 0.00000000e+000]])
1760
"""
18-
m, _ = np.shape(x_train) # m is the number of training samples
19-
weights = np.eye(m) # Initializing weights as identity matrix
20-
21-
# calculating weights for all training examples [x(i)'s]
22-
for j in range(m):
61+
n = len(x_train) # Number of training samples
62+
weights = np.eye(n) # Initialize weights as identity matrix
63+
for j in range(n):
2364
diff = point - x_train[j]
2465
weights[j, j] = np.exp(diff @ diff.T / (-2.0 * tau**2))
66+
2567
return weights
2668

2769

2870
def local_weight(
2971
point: np.ndarray, x_train: np.ndarray, y_train: np.ndarray, tau: float
3072
) -> np.ndarray:
3173
"""
32-
Calculate the local weights using the weight_matrix function on training data.
33-
Return the weighted matrix.
74+
Calculate the local weights at a given prediction point using the weight
75+
matrix for that point
76+
77+
Args:
78+
point: x-value at which the prediction is being made
79+
x_train: ndarray of x-values for training
80+
y_train: ndarray of y-values for training
81+
tau: bandwidth value, controls how quickly the weight of training values
82+
decreases as the distance from the prediction point increases
83+
Returns:
84+
ndarray of local weights
3485
>>> local_weight(
3586
... np.array([1., 1.]),
3687
... np.array([[16.99, 10.34], [21.01,23.68], [24.59,25.69]]),
@@ -52,17 +103,24 @@ def local_weight_regression(
52103
x_train: np.ndarray, y_train: np.ndarray, tau: float
53104
) -> np.ndarray:
54105
"""
55-
Calculate predictions for each data point on axis
106+
Calculate predictions for each point in the training data
107+
108+
Args:
109+
x_train: ndarray of x-values for training
110+
y_train: ndarray of y-values for training
111+
tau: bandwidth value, controls how quickly the weight of training values
112+
decreases as the distance from the prediction point increases
113+
114+
Returns:
115+
ndarray of predictions
56116
>>> local_weight_regression(
57117
... np.array([[16.99, 10.34], [21.01, 23.68], [24.59, 25.69]]),
58118
... np.array([[1.01, 1.66, 3.5]]),
59119
... 0.6
60120
... )
61121
array([1.07173261, 1.65970737, 3.50160179])
62122
"""
63-
m, _ = np.shape(x_train)
64-
y_pred = np.zeros(m)
65-
123+
y_pred = np.zeros(len(x_train)) # Initialize array of predictions
66124
for i, item in enumerate(x_train):
67125
y_pred[i] = item @ local_weight(item, x_train, y_train, tau)
68126

@@ -74,14 +132,15 @@ def load_data(
74132
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
75133
"""
76134
Load data from seaborn and split it into x and y points
135+
>>> pass # No doctests, function is for demo purposes only
77136
"""
78137
import seaborn as sns
79138

80139
data = sns.load_dataset(dataset_name)
81-
x_data = np.array(data[x_name]) # total_bill
82-
y_data = np.array(data[y_name]) # tip
140+
x_data = np.array(data[x_name])
141+
y_data = np.array(data[y_name])
83142

84-
one = np.ones(np.shape(y_data)[0], dtype=int)
143+
one = np.ones(len(y_data))
85144

86145
# pairing elements of one and x_data
87146
x_train = np.column_stack((one, x_data))
@@ -99,6 +158,7 @@ def plot_preds(
99158
) -> plt.plot:
100159
"""
101160
Plot predictions and display the graph
161+
>>> pass # No doctests, function is for demo purposes only
102162
"""
103163
x_train_sorted = np.sort(x_train, axis=0)
104164
plt.scatter(x_data, y_data, color="blue")
@@ -119,6 +179,7 @@ def plot_preds(
119179

120180
doctest.testmod()
121181

182+
# Demo with a dataset from the seaborn module
122183
training_data_x, total_bill, tip = load_data("tips", "total_bill", "tip")
123184
predictions = local_weight_regression(training_data_x, tip, 5)
124185
plot_preds(training_data_x, predictions, total_bill, tip, "total_bill", "tip")

0 commit comments

Comments
 (0)