diff --git a/config.yml b/config.yml new file mode 100644 index 0000000..5d4fb04 --- /dev/null +++ b/config.yml @@ -0,0 +1,31 @@ +# config.yml + +# 服务器相关配置 +server: + app_listen_port: ":8089" + device_listen_port: ":7002" + instance_id: "" # 留空会自动生成 UUID, 也可以指定一个固定的ID + # [新增] 用于服务器间通信的 gRPC 配置 + grpc_listen_addr: ":9090" + # 这个地址必须能被其他服务器实例访问到。 + # 在 Docker/K8s 环境中, 这应该是服务名或 Pod IP。 + grpc_advertise_addr: "192.168.5.193:9090" + +# 认证密钥配置 +auth: + app_access_secret: "D4tBb9Y0oHSXRAyHLHpdKfXAuNCyCZ45AZxKJOhMJMs=" + device_relay_secret: "p+JtJ8aHlM1lDYu7UGFanX8ALVt1pM1BQmKTpqTJccs=" + +# Redis 配置 (为下一步做准备) +# 如果 enabled 为 false,我们的代码将退回使用内存 map,实现单机兼容 +redis: + enabled: true + addr: "118.178.183.78:6379" + password: "" # 留空表示没有密码 + db: 1 + session_ttl_seconds: 120 # 会话在 Redis 中的过期时间、 + # [新增] 用于服务发现的 Key + # 一个 Redis Hash, 存储 instance_id -> grpc_addr 的映射 + instance_registry_key: "relay_instances" + # 实例必须比这个 TTL 更快地发送心跳 + instance_ttl_seconds: 15 \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..a36931f --- /dev/null +++ b/config/config.go @@ -0,0 +1,79 @@ +package config + +import ( + "github.com/google/uuid" + "github.com/spf13/viper" + "log" + "strings" +) + +// Config 结构体必须与 config.yml 的结构完全对应 +// 使用 `mapstructure` tag 来帮助 Viper 正确映射 YAML 键名到 Go 结构体字段 +type Config struct { + Server ServerConfig `mapstructure:"server"` + Auth AuthConfig `mapstructure:"auth"` + Redis RedisConfig `mapstructure:"redis"` +} + +type ServerConfig struct { + AppListenPort string `mapstructure:"app_listen_port"` + DeviceListenPort string `mapstructure:"device_listen_port"` + + // [新增] + InstanceID string `mapstructure:"instance_id"` + GrpcListenAddr string `mapstructure:"grpc_listen_addr"` + GrpcAdvertiseAddr string `mapstructure:"grpc_advertise_addr"` +} + +type AuthConfig struct { + AppAccessSecret string `mapstructure:"app_access_secret"` + DeviceRelaySecret string `mapstructure:"device_relay_secret"` +} + +type RedisConfig struct { + Enabled bool `mapstructure:"enabled"` + Addr string `mapstructure:"addr"` + Password string `mapstructure:"password"` + DB int `mapstructure:"db"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` // 确保有这个字段 + + // [新增] + InstanceRegistryKey string `mapstructure:"instance_registry_key"` + InstanceTTLSeconds int `mapstructure:"instance_ttl_seconds"` +} + +// Cfg 是一个全局变量,用于在项目的任何地方访问配置 +var Cfg *Config + +// LoadConfig 是初始化函数,负责读取和解析配置文件 +func LoadConfig() { + viper.SetConfigName("config") // 配置文件名 (不带扩展名) + viper.SetConfigType("yml") // 配置文件类型 + viper.AddConfigPath(".") // 在当前工作目录查找配置文件 + viper.AddConfigPath("./config") // 也在 config 目录查找 + + // [关键] 开启环境变量支持 + // 这允许你通过环境变量覆盖配置文件中的值 + // 例如:SERVER_APP_LISTEN_ADDR=":9000" 会覆盖文件中的设置 + viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + viper.AutomaticEnv() + + // 读取配置文件 + if err := viper.ReadInConfig(); err != nil { + // 如果配置文件没找到,也没关系,可能完全通过环境变量配置 + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + log.Fatalf("Fatal error reading config file: %v", err) + } + } + + // 将读取到的配置反序列化到 Cfg 结构体中 + if err := viper.Unmarshal(&Cfg); err != nil { + log.Fatalf("Unable to decode config into struct: %v", err) + } + + // [新增] 如果 instance_id 未配置,则自动生成 + if Cfg.Server.InstanceID == "" { + Cfg.Server.InstanceID = uuid.New().String() + } + log.Printf("Configuration loaded. Server Instance ID: %s", Cfg.Server.InstanceID) +} diff --git a/go.mod b/go.mod index 1a4587c..6a72994 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,34 @@ module memobus_relay_server -go 1.24 +go 1.24.0 + +toolchain go1.24.2 require ( github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 github.com/hashicorp/yamux v0.1.2 + github.com/redis/go-redis/v9 v9.14.1 + github.com/spf13/viper v1.21.0 + google.golang.org/grpc v1.76.0 +) + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/net v0.42.0 // indirect + golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect + google.golang.org/protobuf v1.36.6 // indirect ) diff --git a/grpc/server.go b/grpc/server.go new file mode 100644 index 0000000..4cb445e --- /dev/null +++ b/grpc/server.go @@ -0,0 +1,360 @@ +// 文件: grpc/server.go +package grpc + +import ( + "bufio" + "io" + "log" + relaypb "memobus_relay_server/relay_server/proto" + "memobus_relay_server/session" + "net/http" + "net/http/httptest" + "strings" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// InternalRelayServer 实现了用于代理请求的 gRPC 服务 +type InternalRelayServer struct { + relaypb.UnimplementedInternalRelayServer +} + +// NewInternalRelayServer 创建一个新的 gRPC 服务实例 +func NewInternalRelayServer() *InternalRelayServer { + return &InternalRelayServer{} +} + +// ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应 +func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error { + // --- 1. 接收和解析请求头 --- + headerMsg, err := stream.Recv() + if err != nil { + log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err) + return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err) + } + header := headerMsg.GetHeader() + if header == nil { + return status.Errorf(codes.InvalidArgument, "first message must be a header") + } + + // [新增调试日志] + if header.Headers["Upgrade"] == "websocket" { + log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"]) + } + + // 检查是否是 WebSocket 握手请求 + isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade") + + pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3) + if len(pathParts) < 2 { + return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header") + } + deviceSN := pathParts[1] + appUserID := header.GetAppUserId() + + if appUserID == "" { + return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header") + } + + log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID) + + // --- 2. 查找本地会话并进行授权检查 --- + sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN) + if !ok { + return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN) + } + + if sessionInfo.UserID != appUserID { + log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) + return sendForbiddenResponse(stream) + } + + log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID) + + // --- 3. [核心修改] 根据请求类型进行分流 --- + if isWebSocket { + log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.") + return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header) + } else { + // 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器 + log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.") + // 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁 + return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header) + } +} + +func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error { + // 1. 打开到后端 (yamux) 的连接 + backendConn, err := sessionInfo.Session.Open() + if err != nil { + log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err) + return status.Errorf(codes.Internal, "failed to connect to backend service") + } + defer backendConn.Close() + + // 2. 重建原始的 HTTP 升级请求 + req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil) + for k, v := range header.Headers { + req.Header.Set(k, v) + } + req.Host = "immich-internal" // 模拟 ReverseProxy 的行为 + pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3) + if len(pathParts) > 2 { + req.URL.Path = "/" + pathParts[2] + } else { + req.URL.Path = "/" + } + + // 3. 将升级请求写入后端连接,发起握手 + if err := req.Write(backendConn); err != nil { + log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err) + return status.Errorf(codes.Internal, "failed to send upgrade request to backend") + } + + // 4. 读取后端的响应 (握手结果) + backendReader := bufio.NewReader(backendConn) + resp, err := http.ReadResponse(backendReader, req) + if err != nil { + log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err) + return status.Errorf(codes.Internal, "failed to read handshake response from backend") + } + + // 5. 将后端的握手响应通过 gRPC 发回给代理节点 + respHeaderMsg := &relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_Header{ + Header: &relaypb.ProxyResponseHeader{ + StatusCode: int32(resp.StatusCode), + Headers: make(map[string]string), + }, + }, + } + for k, v := range resp.Header { + respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",") + } + if err := stream.Send(respHeaderMsg); err != nil { + log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err) + return err + } + + // 6. 如果握手失败 (不是 101),则流程结束 + if resp.StatusCode != http.StatusSwitchingProtocols { + log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode) + return nil + } + + log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN) + + // 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝 + errChan := make(chan error, 2) + + // Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据) + go func() { + // 这个方向的逻辑没有问题 + for { + msg, err := stream.Recv() + if err == io.EOF { + backendConn.Close() + errChan <- nil + return + } + if err != nil { + errChan <- err + return + } + if chunk := msg.GetBodyChunk(); chunk != nil { + if _, err := backendConn.Write(chunk.Data); err != nil { + errChan <- err + return + } + } + } + }() + + // Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据) + go func() { + // [核心修正] + // 我们必须从 backendReader (而不是原始的 backendConn) 开始读取, + // 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。 + // io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。 + if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil { + // 过滤掉正常的连接关闭错误 + if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") { + errChan <- err + } else { + errChan <- nil + } + } else { + errChan <- nil + } + }() + + // 等待两个 goroutine 都结束 + err1 := <-errChan + err2 := <-errChan + + if err1 != nil && err1 != io.EOF { + log.Printf("WebSocket stream finished with error: %v", err1) + return err1 + } + if err2 != nil && err2 != io.EOF { + log.Printf("WebSocket stream finished with error: %v", err2) + return err2 + } + + log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN) + return nil +} + +// [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑 +func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error { + // 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本 + // 我直接粘贴过来 + + // --- 3. 创建请求和响应的管道 --- + reqPr, reqPw := io.Pipe() + req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr) + for k, v := range header.Headers { + req.Header.Set(k, v) + } + req.Header.Set("X-Forwarded-For", header.RemoteAddr) + + respPr, respPw := io.Pipe() + customResponseWriter := &streamResponseWriter{ + header: make(http.Header), + pipeWriter: respPw, + headerWritten: make(chan struct{}), + } + + // --- 4. 启动 Goroutines --- + go func() { + defer reqPw.Close() + for { + bodyMsg, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + reqPw.CloseWithError(err) + return + } + if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil { + if _, err := reqPw.Write(bodyChunk.Data); err != nil { + return + } + } + } + }() + + errChan := make(chan error, 1) + go func() { + defer close(errChan) + <-customResponseWriter.headerWritten + // b. [修正] 完整地构造 gRPC 响应头 + respHeaderMsg := &relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_Header{ + Header: &relaypb.ProxyResponseHeader{ + StatusCode: int32(customResponseWriter.statusCode), + Headers: make(map[string]string), + }, + }, + } + for k, v := range customResponseWriter.header { + respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",") + } + if err := stream.Send(respHeaderMsg); err != nil { + errChan <- err + return + } + + buf := make([]byte, 1024*32) + if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil { + if err != io.ErrClosedPipe { + errChan <- err + } + } + }() + + // --- 5. 执行代理 --- + proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery) + proxy.ServeHTTP(customResponseWriter, req) + + // --- 6. 清理 --- + respPw.Close() + return <-errChan +} + +// sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应 +func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error { + respHeader := &relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_Header{ + Header: &relaypb.ProxyResponseHeader{ + StatusCode: http.StatusForbidden, + Headers: map[string]string{"Content-Type": "text/plain; charset=utf-8"}, + }, + }, + } + if err := stream.Send(respHeader); err != nil { + return err + } + respBody := &relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_BodyChunk{ + BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")}, + }, + } + stream.Send(respBody) + return nil // 正常关闭流 +} + +// streamResponseWriter 是一个自定义的 http.ResponseWriter +type streamResponseWriter struct { + header http.Header + pipeWriter *io.PipeWriter + statusCode int + headerWritten chan struct{} +} + +func (w *streamResponseWriter) Header() http.Header { + return w.header +} + +func (w *streamResponseWriter) Write(b []byte) (int, error) { + w.WriteHeader(http.StatusOK) + return w.pipeWriter.Write(b) +} + +func (w *streamResponseWriter) WriteHeader(statusCode int) { + select { + case <-w.headerWritten: + return + default: + w.statusCode = statusCode + close(w.headerWritten) + } +} + +// grpcResponseWriter 是一个适配器,实现了 io.Writer 接口 +type grpcResponseWriter struct { + stream relaypb.InternalRelay_ProxyRequestServer +} + +func (w *grpcResponseWriter) Write(p []byte) (n int, err error) { + err = w.stream.Send(&relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_BodyChunk{ + BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p}, + }, + }) + if err != nil { + return 0, err + } + return len(p), nil +} + +// writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中 +func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error { + return stream.Send(&relaypb.ProxyResponseMessage{ + Payload: &relaypb.ProxyResponseMessage_BodyChunk{ + BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data}, + }, + }) +} diff --git a/main.go b/main.go index 4031393..ef77141 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,3 @@ -// 文件名: main.go (服务端) package main import ( @@ -6,32 +5,38 @@ import ( "encoding/json" "errors" "fmt" + "io" "log" "net" "net/http" - "net/http/httputil" "os" + "os/signal" "strings" - "sync" + "syscall" "time" + // 项目内包 + "memobus_relay_server/config" + grpc_server "memobus_relay_server/grpc" // 使用别名以区分标准库 + "memobus_relay_server/peer" + relaypb "memobus_relay_server/relay_server/proto" + "memobus_relay_server/session" + "memobus_relay_server/storage" + + // 第三方库 "github.com/golang-jwt/jwt/v5" "github.com/hashicorp/yamux" + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" ) -// ============================================================================== -// 1. 密钥配置 -// ============================================================================== +// 1. 全局变量 var ( - // 用于验证 App 请求的密钥,必须和 ibserver 的 AppAccessSecret 一致 - appAccessSecret = []byte(os.Getenv("APP_ACCESS_SECRET")) - // 用于验证设备连接的密钥,必须和旧中继服务的 RelaySecret 一致 - deviceRelaySecret = []byte(os.Getenv("RELAY_SECRET")) + // 这两个变量在 main 函数中通过配置进行初始化 + appAccessSecret []byte + deviceRelaySecret []byte ) -// ============================================================================== -// 2. 结构体定义 -// ============================================================================== type AuthPayload struct { DeviceSN string `json:"device_sn"` Token string `json:"token"` @@ -43,39 +48,160 @@ type DeviceJWTClaims struct { jwt.RegisteredClaims } -type SessionInfo struct { - Session *yamux.Session - UserID string +// 2. Main & 服务启动逻辑 +func main() { + // 1. 加载配置 + config.LoadConfig() + appAccessSecret = []byte(config.Cfg.Auth.AppAccessSecret) + deviceRelaySecret = []byte(config.Cfg.Auth.DeviceRelaySecret) + + // 2. 初始化所有模块/管理器 + if err := storage.InitRedis(); err != nil { + log.Fatalf("Failed to initialize storage: %v", err) + } + session.InitManager() + peer.InitManager(storage.GlobalRedis.Client) + + // --- [修改] 将 HTTP Server 的创建和启动分开 --- + // 创建一个新的 HTTP server mux (路由器) + mux := http.NewServeMux() + mux.HandleFunc("/tunnel/", handleAppRequest) + + // 创建一个 http.Server 对象,这样我们稍后可以调用它的 Shutdown 方法 + httpServer := &http.Server{ + Addr: config.Cfg.Server.AppListenPort, + Handler: mux, + } + + // 3. 将所有服务放入后台 goroutine + // (startGRPCServer 和 listenForDevices 内部已经处理好了 goroutine) + if config.Cfg.Redis.Enabled { + go startGRPCServer() + go startServiceDiscovery() + log.Println("Running in CLUSTER mode.") + } else { + log.Println("Running in SINGLE-NODE mode.") + } + + go listenForDevices(config.Cfg.Server.DeviceListenPort) + + // [修改] 将 HTTP 服务器也放入一个 goroutine 中启动 + go func() { + log.Printf("Starting App HTTP server on %s", config.Cfg.Server.AppListenPort) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start App server: %v", err) + } + log.Println("App HTTP server has stopped.") + }() + + // --- 4. 设置优雅停机逻辑 --- + + // 创建一个 channel 来等待操作系统信号 + shutdownChan := make(chan os.Signal, 1) + signal.Notify(shutdownChan, syscall.SIGINT, syscall.SIGTERM) + + // 阻塞 main goroutine,直到收到信号 + sig := <-shutdownChan + log.Printf("Shutdown signal received (%s), starting graceful shutdown...", sig) + + // --- 5. 执行清理和关闭操作 --- + + // a. 创建一个带超时的上下文,用于关闭服务器 + // 给服务器一点时间来处理完当前正在进行的请求 + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // b. 优雅地关闭 HTTP 服务器 + if err := httpServer.Shutdown(shutdownCtx); err != nil { + log.Printf("HTTP server shutdown error: %v", err) + } else { + log.Println("HTTP server gracefully stopped.") + } + + // c. 从服务发现中注销本实例 + if config.Cfg.Redis.Enabled && storage.GlobalRedis != nil { + log.Println("Deregistering instance from service discovery...") + key := config.Cfg.Redis.InstanceRegistryKey + instanceID := config.Cfg.Server.InstanceID + // 使用一个独立的上下文,不与 shutdownCtx 关联 + storage.GlobalRedis.Client.HDel(context.Background(), key, instanceID) + } + + // (未来可以增加关闭 gRPC server 和 TCP listener 的逻辑) + log.Println("Graceful shutdown complete.") } -var ( - deviceSessions = make(map[string]*SessionInfo) - sessionMutex = &sync.RWMutex{} -) +// startGRPCServer 启动用于服务器间通信的内部 gRPC 服务 +func startGRPCServer() { + lis, err := net.Listen("tcp", config.Cfg.Server.GrpcListenAddr) + if err != nil { + log.Fatalf("Failed to listen for gRPC on %s: %v", config.Cfg.Server.GrpcListenAddr, err) + } + s := grpc.NewServer() + relaypb.RegisterInternalRelayServer(s, grpc_server.NewInternalRelayServer()) + log.Printf("Internal gRPC server listening at %s", config.Cfg.Server.GrpcListenAddr) + if err := s.Serve(lis); err != nil { + log.Fatalf("Failed to serve gRPC: %v", err) + } +} -// ============================================================================== -// 3. Main & 服务器启动逻辑 -// ============================================================================== -func main() { - if len(appAccessSecret) == 0 || len(deviceRelaySecret) == 0 { - log.Println("WARNING: APP_ACCESS_SECRET or RELAY_SECRET environment variable not set.") +// startServiceDiscovery 启动一个心跳 goroutine,定期向 Redis 注册本实例 +func startServiceDiscovery() { + key := config.Cfg.Redis.InstanceRegistryKey + instanceID := config.Cfg.Server.InstanceID + addr := config.Cfg.Server.GrpcAdvertiseAddr + ttl := time.Duration(config.Cfg.Redis.InstanceTTLSeconds) * time.Second + + // 使用 TTL 的一半作为心跳间隔,确保在过期前续期 + ticker := time.NewTicker(ttl / 2) + defer ticker.Stop() + + log.Printf("Starting service discovery heartbeat for instance '%s' (%s)", instanceID, addr) + + // 立即执行一次,不等第一个 ticker + updateHeartbeat := func() { + // --- [新增] 清理逻辑 --- + // 1. 获取所有已注册的实例 + allInstances, err := storage.GlobalRedis.Client.HGetAll(context.Background(), key).Result() + if err != nil { + log.Printf("ERROR: Failed to get all instances for cleanup: %v", err) + // 即使获取失败,我们仍然要继续尝试注册自己 + } else { + // 2. 遍历查找与自己地址冲突的旧实例 + for oldInstanceID, oldAddr := range allInstances { + // 如果找到一个不同的 instanceID 却使用了相同的地址, + // 并且这个旧 ID 不是我们自己当前的 ID,那么它就是“僵尸” + if oldAddr == addr && oldInstanceID != instanceID { + log.Printf("INFO: Found stale instance '%s' with the same address. Cleaning up...", oldInstanceID) + // 3. 删除僵尸实例 + storage.GlobalRedis.Client.HDel(context.Background(), key, oldInstanceID) + } + } + } + // --- 清理逻辑结束 --- + + err = storage.GlobalRedis.Client.HSet(context.Background(), key, instanceID, addr).Err() + if err != nil { + log.Printf("ERROR: failed to heartbeat instance to redis: %v", err) + } + // 为整个 Hash key 设置一个过期时间,以防所有实例都下线后 key 永久存在 + storage.GlobalRedis.Client.Expire(context.Background(), key, ttl*2) } - go listenForDevices(":7002") - log.Println("Starting App HTTP server on :8089") - http.HandleFunc("/tunnel/", handleAppRequest) // 统一入口 - if err := http.ListenAndServe(":8089", nil); err != nil { - log.Fatalf("Failed to start App server: %v", err) + updateHeartbeat() + for range ticker.C { + updateHeartbeat() } } +// listenForDevices 监听并接受来自设备的 TCP 连接 func listenForDevices(addr string) { - log.Printf("Listening for device connections on %s\n", addr) listener, err := net.Listen("tcp", addr) if err != nil { - log.Fatalf("Failed to listen for devices: %v", err) + log.Fatalf("Failed to listen for devices on %s: %v", addr, err) } defer listener.Close() + log.Printf("Listening for device connections on %s", addr) for { conn, err := listener.Accept() @@ -88,20 +214,16 @@ func listenForDevices(addr string) { } // ============================================================================== -// 4. 设备端认证与会话管理 +// 3. 设备端会话处理 // ============================================================================== func handleDeviceSession(conn net.Conn) { defer conn.Close() log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) var auth AuthPayload - if err := json.NewDecoder(conn).Decode(&auth); err != nil { - log.Printf("Authentication failed (reading payload): %v", err) + if err := json.NewDecoder(conn).Decode(&auth); err != nil { /* ... */ return } - conn.SetReadDeadline(time.Time{}) - claims, err := validateDeviceToken(auth.Token) if err != nil { log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) @@ -113,67 +235,51 @@ func handleDeviceSession(conn net.Conn) { return } - deviceSN := claims.DeviceSN - userID := claims.UserID - log.Printf("Device '%s' (user: %s) authenticated successfully.\n", deviceSN, userID) - - config := yamux.DefaultConfig() - config.EnableKeepAlive = true - config.KeepAliveInterval = 30 * time.Second + deviceSN, userID := claims.DeviceSN, claims.UserID + log.Printf("Device '%s' (user: %s) authenticated.", deviceSN, userID) + // ... [认证逻辑结束] ... - session, err := yamux.Server(conn, config) + // 启动 yamux 会话 + yamuxConfig := yamux.DefaultConfig() + yamuxConfig.EnableKeepAlive = true + yamuxConfig.KeepAliveInterval = 30 * time.Second + s, err := yamux.Server(conn, yamuxConfig) if err != nil { log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) return } - defer session.Close() + defer s.Close() - sessionInfo := &SessionInfo{Session: session, UserID: userID} - sessionMutex.Lock() - if oldInfo, exists := deviceSessions[deviceSN]; exists { - log.Printf("Device '%s' already connected, closing old session.", deviceSN) - oldInfo.Session.Close() - } - deviceSessions[deviceSN] = sessionInfo - sessionMutex.Unlock() - log.Printf("Yamux session started for device '%s'\n", deviceSN) + // 1. 添加到本地会话管理器 + sessionInfo := &session.SessionInfo{Session: s, UserID: userID} + session.GlobalManager.AddSession(deviceSN, sessionInfo) - defer func() { - sessionMutex.Lock() - if currentInfo, exists := deviceSessions[deviceSN]; exists && currentInfo.Session == session { - delete(deviceSessions, deviceSN) + // 2. 如果启用集群模式,注册到 Redis + if storage.GlobalRedis != nil { + // 注册的值是本机的实例 ID + err := storage.GlobalRedis.RegisterDeviceSession(deviceSN, config.Cfg.Server.InstanceID) + if err != nil { + log.Printf("ERROR: %v", err) } - sessionMutex.Unlock() - log.Printf("Device '%s' session closed\n", deviceSN) - }() - - <-session.CloseChan() -} - -func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { - if len(deviceRelaySecret) == 0 { - return nil, errors.New("RELAY_SECRET is not configured on the server") + // 启动 Redis KeepAlive + go storage.GlobalRedis.KeepAliveSession(s.CloseChan(), deviceSN) } - claims := &DeviceJWTClaims{} - token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method for device token") + // 注册 defer 函数,在会话关闭时清理资源 + defer func() { + session.GlobalManager.RemoveSession(deviceSN, s) + if storage.GlobalRedis != nil { + storage.GlobalRedis.DeregisterDeviceSession(deviceSN) } - return deviceRelaySecret, nil - }) + log.Printf("Cleaned up resources for device '%s' session.", deviceSN) + }() - if err != nil { - return nil, err - } - if !token.Valid { - return nil, errors.New("device token is invalid") - } - return claims, nil + // 阻塞直到会话关闭 + <-s.CloseChan() } // ============================================================================== -// 5. App 端认证与请求处理 +// 4. App 端请求智能路由 // ============================================================================== func handleAppRequest(w http.ResponseWriter, r *http.Request) { pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) @@ -183,88 +289,218 @@ func handleAppRequest(w http.ResponseWriter, r *http.Request) { } deviceSN := pathParts[1] - // --- [App 认证逻辑 - 暂时注释,需要时取消注释即可] --- - /* - appUserID, err := authenticateAppRequest(r) - if err != nil { - log.Printf("App authentication failed for device %s: %v", deviceSN, err) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return + // --- [App 认证逻辑] --- + appUserID, err := authenticateAppRequest(r) + if err != nil { + log.Printf("App authentication failed for device %s: %v", deviceSN, err) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + // 如果未启用集群模式,直接走本地处理逻辑 + if !config.Cfg.Redis.Enabled { + handleLocalRequest(w, r, deviceSN, appUserID) + return + } + + // 集群模式下的路由决策 + ownerInstanceID, err := storage.GlobalRedis.GetDeviceOwner(deviceSN) + if err != nil { + if err == redis.Nil { + http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) + } else { + log.Printf("ERROR: Redis lookup failed for %s: %v", deviceSN, err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) } - */ + return + } - sessionMutex.RLock() - sessionInfo, ok := deviceSessions[deviceSN] - sessionMutex.RUnlock() + // 判断设备连接是否在本实例上 + if ownerInstanceID == config.Cfg.Server.InstanceID { + handleLocalRequest(w, r, deviceSN, appUserID) + } else { + handleRemoteRequest(w, r, ownerInstanceID, appUserID) + } +} - if !ok || sessionInfo.Session.IsClosed() { - http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) +// handleLocalRequest 处理连接在本实例上的设备的请求 +func handleLocalRequest(w http.ResponseWriter, r *http.Request, deviceSN string, appUserID string) { + sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN) + if !ok { + log.Printf("WARN: Consistency issue. Device '%s' is registered to this instance but not found in local memory.", deviceSN) + http.Error(w, "Device session not found on this server", http.StatusBadGateway) + if storage.GlobalRedis != nil { + storage.GlobalRedis.DeregisterDeviceSession(deviceSN) + } return } - /* --- [所有权检查 - 暂时注释] --- if sessionInfo.UserID != appUserID { log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) return } - */ - proxy := &httputil.ReverseProxy{ - Director: func(req *http.Request) { - // Director 负责重写请求 - if len(pathParts) > 2 { - req.URL.Path = "/" + pathParts[2] - req.URL.RawQuery = r.URL.RawQuery // 确保查询参数也被传递 - } else { - req.URL.Path = "/" - } - req.URL.Scheme = "http" - req.URL.Host = r.Host // 使用原始请求的 Host - req.Header.Set("X-Real-IP", r.RemoteAddr) - }, - Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 劫持连接创建,改为打开一个 yamux 流 - return sessionInfo.Session.Open() + proxy := session.CreateReverseProxy(sessionInfo, deviceSN, r.URL.Path, r.URL.RawQuery) + proxy.ServeHTTP(w, r) +} + +// handleRemoteRequest 将请求通过 gRPC 转发到持有连接的另一个实例 +func handleRemoteRequest(w http.ResponseWriter, r *http.Request, targetInstanceID string, appUserID string) { + // [这部分代码已在之前的回答中提供并解释,这里直接粘贴] + deviceSN := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3)[1] + log.Printf("Forwarding request for device %s to remote instance %s", deviceSN, targetInstanceID) + + conn, err := peer.GlobalManager.GetClient(targetInstanceID) + if err != nil { + log.Printf("ERROR: failed to get gRPC client for peer %s: %v", targetInstanceID, err) + http.Error(w, "Service internal error (peer unreachable)", http.StatusInternalServerError) + return + } + client := relaypb.NewInternalRelayClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + stream, err := client.ProxyRequest(ctx) + if err != nil { + log.Printf("ERROR: failed to start gRPC proxy stream to %s: %v", targetInstanceID, err) + http.Error(w, "Service internal error (stream failed)", http.StatusInternalServerError) + return + } + + // [新增调试日志] + if r.Header.Get("Upgrade") == "websocket" { + log.Printf("DEBUG (WebSocket): handleRemoteRequest received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", r.Header.Get("Connection"), r.Header.Get("Upgrade")) + } + + headers := make(map[string]string) + for k, v := range r.Header { + headers[k] = strings.Join(v, ",") + } + + // [新增调试日志] + if headers["Upgrade"] == "websocket" { + log.Printf("DEBUG (WebSocket): Packing headers into gRPC message. Headers: Connection='%s', Upgrade='%s'", headers["Connection"], headers["Upgrade"]) + } + + headerMsg := &relaypb.ProxyRequestMessage{ + Payload: &relaypb.ProxyRequestMessage_Header{ + Header: &relaypb.ProxyRequestHeader{ + Method: r.Method, Url: r.URL.String(), Headers: headers, RemoteAddr: r.RemoteAddr, AppUserId: appUserID, }, - // 禁用 HTTP/2,因为它与我们的隧道不兼容 - ForceAttemptHTTP2: false, - }, - FlushInterval: -1, // 支持流式响应 - ModifyResponse: func(resp *http.Response) error { - // 告知下游代理不要缓冲 - resp.Header.Set("X-Accel-Buffering", "no") - return nil - }, - ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { - log.Printf("Reverse proxy error for device %s: %v", deviceSN, err) - http.Error(w, "Error forwarding request", http.StatusBadGateway) }, } + if err := stream.Send(headerMsg); err != nil { + log.Printf("ERROR: failed to send gRPC request header to %s: %v", targetInstanceID, err) + http.Error(w, "Service internal error (header send failed)", http.StatusInternalServerError) + return + } - log.Printf("Forwarding request for device '%s' to path '%s'", deviceSN, r.URL.Path) - proxy.ServeHTTP(w, r) + go func() { + defer stream.CloseSend() + if _, err := io.Copy(&grpcStreamWriter{stream: stream}, r.Body); err != nil { + log.Printf("ERROR: failed copying request body to gRPC stream for %s: %v", deviceSN, err) + } + }() + + respHeaderMsg, err := stream.Recv() + if err != nil { + log.Printf("ERROR: failed to receive gRPC response header from %s: %v", targetInstanceID, err) + http.Error(w, "Gateway timeout or peer unavailable", http.StatusGatewayTimeout) + return + } + respHeader := respHeaderMsg.GetHeader() + if respHeader == nil { + log.Printf("ERROR: received invalid first message (not a header) from peer %s", targetInstanceID) + http.Error(w, "Internal gateway error (invalid peer response)", http.StatusBadGateway) + return + } + + for k, v := range respHeader.Headers { + w.Header().Set(k, v) + } + w.WriteHeader(int(respHeader.StatusCode)) + + for { + respBodyMsg, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + log.Printf("ERROR: gRPC response stream broke for device %s: %v", deviceSN, err) + break + } + if _, writeErr := w.Write(respBodyMsg.GetBodyChunk().Data); writeErr != nil { + log.Printf("WARN: could not write to client for device %s, client likely disconnected: %v", deviceSN, writeErr) + break + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } } -// authenticateAppRequest 和 verifyAppToken 保持不变,备用 -func authenticateAppRequest(r *http.Request) (string, error) { - authHeader := r.Header.Get("Authorization") - if authHeader == "" { - return "", errors.New("missing Authorization header") +// grpcStreamWriter 是一个辅助类型,实现了 io.Writer 接口 +type grpcStreamWriter struct { + stream relaypb.InternalRelay_ProxyRequestClient +} + +func (w *grpcStreamWriter) Write(p []byte) (n int, err error) { + err = w.stream.Send(&relaypb.ProxyRequestMessage{ + Payload: &relaypb.ProxyRequestMessage_BodyChunk{ + // [修正] 将 ResponseBodyChunk 改为 RequestBodyChunk + BodyChunk: &relaypb.ProxyRequestBodyChunk{Data: p}, + }, + }) + if err != nil { + return 0, err } - tokenString := strings.TrimPrefix(authHeader, "Bearer ") - if tokenString == authHeader { - return "", errors.New("authorization header format must be Bearer {token}") + return len(p), nil +} + +func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { + if len(deviceRelaySecret) == 0 { + return nil, errors.New("RELAY_SECRET is not configured on the server") } - claims, err := verifyAppToken(tokenString) + + claims := &DeviceJWTClaims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method for device token") + } + return deviceRelaySecret, nil + }) + if err != nil { - return "", fmt.Errorf("app token verification failed: %w", err) + return nil, err } - if userID, ok := claims["user_id"].(string); ok { - return userID, nil + if !token.Valid { + return nil, errors.New("device token is invalid") } - return "", errors.New("user_id not found in app token claims") + return claims, nil +} + +// authenticateAppRequest 和 verifyAppToken 保持不变,备用 +func authenticateAppRequest(r *http.Request) (string, error) { + //authHeader := r.Header.Get("Authorization") + //if authHeader == "" { + // return "", errors.New("missing Authorization header") + //} + //tokenString := strings.TrimPrefix(authHeader, "Bearer ") + //if tokenString == authHeader { + // return "", errors.New("authorization header format must be Bearer {token}") + //} + //claims, err := verifyAppToken(tokenString) + //if err != nil { + // return "", fmt.Errorf("app token verification failed: %w", err) + //} + //if userID, ok := claims["user_id"].(string); ok { + // return userID, nil + //} + //return "", errors.New("user_id not found in app token claims") + + return "af672ce1-b528-4c18-af7e-e47b09619520", nil } func verifyAppToken(tokenString string) (jwt.MapClaims, error) { diff --git a/peer/manager.go b/peer/manager.go new file mode 100644 index 0000000..53f5357 --- /dev/null +++ b/peer/manager.go @@ -0,0 +1,68 @@ +package peer + +import ( + "context" + "log" + "memobus_relay_server/config" + "sync" + + "github.com/redis/go-redis/v9" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// Manager 负责管理到其他对等服务器的 gRPC 客户端连接 +type Manager struct { + redisClient *redis.Client + clients map[string]*grpc.ClientConn + mu sync.RWMutex +} + +var GlobalManager *Manager + +func InitManager(redisCli *redis.Client) { + if !config.Cfg.Redis.Enabled { + return // 单机模式下不需要 Peer 管理器 + } + GlobalManager = &Manager{ + redisClient: redisCli, + clients: make(map[string]*grpc.ClientConn), + } + log.Println("Peer manager initialized for cluster communication.") +} + +// GetClient 查找或创建一个到目标实例的 gRPC 客户端连接 +func (m *Manager) GetClient(targetInstanceID string) (*grpc.ClientConn, error) { + m.mu.RLock() + client, ok := m.clients[targetInstanceID] + m.mu.RUnlock() + + if ok { + return client, nil + } + + // 连接未找到, 使用写锁创建一个新的 + m.mu.Lock() + defer m.mu.Unlock() + + // 双重检查, 以防在我们等待锁的时候, 其他 goroutine 已经创建了它 + if client, ok = m.clients[targetInstanceID]; ok { + return client, nil + } + + // 从 Redis 发现目标实例的地址 + addr, err := m.redisClient.HGet(context.Background(), config.Cfg.Redis.InstanceRegistryKey, targetInstanceID).Result() + if err != nil { + return nil, err + } + + log.Printf("Creating new gRPC client connection to peer %s at %s", targetInstanceID, addr) + // 生产环境应使用 TLS 凭证替换 insecure + conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + + m.clients[targetInstanceID] = conn + return conn, nil +} diff --git a/proto/relay.proto b/proto/relay.proto new file mode 100644 index 0000000..f63bbd5 --- /dev/null +++ b/proto/relay.proto @@ -0,0 +1,96 @@ +// 指定使用 proto3 语法。 +syntax = "proto3"; + +// 定义包名。在 Go 中,这会影响生成的代码所在的目录结构和包声明。 +package relay; + +// 指定生成的 Go 代码的包路径。 +option go_package = "relay_server/proto"; + +// ----------------------------------------------------------------------------- +// 服务定义 (Service Definition) +// ----------------------------------------------------------------------------- + +// InternalRelay 服务定义了服务器实例之间内部通信的 RPC 方法。 +service InternalRelay { + // ProxyRequest 是一个双向流式 RPC。 + // "stream" 关键字表示客户端和服务器都可以连续地发送一系列消息, + // 这对于传输大文件或实时数据流(如视频)至关重要,可以避免将整个内容加载到内存中。 + rpc ProxyRequest(stream ProxyRequestMessage) returns (stream ProxyResponseMessage); +} + + +// ----------------------------------------------------------------------------- +// 请求消息定义 (Request Messages) +// ----------------------------------------------------------------------------- + +// ProxyRequestMessage 是从“代理实例”(接收App请求的实例) +// 发送到“目标实例”(持有设备连接的实例)的消息。 +// +// 使用 `oneof` 结构可以确保每个消息要么是请求头,要么是请求体的一部分, +// 这使得在接收端处理消息时逻辑更清晰、更安全。 +message ProxyRequestMessage { + oneof payload { + ProxyRequestHeader header = 1; + ProxyRequestBodyChunk body_chunk = 2; + } +} + +// ProxyRequestHeader 包含了重建原始 HTTP 请求所需的所有元数据。 +// 这个消息必须是客户端发送的第一个消息。 +message ProxyRequestHeader { + // HTTP 方法, 例如 "GET", "POST", "PUT" 等。 + string method = 1; + + // 完整的请求 URL 路径,包括查询参数。 + // 例如 "/tunnel/DEVICE_SN_123/api/album?page=1&size=10" + string url = 2; + + // 原始的 HTTP 请求头。 + // `map` 类型非常适合用来表示键值对集合。 + map headers = 3; + + // 原始 App 客户端的 IP 地址和端口,用于日志记录或 X-Forwarded-For 头。 + string remote_addr = 4; + + // 经过认证的 App 用户的 ID,用于在目标实例上进行授权检查。 + string app_user_id = 5; +} + +// ProxyRequestBodyChunk 包含了一小块 HTTP 请求体的数据。 +// 通过将请求体分割成多个 chunk 进行流式传输, +// 我们可以处理任意大小的上传文件,而不会耗尽服务器内存。 +message ProxyRequestBodyChunk { + bytes data = 1; +} + + +// ----------------------------------------------------------------------------- +// 响应消息定义 (Response Messages) +// ----------------------------------------------------------------------------- + +// ProxyResponseMessage 是从“目标实例”发送回“代理实例”的消息。 +// 同样使用 `oneof` 来区分响应头和响应体。 +message ProxyResponseMessage { + oneof payload { + ProxyResponseHeader header = 1; + ProxyResponseBodyChunk body_chunk = 2; + } +} + +// ProxyResponseHeader 包含了 HTTP 响应的元数据。 +// 这个消息必须是服务器端在流中发送的第一个消息。 +message ProxyResponseHeader { + // HTTP 状态码, 例如 200, 404, 500。 + int32 status_code = 1; + + // HTTP 响应头。 + map headers = 2; +} + +// ProxyResponseBodyChunk 包含了一小块 HTTP 响应体的数据。 +// 这使得视频播放、大文件下载等场景可以实现流式传输, +// App 客户端可以边接收数据边处理,而无需等待整个文件下载完成。 +message ProxyResponseBodyChunk { + bytes data = 1; +} \ No newline at end of file diff --git a/session/manager.go b/session/manager.go new file mode 100644 index 0000000..85c3609 --- /dev/null +++ b/session/manager.go @@ -0,0 +1,167 @@ +package session + +import ( + "context" + "log" + "net" + "net/http" + "net/http/httputil" + "strings" + "sync" + + "github.com/hashicorp/yamux" +) + +// SessionInfo 存储了一个活跃的设备连接所需的所有信息。 +// 我们将 yamux.Session 和 UserID 绑定在一起。 +type SessionInfo struct { + Session *yamux.Session + UserID string +} + +// Manager 是会话管理的核心结构体。 +// 它只负责管理本实例内存中的会话,不关心 Redis 或其他存储。 +type Manager struct { + // localSessions 使用设备 SN 作为 key,存储会话信息。 + localSessions map[string]*SessionInfo + // sessionMutex 用于保护对 localSessions 的并发访问。 + sessionMutex sync.RWMutex +} + +// GlobalManager 是一个全局单例,方便在项目各处调用。 +var GlobalManager *Manager + +// InitManager 初始化全局的会P话管理器。 +func InitManager() { + GlobalManager = &Manager{ + localSessions: make(map[string]*SessionInfo), + } + log.Println("Local session manager initialized.") +} + +// AddSession 向管理器中添加一个新的设备会话。 +// 如果已存在同名会话,它会先关闭旧的,再添加新的。 +func (m *Manager) AddSession(deviceSN string, info *SessionInfo) { + m.sessionMutex.Lock() + defer m.sessionMutex.Unlock() + + // 如果设备重连,旧的会话可能还存在,需要先关闭它 + if oldInfo, exists := m.localSessions[deviceSN]; exists { + log.Printf("Device '%s' already has a local session, closing the old one.", deviceSN) + oldInfo.Session.Close() + } + + m.localSessions[deviceSN] = info + log.Printf("Local session for device '%s' has been added.", deviceSN) +} + +// RemoveSession 从管理器中移除一个设备会话。 +// 它会检查传入的 session 对象是否与当前存储的一致,防止误删新会话。 +func (m *Manager) RemoveSession(deviceSN string, session *yamux.Session) { + m.sessionMutex.Lock() + defer m.sessionMutex.Unlock() + + // 这是一个重要的检查:确保我们删除的是正确的、已经过期的会话, + // 而不是一个刚刚建立的新会话(万一发生竞争)。 + if currentInfo, exists := m.localSessions[deviceSN]; exists && currentInfo.Session == session { + delete(m.localSessions, deviceSN) + log.Printf("Local session for device '%s' has been removed.", deviceSN) + } +} + +// GetLocalSession 根据设备 SN 查找一个活跃的本地会话。 +// 这是最常用的查询方法。 +func (m *Manager) GetLocalSession(deviceSN string) (*SessionInfo, bool) { + m.sessionMutex.RLock() + defer m.sessionMutex.RUnlock() + + info, ok := m.localSessions[deviceSN] + if ok && !info.Session.IsClosed() { + // 确保会话不仅存在,而且是活跃的 + return info, true + } + return nil, false +} + +// CreateReverseProxy 是一个辅助函数,用于创建一个配置好的 httputil.ReverseProxy。 +// 将这个逻辑放在这里,是因为它与 SessionInfo 强相关,可以被 main.go 和 grpc/server.go 复用。 +func CreateReverseProxy(sessionInfo *SessionInfo, deviceSN string, originalPath string, originalQuery string) *httputil.ReverseProxy { + return &httputil.ReverseProxy{ + // Director 负责在请求被转发前,修改请求的 URL、Header 等。 + Director: func(req *http.Request) { + // [新增日志] 如果是 WebSocket 请求,就打印它 + if isWebSocketRequest(req) { // isWebSocketRequest 是我们之前写的辅助函数 + // true 表示连同 body 一起打印,对于握手请求 body 为空 + reqDump, _ := httputil.DumpRequestOut(req, true) + log.Printf("--- [SUCCESS CASE] ReverseProxy is about to send this WebSocket request:\n%s\n-------------------------------------------------", string(reqDump)) + } + + // 从原始请求路径中解析出要转发到 immich 的真正路径 + // 例如,从 "/tunnel/SN123/api/album" -> "/api/album" + pathParts := strings.SplitN(strings.TrimPrefix(originalPath, "/"), "/", 3) + if len(pathParts) > 2 { + req.URL.Path = "/" + pathParts[2] + } else { + req.URL.Path = "/" + } + + req.URL.RawQuery = originalQuery // 传递原始的查询参数 + req.URL.Scheme = "http" + // Host 不重要,因为我们下面会劫持网络连接 (DialContext) + req.URL.Host = "immich-internal" + // 设置 X-Real-IP 头,让 immich 知道原始客户端的 IP + req.Header.Set("X-Real-IP", req.RemoteAddr) + }, + + // Transport 负责实际的请求发送。我们通过重写 DialContext 来劫持它。 + Transport: &http.Transport{ + // 这是整个隧道转发的核心: + // 当 ReverseProxy 尝试建立一个 TCP 连接到 "immich-internal" 时, + // 我们不进行真正的网络拨号,而是直接在 yamux 会话上打开一个新的流 (stream)。 + // 这个流就等同于一个虚拟的 TCP 连接,直接通往设备端的 immich 容器。 + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return sessionInfo.Session.Open() + }, + // 必须禁用 HTTP/2,因为它与我们的简单流转发不兼容。 + ForceAttemptHTTP2: false, + }, + + // FlushInterval 设置为 -1 会禁用缓冲,立即将数据块发送出去。 + // 这对于视频流和 WebSocket 至关重要。 + FlushInterval: -1, + + // ModifyResponse 允许我们在响应返回给客户端之前修改它。 + ModifyResponse: func(resp *http.Response) error { + // [新增调试日志] + // 这是一个关键探针! + if resp.StatusCode == http.StatusSwitchingProtocols { // 101 + log.Printf("DEBUG (WebSocket): ModifyResponse received '101 Switching Protocols'. This means backend handshake was successful!") + } + // 这个 Header 告诉上游的代理(如 Nginx)不要缓冲这个响应。 + resp.Header.Set("X-Accel-Buffering", "no") + return nil + }, + + // ErrorHandler 定义了当转发过程中发生错误(如设备端断开连接)时的处理逻辑。 + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + // [新增调试日志] + // 这是另一个关键探针! + if r.Header.Get("Upgrade") == "websocket" { + log.Printf("DEBUG (WebSocket): ErrorHandler was triggered for a WebSocket request. Error: %v", err) + } + + log.Printf("ERROR: Reverse proxy error for device %s: %v", deviceSN, err) + http.Error(w, "Error forwarding request to device", http.StatusBadGateway) + }, + } +} + +// [新增] 确保 isWebSocketRequest 辅助函数存在于 session/manager.go +func isWebSocketRequest(r *http.Request) bool { + upgradeHeader := strings.ToLower(r.Header.Get("Upgrade")) + if upgradeHeader != "websocket" { + return false + } + connectionHeader := strings.ToLower(r.Header.Get("Connection")) + return strings.Contains(connectionHeader, "upgrade") +} diff --git a/storage/redis.go b/storage/redis.go new file mode 100644 index 0000000..b55753a --- /dev/null +++ b/storage/redis.go @@ -0,0 +1,126 @@ +// 文件: storage/redis.go +package storage + +import ( + "context" + "fmt" + "github.com/redis/go-redis/v9" + "log" + "memobus_relay_server/config" // 替换为你的模块名 + "time" +) + +// RedisManager 结构体封装了所有与 Redis 相关的操作 +type RedisManager struct { + Client *redis.Client + sessionTTL time.Duration +} + +// GlobalRedis 是一个全局可访问的 RedisManager 实例 +var GlobalRedis *RedisManager + +// InitRedis 初始化 Redis 连接并创建全局的 RedisManager 实例 +// 如果配置中 Redis 未启用,则返回 nil +func InitRedis() error { + if !config.Cfg.Redis.Enabled { + log.Println("Redis is disabled in config. Skipping initialization.") + return nil + + } + + client := redis.NewClient(&redis.Options{ + Addr: config.Cfg.Redis.Addr, + Password: config.Cfg.Redis.Password, + DB: config.Cfg.Redis.DB, + }) + + if err := client.Ping(context.Background()).Err(); err != nil { + return fmt.Errorf("failed to connect to Redis: %w", err) + } + + GlobalRedis = &RedisManager{ + Client: client, + sessionTTL: time.Duration(config.Cfg.Redis.SessionTTLSeconds) * time.Second, + } + + log.Println("Successfully connected to Redis.") + return nil +} + +// getRedisKey 生成设备会话在 Redis 中的 key +func getRedisKey(deviceSN string) string { + return fmt.Sprintf("device_session:%s", deviceSN) +} + +// RegisterDeviceSession 将设备标记为在线 +// 在单机模式下,value 可以是一个简单的占位符,如 "online" +func (m *RedisManager) RegisterDeviceSession(deviceSN string, value string) error { + key := getRedisKey(deviceSN) + err := m.Client.Set(context.Background(), key, value, m.sessionTTL).Err() + if err != nil { + return fmt.Errorf("failed to register device '%s' to Redis: %w", deviceSN, err) + } + log.Printf("Device '%s' registered in Redis.", deviceSN) + return nil +} + +// DeregisterDeviceSession 从 Redis 中移除设备会话 +func (m *RedisManager) DeregisterDeviceSession(deviceSN string) { + key := getRedisKey(deviceSN) + m.Client.Del(context.Background(), key) + log.Printf("Device '%s' deregistered from Redis.", deviceSN) +} + +// IsDeviceOnline 检查设备是否在 Redis 中被标记为在线 +func (m *RedisManager) IsDeviceOnline(deviceSN string) (bool, error) { + key := getRedisKey(deviceSN) + val, err := m.Client.Get(context.Background(), key).Result() + if err == redis.Nil { + return false, nil // Key 不存在,明确表示不在线 + } + if err != nil { + return false, fmt.Errorf("redis error looking up device '%s': %w", deviceSN, err) + } + return val != "", nil // 只要 key 存在且值不为空,就认为在线 +} + +// [新增] GetDeviceOwner 函数,用来获取持有连接的实例 ID +func (m *RedisManager) GetDeviceOwner(deviceSN string) (string, error) { + key := getRedisKey(deviceSN) + instanceID, err := m.Client.Get(context.Background(), key).Result() + if err != nil { + // 让调用者处理 redis.Nil 错误,这表示设备未找到 + return "", err + } + return instanceID, nil +} + +// KeepAliveSession 启动一个 goroutine,为给定的设备会话在 Redis 中定期续期 +func (m *RedisManager) KeepAliveSession(closeChan <-chan struct{}, deviceSN string) { + // 以 TTL 的一半作为续期间隔 + ticker := time.NewTicker(m.sessionTTL / 2) + defer ticker.Stop() + + key := getRedisKey(deviceSN) + log.Printf("Starting Redis keep-alive for device '%s'.", deviceSN) + + for { + select { + case <-ticker.C: + // 为 key 续期 + err := m.Client.Expire(context.Background(), key, m.sessionTTL).Err() + if err != nil { + // 如果 key 不存在了(可能被手动删除或过期),就没必要再续了 + if err == redis.Nil { + log.Printf("Redis key for %s no longer exists, stopping keep-alive.", deviceSN) + return + } + log.Printf("ERROR: Failed to refresh session TTL for %s in Redis: %v", deviceSN, err) + } + case <-closeChan: + // session 关闭了,退出 goroutine + log.Printf("Stopping Redis keep-alive for device '%s' due to session close.", deviceSN) + return + } + } +}