You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
1.9 KiB
83 lines
1.9 KiB
package main
|
|
|
|
import (
|
|
"context"
|
|
"golang.org/x/time/rate"
|
|
"net"
|
|
)
|
|
|
|
// ThrottledConn 包装了一个 net.Conn,并对其读写速率进行限制
|
|
type ThrottledConn struct {
|
|
net.Conn
|
|
readerLimiter *rate.Limiter
|
|
writerLimiter *rate.Limiter
|
|
}
|
|
|
|
// NewThrottledConn 创建一个新的限流连接
|
|
func NewThrottledConn(conn net.Conn, readLimiter, writeLimiter *rate.Limiter) *ThrottledConn {
|
|
return &ThrottledConn{
|
|
Conn: conn,
|
|
readerLimiter: readLimiter,
|
|
writerLimiter: writeLimiter,
|
|
}
|
|
}
|
|
|
|
// Read 方法被重写,增加了速率限制
|
|
func (c *ThrottledConn) Read(p []byte) (n int, err error) {
|
|
if c.readerLimiter == nil {
|
|
return c.Conn.Read(p) // 如果没有读限制器,则直接读取
|
|
}
|
|
|
|
// 读取数据
|
|
n, err = c.Conn.Read(p)
|
|
if n > 0 {
|
|
// [核心] 等待,直到令牌桶允许 n 字节的数据通过
|
|
if waitErr := c.readerLimiter.WaitN(context.Background(), n); waitErr != nil {
|
|
return n, waitErr // 如果等待出错,返回错误
|
|
}
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// Write 方法被重写,增加了速率限制
|
|
func (c *ThrottledConn) Write(p []byte) (n int, err error) {
|
|
if c.writerLimiter == nil {
|
|
return c.Conn.Write(p) // 如果没有写限制器,则直接写入
|
|
}
|
|
|
|
// 分块写入,因为一次写入可能超过令牌桶的 burst size
|
|
|
|
remaining := len(p)
|
|
written := 0
|
|
|
|
for remaining > 0 {
|
|
// 计算本次可以写入多少
|
|
// WaitN 会阻塞,直到桶里有足够的令牌,或者 context 被取消
|
|
// 一次最多等待一个桶大小 (burst size) 的令牌
|
|
burst := c.writerLimiter.Burst()
|
|
toWrite := min(remaining, burst)
|
|
|
|
if err := c.writerLimiter.WaitN(context.Background(), toWrite); err != nil {
|
|
return written, err
|
|
}
|
|
|
|
// 写入数据
|
|
n, err := c.Conn.Write(p[written : written+toWrite])
|
|
if n > 0 {
|
|
written += n
|
|
remaining -= n
|
|
}
|
|
if err != nil {
|
|
return written, err
|
|
}
|
|
}
|
|
return written, nil
|
|
}
|
|
|
|
// min 辅助函数
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|