Browse Source

1. 增加grpc代码实现(websocket在grpc模式下有问题)

grpc-forwarding
lin_hl 2 weeks ago
parent
commit
cd38aa41ac
  1. 31
      config.yml
  2. 79
      config/config.go
  3. 28
      go.mod
  4. 360
      grpc/server.go
  5. 510
      main.go
  6. 68
      peer/manager.go
  7. 96
      proto/relay.proto
  8. 167
      session/manager.go
  9. 126
      storage/redis.go

31
config.yml

@ -0,0 +1,31 @@
# config.yml
# 服务器相关配置
server:
app_listen_port: ":8089"
device_listen_port: ":7002"
instance_id: "" # 留空会自动生成 UUID, 也可以指定一个固定的ID
# [新增] 用于服务器间通信的 gRPC 配置
grpc_listen_addr: ":9090"
# 这个地址必须能被其他服务器实例访问到。
# 在 Docker/K8s 环境中, 这应该是服务名或 Pod IP。
grpc_advertise_addr: "192.168.5.193:9090"
# 认证密钥配置
auth:
app_access_secret: "D4tBb9Y0oHSXRAyHLHpdKfXAuNCyCZ45AZxKJOhMJMs="
device_relay_secret: "p+JtJ8aHlM1lDYu7UGFanX8ALVt1pM1BQmKTpqTJccs="
# Redis 配置 (为下一步做准备)
# 如果 enabled 为 false,我们的代码将退回使用内存 map,实现单机兼容
redis:
enabled: true
addr: "118.178.183.78:6379"
password: "" # 留空表示没有密码
db: 1
session_ttl_seconds: 120 # 会话在 Redis 中的过期时间、
# [新增] 用于服务发现的 Key
# 一个 Redis Hash, 存储 instance_id -> grpc_addr 的映射
instance_registry_key: "relay_instances"
# 实例必须比这个 TTL 更快地发送心跳
instance_ttl_seconds: 15

79
config/config.go

@ -0,0 +1,79 @@
package config
import (
"github.com/google/uuid"
"github.com/spf13/viper"
"log"
"strings"
)
// Config 结构体必须与 config.yml 的结构完全对应
// 使用 `mapstructure` tag 来帮助 Viper 正确映射 YAML 键名到 Go 结构体字段
type Config struct {
Server ServerConfig `mapstructure:"server"`
Auth AuthConfig `mapstructure:"auth"`
Redis RedisConfig `mapstructure:"redis"`
}
type ServerConfig struct {
AppListenPort string `mapstructure:"app_listen_port"`
DeviceListenPort string `mapstructure:"device_listen_port"`
// [新增]
InstanceID string `mapstructure:"instance_id"`
GrpcListenAddr string `mapstructure:"grpc_listen_addr"`
GrpcAdvertiseAddr string `mapstructure:"grpc_advertise_addr"`
}
type AuthConfig struct {
AppAccessSecret string `mapstructure:"app_access_secret"`
DeviceRelaySecret string `mapstructure:"device_relay_secret"`
}
type RedisConfig struct {
Enabled bool `mapstructure:"enabled"`
Addr string `mapstructure:"addr"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` // 确保有这个字段
// [新增]
InstanceRegistryKey string `mapstructure:"instance_registry_key"`
InstanceTTLSeconds int `mapstructure:"instance_ttl_seconds"`
}
// Cfg 是一个全局变量,用于在项目的任何地方访问配置
var Cfg *Config
// LoadConfig 是初始化函数,负责读取和解析配置文件
func LoadConfig() {
viper.SetConfigName("config") // 配置文件名 (不带扩展名)
viper.SetConfigType("yml") // 配置文件类型
viper.AddConfigPath(".") // 在当前工作目录查找配置文件
viper.AddConfigPath("./config") // 也在 config 目录查找
// [关键] 开启环境变量支持
// 这允许你通过环境变量覆盖配置文件中的值
// 例如:SERVER_APP_LISTEN_ADDR=":9000" 会覆盖文件中的设置
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
// 读取配置文件
if err := viper.ReadInConfig(); err != nil {
// 如果配置文件没找到,也没关系,可能完全通过环境变量配置
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
log.Fatalf("Fatal error reading config file: %v", err)
}
}
// 将读取到的配置反序列化到 Cfg 结构体中
if err := viper.Unmarshal(&Cfg); err != nil {
log.Fatalf("Unable to decode config into struct: %v", err)
}
// [新增] 如果 instance_id 未配置,则自动生成
if Cfg.Server.InstanceID == "" {
Cfg.Server.InstanceID = uuid.New().String()
}
log.Printf("Configuration loaded. Server Instance ID: %s", Cfg.Server.InstanceID)
}

28
go.mod

@ -1,8 +1,34 @@
module memobus_relay_server module memobus_relay_server
go 1.24 go 1.24.0
toolchain go1.24.2
require ( require (
github.com/golang-jwt/jwt/v5 v5.3.0 github.com/golang-jwt/jwt/v5 v5.3.0
github.com/google/uuid v1.6.0
github.com/hashicorp/yamux v0.1.2 github.com/hashicorp/yamux v0.1.2
github.com/redis/go-redis/v9 v9.14.1
github.com/spf13/viper v1.21.0
google.golang.org/grpc v1.76.0
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/sagikazarmark/locafero v0.11.0 // indirect
github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect
github.com/spf13/afero v1.15.0 // indirect
github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/net v0.42.0 // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect
google.golang.org/protobuf v1.36.6 // indirect
) )

360
grpc/server.go

@ -0,0 +1,360 @@
// 文件: grpc/server.go
package grpc
import (
"bufio"
"io"
"log"
relaypb "memobus_relay_server/relay_server/proto"
"memobus_relay_server/session"
"net/http"
"net/http/httptest"
"strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// InternalRelayServer 实现了用于代理请求的 gRPC 服务
type InternalRelayServer struct {
relaypb.UnimplementedInternalRelayServer
}
// NewInternalRelayServer 创建一个新的 gRPC 服务实例
func NewInternalRelayServer() *InternalRelayServer {
return &InternalRelayServer{}
}
// ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应
func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error {
// --- 1. 接收和解析请求头 ---
headerMsg, err := stream.Recv()
if err != nil {
log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err)
return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err)
}
header := headerMsg.GetHeader()
if header == nil {
return status.Errorf(codes.InvalidArgument, "first message must be a header")
}
// [新增调试日志]
if header.Headers["Upgrade"] == "websocket" {
log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"])
}
// 检查是否是 WebSocket 握手请求
isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade")
pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3)
if len(pathParts) < 2 {
return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header")
}
deviceSN := pathParts[1]
appUserID := header.GetAppUserId()
if appUserID == "" {
return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header")
}
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
// --- 2. 查找本地会话并进行授权检查 ---
sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN)
if !ok {
return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN)
}
if sessionInfo.UserID != appUserID {
log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID)
return sendForbiddenResponse(stream)
}
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
// --- 3. [核心修改] 根据请求类型进行分流 ---
if isWebSocket {
log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.")
return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header)
} else {
// 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器
log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.")
// 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁
return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header)
}
}
func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
// 1. 打开到后端 (yamux) 的连接
backendConn, err := sessionInfo.Session.Open()
if err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err)
return status.Errorf(codes.Internal, "failed to connect to backend service")
}
defer backendConn.Close()
// 2. 重建原始的 HTTP 升级请求
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil)
for k, v := range header.Headers {
req.Header.Set(k, v)
}
req.Host = "immich-internal" // 模拟 ReverseProxy 的行为
pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3)
if len(pathParts) > 2 {
req.URL.Path = "/" + pathParts[2]
} else {
req.URL.Path = "/"
}
// 3. 将升级请求写入后端连接,发起握手
if err := req.Write(backendConn); err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err)
return status.Errorf(codes.Internal, "failed to send upgrade request to backend")
}
// 4. 读取后端的响应 (握手结果)
backendReader := bufio.NewReader(backendConn)
resp, err := http.ReadResponse(backendReader, req)
if err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err)
return status.Errorf(codes.Internal, "failed to read handshake response from backend")
}
// 5. 将后端的握手响应通过 gRPC 发回给代理节点
respHeaderMsg := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: int32(resp.StatusCode),
Headers: make(map[string]string),
},
},
}
for k, v := range resp.Header {
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
}
if err := stream.Send(respHeaderMsg); err != nil {
log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err)
return err
}
// 6. 如果握手失败 (不是 101),则流程结束
if resp.StatusCode != http.StatusSwitchingProtocols {
log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode)
return nil
}
log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN)
// 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝
errChan := make(chan error, 2)
// Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据)
go func() {
// 这个方向的逻辑没有问题
for {
msg, err := stream.Recv()
if err == io.EOF {
backendConn.Close()
errChan <- nil
return
}
if err != nil {
errChan <- err
return
}
if chunk := msg.GetBodyChunk(); chunk != nil {
if _, err := backendConn.Write(chunk.Data); err != nil {
errChan <- err
return
}
}
}
}()
// Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据)
go func() {
// [核心修正]
// 我们必须从 backendReader (而不是原始的 backendConn) 开始读取,
// 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。
// io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。
if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil {
// 过滤掉正常的连接关闭错误
if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") {
errChan <- err
} else {
errChan <- nil
}
} else {
errChan <- nil
}
}()
// 等待两个 goroutine 都结束
err1 := <-errChan
err2 := <-errChan
if err1 != nil && err1 != io.EOF {
log.Printf("WebSocket stream finished with error: %v", err1)
return err1
}
if err2 != nil && err2 != io.EOF {
log.Printf("WebSocket stream finished with error: %v", err2)
return err2
}
log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN)
return nil
}
// [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑
func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
// 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本
// 我直接粘贴过来
// --- 3. 创建请求和响应的管道 ---
reqPr, reqPw := io.Pipe()
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr)
for k, v := range header.Headers {
req.Header.Set(k, v)
}
req.Header.Set("X-Forwarded-For", header.RemoteAddr)
respPr, respPw := io.Pipe()
customResponseWriter := &streamResponseWriter{
header: make(http.Header),
pipeWriter: respPw,
headerWritten: make(chan struct{}),
}
// --- 4. 启动 Goroutines ---
go func() {
defer reqPw.Close()
for {
bodyMsg, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
reqPw.CloseWithError(err)
return
}
if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil {
if _, err := reqPw.Write(bodyChunk.Data); err != nil {
return
}
}
}
}()
errChan := make(chan error, 1)
go func() {
defer close(errChan)
<-customResponseWriter.headerWritten
// b. [修正] 完整地构造 gRPC 响应头
respHeaderMsg := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: int32(customResponseWriter.statusCode),
Headers: make(map[string]string),
},
},
}
for k, v := range customResponseWriter.header {
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
}
if err := stream.Send(respHeaderMsg); err != nil {
errChan <- err
return
}
buf := make([]byte, 1024*32)
if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil {
if err != io.ErrClosedPipe {
errChan <- err
}
}
}()
// --- 5. 执行代理 ---
proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery)
proxy.ServeHTTP(customResponseWriter, req)
// --- 6. 清理 ---
respPw.Close()
return <-errChan
}
// sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应
func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error {
respHeader := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_Header{
Header: &relaypb.ProxyResponseHeader{
StatusCode: http.StatusForbidden,
Headers: map[string]string{"Content-Type": "text/plain; charset=utf-8"},
},
},
}
if err := stream.Send(respHeader); err != nil {
return err
}
respBody := &relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")},
},
}
stream.Send(respBody)
return nil // 正常关闭流
}
// streamResponseWriter 是一个自定义的 http.ResponseWriter
type streamResponseWriter struct {
header http.Header
pipeWriter *io.PipeWriter
statusCode int
headerWritten chan struct{}
}
func (w *streamResponseWriter) Header() http.Header {
return w.header
}
func (w *streamResponseWriter) Write(b []byte) (int, error) {
w.WriteHeader(http.StatusOK)
return w.pipeWriter.Write(b)
}
func (w *streamResponseWriter) WriteHeader(statusCode int) {
select {
case <-w.headerWritten:
return
default:
w.statusCode = statusCode
close(w.headerWritten)
}
}
// grpcResponseWriter 是一个适配器,实现了 io.Writer 接口
type grpcResponseWriter struct {
stream relaypb.InternalRelay_ProxyRequestServer
}
func (w *grpcResponseWriter) Write(p []byte) (n int, err error) {
err = w.stream.Send(&relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p},
},
})
if err != nil {
return 0, err
}
return len(p), nil
}
// writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中
func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error {
return stream.Send(&relaypb.ProxyResponseMessage{
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data},
},
})
}

510
main.go

@ -1,4 +1,3 @@
// 文件名: main.go (服务端)
package main package main
import ( import (
@ -6,32 +5,38 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"net/http" "net/http"
"net/http/httputil"
"os" "os"
"os/signal"
"strings" "strings"
"sync" "syscall"
"time" "time"
// 项目内包
"memobus_relay_server/config"
grpc_server "memobus_relay_server/grpc" // 使用别名以区分标准库
"memobus_relay_server/peer"
relaypb "memobus_relay_server/relay_server/proto"
"memobus_relay_server/session"
"memobus_relay_server/storage"
// 第三方库
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"github.com/redis/go-redis/v9"
"google.golang.org/grpc"
) )
// ============================================================================== // 1. 全局变量
// 1. 密钥配置
// ==============================================================================
var ( var (
// 用于验证 App 请求的密钥,必须和 ibserver 的 AppAccessSecret 一致 // 这两个变量在 main 函数中通过配置进行初始化
appAccessSecret = []byte(os.Getenv("APP_ACCESS_SECRET")) appAccessSecret []byte
// 用于验证设备连接的密钥,必须和旧中继服务的 RelaySecret 一致 deviceRelaySecret []byte
deviceRelaySecret = []byte(os.Getenv("RELAY_SECRET"))
) )
// ==============================================================================
// 2. 结构体定义
// ==============================================================================
type AuthPayload struct { type AuthPayload struct {
DeviceSN string `json:"device_sn"` DeviceSN string `json:"device_sn"`
Token string `json:"token"` Token string `json:"token"`
@ -43,39 +48,160 @@ type DeviceJWTClaims struct {
jwt.RegisteredClaims jwt.RegisteredClaims
} }
type SessionInfo struct { // 2. Main & 服务启动逻辑
Session *yamux.Session func main() {
UserID string // 1. 加载配置
config.LoadConfig()
appAccessSecret = []byte(config.Cfg.Auth.AppAccessSecret)
deviceRelaySecret = []byte(config.Cfg.Auth.DeviceRelaySecret)
// 2. 初始化所有模块/管理器
if err := storage.InitRedis(); err != nil {
log.Fatalf("Failed to initialize storage: %v", err)
}
session.InitManager()
peer.InitManager(storage.GlobalRedis.Client)
// --- [修改] 将 HTTP Server 的创建和启动分开 ---
// 创建一个新的 HTTP server mux (路由器)
mux := http.NewServeMux()
mux.HandleFunc("/tunnel/", handleAppRequest)
// 创建一个 http.Server 对象,这样我们稍后可以调用它的 Shutdown 方法
httpServer := &http.Server{
Addr: config.Cfg.Server.AppListenPort,
Handler: mux,
}
// 3. 将所有服务放入后台 goroutine
// (startGRPCServer 和 listenForDevices 内部已经处理好了 goroutine)
if config.Cfg.Redis.Enabled {
go startGRPCServer()
go startServiceDiscovery()
log.Println("Running in CLUSTER mode.")
} else {
log.Println("Running in SINGLE-NODE mode.")
}
go listenForDevices(config.Cfg.Server.DeviceListenPort)
// [修改] 将 HTTP 服务器也放入一个 goroutine 中启动
go func() {
log.Printf("Starting App HTTP server on %s", config.Cfg.Server.AppListenPort)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Failed to start App server: %v", err)
}
log.Println("App HTTP server has stopped.")
}()
// --- 4. 设置优雅停机逻辑 ---
// 创建一个 channel 来等待操作系统信号
shutdownChan := make(chan os.Signal, 1)
signal.Notify(shutdownChan, syscall.SIGINT, syscall.SIGTERM)
// 阻塞 main goroutine,直到收到信号
sig := <-shutdownChan
log.Printf("Shutdown signal received (%s), starting graceful shutdown...", sig)
// --- 5. 执行清理和关闭操作 ---
// a. 创建一个带超时的上下文,用于关闭服务器
// 给服务器一点时间来处理完当前正在进行的请求
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// b. 优雅地关闭 HTTP 服务器
if err := httpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
} else {
log.Println("HTTP server gracefully stopped.")
}
// c. 从服务发现中注销本实例
if config.Cfg.Redis.Enabled && storage.GlobalRedis != nil {
log.Println("Deregistering instance from service discovery...")
key := config.Cfg.Redis.InstanceRegistryKey
instanceID := config.Cfg.Server.InstanceID
// 使用一个独立的上下文,不与 shutdownCtx 关联
storage.GlobalRedis.Client.HDel(context.Background(), key, instanceID)
}
// (未来可以增加关闭 gRPC server 和 TCP listener 的逻辑)
log.Println("Graceful shutdown complete.")
}
// startGRPCServer 启动用于服务器间通信的内部 gRPC 服务
func startGRPCServer() {
lis, err := net.Listen("tcp", config.Cfg.Server.GrpcListenAddr)
if err != nil {
log.Fatalf("Failed to listen for gRPC on %s: %v", config.Cfg.Server.GrpcListenAddr, err)
}
s := grpc.NewServer()
relaypb.RegisterInternalRelayServer(s, grpc_server.NewInternalRelayServer())
log.Printf("Internal gRPC server listening at %s", config.Cfg.Server.GrpcListenAddr)
if err := s.Serve(lis); err != nil {
log.Fatalf("Failed to serve gRPC: %v", err)
}
} }
var ( // startServiceDiscovery 启动一个心跳 goroutine,定期向 Redis 注册本实例
deviceSessions = make(map[string]*SessionInfo) func startServiceDiscovery() {
sessionMutex = &sync.RWMutex{} key := config.Cfg.Redis.InstanceRegistryKey
) instanceID := config.Cfg.Server.InstanceID
addr := config.Cfg.Server.GrpcAdvertiseAddr
ttl := time.Duration(config.Cfg.Redis.InstanceTTLSeconds) * time.Second
// ============================================================================== // 使用 TTL 的一半作为心跳间隔,确保在过期前续期
// 3. Main & 服务器启动逻辑 ticker := time.NewTicker(ttl / 2)
// ============================================================================== defer ticker.Stop()
func main() {
if len(appAccessSecret) == 0 || len(deviceRelaySecret) == 0 { log.Printf("Starting service discovery heartbeat for instance '%s' (%s)", instanceID, addr)
log.Println("WARNING: APP_ACCESS_SECRET or RELAY_SECRET environment variable not set.")
// 立即执行一次,不等第一个 ticker
updateHeartbeat := func() {
// --- [新增] 清理逻辑 ---
// 1. 获取所有已注册的实例
allInstances, err := storage.GlobalRedis.Client.HGetAll(context.Background(), key).Result()
if err != nil {
log.Printf("ERROR: Failed to get all instances for cleanup: %v", err)
// 即使获取失败,我们仍然要继续尝试注册自己
} else {
// 2. 遍历查找与自己地址冲突的旧实例
for oldInstanceID, oldAddr := range allInstances {
// 如果找到一个不同的 instanceID 却使用了相同的地址,
// 并且这个旧 ID 不是我们自己当前的 ID,那么它就是“僵尸”
if oldAddr == addr && oldInstanceID != instanceID {
log.Printf("INFO: Found stale instance '%s' with the same address. Cleaning up...", oldInstanceID)
// 3. 删除僵尸实例
storage.GlobalRedis.Client.HDel(context.Background(), key, oldInstanceID)
}
}
} }
go listenForDevices(":7002") // --- 清理逻辑结束 ---
log.Println("Starting App HTTP server on :8089") err = storage.GlobalRedis.Client.HSet(context.Background(), key, instanceID, addr).Err()
http.HandleFunc("/tunnel/", handleAppRequest) // 统一入口 if err != nil {
if err := http.ListenAndServe(":8089", nil); err != nil { log.Printf("ERROR: failed to heartbeat instance to redis: %v", err)
log.Fatalf("Failed to start App server: %v", err) }
// 为整个 Hash key 设置一个过期时间,以防所有实例都下线后 key 永久存在
storage.GlobalRedis.Client.Expire(context.Background(), key, ttl*2)
}
updateHeartbeat()
for range ticker.C {
updateHeartbeat()
} }
} }
// listenForDevices 监听并接受来自设备的 TCP 连接
func listenForDevices(addr string) { func listenForDevices(addr string) {
log.Printf("Listening for device connections on %s\n", addr)
listener, err := net.Listen("tcp", addr) listener, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
log.Fatalf("Failed to listen for devices: %v", err) log.Fatalf("Failed to listen for devices on %s: %v", addr, err)
} }
defer listener.Close() defer listener.Close()
log.Printf("Listening for device connections on %s", addr)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
@ -88,20 +214,16 @@ func listenForDevices(addr string) {
} }
// ============================================================================== // ==============================================================================
// 4. 设备端认证与会话管 // 3. 设备端会话处
// ============================================================================== // ==============================================================================
func handleDeviceSession(conn net.Conn) { func handleDeviceSession(conn net.Conn) {
defer conn.Close() defer conn.Close()
log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr()) log.Printf("New device connected from %s, awaiting authentication...\n", conn.RemoteAddr())
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
var auth AuthPayload var auth AuthPayload
if err := json.NewDecoder(conn).Decode(&auth); err != nil { if err := json.NewDecoder(conn).Decode(&auth); err != nil { /* ... */
log.Printf("Authentication failed (reading payload): %v", err)
return return
} }
conn.SetReadDeadline(time.Time{})
claims, err := validateDeviceToken(auth.Token) claims, err := validateDeviceToken(auth.Token)
if err != nil { if err != nil {
log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err) log.Printf("Authentication failed for SN %s (token validation): %v", auth.DeviceSN, err)
@ -113,67 +235,51 @@ func handleDeviceSession(conn net.Conn) {
return return
} }
deviceSN := claims.DeviceSN deviceSN, userID := claims.DeviceSN, claims.UserID
userID := claims.UserID log.Printf("Device '%s' (user: %s) authenticated.", deviceSN, userID)
log.Printf("Device '%s' (user: %s) authenticated successfully.\n", deviceSN, userID) // ... [认证逻辑结束] ...
config := yamux.DefaultConfig()
config.EnableKeepAlive = true
config.KeepAliveInterval = 30 * time.Second
session, err := yamux.Server(conn, config) // 启动 yamux 会话
yamuxConfig := yamux.DefaultConfig()
yamuxConfig.EnableKeepAlive = true
yamuxConfig.KeepAliveInterval = 30 * time.Second
s, err := yamux.Server(conn, yamuxConfig)
if err != nil { if err != nil {
log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err) log.Printf("Failed to start yamux session for device '%s': %v", deviceSN, err)
return return
} }
defer session.Close() defer s.Close()
sessionInfo := &SessionInfo{Session: session, UserID: userID} // 1. 添加到本地会话管理器
sessionMutex.Lock() sessionInfo := &session.SessionInfo{Session: s, UserID: userID}
if oldInfo, exists := deviceSessions[deviceSN]; exists { session.GlobalManager.AddSession(deviceSN, sessionInfo)
log.Printf("Device '%s' already connected, closing old session.", deviceSN)
oldInfo.Session.Close()
}
deviceSessions[deviceSN] = sessionInfo
sessionMutex.Unlock()
log.Printf("Yamux session started for device '%s'\n", deviceSN)
defer func() { // 2. 如果启用集群模式,注册到 Redis
sessionMutex.Lock() if storage.GlobalRedis != nil {
if currentInfo, exists := deviceSessions[deviceSN]; exists && currentInfo.Session == session { // 注册的值是本机的实例 ID
delete(deviceSessions, deviceSN) err := storage.GlobalRedis.RegisterDeviceSession(deviceSN, config.Cfg.Server.InstanceID)
if err != nil {
log.Printf("ERROR: %v", err)
} }
sessionMutex.Unlock() // 启动 Redis KeepAlive
log.Printf("Device '%s' session closed\n", deviceSN) go storage.GlobalRedis.KeepAliveSession(s.CloseChan(), deviceSN)
}()
<-session.CloseChan()
}
func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) {
if len(deviceRelaySecret) == 0 {
return nil, errors.New("RELAY_SECRET is not configured on the server")
} }
claims := &DeviceJWTClaims{} // 注册 defer 函数,在会话关闭时清理资源
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) { defer func() {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { session.GlobalManager.RemoveSession(deviceSN, s)
return nil, fmt.Errorf("unexpected signing method for device token") if storage.GlobalRedis != nil {
storage.GlobalRedis.DeregisterDeviceSession(deviceSN)
} }
return deviceRelaySecret, nil log.Printf("Cleaned up resources for device '%s' session.", deviceSN)
}) }()
if err != nil { // 阻塞直到会话关闭
return nil, err <-s.CloseChan()
}
if !token.Valid {
return nil, errors.New("device token is invalid")
}
return claims, nil
} }
// ============================================================================== // ==============================================================================
// 5. App 端认证与请求处理 // 4. App 端请求智能路由
// ============================================================================== // ==============================================================================
func handleAppRequest(w http.ResponseWriter, r *http.Request) { func handleAppRequest(w http.ResponseWriter, r *http.Request) {
pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3) pathParts := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3)
@ -183,88 +289,218 @@ func handleAppRequest(w http.ResponseWriter, r *http.Request) {
} }
deviceSN := pathParts[1] deviceSN := pathParts[1]
// --- [App 认证逻辑 - 暂时注释,需要时取消注释即可] --- // --- [App 认证逻辑] ---
/*
appUserID, err := authenticateAppRequest(r) appUserID, err := authenticateAppRequest(r)
if err != nil { if err != nil {
log.Printf("App authentication failed for device %s: %v", deviceSN, err) log.Printf("App authentication failed for device %s: %v", deviceSN, err)
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
return return
} }
*/
sessionMutex.RLock() // 如果未启用集群模式,直接走本地处理逻辑
sessionInfo, ok := deviceSessions[deviceSN] if !config.Cfg.Redis.Enabled {
sessionMutex.RUnlock() handleLocalRequest(w, r, deviceSN, appUserID)
return
}
if !ok || sessionInfo.Session.IsClosed() { // 集群模式下的路由决策
ownerInstanceID, err := storage.GlobalRedis.GetDeviceOwner(deviceSN)
if err != nil {
if err == redis.Nil {
http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway) http.Error(w, fmt.Sprintf("Device '%s' is not connected", deviceSN), http.StatusBadGateway)
} else {
log.Printf("ERROR: Redis lookup failed for %s: %v", deviceSN, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
return
}
// 判断设备连接是否在本实例上
if ownerInstanceID == config.Cfg.Server.InstanceID {
handleLocalRequest(w, r, deviceSN, appUserID)
} else {
handleRemoteRequest(w, r, ownerInstanceID, appUserID)
}
}
// handleLocalRequest 处理连接在本实例上的设备的请求
func handleLocalRequest(w http.ResponseWriter, r *http.Request, deviceSN string, appUserID string) {
sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN)
if !ok {
log.Printf("WARN: Consistency issue. Device '%s' is registered to this instance but not found in local memory.", deviceSN)
http.Error(w, "Device session not found on this server", http.StatusBadGateway)
if storage.GlobalRedis != nil {
storage.GlobalRedis.DeregisterDeviceSession(deviceSN)
}
return return
} }
/* --- [所有权检查 - 暂时注释] ---
if sessionInfo.UserID != appUserID { if sessionInfo.UserID != appUserID {
log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID) log.Printf("Forbidden: App user '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID)
http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden) http.Error(w, "Forbidden: you do not own this device", http.StatusForbidden)
return return
} }
*/
proxy := &httputil.ReverseProxy{ proxy := session.CreateReverseProxy(sessionInfo, deviceSN, r.URL.Path, r.URL.RawQuery)
Director: func(req *http.Request) { proxy.ServeHTTP(w, r)
// Director 负责重写请求 }
if len(pathParts) > 2 {
req.URL.Path = "/" + pathParts[2] // handleRemoteRequest 将请求通过 gRPC 转发到持有连接的另一个实例
req.URL.RawQuery = r.URL.RawQuery // 确保查询参数也被传递 func handleRemoteRequest(w http.ResponseWriter, r *http.Request, targetInstanceID string, appUserID string) {
} else { // [这部分代码已在之前的回答中提供并解释,这里直接粘贴]
req.URL.Path = "/" deviceSN := strings.SplitN(strings.TrimPrefix(r.URL.Path, "/"), "/", 3)[1]
log.Printf("Forwarding request for device %s to remote instance %s", deviceSN, targetInstanceID)
conn, err := peer.GlobalManager.GetClient(targetInstanceID)
if err != nil {
log.Printf("ERROR: failed to get gRPC client for peer %s: %v", targetInstanceID, err)
http.Error(w, "Service internal error (peer unreachable)", http.StatusInternalServerError)
return
} }
req.URL.Scheme = "http" client := relaypb.NewInternalRelayClient(conn)
req.URL.Host = r.Host // 使用原始请求的 Host
req.Header.Set("X-Real-IP", r.RemoteAddr) ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
}, defer cancel()
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { stream, err := client.ProxyRequest(ctx)
// 劫持连接创建,改为打开一个 yamux 流 if err != nil {
return sessionInfo.Session.Open() log.Printf("ERROR: failed to start gRPC proxy stream to %s: %v", targetInstanceID, err)
}, http.Error(w, "Service internal error (stream failed)", http.StatusInternalServerError)
// 禁用 HTTP/2,因为它与我们的隧道不兼容 return
ForceAttemptHTTP2: false, }
},
FlushInterval: -1, // 支持流式响应 // [新增调试日志]
ModifyResponse: func(resp *http.Response) error { if r.Header.Get("Upgrade") == "websocket" {
// 告知下游代理不要缓冲 log.Printf("DEBUG (WebSocket): handleRemoteRequest received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", r.Header.Get("Connection"), r.Header.Get("Upgrade"))
resp.Header.Set("X-Accel-Buffering", "no") }
return nil
headers := make(map[string]string)
for k, v := range r.Header {
headers[k] = strings.Join(v, ",")
}
// [新增调试日志]
if headers["Upgrade"] == "websocket" {
log.Printf("DEBUG (WebSocket): Packing headers into gRPC message. Headers: Connection='%s', Upgrade='%s'", headers["Connection"], headers["Upgrade"])
}
headerMsg := &relaypb.ProxyRequestMessage{
Payload: &relaypb.ProxyRequestMessage_Header{
Header: &relaypb.ProxyRequestHeader{
Method: r.Method, Url: r.URL.String(), Headers: headers, RemoteAddr: r.RemoteAddr, AppUserId: appUserID,
}, },
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
log.Printf("Reverse proxy error for device %s: %v", deviceSN, err)
http.Error(w, "Error forwarding request", http.StatusBadGateway)
}, },
} }
if err := stream.Send(headerMsg); err != nil {
log.Printf("ERROR: failed to send gRPC request header to %s: %v", targetInstanceID, err)
http.Error(w, "Service internal error (header send failed)", http.StatusInternalServerError)
return
}
log.Printf("Forwarding request for device '%s' to path '%s'", deviceSN, r.URL.Path) go func() {
proxy.ServeHTTP(w, r) defer stream.CloseSend()
if _, err := io.Copy(&grpcStreamWriter{stream: stream}, r.Body); err != nil {
log.Printf("ERROR: failed copying request body to gRPC stream for %s: %v", deviceSN, err)
}
}()
respHeaderMsg, err := stream.Recv()
if err != nil {
log.Printf("ERROR: failed to receive gRPC response header from %s: %v", targetInstanceID, err)
http.Error(w, "Gateway timeout or peer unavailable", http.StatusGatewayTimeout)
return
}
respHeader := respHeaderMsg.GetHeader()
if respHeader == nil {
log.Printf("ERROR: received invalid first message (not a header) from peer %s", targetInstanceID)
http.Error(w, "Internal gateway error (invalid peer response)", http.StatusBadGateway)
return
}
for k, v := range respHeader.Headers {
w.Header().Set(k, v)
}
w.WriteHeader(int(respHeader.StatusCode))
for {
respBodyMsg, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
log.Printf("ERROR: gRPC response stream broke for device %s: %v", deviceSN, err)
break
}
if _, writeErr := w.Write(respBodyMsg.GetBodyChunk().Data); writeErr != nil {
log.Printf("WARN: could not write to client for device %s, client likely disconnected: %v", deviceSN, writeErr)
break
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
}
} }
// authenticateAppRequest 和 verifyAppToken 保持不变,备用 // grpcStreamWriter 是一个辅助类型,实现了 io.Writer 接口
func authenticateAppRequest(r *http.Request) (string, error) { type grpcStreamWriter struct {
authHeader := r.Header.Get("Authorization") stream relaypb.InternalRelay_ProxyRequestClient
if authHeader == "" { }
return "", errors.New("missing Authorization header")
func (w *grpcStreamWriter) Write(p []byte) (n int, err error) {
err = w.stream.Send(&relaypb.ProxyRequestMessage{
Payload: &relaypb.ProxyRequestMessage_BodyChunk{
// [修正] 将 ResponseBodyChunk 改为 RequestBodyChunk
BodyChunk: &relaypb.ProxyRequestBodyChunk{Data: p},
},
})
if err != nil {
return 0, err
}
return len(p), nil
}
func validateDeviceToken(tokenString string) (*DeviceJWTClaims, error) {
if len(deviceRelaySecret) == 0 {
return nil, errors.New("RELAY_SECRET is not configured on the server")
} }
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == authHeader { claims := &DeviceJWTClaims{}
return "", errors.New("authorization header format must be Bearer {token}") token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method for device token")
} }
claims, err := verifyAppToken(tokenString) return deviceRelaySecret, nil
})
if err != nil { if err != nil {
return "", fmt.Errorf("app token verification failed: %w", err) return nil, err
} }
if userID, ok := claims["user_id"].(string); ok { if !token.Valid {
return userID, nil return nil, errors.New("device token is invalid")
} }
return "", errors.New("user_id not found in app token claims") return claims, nil
}
// authenticateAppRequest 和 verifyAppToken 保持不变,备用
func authenticateAppRequest(r *http.Request) (string, error) {
//authHeader := r.Header.Get("Authorization")
//if authHeader == "" {
// return "", errors.New("missing Authorization header")
//}
//tokenString := strings.TrimPrefix(authHeader, "Bearer ")
//if tokenString == authHeader {
// return "", errors.New("authorization header format must be Bearer {token}")
//}
//claims, err := verifyAppToken(tokenString)
//if err != nil {
// return "", fmt.Errorf("app token verification failed: %w", err)
//}
//if userID, ok := claims["user_id"].(string); ok {
// return userID, nil
//}
//return "", errors.New("user_id not found in app token claims")
return "af672ce1-b528-4c18-af7e-e47b09619520", nil
} }
func verifyAppToken(tokenString string) (jwt.MapClaims, error) { func verifyAppToken(tokenString string) (jwt.MapClaims, error) {

68
peer/manager.go

@ -0,0 +1,68 @@
package peer
import (
"context"
"log"
"memobus_relay_server/config"
"sync"
"github.com/redis/go-redis/v9"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Manager 负责管理到其他对等服务器的 gRPC 客户端连接
type Manager struct {
redisClient *redis.Client
clients map[string]*grpc.ClientConn
mu sync.RWMutex
}
var GlobalManager *Manager
func InitManager(redisCli *redis.Client) {
if !config.Cfg.Redis.Enabled {
return // 单机模式下不需要 Peer 管理器
}
GlobalManager = &Manager{
redisClient: redisCli,
clients: make(map[string]*grpc.ClientConn),
}
log.Println("Peer manager initialized for cluster communication.")
}
// GetClient 查找或创建一个到目标实例的 gRPC 客户端连接
func (m *Manager) GetClient(targetInstanceID string) (*grpc.ClientConn, error) {
m.mu.RLock()
client, ok := m.clients[targetInstanceID]
m.mu.RUnlock()
if ok {
return client, nil
}
// 连接未找到, 使用写锁创建一个新的
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查, 以防在我们等待锁的时候, 其他 goroutine 已经创建了它
if client, ok = m.clients[targetInstanceID]; ok {
return client, nil
}
// 从 Redis 发现目标实例的地址
addr, err := m.redisClient.HGet(context.Background(), config.Cfg.Redis.InstanceRegistryKey, targetInstanceID).Result()
if err != nil {
return nil, err
}
log.Printf("Creating new gRPC client connection to peer %s at %s", targetInstanceID, addr)
// 生产环境应使用 TLS 凭证替换 insecure
conn, err := grpc.Dial(addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
m.clients[targetInstanceID] = conn
return conn, nil
}

96
proto/relay.proto

@ -0,0 +1,96 @@
// 使 proto3
syntax = "proto3";
// Go
package relay;
// Go
option go_package = "relay_server/proto";
// -----------------------------------------------------------------------------
// (Service Definition)
// -----------------------------------------------------------------------------
// InternalRelay RPC
service InternalRelay {
// ProxyRequest RPC
// "stream"
//
rpc ProxyRequest(stream ProxyRequestMessage) returns (stream ProxyResponseMessage);
}
// -----------------------------------------------------------------------------
// (Request Messages)
// -----------------------------------------------------------------------------
// ProxyRequestMessage App请求的实例
//
//
// 使 `oneof`
// 使
message ProxyRequestMessage {
oneof payload {
ProxyRequestHeader header = 1;
ProxyRequestBodyChunk body_chunk = 2;
}
}
// ProxyRequestHeader HTTP
//
message ProxyRequestHeader {
// HTTP , "GET", "POST", "PUT"
string method = 1;
// URL
// "/tunnel/DEVICE_SN_123/api/album?page=1&size=10"
string url = 2;
// HTTP
// `map`
map<string, string> headers = 3;
// App IP X-Forwarded-For
string remote_addr = 4;
// App ID
string app_user_id = 5;
}
// ProxyRequestBodyChunk HTTP
// chunk
//
message ProxyRequestBodyChunk {
bytes data = 1;
}
// -----------------------------------------------------------------------------
// (Response Messages)
// -----------------------------------------------------------------------------
// ProxyResponseMessage
// 使 `oneof`
message ProxyResponseMessage {
oneof payload {
ProxyResponseHeader header = 1;
ProxyResponseBodyChunk body_chunk = 2;
}
}
// ProxyResponseHeader HTTP
//
message ProxyResponseHeader {
// HTTP , 200, 404, 500
int32 status_code = 1;
// HTTP
map<string, string> headers = 2;
}
// ProxyResponseBodyChunk HTTP
// 使
// App
message ProxyResponseBodyChunk {
bytes data = 1;
}

167
session/manager.go

@ -0,0 +1,167 @@
package session
import (
"context"
"log"
"net"
"net/http"
"net/http/httputil"
"strings"
"sync"
"github.com/hashicorp/yamux"
)
// SessionInfo 存储了一个活跃的设备连接所需的所有信息。
// 我们将 yamux.Session 和 UserID 绑定在一起。
type SessionInfo struct {
Session *yamux.Session
UserID string
}
// Manager 是会话管理的核心结构体。
// 它只负责管理本实例内存中的会话,不关心 Redis 或其他存储。
type Manager struct {
// localSessions 使用设备 SN 作为 key,存储会话信息。
localSessions map[string]*SessionInfo
// sessionMutex 用于保护对 localSessions 的并发访问。
sessionMutex sync.RWMutex
}
// GlobalManager 是一个全局单例,方便在项目各处调用。
var GlobalManager *Manager
// InitManager 初始化全局的会P话管理器。
func InitManager() {
GlobalManager = &Manager{
localSessions: make(map[string]*SessionInfo),
}
log.Println("Local session manager initialized.")
}
// AddSession 向管理器中添加一个新的设备会话。
// 如果已存在同名会话,它会先关闭旧的,再添加新的。
func (m *Manager) AddSession(deviceSN string, info *SessionInfo) {
m.sessionMutex.Lock()
defer m.sessionMutex.Unlock()
// 如果设备重连,旧的会话可能还存在,需要先关闭它
if oldInfo, exists := m.localSessions[deviceSN]; exists {
log.Printf("Device '%s' already has a local session, closing the old one.", deviceSN)
oldInfo.Session.Close()
}
m.localSessions[deviceSN] = info
log.Printf("Local session for device '%s' has been added.", deviceSN)
}
// RemoveSession 从管理器中移除一个设备会话。
// 它会检查传入的 session 对象是否与当前存储的一致,防止误删新会话。
func (m *Manager) RemoveSession(deviceSN string, session *yamux.Session) {
m.sessionMutex.Lock()
defer m.sessionMutex.Unlock()
// 这是一个重要的检查:确保我们删除的是正确的、已经过期的会话,
// 而不是一个刚刚建立的新会话(万一发生竞争)。
if currentInfo, exists := m.localSessions[deviceSN]; exists && currentInfo.Session == session {
delete(m.localSessions, deviceSN)
log.Printf("Local session for device '%s' has been removed.", deviceSN)
}
}
// GetLocalSession 根据设备 SN 查找一个活跃的本地会话。
// 这是最常用的查询方法。
func (m *Manager) GetLocalSession(deviceSN string) (*SessionInfo, bool) {
m.sessionMutex.RLock()
defer m.sessionMutex.RUnlock()
info, ok := m.localSessions[deviceSN]
if ok && !info.Session.IsClosed() {
// 确保会话不仅存在,而且是活跃的
return info, true
}
return nil, false
}
// CreateReverseProxy 是一个辅助函数,用于创建一个配置好的 httputil.ReverseProxy。
// 将这个逻辑放在这里,是因为它与 SessionInfo 强相关,可以被 main.go 和 grpc/server.go 复用。
func CreateReverseProxy(sessionInfo *SessionInfo, deviceSN string, originalPath string, originalQuery string) *httputil.ReverseProxy {
return &httputil.ReverseProxy{
// Director 负责在请求被转发前,修改请求的 URL、Header 等。
Director: func(req *http.Request) {
// [新增日志] 如果是 WebSocket 请求,就打印它
if isWebSocketRequest(req) { // isWebSocketRequest 是我们之前写的辅助函数
// true 表示连同 body 一起打印,对于握手请求 body 为空
reqDump, _ := httputil.DumpRequestOut(req, true)
log.Printf("--- [SUCCESS CASE] ReverseProxy is about to send this WebSocket request:\n%s\n-------------------------------------------------", string(reqDump))
}
// 从原始请求路径中解析出要转发到 immich 的真正路径
// 例如,从 "/tunnel/SN123/api/album" -> "/api/album"
pathParts := strings.SplitN(strings.TrimPrefix(originalPath, "/"), "/", 3)
if len(pathParts) > 2 {
req.URL.Path = "/" + pathParts[2]
} else {
req.URL.Path = "/"
}
req.URL.RawQuery = originalQuery // 传递原始的查询参数
req.URL.Scheme = "http"
// Host 不重要,因为我们下面会劫持网络连接 (DialContext)
req.URL.Host = "immich-internal"
// 设置 X-Real-IP 头,让 immich 知道原始客户端的 IP
req.Header.Set("X-Real-IP", req.RemoteAddr)
},
// Transport 负责实际的请求发送。我们通过重写 DialContext 来劫持它。
Transport: &http.Transport{
// 这是整个隧道转发的核心:
// 当 ReverseProxy 尝试建立一个 TCP 连接到 "immich-internal" 时,
// 我们不进行真正的网络拨号,而是直接在 yamux 会话上打开一个新的流 (stream)。
// 这个流就等同于一个虚拟的 TCP 连接,直接通往设备端的 immich 容器。
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return sessionInfo.Session.Open()
},
// 必须禁用 HTTP/2,因为它与我们的简单流转发不兼容。
ForceAttemptHTTP2: false,
},
// FlushInterval 设置为 -1 会禁用缓冲,立即将数据块发送出去。
// 这对于视频流和 WebSocket 至关重要。
FlushInterval: -1,
// ModifyResponse 允许我们在响应返回给客户端之前修改它。
ModifyResponse: func(resp *http.Response) error {
// [新增调试日志]
// 这是一个关键探针!
if resp.StatusCode == http.StatusSwitchingProtocols { // 101
log.Printf("DEBUG (WebSocket): ModifyResponse received '101 Switching Protocols'. This means backend handshake was successful!")
}
// 这个 Header 告诉上游的代理(如 Nginx)不要缓冲这个响应。
resp.Header.Set("X-Accel-Buffering", "no")
return nil
},
// ErrorHandler 定义了当转发过程中发生错误(如设备端断开连接)时的处理逻辑。
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
// [新增调试日志]
// 这是另一个关键探针!
if r.Header.Get("Upgrade") == "websocket" {
log.Printf("DEBUG (WebSocket): ErrorHandler was triggered for a WebSocket request. Error: %v", err)
}
log.Printf("ERROR: Reverse proxy error for device %s: %v", deviceSN, err)
http.Error(w, "Error forwarding request to device", http.StatusBadGateway)
},
}
}
// [新增] 确保 isWebSocketRequest 辅助函数存在于 session/manager.go
func isWebSocketRequest(r *http.Request) bool {
upgradeHeader := strings.ToLower(r.Header.Get("Upgrade"))
if upgradeHeader != "websocket" {
return false
}
connectionHeader := strings.ToLower(r.Header.Get("Connection"))
return strings.Contains(connectionHeader, "upgrade")
}

126
storage/redis.go

@ -0,0 +1,126 @@
// 文件: storage/redis.go
package storage
import (
"context"
"fmt"
"github.com/redis/go-redis/v9"
"log"
"memobus_relay_server/config" // 替换为你的模块名
"time"
)
// RedisManager 结构体封装了所有与 Redis 相关的操作
type RedisManager struct {
Client *redis.Client
sessionTTL time.Duration
}
// GlobalRedis 是一个全局可访问的 RedisManager 实例
var GlobalRedis *RedisManager
// InitRedis 初始化 Redis 连接并创建全局的 RedisManager 实例
// 如果配置中 Redis 未启用,则返回 nil
func InitRedis() error {
if !config.Cfg.Redis.Enabled {
log.Println("Redis is disabled in config. Skipping initialization.")
return nil
}
client := redis.NewClient(&redis.Options{
Addr: config.Cfg.Redis.Addr,
Password: config.Cfg.Redis.Password,
DB: config.Cfg.Redis.DB,
})
if err := client.Ping(context.Background()).Err(); err != nil {
return fmt.Errorf("failed to connect to Redis: %w", err)
}
GlobalRedis = &RedisManager{
Client: client,
sessionTTL: time.Duration(config.Cfg.Redis.SessionTTLSeconds) * time.Second,
}
log.Println("Successfully connected to Redis.")
return nil
}
// getRedisKey 生成设备会话在 Redis 中的 key
func getRedisKey(deviceSN string) string {
return fmt.Sprintf("device_session:%s", deviceSN)
}
// RegisterDeviceSession 将设备标记为在线
// 在单机模式下,value 可以是一个简单的占位符,如 "online"
func (m *RedisManager) RegisterDeviceSession(deviceSN string, value string) error {
key := getRedisKey(deviceSN)
err := m.Client.Set(context.Background(), key, value, m.sessionTTL).Err()
if err != nil {
return fmt.Errorf("failed to register device '%s' to Redis: %w", deviceSN, err)
}
log.Printf("Device '%s' registered in Redis.", deviceSN)
return nil
}
// DeregisterDeviceSession 从 Redis 中移除设备会话
func (m *RedisManager) DeregisterDeviceSession(deviceSN string) {
key := getRedisKey(deviceSN)
m.Client.Del(context.Background(), key)
log.Printf("Device '%s' deregistered from Redis.", deviceSN)
}
// IsDeviceOnline 检查设备是否在 Redis 中被标记为在线
func (m *RedisManager) IsDeviceOnline(deviceSN string) (bool, error) {
key := getRedisKey(deviceSN)
val, err := m.Client.Get(context.Background(), key).Result()
if err == redis.Nil {
return false, nil // Key 不存在,明确表示不在线
}
if err != nil {
return false, fmt.Errorf("redis error looking up device '%s': %w", deviceSN, err)
}
return val != "", nil // 只要 key 存在且值不为空,就认为在线
}
// [新增] GetDeviceOwner 函数,用来获取持有连接的实例 ID
func (m *RedisManager) GetDeviceOwner(deviceSN string) (string, error) {
key := getRedisKey(deviceSN)
instanceID, err := m.Client.Get(context.Background(), key).Result()
if err != nil {
// 让调用者处理 redis.Nil 错误,这表示设备未找到
return "", err
}
return instanceID, nil
}
// KeepAliveSession 启动一个 goroutine,为给定的设备会话在 Redis 中定期续期
func (m *RedisManager) KeepAliveSession(closeChan <-chan struct{}, deviceSN string) {
// 以 TTL 的一半作为续期间隔
ticker := time.NewTicker(m.sessionTTL / 2)
defer ticker.Stop()
key := getRedisKey(deviceSN)
log.Printf("Starting Redis keep-alive for device '%s'.", deviceSN)
for {
select {
case <-ticker.C:
// 为 key 续期
err := m.Client.Expire(context.Background(), key, m.sessionTTL).Err()
if err != nil {
// 如果 key 不存在了(可能被手动删除或过期),就没必要再续了
if err == redis.Nil {
log.Printf("Redis key for %s no longer exists, stopping keep-alive.", deviceSN)
return
}
log.Printf("ERROR: Failed to refresh session TTL for %s in Redis: %v", deviceSN, err)
}
case <-closeChan:
// session 关闭了,退出 goroutine
log.Printf("Stopping Redis keep-alive for device '%s' due to session close.", deviceSN)
return
}
}
}
Loading…
Cancel
Save