优化数据分析和日志记录,重构连接管理,添加对空缓冲区的处理

This commit is contained in:
2025-02-28 17:50:48 +08:00
parent 06bcaf8bc7
commit b8a3dd93dc
8 changed files with 324 additions and 151 deletions

View File

@@ -15,6 +15,9 @@ func ReadByte(reader io.Reader) (byte, error) {
} }
func ReadBuffer(reader io.Reader, size int) ([]byte, error) { func ReadBuffer(reader io.Reader, size int) ([]byte, error) {
if size == 0 {
return []byte{}, nil
}
buffer := make([]byte, size) buffer := make([]byte, size)
_, err := io.ReadFull(reader, buffer) _, err := io.ReadFull(reader, buffer)
if err != nil { if err != nil {

View File

@@ -2,54 +2,307 @@ package fwd
import ( import (
"bufio" "bufio"
"encoding/binary"
"io" "io"
"log/slog" "log/slog"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/socks"
"strings"
"github.com/pkg/errors"
) )
func analysis(reader io.Reader) { func analysisAndLog(conn socks.ProxyConn, reader io.Reader) error {
buf := bufio.NewReader(reader) buf := bufio.NewReader(reader)
first, err := buf.Peek(8)
if err != nil {
slog.Error("analysis peek error", "err", err)
} else {
switch { domain, proto, err := sniffing(buf)
case first[0] == 0x16: if err != nil {
analysisHttps(reader) err = errors.Wrap(err, "analysis sniffing error")
case } else {
string(first[:4]) == "GET ", slog.Info(
// string(first[:4]) == "PUT ", "用户访问记录",
string(first[:5]) == "POST ": slog.Uint64("uid", uint64(conn.Uid)),
// string(first[:4]) == "HEAD ", slog.String("user", conn.Conn.RemoteAddr().String()),
// string(first[:4]) == "TRACE ", slog.String("proxy", "socks"),
// string(first[:4]) == "PATCH ", slog.String("node", conn.Conn.LocalAddr().String()),
// string(first[:4]) == "DELETE ", slog.String("proto", proto),
// string(first[:4]) == "CONNECT ", slog.String("dest", conn.Dest),
// string(first[:4]) == "OPTIONS ": slog.String("domain", domain),
analysisHttp(reader) )
}
go func() {
discord(buf)
}()
return err
}
func sniffing(reader *bufio.Reader) (string, string, error) {
peek, err := reader.Peek(8)
if err != nil {
return "", "", errors.Wrap(err, "sniffing peek error")
}
method, ok := isHttp(peek)
if ok {
domain, err := analysisHttp(reader)
return domain, "http(" + method + ")", err
}
tlsType, tlsVersion, ok := isTls(peek)
if ok {
var domain string
if tlsType == "handshake" {
domain, err = analysisTls(reader)
}
return domain, "tls(" + tlsType + "," + tlsVersion + ")", err
}
return "nil", "tcp", nil
}
func isHttp(bytes []byte) (string, bool) {
var blankIndex int
for i := range bytes {
if bytes[i] == ' ' {
blankIndex = i
break
} }
} }
discord(reader) method := string(bytes[:blankIndex])
switch method {
case "GET",
"POST",
"PUT",
"PATCH",
"DELETE",
"HEAD",
"OPTIONS",
"TRACE",
"CONNECT":
return method, true
}
return "", false
} }
func analysisHttp(reader io.Reader) { func isTls(bytes []byte) (string, string, bool) {
var tlsType string
switch bytes[0] {
case 0x14:
tlsType = "change-cipher-spec"
case 0x15:
tlsType = "alert"
case 0x16:
tlsType = "handshake"
case 0x17:
tlsType = "application-data"
}
var tlsVersion string
if bytes[1] == 0x03 {
switch bytes[2] {
case 0x00:
tlsVersion = "SSL3.0"
case 0x01:
tlsVersion = "TLS1.0"
case 0x02:
tlsVersion = "TLS1.1"
case 0x03:
tlsVersion = "TLS1.2"
}
}
if tlsType != "" && tlsVersion != "" {
return tlsType, tlsVersion, true
} else {
return "", "", false
}
} }
func analysisHttps(reader io.Reader) { func analysisHttp(reader *bufio.Reader) (string, error) {
slog.Debug("analysis http")
head, err := utils.ReadBuffer(reader, 5) // reade top
top, err := httpReadLine(reader)
if err != nil { if err != nil {
slog.Error("analysis https err", "err", err) return "", errors.Wrap(err, "analysis http read top error")
return
} }
if head[1] == 0x03 && head[2] == 0x03 { // read header
// tls1.2 host := strings.Split(top, " ")[1]
for {
line, err := httpReadLine(reader)
if err != nil {
return "", err
}
if line == "" {
break
}
if strings.HasPrefix(line, "Host: ") {
host = strings.TrimPrefix(line, "Host: ")
}
} }
return host, nil
} }
func discord(reader io.Reader) { func httpReadLine(reader *bufio.Reader) (line string, err error) {
var lineStr strings.Builder
for {
line, prefix, err := reader.ReadLine()
if err != nil {
return "", errors.Wrap(err, "analysis http read line error")
}
lineStr.Write(line)
if !prefix {
break
}
}
return lineStr.String(), nil
}
func analysisTls(reader *bufio.Reader) (string, error) {
slog.Debug("analysis https")
// tls record
_, err := utils.ReadBuffer(reader, 5)
if err != nil {
return "", errors.Wrap(err, "analysis https read head error")
}
// tls type
hsType, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read hsType error")
}
switch hsType {
case 0x01: // client hello
// length
_, err = utils.ReadBuffer(reader, 3)
if err != nil {
return "", errors.Wrap(err, "analysis https read tls length error")
}
// version
_, err = utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read version error")
}
// random
_, err = utils.ReadBuffer(reader, 32)
if err != nil {
return "", errors.Wrap(err, "analysis https read random error")
}
// session id length
sessionIdLen, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read sessionIdLen error")
}
// session id
_, err = utils.ReadBuffer(reader, int(sessionIdLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read sessionId error")
}
// cipher suites length
cLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read cLen error")
}
cLen := binary.BigEndian.Uint16(cLenBuf)
// cipher suites
_, err = utils.ReadBuffer(reader, int(cLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read c error")
}
// compression methods length
cmLen, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read cmLen error")
}
// compression methods
_, err = utils.ReadBuffer(reader, int(cmLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read cm error")
}
// extensions length
eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read eLen error")
}
eLen := binary.BigEndian.Uint16(eLenBuf)
// extensions
host := ""
for i := 0; i < int(eLen); {
// extension type
eTypeBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read extension type error")
}
eType := binary.BigEndian.Uint16(eTypeBuf)
// extension length
eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read extension length error")
}
eLen := binary.BigEndian.Uint16(eLenBuf)
// server name
if eType == 0x00 {
// server name list length
_, err = utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read server name list length error")
}
// server name type
_, err = reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read server name type error")
}
// server name length
sLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read server name length error")
}
sLen := binary.BigEndian.Uint16(sLenBuf)
// server name
bytes, err := utils.ReadBuffer(reader, int(sLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read server name error")
}
host = string(bytes)
return host, nil
} else {
// other extension
_, err = utils.ReadBuffer(reader, int(eLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read extension error")
}
}
i += 4 + int(eLen)
}
default:
return "", nil
}
return "", errors.New("analysis https error")
}
func discord(reader *bufio.Reader) {
_, err := io.Copy(io.Discard, reader) _, err := io.Copy(io.Discard, reader)
if err != nil { if err != nil {
slog.Error("analysis discord err", "err", err) slog.Error("analysis discord err", "err", err)

View File

@@ -24,7 +24,7 @@ type Service struct {
Config *Config Config *Config
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
userConnMap map[string]socks.ProxyConn userConnMap sync.Map
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
@@ -41,7 +41,7 @@ func New(config *Config) *Service {
Config: config, Config: config,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
userConnMap: make(map[string]socks.ProxyConn), userConnMap: sync.Map{},
ctrlConnWg: utils.CountWaitGroup{}, ctrlConnWg: utils.CountWaitGroup{},
dataConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{},
fwdLesWg: utils.CountWaitGroup{}, fwdLesWg: utils.CountWaitGroup{},
@@ -96,10 +96,13 @@ func (s *Service) Run() {
} }
// 清理资源 // 清理资源
for _, conn := range s.userConnMap { s.userConnMap.Range(func(key, value any) bool {
conn := value.(socks.ProxyConn)
utils.Close(conn) utils.Close(conn)
} s.userConnMap.Delete(key)
clear(s.userConnMap) return true
})
s.userConnMap.Clear()
s.ctrlConnWg.Wait() s.ctrlConnWg.Wait()
slog.Debug("控制通道连接已关闭") slog.Debug("控制通道连接已关闭")
@@ -143,7 +146,7 @@ func (s *Service) startCtrlTun() error {
defer utils.Close(conn) defer utils.Close(conn)
err := s.processCtrlConn(conn) err := s.processCtrlConn(conn)
if err != nil { if err != nil {
slog.Error("处理控制通道连接失败", err) slog.Error("处理控制通道连接失败", "err", err)
} }
}() }()
} }
@@ -213,7 +216,7 @@ func (s *Service) processCtrlConn(controller net.Conn) error {
slog.Error("向客户端发送 tag 失败", "err", err) slog.Error("向客户端发送 tag 失败", "err", err)
return return
} }
s.userConnMap[tag] = user s.userConnMap.Store(tag, user)
}() }()
} }
} }
@@ -254,7 +257,7 @@ func (s *Service) startDataTun() error {
defer utils.Close(conn) defer utils.Close(conn)
err := s.processDataConn(conn) err := s.processDataConn(conn)
if err != nil { if err != nil {
slog.Error("处理数据通道失败", err) slog.Error("处理数据通道失败", "err", err)
} }
}() }()
} }
@@ -277,17 +280,17 @@ func (s *Service) processDataConn(client net.Conn) error {
// 找到用户连接 // 找到用户连接
var data socks.ProxyConn var data socks.ProxyConn
var ok bool
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
return nil return nil
default: default:
data, ok = s.userConnMap[tag] dataAny, ok := s.userConnMap.Load(tag)
if !ok { if !ok {
return errors.New("查找用户连接失败") return errors.New("查找用户连接失败")
} }
data = dataAny.(socks.ProxyConn)
defer func() { defer func() {
delete(s.userConnMap, tag) s.userConnMap.Delete(tag)
utils.Close(data) utils.Close(data)
}() }()
} }
@@ -314,15 +317,21 @@ func (s *Service) processDataConn(client net.Conn) error {
// 数据转发 // 数据转发
slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest) slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest)
// userPipeReader, userPipeWriter := io.Pipe() userPipeReader, userPipeWriter := io.Pipe()
// defer utils.Close(userPipeWriter) defer utils.Close(userPipeWriter)
// teeUser := io.TeeReader(user, userPipeWriter) teeUser := io.TeeReader(user, userPipeWriter)
go func() {
err := analysisAndLog(data, userPipeReader)
if err != nil {
slog.Error("数据解析失败", "err", err)
}
}()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := io.Copy(client, user) _, err := io.Copy(client, teeUser)
if err != nil { if err != nil {
slog.Error("数据转发失败 user->client", "err", err) slog.Error("数据转发失败 user->client", "err", err)
} }
@@ -332,7 +341,11 @@ func (s *Service) processDataConn(client net.Conn) error {
defer wg.Done() defer wg.Done()
_, err := io.Copy(user, client) _, err := io.Copy(user, client)
if err != nil { if err != nil {
slog.Error("数据转发失败 client->user", "err", err) if errors.Is(err, net.ErrClosed) {
} else {
// slog.Error("数据转发失败 client->user", "err", err, "errType", reflect.TypeOf(err))
}
} }
}() }()
wg.Wait() wg.Wait()

4
server/fwd/http/http.go Normal file
View File

@@ -0,0 +1,4 @@
package http
func Start() {
}

View File

@@ -49,27 +49,6 @@ type AddrSpec struct {
Port int Port int
} }
func (a AddrSpec) Domain() []string {
if a.FQDN != "" {
return []string{a.FQDN}
}
var domain []string
ch := make(chan struct{})
defer close(ch)
go func() {
addr, err := net.LookupAddr(a.IP.String())
if err == nil {
domain = addr
}
ch <- struct{}{}
}()
<-ch
return domain
}
func (a AddrSpec) String() string { func (a AddrSpec) String() string {
if a.FQDN != "" { if a.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
@@ -258,7 +237,11 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request)
if conn != nil { if conn != nil {
utils.Close(conn) utils.Close(conn)
} }
case s.Conn <- ProxyConn{conn, req.realDestAddr.Address()}: case s.Conn <- ProxyConn{
req.Authentication.Payload.ID,
conn,
req.realDestAddr.Address(),
}:
} }
return nil return nil
} }
@@ -352,6 +335,7 @@ func SendSuccess(user net.Conn, target net.Conn) error {
} }
type ProxyConn struct { type ProxyConn struct {
Uid uint
// 用户连入的连接 // 用户连入的连接
Conn net.Conn Conn net.Conn
// 用户目标地址 // 用户目标地址

View File

@@ -11,7 +11,6 @@ import (
"os" "os"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@@ -191,18 +190,6 @@ func (s *Server) process(conn net.Conn) error {
slog.Debug("连接请求处理完成") slog.Debug("连接请求处理完成")
} }
// 记录日志
go func() {
slog.Info(
"用户访问记录",
slog.Uint64("uid", uint64(authContext.Payload.ID)),
slog.String("user", conn.RemoteAddr().String()),
slog.Any("node", conn.LocalAddr().String()),
slog.String("dest", request.DestAddr.Address()),
slog.String("domain", strings.Join(request.DestAddr.Domain(), ",")),
)
}()
request.Authentication = authContext request.Authentication = authContext
user, ok := conn.RemoteAddr().(*net.TCPAddr) user, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok { if !ok {

View File

@@ -1,65 +0,0 @@
package mnt
import (
"context"
"log/slog"
"github.com/google/gopacket"
"github.com/google/gopacket/pcap"
"github.com/pkg/errors"
"golang.org/x/text/encoding/simplifiedchinese"
)
func Start(ctx context.Context, errCh chan error) {
// 打开一个网络接口
device, err := pcap.OpenLive("WLAN", 1600, true, pcap.BlockForever)
if err != nil {
gbk := simplifiedchinese.GBK.NewDecoder()
errMsg, err := gbk.String(err.Error())
if err != nil {
errMsg = err.Error()
}
errCh <- errors.Wrap(err, "打开网络接口失败, "+errMsg)
return
}
defer device.Close()
err = device.SetBPFFilter("tcp")
if err != nil {
errCh <- errors.Wrap(err, "设置 BPF 过滤器失败")
return
}
err = device.SetDirection(pcap.DirectionIn)
if err != nil {
errCh <- errors.Wrap(err, "设置捕获方向失败")
return
}
// 创建一个数据包源
source := gopacket.NewPacketSource(device, device.LinkType())
source.NoCopy = true
source.Lazy = true
for {
select {
case <-ctx.Done():
slog.Debug("monitor 被动结束")
errCh <- nil
return
case packet := <-source.Packets():
handle(packet)
}
}
}
func handle(packet gopacket.Packet) {
slog.Debug("Packet: ", packet)
slog.Debug("Layers: ", packet.Layers())
slog.Debug("Application: ", packet.ApplicationLayer())
slog.Debug("Transport: ", packet.TransportLayer())
slog.Debug("Network: ", packet.NetworkLayer())
slog.Debug("Link: ", packet.LinkLayer())
}

View File

@@ -96,20 +96,14 @@ func initLog() {
func startFwdServer(ctx context.Context) error { func startFwdServer(ctx context.Context) error {
server := fwd.New(nil) server := fwd.New(nil)
go func() { go func() {
<-ctx.Done() <-ctx.Done()
server.Close() server.Close()
}() }()
server.Run() server.Run()
return nil return nil
} }
func startMntServer(ctx context.Context) {
}
func startWebServer(ctx context.Context) { func startWebServer(ctx context.Context) {
} }