重构迁移核心数据结构到认证模块;完善中间件初始化逻辑以及 logger 记录过程
This commit is contained in:
@@ -3,6 +3,7 @@ package services
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
@@ -24,14 +25,14 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie
|
||||
// OauthClientCredentials 验证客户端凭证
|
||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
|
||||
|
||||
var clientType PayloadType
|
||||
var clientType auth.PayloadType
|
||||
switch client.Spec {
|
||||
case 1:
|
||||
clientType = PayloadClientPublic
|
||||
clientType = auth.PayloadClientPublic
|
||||
case 2:
|
||||
clientType = PayloadClientPublic
|
||||
clientType = auth.PayloadClientPublic
|
||||
case 3:
|
||||
clientType = PayloadClientConfidential
|
||||
clientType = auth.PayloadClientConfidential
|
||||
}
|
||||
|
||||
var permissions = make(map[string]struct{}, len(scope))
|
||||
@@ -40,9 +41,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
}
|
||||
|
||||
// 保存会话并返回令牌
|
||||
auth := AuthContext{
|
||||
authCtx := auth.Context{
|
||||
Permissions: permissions,
|
||||
Payload: Payload{
|
||||
Payload: auth.Payload{
|
||||
Id: client.ID,
|
||||
Type: clientType,
|
||||
Name: client.Name,
|
||||
@@ -50,7 +51,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
}
|
||||
|
||||
// todo 数据库定义会话持续时间
|
||||
token, err := Session.Create(ctx, auth, false)
|
||||
token, err := Session.Create(ctx, authCtx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -136,16 +137,16 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
||||
}
|
||||
|
||||
// 保存到会话
|
||||
auth := AuthContext{
|
||||
Payload: Payload{
|
||||
authCtx := auth.Context{
|
||||
Payload: auth.Payload{
|
||||
Id: user.ID,
|
||||
Type: PayloadUser,
|
||||
Type: auth.PayloadUser,
|
||||
Name: user.Name,
|
||||
Avatar: user.Avatar,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := Session.Create(ctx, auth, data.Remember)
|
||||
token, err := Session.Create(ctx, authCtx, data.Remember)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"platform/pkg/orm"
|
||||
"platform/pkg/rds"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
g "platform/web/globals"
|
||||
"platform/web/models"
|
||||
@@ -64,7 +65,7 @@ type ResourceInfo struct {
|
||||
|
||||
// region RemoveChannel
|
||||
|
||||
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
|
||||
func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error {
|
||||
var step = time.Now()
|
||||
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
|
||||
|
||||
@@ -82,8 +83,8 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
|
||||
|
||||
// 检查权限,如果为用户操作的话,则只能删除自己的通道
|
||||
for _, channel := range channels {
|
||||
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
|
||||
return core.AuthForbiddenErr("无权限访问")
|
||||
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
|
||||
return core.ForbiddenErr("无权限访问")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,7 +239,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
|
||||
|
||||
func (s *channelService) CreateChannel(
|
||||
ctx context.Context,
|
||||
auth *AuthContext,
|
||||
authCtx *auth.Context,
|
||||
resourceId int32,
|
||||
protocol ChannelProtocol,
|
||||
authType ChannelAuthType,
|
||||
@@ -283,7 +284,7 @@ func (s *channelService) CreateChannel(
|
||||
slog.Debug("查找套餐", "rid", rid, "step", time.Since(step))
|
||||
|
||||
// 检查用户权限
|
||||
err = checkUser(auth, resource, count)
|
||||
err = checkUser(authCtx, resource, count)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -302,7 +303,7 @@ func (s *channelService) CreateChannel(
|
||||
step = time.Now()
|
||||
|
||||
expiration := core.LocalDateTime(now.Add(time.Duration(resource.Live) * time.Second))
|
||||
_addr, channels, err := assignPort(q, edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter)
|
||||
_addr, channels, err := assignPort(q, edgeAssigns, authCtx.Payload.Id, protocol, authType, expiration, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -356,11 +357,11 @@ func (s *channelService) CreateChannel(
|
||||
return addr, nil
|
||||
}
|
||||
|
||||
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
|
||||
func checkUser(authCtx *auth.Context, resource *ResourceInfo, count int) error {
|
||||
|
||||
// 检查使用人
|
||||
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId {
|
||||
return core.AuthForbiddenErr("无权限访问")
|
||||
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != resource.UserId {
|
||||
return core.ForbiddenErr("无权限访问")
|
||||
}
|
||||
|
||||
// 检查套餐状态
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"platform/pkg/testutil"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
g "platform/web/globals"
|
||||
"platform/web/models"
|
||||
@@ -276,7 +277,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth *AuthContext
|
||||
auth *auth.Context
|
||||
resourceId int32
|
||||
protocol ChannelProtocol
|
||||
authType ChannelAuthType
|
||||
@@ -286,8 +287,8 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
|
||||
// 准备测试数据
|
||||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||||
var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}}
|
||||
var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}}
|
||||
mc.AutoQueryMock = func() (g.CloudConnectResp, error) {
|
||||
return g.CloudConnectResp{
|
||||
"test-proxy": []g.AutoConfig{
|
||||
@@ -967,7 +968,7 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth *AuthContext
|
||||
auth *auth.Context
|
||||
id []int32
|
||||
}
|
||||
|
||||
@@ -989,8 +990,8 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
md.Create(adminUser)
|
||||
|
||||
// 认证上下文
|
||||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||||
var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}}
|
||||
var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}}
|
||||
|
||||
// 创建代理
|
||||
var proxy = &models.Proxy{
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/rds"
|
||||
"platform/web/auth"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -19,9 +20,9 @@ var Session SessionServiceInter = &sessionService{}
|
||||
|
||||
type SessionServiceInter interface {
|
||||
// Find 通过访问令牌获取会话信息
|
||||
Find(ctx context.Context, token string) (*AuthContext, error)
|
||||
Find(ctx context.Context, token string) (*auth.Context, error)
|
||||
// Create 创建一个新的会话
|
||||
Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error)
|
||||
Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error)
|
||||
// Refresh 刷新一个会话
|
||||
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
|
||||
// Remove 删除会话
|
||||
@@ -41,7 +42,7 @@ var (
|
||||
type sessionService struct{}
|
||||
|
||||
// Find 通过访问令牌获取会话信息
|
||||
func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
|
||||
func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result()
|
||||
@@ -53,16 +54,16 @@ func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext,
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
auth := new(AuthContext)
|
||||
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
|
||||
authCtx := new(auth.Context)
|
||||
if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
return authCtx, nil
|
||||
}
|
||||
|
||||
// Create 创建一个新的会话
|
||||
func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
|
||||
func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
@@ -70,14 +71,14 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(auth)
|
||||
authData, err := json.Marshal(authCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: auth,
|
||||
AuthContext: authCtx,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -103,7 +104,7 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: auth,
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -205,74 +206,8 @@ func refreshKey(token string) string {
|
||||
|
||||
// endregion
|
||||
|
||||
// region AuthContext
|
||||
|
||||
// AuthContext 定义认证信息
|
||||
type AuthContext struct {
|
||||
Payload Payload `json:"payload"`
|
||||
Agent Agent `json:"agent,omitempty"`
|
||||
Permissions map[string]struct{} `json:"permissions,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Payload 定义负载信息
|
||||
type Payload struct {
|
||||
Id int32 `json:"id,omitempty"`
|
||||
Type PayloadType `json:"type,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
}
|
||||
|
||||
// PayloadType 定义负载类型
|
||||
type PayloadType int
|
||||
|
||||
const (
|
||||
// PayloadUser 用户类型
|
||||
PayloadUser PayloadType = iota
|
||||
// PayloadAdmin 管理员类型
|
||||
PayloadAdmin
|
||||
// PayloadClientPublic 公共客户端类型
|
||||
PayloadClientPublic
|
||||
// PayloadClientConfidential 机密客户端类型
|
||||
PayloadClientConfidential
|
||||
)
|
||||
|
||||
func (t PayloadType) Name() string {
|
||||
switch t {
|
||||
case PayloadUser:
|
||||
return "user"
|
||||
case PayloadAdmin:
|
||||
return "admn"
|
||||
case PayloadClientPublic:
|
||||
return "cpub"
|
||||
case PayloadClientConfidential:
|
||||
return "ccnf"
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
Id int32 `json:"id,omitempty"`
|
||||
Addr string `json:"addr,omitempty"`
|
||||
}
|
||||
|
||||
// AnyPermission 检查认证是否包含指定权限
|
||||
func (a *AuthContext) AnyPermission(requiredPermission ...string) bool {
|
||||
if a == nil || a.Permissions == nil {
|
||||
return false
|
||||
}
|
||||
for _, permission := range requiredPermission {
|
||||
if _, ok := a.Permissions[permission]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext AuthContext
|
||||
AuthContext auth.Context
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
@@ -287,5 +222,5 @@ type TokenDetails struct {
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth AuthContext
|
||||
Auth auth.Context
|
||||
}
|
||||
|
||||
@@ -4,17 +4,18 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/pkg/testutil"
|
||||
"platform/web/auth"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 创建测试用的认证上下文
|
||||
func createTestAuthContext() AuthContext {
|
||||
func createTestAuthContext() auth.Context {
|
||||
//goland:noinspection ALL
|
||||
return AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
return auth.Context{
|
||||
Payload: auth.Payload{
|
||||
Type: auth.PayloadUser,
|
||||
Id: 1001,
|
||||
},
|
||||
Permissions: map[string]struct{}{
|
||||
@@ -31,11 +32,11 @@ func createTestAuthContext() AuthContext {
|
||||
func Test_sessionService_Create(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
authCtx := createTestAuthContext()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth AuthContext
|
||||
auth auth.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -47,7 +48,7 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
name: "创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
auth: authCtx,
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
// 验证令牌存在且格式正确
|
||||
@@ -60,7 +61,7 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息正确
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
if !reflect.DeepEqual(td.Auth, authCtx) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -100,11 +101,11 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
func Test_sessionService_Find(t *testing.T) {
|
||||
testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个有效的会话
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
@@ -119,7 +120,7 @@ func Test_sessionService_Find(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *AuthContext
|
||||
want *auth.Context
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
@@ -128,7 +129,7 @@ func Test_sessionService_Find(t *testing.T) {
|
||||
ctx: ctx,
|
||||
token: validToken,
|
||||
},
|
||||
want: &auth,
|
||||
want: &authCtx,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
@@ -159,11 +160,11 @@ func Test_sessionService_Find(t *testing.T) {
|
||||
func Test_sessionService_Refresh(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个初始会话
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建初始会话: %v", err)
|
||||
}
|
||||
@@ -197,7 +198,7 @@ func Test_sessionService_Refresh(t *testing.T) {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息一致
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
if !reflect.DeepEqual(td.Auth, authCtx) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@@ -251,11 +252,11 @@ func Test_sessionService_Refresh(t *testing.T) {
|
||||
func Test_sessionService_Remove(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个会话
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
@@ -312,7 +313,7 @@ func Test_sessionService_Remove(t *testing.T) {
|
||||
|
||||
func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
type fields struct {
|
||||
Payload Payload
|
||||
Payload auth.Payload
|
||||
Permissions map[string]struct{}
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
@@ -328,7 +329,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "用户拥有所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
"write": {},
|
||||
@@ -343,7 +344,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "用户拥有至少一个所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
@@ -357,7 +358,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "用户没有所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
@@ -371,7 +372,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "空权限列表",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{},
|
||||
Metadata: nil,
|
||||
},
|
||||
@@ -383,7 +384,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "nil权限列表",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
@@ -395,7 +396,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
{
|
||||
name: "nil认证上下文",
|
||||
fields: fields{
|
||||
Payload: Payload{},
|
||||
Payload: auth.Payload{},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
@@ -408,7 +409,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &AuthContext{
|
||||
a := &auth.Context{
|
||||
Payload: tt.fields.Payload,
|
||||
Permissions: tt.fields.Permissions,
|
||||
Metadata: tt.fields.Metadata,
|
||||
|
||||
Reference in New Issue
Block a user