commit
						256ab3ebb5
					
				 3 changed files with 313 additions and 0 deletions
			
			
		@ -0,0 +1,8 @@ | 
				
			|||
module memobus_relay_server | 
				
			|||
 | 
				
			|||
go 1.24 | 
				
			|||
 | 
				
			|||
require ( | 
				
			|||
	github.com/golang-jwt/jwt/v5 v5.3.0 | 
				
			|||
	github.com/hashicorp/yamux v0.1.2 | 
				
			|||
) | 
				
			|||
@ -0,0 +1,290 @@ | 
				
			|||
// 文件名: main.go (服务端)
 | 
				
			|||
package main | 
				
			|||
 | 
				
			|||
import ( | 
				
			|||
	"context" | 
				
			|||
	"encoding/json" | 
				
			|||
	"errors" | 
				
			|||
	"fmt" | 
				
			|||
	"log" | 
				
			|||
	"net" | 
				
			|||
	"net/http" | 
				
			|||
	"net/http/httputil" | 
				
			|||
	"os" | 
				
			|||
	"strings" | 
				
			|||
	"sync" | 
				
			|||
	"time" | 
				
			|||
 | 
				
			|||
	"github.com/golang-jwt/jwt/v5" | 
				
			|||
	"github.com/hashicorp/yamux" | 
				
			|||
) | 
				
			|||
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
// 1. 密钥配置
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
var ( | 
				
			|||
	// 用于验证 App 请求的密钥,必须和 ibserver 的 AppAccessSecret 一致
 | 
				
			|||
	appAccessSecret = []byte(os.Getenv("APP_ACCESS_SECRET")) | 
				
			|||
	// 用于验证设备连接的密钥,必须和旧中继服务的 RelaySecret 一致
 | 
				
			|||
	deviceRelaySecret = []byte(os.Getenv("RELAY_SECRET")) | 
				
			|||
) | 
				
			|||
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
// 2. 结构体定义
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
type AuthPayload struct { | 
				
			|||
	DeviceSN string `json:"device_sn"` | 
				
			|||
	Token    string `json:"token"` | 
				
			|||
} | 
				
			|||
 | 
				
			|||
type DeviceJWTClaims struct { | 
				
			|||
	DeviceSN string `json:"sn"` | 
				
			|||
	UserID   string `json:"userId"` | 
				
			|||
	jwt.RegisteredClaims | 
				
			|||
} | 
				
			|||
 | 
				
			|||
type SessionInfo struct { | 
				
			|||
	Session *yamux.Session | 
				
			|||
	UserID  string | 
				
			|||
} | 
				
			|||
 | 
				
			|||
var ( | 
				
			|||
	deviceSessions = make(map[string]*SessionInfo) | 
				
			|||
	sessionMutex   = &sync.RWMutex{} | 
				
			|||
) | 
				
			|||
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
// 3. Main & 服务器启动逻辑
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
func main() { | 
				
			|||
	if len(appAccessSecret) == 0 || len(deviceRelaySecret) == 0 { | 
				
			|||
		log.Println("WARNING: APP_ACCESS_SECRET or RELAY_SECRET environment variable not set.") | 
				
			|||
	} | 
				
			|||
	go listenForDevices(":7002") | 
				
			|||
 | 
				
			|||
	log.Println("Starting App HTTP server on :8089") | 
				
			|||
	http.HandleFunc("/tunnel/", handleAppRequest) // 统一入口
 | 
				
			|||
	if err := http.ListenAndServe(":8089", nil); err != nil { | 
				
			|||
		log.Fatalf("Failed to start App server: %v", err) | 
				
			|||
	} | 
				
			|||
} | 
				
			|||
 | 
				
			|||
func listenForDevices(addr string) { | 
				
			|||
	log.Printf("Listening for device connections on %s\n", addr) | 
				
			|||
	listener, err := net.Listen("tcp", addr) | 
				
			|||
	if err != nil { | 
				
			|||
		log.Fatalf("Failed to listen for devices: %v", err) | 
				
			|||
	} | 
				
			|||
	defer listener.Close() | 
				
			|||
 | 
				
			|||
	for { | 
				
			|||
		conn, err := listener.Accept() | 
				
			|||
		if err != nil { | 
				
			|||
			log.Printf("Failed to accept device connection: %v", err) | 
				
			|||
			continue | 
				
			|||
		} | 
				
			|||
		go handleDeviceSession(conn) | 
				
			|||
	} | 
				
			|||
} | 
				
			|||
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
// 4. 设备端认证与会话管理
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
func handleDeviceSession(conn net.Conn) { | 
				
			|||
	defer conn.Close() | 
				
			|||
	log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) | 
				
			|||
 | 
				
			|||
	conn.SetReadDeadline(time.Now().Add(10 * time.Second)) | 
				
			|||
	var auth AuthPayload | 
				
			|||
	if err := json.NewDecoder(conn).Decode(&auth); err != nil { | 
				
			|||
		log.Printf("Authentication failed (reading payload): %v", err) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
	conn.SetReadDeadline(time.Time{}) | 
				
			|||
 | 
				
			|||
	claims, err := validateDeviceToken(auth.Token) | 
				
			|||
	if err != nil { | 
				
			|||
		log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
 | 
				
			|||
	if claims.DeviceSN != auth.DeviceSN { | 
				
			|||
		log.Printf("Authentication failed (SN mismatch: token SN '%s' vs payload SN '%s')", claims.DeviceSN, auth.DeviceSN) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
 | 
				
			|||
	deviceSN := claims.DeviceSN | 
				
			|||
	userID := claims.UserID | 
				
			|||
	log.Printf("Device '%s' (user: %s) authenticated successfully.\n", deviceSN, userID) | 
				
			|||
 | 
				
			|||
	config := yamux.DefaultConfig() | 
				
			|||
	config.EnableKeepAlive = true | 
				
			|||
	config.KeepAliveInterval = 30 * time.Second | 
				
			|||
 | 
				
			|||
	session, err := yamux.Server(conn, config) | 
				
			|||
	if err != nil { | 
				
			|||
		log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
	defer session.Close() | 
				
			|||
 | 
				
			|||
	sessionInfo := &SessionInfo{Session: session, UserID: userID} | 
				
			|||
	sessionMutex.Lock() | 
				
			|||
	if oldInfo, exists := deviceSessions[deviceSN]; exists { | 
				
			|||
		log.Printf("Device '%s' already connected, closing old session.", deviceSN) | 
				
			|||
		oldInfo.Session.Close() | 
				
			|||
	} | 
				
			|||
	deviceSessions[deviceSN] = sessionInfo | 
				
			|||
	sessionMutex.Unlock() | 
				
			|||
	log.Printf("Yamux session started for device '%s'\n", deviceSN) | 
				
			|||
 | 
				
			|||
	defer func() { | 
				
			|||
		sessionMutex.Lock() | 
				
			|||
		if currentInfo, exists := deviceSessions[deviceSN]; exists && currentInfo.Session == session { | 
				
			|||
			delete(deviceSessions, deviceSN) | 
				
			|||
		} | 
				
			|||
		sessionMutex.Unlock() | 
				
			|||
		log.Printf("Device '%s' session closed\n", deviceSN) | 
				
			|||
	}() | 
				
			|||
 | 
				
			|||
	<-session.CloseChan() | 
				
			|||
} | 
				
			|||
 | 
				
			|||
func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { | 
				
			|||
	if len(deviceRelaySecret) == 0 { | 
				
			|||
		return nil, errors.New("RELAY_SECRET is not configured on the server") | 
				
			|||
	} | 
				
			|||
 | 
				
			|||
	claims := &DeviceJWTClaims{} | 
				
			|||
	token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { | 
				
			|||
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
				
			|||
			return nil, fmt.Errorf("unexpected signing method for device token") | 
				
			|||
		} | 
				
			|||
		return deviceRelaySecret, nil | 
				
			|||
	}) | 
				
			|||
 | 
				
			|||
	if err != nil { | 
				
			|||
		return nil, err | 
				
			|||
	} | 
				
			|||
	if !token.Valid { | 
				
			|||
		return nil, errors.New("device token is invalid") | 
				
			|||
	} | 
				
			|||
	return claims, nil | 
				
			|||
} | 
				
			|||
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
// 5. App 端认证与请求处理
 | 
				
			|||
// ==============================================================================
 | 
				
			|||
func handleAppRequest(w http.ResponseWriter, r *http.Request) { | 
				
			|||
	pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) | 
				
			|||
	if len(pathParts) < 2 || pathParts[0] != "tunnel" { | 
				
			|||
		http.Error(w, "Invalid path format. Use /tunnel/{deviceSN}/...", http.StatusBadRequest) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
	deviceSN := pathParts[1] | 
				
			|||
 | 
				
			|||
	// --- [App 认证逻辑 - 暂时注释,需要时取消注释即可] ---
 | 
				
			|||
	/* | 
				
			|||
		appUserID, err := authenticateAppRequest(r) | 
				
			|||
		if err != nil { | 
				
			|||
			log.Printf("App authentication failed for device %s: %v", deviceSN, err) | 
				
			|||
			http.Error(w, "Unauthorized", http.StatusUnauthorized) | 
				
			|||
			return | 
				
			|||
		} | 
				
			|||
	*/ | 
				
			|||
 | 
				
			|||
	sessionMutex.RLock() | 
				
			|||
	sessionInfo, ok := deviceSessions[deviceSN] | 
				
			|||
	sessionMutex.RUnlock() | 
				
			|||
 | 
				
			|||
	if !ok || sessionInfo.Session.IsClosed() { | 
				
			|||
		http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
 | 
				
			|||
	/* --- [所有权检查 - 暂时注释] --- | 
				
			|||
	if sessionInfo.UserID != appUserID { | 
				
			|||
		log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) | 
				
			|||
		http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) | 
				
			|||
		return | 
				
			|||
	} | 
				
			|||
	*/ | 
				
			|||
 | 
				
			|||
	proxy := &httputil.ReverseProxy{ | 
				
			|||
		Director: func(req *http.Request) { | 
				
			|||
			// Director 负责重写请求
 | 
				
			|||
			if len(pathParts) > 2 { | 
				
			|||
				req.URL.Path = "/" + pathParts[2] | 
				
			|||
				req.URL.RawQuery = r.URL.RawQuery // 确保查询参数也被传递
 | 
				
			|||
			} else { | 
				
			|||
				req.URL.Path = "/" | 
				
			|||
			} | 
				
			|||
			req.URL.Scheme = "http" | 
				
			|||
			req.URL.Host = r.Host // 使用原始请求的 Host
 | 
				
			|||
			req.Header.Set("X-Real-IP", r.RemoteAddr) | 
				
			|||
		}, | 
				
			|||
		Transport: &http.Transport{ | 
				
			|||
			DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { | 
				
			|||
				// 劫持连接创建,改为打开一个 yamux 流
 | 
				
			|||
				return sessionInfo.Session.Open() | 
				
			|||
			}, | 
				
			|||
			// 禁用 HTTP/2,因为它与我们的隧道不兼容
 | 
				
			|||
			ForceAttemptHTTP2: false, | 
				
			|||
		}, | 
				
			|||
		FlushInterval: -1, // 支持流式响应
 | 
				
			|||
		ModifyResponse: func(resp *http.Response) error { | 
				
			|||
			// 告知下游代理不要缓冲
 | 
				
			|||
			resp.Header.Set("X-Accel-Buffering", "no") | 
				
			|||
			return nil | 
				
			|||
		}, | 
				
			|||
		ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { | 
				
			|||
			log.Printf("Reverse proxy error for device %s: %v", deviceSN, err) | 
				
			|||
			http.Error(w, "Error forwarding request", http.StatusBadGateway) | 
				
			|||
		}, | 
				
			|||
	} | 
				
			|||
 | 
				
			|||
	log.Printf("Forwarding request for device '%s' to path '%s'", deviceSN, r.URL.Path) | 
				
			|||
	proxy.ServeHTTP(w, r) | 
				
			|||
} | 
				
			|||
 | 
				
			|||
// authenticateAppRequest 和 verifyAppToken 保持不变,备用
 | 
				
			|||
func authenticateAppRequest(r *http.Request) (string, error) { | 
				
			|||
	authHeader := r.Header.Get("Authorization") | 
				
			|||
	if authHeader == "" { | 
				
			|||
		return "", errors.New("missing Authorization header") | 
				
			|||
	} | 
				
			|||
	tokenString := strings.TrimPrefix(authHeader, "Bearer ") | 
				
			|||
	if tokenString == authHeader { | 
				
			|||
		return "", errors.New("authorization header format must be Bearer {token}") | 
				
			|||
	} | 
				
			|||
	claims, err := verifyAppToken(tokenString) | 
				
			|||
	if err != nil { | 
				
			|||
		return "", fmt.Errorf("app token verification failed: %w", err) | 
				
			|||
	} | 
				
			|||
	if userID, ok := claims["user_id"].(string); ok { | 
				
			|||
		return userID, nil | 
				
			|||
	} | 
				
			|||
	return "", errors.New("user_id not found in app token claims") | 
				
			|||
} | 
				
			|||
 | 
				
			|||
func verifyAppToken(tokenString string) (jwt.MapClaims, error) { | 
				
			|||
	if len(tokenString) == 0 { | 
				
			|||
		return nil, errors.New("token can not be empty") | 
				
			|||
	} | 
				
			|||
	if len(appAccessSecret) == 0 { | 
				
			|||
		return nil, errors.New("APP_ACCESS_SECRET is not configured") | 
				
			|||
	} | 
				
			|||
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | 
				
			|||
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
				
			|||
			return nil, fmt.Errorf("unexpected signing method for app token") | 
				
			|||
		} | 
				
			|||
		return appAccessSecret, nil | 
				
			|||
	}) | 
				
			|||
	if err != nil { | 
				
			|||
		return nil, err | 
				
			|||
	} | 
				
			|||
	if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { | 
				
			|||
		return claims, nil | 
				
			|||
	} | 
				
			|||
	return nil, errors.New("invalid app token") | 
				
			|||
} | 
				
			|||
@ -0,0 +1,15 @@ | 
				
			|||
GOCMD=go | 
				
			|||
GOBUILD=$(GOCMD) build | 
				
			|||
GOCLEAN=$(GOCMD) clean | 
				
			|||
GOTEST=$(GOCMD) test | 
				
			|||
GOGET=$(GOCMD) get | 
				
			|||
 | 
				
			|||
BINARY_NAME=main | 
				
			|||
 | 
				
			|||
all: test build | 
				
			|||
build: | 
				
			|||
	@$(GOBUILD) -o $(BINARY_NAME) main.go | 
				
			|||
build-linux-amd64: | 
				
			|||
	@GOOS=linux GOARCH=amd64 $(GOBUILD) -o $(BINARY_NAME) . | 
				
			|||
build-linux-riscv64: | 
				
			|||
	@GOOS=linux GOARCH=riscv64 $(GOBUILD) -o $(BINARY_NAME) . | 
				
			|||
					Loading…
					
					
				
		Reference in new issue