这是用户在 2025-7-18 14:59 为 https://zhuanlan.zhihu.com/p/9912733791 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
查看 HuggingFace 模型的结构

查看 HuggingFace 模型的结构

17 人赞同了该文章

How To See the Architecture of A Loaded Model
如何查看已加载模型的结构

I spend several days to build up the online update weights feature in SGLang for RLHF workflow. Till now, I still couldn't figure out how to load the [name, weights] pairs into SGLang Engine. This is quite annoying, so I decide to dig into the code to figure out how it works.
我花了几天时间为 SGLang 的 RLHF 工作流程实现在线更新权重功能。直到现在,我仍然无法弄清楚如何将 [name, weights] 权重对加载到 SGLang 引擎中。这相当令人沮丧,因此我决定深入研究代码,弄清楚其工作原理。

Thus, this note is written to record my findings on how to investigate the structure of a model after it's loaded by Hugging Face or SGLang. We will first start with Hugging Face, and then move on to SGLang. All the codes are based on the meta-llama/Llama-3.2-1B-Instruct model.
因此,本文记录了我关于如何调查 Hugging Face 或 SGLang 加载后模型结构的发现。我们将首先从 Hugging Face 开始,然后转向 SGLang。所有代码均基于 meta-llama/Llama-3.2-1B-Instruct 模型。

Hugging Face

Loading a model from Hugging Face is quite simple with its direct API.
通过其直接 API 从 Hugging Face 加载模型非常简单。

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", torch_dtype="bfloat16").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
print(model)

Try to print the model, we get the following output:
尝试打印模型,我们得到以下输出:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
)

Let's go into the detailed component of the model, i.e., model.model, model.lm_head.
让我们深入了解模型的详细组成部分,即 model.model, model.lm_head

model.model

The whole class LlamaModel implements the Transformer decoder architecture.
整个类 LlamaModel 实现了 Transformer 解码器架构。

  1. Embedding Layer  嵌入层
  2. embed_tokens: Embedding(128256, 2048)
  3. 128256: Vocabulary size.   128256 : 词汇表大小。
  4. 2048: Dimensionality of each embedding vector.
    2048 : 每个嵌入向量的维度。
  5. Maps discrete vocabulary indices to continuous embedding vectors.
    将离散词汇索引映射为连续的嵌入向量。

  6. Decoder Layers  解码器层
  7. layers: ModuleList (16 x LlamaDecoderLayer)
  8. A stack of 16 decoder layers. Each LlamaDecoderLayer contains one self_attn, mlp, input_layernorm, post_attention_layernorm.
    由 16 个解码器层组成的堆栈。每个 LlamaDecoderLayer 包含一个 self_attn, mlp, input_layernorm, post_attention_layernorm

  9. self_attn: LlamaSdpaAttention: Computes self-attention scores and aggregates contextual information.
    self_attn: LlamaSdpaAttention :计算自注意力分数并聚合上下文信息。

  10. q_proj: Projects input features into the query space. Input: 2048, Output: 2048.
    q_proj : 将输入特征投影到查询空间。输入: 2048 ,输出: 2048
  11. k_proj: Projects input features into the key space. Input: 2048, Output: 512 (dimensionality reduction for efficiency).
    k_proj : 将输入特征投影到键空间。输入: 2048 ,输出: 512 (降维以提高效率)。
  12. v_proj: Projects input features into the value space. Input: 2048, Output: 512.
    v_proj : 将输入特征投影到值空间。输入: 2048 ,输出: 512
  13. o_proj: Projects attention outputs back to the input feature dimensionality. Input: 2048, Output: 2048.
    o_proj : 将注意力输出投影回输入特征维度。输入: 2048 ,输出: 2048
  14. rotary_emb: Rotary positional embeddings to encode sequence position.
    rotary_emb : 使用旋转位置编码来表示序列位置。

  15. mlp: LlamaMLP: Applies non-linear transformations through a multi-layer perceptron.
    mlp: LlamaMLP : 通过多层感知机应用非线性变换。

  16. gate_proj: Linear layer, Input: 2048, Output: 8192.
    gate_proj : 线性层,输入: 2048 ,输出: 8192
  17. up_proj: Linear layer, Input: 2048, Output: 8192.
    up_proj : 线性层,输入: 2048 ,输出: 8192
  18. down_proj: Linear layer, Input: 8192, Output: 2048.
    down_proj : 线性层,输入维度: 8192 ,输出维度: 2048
  19. act_fn: Activation function SiLU (Swish) introduces non-linearity.
    act_fn : 激活函数 SiLU (Swish) 引入非线性变换。

  20. input_layernorm: Applies RMSNorm to the layer input with dimensionality 2048.
    input_layernorm : 对维度为 2048 的层输入应用 RMSNorm 归一化。

  21. post_attention_layernorm: Applies RMSNorm after the attention mechanism.
    post_attention_layernorm : 在注意力机制后应用 RMSNorm 归一化。

  22. Global Normalization  全局归一化
  23. norm: LlamaRMSNorm((2048,), eps=1e-05)
  24. Applies RMSNorm to the final decoder output for stable feature scaling.
    对最终解码器输出应用 RMSNorm,实现稳定的特征缩放。

  25. Rotary Positional Embedding
    旋转位置编码

  26. rotary_emb: LlamaRotaryEmbedding()
  27. Encodes positional information using rotary embeddings to enhance sequence modeling.
    采用旋转嵌入技术编码位置信息,以增强序列建模能力。

model.lm_head

  • lm_head: Linear(in_features=2048, out_features=128256, bias=False)
  • A linear layer that maps the decoder's output features (dim: 2048) to the vocabulary size (128256).
    一个将解码器输出特征(维度: 2048 )映射到词汇表大小( 128256 )的线性层。
  • No bias term: Reduces the number of trainable parameters and computation complexity.
    无偏置项:减少可训练参数数量和计算复杂度。

Model State Dict  模型状态字典

In Pytorch, state_dict is a core mechanism for saving and loading the parameters and optimizer states of a model. Here is its function and principle:
在 PyTorch 中, state_dict 是保存和加载模型参数及优化器状态的核心机制。以下是其功能与原理:

  • state_dict: A Python dictionary object that maps each layer to its parameter tensor.
    state_dict :一个 Python 字典对象,将每个层映射到其参数张量。
  • model.state_dict(): Returns the state dictionary of the model, containing all the weights and biases.
    model.state_dict() :返回模型的状态字典,包含所有权重和偏置参数。
  • torch.save(model.state_dict(), PATH): Saves the state dictionary to a file.
    torch.save(model.state_dict(), PATH) : 将状态字典保存到文件中。
  • model.load_state_dict(torch.load(PATH)): Loads the state dictionary from a file.
    model.load_state_dict(torch.load(PATH)) : 从文件加载状态字典。

As you see, state dict is a dictionary, and contains the name and weights of each layer.
如你所见,状态字典是一个字典,包含每一层的名称和权重。

We first get the state dict of the model, and then get its VRAM usage.
我们首先获取模型的状态字典,然后计算其显存占用情况。

state_dict = model.state_dict()

total_memory = 0
for name, param in state_dict.items():
    param_memory = param.numel() * param.element_size()  # numel() gives the number of elements, element_size() gives the size in bytes
    total_memory += param_memory

total_memory_mb = total_memory / (1024 * 1024)
print(f"Total memory usage of the state_dict: {total_memory_mb:.2f} MB")

# Total memory usage of the state_dict: 2858.13 MB

A 1B model in bfloat16 precision takes about 2.8GB VRAM, that's reasonable.
一个采用 bfloat16 精度的 10 亿参数模型大约占用 2.8GB 显存,这个数值是合理的。

print(state_dict.keys())

# odict_keys(['model.embed_tokens.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.norm.weight', 'lm_head.weight'])

Different from the model architecture, the state dict unsqueezes the name and weights of all the components.
与模型架构不同,状态字典会展开所有组件的名称和权重。

Also, dict(model.named_parameters()).keys() gives the same result.
此外, dict(model.named_parameters()).keys() 也给出了相同的结果。

SGLang

After seeing the architecture of a loaded model by Hugging Face, it's much clear for SGLang. Still take the example of Llama-3.2-1B-Instruct model, here is the llama file how it's loaded.
通过查看 Hugging Face 加载的模型架构后,SGLang 的情况就清晰多了。仍以 Llama-3.2-1B-Instruct 模型为例,以下是加载该模型的 llama 文件内容。

Take a step in:  进一步深入:

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        embed_tokens_weight = None
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

        params_dict = dict(self.named_parameters())

        load_tie_word_embeddings = (
            hasattr(self.config, "tie_word_embeddings")
            and self.config.tie_word_embeddings
            and "lm_head.weight" in params_dict
        )

        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name or "projector" in name:
                continue
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            if name.startswith("model.vision_tower") and name not in params_dict:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Skip loading kv_scale from ckpts towards new design.
                if name.endswith(".kv_scale") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)

                if load_tie_word_embeddings and name == "model.embed_tokens.weight":
                    embed_tokens_weight = loaded_weight

        if load_tie_word_embeddings:
            # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
            param = self.lm_head.weight
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            if embed_tokens_weight is not None:
                weight_loader(param, embed_tokens_weight)

        apply_torchao_config_(self, params_dict, set(["proj.weight"]))

This function defined how a model's weights are loaded after parameter name mapping.
该函数定义了在参数名称映射后如何加载模型的权重。

stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]

Generally, each q_proj, k_proj, v_proj is mapped into a qkv_proj, and then named with q, k, v shard id. gate_proj and up_proj are mapped into gate_up_proj with shard id 0 and 1 respectively.
通常,每个 q_proj, k_proj, v_proj 会被映射到一个 qkv_proj ,然后以 q, k, v 分片 ID 命名。 gate_projup_proj 则分别映射到带有分片 ID 01gate_up_proj

Then, how to see the state dict of a loaded model?
那么,如何查看已加载模型的状态字典呢?

def get_weights_by_name(
        self, name: str, truncate_size: int = 100, tp_size: int = 1
    ) -> Optional[torch.Tensor]:
        """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.

        Only used for unit test with an unoptimized performance.
        For optimized performance, please use torch.save and torch.load.
        """
        try:
            mapped_name = name
            mapped_shard_id = None
            for param_name, weight_name, shard_id in self.stacked_params_mapping:
                if weight_name in name:
                    mapped_name = name.replace(weight_name, param_name)
                    mapped_shard_id = shard_id
                    break
            params_dict = dict(self.named_parameters())
            if mapped_name in params_dict:
                param = params_dict[mapped_name]
                if mapped_shard_id is not None:
                    if mapped_shard_id in ["q", "k", "v"]:
                        num_heads = self.config.num_attention_heads // tp_size
                        num_kv_heads = self.config.num_key_value_heads // tp_size
                        head_dim = (
                            self.config.hidden_size // self.config.num_attention_heads
                        )
                        if mapped_shard_id == "q":
                            offset = 0
                            size = num_heads * head_dim
                        elif mapped_shard_id == "k":
                            offset = num_heads * head_dim
                            size = num_kv_heads * head_dim
                        elif mapped_shard_id == "v":
                            offset = (num_heads + num_kv_heads) * head_dim
                            size = num_kv_heads * head_dim
                        weight = param.data.narrow(0, offset, size)
                    elif mapped_shard_id in [0, 1]:
                        intermediate_size = self.config.intermediate_size
                        hidden_size = self.config.hidden_size
                        slice_size = intermediate_size // tp_size
                        if mapped_shard_id == 0:  # gate_proj
                            offset = 0
                            size = slice_size
                        elif mapped_shard_id == 1:  # up_proj
                            offset = slice_size
                            size = slice_size

                        weight = param.data.narrow(0, offset, size)
                    else:
                        weight = param.data
                else:
                    weight = param.data
                if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
                    gathered_weights = [
                        torch.zeros_like(weight) for _ in range(tp_size)
                    ]
                    torch.distributed.all_gather(gathered_weights, weight)
                    weight = torch.cat(gathered_weights, dim=1)
                return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
            else:
                return None

This function defined how to get the weights of a parameter by its name. Read and write are the two sides of the same coin. 【TODO: add tp illustration】
该函数定义了如何通过参数名称获取其权重。读取和写入是同一枚硬币的两面。【待办:添加类型说明图示】

Note that for o_proj and down_proj, the weights should be gathered by all the GPUs.
注意,对于 o_projdown_proj ,权重应由所有 GPU 共同收集。

Read is much more difficult than write/load. You can load a new parameter by self.load_weights([(name, weights)]), but how to get the weights of a parameter? Refer to this docs decribing the update weights from distributed workers.
读取比写入/加载要困难得多。你可以通过 self.load_weights([(name, weights)]) 加载一个新参数,但如何获取某个参数的权重?请参阅这份描述从分布式工作节点更新权重的文档。

发布于 2024-11-30 10:35・美国
理性发言,友善互动

2 条评论
默认
最新
草木如织

为啥这里k_proj和q_proj的维度不一样[蹲]

2024-12-19 · 湖北  2024 年 12 月 19 日 · 湖北
Chayenne Zhao
真的么?
2024-12-19 · 美国
想来知乎工作?请发送邮件到 jobs@zhihu.com