Skip to content

Commit 88d483c

Browse files
authored
Merge pull request #4 from toddyoe/main
chore: add system instruction to enhance compliance with function call
2 parents a592269 + 8d48db0 commit 88d483c

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

app/services/chat/message_converter.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# app/services/chat/message_converter.py
22

33
from abc import ABC, abstractmethod
4-
from typing import List, Dict, Any
4+
from typing import Any, Dict, List, Optional
5+
6+
SUPPORTED_ROLES = ["user", "model", "system"]
57

68

79
class MessageConverter(ABC):
810
"""消息转换器基类"""
911

1012
@abstractmethod
11-
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
13+
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
1214
pass
1315

1416

@@ -30,16 +32,19 @@ def _convert_image(image_url: str) -> Dict[str, Any]:
3032
class OpenAIMessageConverter(MessageConverter):
3133
"""OpenAI消息格式转换器"""
3234

33-
def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
35+
def convert(self, messages: List[Dict[str, Any]]) -> tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
3436
converted_messages = []
37+
system_instruction = None
38+
3539
for msg in messages:
36-
role = "user" if msg["role"] == "user" else "model"
37-
parts = []
40+
role = msg.get("role", "")
41+
if role not in SUPPORTED_ROLES:
42+
role = "model"
3843

39-
if isinstance(msg["content"], str):
40-
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空
41-
if msg["content"]:
42-
parts.append({"text": msg["content"]})
44+
parts = []
45+
if isinstance(msg["content"], str) and msg["content"]:
46+
# 请求 gemini 接口时如果包含 content 字段但内容为空时会返回 400 错误,所以需要判断是否为空并移除
47+
parts.append({"text": msg["content"]})
4348
elif isinstance(msg["content"], list):
4449
for content in msg["content"]:
4550
if isinstance(content, str) and content:
@@ -50,6 +55,10 @@ def convert(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
5055
elif content["type"] == "image_url":
5156
parts.append(_convert_image(content["image_url"]["url"]))
5257

53-
converted_messages.append({"role": role, "parts": parts})
58+
if parts:
59+
if role == "system":
60+
system_instruction = {"role": "system", "parts": parts}
61+
else:
62+
converted_messages.append({"role": role, "parts": parts})
5463

55-
return converted_messages
64+
return converted_messages, system_instruction

app/services/openai_chat_service.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from copy import deepcopy
44
import json
5-
from typing import Dict, Any, AsyncGenerator, List, Union
5+
from typing import Dict, Any, AsyncGenerator, List, Optional, Union
66
from app.core.logger import get_openai_logger
77
from app.services.chat.message_converter import OpenAIMessageConverter
88
from app.services.chat.response_handler import OpenAIResponseHandler
@@ -87,10 +87,10 @@ def _get_safety_settings(model: str) -> List[Dict[str, str]]:
8787

8888

8989
def _build_payload(
90-
request: ChatRequest, messages: List[Dict[str, Any]]
90+
request: ChatRequest, messages: List[Dict[str, Any]], instruction: Optional[Dict[str, Any]] = None
9191
) -> Dict[str, Any]:
9292
"""构建请求payload"""
93-
return {
93+
payload = {
9494
"contents": messages,
9595
"generationConfig": {
9696
"temperature": request.temperature,
@@ -103,6 +103,16 @@ def _build_payload(
103103
"safetySettings": _get_safety_settings(request.model),
104104
}
105105

106+
if (
107+
instruction
108+
and isinstance(instruction, dict)
109+
and instruction.get("role") == "system"
110+
and instruction.get("parts")
111+
):
112+
payload["systemInstruction"] = instruction
113+
114+
return payload
115+
106116

107117
class OpenAIChatService:
108118
"""聊天服务"""
@@ -120,10 +130,10 @@ async def create_chat_completion(
120130
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
121131
"""创建聊天完成"""
122132
# 转换消息格式
123-
messages = self.message_converter.convert(request.messages)
133+
messages, instruction = self.message_converter.convert(request.messages)
124134

125135
# 构建请求payload
126-
payload = _build_payload(request, messages)
136+
payload = _build_payload(request, messages, instruction)
127137

128138
if request.stream:
129139
return self._handle_stream_completion(request.model, payload, api_key)

0 commit comments

Comments
 (0)