30分钟搞懂线性回归:从数学公式到西瓜书实战指南
【免费下载链接】pumpkin-book 《机器学习》(西瓜书)公式详解 项目地址: https://gitcode.com/datawhalechina/pumpkin-book
你是否还在为机器学习中的线性回归公式头疼?看着满页的数学符号却不知道如何下手实现?本文将用最通俗的语言,带你一步步揭开线性回归的神秘面纱,从一元到多元,从数学推导到《机器学习》(西瓜书)实战应用,让你彻底搞懂这个最基础也最重要的算法。读完本文,你将能够:理解线性回归的核心思想、掌握从一元到多元的推导过程、学会使用南瓜书辅助学习、独立实现简单的线性回归模型。
线性回归的核心思想
线性回归(Linear Regression)是机器学习中最基础也是应用最广泛的算法之一,它的核心思想非常简单:找到一条直线(或超平面),使得数据点到这条直线(或超平面)的距离之和最小。就像我们想根据西瓜的大小来预测它的价格,线性回归就是帮我们找到大小和价格之间的最佳线性关系。
一元线性回归的数学推导
模型定义
一元线性回归模型可以表示为:
$y = wx + b$
其中,$x$ 是自变量(如西瓜的大小),$y$ 是因变量(如西瓜的价格),$w$ 是权重(weight),$b$ 是偏置(bias)。我们的目标就是找到最优的 $w$ 和 $b$,使得模型的预测值与真实值之间的误差最小。
损失函数
为了衡量模型的好坏,我们需要定义一个损失函数(Loss Function)。最常用的是均方误差(Mean Squared Error, MSE):
$L(w,b) = \frac{1}{n}\sum_{i=1}^{n}(y_i - (wx_i + b))^2$
其中,$n$ 是样本数量,$y_i$ 是第 $i$ 个样本的真实值,$wx_i + b$ 是模型的预测值。我们的目标就是最小化这个损失函数。
参数求解
要找到最小化损失函数的 $w$ 和 $b$,我们可以使用梯度下降法(Gradient Descent),也可以直接通过数学推导求解。这里我们介绍后者。
对 $L(w,b)$ 分别求关于 $w$ 和 $b$ 的偏导数,并令其等于 0:
$\frac{\partial L}{\partial w} = \frac{2}{n}\sum_{i=1}^{n}(y_i - wx_i - b)(-x_i) = 0$
$\frac{\partial L}{\partial b} = \frac{2}{n}\sum_{i=1}^{n}(y_i - wx_i - b)(-1) = 0$
解这两个方程,可以得到:
$b = \bar{y} - w\bar{x}$
$w = \frac{\sum_{i=1}^{n}(x_i - \bar{x})(y_i - \bar{y})}{\sum_{i=1}^{n}(x_i - \bar{x})^2}$
其中,$\bar{x}$ 是 $x$ 的平均值,$\bar{y}$ 是 $y$ 的平均值。
多元线性回归的扩展
当自变量不止一个时,比如我们不仅用西瓜的大小,还用它的甜度、色泽来预测价格,就需要用到多元线性回归。
模型定义
多元线性回归模型可以表示为:
$y = w_1x_1 + w_2x_2 + ... + w_dx_d + b$
用向量形式可以更简洁地表示为:
$y = \mathbf{w}^T\mathbf{x} + b$
其中,$\mathbf{w} = (w_1, w_2, ..., w_d)^T$ 是权重向量,$\mathbf{x} = (x_1, x_2, ..., x_d)^T$ 是输入向量,$b$ 是偏置。
损失函数与参数求解
多元线性回归的损失函数同样可以定义为均方误差:
$L(\mathbf{w},b) = \frac{1}{n}\sum_{i=1}^{n}(y_i - (\mathbf{w}^T\mathbf{x}_i + b))^2$
为了方便计算,我们可以将偏置 $b$ 也纳入权重向量,令 $\hat{\mathbf{w}} = (\mathbf{w}^T, b)^T$,$\hat{\mathbf{x}}_i = (\mathbf{x}_i^T, 1)^T$,则模型可以表示为:
$y = \hat{\mathbf{w}}^T\hat{\mathbf{x}}$
此时损失函数为:
$L(\hat{\mathbf{w}}) = \frac{1}{n}|\mathbf{Y} - \mathbf{X}\hat{\mathbf{w}}|^2$
其中,$\mathbf{Y}$ 是真实值向量,$\mathbf{X}$ 是样本矩阵(每一行是一个样本)。
对 $\hat{\mathbf{w}}$ 求导并令其等于 0,可以得到:
$\hat{\mathbf{w}} = (\mathbf{X}^T\mathbf{X})^{-1}\mathbf{X}^T\mathbf{Y}$
这就是多元线性回归的正规方程解(Normal Equation)。
南瓜书辅助学习
《机器学习》(西瓜书)是机器学习领域的经典教材,但其中的公式推导对于初学者来说可能有些晦涩。而南瓜书(README.md)作为西瓜书的公式详解,为我们提供了极大的帮助。
南瓜书的 docs 目录下有各个章节的详细解释,比如docs/chapter3/chapter3.md就详细讲解了线性回归相关的内容。你可以通过这些文档,结合西瓜书进行学习,加深对公式的理解。
实战应用
学习了线性回归的理论知识后,我们可以尝试用代码实现一个简单的线性回归模型。这里我们以南瓜书中的例子为例,假设我们有一批西瓜的数据,包括大小、甜度和价格,我们可以用多元线性回归来预测西瓜的价格。
首先,我们需要准备数据。然后,根据上面推导的公式,计算出权重 $\mathbf{w}$ 和偏置 $b$。最后,用得到的模型进行预测。
虽然本文不包含具体的代码实现,但你可以参考南瓜书中的示例,结合自己的理解进行实践。相信通过理论学习和实际操作,你一定能熟练掌握线性回归算法。
总结
本文从线性回归的核心思想出发,详细推导了一元和多元线性回归的数学公式,并介绍了如何使用南瓜书辅助学习。线性回归作为机器学习的基础算法,虽然简单,但应用广泛,掌握它对于后续学习更复杂的算法至关重要。希望本文能帮助你更好地理解线性回归,为你的机器学习之旅打下坚实的基础。如果你觉得本文对你有帮助,欢迎点赞、收藏,也欢迎在评论区分享你的学习心得。下一篇文章,我们将介绍线性回归的正则化方法,敬请期待!
【免费下载链接】pumpkin-book 《机器学习》(西瓜书)公式详解 项目地址: https://gitcode.com/datawhalechina/pumpkin-book