Browse Source

1. 增加限速配置以及实现相应功能,限制上传下载的速度;

2. 中转服务启动时清理僵尸设备。
main
lin_hl 5 days ago
parent
commit
1b7d67bf6d
  1. 11
      config.yml
  2. 20
      config/config.go
  3. 5
      go.mod
  4. 68
      main.go
  5. 83
      throttled_conn.go

11
config.yml

@ -8,6 +8,17 @@ server:
public_device_addr: "192.168.5.193:7002"
instance_id: "" # 留空会自动生成
# 速率限制配置
rate_limit:
enabled: true # 是否启用限流
# 下载速率限制(从 relay-server -> App),单位 MB/s
download_mbps: 1.5
# 上传速率限制(从 App -> relay-server),单位 MB/s
upload_mbps: 1.5
# 令牌桶的“桶”大小,单位 KB。
# 允许的瞬时突发流量。建议是速率的 1-2 倍。
# 例如,1 MB/s 的速率,可以设置 1024-2048 KB 的桶大小。
burst_kb: 1536
# 认证密钥配置
auth:

20
config/config.go

@ -13,11 +13,12 @@ type Config struct {
Redis RedisConfig `mapstructure:"redis"`
}
type ServerConfig struct {
AppListenPort string `mapstructure:"app_listen_port"`
DeviceListenPort string `mapstructure:"device_listen_port"`
PublicAppAddr string `mapstructure:"public_app_addr"`
PublicDeviceAddr string `mapstructure:"public_device_addr"`
InstanceID string `mapstructure:"instance_id"`
AppListenPort string `mapstructure:"app_listen_port"`
DeviceListenPort string `mapstructure:"device_listen_port"`
PublicAppAddr string `mapstructure:"public_app_addr"`
PublicDeviceAddr string `mapstructure:"public_device_addr"`
InstanceID string `mapstructure:"instance_id"`
RateLimit RateLimitConfig `mapstructure:"rate_limit"` // 新增
}
type AuthConfig struct {
Enabled bool `mapstructure:"enabled"`
@ -31,13 +32,20 @@ type RedisConfig struct {
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
// [新增]
InstanceRegistryKey string `mapstructure:"instance_registry_key"`
DeviceRelayMappingKey string `mapstructure:"device_relay_mapping_key"`
HeartbeatIntervalSeconds int `mapstructure:"heartbeat_interval_seconds"`
InstanceTTLSeconds int `mapstructure:"instance_ttl_seconds"`
}
// 速率限制配置
type RateLimitConfig struct {
Enabled bool `mapstructure:"enabled"`
DownloadMBps float64 `mapstructure:"download_mbps"`
UploadMBps float64 `mapstructure:"upload_mbps"`
BurstKB int `mapstructure:"burst_kb"`
}
var Cfg *Config
func LoadConfig() {

5
go.mod

@ -1,6 +1,8 @@
module memobus_relay_server
go 1.24
go 1.24.0
toolchain go1.24.2
require (
github.com/golang-jwt/jwt/v5 v5.3.0
@ -8,6 +10,7 @@ require (
github.com/hashicorp/yamux v0.1.2
github.com/redis/go-redis/v9 v9.14.1
github.com/spf13/viper v1.21.0
golang.org/x/time v0.14.0
)
require (

68
main.go

@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"github.com/redis/go-redis/v9"
"golang.org/x/time/rate"
"log"
"memobus_relay_server/config"
"memobus_relay_server/registry"
@ -65,6 +66,12 @@ func main() {
log.Fatalf("Failed to initialize storage: %v", err)
}
// 启动时清理逻辑,清理僵尸设备
if storage.RedisClient != nil {
log.Println("Performing startup cleanup of device mappings...")
cleanupStaleMappings(config.Cfg.Server.InstanceID)
}
// 3. 启动注册与心跳 (它会自己检查 Redis 是否启用)
registry.StartHeartbeat(func() int {
sessionMutex.RLock()
@ -111,6 +118,34 @@ func main() {
log.Println("Graceful shutdown complete.")
}
// 清理函数
func cleanupStaleMappings(myInstanceID string) {
mappingKey := config.Cfg.Redis.DeviceRelayMappingKey
// 1. 获取所有设备映射
allMappings, err := storage.RedisClient.HGetAll(context.Background(), mappingKey).Result()
if err != nil {
log.Printf("ERROR: Failed to get all mappings for cleanup: %v", err)
return
}
// 2. 遍历查找所有指向本实例的旧映射
staleDevices := []string{}
for deviceSN, instanceID := range allMappings {
if instanceID == myInstanceID {
staleDevices = append(staleDevices, deviceSN)
}
}
// 3. 如果找到了,就批量删除它们
if len(staleDevices) > 0 {
log.Printf("Found %d stale mappings pointing to this instance (%s). Removing them...", len(staleDevices), myInstanceID)
if err := storage.RedisClient.HDel(context.Background(), mappingKey, staleDevices...).Err(); err != nil {
log.Printf("ERROR: Failed to clean up stale mappings: %v", err)
}
}
}
func listenForDevices(addr string) {
log.Printf("Listening for device connections on %s\n", addr)
listener, err := net.Listen("tcp", addr)
@ -282,8 +317,37 @@ func handleAppRequest(w http.ResponseWriter, r *http.Request) {
},
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 劫持连接创建,改为打开一个 yamux 流
return sessionInfo.Session.Open()
// 1. 正常打开一个原始的 yamux 流
stream, err := sessionInfo.Session.Open()
if err != nil {
return nil, err
}
// 2. [核心修改] 检查是否需要应用限流
if !config.Cfg.Server.RateLimit.Enabled {
return stream, nil // 如果不启用,直接返回原始 stream
}
// 3. 创建速率限制器
// 下载限制 (对应 stream 的 Read)
downloadLimit := rate.Limit(config.Cfg.Server.RateLimit.DownloadMBps * 1024 * 1024)
downloadBurst := config.Cfg.Server.RateLimit.BurstKB * 1024
downloadLimiter := rate.NewLimiter(downloadLimit, downloadBurst)
// 上传限制 (对应 stream 的 Write)
uploadLimit := rate.Limit(config.Cfg.Server.RateLimit.UploadMBps * 1024 * 1024)
uploadBurst := config.Cfg.Server.RateLimit.BurstKB * 1024
uploadLimiter := rate.NewLimiter(uploadLimit, uploadBurst)
log.Printf("Applying rate limit for device %s: Down %.2f MB/s, Up %.2f MB/s, Burst %d KB",
deviceSN, config.Cfg.Server.RateLimit.DownloadMBps, config.Cfg.Server.RateLimit.UploadMBps, config.Cfg.Server.RateLimit.BurstKB)
// 4. 将原始 stream 包装成限流的连接
// 下载:App 从 relay-server 下载数据,对应 ReverseProxy 从 stream 中“读取(Read)”数据
// 上传:App 向 relay-server 上传数据,对应 ReverseProxy 向 stream 中“写入(Write)”数据
throttledStream := NewThrottledConn(stream, downloadLimiter, uploadLimiter)
return throttledStream, nil
},
// 禁用 HTTP/2,因为它与我们的隧道不兼容
ForceAttemptHTTP2: false,

83
throttled_conn.go

@ -0,0 +1,83 @@
package main
import (
"context"
"golang.org/x/time/rate"
"net"
)
// ThrottledConn 包装了一个 net.Conn,并对其读写速率进行限制
type ThrottledConn struct {
net.Conn
readerLimiter *rate.Limiter
writerLimiter *rate.Limiter
}
// NewThrottledConn 创建一个新的限流连接
func NewThrottledConn(conn net.Conn, readLimiter, writeLimiter *rate.Limiter) *ThrottledConn {
return &ThrottledConn{
Conn: conn,
readerLimiter: readLimiter,
writerLimiter: writeLimiter,
}
}
// Read 方法被重写,增加了速率限制
func (c *ThrottledConn) Read(p []byte) (n int, err error) {
if c.readerLimiter == nil {
return c.Conn.Read(p) // 如果没有读限制器,则直接读取
}
// 读取数据
n, err = c.Conn.Read(p)
if n > 0 {
// [核心] 等待,直到令牌桶允许 n 字节的数据通过
if waitErr := c.readerLimiter.WaitN(context.Background(), n); waitErr != nil {
return n, waitErr // 如果等待出错,返回错误
}
}
return n, err
}
// Write 方法被重写,增加了速率限制
func (c *ThrottledConn) Write(p []byte) (n int, err error) {
if c.writerLimiter == nil {
return c.Conn.Write(p) // 如果没有写限制器,则直接写入
}
// 分块写入,因为一次写入可能超过令牌桶的 burst size
remaining := len(p)
written := 0
for remaining > 0 {
// 计算本次可以写入多少
// WaitN 会阻塞,直到桶里有足够的令牌,或者 context 被取消
// 一次最多等待一个桶大小 (burst size) 的令牌
burst := c.writerLimiter.Burst()
toWrite := min(remaining, burst)
if err := c.writerLimiter.WaitN(context.Background(), toWrite); err != nil {
return written, err
}
// 写入数据
n, err := c.Conn.Write(p[written : written+toWrite])
if n > 0 {
written += n
remaining -= n
}
if err != nil {
return written, err
}
}
return written, nil
}
// min 辅助函数
func min(a, b int) int {
if a < b {
return a
}
return b
}
Loading…
Cancel
Save