From babbf126aaf0a925ec1a46b6cb19d2bb99ac2d8c Mon Sep 17 00:00:00 2001 From: Song367 <601337784@qq.com> Date: Sat, 27 Dec 2025 16:30:27 +0800 Subject: [PATCH] dh chat api --- .gitea/workflows/deploy.yaml | 8 +- handler/llm_handler.go | 57 ++++++ main.go | 1 + service/llm_service.go | 329 +++++++++++++++++++++++++++++++---- 4 files changed, 355 insertions(+), 40 deletions(-) diff --git a/.gitea/workflows/deploy.yaml b/.gitea/workflows/deploy.yaml index 4a1984d..6364046 100644 --- a/.gitea/workflows/deploy.yaml +++ b/.gitea/workflows/deploy.yaml @@ -30,12 +30,12 @@ jobs: uses: https://gitea.yantootech.com/neil/build-push-action@v6 with: push: true - tags: 14.103.114.237:30005/gongzheng-backend:${{ gitea.run_id }} + tags: 14.103.114.237:30005/dh-backend:${{ gitea.run_id }} - name: Install run: | - helm upgrade --install gongzheng-backend ./.gitea/charts \ - --namespace gongzhengb \ + helm upgrade --install dh-backend ./.gitea/charts \ + --namespace dh \ --create-namespace \ - --set image.repository=14.103.114.237:30005/gongzheng-backend \ + --set image.repository=14.103.114.237:30005/dh-backend \ --set image.tag=${{ gitea.run_id }} - run: echo "🍏 This job's status is ${{ job.status }}." \ No newline at end of file diff --git a/handler/llm_handler.go b/handler/llm_handler.go index f4a5831..717b485 100644 --- a/handler/llm_handler.go +++ b/handler/llm_handler.go @@ -79,6 +79,63 @@ func (h *LLMHandler) Chat(c *gin.Context) { c.JSON(http.StatusOK, response) } +// ChatExt handles external QA chat requests +func (h *LLMHandler) ChatExt(c *gin.Context) { + var requestData map[string]interface{} + if err := c.ShouldBindJSON(&requestData); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"}) + return + } + + response, err := h.llmService.CallExtQAAPI(requestData) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + // Check if the response is a channel (streaming response) + if messageChan, ok := response.(chan service.Message); ok { + // Set headers for SSE + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("Transfer-Encoding", "chunked") + + // Create a channel to handle client disconnection + clientGone := c.Writer.CloseNotify() + + // Stream the messages + for { + select { + case <-clientGone: + return + case message, ok := <-messageChan: + if !ok { + return + } + + // Convert message to JSON + jsonData, err := json.Marshal(message) + if err != nil { + continue + } + + // Write the SSE message + c.SSEvent("message", string(jsonData)) + c.Writer.Flush() + + // If this is the end message, close the connection + if message.IsEnd { + return + } + } + } + } + + // Non-streaming response + c.JSON(http.StatusOK, response) +} + // StopConversation handles stopping a conversation func (h *LLMHandler) StopConversation(c *gin.Context) { taskID := c.Param("task_id") diff --git a/main.go b/main.go index a700486..de1b803 100644 --- a/main.go +++ b/main.go @@ -69,6 +69,7 @@ func main() { // Define routes router.POST("/chat", llmHandler.Chat) + router.POST("/chat-ext", llmHandler.ChatExt) router.POST("/chat-messages/:task_id/stop", llmHandler.StopConversation) router.DELETE("/conversations/:conversation_id", llmHandler.DeleteConversation) router.POST("/speech/synthesize", llmHandler.SynthesizeSpeech) diff --git a/service/llm_service.go b/service/llm_service.go index b95a712..b45d5ca 100644 --- a/service/llm_service.go +++ b/service/llm_service.go @@ -127,6 +127,24 @@ type LLMOurRequestPayload struct { Messages []LLMOurMessage `json:"messages"` } +// ExtQARequestPayload represents the payload for the external QA API +type ExtQARequestPayload struct { + TagIDs []int `json:"tag_ids"` + ConversationID string `json:"conversation_id"` + Content string `json:"content"` +} + +// ExtQAResponse represents the response from the external QA API +type ExtQAResponse struct { + RequestID string `json:"request_id"` + Output struct { + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + SessionID string `json:"session_id"` + DocReferences []interface{} `json:"doc_references"` + } `json:"output"` +} + // NewLLMService creates a new instance of LLMService func NewLLMService(config Config) *LLMService { return &LLMService{ @@ -981,12 +999,10 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil { return "", fmt.Errorf("error reading chunk header: %v", err) } - if string(dataChunk.Subchunk2ID[:]) == "data" { break } - - // Skip this chunk + // Skip this chunk if it's not "data" if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil { return "", fmt.Errorf("error skipping chunk: %v", err) } @@ -998,38 +1014,38 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { return "", fmt.Errorf("error reading audio data: %v", err) } - // Calculate samples per channel - samplesPerChannel := len(audioBytes) / int(fmtChunk.BlockAlign) - channels := int(fmtChunk.NumChannels) - bytesPerSample := int(fmtChunk.BitsPerSample) / 8 + // Calculate samples + bytesPerSample := int(fmtChunk.BitsPerSample / 8) + if bytesPerSample == 0 { + bytesPerSample = 2 // Default to 16-bit if unknown + } + + numSamples := len(audioBytes) / bytesPerSample + + // Find last non-silent sample + // Threshold: approx 1% of max amplitude for 16-bit audio (32768 * 0.01 ~= 327) + threshold := 327.0 - // Find the last non-silent sample lastNonSilent := 0 - silenceThreshold := 0.01 // Adjust this threshold as needed - for i := 0; i < samplesPerChannel; i++ { - isSilent := true - for ch := 0; ch < channels; ch++ { - offset := i*int(fmtChunk.BlockAlign) + ch*bytesPerSample - if offset+bytesPerSample > len(audioBytes) { - continue - } + for i := 0; i < numSamples; i++ { + // Get sample value + var sample int16 + offset := i * bytesPerSample - // Convert bytes to sample value - var sample int16 - if err := binary.Read(bytes.NewReader(audioBytes[offset:offset+bytesPerSample]), binary.LittleEndian, &sample); err != nil { - continue - } - - // Normalize sample to [-1, 1] range - normalizedSample := float64(sample) / 32768.0 - if math.Abs(normalizedSample) > silenceThreshold { - isSilent = false - break - } + if offset+bytesPerSample > len(audioBytes) { + break } - if !isSilent { + if bytesPerSample == 2 { + sample = int16(binary.LittleEndian.Uint16(audioBytes[offset : offset+2])) + } else if bytesPerSample == 1 { + // 8-bit audio is usually unsigned 0-255, center at 128 + sample = int16(audioBytes[offset]) - 128 + sample *= 256 // Scale to 16-bit range roughly + } + + if math.Abs(float64(sample)) > threshold { lastNonSilent = i } } @@ -1037,12 +1053,12 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { // Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1) lastSample := lastNonSilent + bufferSamples - if lastSample > samplesPerChannel { - lastSample = samplesPerChannel + if lastSample > numSamples { + lastSample = numSamples } // Calculate new data size - newDataSize := lastSample * int(fmtChunk.BlockAlign) + newDataSize := lastSample * bytesPerSample trimmedAudio := audioBytes[:newDataSize] // Create new buffer for the trimmed audio @@ -1070,19 +1086,260 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { return "", fmt.Errorf("error writing trimmed audio data: %v", err) } - // Encode back to base64 return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil } +// CallExtQAAPI handles the external QA API call +func (s *LLMService) CallExtQAAPI(data map[string]interface{}) (interface{}, error) { + var tagIDs []int + if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok { + for _, v := range tagIDsRaw { + if id, ok := v.(float64); ok { + tagIDs = append(tagIDs, int(id)) + } + } + } + + payload := ExtQARequestPayload{ + TagIDs: tagIDs, + ConversationID: getString(data, "conversation_id"), + Content: getString(data, "content"), + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling payload: %v", err) + } + + url := "http://47.100.108.206:30028/api/qa/v1/chat/completionForExt" + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + return s.handleStreamingResponseForExt(req, data) +} + +// handleStreamingResponseForExt processes streaming responses from the external QA API +func (s *LLMService) handleStreamingResponseForExt(req *http.Request, data map[string]interface{}) (chan Message, error) { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error making request: %v", err) + } + + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + messageChan := make(chan Message, 100) // Buffered channel for better performance + all_message := "" + initialSessage := "" + go func() { + defer resp.Body.Close() + defer close(messageChan) + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + fmt.Printf("Error reading line: %v\n", err) + continue + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Remove "data: " prefix if present + if strings.HasPrefix(line, "data:") { + line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + if line == "" { + continue + } + + var response ExtQAResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + // fmt.Printf("Error unmarshaling JSON: %v, line: %s\n", err, line) + continue + } + + // Map external API response to local variables + answer := response.Output.Text + conversationID := response.Output.SessionID + taskID := response.RequestID + + // Logic block matching 'case "message"' + if answer != "" || response.Output.FinishReason != "stop" { + fmt.Println("源文本:", answer) + + // 定义标点符号map + punctuations := map[string]bool{ + ",": true, ",": true, // 逗号 + ".": true, "。": true, // 句号 + "!": true, "!": true, // 感叹号 + "?": true, "?": true, // 问号 + ";": true, ";": true, // 分号 + ":": true, // 冒号 + "、": true, + } + + // 删除字符串前后的标点符号 + trimPunctuation := func(s string) string { + if len(s) > 0 { + // 获取最后一个字符的 rune + lastRune, size := utf8.DecodeLastRuneInString(s) + if punctuations[string(lastRune)] { + s = s[:len(s)-size] + } + } + return s + } + + // 判断字符串是否包含标点符号 + containsPunctuation := func(s string) bool { + for _, char := range s { + if punctuations[string(char)] { + return true + } + } + return false + } + + // 按标点符号分割文本 + splitByPunctuation := func(s string) []string { + var result []string + var current string + for _, char := range s { + if punctuations[string(char)] { + if current != "" { + result = append(result, current+string(char)) + current = "" + } + } else { + current += string(char) + } + } + if current != "" { + result = append(result, current) + } + return result + } + + new_message := "" + initialSessage += answer + all_message += answer + + if containsPunctuation(initialSessage) { + segments := splitByPunctuation(initialSessage) + if len(segments) > 1 { + format_message := strings.Join(segments[:len(segments)-1], "") + // 检查initialSessage的字符长度是否超过15个 + if utf8.RuneCountInString(format_message) > 15 { + initialSessage = segments[len(segments)-1] + // 如果超过10个字符,将其添加到new_message中并清空initialSessage + new_message = strings.Join(segments[:len(segments)-1], "") + } else { + if containsPunctuation(format_message) && utf8.RuneCountInString(format_message) > 10 { + initialSessage = segments[len(segments)-1] + new_message = strings.Join(segments[:len(segments)-1], "") + } else { + // continue logic (do nothing here) + } + } + } else { + if utf8.RuneCountInString(initialSessage) > 15 { + new_message = initialSessage + initialSessage = "" + } else { + // continue logic (do nothing here) + } + } + } + + if new_message != "" { + s_msg := strings.TrimSpace(new_message) + // Trim punctuation from the message + new_message = trimPunctuation(s_msg) + fmt.Println("new_message", new_message) + + // Send message without audio + fmt.Println("所有消息:", all_message) + messageChan <- Message{ + Answer: s_msg, + IsEnd: false, + ConversationID: conversationID, + TaskID: taskID, + // ClientID: conversationID, + AudioData: "", + } + } + } + + // Logic block matching 'case "message_end"' + if response.Output.FinishReason == "stop" { + // 在流结束前,处理剩余的文本 + if initialSessage != "" { + s_msg := strings.TrimSpace(initialSessage) + + // 定义标点符号map (needed again if functions are not visible, but we can reuse the ones above if scoped correctly. + // To be safe and "copy steps", I'll redefine or just use the logic if it's reachable. + // In Go, functions defined inside a loop are re-created or we can define them outside the loop. + // The user code defined them inside `case "message"`. + // I will define them at the top of the loop or inside the if block. + // Since I'm not in a switch case anymore, I can define them once at the top of the loop or before the loop.) + + // To strictly follow "copy steps", I will assume the logic needs to run. + // I'll just send the remaining message without punctuation trimming logic for audio generation + // because the original code only trimmed for audio generation `SynthesizeSpeech(new_message, audio_type)`. + // Wait, the original code sent `s_msg` (trimmed space) to messageChan, and used `new_message` (trimmed punctuation) for audio. + // So for messageChan, I just use `s_msg`. + + fmt.Println("最后一段文本:", s_msg) + + // Send the last message + messageChan <- Message{ + Answer: s_msg, + IsEnd: false, + ConversationID: conversationID, + TaskID: taskID, + // ClientID: conversationID, + AudioData: "", + } + + initialSessage = "" + } + + // Send end message + messageChan <- Message{ + Answer: "", + IsEnd: true, + ConversationID: conversationID, + TaskID: taskID, + } + return + } + } + }() + + return messageChan, nil +} + // SaveBase64AsWAV saves base64 encoded audio data as a WAV file func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error { - // Decode base64 data + // Decode base64 audio data audioData, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { - return fmt.Errorf("error decoding base64 data: %v", err) + return fmt.Errorf("error decoding base64 audio: %v", err) } - // Validate WAV header + // Valid WAV header check if len(audioData) < 44 { // WAV header is 44 bytes return fmt.Errorf("invalid WAV data: too short") }