package socks5 import ( "context" "fmt" "io" "log/slog" "net" "proxy-server/pkg/utils" "strconv" "github.com/pkg/errors" ) const ( ConnectCommand = byte(1) BindCommand = byte(2) AssociateCommand = byte(3) ipv4Address = byte(1) fqdnAddress = byte(3) ipv6Address = byte(4) ) const ( successReply byte = iota serverFailure ruleFailure networkUnreachable hostUnreachable connectionRefused ttlExpired commandNotSupported addrTypeNotSupported ) var ( unrecognizedAddrType = fmt.Errorf("Unrecognized address type") ) // AddressRewriter is used to rewrite a destination transparently type AddressRewriter interface { Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) } // AddrSpec 地址 type AddrSpec struct { FQDN string IP net.IP Port int } func (a AddrSpec) String() string { if a.FQDN != "" { return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) } return fmt.Sprintf("%s:%d", a.IP, a.Port) } // Address returns a string suitable to dial; prefer returning IP-based // address, fallback to FQDN func (a AddrSpec) Address() string { if 0 != len(a.IP) { return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) } return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) } func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) { // 检查版本 err := checkVersion(reader) if err != nil { return nil, err } // 检查连接命令 command, err := utils.ReadByte(reader) if err != nil { return nil, err } slog.Debug("客户端使用的连接指令:%v", command) if command != ConnectCommand && command != BindCommand && command != AssociateCommand { err = sendReply(writer, commandNotSupported, nil) if err != nil { return nil, err } return nil, errors.New("不支持该连接指令") } // 跳过保留字段 rsv _, err = utils.ReadByte(reader) if err != nil { return nil, err } // 获取目标地址 dest, err := server.parseTarget(reader, writer) if err != nil { return nil, err } request := &Request{ Version: SocksVersion, Command: command, DestAddr: dest, bufConn: reader, } return request, nil } func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) { dest := &AddrSpec{} aTypeBuf := make([]byte, 1) _, err := reader.Read(aTypeBuf) if err != nil { return nil, err } switch aTypeBuf[0] { case ipv4Address: addr := make([]byte, 4) _, err := io.ReadFull(reader, addr) if err != nil { return nil, err } dest.IP = addr case ipv6Address: addr := make([]byte, 16) _, err := io.ReadFull(reader, addr) if err != nil { return nil, err } dest.IP = addr case fqdnAddress: aLenBuf := make([]byte, 1) _, err := reader.Read(aLenBuf) if err != nil { return nil, err } fqdnBuff := make([]byte, int(aLenBuf[0])) _, err = io.ReadFull(reader, fqdnBuff) if err != nil { return nil, err } dest.FQDN = string(fqdnBuff) // 域名解析 addr, err := server.config.Resolver.Resolve(dest.FQDN) if err != nil { err := sendReply(writer, hostUnreachable, nil) if err != nil { return nil, fmt.Errorf("Failed to send reply: %v", err) } return nil, fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) } dest.IP = addr default: err := sendReply(writer, addrTypeNotSupported, nil) if err != nil { return nil, err } return nil, unrecognizedAddrType } portBuf := make([]byte, 2) _, err = io.ReadFull(reader, portBuf) if err != nil { return nil, err } dest.Port = (int(portBuf[0]) << 8) | int(portBuf[1]) return dest, nil } // A Request represents request received by a server type Request struct { // Protocol version Version uint8 // Requested command Command uint8 // AuthContext provided during negotiation AuthContext *AuthContext // AddrSpec of the network that sent the request RemoteAddr *AddrSpec // AddrSpec of the desired destination DestAddr *AddrSpec // AddrSpec of the actual destination (might be affected by rewrite) realDestAddr *AddrSpec bufConn io.Reader } func (server *Server) handle(req *Request, conn net.Conn) error { ctx := context.Background() // 目标地址重写 req.realDestAddr = req.DestAddr if server.config.Rewriter != nil { ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req) } // 根据协商方法建立连接 switch req.Command { case ConnectCommand: return server.handleConnect(ctx, conn, req) case BindCommand: return server.handleBind(ctx, conn, req) case AssociateCommand: return server.handleAssociate(ctx, conn, req) default: return fmt.Errorf("unsupported command: %v", req.Command) } } func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error { // 检查规则集约束 server.config.Logger.Printf("检查约束规则\n") if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) } return fmt.Errorf("request to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接") server.Conn <- ProxyData{conn, req.realDestAddr.Address()} return nil // 与目标服务器建立连接 // server.config.Logger.Printf("与目标服务器建立连接\n") // dial := server.config.Dial // target, err := dial("tcp", req.realDestAddr.Address()) // if err != nil { // msg := err.Error() // resp := hostUnreachable // if strings.Contains(msg, "refused") { // resp = connectionRefused // } else if strings.Contains(msg, "network is unreachable") { // resp = networkUnreachable // } // // err := sendReply(Conn, resp, nil) // if err != nil { // return fmt.Errorf("failed to send reply: %v", err) // } // return fmt.Errorf("request to %v failed: %v", req.DestAddr, err) // } // defer closeConnection(target) // // // 正常响应 // slog.Info("连接成功,开始代理流量") // // local := target.LocalAddr().(*net.TCPAddr) // bind := AddrSpec{IP: local.IP, Port: local.Port} // err = sendReply(Conn, successReply, &bind) // if err != nil { // return fmt.Errorf("Failed to send reply: %v", err) // } // // // 配置超时时间和行为 // timeout := req.AuthContext.Timeout // slog.Debug("超时时间", "timeout", timeout) // // timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) // defer cancel() // // // 代理流量 // errChan := make(chan error, 2) // go func() { // _, err = io.Copy(target, req.bufConn) // errChan <- err // }() // go func() { // _, err = io.Copy(Conn, target) // errChan <- err // }() // // for { // select { // // case <-timeoutCtx.Done(): // slog.Debug("超时断开连接") // // todo 根据 termination 执行不同的断开行为 // return nil // // case err := <-errChan: // slog.Debug("主动断开连接") // if err != nil { // return errors.Wrap(err, "代理流量出现错误") // } // return nil // } // } } func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error { // Check if this is allowed if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support bind if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error { // Check if this is allowed if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) } else { ctx = ctx_ } // TODO: Support associate if err := sendReply(conn, commandNotSupported, nil); err != nil { return fmt.Errorf("Failed to send reply: %v", err) } return nil } func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { var addrType uint8 var addrBody []byte var addrPort uint16 switch { case addr == nil: addrType = ipv4Address addrBody = []byte{0, 0, 0, 0} addrPort = 0 case addr.FQDN != "": addrType = fqdnAddress addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) addrPort = uint16(addr.Port) case addr.IP.To4() != nil: addrType = ipv4Address addrBody = addr.IP.To4() addrPort = uint16(addr.Port) case addr.IP.To16() != nil: addrType = ipv6Address addrBody = addr.IP.To16() addrPort = uint16(addr.Port) default: return fmt.Errorf("failed to format address: %v", addr) } msg := make([]byte, 6+len(addrBody)) msg[0] = SocksVersion msg[1] = resp msg[2] = 0 // Reserved msg[3] = addrType copy(msg[4:], addrBody) msg[4+len(addrBody)] = byte(addrPort >> 8) msg[4+len(addrBody)+1] = byte(addrPort & 0xff) _, err := w.Write(msg) return err } func SendSuccess(user net.Conn, target net.Conn) { local := target.LocalAddr().(*net.TCPAddr) bind := AddrSpec{IP: local.IP, Port: local.Port} err := sendReply(user, successReply, &bind) if err != nil { slog.Error("Failed to send reply", err) } } type ProxyData struct { // 用户连入的连接 Conn net.Conn // 用户目标地址 Dest string } func (d ProxyData) Tag() string { local := d.Conn.LocalAddr() remote := d.Conn.RemoteAddr() return fmt.Sprintf("%s-%s", remote, local) }