From b4b83d113de8a8bfd49fd28e4073d88873b786ef Mon Sep 17 00:00:00 2001 From: Song367 <601337784@qq.com> Date: Tue, 17 Jun 2025 20:40:48 +0800 Subject: [PATCH] initial commit --- .env | 11 + API_DOCUMENTATION.md | 222 ++++++++++++ file_server.py | 142 ++++++++ go.mod | 36 ++ go.sum | 90 +++++ handler/llm_handler.go | 195 +++++++++++ handler/token_handler.go | 34 ++ main.go | 92 +++++ service/llm_service.go | 725 +++++++++++++++++++++++++++++++++++++++ service/token_service.go | 53 +++ static/index.html | 190 ++++++++++ 11 files changed, 1790 insertions(+) create mode 100644 .env create mode 100644 API_DOCUMENTATION.md create mode 100644 file_server.py create mode 100644 go.mod create mode 100644 go.sum create mode 100644 handler/llm_handler.go create mode 100644 handler/token_handler.go create mode 100644 main.go create mode 100644 service/llm_service.go create mode 100644 service/token_service.go create mode 100644 static/index.html diff --git a/.env b/.env new file mode 100644 index 0000000..0595861 --- /dev/null +++ b/.env @@ -0,0 +1,11 @@ +# LLM API Configuration +LLM_API_URL=http://tianchat.zenithsafe.com:5001/v1 +LLM_API_KEY=app-k9WhnUvAPCVcSoPDEYVUxXgC +MiniMaxApiKey=eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJHcm91cE5hbWUiOiLkuIrmtbfpopzpgJTnp5HmioDmnInpmZDlhazlj7giLCJVc2VyTmFtZSI6IuadqOmqpSIsIkFjY291bnQiOiIiLCJTdWJqZWN0SUQiOiIxNzI4NzEyMzI0OTc5NjI2ODM5IiwiUGhvbmUiOiIxMzM4MTU1OTYxOCIsIkdyb3VwSUQiOiIxNzI4NzEyMzI0OTcxMjM4MjMxIiwiUGFnZU5hbWUiOiIiLCJNYWlsIjoiIiwiQ3JlYXRlVGltZSI6IjIwMjUtMDYtMTYgMTY6Mjk6NTkiLCJUb2tlblR5cGUiOjEsImlzcyI6Im1pbmltYXgifQ.D_JF0-nO89NdMZCYq4ocEyqxtZ9SeEdtMvbeSkZTWspt0XfX2QpPAVh-DI3MCPZTeSmjNWLf4fA_Th2zpVrj4UxWMbGKBeLZWLulNpwAHGMUTdqenuih3daCDPCzs0duhlFyQnZgGcEOGQ476HL72N2klujP8BUy_vfAh_Zv0po-aujQa5RxardDSOsbs49NTPEw0SQEXwaJ5bVmiZ5s-ysJ9pZWSEiyJ6SX9z3JeZHKj9DxHdOw5roZR8izo54e4IoqyLlzEfhOMW7P15-ffDH3M6HGiEmeBaGRYGAIciELjZS19ONNMKsTj-wXNGWtKG-sjAB1uuqkkT5Ul9Dunw +MiniMaxApiURL=https://api.minimaxi.com/v1/t2a_v2 +APP_ID=1364994890450210816 +APP_KEY=b4839cb2-cb81-4472-a2c1-2abf31e4bb27 +SIG_EXP=3600 +FILE_URL=http://localhost:8000/ +# Server Configuration +PORT=8080 \ No newline at end of file diff --git a/API_DOCUMENTATION.md b/API_DOCUMENTATION.md new file mode 100644 index 0000000..a66203d --- /dev/null +++ b/API_DOCUMENTATION.md @@ -0,0 +1,222 @@ +# 流式聊天 API 文档 + +## 概述 +该 API 提供实时流式聊天功能,支持文本对话和语音合成。API 使用 Server-Sent Events (SSE) 实现流式响应,确保实时性和高效性。 + +## 基础信息 +- 基础URL: `http://your-domain:8080` +- 内容类型: `application/json` +- 响应类型: `text/event-stream` (流式响应) + +## API 端点 + +### 1. 聊天接口 +``` +POST /chat +``` + +#### 请求参数 +| 参数名 | 类型 | 必填 | 描述 | +|--------|------|------|------| +| query | string | 是 | 用户输入的查询文本 | +| response_mode | string | 是 | 响应模式,使用 "streaming" 启用流式响应 | +| user | string | 是 | 用户标识符 | +| conversation_id | string | 否 | 会话ID,首次对话可不传 | + +#### 请求示例 +```json +{ + "query": "你好,请介绍一下自己", + "response_mode": "streaming", + "user": "user123", + "conversation_id": "" +} +``` + +#### 响应格式 +响应使用 Server-Sent Events (SSE) 格式,每个事件包含以下字段: + +| 字段名 | 类型 | 描述 | +|--------|------|------| +| answer | string | 机器人的文本回复 | +| isEnd | boolean | 是否为最后一条消息 | +| conversation_id | string | 会话ID | +| task_id | string | 任务ID | +| audio_data | string | 语音数据(URL或十六进制编码) | + +#### 响应示例 +``` +data: {"answer":"你好!","isEnd":false,"conversation_id":"conv_123","task_id":"task_456","audio_data":"http://example.com/audio.mp3"} +data: {"answer":"我是AI助手","isEnd":false,"conversation_id":"conv_123","task_id":"task_456","audio_data":"http://example.com/audio2.mp3"} +data: {"answer":"","isEnd":true,"conversation_id":"conv_123","task_id":"task_456"} +``` + +### 2. 停止对话 +``` +POST /chat-messages/:task_id/stop +``` + +#### 路径参数 +| 参数名 | 类型 | 描述 | +|--------|------|------| +| task_id | string | 要停止的任务ID | + +#### 响应示例 +```json +{ + "status": "success", + "message": "Conversation stopped" +} +``` + +### 3. 删除对话 +``` +DELETE /conversations/:conversation_id +``` + +#### 路径参数 +| 参数名 | 类型 | 描述 | +|--------|------|------| +| conversation_id | string | 要删除的会话ID | + +#### 查询参数 +| 参数名 | 类型 | 必填 | 描述 | +|--------|------|------|------| +| user | string | 是 | 用户标识符 | + +#### 响应示例 +```json +{ + "status": "success", + "message": "Conversation deleted" +} +``` + +## 前端集成示例 + +### 1. 基本使用 +```javascript +async function sendMessage(message) { + const response = await fetch('/chat', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + query: message, + response_mode: 'streaming', + user: 'user123', + conversation_id: currentConversationId + }) + }); + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + + while (true) { + const { value, done } = await reader.read(); + if (done) break; + + const chunk = decoder.decode(value); + const lines = chunk.split('\n'); + + for (const line of lines) { + if (line.startsWith('data: ')) { + const data = JSON.parse(line.slice(6)); + // 处理响应数据 + handleResponse(data); + } + } + } +} +``` + +### 2. 处理响应数据 +```javascript +function handleResponse(data) { + // 更新会话ID + if (data.conversation_id) { + currentConversationId = data.conversation_id; + } + + // 处理文本回复 + if (data.answer) { + updateChatMessage(data.answer); + } + + // 处理语音数据 + if (data.audio_data) { + playAudio(data.audio_data); + } +} +``` + +### 3. 播放音频 +```javascript +function playAudio(audioData) { + // URL格式的音频 + if (audioData.startsWith('http')) { + const audio = new Audio(audioData); + audio.play(); + } + // 十六进制编码的音频 + else { + const audioContext = new (window.AudioContext || window.webkitAudioContext)(); + const audioDataArray = new Uint8Array( + audioData.match(/.{1,2}/g).map(byte => parseInt(byte, 16)) + ); + + audioContext.decodeAudioData(audioDataArray.buffer, (buffer) => { + const source = audioContext.createBufferSource(); + source.buffer = buffer; + source.connect(audioContext.destination); + source.start(0); + }); + } +} +``` + +## 错误处理 + +### 常见错误码 +| 状态码 | 描述 | +|--------|------| +| 400 | 请求参数错误 | +| 401 | 未授权 | +| 500 | 服务器内部错误 | + +### 错误响应格式 +```json +{ + "error": "错误描述信息" +} +``` + +## 最佳实践 + +1. **会话管理** + - 保存 `conversation_id` 以维持对话上下文 + - 在对话结束时清理资源 + +2. **错误处理** + - 实现重试机制 + - 优雅处理网络错误 + - 提供用户友好的错误提示 + +3. **性能优化** + - 使用缓冲通道处理流式数据 + - 及时清理不需要的音频资源 + - 实现消息队列避免并发问题 + +4. **安全性** + - 验证用户身份 + - 使用 HTTPS + - 实现请求频率限制 + +## 注意事项 + +1. 确保正确处理 SSE 连接的关闭 +2. 实现适当的错误重试机制 +3. 注意音频资源的及时释放 +4. 考虑网络延迟和断线重连 +5. 实现适当的加载状态提示 \ No newline at end of file diff --git a/file_server.py b/file_server.py new file mode 100644 index 0000000..6669296 --- /dev/null +++ b/file_server.py @@ -0,0 +1,142 @@ +import http.server +import socketserver +import os +import argparse +from urllib.parse import unquote +import mimetypes +import logging + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) + +class FileHandler(http.server.SimpleHTTPRequestHandler): + def __init__(self, *args, **kwargs): + self.base_directory = os.path.abspath(os.path.join(os.getcwd(), 'audio')) + super().__init__(*args, directory=self.base_directory, **kwargs) + + def do_GET(self): + try: + # 解码URL路径 + path = unquote(self.path) + # 获取文件的完整路径 + file_path = os.path.abspath(os.path.join(self.base_directory, path.lstrip('/'))) + + # 安全检查:确保请求的路径在audio目录下 + if not file_path.startswith(self.base_directory): + self.send_error(403, "Access denied") + return + + # 检查文件是否存在 + if not os.path.exists(file_path): + self.send_error(404, "File not found") + return + + # 如果是目录,显示目录内容 + if os.path.isdir(file_path): + self.send_directory_listing(file_path) + return + + # 只允许访问音频文件 + allowed_extensions = {'.wav', '.mp3', '.ogg', '.m4a', '.flac'} + if not any(file_path.lower().endswith(ext) for ext in allowed_extensions): + self.send_error(403, "File type not allowed") + return + + # 获取文件的MIME类型 + content_type, _ = mimetypes.guess_type(file_path) + if content_type is None: + content_type = 'application/octet-stream' + + # 发送文件 + self.send_file(file_path, content_type) + + except Exception as e: + logging.error(f"Error handling request: {str(e)}") + self.send_error(500, f"Internal server error: {str(e)}") + + def send_file(self, file_path, content_type): + try: + with open(file_path, 'rb') as f: + self.send_response(200) + self.send_header('Content-type', content_type) + self.send_header('Content-Disposition', f'attachment; filename="{os.path.basename(file_path)}"') + self.end_headers() + self.wfile.write(f.read()) + except Exception as e: + logging.error(f"Error sending file: {str(e)}") + self.send_error(500, f"Error reading file: {str(e)}") + + def send_directory_listing(self, directory): + try: + self.send_response(200) + self.send_header('Content-type', 'text/html; charset=utf-8') + self.end_headers() + + # 生成目录列表HTML + html = ['', + 'Audio Files Directory', + '', + '

Audio Files Directory

', + '', + ''] + + # 添加父目录链接 + if self.path != '/': + html.append('') + + # 列出目录内容 + for item in sorted(os.listdir(directory)): + item_path = os.path.join(directory, item) + is_dir = os.path.isdir(item_path) + + # 只显示目录和音频文件 + if not is_dir and not any(item.lower().endswith(ext) for ext in {'.wav', '.mp3', '.ogg', '.m4a', '.flac'}): + continue + + size = '-' if is_dir else f"{os.path.getsize(item_path):,} bytes" + item_type = 'Directory' if is_dir else 'Audio File' + item_class = 'audio-file' if not is_dir else '' + html.append(f'') + + html.append('
NameSizeType
..-Directory
{item}{size}{item_type}
') + self.wfile.write('\n'.join(html).encode('utf-8')) + + except Exception as e: + logging.error(f"Error generating directory listing: {str(e)}") + self.send_error(500, f"Error generating directory listing: {str(e)}") + +def run_server(port): + # 确保audio目录存在 + audio_dir = os.path.join(os.getcwd(), 'audio') + if not os.path.exists(audio_dir): + os.makedirs(audio_dir) + logging.info(f"Created audio directory at: {audio_dir}") + + # 创建服务器 + handler = FileHandler + host = "0.0.0.0" + with socketserver.TCPServer((host, port), handler) as httpd: + logging.info(f"Server started at http://{host}:{port}") + logging.info(f"Local access: http://localhost:{port}") + logging.info(f"Serving files from: {audio_dir}") + try: + httpd.serve_forever() + except KeyboardInterrupt: + logging.info("Server stopped by user") + httpd.server_close() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Audio files HTTP server') + parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on (default: 8000)') + args = parser.parse_args() + + run_server(args.port) \ No newline at end of file diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..cd085f1 --- /dev/null +++ b/go.mod @@ -0,0 +1,36 @@ +module gongzheng_minimax + +go 1.21 + +require ( + github.com/gin-gonic/gin v1.9.1 + github.com/joho/godotenv v1.5.1 +) + +require ( + github.com/bytedance/sonic v1.9.1 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/crypto v0.9.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/text v0.9.0 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..879cd31 --- /dev/null +++ b/go.sum @@ -0,0 +1,90 @@ +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g= +golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/handler/llm_handler.go b/handler/llm_handler.go new file mode 100644 index 0000000..f4a5831 --- /dev/null +++ b/handler/llm_handler.go @@ -0,0 +1,195 @@ +package handler + +import ( + "encoding/json" + "net/http" + "time" + + "gongzheng_minimax/service" + + "github.com/gin-gonic/gin" +) + +// LLMHandler handles HTTP requests for the LLM service +type LLMHandler struct { + llmService *service.LLMService +} + +// NewLLMHandler creates a new instance of LLMHandler +func NewLLMHandler(llmService *service.LLMService) *LLMHandler { + return &LLMHandler{ + llmService: llmService, + } +} + +// Chat handles chat requests +func (h *LLMHandler) Chat(c *gin.Context) { + var requestData map[string]interface{} + if err := c.ShouldBindJSON(&requestData); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"}) + return + } + + response, err := h.llmService.CallLLMAPI(requestData) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Check if the response is a channel (streaming response) + if messageChan, ok := response.(chan service.Message); ok { + // Set headers for SSE + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + // Create a channel to handle client disconnection + clientGone := c.Writer.CloseNotify() + + // Stream the messages + for { + select { + case <-clientGone: + return + case message, ok := <-messageChan: + if !ok { + return + } + + // Convert message to JSON + jsonData, err := json.Marshal(message) + if err != nil { + continue + } + + // Write the SSE message + c.SSEvent("message", string(jsonData)) + c.Writer.Flush() + + // If this is the end message, close the connection + if message.IsEnd { + return + } + } + } + } + + // Non-streaming response + c.JSON(http.StatusOK, response) +} + +// StopConversation handles stopping a conversation +func (h *LLMHandler) StopConversation(c *gin.Context) { + taskID := c.Param("task_id") + if taskID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"}) + return + } + + result, err := h.llmService.StopConversation(taskID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// DeleteConversation handles deleting a conversation +func (h *LLMHandler) DeleteConversation(c *gin.Context) { + conversationID := c.Param("conversation_id") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"}) + return + } + + user := c.DefaultQuery("user", "default_user") + result, err := h.llmService.DeleteConversation(conversationID, user) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// SynthesizeSpeech handles text-to-speech requests +func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) { + var request struct { + Text string `json:"text" binding:"required"` + Audio string `json:"audio" binding:"required"` + } + + if err := c.ShouldBindJSON(&request); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"}) + return + } + + result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, result) +} + +// StreamText handles streaming text output +func (h *LLMHandler) StreamText(c *gin.Context) { + // Set headers for SSE + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + segments := []string{ + "好的,", + "我已经成功替换了文本内容。", + "新的文本是一段连续的描述,", + "没有换行,", + "总共65个字符,", + "符合100字以内的要求,", + "并且是一个连续的段落。", + "现在我需要完成任务。", + } + + // Create a channel to handle client disconnection + clientGone := c.Writer.CloseNotify() + + conversationID := "conv_" + time.Now().Format("20060102150405") + taskID := "task_" + time.Now().Format("20060102150405") + + // Stream the segments + for _, segment := range segments { + select { + case <-clientGone: + return + default: + // Create message object + message := map[string]interface{}{ + "event": "message", + "answer": segment, + "conversation_id": conversationID, + "task_id": taskID, + } + + // Convert to JSON and send + jsonData, _ := json.Marshal(message) + c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n")) + c.Writer.Flush() + } + } + + // Send end message + endMessage := map[string]interface{}{ + "event": "message_end", + "answer": "", + "conversation_id": conversationID, + "task_id": taskID, + } + + jsonData, _ := json.Marshal(endMessage) + c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n")) + c.Writer.Flush() +} diff --git a/handler/token_handler.go b/handler/token_handler.go new file mode 100644 index 0000000..e88ab7b --- /dev/null +++ b/handler/token_handler.go @@ -0,0 +1,34 @@ +package handler + +import ( + "net/http" + + "gongzheng_minimax/service" + + "github.com/gin-gonic/gin" +) + +// TokenHandler handles token generation requests +type TokenHandler struct { + tokenService *service.TokenService +} + +// NewTokenHandler creates a new instance of TokenHandler +func NewTokenHandler(tokenService *service.TokenService) *TokenHandler { + return &TokenHandler{ + tokenService: tokenService, + } +} + +// GenerateToken handles token generation requests +func (h *TokenHandler) GenerateToken(c *gin.Context) { + token, err := h.tokenService.CreateSignature() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "token": token, + }) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..8438c63 --- /dev/null +++ b/main.go @@ -0,0 +1,92 @@ +package main + +import ( + "log" + "os" + "strconv" + + "gongzheng_minimax/handler" + "gongzheng_minimax/service" + + "github.com/gin-gonic/gin" + "github.com/joho/godotenv" +) + +func main() { + // Load .env file + if err := godotenv.Load(); err != nil { + log.Printf("Warning: .env file not found: %v", err) + } + + // Initialize LLM service + llmService := service.NewLLMService(service.Config{ + LLMApiURL: os.Getenv("LLM_API_URL"), + LLMApiKey: os.Getenv("LLM_API_KEY"), + MiniMaxApiKey: os.Getenv("MiniMaxApiKey"), + MiniMaxApiURL: os.Getenv("MiniMaxApiURL"), + FILE_URL: os.Getenv("FILE_URL"), + }) + + // Get token configuration from environment variables + sigExp, err := strconv.Atoi(os.Getenv("SIG_EXP")) + if err != nil { + sigExp = 3600 // Default to 1 hour if not set + } + + // Initialize token service + tokenService := service.NewTokenService(service.TokenConfig{ + AppID: os.Getenv("APP_ID"), + AppKey: os.Getenv("APP_KEY"), + SigExp: sigExp, + }) + + // Initialize handlers + llmHandler := handler.NewLLMHandler(llmService) + tokenHandler := handler.NewTokenHandler(tokenService) + + // Create Gin router + router := gin.Default() + + // Add CORS middleware + router.Use(func(c *gin.Context) { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE") + + if c.Request.Method == "OPTIONS" { + c.AbortWithStatus(204) + return + } + + c.Next() + }) + + // Define routes + router.POST("/chat", llmHandler.Chat) + router.POST("/chat-messages/:task_id/stop", llmHandler.StopConversation) + router.DELETE("/conversations/:conversation_id", llmHandler.DeleteConversation) + router.POST("/speech/synthesize", llmHandler.SynthesizeSpeech) + router.GET("/stream-text", llmHandler.StreamText) + router.POST("/token", tokenHandler.GenerateToken) + + // Serve static files + router.Static("/static", "./static") + + // Get host and port from environment variables + host := os.Getenv("HOST") + if host == "" { + host = "0.0.0.0" // Default to all interfaces + } + port := os.Getenv("PORT") + if port == "" { + port = "8080" + } + + // Start server + serverAddr := host + ":" + port + log.Printf("Server starting on %s", serverAddr) + if err := router.Run(serverAddr); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} diff --git a/service/llm_service.go b/service/llm_service.go new file mode 100644 index 0000000..fd89c3b --- /dev/null +++ b/service/llm_service.go @@ -0,0 +1,725 @@ +package service + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math" + "net/http" + "os" + "path/filepath" + "strings" + "time" + "unicode/utf8" +) + +// Config holds the configuration for the LLM service +type Config struct { + LLMApiURL string + LLMApiKey string + MiniMaxApiKey string + MiniMaxApiURL string + FILE_URL string +} + +// LLMService handles communication with the LLM API +type LLMService struct { + config Config + client *http.Client +} + +// Message represents a single message in the conversation +type Message struct { + Answer string `json:"answer"` + IsEnd bool `json:"isEnd"` + ConversationID string `json:"conversation_id"` + TaskID string `json:"task_id"` + ClientID string `json:"client_id,omitempty"` + AudioData string `json:"audio_data,omitempty"` +} + +// RequestPayload represents the payload sent to the LLM API +type RequestPayload struct { + Inputs map[string]interface{} `json:"inputs"` + Query string `json:"query"` + ResponseMode string `json:"response_mode"` + User string `json:"user"` + ConversationID string `json:"conversation_id"` + Files []interface{} `json:"files"` + Audio string `json:"audio"` +} + +// VoiceSetting represents voice configuration +type VoiceSetting struct { + VoiceID string `json:"voice_id"` + Speed float64 `json:"speed"` + Vol float64 `json:"vol"` + Pitch float64 `json:"pitch"` + Emotion string `json:"emotion"` +} + +// AudioSetting represents audio configuration +type AudioSetting struct { + SampleRate int `json:"sample_rate"` + Bitrate int `json:"bitrate"` + Format string `json:"format"` +} + +// SpeechRequest represents the speech synthesis request payload +type SpeechRequest struct { + Model string `json:"model"` + Text string `json:"text"` + Stream bool `json:"stream"` + LanguageBoost string `json:"language_boost"` + OutputFormat string `json:"output_format"` + VoiceSetting VoiceSetting `json:"voice_setting"` + AudioSetting AudioSetting `json:"audio_setting"` +} + +// SpeechData represents the speech data in the response +type SpeechData struct { + Audio string `json:"audio"` + Status int `json:"status"` +} + +// ExtraInfo represents additional information about the speech +type ExtraInfo struct { + AudioLength int `json:"audio_length"` + AudioSampleRate int `json:"audio_sample_rate"` + AudioSize int `json:"audio_size"` + AudioBitrate int `json:"audio_bitrate"` + WordCount int `json:"word_count"` + InvisibleCharacterRatio float64 `json:"invisible_character_ratio"` + AudioFormat string `json:"audio_format"` + UsageCharacters int `json:"usage_characters"` +} + +// BaseResponse represents the base response structure +type BaseResponse struct { + StatusCode int `json:"status_code"` + StatusMsg string `json:"status_msg"` +} + +// SpeechResponse represents the speech synthesis response +type SpeechResponse struct { + Data SpeechData `json:"data"` + ExtraInfo ExtraInfo `json:"extra_info"` + TraceID string `json:"trace_id"` + BaseResp BaseResponse `json:"base_resp"` +} + +// NewLLMService creates a new instance of LLMService +func NewLLMService(config Config) *LLMService { + return &LLMService{ + config: config, + client: &http.Client{}, + } +} + +// CallLLMAPI handles both streaming and non-streaming API calls +func (s *LLMService) CallLLMAPI(data map[string]interface{}) (interface{}, error) { + payload := RequestPayload{ + Inputs: make(map[string]interface{}), + Query: getString(data, "query"), + ResponseMode: getString(data, "response_mode"), + User: getString(data, "user"), + ConversationID: getString(data, "conversation_id"), + Files: make([]interface{}, 0), + Audio: getString(data, "audio"), + } + + fmt.Printf("前端传来的数据:%+v\n", payload) + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling payload: %v", err) + } + + req, err := http.NewRequest("POST", s.config.LLMApiURL+"/chat-messages", bytes.NewBuffer(jsonData)) + // req, err := http.NewRequest("GET", "http://localhost:8080/stream-text", nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey) + req.Header.Set("Content-Type", "application/json") + + isStreaming := payload.ResponseMode == "streaming" + if isStreaming { + return s.handleStreamingResponse(req, data, payload.Audio) + } + + return s.handleNonStreamingResponse(req) +} + +// handleStreamingResponse processes streaming responses +func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) { + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + messageChan := make(chan Message, 100) // Buffered channel for better performance + initialSessage := "" + go func() { + defer resp.Body.Close() + defer close(messageChan) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + fmt.Printf("Error reading line: %v\n", err) + continue + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Remove "data: " prefix if present + line = strings.TrimPrefix(line, "data: ") + + var jsonData map[string]interface{} + if err := json.Unmarshal([]byte(line), &jsonData); err != nil { + fmt.Printf("Error unmarshaling JSON: %v\n", err) + continue + } + + event := getString(jsonData, "event") + switch event { + case "message": + answer := getString(jsonData, "answer") + var audio string + + // 定义标点符号map + punctuations := map[string]bool{ + ",": true, ",": true, // 逗号 + ".": true, "。": true, // 句号 + "!": true, "!": true, // 感叹号 + "?": true, "?": true, // 问号 + ";": true, ";": true, // 分号 + ":": true, ":": true, // 冒号 + "、": true, + } + + // 删除字符串前后的标点符号 + trimPunctuation := func(s string) string { + if len(s) > 0 { + // 获取最后一个字符的 rune + lastRune, size := utf8.DecodeLastRuneInString(s) + if punctuations[string(lastRune)] { + s = s[:len(s)-size] + } + } + return s + } + + // 判断字符串是否包含标点符号 + containsPunctuation := func(s string) bool { + for _, char := range s { + if punctuations[string(char)] { + return true + } + } + return false + } + + // 按标点符号分割文本 + splitByPunctuation := func(s string) []string { + var result []string + var current string + for _, char := range s { + if punctuations[string(char)] { + if current != "" { + result = append(result, current+string(char)) + current = "" + } + } else { + current += string(char) + } + } + if current != "" { + result = append(result, current) + } + return result + } + new_message := "" + initialSessage += answer + if containsPunctuation(initialSessage) { + segments := splitByPunctuation(initialSessage) + // fmt.Printf("原始文本: %s\n", initialSessage) + // fmt.Printf("分割后的片段数量: %d\n", len(segments)) + // for i, segment := range segments { + // fmt.Printf("片段 %d: %s\n", i+1, segment) + // } + if len(segments) > 1 { + initialSessage = segments[len(segments)-1] + new_message = strings.Join(segments[:len(segments)-1], "") + } else { + new_message = initialSessage + initialSessage = "" + } + // fmt.Printf("新消息: %s\n", new_message) + // fmt.Printf("剩余文本: %s\n", initialSessage) + } + + if new_message == "" { + continue + } + s_msg := strings.TrimSpace(new_message) + // Trim punctuation from the message + new_message = trimPunctuation(s_msg) + // fmt.Println("new_message", new_message) + + // 最多重试一次 + for i := 0; i < 1; i++ { + speechResp, err := s.SynthesizeSpeech(new_message, audio_type) + if err != nil { + fmt.Printf("Error synthesizing speech: %v\n", err) + break // 语音接口报错直接跳出 + } + fmt.Println("语音:", speechResp) + audio = speechResp.Data.Audio + if audio != "" { + // Download audio from URL and trim silence + resp, err := http.Get(audio) + if err != nil { + fmt.Printf("Error downloading audio: %v\n", err) + } else { + defer resp.Body.Close() + audioBytes, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading audio data: %v\n", err) + } else { + // Save original audio first + originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().Unix()) + if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil { + fmt.Printf("Error saving original audio: %v\n", err) + } + + // Convert audio bytes to base64 for processing + audioBase64 := base64.StdEncoding.EncodeToString(audioBytes) + trimmedAudio, err := s.TrimAudioSilence(audioBase64) + if err != nil { + fmt.Printf("Error trimming audio silence: %v\n", err) + } else { + // Save the trimmed audio as WAV file + audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().Unix()) + outputPath := "audio/" + audio_path + if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil { + fmt.Printf("Error saving trimmed WAV file: %v\n", err) + } + audio = s.config.FILE_URL + audio_path + } + } + } + break // 获取到音频就退出 + } + fmt.Println("audio is empty, retry", speechResp) + // time.Sleep(1 * time.Second) + } + + messageChan <- Message{ + Answer: new_message, + IsEnd: false, + ConversationID: getString(jsonData, "conversation_id"), + TaskID: getString(jsonData, "task_id"), + ClientID: getString(data, "conversation_id"), + AudioData: audio, // Update to use the correct path to audio data + } + case "message_end": + messageChan <- Message{ + Answer: "", + IsEnd: true, + ConversationID: getString(jsonData, "conversation_id"), + TaskID: getString(jsonData, "task_id"), + } + return + } + } + }() + + return messageChan, nil +} + +// handleNonStreamingResponse processes non-streaming responses +func (s *LLMService) handleNonStreamingResponse(req *http.Request) (map[string]interface{}, error) { + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + return result, nil +} + +// StopConversation stops an ongoing conversation +func (s *LLMService) StopConversation(taskID string) (map[string]interface{}, error) { + req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat-messages/%s/stop", s.config.LLMApiURL, taskID), nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + return result, nil +} + +// DeleteConversation deletes a conversation +func (s *LLMService) DeleteConversation(conversationID, user string) (map[string]interface{}, error) { + req, err := http.NewRequest("DELETE", fmt.Sprintf("%s/conversations/%s", s.config.LLMApiURL, conversationID), nil) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + defer resp.Body.Close() + + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %v", err) + } + + return result, nil +} + +// SynthesizeSpeech converts text to speech +func (s *LLMService) SynthesizeSpeech(text string, audio string) (*SpeechResponse, error) { + payload := SpeechRequest{ + Model: "speech-02-turbo", + Text: text, + Stream: false, + LanguageBoost: "auto", + OutputFormat: "url", + VoiceSetting: VoiceSetting{ + VoiceID: audio, + Speed: 1, + Vol: 1, + Pitch: 0, + Emotion: "happy", + }, + AudioSetting: AudioSetting{ + SampleRate: 32000, + Bitrate: 128000, + Format: "wav", + }, + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling speech request: %v", err) + } + + req, err := http.NewRequest("POST", s.config.MiniMaxApiURL, bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Println("error creating speech request: ", err) + return nil, fmt.Errorf("error creating speech request: %v", err) + } + + req.Header.Set("Authorization", "Bearer "+s.config.MiniMaxApiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + fmt.Println("error making speech request: ", err) + return nil, fmt.Errorf("error making speech request: %v", err) + } + defer resp.Body.Close() + // fmt.Println(resp.Body) + if resp.StatusCode != http.StatusOK { + fmt.Println("unexpected status code: ", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + var result SpeechResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + fmt.Println("error decoding speech response: ", err) + return nil, fmt.Errorf("error decoding speech response: %v", err) + } + + return &result, nil +} + +// StreamTextResponse handles streaming text output with predefined segments +func (s *LLMService) StreamTextResponse(conversationID string) (chan Message, error) { + messageChan := make(chan Message, 100) + + segments := []string{ + "好的,", + "我已经成功替换了文本内容。", + "新的文本是一段连续的描述,", + "没有换行,", + "总共65个字符,", + "符合100字以内的要求,", + "并且是一个连续的段落。", + "现在我需要完成任务。", + } + + go func() { + defer close(messageChan) + taskID := "task_" + time.Now().Format("20060102150405") + + for _, segment := range segments { + // Send message + messageChan <- Message{ + Answer: segment, + IsEnd: false, + ConversationID: conversationID, + TaskID: taskID, + ClientID: conversationID, + } + + // Add delay between segments + time.Sleep(500 * time.Millisecond) + } + + // Send end message + messageChan <- Message{ + Answer: "", + IsEnd: true, + ConversationID: conversationID, + TaskID: taskID, + } + }() + + return messageChan, nil +} + +// Helper function to safely get string values from interface{} +func getString(data map[string]interface{}, key string) string { + if val, ok := data[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return "" +} + +// TrimAudioSilence trims the silence at the end of the audio data +func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { + // Decode base64 audio data + decodedData, err := base64.StdEncoding.DecodeString(audioData) + if err != nil { + return "", fmt.Errorf("error decoding base64 audio: %v", err) + } + + // Create a buffer from the decoded data + buf := bytes.NewReader(decodedData) + + // Read RIFF header + var riffHeader struct { + ChunkID [4]byte + ChunkSize uint32 + Format [4]byte + } + + if err := binary.Read(buf, binary.LittleEndian, &riffHeader); err != nil { + return "", fmt.Errorf("error reading RIFF header: %v", err) + } + + // Verify RIFF header + if string(riffHeader.ChunkID[:]) != "RIFF" || string(riffHeader.Format[:]) != "WAVE" { + return "", fmt.Errorf("invalid WAV format") + } + + // Read fmt chunk + var fmtChunk struct { + Subchunk1ID [4]byte + Subchunk1Size uint32 + AudioFormat uint16 + NumChannels uint16 + SampleRate uint32 + ByteRate uint32 + BlockAlign uint16 + BitsPerSample uint16 + } + + if err := binary.Read(buf, binary.LittleEndian, &fmtChunk); err != nil { + return "", fmt.Errorf("error reading fmt chunk: %v", err) + } + + // Skip any extra bytes in fmt chunk + if fmtChunk.Subchunk1Size > 16 { + extraBytes := make([]byte, fmtChunk.Subchunk1Size-16) + if _, err := buf.Read(extraBytes); err != nil { + return "", fmt.Errorf("error skipping extra fmt bytes: %v", err) + } + } + + // Find data chunk + var dataChunk struct { + Subchunk2ID [4]byte + Subchunk2Size uint32 + } + + for { + if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil { + return "", fmt.Errorf("error reading chunk header: %v", err) + } + + if string(dataChunk.Subchunk2ID[:]) == "data" { + break + } + + // Skip this chunk + if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil { + return "", fmt.Errorf("error skipping chunk: %v", err) + } + } + + // Read audio data + audioBytes := make([]byte, dataChunk.Subchunk2Size) + if _, err := buf.Read(audioBytes); err != nil { + return "", fmt.Errorf("error reading audio data: %v", err) + } + + // Calculate samples per channel + samplesPerChannel := len(audioBytes) / int(fmtChunk.BlockAlign) + channels := int(fmtChunk.NumChannels) + bytesPerSample := int(fmtChunk.BitsPerSample) / 8 + + // Find the last non-silent sample + lastNonSilent := 0 + silenceThreshold := 0.01 // Adjust this threshold as needed + + for i := 0; i < samplesPerChannel; i++ { + isSilent := true + for ch := 0; ch < channels; ch++ { + offset := i*int(fmtChunk.BlockAlign) + ch*bytesPerSample + if offset+bytesPerSample > len(audioBytes) { + continue + } + + // Convert bytes to sample value + var sample int16 + if err := binary.Read(bytes.NewReader(audioBytes[offset:offset+bytesPerSample]), binary.LittleEndian, &sample); err != nil { + continue + } + + // Normalize sample to [-1, 1] range + normalizedSample := float64(sample) / 32768.0 + if math.Abs(normalizedSample) > silenceThreshold { + isSilent = false + break + } + } + + if !isSilent { + lastNonSilent = i + } + } + + // Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample + bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1) + lastSample := lastNonSilent + bufferSamples + if lastSample > samplesPerChannel { + lastSample = samplesPerChannel + } + + // Calculate new data size + newDataSize := lastSample * int(fmtChunk.BlockAlign) + trimmedAudio := audioBytes[:newDataSize] + + // Create new buffer for the trimmed audio + var newBuf bytes.Buffer + + // Write RIFF header + riffHeader.ChunkSize = uint32(36 + newDataSize) + if err := binary.Write(&newBuf, binary.LittleEndian, riffHeader); err != nil { + return "", fmt.Errorf("error writing RIFF header: %v", err) + } + + // Write fmt chunk + if err := binary.Write(&newBuf, binary.LittleEndian, fmtChunk); err != nil { + return "", fmt.Errorf("error writing fmt chunk: %v", err) + } + + // Write data chunk header + dataChunk.Subchunk2Size = uint32(newDataSize) + if err := binary.Write(&newBuf, binary.LittleEndian, dataChunk); err != nil { + return "", fmt.Errorf("error writing data chunk header: %v", err) + } + + // Write trimmed audio data + if _, err := newBuf.Write(trimmedAudio); err != nil { + return "", fmt.Errorf("error writing trimmed audio data: %v", err) + } + + // Encode back to base64 + return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil +} + +// SaveBase64AsWAV saves base64 encoded audio data as a WAV file +func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error { + // Decode base64 data + audioData, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + return fmt.Errorf("error decoding base64 data: %v", err) + } + + // Validate WAV header + if len(audioData) < 44 { // WAV header is 44 bytes + return fmt.Errorf("invalid WAV data: too short") + } + + // Check RIFF header + if string(audioData[0:4]) != "RIFF" { + return fmt.Errorf("invalid WAV format: missing RIFF header") + } + + // Check WAVE format + if string(audioData[8:12]) != "WAVE" { + return fmt.Errorf("invalid WAV format: missing WAVE format") + } + + // Create output directory if it doesn't exist + dir := filepath.Dir(outputPath) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("error creating directory: %v", err) + } + + // Write the audio data to file + if err := os.WriteFile(outputPath, audioData, 0644); err != nil { + return fmt.Errorf("error writing WAV file: %v", err) + } + + return nil +} diff --git a/service/token_service.go b/service/token_service.go new file mode 100644 index 0000000..e49c282 --- /dev/null +++ b/service/token_service.go @@ -0,0 +1,53 @@ +package service + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" +) + +// TokenConfig holds the configuration for the token service +type TokenConfig struct { + AppID string + AppKey string + SigExp int +} + +// TokenService handles JWT token generation +type TokenService struct { + config TokenConfig +} + +// NewTokenService creates a new instance of TokenService +func NewTokenService(config TokenConfig) *TokenService { + return &TokenService{ + config: config, + } +} + +// CreateSignature generates a JWT token +func (s *TokenService) CreateSignature() (string, error) { + // Get current time + now := time.Now().UTC() + // Calculate expiration time + expiresAt := now.Add(time.Duration(s.config.SigExp) * time.Second) + + // Create claims + claims := jwt.MapClaims{ + "iss": "your-issuer", // Optional: Issuer + "iat": now.Unix(), // Issued at time + "exp": expiresAt.Unix(), // Expiration time + "appId": s.config.AppID, // Custom claim + } + + // Create token + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + // Sign and get the complete encoded token as a string + tokenString, err := token.SignedString([]byte(s.config.AppKey)) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/static/index.html b/static/index.html new file mode 100644 index 0000000..2b4d0ff --- /dev/null +++ b/static/index.html @@ -0,0 +1,190 @@ + + + + + + Chat Demo + + + +
+
+
+ + +
+
+ + + + \ No newline at end of file