蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键在上图中我们最直观的感受就是左右两张马里奥图像的清晰度差异 这样压缩后的图像可能会影响我们视觉的识别 换个角度 我们把 AI 原始模型当作是 800 万画素的图片 将其转换成为 30 万画素 同时人眼看不出差异 这个过程就完成了有效的模型压缩

大家好,欢迎来到IT知识分享网。

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

在上图中我们最直观的感受就是左右两张马里奥图像的清晰度差异,这样压缩后的图像可能会影响我们视觉的识别。换个角度,我们把AI原始模型当作是800万画素的图片,将其转换成为30万画素,同时人眼看不出差异,这个过程就完成了有效的模型压缩。

为了让模型可以在边缘设备上部署,同时具备低延迟、高准确率的性能。常见的三种模型压缩方法:模型量化、模型剪枝、知识蒸馏。

模型剪枝就是从训练好的模型中去除不重要的权重,知识蒸馏一般是通过将教师模型的知识迁移至学生模型。本文我们深入探讨模型量化是什么、量化的两种技术以及如何使用 PyTorch 对大语言模型进行量化感知训练。

一、模型量化是什么!

一句话:将模型数据从高精度浮点数(通常为 16 位)转换为低精度表示(通常为 8 位整数),从而降低神经网络的计算和内存需求。浮点数简单说就是有小数字的数字,例如123.。整数就是123,浮点转整数,就是数据取整数。

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

AI模型量化带来了存储差异、计算差异与功耗差异。单精度浮点数在电脑上一个数字就 4 Byte,而INT-8的整数只占1 Byte,所以量化前后就差了四倍的存储差异,上图是从FP32到INT8的绝对最大值映射方式。

比如ResNet18需要到126MB来存储,如果用INT8,模型就可以压缩到带宽1/4的大小则只需要31.5MB左右,同时也能节省功耗。

二、模型量化的两种方式!

量化模型有两种方法:

1. 训练后量化(PTQ)

PTQ 是一种在模型完全训练后执行的技术,它使用小型校准数据集来确定最佳量化参数,无需重新训练即可将模型从较高精度转换为较低精度。PTQ 资源效率更高,实施和部署速度更快,适用于无法进行再训练的场景。

以 INT4 为例,最简单的方式就是把浮点数的值域转换成 0 – 15。那中间的浮点数字要怎么转换,如下图所示,INT4中0~16个整数有(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15),在INT4每个数字间隔是1,所以最简单的方式是浮点数也想办法衡量15格,浮点切好的格子,每一个大小都是固定的标度,每一个里面的数字会对应到INT4内。

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

2. 量化感知训练(QAT)

QAT 简单说就是AI模型的权重直接用整数来进行训练,但通常在整数前进行训练,AI模型会先用浮点数训练后再用整数来进行Fine-tune训练。

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

Torchao QAT 流程包含两个步骤:准备,将伪量化操作插入模型的线性层;转换,在训练后将这些伪量化操作转换为真正的量化和反量化操作。

QAT 通常能够实现更高的准确率,因为模型能够在训练过程中自适应量化效应,因此更适合对量化误差敏感的模型。

三、Pytorch对大语言模型量化感知训练!

与训练后量化 (PTQ) 相比,PyTorch 中的 QAT 将 Llama3 在 hellaswag 上的准确率下降和 wikitext 上的困惑度下降分别高达 96% 和 68%。

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

PyTorch 中将 QAT 流程集成到 torchtune 中,并提供了在分布式环境中运行它的配方,类似于现有的完整微调的指南。用户可以通过运行以下命令在LLM 微调期间应用 QAT。

tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full

模型在单个 A100 GPU 上实现 > 2 倍加速和 ~60% 内存减少,Llama3-8B示例代码如下:

import torch from torchtune.models.llama3import llama3 from torchao.quantization.prototype.qatimportInt8DynActInt4WeightQATQuantizer model = llama3(vocab_size=4096, num_layers=16,                num_heads=16, num_kv_heads=4,                embed_dim=2048, max_seq_len=2048).cuda() qat_quant = Int8DynActInt4WeightQATQuantizer() model = qat_quant.prepare(model).train() #  ––– Kathy-like micro-fine-tune ––– optim = torch.optim.AdamW(model.parameters(), 1e-4) lossf = torch.nn.CrossEntropyLoss() for _ inrange(100):     ids   = torch.randint(0,4096,(2,128)).cuda()     label = torch.randint(0,4096,(2,128)).cuda()     loss  = lossf(model(ids), label)     optim.zero_grad(); loss.backward(); optim.step() model_quant = qat_quant.convert(model) torch.save(model_quant.state_dict(),"llama3_int4int8.pth") 

PyTorch 及其成熟的 QAT 工具链非常方便,可以轻松量化任何模型,但是为了成功部署,在实际的工作中确实需要考虑很多因素,包括目标平台及其支持的算子。

关于作者码科智能

专注于多模态大模型与计算机视觉方向,面向多个AI+场景,分享前沿算法,通用工具,开源项目及场景应用等。最后,关注视觉大模型与多模态大模型的小伙伴们可回复‘加群’进入行业交流群!

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

大模型能否代替专用视觉模型?开源视觉能力基准测试

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

开放词汇检测范式再升级!IDEA重磅开源指代目标检测模型Rex-Thinker

蒸馏、剪枝、量化三大模型压缩技术!在边缘设备上部署模型的关键

2.4k Star!布局分析、文字识别与关系预测三位一体!开源OCR文档解析新范式

免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/185028.html

(0)
上一篇 2025-08-05 13:26
下一篇 2025-08-05 13:33

相关推荐

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

关注微信