1. 基本作用
torch.gather的作用是:
从 input 的指定维度 dim 上,按照 index 给出的索引位置取值。基本语法:
output = torch.gather(input, dim, index)基本公式 三维举例 dim=1 output[b][k][d] = input[b][index[b][k][d]][d]其中:
input:原始张量 dim:指定在哪个维度上取值 index:索引张量 output:取出的结果2. 核心规则
torch.gather有一个非常重要的规则:
output 的形状和 index 的形状相同。也就是说:
output.shape == index.shapeindex不只是告诉从哪里取值,它还决定了最终输出张量的形状。
3. 二维例子
假设:
import torch input = torch.tensor([ [10, 20, 30], [40, 50, 60] ]) index = torch.tensor([ [0, 2], [1, 0] ]) output = torch.gather(input, dim=1, index=index)因为dim=1,所以是在列方向取值。
取值过程:
output[0][0] = input[0][index[0][0]] = input[0][0] = 10 output[0][1] = input[0][index[0][1]] = input[0][2] = 30 output[1][0] = input[1][index[1][0]] = input[1][1] = 50 output[1][1] = input[1][index[1][1]] = input[1][0] = 40最终结果:
tensor([ [10, 30], [50, 40] ])4. 三维例子
假设:
input.shape = [B, N, D]含义是:
B:batch size,样本数量 N:每个样本中的 patch 数量 D:每个 patch 的特征维度如果:
index.shape = [B, K, D]并且:
output = torch.gather(input, dim=1, index=index)那么:
output.shape = [B, K, D]因为dim=1,所以是在N这个维度上取值。
核心公式是:
output[b][k][d] = input[b][index[b][k][d]][d]解释:
B 维保持对应 D 维保持对应 只有 N 维根据 index[b][k][d] 指定的位置取值5. 结合 patch 选择代码理解
常见代码:
_, indices = torch.topk(attention_weights, k, dim=1) selected_patches = torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )假设:
patches.shape = [B, N, D] attention_weights.shape = [B, N] indices.shape = [B, K]其中:
B:样本数量 N:patch 数量 D:每个 patch 的特征维度 K:要选出的 patch 数量torch.topk得到的是每个样本中分数最高的K个 patch 索引:
indices.shape = [B, K]但是patches是三维张量:
patches.shape = [B, N, D]所以需要先扩展索引:
indices.unsqueeze(-1)形状变为:
[B, K, 1]再使用:
expand(-1, -1, D)形状变为:
[B, K, D]这样才能和patches的三维结构对应起来。
6. 为什么要 expand 到 D 维
因为每个 patch 不是一个数,而是一个D维特征向量。
如果某个 patch 的索引是:
indices[b][k] = 3扩展后变成:
index[b][k] = [3, 3, 3, ..., 3]长度是D。
于是:
output[b][k][0] = input[b][3][0] output[b][k][1] = input[b][3][1] output[b][k][2] = input[b][3][2] ... output[b][k][D-1] = input[b][3][D-1]也就是把第3个 patch 的完整D维特征全部取出来。
7. 最终效果
对于代码:
selected_patches = torch.gather( patches, 1, indices.unsqueeze(-1).expand(-1, -1, D) )它的作用是:
从每个样本的 N 个 patch 中, 根据 top-k 得到的索引, 选出 K 个重要 patch, 并保留每个 patch 的完整 D 维特征。形状变化:
patches: [B, N, D] indices: [B, K] expanded index: [B, K, D] selected_patches: [B, K, D]8. 记忆方法
可以这样记:
gather = 按照 index,从 input 的某个 dim 维度上取值。如果:
output = torch.gather(input, dim=1, index=index)那么就是:
在第 1 维上取值; 其他维度保持对应关系; output 的形状等于 index 的形状。对于三维张量:
input.shape = [B, N, D] index.shape = [B, K, D] dim = 1核心公式:
output[b][k][d] = input[b][index[b][k][d]][d]