torch.gather 用法笔记
2026/6/10 3:34:43 网站建设 项目流程

1. 基本作用

torch.gather的作用是:

从 input 的指定维度 dim 上,按照 index 给出的索引位置取值。

基本语法:

output = torch.gather(input, dim, index)
基本公式 三维举例 dim=1 output[b][k][d] = input[b][index[b][k][d]][d]

其中:

input:原始张量 dim:指定在哪个维度上取值 index:索引张量 output:取出的结果

2. 核心规则

torch.gather有一个非常重要的规则:

output 的形状和 index 的形状相同。

也就是说:

output.shape == index.shape

index不只是告诉从哪里取值,它还决定了最终输出张量的形状。


3. 二维例子

假设:

import torch input = torch.tensor([ [10, 20, 30], [40, 50, 60] ]) index = torch.tensor([ [0, 2], [1, 0] ]) output = torch.gather(input, dim=1, index=index)

因为dim=1,所以是在列方向取值。

取值过程:

output[0][0] = input[0][index[0][0]] = input[0][0] = 10 output[0][1] = input[0][index[0][1]] = input[0][2] = 30 output[1][0] = input[1][index[1][0]] = input[1][1] = 50 output[1][1] = input[1][index[1][1]] = input[1][0] = 40

最终结果:

tensor([ [10, 30], [50, 40] ])

4. 三维例子

假设:

input.shape = [B, N, D]

含义是:

B:batch size,样本数量 N:每个样本中的 patch 数量 D:每个 patch 的特征维度

如果:

index.shape = [B, K, D]

并且:

output = torch.gather(input, dim=1, index=index)

那么:

output.shape = [B, K, D]

因为dim=1,所以是在N这个维度上取值。

核心公式是:

output[b][k][d] = input[b][index[b][k][d]][d]

解释:

B 维保持对应 D 维保持对应 只有 N 维根据 index[b][k][d] 指定的位置取值

5. 结合 patch 选择代码理解

常见代码:

_, indices = torch.topk(attention_weights, k, dim=1) selected_patches = torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )

假设:

patches.shape = [B, N, D] attention_weights.shape = [B, N] indices.shape = [B, K]

其中:

B:样本数量 N:patch 数量 D:每个 patch 的特征维度 K:要选出的 patch 数量

torch.topk得到的是每个样本中分数最高的K个 patch 索引:

indices.shape = [B, K]

但是patches是三维张量:

patches.shape = [B, N, D]

所以需要先扩展索引:

indices.unsqueeze(-1)

形状变为:

[B, K, 1]

再使用:

expand(-1, -1, D)

形状变为:

[B, K, D]

这样才能和patches的三维结构对应起来。


6. 为什么要 expand 到 D 维

因为每个 patch 不是一个数,而是一个D维特征向量。

如果某个 patch 的索引是:

indices[b][k] = 3

扩展后变成:

index[b][k] = [3, 3, 3, ..., 3]

长度是D

于是:

output[b][k][0] = input[b][3][0] output[b][k][1] = input[b][3][1] output[b][k][2] = input[b][3][2] ... output[b][k][D-1] = input[b][3][D-1]

也就是把第3个 patch 的完整D维特征全部取出来。


7. 最终效果

对于代码:

selected_patches = torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )

它的作用是:

从每个样本的 N 个 patch 中, 根据 top-k 得到的索引, 选出 K 个重要 patch, 并保留每个 patch 的完整 D 维特征。

形状变化:

patches: [B, N, D] indices: [B, K] expanded index: [B, K, D] selected_patches: [B, K, D]

8. 记忆方法

可以这样记:

gather = 按照 index,从 input 的某个 dim 维度上取值。

如果:

output = torch.gather(input, dim=1, index=index)

那么就是:

在第 1 维上取值; 其他维度保持对应关系; output 的形状等于 index 的形状。

对于三维张量:

input.shape = [B, N, D] index.shape = [B, K, D] dim = 1

核心公式:

output[b][k][d] = input[b][index[b][k][d]][d]

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询