package main import ( "context" "encoding/json" "errors" "fmt" "io" "log" "net" "net/http" "os" "os/signal" "strings" "syscall" "time" // 项目内包 "memobus_relay_server/config" grpc_server "memobus_relay_server/grpc" // 使用别名以区分标准库 "memobus_relay_server/peer" relaypb "memobus_relay_server/relay_server/proto" "memobus_relay_server/session" "memobus_relay_server/storage" // 第三方库 "github.com/golang-jwt/jwt/v5" "github.com/hashicorp/yamux" "github.com/redis/go-redis/v9" "google.golang.org/grpc" ) // 1. 全局变量 var ( // 这两个变量在 main 函数中通过配置进行初始化 appAccessSecret []byte deviceRelaySecret []byte ) type AuthPayload struct { DeviceSN string `json:"device_sn"` Token string `json:"token"` } type DeviceJWTClaims struct { DeviceSN string `json:"sn"` UserID string `json:"userId"` jwt.RegisteredClaims } // 2. Main & 服务启动逻辑 func main() { // 1. 加载配置 config.LoadConfig() appAccessSecret = []byte(config.Cfg.Auth.AppAccessSecret) deviceRelaySecret = []byte(config.Cfg.Auth.DeviceRelaySecret) // 2. 初始化所有模块/管理器 if err := storage.InitRedis(); err != nil { log.Fatalf("Failed to initialize storage: %v", err) } session.InitManager() peer.InitManager(storage.GlobalRedis.Client) // --- [修改] 将 HTTP Server 的创建和启动分开 --- // 创建一个新的 HTTP server mux (路由器) mux := http.NewServeMux() mux.HandleFunc("/tunnel/", handleAppRequest) // 创建一个 http.Server 对象,这样我们稍后可以调用它的 Shutdown 方法 httpServer := &http.Server{ Addr: config.Cfg.Server.AppListenPort, Handler: mux, } // 3. 将所有服务放入后台 goroutine // (startGRPCServer 和 listenForDevices 内部已经处理好了 goroutine) if config.Cfg.Redis.Enabled { go startGRPCServer() go startServiceDiscovery() log.Println("Running in CLUSTER mode.") } else { log.Println("Running in SINGLE-NODE mode.") } go listenForDevices(config.Cfg.Server.DeviceListenPort) // [修改] 将 HTTP 服务器也放入一个 goroutine 中启动 go func() { log.Printf("Starting App HTTP server on %s", config.Cfg.Server.AppListenPort) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Failed to start App server: %v", err) } log.Println("App HTTP server has stopped.") }() // --- 4. 设置优雅停机逻辑 --- // 创建一个 channel 来等待操作系统信号 shutdownChan := make(chan os.Signal, 1) signal.Notify(shutdownChan, syscall.SIGINT, syscall.SIGTERM) // 阻塞 main goroutine,直到收到信号 sig := <-shutdownChan log.Printf("Shutdown signal received (%s), starting graceful shutdown...", sig) // --- 5. 执行清理和关闭操作 --- // a. 创建一个带超时的上下文,用于关闭服务器 // 给服务器一点时间来处理完当前正在进行的请求 shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // b. 优雅地关闭 HTTP 服务器 if err := httpServer.Shutdown(shutdownCtx); err != nil { log.Printf("HTTP server shutdown error: %v", err) } else { log.Println("HTTP server gracefully stopped.") } // c. 从服务发现中注销本实例 if config.Cfg.Redis.Enabled && storage.GlobalRedis != nil { log.Println("Deregistering instance from service discovery...") key := config.Cfg.Redis.InstanceRegistryKey instanceID := config.Cfg.Server.InstanceID // 使用一个独立的上下文,不与 shutdownCtx 关联 storage.GlobalRedis.Client.HDel(context.Background(), key, instanceID) } // (未来可以增加关闭 gRPC server 和 TCP listener 的逻辑) log.Println("Graceful shutdown complete.") } // startGRPCServer 启动用于服务器间通信的内部 gRPC 服务 func startGRPCServer() { lis, err := net.Listen("tcp", config.Cfg.Server.GrpcListenAddr) if err != nil { log.Fatalf("Failed to listen for gRPC on %s: %v", config.Cfg.Server.GrpcListenAddr, err) } s := grpc.NewServer() relaypb.RegisterInternalRelayServer(s, grpc_server.NewInternalRelayServer()) log.Printf("Internal gRPC server listening at %s", config.Cfg.Server.GrpcListenAddr) if err := s.Serve(lis); err != nil { log.Fatalf("Failed to serve gRPC: %v", err) } } // startServiceDiscovery 启动一个心跳 goroutine,定期向 Redis 注册本实例 func startServiceDiscovery() { key := config.Cfg.Redis.InstanceRegistryKey instanceID := config.Cfg.Server.InstanceID addr := config.Cfg.Server.GrpcAdvertiseAddr ttl := time.Duration(config.Cfg.Redis.InstanceTTLSeconds) * time.Second // 使用 TTL 的一半作为心跳间隔,确保在过期前续期 ticker := time.NewTicker(ttl / 2) defer ticker.Stop() log.Printf("Starting service discovery heartbeat for instance '%s' (%s)", instanceID, addr) // 立即执行一次,不等第一个 ticker updateHeartbeat := func() { // --- [新增] 清理逻辑 --- // 1. 获取所有已注册的实例 allInstances, err := storage.GlobalRedis.Client.HGetAll(context.Background(), key).Result() if err != nil { log.Printf("ERROR: Failed to get all instances for cleanup: %v", err) // 即使获取失败,我们仍然要继续尝试注册自己 } else { // 2. 遍历查找与自己地址冲突的旧实例 for oldInstanceID, oldAddr := range allInstances { // 如果找到一个不同的 instanceID 却使用了相同的地址, // 并且这个旧 ID 不是我们自己当前的 ID,那么它就是“僵尸” if oldAddr == addr && oldInstanceID != instanceID { log.Printf("INFO: Found stale instance '%s' with the same address. Cleaning up...", oldInstanceID) // 3. 删除僵尸实例 storage.GlobalRedis.Client.HDel(context.Background(), key, oldInstanceID) } } } // --- 清理逻辑结束 --- err = storage.GlobalRedis.Client.HSet(context.Background(), key, instanceID, addr).Err() if err != nil { log.Printf("ERROR: failed to heartbeat instance to redis: %v", err) } // 为整个 Hash key 设置一个过期时间,以防所有实例都下线后 key 永久存在 storage.GlobalRedis.Client.Expire(context.Background(), key, ttl*2) } updateHeartbeat() for range ticker.C { updateHeartbeat() } } // listenForDevices 监听并接受来自设备的 TCP 连接 func listenForDevices(addr string) { listener, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("Failed to listen for devices on %s: %v", addr, err) } defer listener.Close() log.Printf("Listening for device connections on %s", addr) for { conn, err := listener.Accept() if err != nil { log.Printf("Failed to accept device connection: %v", err) continue } go handleDeviceSession(conn) } } // ============================================================================== // 3. 设备端会话处理 // ============================================================================== func handleDeviceSession(conn net.Conn) { defer conn.Close() log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) var auth AuthPayload if err := json.NewDecoder(conn).Decode(&auth); err != nil { /* ... */ return } claims, err := validateDeviceToken(auth.Token) if err != nil { log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) return } if claims.DeviceSN != auth.DeviceSN { log.Printf("Authentication failed (SN mismatch: token SN '%s' vs payload SN '%s')", claims.DeviceSN, auth.DeviceSN) return } deviceSN, userID := claims.DeviceSN, claims.UserID log.Printf("Device '%s' (user: %s) authenticated.", deviceSN, userID) // ... [认证逻辑结束] ... // 启动 yamux 会话 yamuxConfig := yamux.DefaultConfig() yamuxConfig.EnableKeepAlive = true yamuxConfig.KeepAliveInterval = 30 * time.Second s, err := yamux.Server(conn, yamuxConfig) if err != nil { log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) return } defer s.Close() // 1. 添加到本地会话管理器 sessionInfo := &session.SessionInfo{Session: s, UserID: userID} session.GlobalManager.AddSession(deviceSN, sessionInfo) // 2. 如果启用集群模式,注册到 Redis if storage.GlobalRedis != nil { // 注册的值是本机的实例 ID err := storage.GlobalRedis.RegisterDeviceSession(deviceSN, config.Cfg.Server.InstanceID) if err != nil { log.Printf("ERROR: %v", err) } // 启动 Redis KeepAlive go storage.GlobalRedis.KeepAliveSession(s.CloseChan(), deviceSN) } // 注册 defer 函数,在会话关闭时清理资源 defer func() { session.GlobalManager.RemoveSession(deviceSN, s) if storage.GlobalRedis != nil { storage.GlobalRedis.DeregisterDeviceSession(deviceSN) } log.Printf("Cleaned up resources for device '%s' session.", deviceSN) }() // 阻塞直到会话关闭 <-s.CloseChan() } // ============================================================================== // 4. App 端请求智能路由 // ============================================================================== func handleAppRequest(w http.ResponseWriter, r *http.Request) { pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) if len(pathParts) < 2 || pathParts[0] != "tunnel" { http.Error(w, "Invalid path format. Use /tunnel/{deviceSN}/...", http.StatusBadRequest) return } deviceSN := pathParts[1] // --- [App 认证逻辑] --- appUserID, err := authenticateAppRequest(r) if err != nil { log.Printf("App authentication failed for device %s: %v", deviceSN, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } // 如果未启用集群模式,直接走本地处理逻辑 if !config.Cfg.Redis.Enabled { handleLocalRequest(w, r, deviceSN, appUserID) return } // 集群模式下的路由决策 ownerInstanceID, err := storage.GlobalRedis.GetDeviceOwner(deviceSN) if err != nil { if err == redis.Nil { http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) } else { log.Printf("ERROR: Redis lookup failed for %s: %v", deviceSN, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) } return } // 判断设备连接是否在本实例上 if ownerInstanceID == config.Cfg.Server.InstanceID { handleLocalRequest(w, r, deviceSN, appUserID) } else { handleRemoteRequest(w, r, ownerInstanceID, appUserID) } } // handleLocalRequest 处理连接在本实例上的设备的请求 func handleLocalRequest(w http.ResponseWriter, r *http.Request, deviceSN string, appUserID string) { sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN) if !ok { log.Printf("WARN: Consistency issue. Device '%s' is registered to this instance but not found in local memory.", deviceSN) http.Error(w, "Device session not found on this server", http.StatusBadGateway) if storage.GlobalRedis != nil { storage.GlobalRedis.DeregisterDeviceSession(deviceSN) } return } if sessionInfo.UserID != appUserID { log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) return } proxy := session.CreateReverseProxy(sessionInfo, deviceSN, r.URL.Path, r.URL.RawQuery) proxy.ServeHTTP(w, r) } // handleRemoteRequest 将请求通过 gRPC 转发到持有连接的另一个实例 func handleRemoteRequest(w http.ResponseWriter, r *http.Request, targetInstanceID string, appUserID string) { // [这部分代码已在之前的回答中提供并解释,这里直接粘贴] deviceSN := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3)[1] log.Printf("Forwarding request for device %s to remote instance %s", deviceSN, targetInstanceID) conn, err := peer.GlobalManager.GetClient(targetInstanceID) if err != nil { log.Printf("ERROR: failed to get gRPC client for peer %s: %v", targetInstanceID, err) http.Error(w, "Service internal error (peer unreachable)", http.StatusInternalServerError) return } client := relaypb.NewInternalRelayClient(conn) ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() stream, err := client.ProxyRequest(ctx) if err != nil { log.Printf("ERROR: failed to start gRPC proxy stream to %s: %v", targetInstanceID, err) http.Error(w, "Service internal error (stream failed)", http.StatusInternalServerError) return } // [新增调试日志] if r.Header.Get("Upgrade") == "websocket" { log.Printf("DEBUG (WebSocket): handleRemoteRequest received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", r.Header.Get("Connection"), r.Header.Get("Upgrade")) } headers := make(map[string]string) for k, v := range r.Header { headers[k] = strings.Join(v, ",") } // [新增调试日志] if headers["Upgrade"] == "websocket" { log.Printf("DEBUG (WebSocket): Packing headers into gRPC message. Headers: Connection='%s', Upgrade='%s'", headers["Connection"], headers["Upgrade"]) } headerMsg := &relaypb.ProxyRequestMessage{ Payload: &relaypb.ProxyRequestMessage_Header{ Header: &relaypb.ProxyRequestHeader{ Method: r.Method, Url: r.URL.String(), Headers: headers, RemoteAddr: r.RemoteAddr, AppUserId: appUserID, }, }, } if err := stream.Send(headerMsg); err != nil { log.Printf("ERROR: failed to send gRPC request header to %s: %v", targetInstanceID, err) http.Error(w, "Service internal error (header send failed)", http.StatusInternalServerError) return } go func() { defer stream.CloseSend() if _, err := io.Copy(&grpcStreamWriter{stream: stream}, r.Body); err != nil { log.Printf("ERROR: failed copying request body to gRPC stream for %s: %v", deviceSN, err) } }() respHeaderMsg, err := stream.Recv() if err != nil { log.Printf("ERROR: failed to receive gRPC response header from %s: %v", targetInstanceID, err) http.Error(w, "Gateway timeout or peer unavailable", http.StatusGatewayTimeout) return } respHeader := respHeaderMsg.GetHeader() if respHeader == nil { log.Printf("ERROR: received invalid first message (not a header) from peer %s", targetInstanceID) http.Error(w, "Internal gateway error (invalid peer response)", http.StatusBadGateway) return } for k, v := range respHeader.Headers { w.Header().Set(k, v) } w.WriteHeader(int(respHeader.StatusCode)) for { respBodyMsg, err := stream.Recv() if err == io.EOF { break } if err != nil { log.Printf("ERROR: gRPC response stream broke for device %s: %v", deviceSN, err) break } if _, writeErr := w.Write(respBodyMsg.GetBodyChunk().Data); writeErr != nil { log.Printf("WARN: could not write to client for device %s, client likely disconnected: %v", deviceSN, writeErr) break } if f, ok := w.(http.Flusher); ok { f.Flush() } } } // grpcStreamWriter 是一个辅助类型,实现了 io.Writer 接口 type grpcStreamWriter struct { stream relaypb.InternalRelay_ProxyRequestClient } func (w *grpcStreamWriter) Write(p []byte) (n int, err error) { err = w.stream.Send(&relaypb.ProxyRequestMessage{ Payload: &relaypb.ProxyRequestMessage_BodyChunk{ // [修正] 将 ResponseBodyChunk 改为 RequestBodyChunk BodyChunk: &relaypb.ProxyRequestBodyChunk{Data: p}, }, }) if err != nil { return 0, err } return len(p), nil } func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { if len(deviceRelaySecret) == 0 { return nil, errors.New("RELAY_SECRET is not configured on the server") } claims := &DeviceJWTClaims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method for device token") } return deviceRelaySecret, nil }) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("device token is invalid") } return claims, nil } // authenticateAppRequest 和 verifyAppToken 保持不变,备用 func authenticateAppRequest(r *http.Request) (string, error) { //authHeader := r.Header.Get("Authorization") //if authHeader == "" { // return "", errors.New("missing Authorization header") //} //tokenString := strings.TrimPrefix(authHeader, "Bearer ") //if tokenString == authHeader { // return "", errors.New("authorization header format must be Bearer {token}") //} //claims, err := verifyAppToken(tokenString) //if err != nil { // return "", fmt.Errorf("app token verification failed: %w", err) //} //if userID, ok := claims["user_id"].(string); ok { // return userID, nil //} //return "", errors.New("user_id not found in app token claims") return "af672ce1-b528-4c18-af7e-e47b09619520", nil } func verifyAppToken(tokenString string) (jwt.MapClaims, error) { if len(tokenString) == 0 { return nil, errors.New("token can not be empty") } if len(appAccessSecret) == 0 { return nil, errors.New("APP_ACCESS_SECRET is not configured") } token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method for app token") } return appAccessSecret, nil }) if err != nil { return nil, err } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { return claims, nil } return nil, errors.New("invalid app token") }