网关实现自定义接口安全检查与边缘节点连接权限验证
This commit is contained in:
@@ -2,11 +2,9 @@ package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/fwd/repo"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -20,7 +18,7 @@ const (
|
||||
Http = Protocol("http")
|
||||
)
|
||||
|
||||
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
||||
func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.AuthContext, error) {
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
@@ -32,101 +30,48 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
||||
// 获取服务端口
|
||||
localAddr := conn.LocalAddr().String()
|
||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||
localPort, err := strconv.Atoi(_localPort)
|
||||
localPort, err := strconv.ParseUint(_localPort, 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
|
||||
var channels []repo.Channel
|
||||
err = orm.DB.Find(&channels, &repo.Channel{
|
||||
AuthIp: true,
|
||||
UserAddr: remoteHost,
|
||||
NodePort: localPort,
|
||||
Protocol: string(proto),
|
||||
}).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("查询用户权限失败")
|
||||
}
|
||||
// 记录应该只有一条
|
||||
channel, err := orm.MaySingle(channels)
|
||||
if err != nil {
|
||||
return nil, errors.New("不在白名单内")
|
||||
// 查找权限配置
|
||||
var permit, ok = app.Permits[uint16(localPort)]
|
||||
if !ok {
|
||||
return nil, errors.New("没有权限")
|
||||
}
|
||||
|
||||
// 检查是否需要密码认证
|
||||
if channel.AuthPass {
|
||||
return nil, errors.New("需要密码认证")
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
// 检查是否过期
|
||||
if permit.Expire.Before(time.Now()) {
|
||||
return nil, errors.New("权限已过期")
|
||||
}
|
||||
|
||||
return &core.AuthContext{
|
||||
Timeout: timeout,
|
||||
Payload: core.Payload{
|
||||
ID: channel.UserId,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
|
||||
|
||||
// 获取服务端口
|
||||
localAddr := conn.LocalAddr().String()
|
||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||
localPort, err := strconv.Atoi(_localPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
var channel repo.Channel
|
||||
err = orm.DB.Take(&channel, &repo.Channel{
|
||||
AuthPass: true,
|
||||
Username: username,
|
||||
NodePort: localPort,
|
||||
Protocol: string(proto),
|
||||
}).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查密码 todo 哈希
|
||||
if channel.Password != password {
|
||||
return nil, errors.New("密码错误")
|
||||
}
|
||||
|
||||
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
||||
if channel.AuthIp {
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
||||
// 检查 IP 是否可用
|
||||
if len(permit.Whitelists) > 0 {
|
||||
var found = false
|
||||
for _, allowedHost := range permit.Whitelists {
|
||||
var allowed = net.ParseIP(allowedHost)
|
||||
var remote = net.ParseIP(remoteHost)
|
||||
if remote.Equal(allowed) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 查询权限记录
|
||||
if channel.UserAddr != remoteHost {
|
||||
if !found {
|
||||
return nil, errors.New("不在白名单内")
|
||||
}
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
return nil, errors.New("权限已过期")
|
||||
if username != nil && password != nil {
|
||||
if *username != permit.Username || *password != permit.Password {
|
||||
return nil, errors.New("用户名或密码错误")
|
||||
}
|
||||
}
|
||||
|
||||
return &core.AuthContext{
|
||||
Timeout: timeout,
|
||||
Timeout: time.Since(permit.Expire).Seconds(),
|
||||
Payload: core.Payload{
|
||||
ID: channel.UserId,
|
||||
ID: app.Assigns[uint16(localPort)],
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"net"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
local net.Addr
|
||||
remote net.Addr
|
||||
}
|
||||
|
||||
func (m MockConn) LocalAddr() net.Addr {
|
||||
return m.local
|
||||
}
|
||||
|
||||
func (m MockConn) RemoteAddr() net.Addr {
|
||||
return m.remote
|
||||
}
|
||||
|
||||
func (m MockConn) Read(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m MockConn) Write(b []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCheckIp(t *testing.T) {
|
||||
|
||||
orm.InitForTest(
|
||||
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||
)
|
||||
|
||||
type args struct {
|
||||
conn net.Conn
|
||||
proto Protocol
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *core.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test-ok",
|
||||
args: args{
|
||||
conn: MockConn{
|
||||
remote: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 12345,
|
||||
},
|
||||
local: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 20001,
|
||||
},
|
||||
},
|
||||
proto: Http,
|
||||
},
|
||||
want: &core.AuthContext{
|
||||
Payload: core.Payload{ID: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CheckIp(tt.args.conn, tt.args.proto)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPass(t *testing.T) {
|
||||
|
||||
orm.InitForTest(
|
||||
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||
)
|
||||
|
||||
type args struct {
|
||||
conn net.Conn
|
||||
username string
|
||||
password string
|
||||
proto Protocol
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *core.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test-ok",
|
||||
args: args{
|
||||
conn: MockConn{
|
||||
remote: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 12345,
|
||||
},
|
||||
local: &net.TCPAddr{
|
||||
IP: []byte{127, 0, 0, 1},
|
||||
Port: 20001,
|
||||
},
|
||||
},
|
||||
proto: Http,
|
||||
username: "ip4http",
|
||||
password: "asdf",
|
||||
},
|
||||
want: &core.AuthContext{
|
||||
Payload: core.Payload{ID: 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user