概述
PEFT(Parameter-Efficient Fine-Tuning,参数高效微调)是一种针对大型预训练模型进行微调的技术,旨在提高模型在新任务上的性能,同时减少微调参数的数量和计算复杂度。从其定义来看,PEFT是一类技术的统称,并不单只某个算法或技术库;类似ANN算法,是一类算法的统称。
PEFT技术通过引入部分参数微调策略,降低大模型微调需要的参数数量,提高训练效率。
如果对LoRA微调技术有了解,上述这句就已经讲述了LoRA技术的思想,冻结了原始模型参数,引入新的数据集参数,对这部分数据参数微调训练。可以说LoRA微调技术是PEFT技术的一种。
PEFT的主要组成及部分原理有如下几点:
-
部分参数微调:只对模型的一部分参数进行微调,从而降低计算复杂度和内存需求。这部分参数通常是与新任务密切相关的部分,例如与任务相关的输出层和部分隐藏层参数;其数据来源具有针对性。 -
高效微调策略:采用多种高效的微调策略,如LoRA、Prefix、Prompt等微调技术,如LoRA微调就会冻结大部分预训练参数,仅仅更新训练部分参数使得模型在新任务上具有更好的性能。 -
深度学习框架与加速技术:PEFT技术中,不同的微调技术用到了不同的深度学习框架,比如Transformer、LSTM等;而为了训练的加速,支持Accelarate加速技术。 -
集成多种预训练大模型:PEFT技术支持多种大模型的集成,从而在大模型的基础上做微调。
HuggingFace PEFT Library
HuggingFace PEFT Library是PEFT技术的实现;GITHUB源码地址:PEFT[1]。
PEFT库支持的微调方法如下:
-
PROMPT_TUNING -
MULTITASK_PROMPT_TUNING -
P_TUNING -
PREFIX_TUNING -
LORA -
ADALORA -
ADAPTION_PROMPT -
IA3 -
LOHA -
LOKR -
OFT
以Llama2基于LoRA微调[2]为例,使用PEFT库加载预训练模型参数和微调模型参数:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel,PeftConfig
# 微调参数存放路径
finetune_model_path=''
config = PeftConfig.from_pretrained(finetune_model_path)
# 加载预训练模型权重参数
# 例如: base_model_name_or_path='meta-llama/Llama-2-7b-chat'
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
device_map = "cuda:0" if torch.cuda.is_available() else "auto"
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,device_map=device_map,torch_dtype=torch.float16,load_in_8bit=True,trust_remote_code=True,use_flash_attention_2=True)
# 加载微调模型权重参数
# 例如: finetune_model_path='FlagAlpha/Llama2-Chinese-7b-Chat-LoRA'
model = PeftModel.from_pretrained(model, finetune_model_path, device_map={"": 0})
model =model.eval()
input_ids = tokenizer(['<s>Human: 介绍一下北京n</s><s>Assistant: '], return_tensors="pt",add_special_tokens=False).input_ids
if torch.cuda.is_available():
input_ids = input_ids.to('cuda')
generate_input = {
"input_ids":input_ids,
"max_new_tokens":512,
"do_sample":True,
"top_k":50,
"top_p":0.95,
"temperature":0.3,
"repetition_penalty":1.3,
"eos_token_id":tokenizer.eos_token_id,
"bos_token_id":tokenizer.bos_token_id,
"pad_token_id":tokenizer.pad_token_id
}
generate_ids = model.generate(**generate_input)
text = tokenizer.decode(generate_ids[0])
print(text)
以ChatGLM基于LoRA微调[3]为例,使用PEFT库加载预训练模型参数和微调模型参数:
from transformers import AutoModel
import torch
from transformers import AutoTokenizer
from peft import PeftModel
import argparse
def generate(instruction, text):
with torch.no_grad():
input_text = f"指令:{instruction}n语句:{text}n答:"
ids = tokenizer.encode(input_text)
input_ids = torch.LongTensor([ids]).cuda()
output = peft_model.generate(
input_ids=input_ids,
max_length=500,
do_sample=False,
temperature=0.0,
num_return_sequences=1
)
output = tokenizer.decode(output[0])
answer = output.split("答:")[-1]
return answer.strip()
if __name__ == "__main__":
# 预训练模型参数路径
base_model="ZhipuAI/chatglm3-6b"
# lora微调训练后,模型参数路径
lora="LuXun-lora"
instruction="你是一个非常熟悉鲁迅风格的作家,用鲁迅风格的积极正面的语言改写,保持原来的意思:"
# 加载预训练模型
model = AutoModel.from_pretrained(base_model, trust_remote_code=True, load_in_8bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
if args.lora == "":
print("#> No lora model specified, using base model.")
peft_model = model.eval()
else:
print("#> Using lora model:", lora)
# 加载lora微调模型
peft_model = PeftModel.from_pretrained(model, lora).eval()
torch.set_default_tensor_type(torch.cuda.FloatTensor)
texts = [
"你好",
"有多少人工,便有多少智能。",
"落霞与孤鹜齐飞,秋水共长天一色。",
"我去买几个橘子,你就站在这里,不要走动。",
"学习计算机技术,是没有办法救中国的。",
"我怎么样都起不了床,我觉得我可能是得了抑郁症吧。",
"它是整个系统的支撑架构,连接处理器、内存、存储、显卡和外围端口等所有其他组件。",
"古巴导弹危机和越南战争是20世纪最大、最致命的两场冲突。古巴导弹危机涉及美国和苏联之间的僵局,因苏联在古巴设立核导弹基地而引发,而越南战争则是北方(由苏联支持)和南方(由美国支持)之间在印度支那持续的军事冲突。",
"齿槽力矩是指旋转设备受到齿轮牙齿阻力时施加的扭矩。",
"他的作品包括蒙娜丽莎和最后的晚餐,两者都被认为是杰作。",
"滑铁卢战役发生在1815年6月18日,是拿破仑战争的最后一场重大战役。"
]
for text in texts:
print(text)
print(generate(args.instruction, text), "n")
PEFT: https://github.com/huggingface/peft
[2]Llama2基于LoRA微调: https://github.com/LlamaFamily/Llama2-Chinese?tab=readme-ov-file#lora%E5%BE%AE%E8%B0%83
[3]ChatGLM基于LoRA微调: https://blog.csdn.net/saoqi_boy/article/details/135057599
原文始发于微信公众号(阿郎小哥的随笔驿站):聊聊PEFT微调技术
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/244120.html