重构认证相关结构,更新认证流程,添加日志功能
This commit is contained in:
242
server/fwd/socks/socks.go
Normal file
242
server/fwd/socks/socks.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"proxy-server/pkg/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = byte(5)
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Name string
|
||||
|
||||
Host string
|
||||
Port uint16
|
||||
|
||||
// 认证方法
|
||||
AuthMethods []Authenticator
|
||||
|
||||
// 域名解析
|
||||
Resolver NameResolver
|
||||
|
||||
// 自定义认证规则
|
||||
Rules RuleSet
|
||||
|
||||
// 地址重写
|
||||
Rewriter AddressRewriter
|
||||
|
||||
// 用于 bind 和 associate
|
||||
BindIP net.IP
|
||||
|
||||
// Logger
|
||||
Logger *log.Logger
|
||||
|
||||
// 自定义连接流程
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
config *Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg utils.CountWaitGroup
|
||||
Name string
|
||||
Port uint16
|
||||
Conn chan ProxyConn
|
||||
}
|
||||
|
||||
// New 创建服务器
|
||||
func New(conf *Config) (*Server, error) {
|
||||
if len(conf.AuthMethods) == 0 {
|
||||
return nil, ConfigError("认证方法不能为空")
|
||||
}
|
||||
|
||||
if conf.Resolver == nil {
|
||||
conf.Resolver = DNSResolver{}
|
||||
}
|
||||
|
||||
if conf.Rules == nil {
|
||||
conf.Rules = PermitAll()
|
||||
}
|
||||
|
||||
if conf.Logger == nil {
|
||||
conf.Logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||
}
|
||||
|
||||
if conf.Dial == nil {
|
||||
conf.Dial = func(network, addr string) (net.Conn, error) {
|
||||
return net.Dial(network, addr)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
config: conf,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
wg: utils.CountWaitGroup{},
|
||||
Name: conf.Name,
|
||||
Port: conf.Port,
|
||||
Conn: make(chan ProxyConn, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run 监听端口
|
||||
func (s *Server) Run() error {
|
||||
slog.Info("启动 socks5 代理服务")
|
||||
|
||||
// 监听端口
|
||||
host := s.config.Host
|
||||
port := s.config.Port
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
ls, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "监听端口失败")
|
||||
}
|
||||
defer utils.Close(ls)
|
||||
slog.Info("正在监听端口", slog.Uint64("port", uint64(port)))
|
||||
|
||||
// 处理连接
|
||||
connCh := utils.ChanConnAccept(s.ctx, ls)
|
||||
defer close(connCh)
|
||||
|
||||
err = nil
|
||||
for loop := true; loop; {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
slog.Debug("服务主动停止")
|
||||
loop = false
|
||||
case conn, ok := <-connCh:
|
||||
if !ok {
|
||||
err = errors.New("意外错误,无法获取连接")
|
||||
loop = false
|
||||
break
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
// 连接要传出,不能在这里关闭连接
|
||||
err := s.process(conn)
|
||||
if err != nil {
|
||||
slog.Error("处理连接失败", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
s.Close()
|
||||
}
|
||||
|
||||
// 关闭服务
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wgCh := utils.ChanWgWait(timeout, &s.wg)
|
||||
|
||||
err = nil
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
err = errors.New("关闭超时(强制关闭)")
|
||||
case <-wgCh:
|
||||
}
|
||||
|
||||
if s.Conn != nil {
|
||||
close(s.Conn)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Close 关闭服务
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// process 建立连接
|
||||
func (s *Server) process(conn net.Conn) error {
|
||||
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// 认证
|
||||
slog.Debug("开始认证流程")
|
||||
authContext, err := s.authenticate(reader, conn)
|
||||
if err != nil {
|
||||
utils.Close(conn)
|
||||
slog.Error("认证失败", err)
|
||||
return err
|
||||
} else {
|
||||
slog.Debug("认证完成")
|
||||
}
|
||||
|
||||
// 处理连接请求
|
||||
slog.Debug("处理连接请求")
|
||||
request, err := s.request(reader, conn)
|
||||
if err != nil {
|
||||
slog.Error("连接请求处理失败", err)
|
||||
return err
|
||||
} else {
|
||||
slog.Debug("连接请求处理完成")
|
||||
}
|
||||
|
||||
// 记录日志
|
||||
go func() {
|
||||
slog.Info(
|
||||
"用户访问记录",
|
||||
slog.Uint64("uid", uint64(authContext.Payload.ID)),
|
||||
slog.String("user", conn.RemoteAddr().String()),
|
||||
slog.Any("node", conn.LocalAddr().String()),
|
||||
slog.String("dest", request.DestAddr.Address()),
|
||||
slog.String("domain", strings.Join(request.DestAddr.Domain(), ",")),
|
||||
)
|
||||
}()
|
||||
|
||||
request.Authentication = authContext
|
||||
user, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("获取用户地址失败")
|
||||
}
|
||||
|
||||
request.RemoteAddr = &AddrSpec{
|
||||
IP: user.IP,
|
||||
Port: user.Port,
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
slog.Debug("开始代理流量")
|
||||
err = s.handle(request, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkVersion 检查客户端版本
|
||||
func checkVersion(reader io.Reader) error {
|
||||
version, err := utils.ReadByte(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("客户端请求版本", "version", version)
|
||||
|
||||
if version != Version {
|
||||
return errors.New("客户端版本不兼容")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user