重构认证相关结构,更新认证流程,添加日志功能
This commit is contained in:
@@ -23,11 +23,11 @@ const (
|
||||
|
||||
type Authenticator interface {
|
||||
Method() AuthMethod
|
||||
Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*AuthContext, error)
|
||||
Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*Authentication, error)
|
||||
}
|
||||
|
||||
// authenticate 执行认证流程
|
||||
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*Authentication, error) {
|
||||
|
||||
// 版本检查
|
||||
err := checkVersion(reader)
|
||||
@@ -75,8 +75,13 @@ func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext,
|
||||
return nil, errors.New("没有适用的认证方式")
|
||||
}
|
||||
|
||||
type AuthContext struct {
|
||||
type Authentication struct {
|
||||
Method AuthMethod
|
||||
Timeout uint
|
||||
Payload map[string]any
|
||||
Payload Payload
|
||||
Data map[string]any
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
@@ -49,6 +49,27 @@ type AddrSpec struct {
|
||||
Port int
|
||||
}
|
||||
|
||||
func (a AddrSpec) Domain() []string {
|
||||
if a.FQDN != "" {
|
||||
return []string{a.FQDN}
|
||||
}
|
||||
|
||||
var domain []string
|
||||
|
||||
ch := make(chan struct{})
|
||||
defer close(ch)
|
||||
go func() {
|
||||
addr, err := net.LookupAddr(a.IP.String())
|
||||
if err == nil {
|
||||
domain = addr
|
||||
}
|
||||
ch <- struct{}{}
|
||||
}()
|
||||
<-ch
|
||||
|
||||
return domain
|
||||
}
|
||||
|
||||
func (a AddrSpec) String() string {
|
||||
if a.FQDN != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||
@@ -186,8 +207,8 @@ type Request struct {
|
||||
Version uint8
|
||||
// Requested command
|
||||
Command uint8
|
||||
// AuthContext provided during negotiation
|
||||
AuthContext *AuthContext
|
||||
// Authentication provided during negotiation
|
||||
Authentication *Authentication
|
||||
// AddrSpec of the network that sent the request
|
||||
RemoteAddr *AddrSpec
|
||||
// AddrSpec of the desired destination
|
||||
@@ -220,7 +241,6 @@ func (s *Server) handle(req *Request, conn net.Conn) error {
|
||||
}
|
||||
|
||||
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
|
||||
// 检查规则集约束
|
||||
s.config.Logger.Printf("检查约束规则\n")
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
@@ -233,75 +253,8 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request)
|
||||
}
|
||||
|
||||
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
||||
s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
||||
s.Conn <- ProxyConn{conn, req.realDestAddr.Address()}
|
||||
return nil
|
||||
|
||||
// 与目标服务器建立连接
|
||||
// s.config.Logger.Printf("与目标服务器建立连接\n")
|
||||
// dial := s.config.Dial
|
||||
// target, err := dial("tcp", req.realDestAddr.Address())
|
||||
// if err != nil {
|
||||
// msg := err.Error()
|
||||
// resp := hostUnreachable
|
||||
// if strings.Contains(msg, "refused") {
|
||||
// resp = connectionRefused
|
||||
// } else if strings.Contains(msg, "network is unreachable") {
|
||||
// resp = networkUnreachable
|
||||
// }
|
||||
//
|
||||
// err := sendReply(Conn, resp, nil)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("failed to send reply: %v", err)
|
||||
// }
|
||||
// return fmt.Errorf("request to %v failed: %v", req.DestAddr, err)
|
||||
// }
|
||||
// defer closeConnection(target)
|
||||
//
|
||||
// // 正常响应
|
||||
// slog.Info("连接成功,开始代理流量")
|
||||
//
|
||||
// local := target.LocalAddr().(*net.TCPAddr)
|
||||
// bind := AddrSpec{IP: local.IP, Port: local.Port}
|
||||
// err = sendReply(Conn, successReply, &bind)
|
||||
// if err != nil {
|
||||
// return fmt.Errorf("Failed to send reply: %v", err)
|
||||
// }
|
||||
//
|
||||
// // 配置超时时间和行为
|
||||
// timeout := req.AuthContext.Timeout
|
||||
// slog.Debug("超时时间", "timeout", timeout)
|
||||
//
|
||||
// timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// // 代理流量
|
||||
// errChan := make(chan error, 2)
|
||||
// go func() {
|
||||
// _, err = io.Copy(target, req.bufConn)
|
||||
// errChan <- err
|
||||
// }()
|
||||
// go func() {
|
||||
// _, err = io.Copy(Conn, target)
|
||||
// errChan <- err
|
||||
// }()
|
||||
//
|
||||
// for {
|
||||
// select {
|
||||
//
|
||||
// case <-timeoutCtx.Done():
|
||||
// slog.Debug("超时断开连接")
|
||||
// // todo 根据 termination 执行不同的断开行为
|
||||
// return nil
|
||||
//
|
||||
// case err := <-errChan:
|
||||
// slog.Debug("主动断开连接")
|
||||
// if err != nil {
|
||||
// return errors.Wrap(err, "代理流量出现错误")
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
@@ -391,15 +344,19 @@ func SendSuccess(user net.Conn, target net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
type ProxyData struct {
|
||||
type ProxyConn struct {
|
||||
// 用户连入的连接
|
||||
Conn net.Conn
|
||||
// 用户目标地址
|
||||
Dest string
|
||||
}
|
||||
|
||||
func (d ProxyData) Tag() string {
|
||||
func (d ProxyConn) Tag() string {
|
||||
local := d.Conn.LocalAddr()
|
||||
remote := d.Conn.RemoteAddr()
|
||||
return fmt.Sprintf("%s-%s", remote, local)
|
||||
}
|
||||
|
||||
func (d ProxyConn) Close() error {
|
||||
return d.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"os"
|
||||
"proxy-server/pkg/utils"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@@ -55,7 +56,7 @@ type Server struct {
|
||||
wg utils.CountWaitGroup
|
||||
Name string
|
||||
Port uint16
|
||||
Conn chan ProxyData
|
||||
Conn chan ProxyConn
|
||||
}
|
||||
|
||||
// New 创建服务器
|
||||
@@ -90,7 +91,7 @@ func New(conf *Config) (*Server, error) {
|
||||
wg: utils.CountWaitGroup{},
|
||||
Name: conf.Name,
|
||||
Port: conf.Port,
|
||||
Conn: make(chan ProxyData, 100),
|
||||
Conn: make(chan ProxyConn, 100),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -128,7 +129,8 @@ func (s *Server) Run() error {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
err := s.serve(conn)
|
||||
// 连接要传出,不能在这里关闭连接
|
||||
err := s.process(conn)
|
||||
if err != nil {
|
||||
slog.Error("处理连接失败", err)
|
||||
}
|
||||
@@ -163,8 +165,8 @@ func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// serve 建立连接
|
||||
func (s *Server) serve(conn net.Conn) error {
|
||||
// process 建立连接
|
||||
func (s *Server) process(conn net.Conn) error {
|
||||
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
@@ -190,15 +192,27 @@ func (s *Server) serve(conn net.Conn) error {
|
||||
slog.Debug("连接请求处理完成")
|
||||
}
|
||||
|
||||
request.AuthContext = authContext
|
||||
client, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
// 记录日志
|
||||
go func() {
|
||||
slog.Info(
|
||||
"用户访问记录",
|
||||
slog.Uint64("uid", uint64(authContext.Payload.ID)),
|
||||
slog.String("user", conn.RemoteAddr().String()),
|
||||
slog.Any("node", conn.LocalAddr().String()),
|
||||
slog.String("dest", request.DestAddr.Address()),
|
||||
slog.String("domain", strings.Join(request.DestAddr.Domain(), ",")),
|
||||
)
|
||||
}()
|
||||
|
||||
request.Authentication = authContext
|
||||
user, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
return fmt.Errorf("获取客户端地址失败")
|
||||
return fmt.Errorf("获取用户地址失败")
|
||||
}
|
||||
|
||||
request.RemoteAddr = &AddrSpec{
|
||||
IP: client.IP,
|
||||
Port: client.Port,
|
||||
IP: user.IP,
|
||||
Port: user.Port,
|
||||
}
|
||||
|
||||
// 处理请求
|
||||
Reference in New Issue
Block a user