diff --git a/handler/llm_handler.go b/handler/llm_handler.go index 10d8b51..28a5d87 100644 --- a/handler/llm_handler.go +++ b/handler/llm_handler.go @@ -117,14 +117,14 @@ func (h *LLMHandler) ChatExt(c *gin.Context) { // 构造 Service 层需要的参数 map serviceData := map[string]interface{}{ - "tag_ids": []string{"1", "2"}, + "tag_ids": []int{1, 11, 29}, "conversation_id": conversationID, "content": requestData.DhQuestion, } fmt.Printf("Calling Service with data: %+v\n", serviceData) - response, err := h.llmService.CallExtQAAPI(serviceData) + response, err := h.llmService.CallExtQAAPIStreamDirect(serviceData) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return diff --git a/service/llm_service.go b/service/llm_service.go index 5009a75..090aab4 100644 --- a/service/llm_service.go +++ b/service/llm_service.go @@ -129,9 +129,9 @@ type LLMOurRequestPayload struct { // ExtQARequestPayload represents the payload for the external QA API type ExtQARequestPayload struct { - TagIDs []string `json:"tag_ids"` - ConversationID string `json:"conversation_id"` - Content string `json:"content"` + TagIDs []int `json:"tag_ids"` + ConversationID string `json:"conversation_id"` + Content string `json:"content"` } // ExtQAResponse represents the response from the external QA API @@ -1091,18 +1091,14 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) { // CallExtQAAPI handles the external QA API call func (s *LLMService) CallExtQAAPI(data map[string]interface{}) (interface{}, error) { - var tagIDs []string - // 优先尝试直接断言为 []string - if ids, ok := data["tag_ids"].([]string); ok { + var tagIDs []int + // 优先尝试直接断言为 []int + if ids, ok := data["tag_ids"].([]int); ok { tagIDs = ids } else if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok { for _, v := range tagIDsRaw { - // 尝试转换为 string - if id, ok := v.(string); ok { - tagIDs = append(tagIDs, id) - } else if id, ok := v.(float64); ok { - // 兼容数字类型,转为字符串 - tagIDs = append(tagIDs, fmt.Sprintf("%d", int(id))) + if id, ok := v.(float64); ok { + tagIDs = append(tagIDs, int(id)) } } } @@ -1130,6 +1126,122 @@ func (s *LLMService) CallExtQAAPI(data map[string]interface{}) (interface{}, err return s.handleStreamingResponseForExt(req, data) } +// CallExtQAAPIStreamDirect handles the external QA API call and streams the response directly +func (s *LLMService) CallExtQAAPIStreamDirect(data map[string]interface{}) (interface{}, error) { + var tagIDs []int + // 优先尝试直接断言为 []int + if ids, ok := data["tag_ids"].([]int); ok { + tagIDs = ids + } else if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok { + for _, v := range tagIDsRaw { + if id, ok := v.(float64); ok { + tagIDs = append(tagIDs, int(id)) + } + } + } + + payload := ExtQARequestPayload{ + TagIDs: tagIDs, + ConversationID: getString(data, "conversation_id"), + Content: getString(data, "content"), + } + + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("error marshaling payload: %v", err) + } + + url := "http://47.100.108.206:30028/api/qa/v1/chat/completionForExt" + fmt.Printf("Sending request to %s with payload: %s\n", url, string(jsonData)) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("error creating request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + return s.handleStreamingResponseForExtDirect(req, data) +} + +// handleStreamingResponseForExtDirect processes streaming responses from the external QA API and returns Message channel +func (s *LLMService) handleStreamingResponseForExtDirect(req *http.Request, data map[string]interface{}) (chan Message, error) { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + fmt.Printf("Error making external request: %v\n", err) + return nil, fmt.Errorf("error making request: %v", err) + } + + fmt.Printf("External API response status: %d\n", resp.StatusCode) + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + fmt.Printf("External API error body: %s\n", string(body)) + resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + messageChan := make(chan Message, 100) + go func() { + defer resp.Body.Close() + defer close(messageChan) + reader := bufio.NewReader(resp.Body) + + conversationID := getString(data, "conversation_id") + // Use current time as task ID since external API might not provide it in every chunk + taskID := fmt.Sprintf("task_%d", time.Now().UnixNano()) + + for { + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + break + } + fmt.Printf("Error reading line: %v\n", err) + continue + } + + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Remove "data: " prefix if present + if strings.HasPrefix(line, "data:") { + line = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + if line == "" { + continue + } + + fmt.Printf("Processing line: %s\n", line) + + var response ExtQAResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + fmt.Printf("Error unmarshaling JSON: %v, line: %s\n", err, line) + continue + } + + // Direct forwarding logic + answer := response.Output.Text + isEnd := response.Output.FinishReason == "stop" || response.Output.FinishReason == "length" + + messageChan <- Message{ + Answer: answer, + IsEnd: isEnd, + ConversationID: conversationID, + TaskID: taskID, + } + + if isEnd { + return + } + } + }() + + return messageChan, nil +} + // handleStreamingResponseForExt processes streaming responses from the external QA API func (s *LLMService) handleStreamingResponseForExt(req *http.Request, data map[string]interface{}) (chan Message, error) { client := &http.Client{}