package dispatcher import ( "context" "fmt" "log/slog" "net" "proxy-server/gateway/core" "proxy-server/gateway/fwd/http" "proxy-server/gateway/fwd/metrics" "proxy-server/gateway/fwd/socks" "proxy-server/utils" "strconv" "strings" "time" "errors" "github.com/soheilhy/cmux" ) type Server struct { ctx context.Context cancel context.CancelFunc readTimeout time.Duration Port uint16 Conn chan *core.Conn } func New(port uint16, readTimeout time.Duration) (*Server, error) { if port == 0 { return nil, errors.New("port is required") } ctx, cancel := context.WithCancel(context.Background()) return &Server{ ctx, cancel, readTimeout, port, make(chan *core.Conn), }, nil } func (s *Server) Stop() { s.cancel() } func (s *Server) Run() error { port := strconv.Itoa(int(s.Port)) ls, err := net.Listen("tcp", ":"+port) if err != nil { return fmt.Errorf("dispatcher 监听失败: %w", err) } defer utils.Close(ls) m := cmux.New(ls) m.SetReadTimeout(s.readTimeout) defer m.Close() socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05}))) go func() { err = s.acceptSocks(socksLs) if err != nil { if strings.Contains(err.Error(), "mux: server closed") { return } slog.Warn("dispatcher socks accept error", "err", err) } }() httpLs := m.Match(cmux.HTTP1Fast("PATCH")) go func() { err = s.acceptHttp(httpLs) if err != nil { if strings.Contains(err.Error(), "mux: server closed") { return } slog.Warn("dispatcher http accept error", "err", err) } }() errCh := make(chan error, 1) go func() { defer close(errCh) err = m.Serve() if err != nil { err = fmt.Errorf("dispatcher serve error: %w", err) } errCh <- err }() err = nil select { case <-s.ctx.Done(): case err = <-errCh: } close(s.Conn) return err } func (s *Server) acceptHttp(ls net.Listener) error { for { conn, err := ls.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return nil } var ne net.Error if errors.As(err, &ne) && ne.Temporary() { continue } return fmt.Errorf("dispatcher http accept error: %w", err) } metrics.TimerStart.Store(conn, time.Now()) go func() { user, err := http.Process(s.ctx, conn) if err != nil { slog.Error("处理 http 连接失败", "err", err) utils.Close(conn) return } select { case <-s.ctx.Done(): utils.Close(user) case s.Conn <- user: } }() } } func (s *Server) acceptSocks(ls net.Listener) error { for { conn, err := ls.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return nil } var ne net.Error if errors.As(err, &ne) && ne.Temporary() { continue } return fmt.Errorf("dispatcher socks accept error: %w", err) } metrics.TimerStart.Store(conn, time.Now()) go func() { user, err := socks.Process(s.ctx, conn) if err != nil { slog.Error("处理 socks 连接失败", "err", err) utils.Close(conn) return } select { case <-s.ctx.Done(): utils.Close(user) case s.Conn <- user: } }() } } type Conn struct { }