Files
proxy/gateway/fwd/dispatcher/dispatch.go

171 lines
3.1 KiB
Go

package dispatcher
import (
"context"
"fmt"
"log/slog"
"net"
"proxy-server/gateway/core"
"proxy-server/gateway/fwd/http"
"proxy-server/gateway/fwd/metrics"
"proxy-server/gateway/fwd/socks"
"proxy-server/pkg/utils"
"strconv"
"strings"
"time"
"errors"
"github.com/soheilhy/cmux"
)
type Server struct {
ctx context.Context
cancel context.CancelFunc
readTimeout time.Duration
Port uint16
Conn chan *core.Conn
}
func New(port uint16, readTimeout time.Duration) (*Server, error) {
if port == 0 {
return nil, errors.New("port is required")
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
ctx,
cancel,
readTimeout,
port,
make(chan *core.Conn),
}, nil
}
func (s *Server) Close() {
s.cancel()
}
func (s *Server) Run() error {
port := strconv.Itoa(int(s.Port))
ls, err := net.Listen("tcp", ":"+port)
if err != nil {
return fmt.Errorf("dispatcher 监听失败: %w", err)
}
defer utils.Close(ls)
m := cmux.New(ls)
m.SetReadTimeout(s.readTimeout)
defer m.Close()
socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05})))
go func() {
err = s.acceptSocks(socksLs)
if err != nil {
if strings.Contains(err.Error(), "mux: server closed") {
return
}
slog.Warn("dispatcher socks accept error", "err", err)
}
}()
httpLs := m.Match(cmux.HTTP1Fast("PATCH"))
go func() {
err = s.acceptHttp(httpLs)
if err != nil {
if strings.Contains(err.Error(), "mux: server closed") {
return
}
slog.Warn("dispatcher http accept error", "err", err)
}
}()
errCh := make(chan error, 1)
go func() {
defer close(errCh)
err = m.Serve()
if err != nil {
err = fmt.Errorf("dispatcher serve error: %w", err)
}
errCh <- err
}()
err = nil
select {
case <-s.ctx.Done():
case err = <-errCh:
}
close(s.Conn)
return err
}
func (s *Server) acceptHttp(ls net.Listener) error {
for {
conn, err := ls.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return fmt.Errorf("dispatcher http accept error: %w", err)
}
metrics.TimerStart.Store(conn, time.Now())
go func() {
user, err := http.Process(s.ctx, conn)
if err != nil {
slog.Error("处理 http 连接失败", "err", err)
utils.Close(conn)
return
}
select {
case <-s.ctx.Done():
utils.Close(user)
case s.Conn <- user:
}
}()
}
}
func (s *Server) acceptSocks(ls net.Listener) error {
for {
conn, err := ls.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return fmt.Errorf("dispatcher socks accept error: %w", err)
}
metrics.TimerStart.Store(conn, time.Now())
go func() {
user, err := socks.Process(s.ctx, conn)
if err != nil {
slog.Error("处理 socks 连接失败", "err", err)
utils.Close(conn)
return
}
select {
case <-s.ctx.Done():
utils.Close(user)
case s.Conn <- user:
}
}()
}
}
type Conn struct {
}