重构认证授权逻辑,集中到 auth 包中

This commit is contained in:
2025-05-12 10:07:12 +08:00
parent cfdee98a1b
commit 2c37dcc2be
40 changed files with 905 additions and 455 deletions

View File

@@ -77,7 +77,7 @@ func Locals(c *fiber.Ctx, auth *Context) {
}
func authBearer(ctx context.Context, token string) (*Context, error) {
auth, err := find(ctx, token)
auth, err := FindSession(ctx, token)
if err != nil {
slog.Debug(err.Error())
return nil, err

View File

@@ -16,3 +16,43 @@ const (
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式
)
func Token(grant GrantType) error {
return nil
}
func authAuthorizationCode() {
}
func authClientCredential() {
}
func authRefreshToken() {
}
func authPassword() {
}
func authPasswordSecret() {
}
func authPasswordPhone() {
}
func authPasswordEmail() {
}
func Revoke() error {
return nil
}
func Introspect() error {
return nil
}

View File

@@ -1,13 +1,28 @@
package auth
import (
client2 "platform/web/domains/client"
)
// Context 定义认证信息
type Context struct {
Payload Payload `json:"payload"`
Agent Agent `json:"agent,omitempty"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (a *Context) AnyType(types ...PayloadType) bool {
if a == nil {
return false
}
for _, t := range types {
if a.Payload.Type == t {
return true
}
}
return false
}
// AnyPermission 检查认证是否包含指定权限
func (a *Context) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
@@ -29,26 +44,15 @@ type Payload struct {
Avatar string `json:"avatar,omitempty"`
}
type Agent struct {
Id int32 `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
type PayloadType int
const (
// PayloadNone 游客
PayloadNone PayloadType = iota
// PayloadUser 用户
PayloadUser
// PayloadAdmin 管理员
PayloadAdmin
// PayloadPublicServer 公共服务public_client
PayloadPublicServer
// PayloadSecuredServer 安全服务credential_client
PayloadSecuredServer
// PayloadInternalServer 内部服务
PayloadInternalServer
PayloadNone PayloadType = iota // 游客
PayloadUser // 用户
PayloadAdmin // 管理员
PayloadPublicServer // 公共服务public_client
PayloadSecuredServer // 安全服务credential_client
PayloadInternalServer // 内部服务
)
func (t PayloadType) ToStr() string {
@@ -80,3 +84,16 @@ func PayloadTypeFromStr(name string) PayloadType {
return PayloadNone
}
}
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
var clientType PayloadType
switch spec {
case client2.SpecNative, client2.SpecBrowser:
clientType = PayloadPublicServer
case client2.SpecWeb:
clientType = PayloadSecuredServer
case client2.SpecTrusted:
clientType = PayloadInternalServer
}
return clientType
}

View File

@@ -5,11 +5,21 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"platform/pkg/env"
g "platform/web/globals"
"time"
)
func find(ctx context.Context, token string) (*Context, error) {
type Session struct {
// 认证主体
Payload *Payload
// 令牌信息
TokenDetails *TokenDetails
}
func FindSession(ctx context.Context, token string) (*Context, error) {
// 读取认证数据
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
@@ -29,6 +39,170 @@ func find(ctx context.Context, token string) (*Context, error) {
return auth, nil
}
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(authCtx)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: authCtx,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := g.Redis.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: authCtx,
}, nil
}
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidRefreshToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
}
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
}
// 刷新令牌键的格式为 "session:refreshKey:<token>"
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth *Context
}
type RefreshData struct {
AuthContext *Context
AccessToken string
}
type SessionErr string
func (e SessionErr) Error() string {
return string(e)
}
const (
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
)