// 文件: grpc/server.go package grpc import ( "bufio" "io" "log" relaypb "memobus_relay_server/relay_server/proto" "memobus_relay_server/session" "net/http" "net/http/httptest" "strings" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // InternalRelayServer 实现了用于代理请求的 gRPC 服务 type InternalRelayServer struct { relaypb.UnimplementedInternalRelayServer } // NewInternalRelayServer 创建一个新的 gRPC 服务实例 func NewInternalRelayServer() *InternalRelayServer { return &InternalRelayServer{} } // ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应 func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error { // --- 1. 接收和解析请求头 --- headerMsg, err := stream.Recv() if err != nil { log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err) return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err) } header := headerMsg.GetHeader() if header == nil { return status.Errorf(codes.InvalidArgument, "first message must be a header") } // [新增调试日志] if header.Headers["Upgrade"] == "websocket" { log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"]) } // 检查是否是 WebSocket 握手请求 isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade") pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3) if len(pathParts) < 2 { return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header") } deviceSN := pathParts[1] appUserID := header.GetAppUserId() if appUserID == "" { return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header") } log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID) // --- 2. 查找本地会话并进行授权检查 --- sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN) if !ok { return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN) } if sessionInfo.UserID != appUserID { log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) return sendForbiddenResponse(stream) } log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID) // --- 3. [核心修改] 根据请求类型进行分流 --- if isWebSocket { log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.") return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header) } else { // 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器 log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.") // 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁 return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header) } } func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error { // 1. 打开到后端 (yamux) 的连接 backendConn, err := sessionInfo.Session.Open() if err != nil { log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err) return status.Errorf(codes.Internal, "failed to connect to backend service") } defer backendConn.Close() // 2. 重建原始的 HTTP 升级请求 req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil) for k, v := range header.Headers { req.Header.Set(k, v) } req.Host = "immich-internal" // 模拟 ReverseProxy 的行为 pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3) if len(pathParts) > 2 { req.URL.Path = "/" + pathParts[2] } else { req.URL.Path = "/" } // 3. 将升级请求写入后端连接,发起握手 if err := req.Write(backendConn); err != nil { log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err) return status.Errorf(codes.Internal, "failed to send upgrade request to backend") } // 4. 读取后端的响应 (握手结果) backendReader := bufio.NewReader(backendConn) resp, err := http.ReadResponse(backendReader, req) if err != nil { log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err) return status.Errorf(codes.Internal, "failed to read handshake response from backend") } // 5. 将后端的握手响应通过 gRPC 发回给代理节点 respHeaderMsg := &relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_Header{ Header: &relaypb.ProxyResponseHeader{ StatusCode: int32(resp.StatusCode), Headers: make(map[string]string), }, }, } for k, v := range resp.Header { respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",") } if err := stream.Send(respHeaderMsg); err != nil { log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err) return err } // 6. 如果握手失败 (不是 101),则流程结束 if resp.StatusCode != http.StatusSwitchingProtocols { log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode) return nil } log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN) // 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝 errChan := make(chan error, 2) // Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据) go func() { // 这个方向的逻辑没有问题 for { msg, err := stream.Recv() if err == io.EOF { backendConn.Close() errChan <- nil return } if err != nil { errChan <- err return } if chunk := msg.GetBodyChunk(); chunk != nil { if _, err := backendConn.Write(chunk.Data); err != nil { errChan <- err return } } } }() // Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据) go func() { // [核心修正] // 我们必须从 backendReader (而不是原始的 backendConn) 开始读取, // 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。 // io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。 if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil { // 过滤掉正常的连接关闭错误 if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") { errChan <- err } else { errChan <- nil } } else { errChan <- nil } }() // 等待两个 goroutine 都结束 err1 := <-errChan err2 := <-errChan if err1 != nil && err1 != io.EOF { log.Printf("WebSocket stream finished with error: %v", err1) return err1 } if err2 != nil && err2 != io.EOF { log.Printf("WebSocket stream finished with error: %v", err2) return err2 } log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN) return nil } // [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑 func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error { // 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本 // 我直接粘贴过来 // --- 3. 创建请求和响应的管道 --- reqPr, reqPw := io.Pipe() req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr) for k, v := range header.Headers { req.Header.Set(k, v) } req.Header.Set("X-Forwarded-For", header.RemoteAddr) respPr, respPw := io.Pipe() customResponseWriter := &streamResponseWriter{ header: make(http.Header), pipeWriter: respPw, headerWritten: make(chan struct{}), } // --- 4. 启动 Goroutines --- go func() { defer reqPw.Close() for { bodyMsg, err := stream.Recv() if err == io.EOF { return } if err != nil { reqPw.CloseWithError(err) return } if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil { if _, err := reqPw.Write(bodyChunk.Data); err != nil { return } } } }() errChan := make(chan error, 1) go func() { defer close(errChan) <-customResponseWriter.headerWritten // b. [修正] 完整地构造 gRPC 响应头 respHeaderMsg := &relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_Header{ Header: &relaypb.ProxyResponseHeader{ StatusCode: int32(customResponseWriter.statusCode), Headers: make(map[string]string), }, }, } for k, v := range customResponseWriter.header { respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",") } if err := stream.Send(respHeaderMsg); err != nil { errChan <- err return } buf := make([]byte, 1024*32) if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil { if err != io.ErrClosedPipe { errChan <- err } } }() // --- 5. 执行代理 --- proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery) proxy.ServeHTTP(customResponseWriter, req) // --- 6. 清理 --- respPw.Close() return <-errChan } // sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应 func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error { respHeader := &relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_Header{ Header: &relaypb.ProxyResponseHeader{ StatusCode: http.StatusForbidden, Headers: map[string]string{"Content-Type": "text/plain; charset=utf-8"}, }, }, } if err := stream.Send(respHeader); err != nil { return err } respBody := &relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_BodyChunk{ BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")}, }, } stream.Send(respBody) return nil // 正常关闭流 } // streamResponseWriter 是一个自定义的 http.ResponseWriter type streamResponseWriter struct { header http.Header pipeWriter *io.PipeWriter statusCode int headerWritten chan struct{} } func (w *streamResponseWriter) Header() http.Header { return w.header } func (w *streamResponseWriter) Write(b []byte) (int, error) { w.WriteHeader(http.StatusOK) return w.pipeWriter.Write(b) } func (w *streamResponseWriter) WriteHeader(statusCode int) { select { case <-w.headerWritten: return default: w.statusCode = statusCode close(w.headerWritten) } } // grpcResponseWriter 是一个适配器,实现了 io.Writer 接口 type grpcResponseWriter struct { stream relaypb.InternalRelay_ProxyRequestServer } func (w *grpcResponseWriter) Write(p []byte) (n int, err error) { err = w.stream.Send(&relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_BodyChunk{ BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p}, }, }) if err != nil { return 0, err } return len(p), nil } // writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中 func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error { return stream.Send(&relaypb.ProxyResponseMessage{ Payload: &relaypb.ProxyResponseMessage_BodyChunk{ BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data}, }, }) }