新中转服务
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

360 lines
11 KiB

// 文件: grpc/server.go
package grpc
import (
"bufio"
"io"
"log"
relaypb "memobus_relay_server/relay_server/proto"
"memobus_relay_server/session"
"net/http"
"net/http/httptest"
"strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// InternalRelayServer 实现了用于代理请求的 gRPC 服务
type InternalRelayServer struct {
relaypb.UnimplementedInternalRelayServer
}
// NewInternalRelayServer 创建一个新的 gRPC 服务实例
func NewInternalRelayServer() *InternalRelayServer {
return &InternalRelayServer{}
}
// ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应
func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error {
// --- 1. 接收和解析请求头 ---
headerMsg, err := stream.Recv()
if err != nil {
log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err)
return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err)
}
header := headerMsg.GetHeader()
if header == nil {
return status.Errorf(codes.InvalidArgument, "first message must be a header")
}
// [新增调试日志]
if header.Headers["Upgrade"] == "websocket" {
log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"])
}
// 检查是否是 WebSocket 握手请求
isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade")
pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3)
if len(pathParts) < 2 {
return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header")
}
deviceSN := pathParts[1]
appUserID := header.GetAppUserId()
if appUserID == "" {
return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header")
}
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
// --- 2. 查找本地会话并进行授权检查 ---
sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN)
if !ok {
return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN)
}
if sessionInfo.UserID != appUserID {
log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID)
return sendForbiddenResponse(stream)
}
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
// --- 3. [核心修改] 根据请求类型进行分流 ---
if isWebSocket {
log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.")
return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header)
} else {
// 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器
log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.")
// 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁
return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header)
}
}
func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
// 1. 打开到后端 (yamux) 的连接
backendConn, err := sessionInfo.Session.Open()
if err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err)
return status.Errorf(codes.Internal, "failed to connect to backend service")
}
defer backendConn.Close()
// 2. 重建原始的 HTTP 升级请求
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil)
for k, v := range header.Headers {
req.Header.Set(k, v)
}
req.Host = "immich-internal" // 模拟 ReverseProxy 的行为
pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3)
if len(pathParts) > 2 {
req.URL.Path = "/" + pathParts[2]
} else {
req.URL.Path = "/"
}
// 3. 将升级请求写入后端连接,发起握手
if err := req.Write(backendConn); err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err)
return status.Errorf(codes.Internal, "failed to send upgrade request to backend")
}
// 4. 读取后端的响应 (握手结果)
backendReader := bufio.NewReader(backendConn)
resp, err := http.ReadResponse(backendReader, req)
if err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err)
return status.Errorf(codes.Internal, "failed to read handshake response from backend")
}
// 5. 将后端的握手响应通过 gRPC 发回给代理节点
respHeaderMsg := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: int32(resp.StatusCode),
Headers: make(map[string]string),
},
},
}
for k, v := range resp.Header {
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
}
if err := stream.Send(respHeaderMsg); err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err)
return err
}
// 6. 如果握手失败 (不是 101),则流程结束
if resp.StatusCode != http.StatusSwitchingProtocols {
log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode)
return nil
}
log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN)
// 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝
errChan := make(chan error, 2)
// Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据)
go func() {
// 这个方向的逻辑没有问题
for {
msg, err := stream.Recv()
if err == io.EOF {
backendConn.Close()
errChan <- nil
return
}
if err != nil {
errChan <- err
return
}
if chunk := msg.GetBodyChunk(); chunk != nil {
if _, err := backendConn.Write(chunk.Data); err != nil {
errChan <- err
return
}
}
}
}()
// Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据)
go func() {
// [核心修正]
// 我们必须从 backendReader (而不是原始的 backendConn) 开始读取,
// 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。
// io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。
if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil {
// 过滤掉正常的连接关闭错误
if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") {
errChan <- err
} else {
errChan <- nil
}
} else {
errChan <- nil
}
}()
// 等待两个 goroutine 都结束
err1 := <-errChan
err2 := <-errChan
if err1 != nil && err1 != io.EOF {
log.Printf("WebSocket stream finished with error: %v", err1)
return err1
}
if err2 != nil && err2 != io.EOF {
log.Printf("WebSocket stream finished with error: %v", err2)
return err2
}
log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN)
return nil
}
// [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑
func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
// 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本
// 我直接粘贴过来
// --- 3. 创建请求和响应的管道 ---
reqPr, reqPw := io.Pipe()
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr)
for k, v := range header.Headers {
req.Header.Set(k, v)
}
req.Header.Set("X-Forwarded-For", header.RemoteAddr)
respPr, respPw := io.Pipe()
customResponseWriter := &streamResponseWriter{
header: make(http.Header),
pipeWriter: respPw,
headerWritten: make(chan struct{}),
}
// --- 4. 启动 Goroutines ---
go func() {
defer reqPw.Close()
for {
bodyMsg, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
reqPw.CloseWithError(err)
return
}
if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil {
if _, err := reqPw.Write(bodyChunk.Data); err != nil {
return
}
}
}
}()
errChan := make(chan error, 1)
go func() {
defer close(errChan)
<-customResponseWriter.headerWritten
// b. [修正] 完整地构造 gRPC 响应头
respHeaderMsg := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: int32(customResponseWriter.statusCode),
Headers: make(map[string]string),
},
},
}
for k, v := range customResponseWriter.header {
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
}
if err := stream.Send(respHeaderMsg); err != nil {
errChan <- err
return
}
buf := make([]byte, 1024*32)
if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil {
if err != io.ErrClosedPipe {
errChan <- err
}
}
}()
// --- 5. 执行代理 ---
proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery)
proxy.ServeHTTP(customResponseWriter, req)
// --- 6. 清理 ---
respPw.Close()
return <-errChan
}
// sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应
func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error {
respHeader := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: http.StatusForbidden,
Headers: map[string]string{"Content-Type": "text/plain; charset=utf-8"},
},
},
}
if err := stream.Send(respHeader); err != nil {
return err
}
respBody := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")},
},
}
stream.Send(respBody)
return nil // 正常关闭流
}
// streamResponseWriter 是一个自定义的 http.ResponseWriter
type streamResponseWriter struct {
header http.Header
pipeWriter *io.PipeWriter
statusCode int
headerWritten chan struct{}
}
func (w *streamResponseWriter) Header() http.Header {
return w.header
}
func (w *streamResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(http.StatusOK)
return w.pipeWriter.Write(b)
}
func (w *streamResponseWriter) WriteHeader(statusCode int) {
select {
case <-w.headerWritten:
return
default:
w.statusCode = statusCode
close(w.headerWritten)
}
}
// grpcResponseWriter 是一个适配器,实现了 io.Writer 接口
type grpcResponseWriter struct {
stream relaypb.InternalRelay_ProxyRequestServer
}
func (w *grpcResponseWriter) Write(p []byte) (n int, err error) {
err = w.stream.Send(&relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p},
},
})
if err != nil {
return 0, err
}
return len(p), nil
}
// writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中
func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error {
return stream.Send(&relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data},
},
})
}