大家好,欢迎来到IT知识分享网。
一、期望最大化算法
期望最大化(EM)算法是一种在统计学和机器学习中广泛使用的迭代方法,它特别适用于含有隐变量的概率模型参数估计问题。在统计学和机器学习中,有很多不同的模型,例如高斯混合模型(GMM)、隐马尔可夫模型(HMM)等,都可以用EM算法来估计这些模型中的参数。EM算法的主要思想是通过两个步骤的交替执行来找到模型参数的估计值:期望(E)步骤和最大化(M)步骤。此外,EM算法的收敛性也意味着它可以在多次迭代后得到稳定的参数估计,这对于模型的预测和分析非常重要。
二、期望最大化算法原理
1、E步骤(Expectation step)
在E步骤中,我们计算隐变量的条件期望给定观测数据和当前参数估计。假设我们有一个数据集,隐变量
和参数
,则E步骤计算的是:
其中,是隐变量
在观测数据
和当前参数
下的条件概率。
2、M步骤(Maximization step):
在M步骤中,我们利用E步骤计算出的隐变量的分布来更新参数的估计,以最大化似然函数。M步骤的计算公式为:
其中,是更新后的参数估计,
是上一步的参数估计,N是数据点的数量,求和是对所有数据点和所有可能的隐变量值进行的。
三、EM算法应用
假设我们有一个高斯混合模型GMM,其中有K个高斯分布,参数为,其中
是第k个高斯分布的权重,
是均值,
是方差,则计算EM有:
1、E步骤
计算每个数据点属于每个高斯分布的responsibility(也称为 posterior probability):
这里,是多元正态分布的概率密度函数
2、M步骤
更新每个高斯分布的参数:
其中,是数据点
属于第k个高斯分布的后验概率。N是数据点的数量,求和对所有数据点进行,E步骤和M步骤交替进行,知道参数
收敛。参数更新公式中,分子和分母有相同的部分,但不能简单约去,因为分母中的部分确保了每个数据点在计算新的均值时,其贡献是按照它属于该高斯分布的概率加权的。
四、python实现EM算法
这里,首先生成两个高斯分布的数据,然后定义一个高斯函数来计算给定均值和标准差的数据的概率密度。接下来,定义E步骤和M步骤的函数。最后,运行EM算法迭代100次。
import numpy as np import matplotlib.pyplot as plt from scipy.stats import multivariate_normal # 生成示例数据 np.random.seed(42) X = np.vstack([np.random.multivariate_normal([0, 0], np.eye(2), 100), np.random.multivariate_normal([5, 5], np.eye(2), 100)]) # 定义高斯函数 def gaussian(X, mean, cov): return multivariate_normal.pdf(X, mean, cov) # E步骤 def e_step(X, weights, means, covariances): n, d = X.shape k = len(weights) responsibilities = np.zeros((n, k)) for i in range(k): responsibilities[:, i] = weights[i] * gaussian(X, means[i], covariances[i]) responsibilities /= responsibilities.sum(axis=1, keepdims=True) return responsibilities # M步骤 def m_step(X, responsibilities): n, d = X.shape k = responsibilities.shape[1] weights = responsibilities.sum(axis=0) / n means = np.dot(responsibilities.T, X) / responsibilities.sum(axis=0)[:, np.newaxis] covariances = np.zeros((k, d, d)) for i in range(k): diff = X - means[i] covariances[i] = np.dot(responsibilities[:, i] * diff.T, diff) / responsibilities[:, i].sum() return weights, means, covariances # 初始化参数 def initialize_parameters(X, k): n, d = X.shape weights = np.ones(k) / k means = X[np.random.choice(n, k, False)] covariances = np.array([np.eye(d)] * k) return weights, means, covariances # EM算法 def em_algorithm(X, k, max_iter=100, tol=1e-6): weights, means, covariances = initialize_parameters(X, k) log_likelihoods = [] for i in range(max_iter): responsibilities = e_step(X, weights, means, covariances) weights, means, covariances = m_step(X, responsibilities) log_likelihoods.append(log_likelihood(X, weights, means, covariances)) if i > 0 and np.abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol: break return weights, means, covariances, log_likelihoods, responsibilities # 计算对数似然 def log_likelihood(X, weights, means, covariances): n, d = X.shape k = len(weights) log_likelihood = 0 for i in range(k): log_likelihood += weights[i] * gaussian(X, means[i], covariances[i]) return np.log(log_likelihood).sum() # 绘制高斯分布的等高线 def draw_ellipse(mean, cov, ax, label, alpha=1.0): from matplotlib.patches import Ellipse v, w = np.linalg.eigh(cov) v = 2.0 * np.sqrt(2.0) * np.sqrt(v) u = w[0] / np.linalg.norm(w[0]) angle = np.arctan(u[1] / u[0]) angle = 180.0 * angle / np.pi ell = Ellipse(mean, v[0], v[1], 180.0 + angle, edgecolor='red', lw=2, facecolor='none', alpha=alpha, label=label) ax.add_patch(ell) # 运行EM算法 k = 2 weights, means, covariances, log_likelihoods, responsibilities = em_algorithm(X, k) # 可视化最终结果 plt.figure(figsize=(8, 6)) plt.scatter(X[:, 0], X[:, 1], s=10, label='Data points') ax = plt.gca() for j in range(k): draw_ellipse(means[j], covariances[j], ax, label=f'Gaussian {j+1}', alpha=weights[j]) plt.title('Final Gaussian Mixture Model') plt.legend() plt.show() # 打印结果 print("权重:", weights) print("均值:", means) print("协方差矩阵:", covariances) # 打印对数似然 print("对数似然:", log_likelihoods[-1]) # 计算AIC和BIC n, d = X.shape num_params = k * (d + d * (d + 1) / 2) + k - 1 aic = 2 * num_params - 2 * log_likelihoods[-1] bic = np.log(n) * num_params - 2 * log_likelihoods[-1] print("AIC:", aic) print("BIC:", bic)
其中,aic和bic计算模型的AIC和BIC值,AIC和BIC值越小,表示模型越好。理想情况下,高斯分布的等高线应该很好地覆盖数据点的分布区域,可视化结果如下。
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/158424.html