LlamaForCausalLM

graph TD
    %% 主图方向
    direction RL

    %% 全局样式定义
    classDef preprocess fill:#e8f8f5,stroke:#45b7d1;
    classDef decoder fill:#f9d5e5,stroke:#e892a2;
    classDef attention fill:#d4efdf,stroke:#7dcea0;
    classDef ffn fill:#e8daef,stroke:#c39bd3;
    classDef layer fill:#fdebd0,stroke:#eb984e;

    subgraph "Model"
        direction TB

        %% 预处理模块
        subgraph "Preprocessing"
            direction TB
            A[Input Text] --> B[Token Embedding]
            B --> B1[Get KVCache]
            B1 --> B2[Get cache_position]
            B2 --> B3[Get position_ids]
            B3 --> B4[Compute causal_mask]
            B4 --> B5[rotary_emb]
            class B,B1,B2,B3,B4,B5 preprocess;
        end

        %% 解码器堆叠层
        subgraph "Decoder Stack"
            direction TB
            B5 --> D[Decoder Layer 1]
            D --> E[Decoder Layer 2]
            E --> F[...]
            F --> G[Decoder Layer N]
            G --> G1[Last LayerNorm]
            class D,E,F,G layer;
        end

        %% 输出模块
        subgraph "Output"
            direction TB
            G1 --> H[Lm_head]
            H --> I[Output Layer]
        end
    end

    %% 解码器层详细结构
    subgraph "Decoder Layer Detail"
        style D decoder
        direction TB

        D --> D0[hidden_states]
        D0 --> D1[Layer Norm]
        D1 --> D2[Self-Attention]
        D2 --> D3[Add]
        D0 -.- D3
        D3 --> D5[post LayerNorm]
        D5 --> D6[MLP]
        D6 --> D7[Add]
        D3 -.- D7
        
        class D2 attention;
        class D5,D6 ffn;
    end

    %% 注意力层详细结构
    subgraph "Attention Layer Detail"
        style D2 attention
        direction LR

        D2 --> D20[attention_input]
        D20 --> D201[K]
        D20 --> D202[V]
        D20 --> D203[Q]
        D20 --> D204[attention_mask]
        
        D201 --> D211[repeat_kv_k]
        D202 --> D212[repeat_kv_v]
        D211 --> D21[Matmul]
        D203 --> D21
        D204 --> D214[slice]
        D214 --> D22[Add]
        D21 --> D22
        D22 --> D23[softmax: float32]
        D23 --> D24[Matmul:TP<1,2>]
        D212 --> D24
    end

    %% 不可见连接引导
    %% G1 -.-|缓冲连接| H %% 
%%     D7 -.-|隐藏连接| E %%
    F -.-|...| G