网关实现自定义接口安全检查与边缘节点连接权限验证

This commit is contained in:
2025-05-15 15:56:20 +08:00
parent b29882f0a7
commit d65fe4db6f
25 changed files with 353 additions and 703 deletions

View File

@@ -22,7 +22,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
} else {
slog.Debug(
"用户访问记录",
slog.Uint64("uid", uint64(conn.Auth.Payload.ID)),
slog.Int("uid", int(conn.Auth.Payload.ID)),
slog.String("user", conn.RemoteAddr().String()),
slog.String("proxy", conn.Protocol),
slog.String("node", conn.LocalAddr().String()),

View File

@@ -2,11 +2,9 @@ package auth
import (
"fmt"
"log/slog"
"net"
"proxy-server/server/app"
"proxy-server/server/fwd/core"
"proxy-server/server/fwd/repo"
"proxy-server/server/pkg/orm"
"strconv"
"time"
@@ -20,7 +18,7 @@ const (
Http = Protocol("http")
)
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.AuthContext, error) {
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
@@ -32,101 +30,48 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
localPort, err := strconv.ParseUint(_localPort, 10, 16)
if err != nil {
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
}
// 查权限记录
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
var channels []repo.Channel
err = orm.DB.Find(&channels, &repo.Channel{
AuthIp: true,
UserAddr: remoteHost,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.New("查询用户权限失败")
}
// 记录应该只有一条
channel, err := orm.MaySingle(channels)
if err != nil {
return nil, errors.New("不在白名单内")
// 查权限配置
var permit, ok = app.Permits[uint16(localPort)]
if !ok {
return nil, errors.New("没有权限")
}
// 检查是否需要密码认证
if channel.AuthPass {
return nil, errors.New("需要密码认证")
}
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
// 检查是否过期
if permit.Expire.Before(time.Now()) {
return nil, errors.New("权限已过期")
}
return &core.AuthContext{
Timeout: timeout,
Payload: core.Payload{
ID: channel.UserId,
},
}, nil
}
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
}
// 查询权限记录
var channel repo.Channel
err = orm.DB.Take(&channel, &repo.Channel{
AuthPass: true,
Username: username,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.New("用户不存在")
}
// 检查密码 todo 哈希
if channel.Password != password {
return nil, errors.New("密码错误")
}
// 如果用户设置了双验证则检查 ip 是否在白名单中
if channel.AuthIp {
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return nil, fmt.Errorf("无法获取连接信息: %w", err)
// 检查 IP 是否可用
if len(permit.Whitelists) > 0 {
var found = false
for _, allowedHost := range permit.Whitelists {
var allowed = net.ParseIP(allowedHost)
var remote = net.ParseIP(remoteHost)
if remote.Equal(allowed) {
found = true
break
}
}
// 查询权限记录
if channel.UserAddr != remoteHost {
if !found {
return nil, errors.New("不在白名单内")
}
}
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
if username != nil && password != nil {
if *username != permit.Username || *password != permit.Password {
return nil, errors.New("用户名或密码错误")
}
}
return &core.AuthContext{
Timeout: timeout,
Timeout: time.Since(permit.Expire).Seconds(),
Payload: core.Payload{
ID: channel.UserId,
ID: app.Assigns[uint16(localPort)],
},
}, nil
}

View File

@@ -1,151 +0,0 @@
package auth
import (
"net"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/orm"
"reflect"
"testing"
"time"
)
type MockConn struct {
local net.Addr
remote net.Addr
}
func (m MockConn) LocalAddr() net.Addr {
return m.local
}
func (m MockConn) RemoteAddr() net.Addr {
return m.remote
}
func (m MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Close() error {
return nil
}
func (m MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestCheckIp(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckIp(tt.args.conn, tt.args.proto)
if (err != nil) != tt.wantErr {
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckPass(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
username string
password string
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
username: "ip4http",
password: "asdf",
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
if (err != nil) != tt.wantErr {
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -100,5 +100,5 @@ type AuthContext struct {
}
type Payload struct {
ID uint
ID int32
}

View File

@@ -9,10 +9,11 @@ import (
"log/slog"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/app"
"proxy-server/server/env"
"proxy-server/server/fwd/core"
"proxy-server/server/fwd/dispatcher"
"proxy-server/server/fwd/metrics"
"proxy-server/server/pkg/env"
"proxy-server/server/report"
"strconv"
"strings"
@@ -74,29 +75,26 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
if err != nil {
return fmt.Errorf("读取客户端 ID 失败: %w", err)
}
var clientId = int32(binary.BigEndian.Uint32(recv))
var client = int32(binary.BigEndian.Uint32(recv))
// 分配端口
var minim uint16 = 20000
var maxim uint16 = 60000
var fwdPort uint16
var port uint16
for i := minim; i < maxim; i++ {
var _, ok = s.fwdPortMap[i]
var _, ok = app.Assigns[i]
if !ok {
fwdPort = i
s.fwdPortMap[i] = clientId
port = i
app.Assigns[i] = client
break
}
}
if fwdPort == 0 {
if port == 0 {
return errors.New("没有可用的端口")
}
// 报告端口分配
if s.Config.Id == nil || *s.Config.Id == 0 {
return errors.New("转发服务未成功注册,无法提供服务")
}
err = report.Assigned(s.ctx, *s.Config.Id, clientId, fwdPort)
err = report.Assigned(client, port)
if err != nil {
return fmt.Errorf("报告端口分配失败: %w", err)
}
@@ -108,8 +106,8 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
}
// 启动转发服务
slog.Info("监听转发端口", "port", fwdPort, "client", clientId)
proxy, err := dispatcher.New(fwdPort)
slog.Info("监听转发端口", "port", port, "client", client)
proxy, err := dispatcher.New(port)
if err != nil {
return err
}

View File

@@ -7,8 +7,8 @@ import (
"net"
"proxy-server/pkg/utils"
"proxy-server/server/debug"
"proxy-server/server/env"
"proxy-server/server/fwd/metrics"
"proxy-server/server/pkg/env"
"strconv"
"sync"
"time"

View File

@@ -8,12 +8,7 @@ import (
"sync"
)
type Config struct {
Id *int32
}
type Service struct {
Config *Config
ctx context.Context
cancel context.CancelFunc
@@ -23,21 +18,13 @@ type Service struct {
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
userConnWg utils.CountWaitGroup
fwdPortMap map[uint16]int32 // 转发端口映射key 为端口号value 为边缘节点 ID
}
func New(config *Config) *Service {
if config == nil {
config = &Config{}
}
func New() *Service {
ctx, cancel := context.WithCancel(context.Background())
return &Service{
Config: config,
ctx: ctx,
cancel: cancel,
fwdPortMap: make(map[uint16]int32),
ctx: ctx,
cancel: cancel,
}
}

View File

@@ -49,18 +49,9 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 验证账号
authInfo := headers.Get("Proxy-Authorization")
var authCtx *core.AuthContext
var authErr error
if authInfo == "" {
authCtx, authErr = auth.CheckIp(conn, auth.Http)
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, fmt.Errorf("响应 407 失败: %v", err)
}
return nil, fmt.Errorf("验证账号失败: %v", authErr)
}
} else {
var username *string = nil
var password *string = nil
if authInfo != "" {
authParts := strings.Split(authInfo, " ")
if len(authParts) != 2 {
return nil, errors.New("无效的 Proxy-Authorization")
@@ -73,14 +64,17 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
return nil, fmt.Errorf("解码认证信息失败: %v", err)
}
authPair := strings.Split(string(authBytes), ":")
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, fmt.Errorf("响应 407 失败: %v", err)
}
return nil, fmt.Errorf("验证账号失败: %v", authErr)
username = &authPair[0]
password = &authPair[1]
}
authCtx, err := auth.Protect(conn, auth.Http, username, password)
if err != nil {
_, err = conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, fmt.Errorf("响应 407 失败: %v", err)
}
return nil, fmt.Errorf("验证账号失败: %v", err)
}
// 获取 Host

View File

@@ -104,10 +104,10 @@ func checkVersion(reader io.Reader) error {
}
// authenticate 执行认证流程
func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*core.AuthContext, error) {
func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (authContext *core.AuthContext, err error) {
// 版本检查
err := checkVersion(reader)
err = checkVersion(reader)
if err != nil {
return nil, err
}
@@ -122,7 +122,6 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
return nil, err
}
// 密码模式
if slices.Contains(methods, UserPassAuth) {
_, err := conn.Write([]byte{Version, byte(UserPassAuth)})
if err != nil {
@@ -167,7 +166,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
password := string(passwordBuf)
// 检查权限
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
authContext, err = auth.Protect(conn, auth.Socks5, &username, &password)
if err != nil {
return nil, fmt.Errorf("权限检查失败: %w", err)
}
@@ -179,30 +178,29 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
}
return authContext, nil
}
// 无认证
if slices.Contains(methods, NoAuth) {
} else if slices.Contains(methods, NoAuth) {
_, err = conn.Write([]byte{Version, NoAuth})
if err != nil {
return nil, fmt.Errorf("响应认证方式失败: %w", err)
}
authCtx, err := auth.CheckIp(conn, auth.Socks5)
authContext, err = auth.Protect(conn, auth.Socks5, nil, nil)
if err != nil {
return nil, fmt.Errorf("权限检查失败: %w", err)
}
return authContext, nil
return authCtx, nil
} else {
_, err = conn.Write([]byte{Version, NoAcceptable})
if err != nil {
return nil, err
}
return nil, errors.New("没有适用的认证方式")
}
// 无适用的认证方式
_, err = conn.Write([]byte{Version, NoAcceptable})
if err != nil {
return nil, err
}
return nil, errors.New("没有适用的认证方式")
}
type Request struct {