LlamaForCausalLM

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
graph TD
%% 主图方向
direction RL

%% 全局样式定义
classDef preprocess fill:#e8f8f5,stroke:#45b7d1;
classDef decoder fill:#f9d5e5,stroke:#e892a2;
classDef attention fill:#d4efdf,stroke:#7dcea0;
classDef ffn fill:#e8daef,stroke:#c39bd3;
classDef layer fill:#fdebd0,stroke:#eb984e;

subgraph "Model"
direction TB

%% 预处理模块
subgraph "Preprocessing"
direction TB
A[Input Text] --> B[Token Embedding]
B --> B1[Get KVCache]
B1 --> B2[Get cache_position]
B2 --> B3[Get position_ids]
B3 --> B4[Compute causal_mask]
B4 --> B5[rotary_emb]
class B,B1,B2,B3,B4,B5 preprocess;
end

%% 解码器堆叠层
subgraph "Decoder Stack"
direction TB
B5 --> D[Decoder Layer 1]
D --> E[Decoder Layer 2]
E --> F[...]
F --> G[Decoder Layer N]
G --> G1[Last LayerNorm]
class D,E,F,G layer;
end

%% 输出模块
subgraph "Output"
direction TB
G1 --> H[Lm_head]
H --> I[Output Layer]
end
end

%% 解码器层详细结构
subgraph "Decoder Layer Detail"
style D decoder
direction TB

D --> D0[hidden_states]
D0 --> D1[Layer Norm]
D1 --> D2[Self-Attention]
D2 --> D3[Add]
D0 -.- D3
D3 --> D5[post LayerNorm]
D5 --> D6[MLP]
D6 --> D7[Add]
D3 -.- D7

class D2 attention;
class D5,D6 ffn;
end

%% 注意力层详细结构
subgraph "Attention Layer Detail"
style D2 attention
direction LR

D2 --> D20[attention_input]
D20 --> D201[K]
D20 --> D202[V]
D20 --> D203[Q]
D20 --> D204[attention_mask]

D201 --> D211[repeat_kv_k]
D202 --> D212[repeat_kv_v]
D211 --> D21[Matmul]
D203 --> D21
D204 --> D214[slice]
D214 --> D22[Add]
D21 --> D22
D22 --> D23[softmax: float32]
D23 --> D24[Matmul:TP<1,2>]
D212 --> D24
end

%% 不可见连接引导
%% G1 -.-|缓冲连接| H %%
%% D7 -.-|隐藏连接| E %%
F -.-|...| G