本文介绍了 UMAP 降维算法,并以可视化手写数字数据集为例,展示了 UMAP 的使用方法。
介绍自己 🙈
生成本文简介 👋
推荐相关文章 📖
前往主页 🏠
前往爱发电购买
学习笔记机器学习机器学习 流形数据降维:UMAP 降维算法
小嗷犬UMAP 简介
UMAP(Uniform Manifold Approximation and Projection)是一种先进的非线性降维技术,用于将高维数据集转换为低维空间中的表示,同时尽可能保留原始数据的复杂结构和拓扑特性。它特别适用于可视化分析和机器学习领域的预处理步骤。
理论基础
流形学习:UMAP 建立在流形学习的基础上,该理论认为即使在高维空间中,许多真实世界的数据点也可以近似地分布在一个低维流形上。通过捕捉这些隐藏的低维结构,UMAP 能够生成有意义的二维或三维投影。
代数拓扑:算法利用了代数拓扑的概念,特别是对邻域图的同胚嵌入来估计数据流形上的全局和局部连通性。这意味着 UMAP 不仅关注数据点之间的局部相似性,还考虑了它们在整个数据集中的相对位置和全局关系。
黎曼几何:虽然 UMAP 并不直接依赖于严格的黎曼几何计算,但其背后的思想受到了启发。算法假设数据均匀分布在某种局部恒定度量的空间中,并且这个空间可以通过数学操作进行近似。
特点与优势
保留全局结构:相比 t-SNE(t-Distributed Stochastic Neighbor Embedding),UMAP 更注重保持数据集的全局结构,这对于某些应用如聚类和分类任务尤其重要。
计算效率:UMAP 优化了算法实现,使其运行速度更快,特别是在处理大规模数据时表现更为出色。
无维度限制:UMAP 可以处理任意大小的嵌入维度,不受像 t-SNE 那样的硬性限制,这使得它不仅限于可视化,还可以作为通用的降维工具应用于其他机器学习模型。
参数调整:UMAP 提供了一些关键参数供用户自定义,例如 n_neighbors
控制每个点邻居的数量,影响降维后数据点的聚集程度;而 min_dist
参数则决定了低维空间中点之间的最小距离,有助于控制数据点的分布密度。
应用场景
- 数据可视化
- 高维数据的探索性数据分析(EDA)
- 大规模单细胞转录组数据分析
- 异常检测
- 机器学习模型的特征降维预处理
在 Python 中使用 UMAP
安装 umap-learn 库
使用 UMAP 可视化手写数字数据集
下面我们使用 UMAP 将手写数字数据集降到二维空间,并将降维后的数据可视化。
导入需要的包:
1 2 3 4 5 6 7 8
| import numpy as np import pandas as pd import matplotlib.pyplot as plt
from umap import UMAP from sklearn.preprocessing import MinMaxScaler
from torchvision import datasets
|
加载手写数字数据集:
1 2 3 4 5
| digits = datasets.MNIST("./data", train=True, download=True) X, y = digits.data.numpy().reshape(-1, 28 * 28), digits.targets.numpy() n = 5000 X, y = X[:n], y[:n] X.shape, y.shape
|
可视化原始数据:
1 2 3 4 5 6 7 8 9 10 11
| n = 10 img = np.zeros((30 * n, 30 * n)) for i in range(n): ix = 30 * i + 1 for j in range(n): iy = 30 * j + 1 img[ix : ix + 28, iy : iy + 28] = X[i * n + j].reshape(28, 28) plt.figure(figsize=(8, 8)) plt.imshow(img, cmap=plt.cm.binary) plt.axis("off") plt.show()
|

使用 UMAP 将数据降到二维空间并可视化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
|
reducer = UMAP(n_components=2, random_state=0) embedding = reducer.fit_transform(X)
scaler = MinMaxScaler() embedding = scaler.fit_transform(embedding)
plt.figure(figsize=(9, 9)) for i in range(embedding.shape[0]): plt.text( embedding[i, 0], embedding[i, 1], str(y[i]), color=plt.cm.tab10(y[i]), fontdict={"size": 12}, va="center", ha="center", ) plt.axis("off") plt.show()
|
