package socks import ( "bufio" "context" "encoding/binary" "fmt" "io" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/fwd/core" "slices" "github.com/pkg/errors" ) const ( Version = byte(5) AuthVersion = byte(1) ) const ( NoAuth = byte(0) UserPassAuth = byte(2) NoAcceptable = byte(0xFF) ) const ( AuthSuccess = byte(0) AuthFailure = byte(1) ) const ( ConnectCommand = byte(1) BindCommand = byte(2) AssociateCommand = byte(3) ) const ( ipv4Address = byte(1) fqdnAddress = byte(3) ipv6Address = byte(4) ) const ( successReply byte = iota serverFailure ruleFailure networkUnreachable hostUnreachable connectionRefused ttlExpired commandNotSupported addrTypeNotSupported ) // Process 处理连接 func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { reader := bufio.NewReader(conn) // 认证 auth, err := authenticate(ctx, reader, conn) if err != nil { return nil, errors.Wrap(err, "认证失败") } // 处理连接请求 request, err := request(ctx, reader, conn) if err != nil { return nil, errors.Wrap(err, "处理连接请求失败") } // 代理连接 if request.Command != ConnectCommand { return nil, errors.New("不支持的连接指令") } // 响应成功 err = sendReply(conn, successReply, request.DestAddr) return &core.Conn{ Conn: conn, Reader: reader, Protocol: "socks5", Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), Dest: request.DestAddr, Auth: auth, }, nil } // checkVersion 检查客户端版本 func checkVersion(reader io.Reader) error { version, err := utils.ReadByte(reader) if err != nil { return err } if version != Version { return errors.New("客户端版本不兼容") } return nil } // authenticate 执行认证流程 func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*core.AuthContext, error) { // 版本检查 err := checkVersion(reader) if err != nil { return nil, err } // 获取客户端认证方式 nAuth, err := utils.ReadByte(reader) if err != nil { return nil, err } methods, err := utils.ReadBuffer(reader, int(nAuth)) if err != nil { return nil, err } // 密码模式 if slices.Contains(methods, UserPassAuth) { _, err := conn.Write([]byte{Version, byte(UserPassAuth)}) if err != nil { return nil, errors.Wrap(err, "响应认证方式失败") } // 检查认证版本 slog.Debug("验证认证版本") v, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取版本号失败") } if v != AuthVersion { _, err := conn.Write([]byte{Version, AuthFailure}) if err != nil { return nil, errors.Wrap(err, "响应认证失败") } return nil, errors.New("认证版本参数不正确") } // 读取账号 slog.Debug("验证用户账号") uLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取用户名长度失败") } usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) if err != nil { return nil, errors.Wrap(err, "读取用户名失败") } username := string(usernameBuf) // 读取密码 pLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取密码长度失败") } passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) if err != nil { return nil, errors.Wrap(err, "读取密码失败") } password := string(passwordBuf) // 检查权限 authContext, err := core.CheckPass(conn, username, password) if err != nil { return nil, errors.Wrap(err, "权限检查失败") } // 响应认证成功 _, err = conn.Write([]byte{AuthVersion, AuthSuccess}) if err != nil { return nil, errors.Wrap(err, "响应认证成功失败") } return authContext, nil } // 无认证 if slices.Contains(methods, NoAuth) { _, err = conn.Write([]byte{Version, NoAuth}) if err != nil { return nil, errors.Wrap(err, "响应认证方式失败") } authContext, err := core.CheckIp(conn) if err != nil { return nil, errors.Wrap(err, "权限检查失败") } return authContext, nil } // 无适用的认证方式 _, err = conn.Write([]byte{Version, NoAcceptable}) if err != nil { return nil, err } return nil, errors.New("没有适用的认证方式") } type Request struct { Command uint8 DestAddr *core.FwdAddr } // request 处理连接请求 func request(ctx context.Context, reader io.Reader, writer io.Writer) (*Request, error) { // 检查版本 err := checkVersion(reader) if err != nil { return nil, err } // 检查连接命令 command, err := utils.ReadByte(reader) if err != nil { return nil, err } if command != ConnectCommand && command != BindCommand && command != AssociateCommand { err = sendReply(writer, commandNotSupported, nil) if err != nil { return nil, err } return nil, errors.New("不支持该连接指令") } // 跳过保留字段 rsv _, err = utils.ReadByte(reader) if err != nil { return nil, err } // 获取目标地址 dest, err := parseTarget(reader, writer) if err != nil { return nil, err } request := &Request{ Command: command, DestAddr: dest, } return request, nil } func parseTarget(reader io.Reader, writer io.Writer) (*core.FwdAddr, error) { dest := &core.FwdAddr{} aTypeBuf, err := utils.ReadByte(reader) if err != nil { return nil, err } switch aTypeBuf { case ipv4Address: addr := make([]byte, 4) _, err := io.ReadFull(reader, addr) if err != nil { return nil, err } dest.IP = addr case ipv6Address: addr := make([]byte, 16) _, err := io.ReadFull(reader, addr) if err != nil { return nil, err } dest.IP = addr case fqdnAddress: aLenBuf := make([]byte, 1) _, err := reader.Read(aLenBuf) if err != nil { return nil, err } fqdnBuff := make([]byte, int(aLenBuf[0])) _, err = io.ReadFull(reader, fqdnBuff) if err != nil { return nil, err } dest.Domain = string(fqdnBuff) // 域名解析 addr, err := net.ResolveIPAddr("ip", dest.Domain) if err != nil { err := sendReply(writer, hostUnreachable, nil) if err != nil { return nil, fmt.Errorf("failed to send reply: %v", err) } return nil, fmt.Errorf("failed to resolve destination '%v': %v", dest.Domain, err) } dest.IP = addr.IP default: err := sendReply(writer, addrTypeNotSupported, nil) if err != nil { return nil, err } return nil, fmt.Errorf("unrecognized address type") } portBuf := make([]byte, 2) _, err = io.ReadFull(reader, portBuf) if err != nil { return nil, err } dest.Port = int(binary.BigEndian.Uint16(portBuf)) return dest, nil } func sendReply(w io.Writer, resp uint8, addr *core.FwdAddr) error { var addrType uint8 var addrBody []byte var addrPort uint16 switch { case addr == nil: addrType = ipv4Address addrBody = []byte{0, 0, 0, 0} addrPort = 0 case addr.Domain != "": addrType = fqdnAddress addrBody = append([]byte{byte(len(addr.Domain))}, addr.Domain...) addrPort = uint16(addr.Port) case addr.IP.To4() != nil: addrType = ipv4Address addrBody = addr.IP.To4() addrPort = uint16(addr.Port) case addr.IP.To16() != nil: addrType = ipv6Address addrBody = addr.IP.To16() addrPort = uint16(addr.Port) default: return fmt.Errorf("failed to format address: %v", addr) } msg := make([]byte, 6+len(addrBody)) msg[0] = Version msg[1] = resp msg[2] = 0 // Reserved msg[3] = addrType copy(msg[4:], addrBody) msg[4+len(addrBody)] = byte(addrPort >> 8) msg[4+len(addrBody)+1] = byte(addrPort & 0xff) _, err := w.Write(msg) return err } func SendSuccess(user net.Conn, target net.Conn) error { local := target.LocalAddr().(*net.TCPAddr) bind := core.FwdAddr{IP: local.IP, Port: local.Port} err := sendReply(user, successReply, &bind) if err != nil { return err } return nil }