196 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			196 lines
		
	
	
		
			4.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package handler
 | 
						||
 | 
						||
import (
 | 
						||
	"encoding/json"
 | 
						||
	"net/http"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"gongzheng_minimax/service"
 | 
						||
 | 
						||
	"github.com/gin-gonic/gin"
 | 
						||
)
 | 
						||
 | 
						||
// LLMHandler handles HTTP requests for the LLM service
 | 
						||
type LLMHandler struct {
 | 
						||
	llmService *service.LLMService
 | 
						||
}
 | 
						||
 | 
						||
// NewLLMHandler creates a new instance of LLMHandler
 | 
						||
func NewLLMHandler(llmService *service.LLMService) *LLMHandler {
 | 
						||
	return &LLMHandler{
 | 
						||
		llmService: llmService,
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// Chat handles chat requests
 | 
						||
func (h *LLMHandler) Chat(c *gin.Context) {
 | 
						||
	var requestData map[string]interface{}
 | 
						||
	if err := c.ShouldBindJSON(&requestData); err != nil {
 | 
						||
		c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	response, err := h.llmService.CallLLMAPI(requestData)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	// Check if the response is a channel (streaming response)
 | 
						||
	if messageChan, ok := response.(chan service.Message); ok {
 | 
						||
		// Set headers for SSE
 | 
						||
		c.Header("Content-Type", "text/event-stream")
 | 
						||
		c.Header("Cache-Control", "no-cache")
 | 
						||
		c.Header("Connection", "keep-alive")
 | 
						||
		c.Header("Transfer-Encoding", "chunked")
 | 
						||
 | 
						||
		// Create a channel to handle client disconnection
 | 
						||
		clientGone := c.Writer.CloseNotify()
 | 
						||
 | 
						||
		// Stream the messages
 | 
						||
		for {
 | 
						||
			select {
 | 
						||
			case <-clientGone:
 | 
						||
				return
 | 
						||
			case message, ok := <-messageChan:
 | 
						||
				if !ok {
 | 
						||
					return
 | 
						||
				}
 | 
						||
 | 
						||
				// Convert message to JSON
 | 
						||
				jsonData, err := json.Marshal(message)
 | 
						||
				if err != nil {
 | 
						||
					continue
 | 
						||
				}
 | 
						||
 | 
						||
				// Write the SSE message
 | 
						||
				c.SSEvent("message", string(jsonData))
 | 
						||
				c.Writer.Flush()
 | 
						||
 | 
						||
				// If this is the end message, close the connection
 | 
						||
				if message.IsEnd {
 | 
						||
					return
 | 
						||
				}
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// Non-streaming response
 | 
						||
	c.JSON(http.StatusOK, response)
 | 
						||
}
 | 
						||
 | 
						||
// StopConversation handles stopping a conversation
 | 
						||
func (h *LLMHandler) StopConversation(c *gin.Context) {
 | 
						||
	taskID := c.Param("task_id")
 | 
						||
	if taskID == "" {
 | 
						||
		c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	result, err := h.llmService.StopConversation(taskID)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	c.JSON(http.StatusOK, result)
 | 
						||
}
 | 
						||
 | 
						||
// DeleteConversation handles deleting a conversation
 | 
						||
func (h *LLMHandler) DeleteConversation(c *gin.Context) {
 | 
						||
	conversationID := c.Param("conversation_id")
 | 
						||
	if conversationID == "" {
 | 
						||
		c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	user := c.DefaultQuery("user", "default_user")
 | 
						||
	result, err := h.llmService.DeleteConversation(conversationID, user)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	c.JSON(http.StatusOK, result)
 | 
						||
}
 | 
						||
 | 
						||
// SynthesizeSpeech handles text-to-speech requests
 | 
						||
func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) {
 | 
						||
	var request struct {
 | 
						||
		Text  string `json:"text" binding:"required"`
 | 
						||
		Audio string `json:"audio" binding:"required"`
 | 
						||
	}
 | 
						||
 | 
						||
	if err := c.ShouldBindJSON(&request); err != nil {
 | 
						||
		c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 | 
						||
		return
 | 
						||
	}
 | 
						||
 | 
						||
	c.JSON(http.StatusOK, result)
 | 
						||
}
 | 
						||
 | 
						||
// StreamText handles streaming text output
 | 
						||
func (h *LLMHandler) StreamText(c *gin.Context) {
 | 
						||
	// Set headers for SSE
 | 
						||
	c.Header("Content-Type", "text/event-stream")
 | 
						||
	c.Header("Cache-Control", "no-cache")
 | 
						||
	c.Header("Connection", "keep-alive")
 | 
						||
	c.Header("Transfer-Encoding", "chunked")
 | 
						||
 | 
						||
	segments := []string{
 | 
						||
		"好的,",
 | 
						||
		"我已经成功替换了文本内容。",
 | 
						||
		"新的文本是一段连续的描述,",
 | 
						||
		"没有换行,",
 | 
						||
		"总共65个字符,",
 | 
						||
		"符合100字以内的要求,",
 | 
						||
		"并且是一个连续的段落。",
 | 
						||
		"现在我需要完成任务。",
 | 
						||
	}
 | 
						||
 | 
						||
	// Create a channel to handle client disconnection
 | 
						||
	clientGone := c.Writer.CloseNotify()
 | 
						||
 | 
						||
	conversationID := "conv_" + time.Now().Format("20060102150405")
 | 
						||
	taskID := "task_" + time.Now().Format("20060102150405")
 | 
						||
 | 
						||
	// Stream the segments
 | 
						||
	for _, segment := range segments {
 | 
						||
		select {
 | 
						||
		case <-clientGone:
 | 
						||
			return
 | 
						||
		default:
 | 
						||
			// Create message object
 | 
						||
			message := map[string]interface{}{
 | 
						||
				"event":           "message",
 | 
						||
				"answer":          segment,
 | 
						||
				"conversation_id": conversationID,
 | 
						||
				"task_id":         taskID,
 | 
						||
			}
 | 
						||
 | 
						||
			// Convert to JSON and send
 | 
						||
			jsonData, _ := json.Marshal(message)
 | 
						||
			c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
 | 
						||
			c.Writer.Flush()
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	// Send end message
 | 
						||
	endMessage := map[string]interface{}{
 | 
						||
		"event":           "message_end",
 | 
						||
		"answer":          "",
 | 
						||
		"conversation_id": conversationID,
 | 
						||
		"task_id":         taskID,
 | 
						||
	}
 | 
						||
 | 
						||
	jsonData, _ := json.Marshal(endMessage)
 | 
						||
	c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
 | 
						||
	c.Writer.Flush()
 | 
						||
}
 |