You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
360 lines
11 KiB
360 lines
11 KiB
// 文件: grpc/server.go
|
|
package grpc
|
|
|
|
import (
|
|
"bufio"
|
|
"io"
|
|
"log"
|
|
relaypb "memobus_relay_server/relay_server/proto"
|
|
"memobus_relay_server/session"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
// InternalRelayServer 实现了用于代理请求的 gRPC 服务
|
|
type InternalRelayServer struct {
|
|
relaypb.UnimplementedInternalRelayServer
|
|
}
|
|
|
|
// NewInternalRelayServer 创建一个新的 gRPC 服务实例
|
|
func NewInternalRelayServer() *InternalRelayServer {
|
|
return &InternalRelayServer{}
|
|
}
|
|
|
|
// ProxyRequest 是核心的 gRPC 流处理器,实现了完整的请求代理和流式响应
|
|
func (s *InternalRelayServer) ProxyRequest(stream relaypb.InternalRelay_ProxyRequestServer) error {
|
|
// --- 1. 接收和解析请求头 ---
|
|
headerMsg, err := stream.Recv()
|
|
if err != nil {
|
|
log.Printf("ERROR (gRPC): Failed to receive initial header: %v", err)
|
|
return status.Errorf(codes.InvalidArgument, "failed to receive header: %v", err)
|
|
}
|
|
header := headerMsg.GetHeader()
|
|
if header == nil {
|
|
return status.Errorf(codes.InvalidArgument, "first message must be a header")
|
|
}
|
|
|
|
// [新增调试日志]
|
|
if header.Headers["Upgrade"] == "websocket" {
|
|
log.Printf("DEBUG (WebSocket): gRPC server received WebSocket upgrade request. Headers: Connection='%s', Upgrade='%s'", header.Headers["Connection"], header.Headers["Upgrade"])
|
|
}
|
|
|
|
// 检查是否是 WebSocket 握手请求
|
|
isWebSocket := header.Headers["Upgrade"] == "websocket" && strings.Contains(strings.ToLower(header.Headers["Connection"]), "upgrade")
|
|
|
|
pathParts := strings.SplitN(strings.TrimPrefix(header.Url, "/"), "/", 3)
|
|
if len(pathParts) < 2 {
|
|
return status.Errorf(codes.InvalidArgument, "invalid URL format in gRPC header")
|
|
}
|
|
deviceSN := pathParts[1]
|
|
appUserID := header.GetAppUserId()
|
|
|
|
if appUserID == "" {
|
|
return status.Errorf(codes.InvalidArgument, "app_user_id is missing in gRPC header")
|
|
}
|
|
|
|
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
|
|
|
|
// --- 2. 查找本地会话并进行授权检查 ---
|
|
sessionInfo, ok := session.GlobalManager.GetLocalSession(deviceSN)
|
|
if !ok {
|
|
return status.Errorf(codes.NotFound, "device '%s' not connected to this instance", deviceSN)
|
|
}
|
|
|
|
if sessionInfo.UserID != appUserID {
|
|
log.Printf("Forbidden (gRPC): User '%s' attempted to access device '%s' owned by '%s'", appUserID, deviceSN, sessionInfo.UserID)
|
|
return sendForbiddenResponse(stream)
|
|
}
|
|
|
|
log.Printf("gRPC Proxy: Handling request for device '%s' from user '%s'", deviceSN, appUserID)
|
|
|
|
// --- 3. [核心修改] 根据请求类型进行分流 ---
|
|
if isWebSocket {
|
|
log.Println("gRPC Proxy: Detected WebSocket request, diverting to transparent proxy handler.")
|
|
return s.handleWebSocketProxy(stream, sessionInfo, deviceSN, header)
|
|
} else {
|
|
// 如果是普通 HTTP, 调用原来的 ReverseProxy 处理器
|
|
log.Println("gRPC Proxy: Detected HTTP request, using ReverseProxy handler.")
|
|
// 注意:我把原来的 ProxyRequest 逻辑提取到了一个新函数中,以保持整洁
|
|
return s.handleHTTPProxy(stream, sessionInfo, deviceSN, header)
|
|
}
|
|
}
|
|
|
|
func (s *InternalRelayServer) handleWebSocketProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
|
|
// 1. 打开到后端 (yamux) 的连接
|
|
backendConn, err := sessionInfo.Session.Open()
|
|
if err != nil {
|
|
log.Printf("ERROR (WebSocket Proxy): Failed to dial backend: %v", err)
|
|
return status.Errorf(codes.Internal, "failed to connect to backend service")
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
// 2. 重建原始的 HTTP 升级请求
|
|
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, nil)
|
|
for k, v := range header.Headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.Host = "immich-internal" // 模拟 ReverseProxy 的行为
|
|
pathParts := strings.SplitN(strings.TrimPrefix(req.URL.Path, "/"), "/", 3)
|
|
if len(pathParts) > 2 {
|
|
req.URL.Path = "/" + pathParts[2]
|
|
} else {
|
|
req.URL.Path = "/"
|
|
}
|
|
|
|
// 3. 将升级请求写入后端连接,发起握手
|
|
if err := req.Write(backendConn); err != nil {
|
|
log.Printf("ERROR (WebSocket Proxy): Failed to write upgrade request to backend: %v", err)
|
|
return status.Errorf(codes.Internal, "failed to send upgrade request to backend")
|
|
}
|
|
|
|
// 4. 读取后端的响应 (握手结果)
|
|
backendReader := bufio.NewReader(backendConn)
|
|
resp, err := http.ReadResponse(backendReader, req)
|
|
if err != nil {
|
|
log.Printf("ERROR (WebSocket Proxy): Failed to read handshake response from backend: %v", err)
|
|
return status.Errorf(codes.Internal, "failed to read handshake response from backend")
|
|
}
|
|
|
|
// 5. 将后端的握手响应通过 gRPC 发回给代理节点
|
|
respHeaderMsg := &relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_Header{
|
|
Header: &relaypb.ProxyResponseHeader{
|
|
StatusCode: int32(resp.StatusCode),
|
|
Headers: make(map[string]string),
|
|
},
|
|
},
|
|
}
|
|
for k, v := range resp.Header {
|
|
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
|
|
}
|
|
if err := stream.Send(respHeaderMsg); err != nil {
|
|
log.Printf("ERROR (WebSocket Proxy): Failed to send handshake response via gRPC: %v", err)
|
|
return err
|
|
}
|
|
|
|
// 6. 如果握手失败 (不是 101),则流程结束
|
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
|
log.Printf("WARN (WebSocket Proxy): Backend returned non-101 status for upgrade: %d", resp.StatusCode)
|
|
return nil
|
|
}
|
|
|
|
log.Printf("WebSocket handshake for device %s successful. Starting bi-directional stream copy.", deviceSN)
|
|
|
|
// 7. 握手成功!现在在 gRPC 流和 yamux 流之间建立双向数据拷贝
|
|
errChan := make(chan error, 2)
|
|
|
|
// Goroutine 1: gRPC 请求流 (来自 App) -> yamux 流 (下行数据)
|
|
go func() {
|
|
// 这个方向的逻辑没有问题
|
|
for {
|
|
msg, err := stream.Recv()
|
|
if err == io.EOF {
|
|
backendConn.Close()
|
|
errChan <- nil
|
|
return
|
|
}
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
if chunk := msg.GetBodyChunk(); chunk != nil {
|
|
if _, err := backendConn.Write(chunk.Data); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// Goroutine 2: yamux 流 (来自设备) -> gRPC 响应流 (上行数据)
|
|
go func() {
|
|
// [核心修正]
|
|
// 我们必须从 backendReader (而不是原始的 backendConn) 开始读取,
|
|
// 以确保 http.ReadResponse 预读到缓冲区的数据不会丢失。
|
|
// io.Copy 会首先清空 backendReader 的内部缓冲区,然后再继续从底层的 backendConn 读取。
|
|
if _, err := io.Copy(&grpcResponseWriter{stream: stream}, backendReader); err != nil {
|
|
// 过滤掉正常的连接关闭错误
|
|
if err != io.EOF && err != io.ErrClosedPipe && !strings.Contains(err.Error(), "use of closed") {
|
|
errChan <- err
|
|
} else {
|
|
errChan <- nil
|
|
}
|
|
} else {
|
|
errChan <- nil
|
|
}
|
|
}()
|
|
|
|
// 等待两个 goroutine 都结束
|
|
err1 := <-errChan
|
|
err2 := <-errChan
|
|
|
|
if err1 != nil && err1 != io.EOF {
|
|
log.Printf("WebSocket stream finished with error: %v", err1)
|
|
return err1
|
|
}
|
|
if err2 != nil && err2 != io.EOF {
|
|
log.Printf("WebSocket stream finished with error: %v", err2)
|
|
return err2
|
|
}
|
|
|
|
log.Printf("WebSocket stream for device %s finished gracefully.", deviceSN)
|
|
return nil
|
|
}
|
|
|
|
// [新增] handleHTTPProxy 包含了原来 ProxyRequest 的所有逻辑
|
|
func (s *InternalRelayServer) handleHTTPProxy(stream relaypb.InternalRelay_ProxyRequestServer, sessionInfo *session.SessionInfo, deviceSN string, header *relaypb.ProxyRequestHeader) error {
|
|
// 这部分代码就是你之前工作正常的、使用 io.Pipe 和 ReverseProxy 的完整流式版本
|
|
// 我直接粘贴过来
|
|
|
|
// --- 3. 创建请求和响应的管道 ---
|
|
reqPr, reqPw := io.Pipe()
|
|
req := httptest.NewRequest(header.Method, "http://internal-proxy"+header.Url, reqPr)
|
|
for k, v := range header.Headers {
|
|
req.Header.Set(k, v)
|
|
}
|
|
req.Header.Set("X-Forwarded-For", header.RemoteAddr)
|
|
|
|
respPr, respPw := io.Pipe()
|
|
customResponseWriter := &streamResponseWriter{
|
|
header: make(http.Header),
|
|
pipeWriter: respPw,
|
|
headerWritten: make(chan struct{}),
|
|
}
|
|
|
|
// --- 4. 启动 Goroutines ---
|
|
go func() {
|
|
defer reqPw.Close()
|
|
for {
|
|
bodyMsg, err := stream.Recv()
|
|
if err == io.EOF {
|
|
return
|
|
}
|
|
if err != nil {
|
|
reqPw.CloseWithError(err)
|
|
return
|
|
}
|
|
if bodyChunk := bodyMsg.GetBodyChunk(); bodyChunk != nil {
|
|
if _, err := reqPw.Write(bodyChunk.Data); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
errChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(errChan)
|
|
<-customResponseWriter.headerWritten
|
|
// b. [修正] 完整地构造 gRPC 响应头
|
|
respHeaderMsg := &relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_Header{
|
|
Header: &relaypb.ProxyResponseHeader{
|
|
StatusCode: int32(customResponseWriter.statusCode),
|
|
Headers: make(map[string]string),
|
|
},
|
|
},
|
|
}
|
|
for k, v := range customResponseWriter.header {
|
|
respHeaderMsg.GetHeader().Headers[k] = strings.Join(v, ",")
|
|
}
|
|
if err := stream.Send(respHeaderMsg); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, 1024*32)
|
|
if _, err := io.CopyBuffer(&grpcResponseWriter{stream: stream}, respPr, buf); err != nil {
|
|
if err != io.ErrClosedPipe {
|
|
errChan <- err
|
|
}
|
|
}
|
|
}()
|
|
|
|
// --- 5. 执行代理 ---
|
|
proxy := session.CreateReverseProxy(sessionInfo, deviceSN, req.URL.Path, req.URL.RawQuery)
|
|
proxy.ServeHTTP(customResponseWriter, req)
|
|
|
|
// --- 6. 清理 ---
|
|
respPw.Close()
|
|
return <-errChan
|
|
}
|
|
|
|
// sendForbiddenResponse 是一个辅助函数,用于发送模拟的 403 响应
|
|
func sendForbiddenResponse(stream relaypb.InternalRelay_ProxyRequestServer) error {
|
|
respHeader := &relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_Header{
|
|
Header: &relaypb.ProxyResponseHeader{
|
|
StatusCode: http.StatusForbidden,
|
|
Headers: map[string]string{"Content-Type": "text/plain; charset=utf-8"},
|
|
},
|
|
},
|
|
}
|
|
if err := stream.Send(respHeader); err != nil {
|
|
return err
|
|
}
|
|
respBody := &relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
|
|
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: []byte("Forbidden")},
|
|
},
|
|
}
|
|
stream.Send(respBody)
|
|
return nil // 正常关闭流
|
|
}
|
|
|
|
// streamResponseWriter 是一个自定义的 http.ResponseWriter
|
|
type streamResponseWriter struct {
|
|
header http.Header
|
|
pipeWriter *io.PipeWriter
|
|
statusCode int
|
|
headerWritten chan struct{}
|
|
}
|
|
|
|
func (w *streamResponseWriter) Header() http.Header {
|
|
return w.header
|
|
}
|
|
|
|
func (w *streamResponseWriter) Write(b []byte) (int, error) {
|
|
w.WriteHeader(http.StatusOK)
|
|
return w.pipeWriter.Write(b)
|
|
}
|
|
|
|
func (w *streamResponseWriter) WriteHeader(statusCode int) {
|
|
select {
|
|
case <-w.headerWritten:
|
|
return
|
|
default:
|
|
w.statusCode = statusCode
|
|
close(w.headerWritten)
|
|
}
|
|
}
|
|
|
|
// grpcResponseWriter 是一个适配器,实现了 io.Writer 接口
|
|
type grpcResponseWriter struct {
|
|
stream relaypb.InternalRelay_ProxyRequestServer
|
|
}
|
|
|
|
func (w *grpcResponseWriter) Write(p []byte) (n int, err error) {
|
|
err = w.stream.Send(&relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
|
|
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: p},
|
|
},
|
|
})
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), nil
|
|
}
|
|
|
|
// writeChunk 辅助函数 - 确保这个函数也存在于你的 grpc/server.go 文件中
|
|
func writeChunk(stream relaypb.InternalRelay_ProxyRequestServer, data []byte) error {
|
|
return stream.Send(&relaypb.ProxyResponseMessage{
|
|
Payload: &relaypb.ProxyResponseMessage_BodyChunk{
|
|
BodyChunk: &relaypb.ProxyResponseBodyChunk{Data: data},
|
|
},
|
|
})
|
|
}
|
|
|