机器学习 流形数据降维:UMAP 降维算法

UMAP 简介

UMAP(Uniform Manifold Approximation and Projection)是一种先进的非线性降维技术,用于将高维数据集转换为低维空间中的表示,同时尽可能保留原始数据的复杂结构和拓扑特性。它特别适用于可视化分析和机器学习领域的预处理步骤。

理论基础

  • 流形学习:UMAP 建立在流形学习的基础上,该理论认为即使在高维空间中,许多真实世界的数据点也可以近似地分布在一个低维流形上。通过捕捉这些隐藏的低维结构,UMAP 能够生成有意义的二维或三维投影。

  • 代数拓扑:算法利用了代数拓扑的概念,特别是对邻域图的同胚嵌入来估计数据流形上的全局和局部连通性。这意味着 UMAP 不仅关注数据点之间的局部相似性,还考虑了它们在整个数据集中的相对位置和全局关系。

  • 黎曼几何:虽然 UMAP 并不直接依赖于严格的黎曼几何计算,但其背后的思想受到了启发。算法假设数据均匀分布在某种局部恒定度量的空间中,并且这个空间可以通过数学操作进行近似。

特点与优势

  1. 保留全局结构:相比 t-SNE(t-Distributed Stochastic Neighbor Embedding),UMAP 更注重保持数据集的全局结构,这对于某些应用如聚类和分类任务尤其重要。

  2. 计算效率:UMAP 优化了算法实现,使其运行速度更快,特别是在处理大规模数据时表现更为出色。

  3. 无维度限制:UMAP 可以处理任意大小的嵌入维度,不受像 t-SNE 那样的硬性限制,这使得它不仅限于可视化,还可以作为通用的降维工具应用于其他机器学习模型。

  4. 参数调整:UMAP 提供了一些关键参数供用户自定义,例如 n_neighbors 控制每个点邻居的数量,影响降维后数据点的聚集程度;而 min_dist 参数则决定了低维空间中点之间的最小距离,有助于控制数据点的分布密度。

应用场景

  • 数据可视化
  • 高维数据的探索性数据分析(EDA)
  • 大规模单细胞转录组数据分析
  • 异常检测
  • 机器学习模型的特征降维预处理

在 Python 中使用 UMAP

安装 umap-learn 库

1
pip install umap-learn

使用 UMAP 可视化手写数字数据集

下面我们使用 UMAP 将手写数字数据集降到二维空间,并将降维后的数据可视化。

导入需要的包:

1
2
3
4
5
6
7
8
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from umap import UMAP
from sklearn.preprocessing import MinMaxScaler

from torchvision import datasets

加载手写数字数据集:

1
2
3
4
5
digits = datasets.MNIST("./data", train=True, download=True)
X, y = digits.data.numpy().reshape(-1, 28 * 28), digits.targets.numpy()
n = 5000
X, y = X[:n], y[:n]
X.shape, y.shape # ((5000, 784), (5000,))

可视化原始数据:

1
2
3
4
5
6
7
8
9
10
11
n = 10  # 显示 10 * 10 个数字
img = np.zeros((30 * n, 30 * n))
for i in range(n):
ix = 30 * i + 1
for j in range(n):
iy = 30 * j + 1
img[ix : ix + 28, iy : iy + 28] = X[i * n + j].reshape(28, 28)
plt.figure(figsize=(8, 8))
plt.imshow(img, cmap=plt.cm.binary)
plt.axis("off")
plt.show()

MNIST

使用 UMAP 将数据降到二维空间并可视化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# UMAP 降维
# UMAP 降维
reducer = UMAP(n_components=2, random_state=0)
embedding = reducer.fit_transform(X)

# 归一化
scaler = MinMaxScaler()
embedding = scaler.fit_transform(embedding)

# 可视化
plt.figure(figsize=(9, 9))
for i in range(embedding.shape[0]):
plt.text(
embedding[i, 0],
embedding[i, 1],
str(y[i]),
color=plt.cm.tab10(y[i]),
fontdict={"size": 12},
va="center",
ha="center",
)
plt.axis("off")
plt.show()

UMAP