5 changed files with 178 additions and 9 deletions
			
			
		@ -0,0 +1,83 @@ | 
				
			|||||
 | 
					package main | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					import ( | 
				
			||||
 | 
						"context" | 
				
			||||
 | 
						"golang.org/x/time/rate" | 
				
			||||
 | 
						"net" | 
				
			||||
 | 
					) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// ThrottledConn 包装了一个 net.Conn,并对其读写速率进行限制
 | 
				
			||||
 | 
					type ThrottledConn struct { | 
				
			||||
 | 
						net.Conn | 
				
			||||
 | 
						readerLimiter *rate.Limiter | 
				
			||||
 | 
						writerLimiter *rate.Limiter | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// NewThrottledConn 创建一个新的限流连接
 | 
				
			||||
 | 
					func NewThrottledConn(conn net.Conn, readLimiter, writeLimiter *rate.Limiter) *ThrottledConn { | 
				
			||||
 | 
						return &ThrottledConn{ | 
				
			||||
 | 
							Conn:          conn, | 
				
			||||
 | 
							readerLimiter: readLimiter, | 
				
			||||
 | 
							writerLimiter: writeLimiter, | 
				
			||||
 | 
						} | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// Read 方法被重写,增加了速率限制
 | 
				
			||||
 | 
					func (c *ThrottledConn) Read(p []byte) (n int, err error) { | 
				
			||||
 | 
						if c.readerLimiter == nil { | 
				
			||||
 | 
							return c.Conn.Read(p) // 如果没有读限制器,则直接读取
 | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						// 读取数据
 | 
				
			||||
 | 
						n, err = c.Conn.Read(p) | 
				
			||||
 | 
						if n > 0 { | 
				
			||||
 | 
							// [核心] 等待,直到令牌桶允许 n 字节的数据通过
 | 
				
			||||
 | 
							if waitErr := c.readerLimiter.WaitN(context.Background(), n); waitErr != nil { | 
				
			||||
 | 
								return n, waitErr // 如果等待出错,返回错误
 | 
				
			||||
 | 
							} | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return n, err | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// Write 方法被重写,增加了速率限制
 | 
				
			||||
 | 
					func (c *ThrottledConn) Write(p []byte) (n int, err error) { | 
				
			||||
 | 
						if c.writerLimiter == nil { | 
				
			||||
 | 
							return c.Conn.Write(p) // 如果没有写限制器,则直接写入
 | 
				
			||||
 | 
						} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						// 分块写入,因为一次写入可能超过令牌桶的 burst size
 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						remaining := len(p) | 
				
			||||
 | 
						written := 0 | 
				
			||||
 | 
					
 | 
				
			||||
 | 
						for remaining > 0 { | 
				
			||||
 | 
							// 计算本次可以写入多少
 | 
				
			||||
 | 
							// WaitN 会阻塞,直到桶里有足够的令牌,或者 context 被取消
 | 
				
			||||
 | 
							// 一次最多等待一个桶大小 (burst size) 的令牌
 | 
				
			||||
 | 
							burst := c.writerLimiter.Burst() | 
				
			||||
 | 
							toWrite := min(remaining, burst) | 
				
			||||
 | 
					
 | 
				
			||||
 | 
							if err := c.writerLimiter.WaitN(context.Background(), toWrite); err != nil { | 
				
			||||
 | 
								return written, err | 
				
			||||
 | 
							} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
							// 写入数据
 | 
				
			||||
 | 
							n, err := c.Conn.Write(p[written : written+toWrite]) | 
				
			||||
 | 
							if n > 0 { | 
				
			||||
 | 
								written += n | 
				
			||||
 | 
								remaining -= n | 
				
			||||
 | 
							} | 
				
			||||
 | 
							if err != nil { | 
				
			||||
 | 
								return written, err | 
				
			||||
 | 
							} | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return written, nil | 
				
			||||
 | 
					} | 
				
			||||
 | 
					
 | 
				
			||||
 | 
					// min 辅助函数
 | 
				
			||||
 | 
					func min(a, b int) int { | 
				
			||||
 | 
						if a < b { | 
				
			||||
 | 
							return a | 
				
			||||
 | 
						} | 
				
			||||
 | 
						return b | 
				
			||||
 | 
					} | 
				
			||||
					Loading…
					
					
				
		Reference in new issue