You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
				
					361 lines
				
				11 KiB
			
		
		
			
		
	
	
					361 lines
				
				11 KiB
			| 
								 
											2 weeks ago
										 
									 | 
							
								// 文件: grpc/server.go
							 | 
						||
| 
								 | 
							
								package grpc
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import (
							 | 
						||
| 
								 | 
							
									"bufio"
							 | 
						||
| 
								 | 
							
									"io"
							 | 
						||
| 
								 | 
							
									"log"
							 | 
						||
| 
								 | 
							
									relaypb "memobus_relay_server/relay_server/proto"
							 | 
						||
| 
								 | 
							
									"memobus_relay_server/session"
							 | 
						||
| 
								 | 
							
									"net/http"
							 | 
						||
| 
								 | 
							
									"net/http/httptest"
							 | 
						||
| 
								 | 
							
									"strings"
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									"google.golang.org/grpc/codes"
							 | 
						||
| 
								 | 
							
									"google.golang.org/grpc/status"
							 | 
						||
| 
								 | 
							
								)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// InternalRelayServer 实现了用于代理请求的 gRPC 服务
							 | 
						||
| 
								 | 
							
								type InternalRelayServer struct {
							 | 
						||
| 
								 | 
							
									relaypb.UnimplementedInternalRelayServer
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// NewInternalRelayServer 创建一个新的 gRPC 服务实例
							 | 
						||
| 
								 | 
							
								func NewInternalRelayServer() *InternalRelayServer {
							 | 
						||
| 
								 | 
							
									return &InternalRelayServer{}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应
							 | 
						||
| 
								 | 
							
								func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error {
							 | 
						||
| 
								 | 
							
									// --- 1. 接收和解析请求头 ---
							 | 
						||
| 
								 | 
							
									headerMsg, err := stream.Recv()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err)
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									header := headerMsg.GetHeader()
							 | 
						||
| 
								 | 
							
									if header == nil {
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.InvalidArgument, "first message must be a header")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// [新增调试日志]
							 | 
						||
| 
								 | 
							
									if header.Headers["Upgrade"] == "websocket" {
							 | 
						||
| 
								 | 
							
										log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"])
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 检查是否是 WebSocket 握手请求
							 | 
						||
| 
								 | 
							
									isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade")
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3)
							 | 
						||
| 
								 | 
							
									if len(pathParts) < 2 {
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									deviceSN := pathParts[1]
							 | 
						||
| 
								 | 
							
									appUserID := header.GetAppUserId()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if appUserID == "" {
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 2. 查找本地会话并进行授权检查 ---
							 | 
						||
| 
								 | 
							
									sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN)
							 | 
						||
| 
								 | 
							
									if !ok {
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if sessionInfo.UserID != appUserID {
							 | 
						||
| 
								 | 
							
										log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID)
							 | 
						||
| 
								 | 
							
										return sendForbiddenResponse(stream)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 3. [核心修改] 根据请求类型进行分流 ---
							 | 
						||
| 
								 | 
							
									if isWebSocket {
							 | 
						||
| 
								 | 
							
										log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.")
							 | 
						||
| 
								 | 
							
										return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header)
							 | 
						||
| 
								 | 
							
									} else {
							 | 
						||
| 
								 | 
							
										// 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器
							 | 
						||
| 
								 | 
							
										log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.")
							 | 
						||
| 
								 | 
							
										// 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁
							 | 
						||
| 
								 | 
							
										return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
							 | 
						||
| 
								 | 
							
									// 1. 打开到后端 (yamux) 的连接
							 | 
						||
| 
								 | 
							
									backendConn, err := sessionInfo.Session.Open()
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err)
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.Internal, "failed to connect to backend service")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									defer backendConn.Close()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 2. 重建原始的 HTTP 升级请求
							 | 
						||
| 
								 | 
							
									req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil)
							 | 
						||
| 
								 | 
							
									for k, v := range header.Headers {
							 | 
						||
| 
								 | 
							
										req.Header.Set(k, v)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									req.Host = "immich-internal" // 模拟 ReverseProxy 的行为
							 | 
						||
| 
								 | 
							
									pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3)
							 | 
						||
| 
								 | 
							
									if len(pathParts) > 2 {
							 | 
						||
| 
								 | 
							
										req.URL.Path = "/" + pathParts[2]
							 | 
						||
| 
								 | 
							
									} else {
							 | 
						||
| 
								 | 
							
										req.URL.Path = "/"
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 3. 将升级请求写入后端连接,发起握手
							 | 
						||
| 
								 | 
							
									if err := req.Write(backendConn); err != nil {
							 | 
						||
| 
								 | 
							
										log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err)
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.Internal, "failed to send upgrade request to backend")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 4. 读取后端的响应 (握手结果)
							 | 
						||
| 
								 | 
							
									backendReader := bufio.NewReader(backendConn)
							 | 
						||
| 
								 | 
							
									resp, err := http.ReadResponse(backendReader, req)
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err)
							 | 
						||
| 
								 | 
							
										return status.Errorf(codes.Internal, "failed to read handshake response from backend")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 5. 将后端的握手响应通过 gRPC 发回给代理节点
							 | 
						||
| 
								 | 
							
									respHeaderMsg := &relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
										Payload: &relaypb.ProxyResponseMessage_Header{
							 | 
						||
| 
								 | 
							
											Header: &relaypb.ProxyResponseHeader{
							 | 
						||
| 
								 | 
							
												StatusCode: int32(resp.StatusCode),
							 | 
						||
| 
								 | 
							
												Headers:    make(map[string]string),
							 | 
						||
| 
								 | 
							
											},
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									for k, v := range resp.Header {
							 | 
						||
| 
								 | 
							
										respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err := stream.Send(respHeaderMsg); err != nil {
							 | 
						||
| 
								 | 
							
										log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err)
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 6. 如果握手失败 (不是 101),则流程结束
							 | 
						||
| 
								 | 
							
									if resp.StatusCode != http.StatusSwitchingProtocols {
							 | 
						||
| 
								 | 
							
										log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode)
							 | 
						||
| 
								 | 
							
										return nil
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝
							 | 
						||
| 
								 | 
							
									errChan := make(chan error, 2)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据)
							 | 
						||
| 
								 | 
							
									go func() {
							 | 
						||
| 
								 | 
							
										// 这个方向的逻辑没有问题
							 | 
						||
| 
								 | 
							
										for {
							 | 
						||
| 
								 | 
							
											msg, err := stream.Recv()
							 | 
						||
| 
								 | 
							
											if err == io.EOF {
							 | 
						||
| 
								 | 
							
												backendConn.Close()
							 | 
						||
| 
								 | 
							
												errChan <- nil
							 | 
						||
| 
								 | 
							
												return
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												errChan <- err
							 | 
						||
| 
								 | 
							
												return
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if chunk := msg.GetBodyChunk(); chunk != nil {
							 | 
						||
| 
								 | 
							
												if _, err := backendConn.Write(chunk.Data); err != nil {
							 | 
						||
| 
								 | 
							
													errChan <- err
							 | 
						||
| 
								 | 
							
													return
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据)
							 | 
						||
| 
								 | 
							
									go func() {
							 | 
						||
| 
								 | 
							
										// [核心修正]
							 | 
						||
| 
								 | 
							
										// 我们必须从 backendReader (而不是原始的 backendConn) 开始读取,
							 | 
						||
| 
								 | 
							
										// 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。
							 | 
						||
| 
								 | 
							
										// io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。
							 | 
						||
| 
								 | 
							
										if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil {
							 | 
						||
| 
								 | 
							
											// 过滤掉正常的连接关闭错误
							 | 
						||
| 
								 | 
							
											if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") {
							 | 
						||
| 
								 | 
							
												errChan <- err
							 | 
						||
| 
								 | 
							
											} else {
							 | 
						||
| 
								 | 
							
												errChan <- nil
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										} else {
							 | 
						||
| 
								 | 
							
											errChan <- nil
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// 等待两个 goroutine 都结束
							 | 
						||
| 
								 | 
							
									err1 := <-errChan
							 | 
						||
| 
								 | 
							
									err2 := <-errChan
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									if err1 != nil && err1 != io.EOF {
							 | 
						||
| 
								 | 
							
										log.Printf("WebSocket stream finished with error: %v", err1)
							 | 
						||
| 
								 | 
							
										return err1
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err2 != nil && err2 != io.EOF {
							 | 
						||
| 
								 | 
							
										log.Printf("WebSocket stream finished with error: %v", err2)
							 | 
						||
| 
								 | 
							
										return err2
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN)
							 | 
						||
| 
								 | 
							
									return nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑
							 | 
						||
| 
								 | 
							
								func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
							 | 
						||
| 
								 | 
							
									// 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本
							 | 
						||
| 
								 | 
							
									// 我直接粘贴过来
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 3. 创建请求和响应的管道 ---
							 | 
						||
| 
								 | 
							
									reqPr, reqPw := io.Pipe()
							 | 
						||
| 
								 | 
							
									req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr)
							 | 
						||
| 
								 | 
							
									for k, v := range header.Headers {
							 | 
						||
| 
								 | 
							
										req.Header.Set(k, v)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									req.Header.Set("X-Forwarded-For", header.RemoteAddr)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									respPr, respPw := io.Pipe()
							 | 
						||
| 
								 | 
							
									customResponseWriter := &streamResponseWriter{
							 | 
						||
| 
								 | 
							
										header:        make(http.Header),
							 | 
						||
| 
								 | 
							
										pipeWriter:    respPw,
							 | 
						||
| 
								 | 
							
										headerWritten: make(chan struct{}),
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 4. 启动 Goroutines ---
							 | 
						||
| 
								 | 
							
									go func() {
							 | 
						||
| 
								 | 
							
										defer reqPw.Close()
							 | 
						||
| 
								 | 
							
										for {
							 | 
						||
| 
								 | 
							
											bodyMsg, err := stream.Recv()
							 | 
						||
| 
								 | 
							
											if err == io.EOF {
							 | 
						||
| 
								 | 
							
												return
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if err != nil {
							 | 
						||
| 
								 | 
							
												reqPw.CloseWithError(err)
							 | 
						||
| 
								 | 
							
												return
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
											if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil {
							 | 
						||
| 
								 | 
							
												if _, err := reqPw.Write(bodyChunk.Data); err != nil {
							 | 
						||
| 
								 | 
							
													return
							 | 
						||
| 
								 | 
							
												}
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									errChan := make(chan error, 1)
							 | 
						||
| 
								 | 
							
									go func() {
							 | 
						||
| 
								 | 
							
										defer close(errChan)
							 | 
						||
| 
								 | 
							
										<-customResponseWriter.headerWritten
							 | 
						||
| 
								 | 
							
										// b. [修正] 完整地构造 gRPC 响应头
							 | 
						||
| 
								 | 
							
										respHeaderMsg := &relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
											Payload: &relaypb.ProxyResponseMessage_Header{
							 | 
						||
| 
								 | 
							
												Header: &relaypb.ProxyResponseHeader{
							 | 
						||
| 
								 | 
							
													StatusCode: int32(customResponseWriter.statusCode),
							 | 
						||
| 
								 | 
							
													Headers:    make(map[string]string),
							 | 
						||
| 
								 | 
							
												},
							 | 
						||
| 
								 | 
							
											},
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										for k, v := range customResponseWriter.header {
							 | 
						||
| 
								 | 
							
											respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
										if err := stream.Send(respHeaderMsg); err != nil {
							 | 
						||
| 
								 | 
							
											errChan <- err
							 | 
						||
| 
								 | 
							
											return
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
										buf := make([]byte, 1024*32)
							 | 
						||
| 
								 | 
							
										if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil {
							 | 
						||
| 
								 | 
							
											if err != io.ErrClosedPipe {
							 | 
						||
| 
								 | 
							
												errChan <- err
							 | 
						||
| 
								 | 
							
											}
							 | 
						||
| 
								 | 
							
										}
							 | 
						||
| 
								 | 
							
									}()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 5. 执行代理 ---
							 | 
						||
| 
								 | 
							
									proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery)
							 | 
						||
| 
								 | 
							
									proxy.ServeHTTP(customResponseWriter, req)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
									// --- 6. 清理 ---
							 | 
						||
| 
								 | 
							
									respPw.Close()
							 | 
						||
| 
								 | 
							
									return <-errChan
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应
							 | 
						||
| 
								 | 
							
								func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error {
							 | 
						||
| 
								 | 
							
									respHeader := &relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
										Payload: &relaypb.ProxyResponseMessage_Header{
							 | 
						||
| 
								 | 
							
											Header: &relaypb.ProxyResponseHeader{
							 | 
						||
| 
								 | 
							
												StatusCode: http.StatusForbidden,
							 | 
						||
| 
								 | 
							
												Headers:    map[string]string{"Content-Type": "text/plain; charset=utf-8"},
							 | 
						||
| 
								 | 
							
											},
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									if err := stream.Send(respHeader); err != nil {
							 | 
						||
| 
								 | 
							
										return err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									respBody := &relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
										Payload: &relaypb.ProxyResponseMessage_BodyChunk{
							 | 
						||
| 
								 | 
							
											BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")},
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									stream.Send(respBody)
							 | 
						||
| 
								 | 
							
									return nil // 正常关闭流
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// streamResponseWriter 是一个自定义的 http.ResponseWriter
							 | 
						||
| 
								 | 
							
								type streamResponseWriter struct {
							 | 
						||
| 
								 | 
							
									header        http.Header
							 | 
						||
| 
								 | 
							
									pipeWriter    *io.PipeWriter
							 | 
						||
| 
								 | 
							
									statusCode    int
							 | 
						||
| 
								 | 
							
									headerWritten chan struct{}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (w *streamResponseWriter) Header() http.Header {
							 | 
						||
| 
								 | 
							
									return w.header
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (w *streamResponseWriter) Write(b []byte) (int, error) {
							 | 
						||
| 
								 | 
							
									w.WriteHeader(http.StatusOK)
							 | 
						||
| 
								 | 
							
									return w.pipeWriter.Write(b)
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (w *streamResponseWriter) WriteHeader(statusCode int) {
							 | 
						||
| 
								 | 
							
									select {
							 | 
						||
| 
								 | 
							
									case <-w.headerWritten:
							 | 
						||
| 
								 | 
							
										return
							 | 
						||
| 
								 | 
							
									default:
							 | 
						||
| 
								 | 
							
										w.statusCode = statusCode
							 | 
						||
| 
								 | 
							
										close(w.headerWritten)
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// grpcResponseWriter 是一个适配器,实现了 io.Writer 接口
							 | 
						||
| 
								 | 
							
								type grpcResponseWriter struct {
							 | 
						||
| 
								 | 
							
									stream relaypb.InternalRelay_ProxyRequestServer
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								func (w *grpcResponseWriter) Write(p []byte) (n int, err error) {
							 | 
						||
| 
								 | 
							
									err = w.stream.Send(&relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
										Payload: &relaypb.ProxyResponseMessage_BodyChunk{
							 | 
						||
| 
								 | 
							
											BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p},
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
									if err != nil {
							 | 
						||
| 
								 | 
							
										return 0, err
							 | 
						||
| 
								 | 
							
									}
							 | 
						||
| 
								 | 
							
									return len(p), nil
							 | 
						||
| 
								 | 
							
								}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								// writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中
							 | 
						||
| 
								 | 
							
								func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error {
							 | 
						||
| 
								 | 
							
									return stream.Send(&relaypb.ProxyResponseMessage{
							 | 
						||
| 
								 | 
							
										Payload: &relaypb.ProxyResponseMessage_BodyChunk{
							 | 
						||
| 
								 | 
							
											BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data},
							 | 
						||
| 
								 | 
							
										},
							 | 
						||
| 
								 | 
							
									})
							 | 
						||
| 
								 | 
							
								}
							 |