commit
						256ab3ebb5
					
				 3 changed files with 313 additions and 0 deletions
			
			
		@ -0,0 +1,8 @@ | 
				
			|||||
 | 
					module memobus_relay_server | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					go 1.24 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					require ( | 
				
			||||
 | 
						github.com/golang-jwt/jwt/v5 v5.3.0 | 
				
			||||
 | 
						github.com/hashicorp/yamux v0.1.2 | 
				
			||||
 | 
					) | 
				
			||||
@ -0,0 +1,290 @@ | 
				
			|||||
 | 
					// 文件名: main.go (服务端)
 | 
				
			||||
 | 
					package main | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					import ( | 
				
			||||
 | 
						"context" | 
				
			||||
 | 
						"encoding/json" | 
				
			||||
 | 
						"errors" | 
				
			||||
 | 
						"fmt" | 
				
			||||
 | 
						"log" | 
				
			||||
 | 
						"net" | 
				
			||||
 | 
						"net/http" | 
				
			||||
 | 
						"net/http/httputil" | 
				
			||||
 | 
						"os" | 
				
			||||
 | 
						"strings" | 
				
			||||
 | 
						"sync" | 
				
			||||
 | 
						"time" | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						"github.com/golang-jwt/jwt/v5" | 
				
			||||
 | 
						"github.com/hashicorp/yamux" | 
				
			||||
 | 
					) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					// 1. 密钥配置
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					var ( | 
				
			||||
 | 
						// 用于验证 App 请求的密钥,必须和 ibserver 的 AppAccessSecret 一致
 | 
				
			||||
 | 
						appAccessSecret = []byte(os.Getenv("APP_ACCESS_SECRET")) | 
				
			||||
 | 
						// 用于验证设备连接的密钥,必须和旧中继服务的 RelaySecret 一致
 | 
				
			||||
 | 
						deviceRelaySecret = []byte(os.Getenv("RELAY_SECRET")) | 
				
			||||
 | 
					) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					// 2. 结构体定义
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					type AuthPayload struct { | 
				
			||||
 | 
						DeviceSN string `json:"device_sn"` | 
				
			||||
 | 
						Token    string `json:"token"` | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					type DeviceJWTClaims struct { | 
				
			||||
 | 
						DeviceSN string `json:"sn"` | 
				
			||||
 | 
						UserID   string `json:"userId"` | 
				
			||||
 | 
						jwt.RegisteredClaims | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					type SessionInfo struct { | 
				
			||||
 | 
						Session *yamux.Session | 
				
			||||
 | 
						UserID  string | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					var ( | 
				
			||||
 | 
						deviceSessions = make(map[string]*SessionInfo) | 
				
			||||
 | 
						sessionMutex   = &sync.RWMutex{} | 
				
			||||
 | 
					) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					// 3. Main & 服务器启动逻辑
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					func main() { | 
				
			||||
 | 
						if len(appAccessSecret) == 0 || len(deviceRelaySecret) == 0 { | 
				
			||||
 | 
							log.Println("WARNING: APP_ACCESS_SECRET or RELAY_SECRET environment variable not set.") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						go listenForDevices(":7002") | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						log.Println("Starting App HTTP server on :8089") | 
				
			||||
 | 
						http.HandleFunc("/tunnel/", handleAppRequest) // 统一入口
 | 
				
			||||
 | 
						if err := http.ListenAndServe(":8089", nil); err != nil { | 
				
			||||
 | 
							log.Fatalf("Failed to start App server: %v", err) | 
				
			||||
 | 
						} | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					func listenForDevices(addr string) { | 
				
			||||
 | 
						log.Printf("Listening for device connections on %s\n", addr) | 
				
			||||
 | 
						listener, err := net.Listen("tcp", addr) | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							log.Fatalf("Failed to listen for devices: %v", err) | 
				
			||||
 | 
						} | 
				
			||||
 | 
						defer listener.Close() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						for { | 
				
			||||
 | 
							conn, err := listener.Accept() | 
				
			||||
 | 
							if err != nil { | 
				
			||||
 | 
								log.Printf("Failed to accept device connection: %v", err) | 
				
			||||
 | 
								continue | 
				
			||||
 | 
							} | 
				
			||||
 | 
							go handleDeviceSession(conn) | 
				
			||||
 | 
						} | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					// 4. 设备端认证与会话管理
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					func handleDeviceSession(conn net.Conn) { | 
				
			||||
 | 
						defer conn.Close() | 
				
			||||
 | 
						log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						conn.SetReadDeadline(time.Now().Add(10 * time.Second)) | 
				
			||||
 | 
						var auth AuthPayload | 
				
			||||
 | 
						if err := json.NewDecoder(conn).Decode(&auth); err != nil { | 
				
			||||
 | 
							log.Printf("Authentication failed (reading payload): %v", err) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
						conn.SetReadDeadline(time.Time{}) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						claims, err := validateDeviceToken(auth.Token) | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						if claims.DeviceSN != auth.DeviceSN { | 
				
			||||
 | 
							log.Printf("Authentication failed (SN mismatch: token SN '%s' vs payload SN '%s')", claims.DeviceSN, auth.DeviceSN) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						deviceSN := claims.DeviceSN | 
				
			||||
 | 
						userID := claims.UserID | 
				
			||||
 | 
						log.Printf("Device '%s' (user: %s) authenticated successfully.\n", deviceSN, userID) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						config := yamux.DefaultConfig() | 
				
			||||
 | 
						config.EnableKeepAlive = true | 
				
			||||
 | 
						config.KeepAliveInterval = 30 * time.Second | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						session, err := yamux.Server(conn, config) | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
						defer session.Close() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						sessionInfo := &SessionInfo{Session: session, UserID: userID} | 
				
			||||
 | 
						sessionMutex.Lock() | 
				
			||||
 | 
						if oldInfo, exists := deviceSessions[deviceSN]; exists { | 
				
			||||
 | 
							log.Printf("Device '%s' already connected, closing old session.", deviceSN) | 
				
			||||
 | 
							oldInfo.Session.Close() | 
				
			||||
 | 
						} | 
				
			||||
 | 
						deviceSessions[deviceSN] = sessionInfo | 
				
			||||
 | 
						sessionMutex.Unlock() | 
				
			||||
 | 
						log.Printf("Yamux session started for device '%s'\n", deviceSN) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						defer func() { | 
				
			||||
 | 
							sessionMutex.Lock() | 
				
			||||
 | 
							if currentInfo, exists := deviceSessions[deviceSN]; exists && currentInfo.Session == session { | 
				
			||||
 | 
								delete(deviceSessions, deviceSN) | 
				
			||||
 | 
							} | 
				
			||||
 | 
							sessionMutex.Unlock() | 
				
			||||
 | 
							log.Printf("Device '%s' session closed\n", deviceSN) | 
				
			||||
 | 
						}() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						<-session.CloseChan() | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) { | 
				
			||||
 | 
						if len(deviceRelaySecret) == 0 { | 
				
			||||
 | 
							return nil, errors.New("RELAY_SECRET is not configured on the server") | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						claims := &DeviceJWTClaims{} | 
				
			||||
 | 
						token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { | 
				
			||||
 | 
							if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
				
			||||
 | 
								return nil, fmt.Errorf("unexpected signing method for device token") | 
				
			||||
 | 
							} | 
				
			||||
 | 
							return deviceRelaySecret, nil | 
				
			||||
 | 
						}) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							return nil, err | 
				
			||||
 | 
						} | 
				
			||||
 | 
						if !token.Valid { | 
				
			||||
 | 
							return nil, errors.New("device token is invalid") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return claims, nil | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					// 5. App 端认证与请求处理
 | 
				
			||||
 | 
					// ==============================================================================
 | 
				
			||||
 | 
					func handleAppRequest(w http.ResponseWriter, r *http.Request) { | 
				
			||||
 | 
						pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) | 
				
			||||
 | 
						if len(pathParts) < 2 || pathParts[0] != "tunnel" { | 
				
			||||
 | 
							http.Error(w, "Invalid path format. Use /tunnel/{deviceSN}/...", http.StatusBadRequest) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
						deviceSN := pathParts[1] | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						// --- [App 认证逻辑 - 暂时注释,需要时取消注释即可] ---
 | 
				
			||||
 | 
						/* | 
				
			||||
 | 
							appUserID, err := authenticateAppRequest(r) | 
				
			||||
 | 
							if err != nil { | 
				
			||||
 | 
								log.Printf("App authentication failed for device %s: %v", deviceSN, err) | 
				
			||||
 | 
								http.Error(w, "Unauthorized", http.StatusUnauthorized) | 
				
			||||
 | 
								return | 
				
			||||
 | 
							} | 
				
			||||
 | 
						*/ | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						sessionMutex.RLock() | 
				
			||||
 | 
						sessionInfo, ok := deviceSessions[deviceSN] | 
				
			||||
 | 
						sessionMutex.RUnlock() | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						if !ok || sessionInfo.Session.IsClosed() { | 
				
			||||
 | 
							http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						/* --- [所有权检查 - 暂时注释] --- | 
				
			||||
 | 
						if sessionInfo.UserID != appUserID { | 
				
			||||
 | 
							log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) | 
				
			||||
 | 
							http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) | 
				
			||||
 | 
							return | 
				
			||||
 | 
						} | 
				
			||||
 | 
						*/ | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						proxy := &httputil.ReverseProxy{ | 
				
			||||
 | 
							Director: func(req *http.Request) { | 
				
			||||
 | 
								// Director 负责重写请求
 | 
				
			||||
 | 
								if len(pathParts) > 2 { | 
				
			||||
 | 
									req.URL.Path = "/" + pathParts[2] | 
				
			||||
 | 
									req.URL.RawQuery = r.URL.RawQuery // 确保查询参数也被传递
 | 
				
			||||
 | 
								} else { | 
				
			||||
 | 
									req.URL.Path = "/" | 
				
			||||
 | 
								} | 
				
			||||
 | 
								req.URL.Scheme = "http" | 
				
			||||
 | 
								req.URL.Host = r.Host // 使用原始请求的 Host
 | 
				
			||||
 | 
								req.Header.Set("X-Real-IP", r.RemoteAddr) | 
				
			||||
 | 
							}, | 
				
			||||
 | 
							Transport: &http.Transport{ | 
				
			||||
 | 
								DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { | 
				
			||||
 | 
									// 劫持连接创建,改为打开一个 yamux 流
 | 
				
			||||
 | 
									return sessionInfo.Session.Open() | 
				
			||||
 | 
								}, | 
				
			||||
 | 
								// 禁用 HTTP/2,因为它与我们的隧道不兼容
 | 
				
			||||
 | 
								ForceAttemptHTTP2: false, | 
				
			||||
 | 
							}, | 
				
			||||
 | 
							FlushInterval: -1, // 支持流式响应
 | 
				
			||||
 | 
							ModifyResponse: func(resp *http.Response) error { | 
				
			||||
 | 
								// 告知下游代理不要缓冲
 | 
				
			||||
 | 
								resp.Header.Set("X-Accel-Buffering", "no") | 
				
			||||
 | 
								return nil | 
				
			||||
 | 
							}, | 
				
			||||
 | 
							ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { | 
				
			||||
 | 
								log.Printf("Reverse proxy error for device %s: %v", deviceSN, err) | 
				
			||||
 | 
								http.Error(w, "Error forwarding request", http.StatusBadGateway) | 
				
			||||
 | 
							}, | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						log.Printf("Forwarding request for device '%s' to path '%s'", deviceSN, r.URL.Path) | 
				
			||||
 | 
						proxy.ServeHTTP(w, r) | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// authenticateAppRequest 和 verifyAppToken 保持不变,备用
 | 
				
			||||
 | 
					func authenticateAppRequest(r *http.Request) (string, error) { | 
				
			||||
 | 
						authHeader := r.Header.Get("Authorization") | 
				
			||||
 | 
						if authHeader == "" { | 
				
			||||
 | 
							return "", errors.New("missing Authorization header") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						tokenString := strings.TrimPrefix(authHeader, "Bearer ") | 
				
			||||
 | 
						if tokenString == authHeader { | 
				
			||||
 | 
							return "", errors.New("authorization header format must be Bearer {token}") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						claims, err := verifyAppToken(tokenString) | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							return "", fmt.Errorf("app token verification failed: %w", err) | 
				
			||||
 | 
						} | 
				
			||||
 | 
						if userID, ok := claims["user_id"].(string); ok { | 
				
			||||
 | 
							return userID, nil | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return "", errors.New("user_id not found in app token claims") | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					func verifyAppToken(tokenString string) (jwt.MapClaims, error) { | 
				
			||||
 | 
						if len(tokenString) == 0 { | 
				
			||||
 | 
							return nil, errors.New("token can not be empty") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						if len(appAccessSecret) == 0 { | 
				
			||||
 | 
							return nil, errors.New("APP_ACCESS_SECRET is not configured") | 
				
			||||
 | 
						} | 
				
			||||
 | 
						token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { | 
				
			||||
 | 
							if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
				
			||||
 | 
								return nil, fmt.Errorf("unexpected signing method for app token") | 
				
			||||
 | 
							} | 
				
			||||
 | 
							return appAccessSecret, nil | 
				
			||||
 | 
						}) | 
				
			||||
 | 
						if err != nil { | 
				
			||||
 | 
							return nil, err | 
				
			||||
 | 
						} | 
				
			||||
 | 
						if claims, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { | 
				
			||||
 | 
							return claims, nil | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return nil, errors.New("invalid app token") | 
				
			||||
 | 
					} | 
				
			||||
@ -0,0 +1,15 @@ | 
				
			|||||
 | 
					GOCMD=go | 
				
			||||
 | 
					GOBUILD=$(GOCMD) build | 
				
			||||
 | 
					GOCLEAN=$(GOCMD) clean | 
				
			||||
 | 
					GOTEST=$(GOCMD) test | 
				
			||||
 | 
					GOGET=$(GOCMD) get | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					BINARY_NAME=main | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					all: test build | 
				
			||||
 | 
					build: | 
				
			||||
 | 
						@$(GOBUILD) -o $(BINARY_NAME) main.go | 
				
			||||
 | 
					build-linux-amd64: | 
				
			||||
 | 
						@GOOS=linux GOARCH=amd64 $(GOBUILD) -o $(BINARY_NAME) . | 
				
			||||
 | 
					build-linux-riscv64: | 
				
			||||
 | 
						@GOOS=linux GOARCH=riscv64 $(GOBUILD) -o $(BINARY_NAME) . | 
				
			||||
					Loading…
					
					
				
		Reference in new issue