ChatGLM-6B版本是文本交互模型,和GPT3.5一样,不能识别图片,但是是否有方法可以实现呢?
那先看看效果(不过如果你GPT3.5,效果应该会更好,这里为了验证功能,所以使用ChatGLM-6B版本)。
效果
第一个:
第二个:
设计架构
对于图片的元素,一般我们只需要知道场景,标签和文字,就能描述这张图片,说干就干,架构如下:
1、我们判断对话的类型,如果是文本则直接使用ChatGLM的对话功能;
2、如果是图片则执行如下步骤:
1)识别图片的场景;
2)识别图片的标签有哪些;
3)用OCR服务识别图片中的文字;
3、将如上的三个信息汇总,按照Prompt模板将信息给到ChatGLM;
4、拿到返回的结果,返回给对话发起者;
实现代码
1、使用ChatGLM
(1)私有搭建
具体参考https://github.com/THUDM/ChatGLM-6B
,按照步骤搭建即可,然后提供API;
(2)使用API
如果自己没有GPU资源,可以去这里直接注册,使用智谱提供的API,地址:https://open.bigmodel.ai/
(3)对话代码
...
kUseChatGLM = Flase
class ZhipuAI(BaseChat):
ability_type = "chatglm_qa_6b" # 能力类型
engine_type = "chatglm_6b" # 引擎类型
if kUseChatGLM:
ability_type = "chatGLM"
engine_type = "chatGLM"
API_KEY = "xxx" # 接口API KEY
PUBLIC_KEY = "xxx" # 公钥
def __init__(self, apitype="", dict_args_input={}):
self.dict_args = {}
for k, v in dict_args_input:
self.dict_args[k] = v
self.system_mess = []
self.user_mess = []
@staticmethod
def getToken():
token_result = kTokenCache.getValue('token')
if not token_result:
token_result = getToken(ZhipuAI.API_KEY, ZhipuAI.PUBLIC_KEY)
kTokenCache.setValue('token', token_result, 60)
return token_result
...
def chat(self, mess):
isok, response = self.openaiRequest(mess)
if isok:
return isok, response
else:
return isok, "error:" + str(response)
def openaiRequest(self, mess):
try:
token_result = ZhipuAI.getToken()
uuid1 = uuid.uuid1()
request_task_no = str(uuid1).replace("-", "")
data = {
"requestTaskNo": request_task_no,
"prompt": mess
}
if self.user_mess and len(self.user_mess) > 0:
if kUseChatGLM:
historyFormat = []
for history in self.user_mess:
logging.info("history: " + str(history))
try:
if len(history["query"]) > 0 and len(history["content"]):
historyFormat.append(history["query"])
historyFormat.append(history["content"])
except Exception as ex:
logging.error("err: " + str(ex))
data["history"] = historyFormat
else:
data["history"] = self.user_mess
logging.info("request data: " + str(data))
if token_result and token_result["code"] == 200:
token = token_result["data"]
if kUseChatGLM:
resp = executeEngine(ZhipuAI.ability_type,
ZhipuAI.engine_type, token, data)
else:
resp = executeEngineV2(ZhipuAI.ability_type,
ZhipuAI.engine_type, token, data)
while resp["code"] == 200 and resp['data']['taskStatus'] == 'PROCESSING':
taskOrderNo = resp['data']['taskOrderNo']
time.sleep(1)
resp = queryTaskResult(token, taskOrderNo)
outputText = resp["data"]["outputText"]
if outputText:
# keep userid to kMaxUserMessLength
if len(outputText) > 0:
if len(self.user_mess) > kMaxUserMessLength:
self.user_mess = self.user_mess[-kMaxUserMessLength]
else:
self.user_mess.append(
{"query": mess, "content": outputText})
return True, resp["data"]["outputText"]
except Exception as ex:
logging.error("err: " + str(ex))
return False, str(ex)
...
以上非完整代码,大家可以参考修改,主要功能就是用智谱提供的API,进行文本对话,然后存储历史记录。
2、分析图片
(1)方案选项
1)图片识别标签,可以自己搭建,如果有兴趣可以参考百度的PaddlePaddle
,具体搭建的方式:https://github.com/PaddlePaddle/PaddleClas
;
2)为了快速验证,这里也可以使用云服务,如阿里云的https://ai.aliyun.com/image
,腾讯云的https://console.cloud.tencent.com/tiia/detectlabel
;
3)OCR的服务也有一些开源的,不过云上使用更方便,可以用腾讯云的https://console.cloud.tencent.com/ocr/overview
;
这里我就是用腾讯云的服务验证,当然不是商用,可以不需要花钱(有一定的免费额度)。
(2)获取标签
先安装SDK:
python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python
调用腾讯云的API,获取图片标签:
import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库
SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128
def get_images_tags(base64_data):
try:
# 实例化一个认证对象,入参需要传入腾讯云账户 SecretId 和 SecretKey,此处还需注意密钥对的保密
# 代码泄露可能会导致 SecretId 和 SecretKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考,建议采用更安全的方式来使用密钥,请参见:https://cloud.tencent.com/document/product/1278/85305
# 密钥可前往官网控制台 https://console.cloud.tencent.com/cam/capi 进行获取
cred = credential.Credential(SecretId, SecretKey)
# 实例化一个http选项,可选的,没有特殊需求可以跳过
http_profile = HttpProfile()
http_profile.endpoint = "tiia.tencentcloudapi.com"
# 实例化一个client选项,可选的,没有特殊需求可以跳过
client_profile = ClientProfile()
client_profile.httpProfile = http_profile
# 实例化要请求产品的client对象,client_profile是可选的
client = tiia_client.TiiaClient(cred, "ap-guangzhou", client_profile)
# 实例化一个请求对象,每个接口都会对应一个request对象
req = tiia_models.DetectLabelRequest()
params = {'ImageBase64': base64_data}
req.from_json_string(json.dumps(params))
# 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
resp = client.DetectLabel(req)
logging.info("get_images_tags: " + resp.to_json_string())
return resp # 输出json格式的字符串回包
except TencentCloudSDKException as err:
logging.error("get_images_tags err: " + str(err))
return None
获取结果样例:
{"Response":{"Labels":[{"Name":"字体","Confidence":92,"FirstCategory":"其他","SecondCategory":"其他"},{"Name":"文本","Confidence":85,"FirstCategory":"卡证文档","SecondCategory":"其他"},{"Name":"品牌","Confidence":68,"FirstCategory":"物品","SecondCategory":"标牌标识"},{"Name":"线","Confidence":50,"FirstCategory":"物品","SecondCategory":"日常用品"},{"Name":"报告","Confidence":27,"FirstCategory":"物品","SecondCategory":"其他"}],"CameraLabels":null,"AlbumLabels":null,"NewsLabels":null,"RequestId":"e710697f-3054-494a-bf33-1990dbe25bbd"}}
(3)获取文字
先安装SDK:
python3 -m pip install tencentcloud-sdk-python-tiia
python3 -m pip install tencentcloud-sdk-python
调用腾讯云的API,获取图片标签:
import base64
import json
import logging
from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.tiia.v20190529 import tiia_client, models as tiia_models # 图片标签的库
from tencentcloud.ocr.v20181119 import ocr_client, models as ocr_models # OCR库
SecretId = 'xxx' # 腾讯云的SecretId
SecretKey = 'xxx' # 腾讯云的SecretKey
kMaxLabels = 128
def get_images_ocr(base64_data):
try:
cred = credential.Credential(SecretId, SecretKey)
# 实例化一个http选项,可选的,没有特殊需求可以跳过
http_profile = HttpProfile()
http_profile.endpoint = "ocr.tencentcloudapi.com"
# 实例化一个client选项,可选的,没有特殊需求可以跳过
client_profile = ClientProfile()
client_profile.httpProfile = http_profile
# 实例化要请求产品的client对象,client_profile是可选的
client = ocr_client.OcrClient(cred, "ap-guangzhou", client_profile)
# 实例化一个请求对象,每个接口都会对应一个request对象
req = ocr_models.RecognizeTableAccurateOCRRequest()
params = {'ImageBase64': base64_data}
req.from_json_string(json.dumps(params))
# 返回的resp是一个DetectLabelResponse的实例,与请求对象对应
resp = client.RecognizeTableAccurateOCR(req)
# 输出json格式的字符串回包
logging.info("get_images_ocr: " + resp.to_json_string())
return resp
except TencentCloudSDKException as err:
logging.error("get_images_ocr err: " + str(err))
return None
获取结果样例:
{"TableDetections": [{"Cells": [{"ColTl": 0, "RowTl": 0, "ColBr": 1, "RowBr": 1, "Text": "周末程序猿", "Type": "header", "Confidence": 100, "Polygon": [{"X": 1407, "Y": 588}, {"X": 2176, "Y": 584}, {"X": 2177, "Y": 748}, {"X": 1408, "Y": 752}]}], "Type": 0, "TableCoordPoint": [{"X": 1407, "Y": 584}, {"X": 2177, "Y": 584}, {"X": 2177, "Y": 752}, {"X": 1407, "Y": 752}]}] ...
3、组装Prompt
可以按照Prompt模板组装数据,如:
#####
图片有以下标签:{拿到的图片标签列表}
#####
图片中有文字,请你理解以下文字:[{拿到的文字标签列表}]...
现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。
注意:如果文字过长,可以用省略号...
,具体长度可以设置256-1024之间。
具体代码:
def get_images_prompt(base64_data):
try:
# "#####n经过某个图片分析服务,得出以下关于这幅图片的信息:n"
prompt = ""
resp = get_images_tags(base64_data)
labels_str = []
if resp:
for labels in resp.Labels:
labels_str.append(labels.Name)
logging.info(labels_str)
if len(labels_str) > 0:
prompt += f"#####n图片有以下标签:{','.join(labels_str)}n"
logging.info("labels_str prompt: " + str(prompt))
resp = get_images_ocr(base64_data)
labels_ocr_str = []
if resp:
for cells in resp.TableDetections:
for text in cells.Cells:
if len(text.Text) > 0:
labels_ocr_str.append(text.Text)
logging.info(str(labels_ocr_str))
if len(labels_ocr_str) > 0:
labels = ','.join(labels_ocr_str).replace("n", "")
if len(labels) > kMaxLabels:
labels = labels[:kMaxLabels] + "..."
prompt += f"#####n图片中有文字,请你理解以下文字:["+labels+"]n"
else:
prompt += f"#####n图片中没有文字n"
prompt += "现在请综合以上信息(标签、文字描述等),自然并详细地描述这副图片。请你不要在回答中暴露上述信息来源是图片分析服务。"
logging.info("labels_ocr_str prompt: " + str(prompt))
return prompt
except Exception as err:
logging.err("get_images_ocr err: " + str(err))
return None
3、将组装的Prompt发给ChatGLM
省略了一些工程代码(如果需要整体代码可以留言给我),测试代码就是这样:
@staticmethod
def test(content, userid=""):
"""
测试函数
:return:
"""
logging.info("content->"+str(content))
chat = SimpleChatFactory.getInstance(kChatModel, userid)
image_pos = content.find("data:image/")
if image_pos >= 0:
images_base64 = re.sub(
'<img src="data:image/(.*);base64,', '', content)
images_base64 = re.sub('">', '', images_base64)
logging.info("images_base64->"+str(images_base64))
content = get_images_prompt(images_base64)
chat.withOption(user_mess=False)
# 如果是图像需要预处理
if len(content) > kMaxContentLength:
return kERRTimeout
isok, replaycontent = chat.chat(content)
if isok:
return replaycontent
else:
return kERRTimeout
以上就是基于云服务+ChatGLM实现图片理解,如果你想自己搭建一个不依赖云服务的,其实也可以,上文给出了一些开源的方案,按照同样的方式替换服务,将生成的文本描述给到ChatGLM或者其他的GPT文本对话模型,实现扩展为多模态。
demo页面:
https://service-mpjvpuxa-1251014631.gz.apigw.tencentcs.com/static/chat.html
原文始发于微信公众号(周末程序猿):ChatGPT|用ChatGLM-6B实现图文对话
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/169413.html