网关实现自定义接口安全检查与边缘节点连接权限验证
This commit is contained in:
12
server/app/app.go
Normal file
12
server/app/app.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package app
|
||||
|
||||
import "proxy-server/server/core"
|
||||
|
||||
var (
|
||||
Id int32
|
||||
Name string
|
||||
PlatformSecret string // 平台密钥,验证接收的请求是否属于平台
|
||||
|
||||
Assigns = make(map[uint16]int32) // 转发端口 -> 转发服务ID
|
||||
Permits = make(map[uint16]core.Permit) // 转发端口 -> 权限配置
|
||||
)
|
||||
10
server/core/auth.go
Normal file
10
server/core/auth.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package core
|
||||
|
||||
import "time"
|
||||
|
||||
type Permit struct {
|
||||
Expire time.Time `json:"expire"`
|
||||
Whitelists []string `json:"whitelists"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
6
server/core/consts.go
Normal file
6
server/core/consts.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package core
|
||||
|
||||
const (
|
||||
Version = 1
|
||||
RestoreMagic = 0x72
|
||||
)
|
||||
74
server/core/security.go
Normal file
74
server/core/security.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
g "proxy-server/server/globals"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SecuredReq struct {
|
||||
Content string `json:"content"`
|
||||
Nonce string `json:"nonce"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
func Decrypt[T any](req *SecuredReq, secret string) (resp *T, err error) {
|
||||
|
||||
// 解密请求
|
||||
block, err := aes.NewCipher([]byte(secret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var nonce = []byte(req.Nonce)
|
||||
|
||||
content, err := base64.StdEncoding.DecodeString(req.Content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var aad = []byte(fmt.Sprintf("%s:%d", req.Nonce, req.Timestamp))
|
||||
|
||||
bytes, err := gcm.Open(nil, nonce, content, aad)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查时间与 nonce 是否匹配
|
||||
var duration = time.Now().UnixMilli() - req.Timestamp
|
||||
if duration > 1000*60*5 { // 5分钟
|
||||
return nil, fmt.Errorf("请求超时,当前时间:%d,接收时间:%d", time.Now().UnixMilli(), req.Timestamp)
|
||||
}
|
||||
|
||||
result, err := g.Redis.Exists(context.Background(), req.Nonce).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result > 0 {
|
||||
return nil, fmt.Errorf("请求已被使用,nonce:%s", req.Nonce)
|
||||
}
|
||||
|
||||
// 将 nonce 存入 redis,设置过期时间为 5 分钟
|
||||
err = g.Redis.Set(context.Background(), req.Nonce, 1, time.Minute*5).Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回解密后的数据
|
||||
err = json.Unmarshal(bytes, &resp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
54
server/pkg/env/env.go → server/env/env.go
vendored
54
server/pkg/env/env.go → server/env/env.go
vendored
@@ -18,12 +18,10 @@ var (
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
|
||||
DbHost string
|
||||
DbPort uint16 = 5432
|
||||
DbDatabase string
|
||||
DbUsername string
|
||||
DbPassword string
|
||||
DbTimezone = "Asia/Shanghai"
|
||||
RedisHost = "localhost"
|
||||
RedisPort = "6379"
|
||||
RedisDb = 0
|
||||
RedisPass = ""
|
||||
|
||||
EndpointOnline string
|
||||
EndpointOffline string
|
||||
@@ -83,46 +81,28 @@ func Init() {
|
||||
panic("环境变量 CLIENT_SECRET 未设置")
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_HOST")
|
||||
value = os.Getenv("REDIS_HOST")
|
||||
if value != "" {
|
||||
DbHost = os.Getenv("DB_HOST")
|
||||
} else {
|
||||
panic("环境变量 DB_HOST 未设置")
|
||||
RedisHost = value
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_PORT")
|
||||
value = os.Getenv("REDIS_PORT")
|
||||
if value != "" {
|
||||
dbPort, err := strconv.Atoi(value)
|
||||
RedisPort = value
|
||||
}
|
||||
|
||||
value = os.Getenv("REDIS_DB")
|
||||
if value != "" {
|
||||
redisDb, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("环境变量 DB_PORT 格式错误: %v", err))
|
||||
panic(fmt.Sprintf("环境变量 REDIS_DB 格式错误: %v", err))
|
||||
}
|
||||
DbPort = uint16(dbPort)
|
||||
RedisDb = redisDb
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_DATABASE")
|
||||
value = os.Getenv("REDIS_PASS")
|
||||
if value != "" {
|
||||
DbDatabase = value
|
||||
} else {
|
||||
panic("环境变量 DB_DATABASE 未设置")
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_USERNAME")
|
||||
if value != "" {
|
||||
DbUsername = value
|
||||
} else {
|
||||
panic("环境变量 DB_USERNAME 未设置")
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_PASSWORD")
|
||||
if value != "" {
|
||||
DbPassword = value
|
||||
} else {
|
||||
panic("环境变量 DB_PASSWORD 未设置")
|
||||
}
|
||||
|
||||
value = os.Getenv("DB_TIMEZONE")
|
||||
if value != "" {
|
||||
DbTimezone = value
|
||||
RedisPass = value
|
||||
}
|
||||
|
||||
value = os.Getenv("ENDPOINT_ONLINE")
|
||||
@@ -22,7 +22,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
||||
} else {
|
||||
slog.Debug(
|
||||
"用户访问记录",
|
||||
slog.Uint64("uid", uint64(conn.Auth.Payload.ID)),
|
||||
slog.Int("uid", int(conn.Auth.Payload.ID)),
|
||||
slog.String("user", conn.RemoteAddr().String()),
|
||||
slog.String("proxy", conn.Protocol),
|
||||
slog.String("node", conn.LocalAddr().String()),
|
||||
|
||||
@@ -2,11 +2,9 @@ package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/fwd/repo"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -20,7 +18,7 @@ const (
|
||||
Http = Protocol("http")
|
||||
)
|
||||
|
||||
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
||||
func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.AuthContext, error) {
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
@@ -32,101 +30,48 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
||||
// 获取服务端口
|
||||
localAddr := conn.LocalAddr().String()
|
||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||
localPort, err := strconv.Atoi(_localPort)
|
||||
localPort, err := strconv.ParseUint(_localPort, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
|
||||
var channels []repo.Channel
|
||||
err = orm.DB.Find(&channels, &repo.Channel{
|
||||
AuthIp: true,
|
||||
UserAddr: remoteHost,
|
||||
NodePort: localPort,
|
||||
Protocol: string(proto),
|
||||
}).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("查询用户权限失败")
|
||||
}
|
||||
// 记录应该只有一条
|
||||
channel, err := orm.MaySingle(channels)
|
||||
if err != nil {
|
||||
return nil, errors.New("不在白名单内")
|
||||
// 查找权限配置
|
||||
var permit, ok = app.Permits[uint16(localPort)]
|
||||
if !ok {
|
||||
return nil, errors.New("没有权限")
|
||||
}
|
||||
|
||||
// 检查是否需要密码认证
|
||||
if channel.AuthPass {
|
||||
return nil, errors.New("需要密码认证")
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
// 检查是否过期
|
||||
if permit.Expire.Before(time.Now()) {
|
||||
return nil, errors.New("权限已过期")
|
||||
}
|
||||
|
||||
return &core.AuthContext{
|
||||
Timeout: timeout,
|
||||
Payload: core.Payload{
|
||||
ID: channel.UserId,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
|
||||
|
||||
// 获取服务端口
|
||||
localAddr := conn.LocalAddr().String()
|
||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||
localPort, err := strconv.Atoi(_localPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
var channel repo.Channel
|
||||
err = orm.DB.Take(&channel, &repo.Channel{
|
||||
AuthPass: true,
|
||||
Username: username,
|
||||
NodePort: localPort,
|
||||
Protocol: string(proto),
|
||||
}).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查密码 todo 哈希
|
||||
if channel.Password != password {
|
||||
return nil, errors.New("密码错误")
|
||||
}
|
||||
|
||||
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
||||
if channel.AuthIp {
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
||||
// 检查 IP 是否可用
|
||||
if len(permit.Whitelists) > 0 {
|
||||
var found = false
|
||||
for _, allowedHost := range permit.Whitelists {
|
||||
var allowed = net.ParseIP(allowedHost)
|
||||
var remote = net.ParseIP(remoteHost)
|
||||
if remote.Equal(allowed) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
if channel.UserAddr != remoteHost {
|
||||
if !found {
|
||||
return nil, errors.New("不在白名单内")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
return nil, errors.New("权限已过期")
|
||||
if username != nil && password != nil {
|
||||
if *username != permit.Username || *password != permit.Password {
|
||||
return nil, errors.New("用户名或密码错误")
|
||||
}
|
||||
}
|
||||
|
||||
return &core.AuthContext{
|
||||
Timeout: timeout,
|
||||
Timeout: time.Since(permit.Expire).Seconds(),
|
||||
Payload: core.Payload{
|
||||
ID: channel.UserId,
|
||||
ID: app.Assigns[uint16(localPort)],
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
local net.Addr
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (m MockConn) LocalAddr() net.Addr {
|
||||
return m.local
|
||||
}
|
||||
|
||||
func (m MockConn) RemoteAddr() net.Addr {
|
||||
return m.remote
|
||||
}
|
||||
|
||||
func (m MockConn) Read(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m MockConn) Write(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCheckIp(t *testing.T) {
|
||||
|
||||
orm.InitForTest(
|
||||
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||
)
|
||||
|
||||
type args struct {
|
||||
conn net.Conn
|
||||
proto Protocol
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *core.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test-ok",
|
||||
args: args{
|
||||
conn: MockConn{
|
||||
remote: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 12345,
|
||||
},
|
||||
local: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 20001,
|
||||
},
|
||||
},
|
||||
proto: Http,
|
||||
},
|
||||
want: &core.AuthContext{
|
||||
Payload: core.Payload{ID: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CheckIp(tt.args.conn, tt.args.proto)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPass(t *testing.T) {
|
||||
|
||||
orm.InitForTest(
|
||||
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||
)
|
||||
|
||||
type args struct {
|
||||
conn net.Conn
|
||||
username string
|
||||
password string
|
||||
proto Protocol
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *core.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test-ok",
|
||||
args: args{
|
||||
conn: MockConn{
|
||||
remote: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 12345,
|
||||
},
|
||||
local: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 20001,
|
||||
},
|
||||
},
|
||||
proto: Http,
|
||||
username: "ip4http",
|
||||
password: "asdf",
|
||||
},
|
||||
want: &core.AuthContext{
|
||||
Payload: core.Payload{ID: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -100,5 +100,5 @@ type AuthContext struct {
|
||||
}
|
||||
|
||||
type Payload struct {
|
||||
ID uint
|
||||
ID int32
|
||||
}
|
||||
|
||||
@@ -9,10 +9,11 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/fwd/dispatcher"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/report"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -74,29 +75,26 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取客户端 ID 失败: %w", err)
|
||||
}
|
||||
var clientId = int32(binary.BigEndian.Uint32(recv))
|
||||
var client = int32(binary.BigEndian.Uint32(recv))
|
||||
|
||||
// 分配端口
|
||||
var minim uint16 = 20000
|
||||
var maxim uint16 = 60000
|
||||
var fwdPort uint16
|
||||
var port uint16
|
||||
for i := minim; i < maxim; i++ {
|
||||
var _, ok = s.fwdPortMap[i]
|
||||
var _, ok = app.Assigns[i]
|
||||
if !ok {
|
||||
fwdPort = i
|
||||
s.fwdPortMap[i] = clientId
|
||||
port = i
|
||||
app.Assigns[i] = client
|
||||
break
|
||||
}
|
||||
}
|
||||
if fwdPort == 0 {
|
||||
if port == 0 {
|
||||
return errors.New("没有可用的端口")
|
||||
}
|
||||
|
||||
// 报告端口分配
|
||||
if s.Config.Id == nil || *s.Config.Id == 0 {
|
||||
return errors.New("转发服务未成功注册,无法提供服务")
|
||||
}
|
||||
err = report.Assigned(s.ctx, *s.Config.Id, clientId, fwdPort)
|
||||
err = report.Assigned(client, port)
|
||||
if err != nil {
|
||||
return fmt.Errorf("报告端口分配失败: %w", err)
|
||||
}
|
||||
@@ -108,8 +106,8 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 启动转发服务
|
||||
slog.Info("监听转发端口", "port", fwdPort, "client", clientId)
|
||||
proxy, err := dispatcher.New(fwdPort)
|
||||
slog.Info("监听转发端口", "port", port, "client", client)
|
||||
proxy, err := dispatcher.New(port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/debug"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"proxy-server/server/pkg/env"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -8,12 +8,7 @@ import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Id *int32
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
Config *Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
@@ -23,21 +18,13 @@ type Service struct {
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
userConnWg utils.CountWaitGroup
|
||||
|
||||
fwdPortMap map[uint16]int32 // 转发端口映射,key 为端口号,value 为边缘节点 ID
|
||||
}
|
||||
|
||||
func New(config *Config) *Service {
|
||||
if config == nil {
|
||||
config = &Config{}
|
||||
}
|
||||
|
||||
func New() *Service {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
Config: config,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
fwdPortMap: make(map[uint16]int32),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,18 +49,9 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
||||
|
||||
// 验证账号
|
||||
authInfo := headers.Get("Proxy-Authorization")
|
||||
var authCtx *core.AuthContext
|
||||
var authErr error
|
||||
if authInfo == "" {
|
||||
authCtx, authErr = auth.CheckIp(conn, auth.Http)
|
||||
if authErr != nil {
|
||||
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应 407 失败: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("验证账号失败: %v", authErr)
|
||||
}
|
||||
} else {
|
||||
var username *string = nil
|
||||
var password *string = nil
|
||||
if authInfo != "" {
|
||||
authParts := strings.Split(authInfo, " ")
|
||||
if len(authParts) != 2 {
|
||||
return nil, errors.New("无效的 Proxy-Authorization")
|
||||
@@ -73,14 +64,17 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
||||
return nil, fmt.Errorf("解码认证信息失败: %v", err)
|
||||
}
|
||||
authPair := strings.Split(string(authBytes), ":")
|
||||
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
|
||||
if authErr != nil {
|
||||
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应 407 失败: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("验证账号失败: %v", authErr)
|
||||
username = &authPair[0]
|
||||
password = &authPair[1]
|
||||
}
|
||||
|
||||
authCtx, err := auth.Protect(conn, auth.Http, username, password)
|
||||
if err != nil {
|
||||
_, err = conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应 407 失败: %v", err)
|
||||
}
|
||||
return nil, fmt.Errorf("验证账号失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取 Host
|
||||
|
||||
@@ -104,10 +104,10 @@ func checkVersion(reader io.Reader) error {
|
||||
}
|
||||
|
||||
// authenticate 执行认证流程
|
||||
func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*core.AuthContext, error) {
|
||||
func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (authContext *core.AuthContext, err error) {
|
||||
|
||||
// 版本检查
|
||||
err := checkVersion(reader)
|
||||
err = checkVersion(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -122,7 +122,6 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 密码模式
|
||||
if slices.Contains(methods, UserPassAuth) {
|
||||
_, err := conn.Write([]byte{Version, byte(UserPassAuth)})
|
||||
if err != nil {
|
||||
@@ -167,7 +166,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
||||
password := string(passwordBuf)
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
|
||||
authContext, err = auth.Protect(conn, auth.Socks5, &username, &password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("权限检查失败: %w", err)
|
||||
}
|
||||
@@ -179,30 +178,29 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
||||
}
|
||||
|
||||
return authContext, nil
|
||||
}
|
||||
|
||||
// 无认证
|
||||
if slices.Contains(methods, NoAuth) {
|
||||
} else if slices.Contains(methods, NoAuth) {
|
||||
|
||||
_, err = conn.Write([]byte{Version, NoAuth})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应认证方式失败: %w", err)
|
||||
}
|
||||
|
||||
authCtx, err := auth.CheckIp(conn, auth.Socks5)
|
||||
authContext, err = auth.Protect(conn, auth.Socks5, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("权限检查失败: %w", err)
|
||||
}
|
||||
return authContext, nil
|
||||
|
||||
return authCtx, nil
|
||||
} else {
|
||||
|
||||
_, err = conn.Write([]byte{Version, NoAcceptable})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, errors.New("没有适用的认证方式")
|
||||
}
|
||||
|
||||
// 无适用的认证方式
|
||||
_, err = conn.Write([]byte{Version, NoAcceptable})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, errors.New("没有适用的认证方式")
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
|
||||
27
server/globals/redis.go
Normal file
27
server/globals/redis.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"github.com/redis/go-redis/v9"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/server/env"
|
||||
)
|
||||
|
||||
var Redis *redis.Client
|
||||
|
||||
func InitRedis() {
|
||||
Redis = redis.NewClient(&redis.Options{
|
||||
Addr: net.JoinHostPort(env.RedisHost, env.RedisPort),
|
||||
DB: env.RedisDb,
|
||||
Password: env.RedisPass,
|
||||
})
|
||||
}
|
||||
|
||||
func ExitRedis() {
|
||||
if Redis != nil {
|
||||
var err = Redis.Close()
|
||||
if err != nil {
|
||||
slog.Warn("关闭 Redis 连接失败", "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package log
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/env"
|
||||
"time"
|
||||
|
||||
"github.com/lmittmann/tint"
|
||||
@@ -1,60 +0,0 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"proxy-server/server/pkg/env"
|
||||
|
||||
"errors"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func Init() {
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable TimeZone=%s",
|
||||
env.DbHost, env.DbPort, env.DbUsername, env.DbPassword, env.DbDatabase, env.DbTimezone,
|
||||
)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDb, err := db.DB()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
sqlDb.SetMaxIdleConns(10)
|
||||
sqlDb.SetMaxOpenConns(100)
|
||||
|
||||
DB = db
|
||||
}
|
||||
|
||||
func InitForTest(dsn string) {
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
DB = db
|
||||
}
|
||||
|
||||
func MaySingle[T any](results []T) (*T, error) {
|
||||
rsLen := len(results)
|
||||
if rsLen == 0 {
|
||||
return nil, errors.New("记录为空")
|
||||
}
|
||||
if rsLen > 1 {
|
||||
slog.Warn("记录不唯一", "ids")
|
||||
}
|
||||
return &results[0], nil
|
||||
}
|
||||
@@ -1,91 +1,88 @@
|
||||
package report
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"proxy-server/client/core"
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/env"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Online(ctx context.Context, name string) (id int32, err error) {
|
||||
func Online(name string) (err error) {
|
||||
var resp string
|
||||
resp, err = repeat(ctx, env.EndpointOnline, map[string]any{
|
||||
resp, err = call(env.EndpointOnline, map[string]any{
|
||||
"name": name,
|
||||
"version": core.Version,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
var body struct {
|
||||
Id int32 `json:"id"`
|
||||
Id int32 `json:"id"`
|
||||
Secret string `json:"secret"`
|
||||
}
|
||||
err = json.Unmarshal([]byte(resp), &body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
|
||||
if body.Id == 0 {
|
||||
return 0, errors.New("服务注册返回 ID 有误")
|
||||
} else {
|
||||
return body.Id, nil
|
||||
}
|
||||
app.Id = body.Id
|
||||
app.PlatformSecret = body.Secret
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Offline(ctx context.Context, name string) (err error) {
|
||||
_, err = repeat(ctx, env.EndpointOffline, map[string]any{
|
||||
func Offline(name string) (err error) {
|
||||
_, err = call(env.EndpointOffline, map[string]any{
|
||||
"name": name,
|
||||
"version": core.Version,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func Assigned(ctx context.Context, id int32, edgeId int32, port uint16) (err error) {
|
||||
_, err = repeat(ctx, env.EndpointAssigned, map[string]any{
|
||||
"proxy": id,
|
||||
func Assigned(edgeId int32, port uint16) (err error) {
|
||||
_, err = call(env.EndpointAssigned, map[string]any{
|
||||
"proxy": app.Id,
|
||||
"edge": edgeId,
|
||||
"port": port,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func repeat(ctx context.Context, endpoint string, body any) (string, error) {
|
||||
func call(endpoint string, body any) (string, error) {
|
||||
bodyStr, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(bodyStr)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Basic "+base64.RawURLEncoding.EncodeToString([]byte("proxy:proxy")))
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if resp != nil && resp.StatusCode == http.StatusOK {
|
||||
var body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(body), nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
slog.Warn("服务调用失败,五秒后重试", "err", err)
|
||||
time.Sleep(5 * time.Second)
|
||||
req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(bodyStr)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var auth = base64.RawURLEncoding.EncodeToString([]byte(env.ClientId + ":" + env.ClientSecret))
|
||||
var basic = fmt.Sprintf("Basic %s", auth)
|
||||
req.Header.Set("Authorization", basic)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("请求失败,状态码:%d", resp.StatusCode)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(respBody), nil
|
||||
}
|
||||
|
||||
@@ -2,15 +2,18 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/signal"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/debug"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd"
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/pkg/log"
|
||||
"proxy-server/server/pkg/orm"
|
||||
g "proxy-server/server/globals"
|
||||
"proxy-server/server/log"
|
||||
"proxy-server/server/report"
|
||||
"proxy-server/server/web"
|
||||
"sync"
|
||||
@@ -24,14 +27,7 @@ import (
|
||||
_ "net/http/pprof"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = 1
|
||||
RestoreMagic = 0x72
|
||||
)
|
||||
|
||||
type server struct {
|
||||
id int32
|
||||
name string
|
||||
}
|
||||
|
||||
func New() *server {
|
||||
@@ -94,22 +90,15 @@ func (s *server) Run() (err error) {
|
||||
|
||||
// 报告上线
|
||||
slog.Debug("报告服务上线")
|
||||
var reportErrCh = make(chan error, 1)
|
||||
defer close(reportErrCh)
|
||||
go func() {
|
||||
id, err := report.Online(ctx, s.name)
|
||||
if err != nil {
|
||||
reportErrCh <- err
|
||||
return
|
||||
}
|
||||
s.id = id
|
||||
}()
|
||||
err = report.Online(app.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("服务上线失败: %w", err)
|
||||
}
|
||||
|
||||
// 等待退出信号
|
||||
osQuit := make(chan os.Signal, 1)
|
||||
signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
var reportErr error
|
||||
select {
|
||||
case <-osQuit:
|
||||
slog.Info("服务主动退出")
|
||||
@@ -117,26 +106,23 @@ func (s *server) Run() (err error) {
|
||||
slog.Warn("fwd 服务异常退出", "err", err)
|
||||
case err := <-apiQuit:
|
||||
slog.Warn("web 服务异常退出", "err", err)
|
||||
case reportErr = <-reportErrCh:
|
||||
slog.Warn("报告服务上线发生错误", "err", reportErr)
|
||||
}
|
||||
cancel()
|
||||
|
||||
// 报告下线
|
||||
if reportErr == nil {
|
||||
slog.Debug("报告服务下线")
|
||||
go func() {
|
||||
err := report.Offline(ctx, s.name)
|
||||
if err != nil {
|
||||
slog.Error("报告服务下线发生错误", "err", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 等待其它服务关闭
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 报告下线
|
||||
slog.Debug("报告服务下线")
|
||||
err = report.Offline(app.Name)
|
||||
if err != nil {
|
||||
slog.Error("服务下线失败", "err", err)
|
||||
}
|
||||
|
||||
// 关闭 redis
|
||||
g.ExitRedis()
|
||||
|
||||
// 等待其它服务关闭
|
||||
select {
|
||||
case <-utils.ChanWgWait(timeout, &wg):
|
||||
slog.Info("服务正常关闭")
|
||||
@@ -156,7 +142,7 @@ func (s *server) init() error {
|
||||
|
||||
log.Init()
|
||||
env.Init()
|
||||
orm.Init()
|
||||
g.InitRedis()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -169,31 +155,29 @@ func (s *server) restore() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(bytes) == 17 && bytes[0] == RestoreMagic {
|
||||
s.name = uuid.UUID(bytes[1:]).String()
|
||||
slog.Info("恢复服务名称", "name", s.name)
|
||||
if len(bytes) == 17 && bytes[0] == core.RestoreMagic {
|
||||
app.Name = uuid.UUID(bytes[1:]).String()
|
||||
slog.Info("恢复服务名称", "name", app.Name)
|
||||
} else {
|
||||
var u = uuid.New()
|
||||
s.name = u.String()
|
||||
app.Name = u.String()
|
||||
|
||||
bytes = make([]byte, 17)
|
||||
bytes[0] = RestoreMagic
|
||||
bytes[0] = core.RestoreMagic
|
||||
copy(bytes[1:], u[:])
|
||||
err := os.WriteFile(file, bytes, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("生成服务名称", "name", s.name)
|
||||
slog.Info("生成服务名称", "name", app.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) startFwd(ctx context.Context) error {
|
||||
server := fwd.New(&fwd.Config{
|
||||
Id: &s.id,
|
||||
})
|
||||
server := fwd.New()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
server.Stop()
|
||||
|
||||
32
server/web/handlers/auth.go
Normal file
32
server/web/handlers/auth.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/core"
|
||||
)
|
||||
|
||||
type AuthReq struct {
|
||||
Port uint16 `json:"port"`
|
||||
core.Permit
|
||||
}
|
||||
|
||||
func Auth(ctx *fiber.Ctx) (err error) {
|
||||
|
||||
// 安全验证
|
||||
var sec core.SecuredReq
|
||||
if err := ctx.BodyParser(&sec); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取请求参数
|
||||
req, err := core.Decrypt[AuthReq](&sec, app.PlatformSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存授权配置
|
||||
app.Permits[req.Port] = req.Permit
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/env"
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
Reference in New Issue
Block a user