// 文件名: main.go (服务端) package main import ( "context" "encoding/json" "errors" "fmt" "github.com/redis/go-redis/v9" "log" "memobus_relay_server/config" "memobus_relay_server/registry" "memobus_relay_server/storage" "net" "net/http" "net/http/httputil" "os" "os/signal" "strings" "sync" "syscall" "time" "github.com/golang-jwt/jwt/v5" "github.com/hashicorp/yamux" ) // 1. 密钥配置 var ( appAccessSecret []byte deviceRelaySecret []byte ) // 2. 结构体定义 type AuthPayload struct { DeviceSN string `json:"device_sn"` Token string `json:"token"` } type DeviceJWTClaims struct { DeviceSN string `json:"sn"` UserID string `json:"userId"` jwt.RegisteredClaims } type SessionInfo struct { Session *yamux.Session UserID string } var ( deviceSessions = make(map[string]*SessionInfo) sessionMutex = &sync.RWMutex{} ) // 3. Main & 服务器启动逻辑 func main() { // 1. 加载配置 config.LoadConfig() appAccessSecret = []byte(config.Cfg.Auth.AppAccessSecret) deviceRelaySecret = []byte(config.Cfg.Auth.DeviceRelaySecret) // 2. 初始化存储层 (Redis) if err := storage.InitRedis(); err != nil { log.Fatalf("Failed to initialize storage: %v", err) } // 3. 启动注册与心跳 (它会自己检查 Redis 是否启用) registry.StartHeartbeat(func() int { sessionMutex.RLock() defer sessionMutex.RUnlock() return len(deviceSessions) }) // 4. 启动核心服务 (放入后台 goroutine) go listenForDevices(config.Cfg.Server.DeviceListenPort) mux := http.NewServeMux() mux.HandleFunc("/tunnel/", handleAppRequest) httpServer := &http.Server{ Addr: config.Cfg.Server.AppListenPort, Handler: mux, } go func() { log.Printf("Starting App HTTP server on %s", config.Cfg.Server.AppListenPort) if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("App server ListenAndServe error: %v", err) } }() // 5. 设置并等待优雅停机 shutdownChan := make(chan os.Signal, 1) signal.Notify(shutdownChan, syscall.SIGINT, syscall.SIGTERM) sig := <-shutdownChan log.Printf("Shutdown signal received (%s), starting graceful shutdown...", sig) // 6. 执行清理操作 shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // a. 向调度服务(Redis)注销自己 registry.Unregister() // b. 优雅地关闭 HTTP 服务器 if err := httpServer.Shutdown(shutdownCtx); err != nil { log.Printf("HTTP server shutdown error: %v", err) } else { log.Println("HTTP server gracefully stopped.") } log.Println("Graceful shutdown complete.") } func listenForDevices(addr string) { log.Printf("Listening for device connections on %s\n", addr) listener, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("Failed to listen for devices: %v", err) } defer listener.Close() for { conn, err := listener.Accept() if err != nil { log.Printf("Failed to accept device connection: %v", err) continue } go handleDeviceSession(conn) } } // 4. 设备端认证与会话管理 func handleDeviceSession(conn net.Conn) { defer conn.Close() log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) conn.SetReadDeadline(time.Now().Add(10 * time.Second)) var auth AuthPayload if err := json.NewDecoder(conn).Decode(&auth); err != nil { log.Printf("Authentication failed (reading payload): %v", err) return } conn.SetReadDeadline(time.Time{}) claims, err := validateDeviceToken(auth.Token) if err != nil { log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) return } if claims.DeviceSN != auth.DeviceSN { log.Printf("Authentication failed (SN mismatch: token SN '%s' vs payload SN '%s')", claims.DeviceSN, auth.DeviceSN) return } deviceSN := claims.DeviceSN userID := claims.UserID log.Printf("Device '%s' (user: %s) authenticated successfully.\n", deviceSN, userID) yamuxConfig := yamux.DefaultConfig() yamuxConfig.EnableKeepAlive = true yamuxConfig.KeepAliveInterval = 30 * time.Second session, err := yamux.Server(conn, yamuxConfig) if err != nil { log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) return } defer session.Close() sessionInfo := &SessionInfo{Session: session, UserID: userID} sessionMutex.Lock() if oldInfo, exists := deviceSessions[deviceSN]; exists { log.Printf("Device '%s' already connected, closing old session.", deviceSN) oldInfo.Session.Close() } deviceSessions[deviceSN] = sessionInfo sessionMutex.Unlock() log.Printf("Yamux session started for device '%s'\n", deviceSN) if storage.RedisClient != nil { instanceID := config.Cfg.Server.InstanceID if err := storage.RedisClient.HSet(context.Background(), config.Cfg.Redis.DeviceRelayMappingKey, deviceSN, instanceID).Err(); err != nil { log.Printf("ERROR: Failed to update device-relay mapping for %s: %v", deviceSN, err) } else { log.Printf("Device %s is now mapped to instance %s in Redis.", deviceSN, instanceID) } } defer func() { sessionMutex.Lock() if currentInfo, exists := deviceSessions[deviceSN]; exists && currentInfo.Session == session { delete(deviceSessions, deviceSN) } sessionMutex.Unlock() log.Printf("Device '%s' session closed\n", deviceSN) // b. 再清理 Redis 映射 if storage.RedisClient != nil { instanceID := config.Cfg.Server.InstanceID // [健壮性优化] 在删除前,先检查一下 Redis 里的值是不是还是自己。 // 这可以防止因为竞态条件,错误地删除了一个刚刚重连到本机的、更新的会话映射。 currentInstanceID, err := storage.RedisClient.HGet(context.Background(), config.Cfg.Redis.DeviceRelayMappingKey, deviceSN).Result() if err == nil && currentInstanceID == instanceID { storage.RedisClient.HDel(context.Background(), config.Cfg.Redis.DeviceRelayMappingKey, deviceSN) log.Printf("Removed device-relay mapping for %s.", deviceSN) } else if err != nil && err != redis.Nil { log.Printf("ERROR: Could not verify mapping for %s before deleting: %v", deviceSN, err) } } }() <-session.CloseChan() } func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { if len(deviceRelaySecret) == 0 { return nil, errors.New("RELAY_SECRET is not configured on the server") } claims := &DeviceJWTClaims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method for device token") } return deviceRelaySecret, nil }) if err != nil { return nil, err } if !token.Valid { return nil, errors.New("device token is invalid") } return claims, nil } // 5. App 端认证与请求处理 func handleAppRequest(w http.ResponseWriter, r *http.Request) { pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) if len(pathParts) < 2 || pathParts[0] != "tunnel" { http.Error(w, "Invalid path format. Use /tunnel/{deviceSN}/...", http.StatusBadRequest) return } deviceSN := pathParts[1] // --- [App 认证逻辑 - 暂时注释,需要时取消注释即可] --- /* appUserID, err := authenticateAppRequest(r) if err != nil { log.Printf("App authentication failed for device %s: %v", deviceSN, err) http.Error(w, "Unauthorized", http.StatusUnauthorized) return } */ sessionMutex.RLock() sessionInfo, ok := deviceSessions[deviceSN] sessionMutex.RUnlock() if !ok || sessionInfo.Session.IsClosed() { http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) return } /* --- [所有权检查 - 暂时注释] --- if sessionInfo.UserID != appUserID { log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) return } */ proxy := &httputil.ReverseProxy{ Director: func(req *http.Request) { // Director 负责重写请求 if len(pathParts) > 2 { req.URL.Path = "/" + pathParts[2] req.URL.RawQuery = r.URL.RawQuery // 确保查询参数也被传递 } else { req.URL.Path = "/" } req.URL.Scheme = "http" req.URL.Host = r.Host // 使用原始请求的 Host req.Header.Set("X-Real-IP", r.RemoteAddr) }, Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { // 劫持连接创建,改为打开一个 yamux 流 return sessionInfo.Session.Open() }, // 禁用 HTTP/2,因为它与我们的隧道不兼容 ForceAttemptHTTP2: false, }, FlushInterval: -1, // 支持流式响应 ModifyResponse: func(resp *http.Response) error { // 告知下游代理不要缓冲 resp.Header.Set("X-Accel-Buffering", "no") return nil }, ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { log.Printf("Reverse proxy error for device %s: %v", deviceSN, err) http.Error(w, "Error forwarding request", http.StatusBadGateway) }, } log.Printf("Forwarding request for device '%s' to path '%s'", deviceSN, r.URL.Path) proxy.ServeHTTP(w, r) } // authenticateAppRequest 和 verifyAppToken 保持不变,备用 func authenticateAppRequest(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") if authHeader == "" { return "", errors.New("missing Authorization header") } tokenString := strings.TrimPrefix(authHeader, "Bearer ") if tokenString == authHeader { return "", errors.New("authorization header format must be Bearer {token}") } claims, err := verifyAppToken(tokenString) if err != nil { return "", fmt.Errorf("app token verification failed: %w", err) } if userID, ok := claims["user_id"].(string); ok { return userID, nil } return "", errors.New("user_id not found in app token claims") } func verifyAppToken(tokenString string) (jwt.MapClaims, error) { if len(tokenString) == 0 { return nil, errors.New("token can not be empty") } if len(appAccessSecret) == 0 { return nil, errors.New("APP_ACCESS_SECRET is not configured") } token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method for app token") } return appAccessSecret, nil }) if err != nil { return nil, err } if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { return claims, nil } return nil, errors.New("invalid app token") }