这是用户在 2025-7-20 16:43 为 https://github.com/VainF/Torch-Pruning 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?
Skip to content
Owner avatar Torch-Pruning Public

[CVPR 2023] DepGraph: Towards Any Structural Pruning; LLMs, Vision Foundation Models, etc.

License

Open in github.dev Open in a new github.dev tab Open in codespace

VainF/Torch-Pruning

Add file

Add file

Folders and files

NameName
Last commit message
Last commit date

Latest commit

7c4b3f7 · Jul 4, 2025

History

1,514 Commits
Jun 13, 2025
Jul 21, 2024
Jul 4, 2025
Mar 23, 2025
Jun 14, 2025
Jul 4, 2025
Jun 4, 2024
Sep 4, 2023
Dec 16, 2019
Jul 4, 2025
Mar 23, 2025
Jul 21, 2024
Jul 4, 2025

Repository files navigation

Test Status Tested PyTorch Versions License Downloads Latest Version Open In Colab arXiv

Torch-Pruning (TP) is a framework for structural pruning with the following features:
Torch-Pruning (TP) 是一个结构化剪枝框架,具有以下特性:

  • General-purpose Pruning Toolkit: TP enables structural pruning for a wide range of deep neural networks. Different from torch.nn.utils.prune that zeroizes parameters via masking, Torch-Pruning deploys an algorithm called ⚡ DepGraph to group and remove coupled parameters.
    通用剪枝工具包:TP 支持对多种深度神经网络进行结构化剪枝。与通过掩码将参数置零的 torch.nn.utils.prune 不同,Torch-Pruning 采用名为 ⚡ DepGraph 的算法来分组并移除耦合参数。
  • Examples: Pruning off-the-shelf models from Huggingface, Timm, Torchvision, including Large Language Models (LLMs), Segment Anything Model (SAM), Diffusion Models, Vision Transformers, ConvNext, Yolov7, yolov8, Swin Transformers, BERT, FasterRCNN, SSD, ResNe(X)t, DenseNet, RegNet, DeepLab, etc. A detailed list can be found in 🎨 Examples.
    示例:可剪枝的现成模型包括来自 Huggingface、Timm、Torchvision 的各类模型,如大型语言模型 (LLMs)、Segment Anything Model (SAM)、扩散模型、视觉 Transformer、ConvNext、Yolov7、yolov8、Swin Transformers、BERT、FasterRCNN、SSD、ResNe(X)t、DenseNet、RegNet、DeepLab 等。完整列表详见 🎨 示例章节。

For more technical details, please refer to our CVPR'23 paper.
更多技术细节请参阅我们在 CVPR'23 发表的论文。

DepGraph: Towards Any Structural Pruning
DepGraph:迈向任意结构化剪枝

Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
方功凡,马昕寅,宋明利,Michael Bi Mi,王新超

xML Lab, National University of Singapore
新加坡国立大学 xML 实验室

Update:  更新:

  • 🔥 2025.03.24 Examples for pruning DeepSeek-R1-Distill.
    🔥 2025.03.24 新增 DeepSeek-R1-Distill 剪枝示例。
  • 🔥 2024.11.17 We are working to add more examples for LLMs, such as Llama-2/3, Phi-3, Qwen-2/2.5.
    🔥 2024.11.17 我们正在为 LLMs(如 Llama-2/3、Phi-3、Qwen-2/2.5 等)添加更多剪枝示例。
  • 🔥 2024.09.27 Check our latest work, MaskLLM (NeurIPS 24 Spotlight), for learnable semi-structured sparsity of LLMs.
    🔥 2024.09.27 查看我们在 NeurIPS 24 Spotlight 的最新工作 MaskLLM,实现 LLMs 的可学习半结构化稀疏化。
  • 🔥 2024.07.20 Add Isomorphic Pruning (ECCV'24). A SOTA method for Vision Transformers and Modern CNNs.
    🔥 2024.07.20 新增同构剪枝方法(ECCV'24),面向视觉 Transformer 和现代 CNN 的 SOTA 方案。

Contact Us:  联系我们

Please do not hesitate to open an issue if you encounter any problems with the library or the paper.
如果您在使用该库或论文过程中遇到任何问题,请随时提交 issue。

Or Join our WeChat group for more discussions: ✉️ Group-2 (>300/500), ✉️ Group-1 (500/500, FULL).
或加入我们的微信群进行更多讨论:✉️ 二群(300/500),✉️ 一群(500/500,已满)。

Table of Contents  目录

Installation  安装

Torch-Pruning only relies on PyTorch and Numpy, and it is compatible with PyTorch 1.x and 2.x. To install the latest version, run the following command:
Torch-Pruning 仅依赖 PyTorch 和 Numpy,兼容 PyTorch 1.x 和 2.x 版本。要安装最新版本,请运行以下命令:

pip install torch-pruning --upgrade

For editable installation:
如需可编辑安装:

git clone https://github.com/VainF/Torch-Pruning.git
cd Torch-Pruning && pip install -e .

Quickstart  快速入门

Here we provide a quick start for Torch-Pruning. More explained details can be found in Tutorals
这里我们提供 Torch-Pruning 的快速入门指南,更多详细说明请参阅教程部分

Why Torch-Pruning?  为什么选择 Torch-Pruning?

In structural pruning, the removal of a single parameter may affect multiple layers. For example, pruning an output dimension of a linear layer will require the removal of the corresponding input dimension in the following linear layer as shown in (a). This dependency between layers makes it challenging to prune complicated networks manually. Torch-Pruning addresses this issue by introducing a graph-based algorithm called DepGraph to automatically identify dependencies and collect groups for pruning.
在结构化剪枝中,移除单个参数可能会影响多个层。例如,剪枝线性层的输出维度将需要移除后续线性层中对应的输入维度,如图(a)所示。这种层间依赖关系使得手动剪枝复杂网络变得极具挑战性。Torch-Pruning 通过引入基于图的算法 DepGraph 来自动识别依赖关系并收集剪枝组,从而解决了这一问题。

How It Works: DepGraph  工作原理:DepGraph

Important  重要

Please make sure that AutoGrad is enabled since TP will analyze the model structure with the Pytorch AutoGrad. This means we need to remove torch.no_grad() or something similar when building the dependency graph.
请确保已启用 AutoGrad 功能,因为 TP(Torch-Pruning)将使用 Pytorch 的 AutoGrad 来分析模型结构。这意味着在构建依赖图时,我们需要移除类似 torch.no_grad() 的标记。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. Build dependency graph for a resnet18. This requires a dummy input for forwarding
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. To prune the output channels of model.conv1, we need to find the corresponding group with a pruning function and pruning indices.
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. Do the pruning
if DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # clear gradients to avoid a large file size
torch.save(model, 'model.pth') # !! no .state_dict here since the structure has been changed after pruning
model = torch.load('model.pth') # load the pruned model. you may need torch.load('model.pth', weights_only=False) for PyTorch 2.6.0+.

The above example shows the core algorithm, DepGraph, that captures the dependencies in structural pruning. The target layer model.conv1 is coupled with multiple layers, necessitating their simultaneous removal in structural pruning. We can print the group to take a look at the internal dependencies. In the subsequent outputs, "A => B" indicates that pruning operation "A" triggers pruning operation "B." The first group[0] refers to the root of pruning. For more details about grouping, please refer to Wiki - DepGraph & Group.
上述示例展示了结构剪枝中的核心算法 DepGraph,它能捕捉层间的依赖关系。目标层 model.conv1 与多个层存在耦合,因此在结构剪枝时需要同步移除这些层。我们可以打印该组别来查看内部依赖关系。在后续输出中,"A => B"表示剪枝操作"A"会触发剪枝操作"B"。第一个组 group[0]对应剪枝的根节点。关于分组机制的更多细节,请参阅 Wiki 文档《DepGraph & Group》。

print(group.details()) # or print(group)
--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs (3) =[2, 6, 9]  (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] 
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs (3) =[2, 6, 9] 
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs (3) =[2, 6, 9] 
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs (3) =[2, 6, 9] 
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] 
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] 
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs (3) =[2, 6, 9] 
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs (3) =[2, 6, 9] 
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] 
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs (3) =[2, 6, 9] 
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs (3) =[2, 6, 9] 
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs (3) =[2, 6, 9] 
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] 
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] 
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs (3) =[2, 6, 9] 
--------------------------------

How to scan all groups (Advanced):
如何扫描所有分组(高级):

There might be many groups in a model. We can use DG.get_all_groups(ignored_layers, root_module_types) to scan all prunable groups sequentially. Each group will begin with a layer that matches the one nn.Module class in root_module_types. The ignored_layers parameter is used to skip some layers that should not be pruned. For example, we can skip the first convolution layer in a ResNet model.
模型中可能存在多个组。我们可以使用 DG.get_all_groups(ignored_layers, root_module_types) 依次扫描所有可剪枝组。每个组将以匹配 root_module_typesnn.Module 类的层开始。 ignored_layers 参数用于跳过不应剪枝的某些层。例如,我们可以跳过 ResNet 模型中的第一个卷积层。

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # Handle groups in sequential order
    idxs = [2,4,6] # your pruning indices, feel free to change them
    group.prune(idxs=idxs)
    print(group)

High-level Pruners  高级剪枝器

Note  注意

The pruning ratio: In TP, the pruning_ratio refers to the pruning ratio of channels/dims. Since both in & out dims will be removed by p , the actual parameter_pruning_ratio of will be roughly 1 ( 1 p ) 2 . To remove 50% of parameters, you may use pruning_ratio=0.30 instead, which leads to the actual parameter pruning ratio of 1 ( 1 0.3 ) 2 = 0.51 (51% parameters removed).
剪枝比例:在 TP 中, pruning_ratio 表示通道/维度的剪枝比例。由于输入和输出维度都会被 p 移除,实际的 parameter_pruning_ratio 将约为 1 ( 1 p ) 2 。若要移除 50%的参数,可改用 pruning_ratio=0.30 ,此时实际参数剪枝比例达到 1 ( 1 0.3 ) 2 = 0.51 (移除 51%参数)。

With DepGraph, we developed several high-level pruners to facilitate effortless pruning. By specifying the desired channel pruning ratio, the pruner will scan all prunable groups, estimate weight importance and perform pruning. You can fine-tune the remaining weights using your own training code. For detailed information on this process, please refer to this tutorial, which shows how to implement a Network Slimming (ICCV 2017) pruner from scratch. Additionally, a more practical example is available in VainF/Isomorphic-Pruning for ViT and ConvNext pruning.
借助 DepGraph,我们开发了多种高级剪枝器以实现轻松剪枝。只需指定目标通道剪枝比例,剪枝器便会扫描所有可剪枝组别,评估权重重要性并执行剪枝操作。您可以使用自己的训练代码对剩余权重进行微调。关于该流程的详细信息,请参阅本教程,其中演示了如何从零开始实现网络瘦身(ICCV 2017)剪枝器。此外,VainF/Isomorphic-Pruning 项目中还提供了针对 ViT 和 ConvNext 剪枝的更实用案例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. Importance criterion, here we calculate the L2 Norm of grouped weights as the importance score
imp = tp.importance.GroupMagnitudeImportance(p=2) 

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.BasePruner( # We can always choose BasePruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
    ignored_layers=ignored_layers,
    round_to=8, # It's recommended to round dims/channels to 4x or 8x for acceleration. Please see: https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html
)

# 3. Prune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
tp.utils.print_tool.before_pruning(model) # or print(model)
pruner.step()
tp.utils.print_tool.after_pruning(model) # or print(model), this util will show the difference before and after pruning
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")


# 4. finetune the pruned model using your own code.
# finetune(model)
# ...
Output  输出

The model difference before and after pruning will be highlighted by something like (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False).

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) => (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
...
     (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) => (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) => (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) => (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True) => (fc): Linear(in_features=256, out_features=1000, bias=True)
)

MACs: 1.822177768 G -> 0.487202536 G, #Params: 11.689512 M -> 3.05588 M

Global Pruning and Isomorphic Pruning
全局剪枝与同构剪枝

Global pruning performs importance ranking on all layers, which has the potential to find better structures. This can be easily achieved by setting global_pruning=True in the pruner. While this strategy can possibly offer performance advantages, it also carries the potential of overly pruning specific layers, resulting in a substantial decline in overall performance. We provide an alternative algorithm called Isomorphic Pruning to alleviate this issue, which can be enabled with isomorphic=True. Comprehensive examples for ViT & ConvNext pruning are available in this project.
全局剪枝会对所有层进行重要性排序,从而有可能发现更优结构。通过设置 global_pruning=True 参数即可轻松实现该策略。虽然这种方法可能带来性能优势,但也存在过度剪枝特定层导致整体性能大幅下降的风险。我们提供了名为同构剪枝的替代算法来缓解此问题,可通过 isomorphic=True 参数启用。本项目包含针对 ViT 和 ConvNext 剪枝的完整示例。

pruner = tp.pruner.BasePruner(
    ...
    isomorphic=True, # enable isomorphic pruning to improve global ranking
    global_pruning=True, # global pruning
)

Pruning Ratios  剪枝比例

The argument pruning_ratio detemines the default pruning ratio. If you want to customize the pruning ratio for some layers or blocks, you can use pruning_ratio_dict. The key of the dict can be a single nn.Module or a tuple of nn.Module. In the second case, all modules in the tuple will form a scope and share the user-defined pruning ratio and compete to be pruned.
参数 pruning_ratio 决定了默认的剪枝比例。若需为某些层或模块定制剪枝比例,可使用 pruning_ratio_dict 。字典的键可以是单个 nn.Module 或由 nn.Module 组成的元组。后者情况下,元组内所有模块将形成 scope ,共享用户定义的剪枝比例并参与剪枝竞争。

pruner = tp.pruner.BasePruner(
    ...
    global_pruning=True,
    pruning_ratio=0.5, # default pruning ratio
    # layer1 & layer2 will share a total pruning ratio of 0.4 while layer 3 will have a pruning ratio of 0.2
    pruning_ratio_dict = {(model.layer1, model.layer2): 0.4, model.layer3: 0.2}, 
)

Sparse Training (Optional)
稀疏训练(可选)

Some pruners like BNScalePruner and GroupNormPruner support sparse training. This can be easily achieved by inserting pruner.update_regularizer() and pruner.regularize(model) in your standard training loops. The pruner will accumulate the regularization gradients to .grad. Sparse training is optional and may not always gaurentee better performance. Be careful when using it.
部分剪枝器如 BNScalePruner 和 GroupNormPruner 支持稀疏训练。只需在标准训练循环中插入 pruner.update_regularizer()pruner.regularize(model) 即可轻松实现。剪枝器会将正则化梯度累积到 .grad 。稀疏训练是可选项,并不总能保证获得更好的性能,使用时需谨慎。

for epoch in range(epochs):
    model.train()
    pruner.update_regularizer() # <== initialize regularizer
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, target)
        loss.backward() # after loss.backward()
        pruner.regularize(model) # <== for sparse training
        optimizer.step() # before optimizer.step()

Interactive Pruning  交互式剪枝

All high-level pruners offer support for interactive pruning. You can utilize the method pruner.step(interactive=True) to retrieve all the groups and interactively prune them by calling group.prune(). This feature is particularly useful if you want to control or monitor the pruning process.
所有高级剪枝器都支持交互式剪枝。您可以使用 pruner.step(interactive=True) 方法检索所有组,并通过调用 group.prune() 进行交互式剪枝。这一特性在需要控制或监控剪枝过程时尤为实用。

for i in range(iterative_steps):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        dep, idxs = group[0] # get the idxs
        target_module = dep.target.module # get the root module
        pruning_fn = dep.handler # get the pruning function
        group.prune()
        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

Pruning by Masking  基于掩码的剪枝

It is possible to implement masking-based Pruning leveraging interactive=True, which zeros out parameters without removing them. An example can be found in tests/test_soft_pruning.py
利用 interactive=True 可以实现基于掩码的剪枝,该方法将参数置零而非直接删除。具体示例可参考 tests/test_soft_pruning.py 测试文件

Group-level Pruning  组级剪枝

With DepGraph, it is easy to design some "group-level" importance scores to estimate the importance of a whole group rather than a single layer. This feature can be also used to sparsify coupled layers, making all the to-be-pruned parameters consistently sparse. In Torch-pruning, all pruners work at the group level. Check the following results to see how grouping improves the performance of pruning.
借助 DepGraph,可以轻松设计一些"组级别"的重要性评分来评估整个组而非单个层的重要性。该特性还可用于稀疏化耦合层,使所有待剪枝参数保持一致的稀疏性。在 Torch-pruning 中,所有剪枝器都在组级别工作。查看以下结果了解分组如何提升剪枝性能。

  • Pruning a ResNet50 pre-trained on ImageNet-1K without fine-tuning.
    在未经微调的情况下,对预训练于 ImageNet-1K 的 ResNet50 进行剪枝。
  • Pruning a Vision Transformer pre-trained on ImageNet-1K without fine-tuning.
    在不进行微调的情况下,对基于 ImageNet-1K 预训练的 Vision Transformer 进行剪枝。

Modify static attributes or forward functions
修改静态属性或前向函数

In some implementations, model forwarding might rely on static attributes. For example in convformer_s18 of timm, we have self.shape which will be changed after pruning. These attributes should be updated manually since it is impossible for TP to know the purpose of these attributes.
在某些实现中,模型前向传播可能依赖静态属性。例如在 timm 库的 convformer_s18 中,存在 self.shape 属性会在剪枝后发生变化。由于 Torch-Pruning(TP)无法识别这些属性的具体用途,需要手动更新这些属性。

class Scale(nn.Module):
    """
    Scale vector by element multiplications.
    """

    def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
        super().__init__()
        self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
        self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)

    def forward(self, x):
        return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning

Save and Load  保存与加载

The following script saves the whole model object (structure+weights) as a 'model.pth'. You can load it using the standard PyTorch API. Just remember that we save and load the whole model without .state_dict or .load_state_dict, since the pruned sturctured will be different from the original definition in your model.py.
以下脚本将整个模型对象(结构+权重)保存为'model.pth'。您可以使用标准 PyTorch API 加载它。请注意,我们保存和加载的是完整模型,不包含 .state_dict.load_state_dict 标记,因为剪枝后的结构会与您在 model.py 中原始定义不同。

model.zero_grad() # Remove gradients
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model
# For PyTorch 2.6.0+, you may need weights_only=False to enable model loading
# model = torch.load('model.pth', weights_only=False)

Low-level Pruning Functions
底层剪枝函数

In Torch-Pruning, we provide a series of low-level pruning functions that only prune a single layer or module. To manually prune the model.conv1 of a ResNet-18, the pruning pipeline should look like this:
在 Torch-Pruning 中,我们提供了一系列底层剪枝函数,这些函数仅对单个层或模块进行剪枝。若要手动剪枝 ResNet-18 的 model.conv1 ,剪枝流程应如下所示:

tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...

The following pruning functions are available:
可用的剪枝函数如下:

'prune_conv_out_channels',
'prune_conv_in_channels',
'prune_depthwise_conv_out_channels',
'prune_depthwise_conv_in_channels',
'prune_batchnorm_out_channels',
'prune_batchnorm_in_channels',
'prune_linear_out_channels',
'prune_linear_in_channels',
'prune_prelu_out_channels',
'prune_prelu_in_channels',
'prune_layernorm_out_channels',
'prune_layernorm_in_channels',
'prune_embedding_out_channels',
'prune_embedding_in_channels',
'prune_parameter_out_channels',
'prune_parameter_in_channels',
'prune_multihead_attention_out_channels',
'prune_multihead_attention_in_channels',
'prune_groupnorm_out_channels',
'prune_groupnorm_in_channels',
'prune_instancenorm_out_channels',
'prune_instancenorm_in_channels',

Customized Layers  自定义层

Please refer to examples/transformers/prune_hf_swin.py, which implements a new pruner for the customized module SwinPatchMerging. Another simple example is available at tests/test_customized_layer.py.
请参考 examples/transformers/prune_hf_swin.py,该文件为自定义模块 SwinPatchMerging 实现了一个新的剪枝器。另一个简单示例可在 tests/test_customized_layer.py 中找到。

Reproduce Paper Results  复现论文结果

Please see reproduce.  请查看 reproduce。

Our results on {ResNet-56 / CIFAR-10 / 2.00x}
我们在 {ResNet-56 / CIFAR-10 / 2.00x} 上的实验结果

Method  方法 Base (%)  基准(%) Pruned (%)  剪枝后(%) Δ Acc (%)   Δ 准确率(%) Speed Up  加速
NIPS [1] - - -0.03 1.76x  1.76 倍
Geometric [2]  几何 [2] 93.59 93.26 -0.33 1.70x  1.70 倍
Polar [3] 93.80 93.83 +0.03 1.88x  1.88 倍
CP [4] 92.80 91.80 -1.00 2.00x  2.00 倍
AMC [5] 92.80 91.90 -0.90 2.00x  2.00 倍
HRank [6] 93.26 92.17 -0.09 2.00x  2.00 倍
SFP [7] 93.59 93.36 +0.23 2.11x  2.11 倍
ResRep [8] 93.71 93.71 +0.00 2.12x  2.12 倍
Ours-L1  我们的 L1 方法 93.53 92.93 -0.60 2.12x  2.12 倍
Ours-BN  我们的 BN 方法 93.53 93.29 -0.24 2.12x  2.12 倍
Ours-Group  我们组 93.53 93.77 +0.38 2.13x  2.13 倍

Latency  延迟

Latency test on ResNet-50, Batch Size=64.
ResNet-50 上的延迟测试,批处理大小=64。

[Iter 0]        Pruning ratio: 0.00,         MACs: 4.12 G,   Params: 25.56 M,        Latency: 45.22 ms +- 0.03 ms
[Iter 1]        Pruning ratio: 0.05,         MACs: 3.68 G,   Params: 22.97 M,        Latency: 46.53 ms +- 0.06 ms
[Iter 2]        Pruning ratio: 0.10,         MACs: 3.31 G,   Params: 20.63 M,        Latency: 43.85 ms +- 0.08 ms
[Iter 3]        Pruning ratio: 0.15,         MACs: 2.97 G,   Params: 18.36 M,        Latency: 41.22 ms +- 0.10 ms
[Iter 4]        Pruning ratio: 0.20,         MACs: 2.63 G,   Params: 16.27 M,        Latency: 39.28 ms +- 0.20 ms
[Iter 5]        Pruning ratio: 0.25,         MACs: 2.35 G,   Params: 14.39 M,        Latency: 34.60 ms +- 0.19 ms
[Iter 6]        Pruning ratio: 0.30,         MACs: 2.02 G,   Params: 12.46 M,        Latency: 33.38 ms +- 0.27 ms
[Iter 7]        Pruning ratio: 0.35,         MACs: 1.74 G,   Params: 10.75 M,        Latency: 31.46 ms +- 0.20 ms
[Iter 8]        Pruning ratio: 0.40,         MACs: 1.50 G,   Params: 9.14 M,         Latency: 29.04 ms +- 0.19 ms
[Iter 9]        Pruning ratio: 0.45,         MACs: 1.26 G,   Params: 7.68 M,         Latency: 27.47 ms +- 0.28 ms
[Iter 10]       Pruning ratio: 0.50,         MACs: 1.07 G,   Params: 6.41 M,         Latency: 20.68 ms +- 0.13 ms
[Iter 11]       Pruning ratio: 0.55,         MACs: 0.85 G,   Params: 5.14 M,         Latency: 20.48 ms +- 0.21 ms
[Iter 12]       Pruning ratio: 0.60,         MACs: 0.67 G,   Params: 4.07 M,         Latency: 18.12 ms +- 0.15 ms
[Iter 13]       Pruning ratio: 0.65,         MACs: 0.53 G,   Params: 3.10 M,         Latency: 15.19 ms +- 0.01 ms
[Iter 14]       Pruning ratio: 0.70,         MACs: 0.39 G,   Params: 2.28 M,         Latency: 13.47 ms +- 0.01 ms
[Iter 15]       Pruning ratio: 0.75,         MACs: 0.29 G,   Params: 1.61 M,         Latency: 10.07 ms +- 0.01 ms
[Iter 16]       Pruning ratio: 0.80,         MACs: 0.18 G,   Params: 1.01 M,         Latency: 8.96 ms +- 0.02 ms
[Iter 17]       Pruning ratio: 0.85,         MACs: 0.10 G,   Params: 0.57 M,         Latency: 7.03 ms +- 0.04 ms
[Iter 18]       Pruning ratio: 0.90,         MACs: 0.05 G,   Params: 0.25 M,         Latency: 5.81 ms +- 0.03 ms
[Iter 19]       Pruning ratio: 0.95,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.70 ms +- 0.03 ms
[Iter 20]       Pruning ratio: 1.00,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.71 ms +- 0.03 ms

Series of Works  系列作品

DepGraph: Towards Any Structural Pruning [Project] [Paper]
DepGraph:迈向任意结构化剪枝 [项目] [论文]

Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
方功凡,马新寅,宋明利,Michael Bi Mi,王新超

CVPR 2023

Isomorphic Pruning for Vision Models [Project] [Arxiv]
视觉模型同构剪枝 [项目] [Arxiv]

Gongfan Fang, Xinyin Ma, Michael Bi Mi, Xinchao Wang
方功凡,马新寅,Michael Bi Mi,王新超

ECCV 2024

LLM-Pruner: On the Structural Pruning of Large Language Models [Project] [arXiv]
LLM-Pruner:大型语言模型的结构化剪枝 [项目] [arXiv]

Xinyin Ma, Gongfan Fang, Xinchao Wang
马新音,方功凡,王新超

NeurIPS 2023  神经信息处理系统大会 2023

Structural Pruning for Diffusion Models [Project] [arxiv]
扩散模型的结构化剪枝 [项目] [arxiv]

Gongfan Fang, Xinyin Ma, Xinchao Wang
方功凡,马昕寅,王新超

NeurIPS 2023  神经信息处理系统大会 2023

DeepCache: Accelerating Diffusion Models for Free [Project] [Arxiv]
DeepCache:零成本加速扩散模型 [项目] [Arxiv]

Xinyin Ma, Gongfan Fang, and Xinchao Wang
马昕寅,方功凡,王新超

CVPR 2024

SlimSAM: 0.1% Data Makes Segment Anything Slim [Project] [Arxiv]
SlimSAM:0.1%数据实现 Segment Anything 模型瘦身 [项目] [Arxiv]

Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang
陈子庚,方功凡,马新寅,王新超

Preprint 2023  2023 年预印本

Citation  引用

@inproceedings{fang2023depgraph,
  title={Depgraph: Towards any structural pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={16091--16101},
  year={2023}
}