HF的Qwen3实现

Qwen3的MOE默认版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

这个版本中会让所有的专家都进行计算,但是expert_mask中可能当前专家并没有被分配token,所以torch.where这里会返回空的两个tensor,后续计算中都是空的tensor在进行,虽然没有实际的有效计算,但是仍然会调用expertLayer,存在大量开销(为什么是大量开销?topk不会和专家数一样大么?MOE的核心概念就是稀疏激活专家!所以不存在topK和num_expert 接近的情况)

Qwen3的MOE优化实现版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

# Loop over all available experts in the model and perform the computation on each expert
expert_hitted = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hitted:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

优化版本中充分利用了MoE的稀疏特性,只调用活跃的专家来参与后续流程的计算,大大减少了空调用Layer的情况。

但是,这种通过Python高级索引来取数进行计算的方式,在实际部署的时候存在一个很大的问题。
编译器无法优化!或者说很难优化!

这里就需要一种balance,计算量和非连续访存之间的平衡。目前还没有什么特别的想法

MOE的具体实现与MLP的区别

总体来说MOE中是包含了大量的专家的,而每个专家进行计算都是使用MLP来完成,所以可以说MOE是对MLP的封装(当然,部分模型中是既有MOE部分也有MLP部分)。为什么会进行这样的封装呢?

MLP可以理解为一个专家对所有的tokens进行计算,得到最终的结果,相对来说模型的泛化性相比于MOE来说较低(可以理解为MOE每个专家负责不同领域的内容,而MLP是一个专家精通各个领域,当然这个行为并不可解释,主观上的臆想)。
MOE相较于MLP的一个突出的特点就在这里,与此而来的一个特点就是专家的稀疏性,每次进行decode都只有部分专家被激活来进行计算。

MOE中相比于MLP就会多一个权重来进行专家任务划分gate ,整体的计算流程如下:

  1. 通过gate来计算一个路由权重 router_weights,后续通过对router_weights进行概率化得到哪些专家需要对哪些token进行计算。
  2. 将专家与他需要处理的token进行计算,然后根据专家对每个Token最终结果的权重(也就是这个专家对这个token的话语权)再次计算,得到这个专家对于这些token的结果。
  3. 汇总所有专家对他负责的Token的结果,得到MOE的最终结果。