diff --git a/config.yml b/config.yml index 15d7bce..a78c7a9 100644 --- a/config.yml +++ b/config.yml @@ -8,6 +8,17 @@ server: public_device_addr: "192.168.5.193:7002" instance_id: "" # 留空会自动生成 + # 速率限制配置 + rate_limit: + enabled: true # 是否启用限流 + # 下载速率限制(从 relay-server -> App),单位 MB/s + download_mbps: 1.5 + # 上传速率限制(从 App -> relay-server),单位 MB/s + upload_mbps: 1.5 + # 令牌桶的“桶”大小,单位 KB。 + # 允许的瞬时突发流量。建议是速率的 1-2 倍。 + # 例如,1 MB/s 的速率,可以设置 1024-2048 KB 的桶大小。 + burst_kb: 1536 # 认证密钥配置 auth: diff --git a/config/config.go b/config/config.go index 463cca5..20f96a2 100644 --- a/config/config.go +++ b/config/config.go @@ -13,11 +13,12 @@ type Config struct { Redis RedisConfig `mapstructure:"redis"` } type ServerConfig struct { - AppListenPort string `mapstructure:"app_listen_port"` - DeviceListenPort string `mapstructure:"device_listen_port"` - PublicAppAddr string `mapstructure:"public_app_addr"` - PublicDeviceAddr string `mapstructure:"public_device_addr"` - InstanceID string `mapstructure:"instance_id"` + AppListenPort string `mapstructure:"app_listen_port"` + DeviceListenPort string `mapstructure:"device_listen_port"` + PublicAppAddr string `mapstructure:"public_app_addr"` + PublicDeviceAddr string `mapstructure:"public_device_addr"` + InstanceID string `mapstructure:"instance_id"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` // 新增 } type AuthConfig struct { Enabled bool `mapstructure:"enabled"` @@ -31,13 +32,20 @@ type RedisConfig struct { Password string `mapstructure:"password"` DB int `mapstructure:"db"` - // [新增] InstanceRegistryKey string `mapstructure:"instance_registry_key"` DeviceRelayMappingKey string `mapstructure:"device_relay_mapping_key"` HeartbeatIntervalSeconds int `mapstructure:"heartbeat_interval_seconds"` InstanceTTLSeconds int `mapstructure:"instance_ttl_seconds"` } +// 速率限制配置 +type RateLimitConfig struct { + Enabled bool `mapstructure:"enabled"` + DownloadMBps float64 `mapstructure:"download_mbps"` + UploadMBps float64 `mapstructure:"upload_mbps"` + BurstKB int `mapstructure:"burst_kb"` +} + var Cfg *Config func LoadConfig() { diff --git a/go.mod b/go.mod index 2062d1f..e276a3b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module memobus_relay_server -go 1.24 +go 1.24.0 + +toolchain go1.24.2 require ( github.com/golang-jwt/jwt/v5 v5.3.0 @@ -8,6 +10,7 @@ require ( github.com/hashicorp/yamux v0.1.2 github.com/redis/go-redis/v9 v9.14.1 github.com/spf13/viper v1.21.0 + golang.org/x/time v0.14.0 ) require ( diff --git a/main.go b/main.go index 19fa879..4d0cb77 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/redis/go-redis/v9" + "golang.org/x/time/rate" "log" "memobus_relay_server/config" "memobus_relay_server/registry" @@ -65,6 +66,12 @@ func main() { log.Fatalf("Failed to initialize storage: %v", err) } + // 启动时清理逻辑,清理僵尸设备 + if storage.RedisClient != nil { + log.Println("Performing startup cleanup of device mappings...") + cleanupStaleMappings(config.Cfg.Server.InstanceID) + } + // 3. 启动注册与心跳 (它会自己检查 Redis 是否启用) registry.StartHeartbeat(func() int { sessionMutex.RLock() @@ -111,6 +118,34 @@ func main() { log.Println("Graceful shutdown complete.") } +// 清理函数 +func cleanupStaleMappings(myInstanceID string) { + mappingKey := config.Cfg.Redis.DeviceRelayMappingKey + + // 1. 获取所有设备映射 + allMappings, err := storage.RedisClient.HGetAll(context.Background(), mappingKey).Result() + if err != nil { + log.Printf("ERROR: Failed to get all mappings for cleanup: %v", err) + return + } + + // 2. 遍历查找所有指向本实例的旧映射 + staleDevices := []string{} + for deviceSN, instanceID := range allMappings { + if instanceID == myInstanceID { + staleDevices = append(staleDevices, deviceSN) + } + } + + // 3. 如果找到了,就批量删除它们 + if len(staleDevices) > 0 { + log.Printf("Found %d stale mappings pointing to this instance (%s). Removing them...", len(staleDevices), myInstanceID) + if err := storage.RedisClient.HDel(context.Background(), mappingKey, staleDevices...).Err(); err != nil { + log.Printf("ERROR: Failed to clean up stale mappings: %v", err) + } + } +} + func listenForDevices(addr string) { log.Printf("Listening for device connections on %s\n", addr) listener, err := net.Listen("tcp", addr) @@ -282,8 +317,37 @@ func handleAppRequest(w http.ResponseWriter, r *http.Request) { }, Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 劫持连接创建,改为打开一个 yamux 流 - return sessionInfo.Session.Open() + // 1. 正常打开一个原始的 yamux 流 + stream, err := sessionInfo.Session.Open() + if err != nil { + return nil, err + } + + // 2. [核心修改] 检查是否需要应用限流 + if !config.Cfg.Server.RateLimit.Enabled { + return stream, nil // 如果不启用,直接返回原始 stream + } + + // 3. 创建速率限制器 + // 下载限制 (对应 stream 的 Read) + downloadLimit := rate.Limit(config.Cfg.Server.RateLimit.DownloadMBps * 1024 * 1024) + downloadBurst := config.Cfg.Server.RateLimit.BurstKB * 1024 + downloadLimiter := rate.NewLimiter(downloadLimit, downloadBurst) + + // 上传限制 (对应 stream 的 Write) + uploadLimit := rate.Limit(config.Cfg.Server.RateLimit.UploadMBps * 1024 * 1024) + uploadBurst := config.Cfg.Server.RateLimit.BurstKB * 1024 + uploadLimiter := rate.NewLimiter(uploadLimit, uploadBurst) + + log.Printf("Applying rate limit for device %s: Down %.2f MB/s, Up %.2f MB/s, Burst %d KB", + deviceSN, config.Cfg.Server.RateLimit.DownloadMBps, config.Cfg.Server.RateLimit.UploadMBps, config.Cfg.Server.RateLimit.BurstKB) + + // 4. 将原始 stream 包装成限流的连接 + // 下载:App 从 relay-server 下载数据,对应 ReverseProxy 从 stream 中“读取(Read)”数据 + // 上传:App 向 relay-server 上传数据,对应 ReverseProxy 向 stream 中“写入(Write)”数据 + throttledStream := NewThrottledConn(stream, downloadLimiter, uploadLimiter) + + return throttledStream, nil }, // 禁用 HTTP/2,因为它与我们的隧道不兼容 ForceAttemptHTTP2: false, diff --git a/throttled_conn.go b/throttled_conn.go new file mode 100644 index 0000000..65679c0 --- /dev/null +++ b/throttled_conn.go @@ -0,0 +1,83 @@ +package main + +import ( + "context" + "golang.org/x/time/rate" + "net" +) + +// ThrottledConn 包装了一个 net.Conn,并对其读写速率进行限制 +type ThrottledConn struct { + net.Conn + readerLimiter *rate.Limiter + writerLimiter *rate.Limiter +} + +// NewThrottledConn 创建一个新的限流连接 +func NewThrottledConn(conn net.Conn, readLimiter, writeLimiter *rate.Limiter) *ThrottledConn { + return &ThrottledConn{ + Conn: conn, + readerLimiter: readLimiter, + writerLimiter: writeLimiter, + } +} + +// Read 方法被重写,增加了速率限制 +func (c *ThrottledConn) Read(p []byte) (n int, err error) { + if c.readerLimiter == nil { + return c.Conn.Read(p) // 如果没有读限制器,则直接读取 + } + + // 读取数据 + n, err = c.Conn.Read(p) + if n > 0 { + // [核心] 等待,直到令牌桶允许 n 字节的数据通过 + if waitErr := c.readerLimiter.WaitN(context.Background(), n); waitErr != nil { + return n, waitErr // 如果等待出错,返回错误 + } + } + return n, err +} + +// Write 方法被重写,增加了速率限制 +func (c *ThrottledConn) Write(p []byte) (n int, err error) { + if c.writerLimiter == nil { + return c.Conn.Write(p) // 如果没有写限制器,则直接写入 + } + + // 分块写入,因为一次写入可能超过令牌桶的 burst size + + remaining := len(p) + written := 0 + + for remaining > 0 { + // 计算本次可以写入多少 + // WaitN 会阻塞,直到桶里有足够的令牌,或者 context 被取消 + // 一次最多等待一个桶大小 (burst size) 的令牌 + burst := c.writerLimiter.Burst() + toWrite := min(remaining, burst) + + if err := c.writerLimiter.WaitN(context.Background(), toWrite); err != nil { + return written, err + } + + // 写入数据 + n, err := c.Conn.Write(p[written : written+toWrite]) + if n > 0 { + written += n + remaining -= n + } + if err != nil { + return written, err + } + } + return written, nil +} + +// min 辅助函数 +func min(a, b int) int { + if a < b { + return a + } + return b +}