diff --git a/README.md b/README.md index 8ee0c72..5599c2e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ ## todo +认证失败应当是 Warn 级别而非 Error 级别,需要修改 + +考虑再修改逻辑,等待子协程退出不应当级联,而是放在包全局管理,否则流程可能有问题 + +ProxyConn 直接实现 Conn 相同的接口,不再取出 Conn 使用 + +web 服务目录结构,不要 app 那层了 + 配置退出等待时间 log 控制台颜色,输出错误堆栈 diff --git a/client/service.go b/client/service.go index 27b9ff2..214c3c9 100644 --- a/client/service.go +++ b/client/service.go @@ -101,7 +101,7 @@ func data(tagLen byte, tagBuf []byte) error { // 发送 tag slog.Info("准备代理流量") - writeBuf := make([]byte, 1+len(tagBuf)) + writeBuf := make([]byte, 1+tagLen) writeBuf[0] = tagLen copy(writeBuf[1:], tagBuf) _, err = src.Write(writeBuf) diff --git a/main.go b/main.go deleted file mode 100644 index e85d256..0000000 --- a/main.go +++ /dev/null @@ -1,9 +0,0 @@ -package main - -import ( - "proxy-server/server" -) - -func main() { - server.Start2() -} diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go index 2f60185..7215563 100644 --- a/pkg/utils/chan.go +++ b/pkg/utils/chan.go @@ -35,7 +35,7 @@ func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { return connCh } -func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} { +func ChanWgWait[T WaitGroup](ctx context.Context, wg T) chan struct{} { ch := make(chan struct{}) go func() { wg.Wait() diff --git a/pkg/utils/sync.go b/pkg/utils/sync.go index b388690..4e18a32 100644 --- a/pkg/utils/sync.go +++ b/pkg/utils/sync.go @@ -6,7 +6,7 @@ import ( ) type WaitGroup interface { - Add(delta uint) + Add(delta int) Done() Wait() } @@ -16,8 +16,8 @@ type CountWaitGroup struct { num atomic.Int64 } -func (c *CountWaitGroup) Add(delta uint) { - c.wg.Add(int(delta)) +func (c *CountWaitGroup) Add(delta int) { + c.wg.Add(delta) c.num.Add(int64(delta)) } diff --git a/server/fwd/auth.go b/server/fwd/auth.go new file mode 100644 index 0000000..400133b --- /dev/null +++ b/server/fwd/auth.go @@ -0,0 +1,206 @@ +package fwd + +import ( + "context" + "io" + "log/slog" + "net" + "proxy-server/pkg/utils" + "proxy-server/server/fwd/socks" + "proxy-server/server/pkg/orm" + "proxy-server/server/web/app/models" + "time" + + "github.com/pkg/errors" +) + +type NoAuthAuthenticator struct { +} + +func (a *NoAuthAuthenticator) Method() socks.AuthMethod { + return socks.NoAuth +} + +func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { + + // 获取用户地址 + conn, ok := writer.(net.Conn) + if !ok { + return nil, errors.New("noAuth 认证失败,无法获取连接信息") + } + addr := conn.RemoteAddr().String() + client, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, errors.Wrap(err, "noAuth 认证失败") + } + slog.Debug("用户的地址为 " + client) + + // 获取服务 + server, ok := ctx.Value("service").(*socks.Server) + if !ok { + return nil, errors.New("noAuth 认证失败,无法获取服务信息") + } + node := server.Name + slog.Debug("服务的名称为 " + server.Name) + + // 查询权限记录 + slog.Info("用户 " + client + " 请求连接到 " + node) + var channels []models.Channel + err = orm.DB. + Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", node). + Joins("INNER JOIN public.users u ON channels.user_id = u.id"). + Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", client). + Where(&models.Channel{ + AuthIp: true, + }). + Find(&channels).Error + if err != nil { + return nil, errors.New("noAuth 查询用户权限失败") + } + + // 记录应该只有一条 + channel, err := orm.MaySingle(channels) + if err != nil { + return nil, errors.Wrap(err, "noAuth 没有权限") + } + + // 检查是否需要密码认证 + if channel.AuthPass { + return nil, errors.New("noAuth 没有权限,需要密码认证") + } + + // 检查权限是否过期 + timeout := channel.Expiration.Sub(time.Now()).Seconds() + slog.Info("用户剩余时间", "timeout", timeout) + if timeout <= 0 { + return nil, errors.New("noAuth 权限已过期") + } + slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) + + return &socks.Authentication{ + Method: socks.NoAuth, + Timeout: uint(timeout), + Payload: socks.Payload{ + ID: channel.UserId, + }, + }, nil +} + +type UserPassAuthenticator struct { +} + +func (a *UserPassAuthenticator) Method() socks.AuthMethod { + return socks.UserPassAuth +} + +func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { + + // 检查认证版本 + slog.Debug("验证认证版本") + v, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取版本号失败") + } + if v != socks.AuthVersion { + _, err := writer.Write([]byte{socks.Version, socks.AuthFailure}) + if err != nil { + return nil, errors.Wrap(err, "响应认证失败") + } + return nil, errors.New("认证版本参数不正确") + } + + // 读取账号 + slog.Debug("验证用户账号") + uLen, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取用户名长度失败") + } + usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) + if err != nil { + return nil, errors.Wrap(err, "读取用户名失败") + } + username := string(usernameBuf) + + // 读取密码 + pLen, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取密码长度失败") + } + passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) + if err != nil { + return nil, errors.Wrap(err, "读取密码失败") + } + password := string(passwordBuf) + + // 查询通道配置 + var channel models.Channel + err = orm.DB. + Where(&models.Channel{ + Username: username, + AuthPass: true, + }). + First(&channel).Error + if err != nil { + return nil, errors.Wrap(err, "查询用户失败") + } + + // 检查密码 todo 哈希 + if channel.Password != password { + return nil, errors.New("密码错误") + } + + // 检查权限是否过期 + timeout := channel.Expiration.Sub(time.Now()).Seconds() + slog.Info("用户剩余时间", "timeout", timeout) + if timeout <= 0 { + return nil, errors.New("权限已过期") + } + + // 如果用户设置了双验证则检查 ip 是否在白名单中 + if channel.AuthIp { + slog.Debug("验证用户 ip") + + // 获取用户地址 + conn, ok := writer.(net.Conn) + if !ok { + return nil, errors.New("无法获取连接信息") + } + addr := conn.RemoteAddr().String() + client, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, errors.Wrap(err, "无法获取连接信息") + } + + // 查询通道配置 + var ips []models.UserIp + err = orm.DB. + Where(&models.UserIp{ + UserId: channel.UserId, + IpAddress: client, + }). + Find(&ips).Error + if err != nil { + return nil, errors.Wrap(err, "查询用户 ip 失败") + } + + // 检查是否在白名单中 + if len(ips) == 0 { + return nil, errors.New("没有权限") + } + } + + // 响应认证成功 + _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess}) + if err != nil { + slog.Error("响应认证失败", "err", err) + return nil, err + } + + return &socks.Authentication{ + Method: socks.UserPassAuth, + Timeout: uint(timeout), + Payload: socks.Payload{ + ID: channel.UserId, + }, + }, nil +} diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index cca96fb..9e6024e 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -10,10 +10,8 @@ import ( "proxy-server/pkg/utils" "proxy-server/server/fwd/socks" "proxy-server/server/pkg/env" - "proxy-server/server/pkg/orm" - "proxy-server/server/web/app/models" - "slices" "strconv" + "sync" "time" "github.com/pkg/errors" @@ -23,8 +21,11 @@ type Config struct { } type Service struct { - Config *Config - connMap map[string]socks.ProxyConn + Config *Config + ctx context.Context + cancel context.CancelFunc + userConnMap map[string]socks.ProxyConn + ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -34,51 +35,81 @@ func New(config *Config) *Service { config = &Config{} } + ctx, cancel := context.WithCancel(context.Background()) return &Service{ - Config: config, - connMap: make(map[string]socks.ProxyConn), - ctrlConnWg: utils.CountWaitGroup{}, - dataConnWg: utils.CountWaitGroup{}, + Config: config, + ctx: ctx, + cancel: cancel, + userConnMap: make(map[string]socks.ProxyConn), + ctrlConnWg: utils.CountWaitGroup{}, + dataConnWg: utils.CountWaitGroup{}, } } -func (s *Service) Run(ctx context.Context) { - slog.Info("启动 fwd 服务") - - // 启动工作协程 - subCtx, cancel := context.WithCancel(ctx) - defer cancel() - - goNum := 2 - subErrCh := make(chan error, goNum) - defer close(subErrCh) - - go s.startCtrlTun(subCtx) - go s.startDataTun(subCtx, subErrCh) - - // 等待结束 - var firstSubErr error = nil - for i := 0; i < goNum; i++ { - err := <-subErrCh - if err != nil { - slog.Error("隧道错误关闭", "err", err) - if firstSubErr == nil { - firstSubErr = err - cancel() - } - } else { - slog.Info("隧道关闭") - } - } - - slog.Info("fwd 服务已结束") -} - func (s *Service) Close() { - + s.cancel() + for _, conn := range s.userConnMap { + utils.Close(conn) + } + clear(s.userConnMap) } -func (s *Service) startCtrlTun(ctx context.Context) error { +func (s *Service) Run() { + slog.Debug("启动 fwd 服务") + + errQuit := make(chan struct{}) + defer close(errQuit) + + wg := sync.WaitGroup{} + + // 启动工作协程 + wg.Add(1) + go func() { + defer wg.Done() + err := s.startCtrlTun() + if err != nil { + slog.Error("控制通道发生错误", "err", err) + errQuit <- struct{}{} + return + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.startDataTun() + if err != nil { + slog.Error("数据通道发生错误", "err", err) + errQuit <- struct{}{} + return + } + }() + + // 等待结束 + select { + case <-s.ctx.Done(): + slog.Debug("服务关闭") + case <-errQuit: + slog.Debug("服务异常退出") + } + + // 退出 + s.Close() + + timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wgCh := utils.ChanWgWait(timeout, &wg) + defer close(wgCh) + + select { + case <-timeout.Done(): + slog.Warn("关闭超时,强制关闭") + case <-wgCh: + slog.Debug("服务已退出") + } +} + +func (s *Service) startCtrlTun() error { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) @@ -90,13 +121,13 @@ func (s *Service) startCtrlTun(ctx context.Context) error { defer utils.Close(ls) // 等待连接 - connCh := utils.ChanConnAccept(ctx, ls) + connCh := utils.ChanConnAccept(s.ctx, ls) defer close(connCh) // 处理连接 for loop := true; loop; { select { - case <-ctx.Done(): + case <-s.ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") loop = false case conn, ok := <-connCh: @@ -108,14 +139,15 @@ func (s *Service) startCtrlTun(ctx context.Context) error { go func() { defer s.ctrlConnWg.Done() defer utils.Close(conn) - s.processCtrlConn(conn) + err := s.processCtrlConn(conn) + if err != nil { + slog.Error("处理控制通道连接失败", err) + } }() } } - // 等待子协程结束 todo 可配置等待时间 - s.Close() - + // 等待子协程结束 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg) @@ -125,28 +157,27 @@ func (s *Service) startCtrlTun(ctx context.Context) error { case <-timeout.Done(): slog.Warn("等待控制通道子协程结束超时") case <-procCh: - slog.Info("控制通道子协程结束") + slog.Debug("控制通道子协程结束") } slog.Debug("关闭控制通道") return nil } -func (s *Service) processCtrlConn(controller net.Conn) { - slog.Info("收到客户端控制通道连接", "addr", controller.RemoteAddr().String()) +func (s *Service) processCtrlConn(controller net.Conn) error { + slog.Info("客户端连入", "addr", controller.RemoteAddr().String()) reader := bufio.NewReader(controller) - // 读取端口 + // 获取转发端口 portBuf, err := utils.ReadBuffer(reader, 2) if err != nil { - slog.Error("接收转发端口失败", "err", err) - return + return errors.Wrap(err, "获取转发端口失败") } port := binary.BigEndian.Uint16(portBuf) - // 新建代理服务 - slog.Info("新建代理服务", "port", port) + // 开放转发端口 todo 混合转发 + slog.Info("开放转发端口", "port", port) proxy, err := socks.New(&socks.Config{ Name: strconv.Itoa(int(port)), Port: port, @@ -155,11 +186,11 @@ func (s *Service) processCtrlConn(controller net.Conn) { &NoAuthAuthenticator{}, }, }) - defer proxy.Close() if err != nil { - slog.Error("代理服务创建失败", "err", err) - return + return errors.Wrap(err, "创建 socks 转发服务失败") } + defer proxy.Close() + go func() { err := proxy.Run() if err != nil { @@ -169,55 +200,74 @@ func (s *Service) processCtrlConn(controller net.Conn) { }() // 等待用户连接 - for { - user := <-proxy.Conn - tag := user.Tag() - tagBuf := make([]byte, len(tag)+1) - tagBuf[0] = byte(len(tag)) - copy(tagBuf[1:], tag) - _, err := controller.Write(tagBuf) - if err != nil { - slog.Error("写入 tag 失败", "err", err) - utils.Close(user) - return + wg := sync.WaitGroup{} + for loop := true; loop; { + select { + case <-s.ctx.Done(): + loop = false + case user, ok := <-proxy.Conn: + if !ok { + loop = false + err = errors.New("无法获取连接") + } + wg.Add(1) + go func() { + defer wg.Done() + + tag := user.Tag() + tagLen := len(tag) + tagBuf := make([]byte, 1+tagLen) + tagBuf[0] = byte(tagLen) + copy(tagBuf[1:], tag) + _, err := controller.Write(tagBuf) + if err != nil { + utils.Close(user) + slog.Error("向客户端发送 tag 失败", "err", err) + return + } + s.userConnMap[tag] = user + }() } - s.connMap[tag] = user } + + wg.Wait() + return nil } -func (s *Service) startDataTun(ctx context.Context, errCh chan error) { +func (s *Service) startDataTun() error { dataPort := env.AppDataPort slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) // 监听端口 - lData, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) + ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) if err != nil { - slog.Error("listen error", "err", err) - return + return errors.Wrap(err, "监听数据通道失败") } - defer utils.Close(lData) + defer utils.Close(ls) // 等待连接 - connCh := utils.ChanConnAccept(ctx, lData) + connCh := utils.ChanConnAccept(s.ctx, ls) defer close(connCh) // 处理连接 -loop: - for { + for loop := true; loop; { select { - case <-ctx.Done(): + case <-s.ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") - break loop + loop = false case conn, ok := <-connCh: if !ok { slog.Debug("结束处理连接,由于获取连接失败") - break loop + loop = false } s.dataConnWg.Add(1) go func() { defer s.dataConnWg.Done() defer utils.Close(conn) - s.processDataConn(conn) + err := s.processDataConn(conn) + if err != nil { + slog.Error("处理数据通道失败", err) + } }() } } @@ -225,7 +275,6 @@ loop: // 等待子协程结束 todo 可配置等待时间 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - procCh := utils.ChanWgWait(timeout, &s.dataConnWg) defer close(procCh) @@ -233,47 +282,60 @@ loop: case <-timeout.Done(): slog.Warn("等待数据通道子协程结束超时") case <-procCh: - slog.Info("数据通道子协程结束") + slog.Debug("数据通道子协程结束") } slog.Debug("关闭数据通道") - errCh <- nil + return nil } -func (s *Service) processDataConn(client net.Conn) { - slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String()) +func (s *Service) processDataConn(client net.Conn) error { + slog.Info("客户端准备接收数据 " + client.RemoteAddr().String()) // 读取 tag tagLen, err := utils.ReadByte(client) if err != nil { - slog.Error("read error", "err", err) - return + return errors.Wrap(err, "从客户端获取 tag 失败") } tagBuf, err := utils.ReadBuffer(client, int(tagLen)) if err != nil { - slog.Error("read error", "err", err) - return + return errors.Wrap(err, "从客户端获取 tag 失败") } tag := string(tagBuf) - // 找到用户连接 - data, ok := s.connMap[tag] - if !ok { - slog.Error("no such connection") - return + select { + case <-s.ctx.Done(): + return nil + default: } + // 找到用户连接 + data, ok := s.userConnMap[tag] + if !ok { + return errors.New("查找用户连接失败") + } + defer func() { + delete(s.userConnMap, tag) + utils.Close(data) + }() + // 响应用户 user := data.Conn - defer utils.Close(user) - socks.SendSuccess(user, client) + err = socks.SendSuccess(user, client) + if err != nil { + // todo 考虑是否需要处理服务关闭后导致用户连接被关闭的情况 + return errors.Wrap(err, "向用户发送成功消息失败") + } - // 写入目标地址 - destBuf := slices.Insert([]byte(data.Dest), 0, byte(len(data.Dest))) + // 发送目标地址 + dest := data.Dest + destLen := len(dest) + destBuf := make([]byte, 1+destLen) + destBuf[0] = byte(destLen) + copy(destBuf[1:], dest) _, err = client.Write(destBuf) if err != nil { - slog.Error("发送目标地址失败", "err", err) - return + return errors.Wrap(err, "向客户端发送目标地址失败") } // 数据转发 @@ -283,215 +345,25 @@ func (s *Service) processDataConn(client net.Conn) { // defer utils.Close(userPipeWriter) // teeUser := io.TeeReader(user, userPipeWriter) - errCh := make(chan error) + wg := sync.WaitGroup{} + wg.Add(1) go func() { + defer wg.Done() _, err := io.Copy(client, user) if err != nil { - slog.Error("processDataConn error u2c", "err", err) + slog.Error("数据转发失败 user->client", "err", err) } - errCh <- err }() - // go analysis(userPipeReader) - + wg.Add(1) go func() { + defer wg.Done() _, err := io.Copy(user, client) if err != nil { - slog.Error("processDataConn error c2u", "err", err) + slog.Error("数据转发失败 client->user", "err", err) } - errCh <- err }() + wg.Wait() - <-errCh slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) -} - -type NoAuthAuthenticator struct { -} - -func (a *NoAuthAuthenticator) Method() socks.AuthMethod { - return socks.NoAuth -} - -func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { - - // 获取用户地址 - conn, ok := writer.(net.Conn) - if !ok { - return nil, errors.New("noAuth 认证失败,无法获取连接信息") - } - addr := conn.RemoteAddr().String() - client, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, errors.Wrap(err, "noAuth 认证失败") - } - slog.Debug("用户的地址为 " + client) - - // 获取服务 - server, ok := ctx.Value("service").(*socks.Server) - if !ok { - return nil, errors.New("noAuth 认证失败,无法获取服务信息") - } - node := server.Name - slog.Debug("服务的名称为 " + server.Name) - - // 查询权限记录 - slog.Info(" 客户端 " + client + " 请求连接到 " + node) - var channels []models.Channel - err = orm.DB. - Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", node). - Joins("INNER JOIN public.users u ON channels.user_id = u.id"). - Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", client). - Where(&models.Channel{ - AuthIp: true, - }). - Find(&channels).Error - if err != nil { - return nil, errors.New("noAuth 查询用户权限失败") - } - - // 记录应该只有一条 - channel, err := orm.MaySingle(channels) - if err != nil { - return nil, errors.Wrap(err, "noAuth 没有权限") - } - - // 检查是否需要密码认证 - if channel.AuthPass { - return nil, errors.New("noAuth 没有权限,需要密码认证") - } - - // 检查权限是否过期 - timeout := channel.Expiration.Sub(time.Now()).Seconds() - slog.Info("用户剩余时间", "timeout", timeout) - if timeout <= 0 { - return nil, errors.New("noAuth 权限已过期") - } - slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) - - return &socks.Authentication{ - Method: socks.NoAuth, - Timeout: uint(timeout), - Payload: socks.Payload{ - ID: channel.UserId, - }, - }, nil -} - -type UserPassAuthenticator struct { -} - -func (a *UserPassAuthenticator) Method() socks.AuthMethod { - return socks.UserPassAuth -} - -func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { - - // 检查认证版本 - slog.Debug("验证认证版本") - v, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取版本号失败") - } - if v != socks.AuthVersion { - _, err := writer.Write([]byte{socks.Version, socks.AuthFailure}) - if err != nil { - return nil, errors.Wrap(err, "响应认证失败") - } - return nil, errors.New("认证版本参数不正确") - } - - // 读取账号 - slog.Debug("验证用户账号") - uLen, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取用户名长度失败") - } - usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) - if err != nil { - return nil, errors.Wrap(err, "读取用户名失败") - } - username := string(usernameBuf) - - // 读取密码 - pLen, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取密码长度失败") - } - passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) - if err != nil { - return nil, errors.Wrap(err, "读取密码失败") - } - password := string(passwordBuf) - - // 查询通道配置 - var channel models.Channel - err = orm.DB. - Where(&models.Channel{ - Username: username, - AuthPass: true, - }). - First(&channel).Error - if err != nil { - return nil, errors.Wrap(err, "查询用户失败") - } - - // 检查密码 todo 哈希 - if channel.Password != password { - return nil, errors.New("密码错误") - } - - // 检查权限是否过期 - timeout := channel.Expiration.Sub(time.Now()).Seconds() - slog.Info("用户剩余时间", "timeout", timeout) - if timeout <= 0 { - return nil, errors.New("权限已过期") - } - - // 如果用户设置了双验证则检查 ip 是否在白名单中 - if channel.AuthIp { - slog.Debug("验证用户 ip") - - // 获取用户地址 - conn, ok := writer.(net.Conn) - if !ok { - return nil, errors.New("无法获取连接信息") - } - addr := conn.RemoteAddr().String() - client, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, errors.Wrap(err, "无法获取连接信息") - } - - // 查询通道配置 - var ips []models.UserIp - err = orm.DB. - Where(&models.UserIp{ - UserId: channel.UserId, - IpAddress: client, - }). - Find(&ips).Error - if err != nil { - return nil, errors.Wrap(err, "查询用户 ip 失败") - } - - // 检查是否在白名单中 - if len(ips) == 0 { - return nil, errors.New("没有权限") - } - } - - // 响应认证成功 - _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess}) - if err != nil { - slog.Error("响应认证失败", "err", err) - return nil, err - } - - return &socks.Authentication{ - Method: socks.UserPassAuth, - Timeout: uint(timeout), - Payload: socks.Payload{ - ID: channel.UserId, - }, - }, nil + return nil } diff --git a/server/fwd/socks/request.go b/server/fwd/socks/request.go index b153d6a..e35b9fd 100644 --- a/server/fwd/socks/request.go +++ b/server/fwd/socks/request.go @@ -335,13 +335,14 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { return err } -func SendSuccess(user net.Conn, target net.Conn) { +func SendSuccess(user net.Conn, target net.Conn) error { local := target.LocalAddr().(*net.TCPAddr) bind := AddrSpec{IP: local.IP, Port: local.Port} err := sendReply(user, successReply, &bind) if err != nil { - slog.Error("Failed to send reply", err) + return err } + return nil } type ProxyConn struct { diff --git a/server/fwd/socks/socks.go b/server/fwd/socks/socks.go index 4abcaa8..0346dc0 100644 --- a/server/fwd/socks/socks.go +++ b/server/fwd/socks/socks.go @@ -95,13 +95,13 @@ func New(conf *Config) (*Server, error) { wg: utils.CountWaitGroup{}, Name: conf.Name, Port: conf.Port, - Conn: make(chan ProxyConn, 100), + Conn: make(chan ProxyConn), }, nil } // Run 监听端口 func (s *Server) Run() error { - slog.Info("启动 socks5 代理服务") + slog.Debug("启动 socks5 代理服务") // 监听端口 host := s.config.Host @@ -112,7 +112,7 @@ func (s *Server) Run() error { return errors.Wrap(err, "监听端口失败") } defer utils.Close(ls) - slog.Info("正在监听端口", slog.Uint64("port", uint64(port))) + slog.Debug("正在监听端口", slog.Uint64("port", uint64(port))) // 处理连接 connCh := utils.ChanConnAccept(s.ctx, ls) diff --git a/server/server.go b/server/server.go index 5dcfcd1..c230241 100644 --- a/server/server.go +++ b/server/server.go @@ -9,11 +9,10 @@ import ( "proxy-server/server/fwd" "proxy-server/server/pkg/env" "proxy-server/server/pkg/orm" - "proxy-server/server/web" + "sync" "syscall" "time" - "github.com/joho/godotenv" "github.com/lmittmann/tint" "github.com/mattn/go-colorable" ) @@ -30,12 +29,18 @@ func Start() { env.Init() orm.Init() + // 退出信号 + osQuit := make(chan os.Signal) + signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) + + errQuit := make(chan struct{}) + defer close(errQuit) + // 启动服务 ctx, cancel := context.WithCancel(context.Background()) defer cancel() - errQuit := make(chan error) - defer close(errQuit) - wg := utils.CountWaitGroup{} + + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -44,13 +49,10 @@ func Start() { if err != nil { slog.Error("代理服务发生错误", "err", err) } - errQuit <- err + errQuit <- struct{}{} }() // 等待退出信号 - osQuit := make(chan os.Signal) - signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) - select { case <-osQuit: slog.Info("服务关闭") @@ -58,62 +60,23 @@ func Start() { slog.Error("服务异常退出") } - // 等待子服务退出 + // 退出服务 cancel() timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wgCh := utils.ChanWgWait(timeout, &wg) + close(wgCh) select { case <-timeout.Done(): - slog.Error("关闭超时,强制关闭") + slog.Warn("关闭超时,强制关闭") case <-wgCh: - slog.Info("服务已退出") + slog.Debug("服务已退出") } } func initLog() { - slog.SetLogLoggerLevel(slog.LevelDebug) -} - -func startFwdServer(ctx context.Context) error { - server := fwd.New(nil) - - go func() { - <-ctx.Done() - server.Close() - }() - - server.Run(ctx) - return nil -} - -func startMntServer(ctx context.Context) { - -} - -func startWebServer(ctx context.Context) { - -} - -func Start2() { - defer func() { - err := recover() - if err != nil { - slog.Error("服务由于意外的 panic 导致退出", err) - } - }() - - ctx := context.Background() - - // 初始化环境变量 - err := godotenv.Load() - if err != nil { - slog.Debug("没有本地环境变量文件") - } - - // 配置日志 writer := colorable.NewColorable(os.Stdout) logger := slog.New(tint.NewHandler(writer, &tint.Options{ Level: slog.LevelDebug, @@ -126,32 +89,24 @@ func Start2() { }, })) slog.SetDefault(logger) - - // 初始化公共组件 - orm.Init() - - // 启动子服务 - goCount := 1 - errChan := make(chan error, goCount) - ctxC, cancel := context.WithCancel(ctx) - defer cancel() - - go web.Start(ctxC, errChan) - // go monitor.Start2(ctxC, errChan) - slog.Info("服务启动成功") - - // 监听异常 - well := true - for i := 0; i < goCount; i++ { - err := <-errChan - if err != nil { - slog.Error("服务异常退出", err) - if well { // 第一次出错时取消其他服务 - well = false - cancel() - } - } - } - close(errChan) - slog.Info("服务已全部退出") +} + +func startFwdServer(ctx context.Context) error { + server := fwd.New(nil) + + go func() { + <-ctx.Done() + server.Close() + }() + + server.Run() + return nil +} + +func startMntServer(ctx context.Context) { + +} + +func startWebServer(ctx context.Context) { + } diff --git a/server/web/app/handlers/channel.go b/server/web/app/handlers/channel.go index beb57ad..05c7062 100644 --- a/server/web/app/handlers/channel.go +++ b/server/web/app/handlers/channel.go @@ -53,7 +53,7 @@ func ChanRequest(c *gin.Context) { // 检查此 ip 是否有权限访问目标 node clientIp := strings.Split(content.RemoteAddr, ":")[0] targetNode := content.ProxyName - slog.Debug(id + " 客户端 " + clientIp + " 请求连接到 " + targetNode) + slog.Debug(id + " 用户 " + clientIp + " 请求连接到 " + targetNode) var channels []models.Channel err = orm.DB.