实现自定义 wg 以统计协程数量
This commit is contained in:
48
pkg/utils/chan.go
Normal file
48
pkg/utils/chan.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||||
|
connCh := make(chan net.Conn)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := ls.Accept()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("接受连接失败", err)
|
||||||
|
// 临时错误重试连接
|
||||||
|
var ne net.Error
|
||||||
|
if errors.As(err, &ne) && ne.Temporary() {
|
||||||
|
slog.Debug("临时错误重试")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// ctx 取消后退出
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
Close(conn)
|
||||||
|
return
|
||||||
|
case connCh <- conn:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return connCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} {
|
||||||
|
ch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wg.Wait()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case ch <- struct{}{}:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
29
pkg/utils/sync.go
Normal file
29
pkg/utils/sync.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CountWaitGroup struct {
|
||||||
|
wg sync.WaitGroup
|
||||||
|
num atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountWaitGroup) Add(delta uint64) {
|
||||||
|
c.wg.Add(int(delta))
|
||||||
|
c.num.Add(delta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountWaitGroup) Done() {
|
||||||
|
c.wg.Done()
|
||||||
|
c.num.Add(-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountWaitGroup) Wait() {
|
||||||
|
c.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CountWaitGroup) Count() uint64 {
|
||||||
|
return c.num.Load()
|
||||||
|
}
|
||||||
@@ -1,13 +1,8 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func ReadByte(reader io.Reader) (byte, error) {
|
func ReadByte(reader io.Reader) (byte, error) {
|
||||||
@@ -36,42 +31,3 @@ func Close[T io.Closer](v T) {
|
|||||||
slog.Warn("对象关闭失败", "err", err)
|
slog.Warn("对象关闭失败", "err", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
|
||||||
connCh := make(chan net.Conn)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := ls.Accept()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("接受连接失败", err)
|
|
||||||
// 临时错误重试连接
|
|
||||||
var ne net.Error
|
|
||||||
if errors.As(err, &ne) && ne.Temporary() {
|
|
||||||
slog.Debug("临时错误重试")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// ctx 取消后退出
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
Close(conn)
|
|
||||||
return
|
|
||||||
case connCh <- conn:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return connCh
|
|
||||||
}
|
|
||||||
|
|
||||||
func WaitChan(ctx context.Context, wg *sync.WaitGroup) chan struct{} {
|
|
||||||
ch := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
wg.Wait()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
case ch <- struct{}{}:
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
return ch
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"proxy-server/server/pkg/socks5"
|
"proxy-server/server/pkg/socks5"
|
||||||
"proxy-server/server/web/app/models"
|
"proxy-server/server/web/app/models"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -24,9 +23,9 @@ type Config struct {
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
ConnMap map[string]socks5.ProxyData
|
connMap map[string]socks5.ProxyData
|
||||||
ctrlConnWg sync.WaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg sync.WaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *Service {
|
func New(config *Config) *Service {
|
||||||
@@ -37,9 +36,9 @@ func New(config *Config) *Service {
|
|||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
Config: _config,
|
Config: _config,
|
||||||
ConnMap: make(map[string]socks5.ProxyData),
|
connMap: make(map[string]socks5.ProxyData),
|
||||||
ctrlConnWg: sync.WaitGroup{},
|
ctrlConnWg: utils.CountWaitGroup{},
|
||||||
dataConnWg: sync.WaitGroup{},
|
dataConnWg: utils.CountWaitGroup{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,6 +110,7 @@ loop:
|
|||||||
slog.Debug("结束处理连接,由于获取连接失败")
|
slog.Debug("结束处理连接,由于获取连接失败")
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
|
s.ctrlConnWg.Add(1)
|
||||||
go s.processCtrlConn(conn)
|
go s.processCtrlConn(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -190,7 +190,7 @@ func (s *Service) processCtrlConn(controller net.Conn) {
|
|||||||
slog.Error("write error", err)
|
slog.Error("write error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.ConnMap[tag] = user
|
s.connMap[tag] = user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +222,7 @@ loop:
|
|||||||
slog.Debug("结束处理连接,由于获取连接失败")
|
slog.Debug("结束处理连接,由于获取连接失败")
|
||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
|
s.dataConnWg.Add(1)
|
||||||
go s.processDataConn(conn)
|
go s.processDataConn(conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -245,7 +246,10 @@ loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) processDataConn(client net.Conn) {
|
func (s *Service) processDataConn(client net.Conn) {
|
||||||
|
defer func() {
|
||||||
|
s.dataConnWg.Done()
|
||||||
|
utils.Close(client)
|
||||||
|
}()
|
||||||
slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String())
|
slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String())
|
||||||
|
|
||||||
// 读取 tag
|
// 读取 tag
|
||||||
@@ -262,7 +266,7 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
tag := string(tagBuf)
|
tag := string(tagBuf)
|
||||||
|
|
||||||
// 找到用户连接
|
// 找到用户连接
|
||||||
data, ok := s.ConnMap[tag]
|
data, ok := s.connMap[tag]
|
||||||
if !ok {
|
if !ok {
|
||||||
slog.Error("no such connection")
|
slog.Error("no such connection")
|
||||||
return
|
return
|
||||||
@@ -270,6 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
|
|
||||||
// 响应用户
|
// 响应用户
|
||||||
user := data.Conn
|
user := data.Conn
|
||||||
|
defer utils.Close(user)
|
||||||
socks5.SendSuccess(user, client)
|
socks5.SendSuccess(user, client)
|
||||||
|
|
||||||
// 写入目标地址
|
// 写入目标地址
|
||||||
@@ -303,16 +308,6 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
}()
|
}()
|
||||||
<-errCh
|
<-errCh
|
||||||
slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest)
|
slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest)
|
||||||
defer func() {
|
|
||||||
err := user.Close()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("close error", err)
|
|
||||||
}
|
|
||||||
err = client.Close()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("close error", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NoAuthAuthenticator struct {
|
type NoAuthAuthenticator struct {
|
||||||
|
|||||||
Reference in New Issue
Block a user