聊聊PEFT微调技术

概述

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调)是一种针对大型预训练模型进行微调的技术,旨在提高模型在新任务上的性能,同时减少微调参数的数量和计算复杂度。从其定义来看,PEFT是一类技术的统称,并不单只某个算法或技术库;类似ANN算法,是一类算法的统称。

PEFT技术通过引入部分参数微调策略,降低大模型微调需要的参数数量,提高训练效率。

如果对LoRA微调技术有了解,上述这句就已经讲述了LoRA技术的思想,冻结了原始模型参数,引入新的数据集参数,对这部分数据参数微调训练。可以说LoRA微调技术是PEFT技术的一种。

PEFT的主要组成及部分原理有如下几点:

  1. 部分参数微调:只对模型的一部分参数进行微调,从而降低计算复杂度和内存需求。这部分参数通常是与新任务密切相关的部分,例如与任务相关的输出层和部分隐藏层参数;其数据来源具有针对性。
  2. 高效微调策略:采用多种高效的微调策略,如LoRA、Prefix、Prompt等微调技术,如LoRA微调就会冻结大部分预训练参数,仅仅更新训练部分参数使得模型在新任务上具有更好的性能。
  3. 深度学习框架与加速技术:PEFT技术中,不同的微调技术用到了不同的深度学习框架,比如Transformer、LSTM等;而为了训练的加速,支持Accelarate加速技术。
  4. 集成多种预训练大模型:PEFT技术支持多种大模型的集成,从而在大模型的基础上做微调。

HuggingFace PEFT Library

HuggingFace PEFT Library是PEFT技术的实现;GITHUB源码地址:PEFT[1]

PEFT库支持的微调方法如下:

  • PROMPT_TUNING
  • MULTITASK_PROMPT_TUNING
  • P_TUNING
  • PREFIX_TUNING
  • LORA
  • ADALORA
  • ADAPTION_PROMPT
  • IA3
  • LOHA
  • LOKR
  • OFT

Llama2基于LoRA微调[2]为例,使用PEFT库加载预训练模型参数微调模型参数:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel,PeftConfig

# 微调参数存放路径
finetune_model_path=''  
config = PeftConfig.from_pretrained(finetune_model_path)

# 加载预训练模型权重参数
# 例如: base_model_name_or_path='meta-llama/Llama-2-7b-chat'
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path,use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
device_map = "cuda:0" if torch.cuda.is_available() else "auto"
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,device_map=device_map,torch_dtype=torch.float16,load_in_8bit=True,trust_remote_code=True,use_flash_attention_2=True)

# 加载微调模型权重参数
# 例如: finetune_model_path='FlagAlpha/Llama2-Chinese-7b-Chat-LoRA'
model = PeftModel.from_pretrained(model, finetune_model_path, device_map={""0})
model =model.eval()

input_ids = tokenizer(['<s>Human: 介绍一下北京n</s><s>Assistant: '], return_tensors="pt",add_special_tokens=False).input_ids
if torch.cuda.is_available():
    input_ids = input_ids.to('cuda')
generate_input = {
    "input_ids":input_ids,
    "max_new_tokens":512,
    "do_sample":True,
    "top_k":50,
    "top_p":0.95,
    "temperature":0.3,
    "repetition_penalty":1.3,
    "eos_token_id":tokenizer.eos_token_id,
    "bos_token_id":tokenizer.bos_token_id,
    "pad_token_id":tokenizer.pad_token_id
}
generate_ids  = model.generate(**generate_input)
text = tokenizer.decode(generate_ids[0])
print(text)

ChatGLM基于LoRA微调[3]为例,使用PEFT库加载预训练模型参数微调模型参数:

from transformers import AutoModel
import torch

from transformers import AutoTokenizer

from peft import PeftModel
import argparse


def generate(instruction, text):
    with torch.no_grad():
        input_text = f"指令:{instruction}n语句:{text}n答:"
        ids = tokenizer.encode(input_text)
        input_ids = torch.LongTensor([ids]).cuda()
        output = peft_model.generate(
            input_ids=input_ids,
            max_length=500,
            do_sample=False,
            temperature=0.0,
            num_return_sequences=1
        )
        output = tokenizer.decode(output[0])
        answer = output.split("答:")[-1]
    return answer.strip()


if __name__ == "__main__":
    # 预训练模型参数路径
    base_model="ZhipuAI/chatglm3-6b"

    # lora微调训练后,模型参数路径
    lora="LuXun-lora"
    instruction="你是一个非常熟悉鲁迅风格的作家,用鲁迅风格的积极正面的语言改写,保持原来的意思:"

    # 加载预训练模型
    model = AutoModel.from_pretrained(base_model, trust_remote_code=True, load_in_8bit=True, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

    if args.lora == "":
        print("#> No lora model specified, using base model.")
        peft_model = model.eval()
    else:
        print("#> Using lora model:", lora)
        # 加载lora微调模型
        peft_model = PeftModel.from_pretrained(model, lora).eval()
    torch.set_default_tensor_type(torch.cuda.FloatTensor)

    texts = [
        "你好",
        "有多少人工,便有多少智能。",
        "落霞与孤鹜齐飞,秋水共长天一色。",
        "我去买几个橘子,你就站在这里,不要走动。",
        "学习计算机技术,是没有办法救中国的。",
        "我怎么样都起不了床,我觉得我可能是得了抑郁症吧。",
        "它是整个系统的支撑架构,连接处理器、内存、存储、显卡和外围端口等所有其他组件。",
        "古巴导弹危机和越南战争是20世纪最大、最致命的两场冲突。古巴导弹危机涉及美国和苏联之间的僵局,因苏联在古巴设立核导弹基地而引发,而越南战争则是北方(由苏联支持)和南方(由美国支持)之间在印度支那持续的军事冲突。",
        "齿槽力矩是指旋转设备受到齿轮牙齿阻力时施加的扭矩。",
        "他的作品包括蒙娜丽莎和最后的晚餐,两者都被认为是杰作。",
        "滑铁卢战役发生在1815年6月18日,是拿破仑战争的最后一场重大战役。"
    ]

    for text in texts:
        print(text)
        print(generate(args.instruction, text), "n")

参考资料
[1]

PEFT: https://github.com/huggingface/peft

[2]

Llama2基于LoRA微调: https://github.com/LlamaFamily/Llama2-Chinese?tab=readme-ov-file#lora%E5%BE%AE%E8%B0%83

[3]

ChatGLM基于LoRA微调: https://blog.csdn.net/saoqi_boy/article/details/135057599


原文始发于微信公众号(阿郎小哥的随笔驿站):聊聊PEFT微调技术

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/244120.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!