package main import ( "context" "golang.org/x/time/rate" "net" ) // ThrottledConn 包装了一个 net.Conn,并对其读写速率进行限制 type ThrottledConn struct { net.Conn readerLimiter *rate.Limiter writerLimiter *rate.Limiter } // NewThrottledConn 创建一个新的限流连接 func NewThrottledConn(conn net.Conn, readLimiter, writeLimiter *rate.Limiter) *ThrottledConn { return &ThrottledConn{ Conn: conn, readerLimiter: readLimiter, writerLimiter: writeLimiter, } } // Read 方法被重写,增加了速率限制 func (c *ThrottledConn) Read(p []byte) (n int, err error) { if c.readerLimiter == nil { return c.Conn.Read(p) // 如果没有读限制器,则直接读取 } // 读取数据 n, err = c.Conn.Read(p) if n > 0 { // [核心] 等待,直到令牌桶允许 n 字节的数据通过 if waitErr := c.readerLimiter.WaitN(context.Background(), n); waitErr != nil { return n, waitErr // 如果等待出错,返回错误 } } return n, err } // Write 方法被重写,增加了速率限制 func (c *ThrottledConn) Write(p []byte) (n int, err error) { if c.writerLimiter == nil { return c.Conn.Write(p) // 如果没有写限制器,则直接写入 } // 分块写入,因为一次写入可能超过令牌桶的 burst size remaining := len(p) written := 0 for remaining > 0 { // 计算本次可以写入多少 // WaitN 会阻塞,直到桶里有足够的令牌,或者 context 被取消 // 一次最多等待一个桶大小 (burst size) 的令牌 burst := c.writerLimiter.Burst() toWrite := min(remaining, burst) if err := c.writerLimiter.WaitN(context.Background(), toWrite); err != nil { return written, err } // 写入数据 n, err := c.Conn.Write(p[written : written+toWrite]) if n > 0 { written += n remaining -= n } if err != nil { return written, err } } return written, nil } // min 辅助函数 func min(a, b int) int { if a < b { return a } return b }