You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							167 lines
						
					
					
						
							6.5 KiB
						
					
					
				
			
		
		
		
			
			
			
				
					
				
				
					
				
			
		
		
	
	
							167 lines
						
					
					
						
							6.5 KiB
						
					
					
				
								package session
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"context"
							 | 
						|
									"log"
							 | 
						|
									"net"
							 | 
						|
									"net/http"
							 | 
						|
									"net/http/httputil"
							 | 
						|
									"strings"
							 | 
						|
									"sync"
							 | 
						|
								
							 | 
						|
									"github.com/hashicorp/yamux"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// SessionInfo 存储了一个活跃的设备连接所需的所有信息。
							 | 
						|
								// 我们将 yamux.Session 和 UserID 绑定在一起。
							 | 
						|
								type SessionInfo struct {
							 | 
						|
									Session *yamux.Session
							 | 
						|
									UserID  string
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Manager 是会话管理的核心结构体。
							 | 
						|
								// 它只负责管理本实例内存中的会话,不关心 Redis 或其他存储。
							 | 
						|
								type Manager struct {
							 | 
						|
									// localSessions 使用设备 SN 作为 key,存储会话信息。
							 | 
						|
									localSessions map[string]*SessionInfo
							 | 
						|
									// sessionMutex 用于保护对 localSessions 的并发访问。
							 | 
						|
									sessionMutex sync.RWMutex
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GlobalManager 是一个全局单例,方便在项目各处调用。
							 | 
						|
								var GlobalManager *Manager
							 | 
						|
								
							 | 
						|
								// InitManager 初始化全局的会P话管理器。
							 | 
						|
								func InitManager() {
							 | 
						|
									GlobalManager = &Manager{
							 | 
						|
										localSessions: make(map[string]*SessionInfo),
							 | 
						|
									}
							 | 
						|
									log.Println("Local session manager initialized.")
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// AddSession 向管理器中添加一个新的设备会话。
							 | 
						|
								// 如果已存在同名会话,它会先关闭旧的,再添加新的。
							 | 
						|
								func (m *Manager) AddSession(deviceSN string, info *SessionInfo) {
							 | 
						|
									m.sessionMutex.Lock()
							 | 
						|
									defer m.sessionMutex.Unlock()
							 | 
						|
								
							 | 
						|
									// 如果设备重连,旧的会话可能还存在,需要先关闭它
							 | 
						|
									if oldInfo, exists := m.localSessions[deviceSN]; exists {
							 | 
						|
										log.Printf("Device '%s' already has a local session, closing the old one.", deviceSN)
							 | 
						|
										oldInfo.Session.Close()
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									m.localSessions[deviceSN] = info
							 | 
						|
									log.Printf("Local session for device '%s' has been added.", deviceSN)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// RemoveSession 从管理器中移除一个设备会话。
							 | 
						|
								// 它会检查传入的 session 对象是否与当前存储的一致,防止误删新会话。
							 | 
						|
								func (m *Manager) RemoveSession(deviceSN string, session *yamux.Session) {
							 | 
						|
									m.sessionMutex.Lock()
							 | 
						|
									defer m.sessionMutex.Unlock()
							 | 
						|
								
							 | 
						|
									// 这是一个重要的检查:确保我们删除的是正确的、已经过期的会话,
							 | 
						|
									// 而不是一个刚刚建立的新会话(万一发生竞争)。
							 | 
						|
									if currentInfo, exists := m.localSessions[deviceSN]; exists && currentInfo.Session == session {
							 | 
						|
										delete(m.localSessions, deviceSN)
							 | 
						|
										log.Printf("Local session for device '%s' has been removed.", deviceSN)
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GetLocalSession 根据设备 SN 查找一个活跃的本地会话。
							 | 
						|
								// 这是最常用的查询方法。
							 | 
						|
								func (m *Manager) GetLocalSession(deviceSN string) (*SessionInfo, bool) {
							 | 
						|
									m.sessionMutex.RLock()
							 | 
						|
									defer m.sessionMutex.RUnlock()
							 | 
						|
								
							 | 
						|
									info, ok := m.localSessions[deviceSN]
							 | 
						|
									if ok && !info.Session.IsClosed() {
							 | 
						|
										// 确保会话不仅存在,而且是活跃的
							 | 
						|
										return info, true
							 | 
						|
									}
							 | 
						|
									return nil, false
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CreateReverseProxy 是一个辅助函数,用于创建一个配置好的 httputil.ReverseProxy。
							 | 
						|
								// 将这个逻辑放在这里,是因为它与 SessionInfo 强相关,可以被 main.go 和 grpc/server.go 复用。
							 | 
						|
								func CreateReverseProxy(sessionInfo *SessionInfo, deviceSN string, originalPath string, originalQuery string) *httputil.ReverseProxy {
							 | 
						|
									return &httputil.ReverseProxy{
							 | 
						|
										// Director 负责在请求被转发前,修改请求的 URL、Header 等。
							 | 
						|
										Director: func(req *http.Request) {
							 | 
						|
											// [新增日志] 如果是 WebSocket 请求,就打印它
							 | 
						|
											if isWebSocketRequest(req) { // isWebSocketRequest 是我们之前写的辅助函数
							 | 
						|
												// true 表示连同 body 一起打印,对于握手请求 body 为空
							 | 
						|
												reqDump, _ := httputil.DumpRequestOut(req, true)
							 | 
						|
												log.Printf("--- [SUCCESS CASE] ReverseProxy is about to send this WebSocket request:\n%s\n-------------------------------------------------", string(reqDump))
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											// 从原始请求路径中解析出要转发到 immich 的真正路径
							 | 
						|
											// 例如,从 "/tunnel/SN123/api/album" -> "/api/album"
							 | 
						|
											pathParts := strings.SplitN(strings.TrimPrefix(originalPath, "/"), "/", 3)
							 | 
						|
											if len(pathParts) > 2 {
							 | 
						|
												req.URL.Path = "/" + pathParts[2]
							 | 
						|
											} else {
							 | 
						|
												req.URL.Path = "/"
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											req.URL.RawQuery = originalQuery // 传递原始的查询参数
							 | 
						|
											req.URL.Scheme = "http"
							 | 
						|
											// Host 不重要,因为我们下面会劫持网络连接 (DialContext)
							 | 
						|
											req.URL.Host = "immich-internal"
							 | 
						|
											// 设置 X-Real-IP 头,让 immich 知道原始客户端的 IP
							 | 
						|
											req.Header.Set("X-Real-IP", req.RemoteAddr)
							 | 
						|
										},
							 | 
						|
								
							 | 
						|
										// Transport 负责实际的请求发送。我们通过重写 DialContext 来劫持它。
							 | 
						|
										Transport: &http.Transport{
							 | 
						|
											// 这是整个隧道转发的核心:
							 | 
						|
											// 当 ReverseProxy 尝试建立一个 TCP 连接到 "immich-internal" 时,
							 | 
						|
											// 我们不进行真正的网络拨号,而是直接在 yamux 会话上打开一个新的流 (stream)。
							 | 
						|
											// 这个流就等同于一个虚拟的 TCP 连接,直接通往设备端的 immich 容器。
							 | 
						|
											DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
							 | 
						|
												return sessionInfo.Session.Open()
							 | 
						|
											},
							 | 
						|
											// 必须禁用 HTTP/2,因为它与我们的简单流转发不兼容。
							 | 
						|
											ForceAttemptHTTP2: false,
							 | 
						|
										},
							 | 
						|
								
							 | 
						|
										// FlushInterval 设置为 -1 会禁用缓冲,立即将数据块发送出去。
							 | 
						|
										// 这对于视频流和 WebSocket 至关重要。
							 | 
						|
										FlushInterval: -1,
							 | 
						|
								
							 | 
						|
										// ModifyResponse 允许我们在响应返回给客户端之前修改它。
							 | 
						|
										ModifyResponse: func(resp *http.Response) error {
							 | 
						|
											// [新增调试日志]
							 | 
						|
											// 这是一个关键探针!
							 | 
						|
											if resp.StatusCode == http.StatusSwitchingProtocols { // 101
							 | 
						|
												log.Printf("DEBUG (WebSocket): ModifyResponse received '101 Switching Protocols'. This means backend handshake was successful!")
							 | 
						|
											}
							 | 
						|
											// 这个 Header 告诉上游的代理(如 Nginx)不要缓冲这个响应。
							 | 
						|
											resp.Header.Set("X-Accel-Buffering", "no")
							 | 
						|
											return nil
							 | 
						|
										},
							 | 
						|
								
							 | 
						|
										// ErrorHandler 定义了当转发过程中发生错误(如设备端断开连接)时的处理逻辑。
							 | 
						|
										ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
							 | 
						|
											// [新增调试日志]
							 | 
						|
											// 这是另一个关键探针!
							 | 
						|
											if r.Header.Get("Upgrade") == "websocket" {
							 | 
						|
												log.Printf("DEBUG (WebSocket): ErrorHandler was triggered for a WebSocket request. Error: %v", err)
							 | 
						|
											}
							 | 
						|
								
							 | 
						|
											log.Printf("ERROR: Reverse proxy error for device %s: %v", deviceSN, err)
							 | 
						|
											http.Error(w, "Error forwarding request to device", http.StatusBadGateway)
							 | 
						|
										},
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// [新增] 确保 isWebSocketRequest 辅助函数存在于 session/manager.go
							 | 
						|
								func isWebSocketRequest(r *http.Request) bool {
							 | 
						|
									upgradeHeader := strings.ToLower(r.Header.Get("Upgrade"))
							 | 
						|
									if upgradeHeader != "websocket" {
							 | 
						|
										return false
							 | 
						|
									}
							 | 
						|
									connectionHeader := strings.ToLower(r.Header.Get("Connection"))
							 | 
						|
									return strings.Contains(connectionHeader, "upgrade")
							 | 
						|
								}
							 | 
						|
								
							 |