引言
上一篇笔记中已经记录了,如何对一个无解的线性方程组\\(Ax=b\\)求近似解。在这里,我们先来回顾两个知识点:
- 如何判断一个线性方程组无解:如果拿上面那个方程组\\(Ax=b\\)举例,那就是向量\\(b\\)不在矩阵A对应的列空间中,至于列空间的概念,可以参考
四个基本子空间
那篇笔记 - 如何对无解的方程组求近似解:根据上一篇笔记
如何寻找一个投影矩阵
可以有这么一个思路,将向量\\(b\\)往矩阵\\(A\\)所在的列空间投影得到向量\\(f\\),得到新的方程组\\(A\\hat{x}=f\\),这个\\(\\hat{x}\\)便为近似解了。如果仅仅为了求近似解可以直接在\\(Ax=b\\)等式左右两侧同时左乘\\(A^{\\mathrm{T}}\\),即\\(A^{\\mathrm{T}}Ax=A^{\\mathrm{T}}b\\)。这个和上面先求投影向量再求解是一样的。
这篇笔记将会探究在机器学习的线性回归如何求解损失函数。
\\(Ax=b\\)无解时求近似解
今天我们需要求一个线性方程组,长成这样\\[ \\begin{equation} \\left \\{ \\begin{array}{lr} 2 * w_1 + 2 * w_2 + b = 14 \\\\ 4 * w_1 - 1 * w_2 + b = 5 \\\\ 4 * w_1 + 0 * w_2 + b = 4 \\\\ 4 * w_1 - 2 * w_2 + b = 3 \\\\ 0 * w_1 - 3 * w_2 + b = -20 \\end{array} \\right. \\end{equation} \\]
将(1)式写成矩阵形式,也就是\\[ \\begin{equation} \\left [ \\begin{matrix} 2 & 2 & 1 \\\\ 4 & -1 & 1 \\\\ 4 & 0 & 1 \\\\ 4 & -2 & 1 \\\\ 0 & -3 & 1 \\end{matrix} \\right] \\left [ \\begin{matrix} w_1 \\\\ w_2 \\\\ b \\end{matrix} \\right]= \\left [ \\begin{matrix} 14 \\\\ 5 \\\\ 4 \\\\ 3 \\\\ -20 \\end{matrix} \\right] \\end{equation} \\]
凭我多年的做题经验,这个方程是无解的。太好了,之前学的东西总算可以用上场了(参考笔记如何寻找一个投影矩阵
等式13)。我们将等式两边同时左乘矩阵的转置,我们会惊讶的发现这个新的等式(3)有解了:\\[ \\begin{equation} \\left [ \\begin{matrix} 2 & 4 & 4 & 4 & 0 \\\\ 2 & -1 & 0 & -2 & -3\\\\ 1 & 1 & 1 & 1 & 1 \\end{matrix} \\right] \\left [ \\begin{matrix} 2 & 2 & 1 \\\\ 4 & -1 & 1 \\\\ 4 & 0 & 1 \\\\ 4 & -2 & 1 \\\\ 0 & -3 & 1 \\end{matrix} \\right] \\left [ \\begin{matrix} \\hat{w_1} \\\\ \\hat{w_2} \\\\ \\hat{b} \\end{matrix} \\right]= \\left [ \\begin{matrix} 2 & 4 & 4 & 4 & 0 \\\\ 2 & -1 & 0 & -2 & -3\\\\ 1 & 1 & 1 & 1 & 1 \\end{matrix} \\right] \\left [ \\begin{matrix} 14 \\\\ 5 \\\\ 4 \\\\ 3 \\\\ -20 \\end{matrix} \\right] \\end{equation} \\]
将(3)式化解得到:\\[ \\begin{equation} \\left [ \\begin{matrix} 52 & -8 & 14 \\\\ -8 & 18 & -4 \\\\ 14 & -4 & 5 \\end{matrix} \\right] \\left [ \\begin{matrix} \\hat{w_1} \\\\ \\hat{w_2} \\\\ \\hat{b} \\end{matrix} \\right]= \\left [ \\begin{matrix} 72 \\\\ 73 \\\\ 6 \\end{matrix} \\right] \\end{equation} \\]
由等式(4)解出的\\(\\hat{w_1},\\hat{w_2},\\hat{b}\\)就是等式(1)的近似解,我们也认为它是最优解。
线性回归
拿预测房价举例,谈谈什么是最小二乘法。比如我们假设房价(price)与2个特征即面积(x1)、楼层(x2)有关。那么我们的目标是找到一张三维空间中的平面去拟合一些数据(假设这些数据都经过归一化处理)。先来看看平面怎么定义的?\\[ \\begin{equation} price = w_1 * x_1 + w_2 * x_2 + b \\end{equation} \\]
我们希望所有的数据点都在这个平面上,那样可以通过解线性方程组来算出这个平面的参数\\(w_1,w_2,b\\),这正是线性代数中学到过的。
好的,我们现在有5笔数据(2, 2, 14)、(4, -1, 5)、(4, 0, 4)、(4 -2 3)、(0 -3 -20),将它们代入(5)式得到我们的方程组吧,解出来\\(w_1,w_2,b\\)这样一个线性模型就ok了。
可是这一步我们之前已经做过了,将这些数据代入方程组是无解的,即给出的这些数据根本不在一个平面上。那么,现在我们放松条件,既然找不到一个平面能令所有的点都在它上面,我们找一个最优的平面总可以吧。
最优平面如何定义
假设我们已经有n组数据,每一组数据都是\\((x_1,x_2,y)\\)的集合。将一组数据\\((x_1, x_2)\\)代入(1)中求出price,我们认为每一组数据产生的误差为\\((price-y)^2\\),将每一组数据产生的误差累加起来就是(6)式。即:\\[ \\begin{equation} J(w_1, w_2, b) = \\sum_{i=1}^{n}(price_i-y_i)^2 \\end{equation} \\]
使得\\(J(w_1, w_2, b)\\)最小的那组参数\\((w_1,w_2,b)\\),可以认为是最优平面的参数。
下面会给出一个动图来展示最优平面是怎么样的一个情况(画了好久才画出满意的效果,画图的代码也会在末尾给出):
可以看到图中,红色的点是我们实际的数据,这个蓝色的透明平面是我画出来的认为能拟合这些数据的最好平面。
如何能找到最优平面?
这仍然是一个数学问题,我们认为使得\\(J(w_1, w_2, b)\\)最小的那组参数\\((w_1,w_2,b)\\),就是最终要寻找的最优平面的参数。
这样的话,我们记\\(J(w_1, w_2, b)\\)为一个函数,求一个多元函数的最值我们在微积分中学到过就是求\\(\\frac{\\partial J}{\\partial w_1}, \\frac{\\partial J}{\\partial w_2}, \\frac{\\partial J}{\\partial b}\\),并且令它们都等于0,就能求出最终的解了。
这里已经涉及到矩阵微积分的内容,我试着写几步:\\[ \\begin{equation} J(w_1, w_2, b) = (price-y)^{\\mathrm{T}}(price-y) \\end{equation} \\]
\\(price\\)和\\(y\\)都是向量,再将\\(price\\)用参数\\(w_1,w_2,b\\)表示:\\[ \\begin{equation} J(w) = (Xw-y)^{\\mathrm{T}}(Xw-y) \\end{equation} \\]
(8)式中,\\(X\\)的每一行是1组数据,它是一个nx3的矩阵;\\(w\\)是个向量\\[ X=\\left[ \\begin{matrix} 第一笔数据的 \\ x1 & x2 & 1 \\\\ 第二笔数据的 \\ x1 & x2 & 1 \\\\ . \\\\ . \\\\ . \\\\ 第n笔数据的 \\ x1 & x2 & 1 \\\\ \\end{matrix} \\right] \\ \\ \\ \\ \\ \\ \\ \\ w =\\left[ \\begin{matrix} w_1\\\\ w_2\\\\ b \\end{matrix} \\right] \\]
继续将(8)式化简\\[ \\begin{equation} J(w) = (w^{\\mathrm{T}}X^{\\mathrm{T}}-y^{\\mathrm{T}})(Xw-y) \\end{equation} \\]
接着去括号\\[ \\begin{equation} J(w) = w^{\\mathrm{T}}X^{\\mathrm{T}}Xw-y^{\\mathrm{T}}Xw-w^{\\mathrm{T}}X^{\\mathrm{T}}y+y^{\\mathrm{T}}y \\end{equation} \\]
其中,\\(y^{\\mathrm{T}}Xw\\)与\\(w^{\\mathrm{T}}X^{\\mathrm{T}}y\\)是相等的,都是一个数,所以最终可以写为\\[ \\begin{equation} J(w) = w^{\\mathrm{T}}X^{\\mathrm{T}}Xw-2w^{\\mathrm{T}}X^{\\mathrm{T}}y+y^{\\mathrm{T}}y \\end{equation} \\]
下面就要进行矩阵微积分了,讲实话我不会。但是我学会两个trick能求出最终的\\(w\\)。
-
第一个trick来自台大的林轩田老师,我记得他很轻松地说可以把上面这个等式变换成我们会的一元二次等式,我当时带着满腹的怀疑按照他说的做了,不过真的得到了结果(惊吓!可能这就是数学的魅力)。我们将(11)式变为\\[ \\begin{equation} J(x) = w^{\\mathrm{T}}Aw - 2w^{\\mathrm{T}}b + c \\\\ subject \\ to \\ A = X^{\\mathrm{T}}X \\\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ b=X^{\\mathrm{T}}y\\\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ c=y^{\\mathrm{T}}y \\end{equation} \\]
当然,这不是严格意义上的转换,但是真的能让我们像解熟悉的一元二次方程一样求出解。对(12)求导令其为0,再将原来的值代入回去能得到\\[ \\begin{equation} 2X^{\\mathrm{T}}Xw - 2X^{\\mathrm{T}}y = 0 \\end{equation} \\]
最终\\[ \\begin{equation} w = (X^{\\mathrm{T}}X)^{-1}X^{\\mathrm{T}}y \\end{equation} \\] -
第二种求解的办法就是记住矩阵微积分的公式:
y \\(\\frac{\\partial y}{\\partial X}\\) \\(AX\\) \\(A^{\\mathrm{T}}\\) \\(X^{\\mathrm{T}}A\\) \\(A\\) \\(X^{\\mathrm{T}}X\\) \\(2X\\) \\(X^{\\mathrm{T}}AX\\) \\(AX+A^{\\mathrm{T}}X\\)
等等,(14)式好熟悉。这不就是求解线性方程组\\(Ax=b\\)这个方程组无解时的最优近似解么。所以,机器学习的线性回归其实就是最小二乘中的拟合问题。一开始就将这个问题看为求解线性方程组问题的话:\\[ \\left[ \\begin{matrix} 第一笔数据的 \\ x1 & x2 & 1 \\\\ 第二笔数据的 \\ x1 & x2 & 1 \\\\ . \\\\ . \\\\ . \\\\ 第n笔数据的 \\ x1 & x2 & 1 \\\\ \\end{matrix} \\right] \\left[ \\begin{matrix} w_1\\\\ w_2\\\\ b \\end{matrix} \\right]=\\left[ \\begin{matrix} 第一笔数据的 \\ price \\\\ 第二笔数据的 \\ price \\\\ . \\\\ . \\\\ . \\\\ 第n笔数据的 \\ price \\\\ \\end{matrix} \\right] \\]
不就是求这个方程组有没有解么?如果没有解,我们就求近似解。这个近似解的求解方法就是上一篇笔记中一直强调的部分,在等式左右两边左乘矩阵的转置,我们马上能得到近似解。
画图代码
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
x1 = np.linspace(-5, 5, 5)
x2 = x1
x1, x2 = np.meshgrid(x1, x2)
price = x1 * 3 + x2 * 4 - 5
np.random.seed(325)
data_x = np.random.randint(-5, 5, 5)
data_y = np.random.randint(-5, 5, 5)
data_z = data_x * 3 + data_y * 4 - 5
bias = np.array([5, 2, -3, 4, -3])
data_z = data_z + bias
fig = plt.figure()
ax = fig.gca(projection=\'3d\')
ax.plot_wireframe(x1, x2, price, rstride=10, cstride=10)
for i in range(len(data_x)):
ax.scatter(data_x[i], data_y[i], data_z[i], color=\'r\')
ax.set_xlabel(\'x1\')
ax.set_ylabel(\'x2\')
ax.set_zlabel(\'price\')
ax.set_xticks([-5, 0, 5])
ax.set_yticks([-5, 0,10])
ax.set_zticks([ -40, 0, 40])
plt.show()