diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c288da2..4dc6242 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -15,6 +15,9 @@ func ReadByte(reader io.Reader) (byte, error) { } func ReadBuffer(reader io.Reader, size int) ([]byte, error) { + if size == 0 { + return []byte{}, nil + } buffer := make([]byte, size) _, err := io.ReadFull(reader, buffer) if err != nil { diff --git a/server/fwd/analysis.go b/server/fwd/analysis.go index 6bcc96e..1d1b686 100644 --- a/server/fwd/analysis.go +++ b/server/fwd/analysis.go @@ -2,54 +2,307 @@ package fwd import ( "bufio" + "encoding/binary" "io" "log/slog" "proxy-server/pkg/utils" + "proxy-server/server/fwd/socks" + "strings" + + "github.com/pkg/errors" ) -func analysis(reader io.Reader) { +func analysisAndLog(conn socks.ProxyConn, reader io.Reader) error { buf := bufio.NewReader(reader) - first, err := buf.Peek(8) - if err != nil { - slog.Error("analysis peek error", "err", err) - } else { - switch { - case first[0] == 0x16: - analysisHttps(reader) - case - string(first[:4]) == "GET ", - // string(first[:4]) == "PUT ", - string(first[:5]) == "POST ": - // string(first[:4]) == "HEAD ", - // string(first[:4]) == "TRACE ", - // string(first[:4]) == "PATCH ", - // string(first[:4]) == "DELETE ", - // string(first[:4]) == "CONNECT ", - // string(first[:4]) == "OPTIONS ": - analysisHttp(reader) + domain, proto, err := sniffing(buf) + if err != nil { + err = errors.Wrap(err, "analysis sniffing error") + } else { + slog.Info( + "用户访问记录", + slog.Uint64("uid", uint64(conn.Uid)), + slog.String("user", conn.Conn.RemoteAddr().String()), + slog.String("proxy", "socks"), + slog.String("node", conn.Conn.LocalAddr().String()), + slog.String("proto", proto), + slog.String("dest", conn.Dest), + slog.String("domain", domain), + ) + } + go func() { + discord(buf) + }() + return err +} + +func sniffing(reader *bufio.Reader) (string, string, error) { + peek, err := reader.Peek(8) + if err != nil { + return "", "", errors.Wrap(err, "sniffing peek error") + } + + method, ok := isHttp(peek) + if ok { + domain, err := analysisHttp(reader) + return domain, "http(" + method + ")", err + } + + tlsType, tlsVersion, ok := isTls(peek) + if ok { + var domain string + if tlsType == "handshake" { + domain, err = analysisTls(reader) + } + return domain, "tls(" + tlsType + "," + tlsVersion + ")", err + } + + return "nil", "tcp", nil +} + +func isHttp(bytes []byte) (string, bool) { + + var blankIndex int + for i := range bytes { + if bytes[i] == ' ' { + blankIndex = i + break } } - discord(reader) + method := string(bytes[:blankIndex]) + + switch method { + case "GET", + "POST", + "PUT", + "PATCH", + "DELETE", + "HEAD", + "OPTIONS", + "TRACE", + "CONNECT": + return method, true + } + + return "", false } -func analysisHttp(reader io.Reader) { +func isTls(bytes []byte) (string, string, bool) { + + var tlsType string + switch bytes[0] { + case 0x14: + tlsType = "change-cipher-spec" + case 0x15: + tlsType = "alert" + case 0x16: + tlsType = "handshake" + case 0x17: + tlsType = "application-data" + } + + var tlsVersion string + if bytes[1] == 0x03 { + switch bytes[2] { + case 0x00: + tlsVersion = "SSL3.0" + case 0x01: + tlsVersion = "TLS1.0" + case 0x02: + tlsVersion = "TLS1.1" + case 0x03: + tlsVersion = "TLS1.2" + } + } + if tlsType != "" && tlsVersion != "" { + return tlsType, tlsVersion, true + } else { + return "", "", false + } } -func analysisHttps(reader io.Reader) { - - head, err := utils.ReadBuffer(reader, 5) +func analysisHttp(reader *bufio.Reader) (string, error) { + slog.Debug("analysis http") + // reade top + top, err := httpReadLine(reader) if err != nil { - slog.Error("analysis https err", "err", err) - return + return "", errors.Wrap(err, "analysis http read top error") } - if head[1] == 0x03 && head[2] == 0x03 { - // tls1.2 + // read header + host := strings.Split(top, " ")[1] + for { + line, err := httpReadLine(reader) + if err != nil { + return "", err + } + if line == "" { + break + } + if strings.HasPrefix(line, "Host: ") { + host = strings.TrimPrefix(line, "Host: ") + } } + + return host, nil } -func discord(reader io.Reader) { +func httpReadLine(reader *bufio.Reader) (line string, err error) { + + var lineStr strings.Builder + for { + line, prefix, err := reader.ReadLine() + if err != nil { + return "", errors.Wrap(err, "analysis http read line error") + } + lineStr.Write(line) + if !prefix { + break + } + } + return lineStr.String(), nil +} + +func analysisTls(reader *bufio.Reader) (string, error) { + slog.Debug("analysis https") + + // tls record + _, err := utils.ReadBuffer(reader, 5) + if err != nil { + return "", errors.Wrap(err, "analysis https read head error") + } + + // tls type + hsType, err := reader.ReadByte() + if err != nil { + return "", errors.Wrap(err, "analysis https read hsType error") + } + + switch hsType { + case 0x01: // client hello + + // length + _, err = utils.ReadBuffer(reader, 3) + if err != nil { + return "", errors.Wrap(err, "analysis https read tls length error") + } + + // version + _, err = utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read version error") + } + + // random + _, err = utils.ReadBuffer(reader, 32) + if err != nil { + return "", errors.Wrap(err, "analysis https read random error") + } + + // session id length + sessionIdLen, err := reader.ReadByte() + if err != nil { + return "", errors.Wrap(err, "analysis https read sessionIdLen error") + } + // session id + _, err = utils.ReadBuffer(reader, int(sessionIdLen)) + if err != nil { + return "", errors.Wrap(err, "analysis https read sessionId error") + } + + // cipher suites length + cLenBuf, err := utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read cLen error") + } + cLen := binary.BigEndian.Uint16(cLenBuf) + // cipher suites + _, err = utils.ReadBuffer(reader, int(cLen)) + if err != nil { + return "", errors.Wrap(err, "analysis https read c error") + } + + // compression methods length + cmLen, err := reader.ReadByte() + if err != nil { + return "", errors.Wrap(err, "analysis https read cmLen error") + } + // compression methods + _, err = utils.ReadBuffer(reader, int(cmLen)) + if err != nil { + return "", errors.Wrap(err, "analysis https read cm error") + } + + // extensions length + eLenBuf, err := utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read eLen error") + } + eLen := binary.BigEndian.Uint16(eLenBuf) + + // extensions + host := "" + + for i := 0; i < int(eLen); { + + // extension type + eTypeBuf, err := utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read extension type error") + } + eType := binary.BigEndian.Uint16(eTypeBuf) + + // extension length + eLenBuf, err := utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read extension length error") + } + eLen := binary.BigEndian.Uint16(eLenBuf) + + // server name + if eType == 0x00 { + // server name list length + _, err = utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read server name list length error") + } + // server name type + _, err = reader.ReadByte() + if err != nil { + return "", errors.Wrap(err, "analysis https read server name type error") + } + // server name length + sLenBuf, err := utils.ReadBuffer(reader, 2) + if err != nil { + return "", errors.Wrap(err, "analysis https read server name length error") + } + sLen := binary.BigEndian.Uint16(sLenBuf) + // server name + bytes, err := utils.ReadBuffer(reader, int(sLen)) + if err != nil { + return "", errors.Wrap(err, "analysis https read server name error") + } + + host = string(bytes) + return host, nil + + } else { + // other extension + _, err = utils.ReadBuffer(reader, int(eLen)) + if err != nil { + return "", errors.Wrap(err, "analysis https read extension error") + } + } + i += 4 + int(eLen) + } + default: + return "", nil + } + + return "", errors.New("analysis https error") +} + +func discord(reader *bufio.Reader) { _, err := io.Copy(io.Discard, reader) if err != nil { slog.Error("analysis discord err", "err", err) diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index 2b46fc9..c83f248 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -24,7 +24,7 @@ type Service struct { Config *Config ctx context.Context cancel context.CancelFunc - userConnMap map[string]socks.ProxyConn + userConnMap sync.Map ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup @@ -41,7 +41,7 @@ func New(config *Config) *Service { Config: config, ctx: ctx, cancel: cancel, - userConnMap: make(map[string]socks.ProxyConn), + userConnMap: sync.Map{}, ctrlConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{}, fwdLesWg: utils.CountWaitGroup{}, @@ -96,10 +96,13 @@ func (s *Service) Run() { } // 清理资源 - for _, conn := range s.userConnMap { + s.userConnMap.Range(func(key, value any) bool { + conn := value.(socks.ProxyConn) utils.Close(conn) - } - clear(s.userConnMap) + s.userConnMap.Delete(key) + return true + }) + s.userConnMap.Clear() s.ctrlConnWg.Wait() slog.Debug("控制通道连接已关闭") @@ -143,7 +146,7 @@ func (s *Service) startCtrlTun() error { defer utils.Close(conn) err := s.processCtrlConn(conn) if err != nil { - slog.Error("处理控制通道连接失败", err) + slog.Error("处理控制通道连接失败", "err", err) } }() } @@ -213,7 +216,7 @@ func (s *Service) processCtrlConn(controller net.Conn) error { slog.Error("向客户端发送 tag 失败", "err", err) return } - s.userConnMap[tag] = user + s.userConnMap.Store(tag, user) }() } } @@ -254,7 +257,7 @@ func (s *Service) startDataTun() error { defer utils.Close(conn) err := s.processDataConn(conn) if err != nil { - slog.Error("处理数据通道失败", err) + slog.Error("处理数据通道失败", "err", err) } }() } @@ -277,17 +280,17 @@ func (s *Service) processDataConn(client net.Conn) error { // 找到用户连接 var data socks.ProxyConn - var ok bool select { case <-s.ctx.Done(): return nil default: - data, ok = s.userConnMap[tag] + dataAny, ok := s.userConnMap.Load(tag) if !ok { return errors.New("查找用户连接失败") } + data = dataAny.(socks.ProxyConn) defer func() { - delete(s.userConnMap, tag) + s.userConnMap.Delete(tag) utils.Close(data) }() } @@ -314,15 +317,21 @@ func (s *Service) processDataConn(client net.Conn) error { // 数据转发 slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest) - // userPipeReader, userPipeWriter := io.Pipe() - // defer utils.Close(userPipeWriter) - // teeUser := io.TeeReader(user, userPipeWriter) + userPipeReader, userPipeWriter := io.Pipe() + defer utils.Close(userPipeWriter) + teeUser := io.TeeReader(user, userPipeWriter) + go func() { + err := analysisAndLog(data, userPipeReader) + if err != nil { + slog.Error("数据解析失败", "err", err) + } + }() wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - _, err := io.Copy(client, user) + _, err := io.Copy(client, teeUser) if err != nil { slog.Error("数据转发失败 user->client", "err", err) } @@ -332,7 +341,11 @@ func (s *Service) processDataConn(client net.Conn) error { defer wg.Done() _, err := io.Copy(user, client) if err != nil { - slog.Error("数据转发失败 client->user", "err", err) + if errors.Is(err, net.ErrClosed) { + + } else { + // slog.Error("数据转发失败 client->user", "err", err, "errType", reflect.TypeOf(err)) + } } }() wg.Wait() diff --git a/server/fwd/http/http.go b/server/fwd/http/http.go new file mode 100644 index 0000000..fc732a9 --- /dev/null +++ b/server/fwd/http/http.go @@ -0,0 +1,4 @@ +package http + +func Start() { +} diff --git a/server/fwd/socks/request.go b/server/fwd/socks/request.go index 691f2c2..e639004 100644 --- a/server/fwd/socks/request.go +++ b/server/fwd/socks/request.go @@ -49,27 +49,6 @@ type AddrSpec struct { Port int } -func (a AddrSpec) Domain() []string { - if a.FQDN != "" { - return []string{a.FQDN} - } - - var domain []string - - ch := make(chan struct{}) - defer close(ch) - go func() { - addr, err := net.LookupAddr(a.IP.String()) - if err == nil { - domain = addr - } - ch <- struct{}{} - }() - <-ch - - return domain -} - func (a AddrSpec) String() string { if a.FQDN != "" { return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) @@ -258,7 +237,11 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) if conn != nil { utils.Close(conn) } - case s.Conn <- ProxyConn{conn, req.realDestAddr.Address()}: + case s.Conn <- ProxyConn{ + req.Authentication.Payload.ID, + conn, + req.realDestAddr.Address(), + }: } return nil } @@ -352,6 +335,7 @@ func SendSuccess(user net.Conn, target net.Conn) error { } type ProxyConn struct { + Uid uint // 用户连入的连接 Conn net.Conn // 用户目标地址 diff --git a/server/fwd/socks/socks.go b/server/fwd/socks/socks.go index cb83898..e9fdde6 100644 --- a/server/fwd/socks/socks.go +++ b/server/fwd/socks/socks.go @@ -11,7 +11,6 @@ import ( "os" "proxy-server/pkg/utils" "strconv" - "strings" "time" "github.com/pkg/errors" @@ -191,18 +190,6 @@ func (s *Server) process(conn net.Conn) error { slog.Debug("连接请求处理完成") } - // 记录日志 - go func() { - slog.Info( - "用户访问记录", - slog.Uint64("uid", uint64(authContext.Payload.ID)), - slog.String("user", conn.RemoteAddr().String()), - slog.Any("node", conn.LocalAddr().String()), - slog.String("dest", request.DestAddr.Address()), - slog.String("domain", strings.Join(request.DestAddr.Domain(), ",")), - ) - }() - request.Authentication = authContext user, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { diff --git a/server/mnt/mnt.go b/server/mnt/mnt.go deleted file mode 100644 index d671bce..0000000 --- a/server/mnt/mnt.go +++ /dev/null @@ -1,65 +0,0 @@ -package mnt - -import ( - "context" - "log/slog" - - "github.com/google/gopacket" - "github.com/google/gopacket/pcap" - "github.com/pkg/errors" - "golang.org/x/text/encoding/simplifiedchinese" -) - -func Start(ctx context.Context, errCh chan error) { - - // 打开一个网络接口 - device, err := pcap.OpenLive("WLAN", 1600, true, pcap.BlockForever) - if err != nil { - gbk := simplifiedchinese.GBK.NewDecoder() - errMsg, err := gbk.String(err.Error()) - if err != nil { - errMsg = err.Error() - } - errCh <- errors.Wrap(err, "打开网络接口失败, "+errMsg) - return - } - defer device.Close() - - err = device.SetBPFFilter("tcp") - if err != nil { - errCh <- errors.Wrap(err, "设置 BPF 过滤器失败") - return - } - - err = device.SetDirection(pcap.DirectionIn) - if err != nil { - errCh <- errors.Wrap(err, "设置捕获方向失败") - return - } - - // 创建一个数据包源 - source := gopacket.NewPacketSource(device, device.LinkType()) - source.NoCopy = true - source.Lazy = true - - for { - select { - case <-ctx.Done(): - slog.Debug("monitor 被动结束") - errCh <- nil - return - - case packet := <-source.Packets(): - handle(packet) - } - } -} - -func handle(packet gopacket.Packet) { - slog.Debug("Packet: ", packet) - slog.Debug("Layers: ", packet.Layers()) - slog.Debug("Application: ", packet.ApplicationLayer()) - slog.Debug("Transport: ", packet.TransportLayer()) - slog.Debug("Network: ", packet.NetworkLayer()) - slog.Debug("Link: ", packet.LinkLayer()) -} diff --git a/server/server.go b/server/server.go index a54f202..00a5662 100644 --- a/server/server.go +++ b/server/server.go @@ -96,20 +96,14 @@ func initLog() { func startFwdServer(ctx context.Context) error { server := fwd.New(nil) - go func() { <-ctx.Done() server.Close() }() - server.Run() return nil } -func startMntServer(ctx context.Context) { - -} - func startWebServer(ctx context.Context) { }