ChatGPT|用ChatGLM-6B实现图文对话

ChatGLM-6B版本是文本交互模型,和GPT3.5一样,不能识别图片,但是是否有方法可以实现呢?

那先看看效果(不过如果你GPT3.5,效果应该会更好,这里为了验证功能,所以使用ChatGLM-6B版本)。

效果

第一个:

ChatGPT|用ChatGLM-6B实现图文对话

第二个:

ChatGPT|用ChatGLM-6B实现图文对话

设计架构

对于图片的元素,一般我们只需要知道场景,标签和文字,就能描述这张图片,说干就干,架构如下:

ChatGPT|用ChatGLM-6B实现图文对话

1、我们判断对话的类型,如果是文本则直接使用ChatGLM的对话功能;
2、如果是图片则执行如下步骤:
    1)识别图片的场景;
    2)识别图片的标签有哪些;
    3)用OCR服务识别图片中的文字;
3、将如上的三个信息汇总,按照Prompt模板将信息给到ChatGLM;
4、拿到返回的结果,返回给对话发起者;

实现代码

1、使用ChatGLM

(1)私有搭建 

具体参考https://github.com/THUDM/ChatGLM-6B,按照步骤搭建即可,然后提供API;

(2)使用API 

如果自己没有GPU资源,可以去这里直接注册,使用智谱提供的API,地址:https://open.bigmodel.ai/

(3)对话代码

...

kUseChatGLM = Flase
class ZhipuAI(BaseChat):
    ability_type = "chatglm_qa_6b"  # 能力类型
    engine_type = "chatglm_6b"  # 引擎类型
    if kUseChatGLM:
        ability_type = "chatGLM"
        engine_type = "chatGLM"
    API_KEY = "xxx"  # 接口API KEY
    PUBLIC_KEY = "xxx"  # 公钥

    def __init__(self, apitype="", dict_args_input={}):
        self.dict_args = {}
        for k, v in dict_args_input:
            self.dict_args[k] = v
        self.system_mess = []
        self.user_mess = []

    @staticmethod
    def getToken():
        token_result = kTokenCache.getValue('token')
        if not token_result:
            token_result = getToken(ZhipuAI.API_KEY, ZhipuAI.PUBLIC_KEY)
            kTokenCache.setValue('token', token_result, 60)
        return token_result

    ...

    def chat(self, mess):
        isok, response = self.openaiRequest(mess)
        if isok:
            return isok, response
        else:
            return isok, "error:" + str(response)

    def openaiRequest(self, mess):
        try:
            token_result = ZhipuAI.getToken()
            uuid1 = uuid.uuid1()
            request_task_no = str(uuid1).replace("-""")
            data = {
                "requestTaskNo": request_task_no,
                "prompt": mess
            }
            if self.user_mess and len(self.user_mess) > 0:
                if kUseChatGLM:
                    historyFormat = []
                    for history in self.user_mess:
                        logging.info("history: " + str(history))
                        try:
                            if len(history["query"]) > 0 and len(history["content"]):
                                historyFormat.append(history["query"])
                                historyFormat.append(history["content"])
                        except Exception as ex:
                            logging.error("err: " + str(ex))
                    data["history"] = historyFormat
                else:
                    data["history"] = self.user_mess
            logging.info("request data: " + str(data))
            if token_result and token_result["code"] == 200:
                token = token_result["data"]
                if kUseChatGLM:
                    resp = executeEngine(ZhipuAI.ability_type,
                                         ZhipuAI.engine_type, token, data)
                else:
                    resp = executeEngineV2(ZhipuAI.ability_type,
                                           ZhipuAI.engine_type, token, data)
                    while resp["code"] == 200 and resp['data']['taskStatus'] == 'PROCESSING':
                        taskOrderNo = resp['data']['taskOrderNo']
                        time.sleep(1)
                        resp = queryTaskResult(token, taskOrderNo)
                outputText = resp["data"]["outputText"]
                if outputText:
                    # keep userid to kMaxUserMessLength
                    if len(outputText) > 0:
                        if len(self.user_mess) > kMaxUserMessLength:
                            self.user_mess = self.user_mess[-kMaxUserMessLength]
                        else:
                            self.user_mess.append(
                                {"query": mess, "content": outputText})
                return True, resp["data"]["outputText"]
        except Exception as ex:
            logging.error("err: " + str(ex))
            return False, str(ex)
...

以上非完整代码,大家可以参考修改,主要功能就是用智谱提供的API,进行文本对话,然后存储历史记录。

2、分析图片

(1)方案选项 

1)图片识别标签,可以自己搭建,如果有兴趣可以参考百度的PaddlePaddle,具体搭建的方式:https://github.com/PaddlePaddle/PaddleClas
2)为了快速验证,这里也可以使用云服务,如阿里云的https://ai.aliyun.com/image,腾讯云的https://console.cloud.tencent.com/tiia/detectlabel
3)OCR的服务也有一些开源的,不过云上使用更方便,可以用腾讯云的https://console.cloud.tencent.com/ocr/overview

这里我就是用腾讯云的服务验证,当然不是商用,可以不需要花钱(有一定的免费额度)。

(2)获取标签

先安装SDK:

python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python

调用腾讯云的API,获取图片标签:

import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库

SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128

def get_images_tags(base64_data):
    try:
        # 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
        # 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
        # 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
        cred = credential.Credential(SecretId, SecretKey)
        # 实例化一个http选项,可选的,没有特殊需求可以跳过
        http_profile = HttpProfile()
        http_profile.endpoint = "tiia.tencentcloudapi.com"
        # 实例化一个client选项,可选的,没有特殊需求可以跳过
        client_profile = ClientProfile()
        client_profile.httpProfile = http_profile
        # 实例化要请求产品的client对象,client_profile是可选的
        client = tiia_client.TiiaClient(cred, "ap-guangzhou", client_profile)
        # 实例化一个请求对象,每个接口都会对应一个request对象
        req = tiia_models.DetectLabelRequest()
        params = {'ImageBase64': base64_data}
        req.from_json_string(json.dumps(params))
        # 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
        resp = client.DetectLabel(req)
        logging.info("get_images_tags: " + resp.to_json_string())
        return resp  # 输出json格式的字符串回包
    except TencentCloudSDKException as err:
        logging.error("get_images_tags err: " + str(err))
        return None

获取结果样例:

{"Response":{"Labels":[{"Name":"字体","Confidence":92,"FirstCategory":"其他","SecondCategory":"其他"},{"Name":"文本","Confidence":85,"FirstCategory":"卡证文档","SecondCategory":"其他"},{"Name":"品牌","Confidence":68,"FirstCategory":"物品","SecondCategory":"标牌标识"},{"Name":"线","Confidence":50,"FirstCategory":"物品","SecondCategory":"日常用品"},{"Name":"报告","Confidence":27,"FirstCategory":"物品","SecondCategory":"其他"}],"CameraLabels":null,"AlbumLabels":null,"NewsLabels":null,"RequestId":"e710697f-3054-494a-bf33-1990dbe25bbd"}}

(3)获取文字

先安装SDK:

python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python

调用腾讯云的API,获取图片标签:

import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库

SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128

def get_images_ocr(base64_data):
    try:
        cred = credential.Credential(SecretId, SecretKey)
        # 实例化一个http选项,可选的,没有特殊需求可以跳过
        http_profile = HttpProfile()
        http_profile.endpoint = "ocr.tencentcloudapi.com"
        # 实例化一个client选项,可选的,没有特殊需求可以跳过
        client_profile = ClientProfile()
        client_profile.httpProfile = http_profile
        # 实例化要请求产品的client对象,client_profile是可选的
        client = ocr_client.OcrClient(cred, "ap-guangzhou", client_profile)
        # 实例化一个请求对象,每个接口都会对应一个request对象
        req = ocr_models.RecognizeTableAccurateOCRRequest()
        params = {'ImageBase64': base64_data}
        req.from_json_string(json.dumps(params))
        # 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
        resp = client.RecognizeTableAccurateOCR(req)
        # 输出json格式的字符串回包
        logging.info("get_images_ocr: " + resp.to_json_string())
        return resp
    except TencentCloudSDKException as err:
        logging.error("get_images_ocr err: " + str(err))
        return None

获取结果样例:

{"TableDetections": [{"Cells": [{"ColTl"0"RowTl"0"ColBr"1"RowBr"1"Text""周末程序猿""Type""header""Confidence"100"Polygon": [{"X"1407"Y"588}, {"X"2176"Y"584}, {"X"2177"Y"748}, {"X"1408"Y"752}]}], "Type"0"TableCoordPoint": [{"X"1407"Y"584}, {"X"2177"Y"584}, {"X"2177"Y"752}, {"X"1407"Y"752}]}] ...

3、组装Prompt

可以按照Prompt模板组装数据,如:

#####
图片有以下标签:{拿到的图片标签列表}
#####
图片中有文字,请你理解以下文字:[{拿到的文字标签列表}]...
现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。

注意:如果文字过长,可以用省略号...,具体长度可以设置256-1024之间。

具体代码:

def get_images_prompt(base64_data):
    try:
        # "#####n经过某个图片分析服务,得出以下关于这幅图片的信息:n"
        prompt = ""
        resp = get_images_tags(base64_data)
        labels_str = []
        if resp:
            for labels in resp.Labels:
                labels_str.append(labels.Name)
        logging.info(labels_str)
        if len(labels_str) > 0:
            prompt += f"#####n图片有以下标签:{','.join(labels_str)}n"
        logging.info("labels_str prompt: " + str(prompt))
        resp = get_images_ocr(base64_data)
        labels_ocr_str = []
        if resp:
            for cells in resp.TableDetections:
                for text in cells.Cells:
                    if len(text.Text) > 0:
                        labels_ocr_str.append(text.Text)
        logging.info(str(labels_ocr_str))
        if len(labels_ocr_str) > 0:
            labels = ','.join(labels_ocr_str).replace("n""")
            if len(labels) > kMaxLabels:
                labels = labels[:kMaxLabels] + "..."
            prompt += f"#####n图片中有文字,请你理解以下文字:["+labels+"]n"
        else:
            prompt += f"#####n图片中没有文字n"
        prompt += "现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。"
        logging.info("labels_ocr_str prompt: " + str(prompt))
        return prompt
    except Exception as err:
        logging.err("get_images_ocr err: " + str(err))
        return None

3、将组装的Prompt发给ChatGLM

省略了一些工程代码(如果需要整体代码可以留言给我),测试代码就是这样:

@staticmethod
def test(content, userid=""):
    """
    测试函数
    :return:
    "
""
    logging.info("content->"+str(content))
    chat = SimpleChatFactory.getInstance(kChatModel, userid)
    image_pos = content.find("data:image/")
    if image_pos >= 0:
        images_base64 = re.sub(
            '<img src="data:image/(.*);base64,''', content)
        images_base64 = re.sub('">''', images_base64)
        logging.info("images_base64->"+str(images_base64))
        content = get_images_prompt(images_base64)
        chat.withOption(user_mess=False)
    # 如果是图像需要预处理
    if len(content) > kMaxContentLength:
        return kERRTimeout
    isok, replaycontent = chat.chat(content)
    if isok:
        return replaycontent
    else:
        return kERRTimeout

以上就是基于云服务+ChatGLM实现图片理解,如果你想自己搭建一个不依赖云服务的,其实也可以,上文给出了一些开源的方案,按照同样的方式替换服务,将生成的文本描述给到ChatGLM或者其他的GPT文本对话模型,实现扩展为多模态。

demo页面:

https://service-mpjvpuxa-1251014631.gz.apigw.tencentcs.com/static/chat.html


原文始发于微信公众号(周末程序猿):ChatGPT|用ChatGLM-6B实现图文对话

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/169413.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!