重构代码结构与认证体系,集成异步任务消费者

This commit is contained in:
2025-11-17 18:38:10 +08:00
parent a97c970166
commit a245229bc2
70 changed files with 2000 additions and 2334 deletions

View File

@@ -4,112 +4,101 @@ import (
"context"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"platform/web/core"
client2 "platform/web/domains/client"
m "platform/web/models"
q "platform/web/queries"
"slices"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
)
type ProtectBuilder struct {
c *fiber.Ctx
types []PayloadType
scopes []string
func Authenticate() fiber.Handler {
return func(ctx *fiber.Ctx) error {
header := ctx.Get(fiber.HeaderAuthorization)
authCtx, err := authHeader(ctx.Context(), header)
if err != nil {
return err
}
if authCtx == nil {
authCtx = &AuthCtx{}
}
SetAuthCtx(ctx, authCtx)
return ctx.Next()
}
}
func NewProtect(c *fiber.Ctx) *ProtectBuilder {
return &ProtectBuilder{c, []PayloadType{}, []string{}}
}
func authHeader(ctx context.Context, header string) (*AuthCtx, error) {
if header == "" {
return nil, nil
}
func (p *ProtectBuilder) Payload(types ...PayloadType) *ProtectBuilder {
p.types = types
return p
}
func (p *ProtectBuilder) Scopes(scopes ...string) *ProtectBuilder {
p.scopes = scopes
return p
}
func (p *ProtectBuilder) Do() (*Context, error) {
return Protect(p.c, p.types, p.scopes)
}
func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) {
// 获取令牌
var header = c.Get("Authorization")
var split = strings.Split(header, " ")
if len(split) != 2 {
slog.Debug("Authorization 头格式不正确")
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
var token = strings.TrimSpace(split[1])
if token == "" {
slog.Debug("提供的令牌为空")
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
var auth *Context
var authCtx *AuthCtx
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
authCtx, err = authBearer(ctx, token)
if err != nil {
slog.Debug("Bearer 认证失败", "err", err)
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
case "Basic":
if !slices.Contains(types, PayloadInternalServer) {
slog.Debug("禁止使用 Basic 认证方式")
return nil, ErrUnauthorize
}
auth, err = authBasic(c.Context(), token)
authCtx, err = authBasic(ctx, token)
if err != nil {
slog.Debug("Basic 认证失败", "err", err)
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
default:
slog.Debug("无效的认证方式", "method", split[0])
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
return nil, ErrForbidden
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
return nil, ErrForbidden
}
// 保存到上下文
Locals(c, auth)
return auth, nil
return authCtx, err
}
func Locals(c *fiber.Ctx, auth *Context) {
c.Locals("auth", auth)
}
func authBearer(ctx context.Context, token string) (*Context, error) {
auth, err := FindSession(ctx, token)
func authBearer(_ context.Context, token string) (*AuthCtx, error) {
session, err := FindSession(token, time.Now())
if err != nil {
slog.Debug(err.Error())
return nil, err
slog.Debug("Bearer 认证失败", "err", err)
return nil, ErrAuthenticateUnauthorize
}
return auth, nil
scopes := []string{}
if session.Scopes_ != nil {
scopes = strings.Split(*session.Scopes_, " ")
}
return &AuthCtx{
User: session.User,
Admin: session.Admin,
Client: session.Client,
Scopes: scopes,
Session: session,
}, nil
}
func authBasic(_ context.Context, token string) (*Context, error) {
func authBasic(_ context.Context, token string) (*AuthCtx, error) {
// 解析 Basic 认证信息
var base, err = base64.RawURLEncoding.DecodeString(token)
@@ -125,14 +114,23 @@ func authBasic(_ context.Context, token string) (*Context, error) {
return nil, errors.New("令牌格式错误,必须是 <client_id>:<client_secret> 格式")
}
var clientID = split[0]
client, err := authClient(split[0], split[1])
if err != nil {
return nil, fmt.Errorf("客户端认证失败:%w", err)
}
return &AuthCtx{
Client: client,
Scopes: []string{},
}, nil
}
func authClient(clientId, clientSecret string) (*m.Client, error) {
// 获取客户端信息
client, err := q.Client.
Where(
q.Client.ClientID.Eq(clientID),
q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)),
q.Client.GrantClient.Is(true),
q.Client.ClientID.Eq(clientId),
q.Client.Status.Eq(1)).
Take()
if err != nil {
@@ -140,33 +138,57 @@ func authBasic(_ context.Context, token string) (*Context, error) {
}
// 检查客户端密钥
var clientSecret = split[1]
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, errors.New("客户端密钥错误")
spec := client2.Spec(client.Spec)
if spec == client2.SpecWeb || spec == client2.SpecApi {
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, errors.New("客户端密钥错误")
}
}
// todo 查询客户端关联权限
// 组织授权信息(一次性请求)
return &Context{
Payload: Payload{
Id: client.ID,
Type: PayloadTypeFromClientSpec(client2.Spec(client.Spec)),
Name: client.Name,
Avatar: client.Icon,
},
Permissions: nil,
Metadata: nil,
}, nil
return client, nil
}
type AuthenticationErr string
func authUserBySms(tx *q.Query, username, code string) (*m.User, error) {
// 验证验证码
err := s.Verifier.VerifySms(context.Background(), username, code)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return nil, ErrAuthorizeInvalidRequest
}
return nil, err
}
func (e AuthenticationErr) Error() string {
return string(e)
// 查找用户
return tx.User.Where(tx.User.Phone.Eq(username)).Take()
}
var (
ErrUnauthorize = AuthenticationErr("令牌无效")
ErrForbidden = AuthenticationErr("没有权限")
)
func authUserByEmail(tx *q.Query, username, code string) (*m.User, error) {
return nil, core.NewServErr("邮箱登录不可用")
}
func authUserByPassword(tx *q.Query, username, password string) (*m.User, error) {
user, err := tx.User.
Where(tx.User.Phone.Eq(username)).
Or(tx.User.Email.Eq(username)).
Or(tx.User.Username.Eq(username)).
Take()
if err != nil {
slog.Debug("查找用户失败", "error", err)
return nil, core.NewBizErr("用户不存在或密码错误")
}
// 验证密码
if user.Password == nil || *user.Password == "" {
slog.Debug("用户未设置密码", "username", username)
return nil, core.NewBizErr("用户不存在或密码错误")
}
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)) != nil {
slog.Debug("密码验证失败", "username", username)
return nil, core.NewBizErr("用户不存在或密码错误")
}
return user, nil
}

View File

@@ -1,5 +1,28 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
"platform/pkg/env"
"platform/pkg/u"
"platform/web/core"
user2 "platform/web/domains/user"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gorm.io/gorm"
)
type GrantType string
const (
@@ -17,36 +40,352 @@ const (
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码
)
func Token(grant GrantType) error {
return nil
type TokenReq struct {
GrantType GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
GrantCodeData
GrantClientData
GrantRefreshData
GrantPasswordData
}
func authAuthorizationCode() {
type GrantCodeData struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
func authClientCredential() {
type GrantClientData struct {
}
func authRefreshToken() {
type GrantRefreshData struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
func authPassword() {
type GrantPasswordData struct {
LoginType PasswordGrantType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
func authPasswordSecret() {
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
func authPasswordPhone() {
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
func authPasswordEmail() {
func Token(c *fiber.Ctx) error {
now := time.Now()
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
}
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
if req.Code == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
}
// 刷新令牌模式
case GrantRefreshToken:
if req.RefreshToken == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
}
// 密码模式
case GrantPassword:
if req.LoginType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
}
if req.Username == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
}
if req.Password == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
}
}
// 验证客户端身份
authCtx := GetAuthCtx(c)
if authCtx == nil {
authCtx = &AuthCtx{}
}
if authCtx.Client == nil {
client, err := authClient(req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
authCtx.Client = client
}
// 处理授权
var session *m.Session
var err error
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
session, err = authAuthorizationCode(c, authCtx, req, now)
// 客户端凭证模式
case GrantClientCredentials:
session, err = authClientCredential(c, authCtx, req, now)
// 刷新令牌模式
case GrantRefreshToken:
session, err = authRefreshToken(c, authCtx, req, now)
// 密码模式
case GrantPassword:
session, err = authPassword(c, authCtx, req, now)
default:
return sendError(c, ErrAuthorizeUnsupportedGrantType)
}
if err != nil {
return sendError(c, err)
}
// 返回响应
return c.JSON(&TokenResp{
TokenType: "Bearer",
AccessToken: session.AccessToken,
RefreshToken: u.Z(session.RefreshToken),
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
Scope: u.Z(session.Scopes_),
})
}
func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
// 检查 code 获取用户授权信息
data, err := g.Redis.Get(context.Background(), req.Code).Result()
if err != nil {
return nil, err
}
var codeCtx CodeContext
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
return nil, err
}
// 检查 PKCE
if codeCtx.CodeChallengeMethod != "" {
if req.CodeVerifier == "" {
return nil, ErrAuthorizeInvalidPKCE
}
switch codeCtx.CodeChallengeMethod {
case "plain":
if req.CodeVerifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
case "S256":
hash := sha256.Sum256([]byte(req.CodeVerifier))
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
if verifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
default:
return nil, ErrAuthorizeInvalidPKCE
}
}
user, err := q.User.Where(
q.User.ID.Eq(codeCtx.UserID),
q.User.Status.Eq(int32(user2.StatusEnabled)),
).First()
if err != nil {
return nil, err
}
// todo 检查 scope
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
UserID: &user.ID,
ClientID: &auth.Client.ID,
Scopes_: u.P(strings.Join(codeCtx.Scopes, " ")),
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
if codeCtx.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
// todo 检查 scope
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
ClientID: &auth.Client.ID,
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
// 保存会话
err := SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
var user *m.User
err := q.Q.Transaction(func(tx *q.Query) (err error) {
switch req.LoginType {
case GrantPasswordPhone:
user, err = authUserBySms(tx, req.Username, req.Password)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if user == nil {
user = &m.User{
Phone: req.Username,
Username: u.P(req.Username),
Status: int32(user2.StatusEnabled),
}
}
case GrantPasswordEmail:
user, err = authUserByEmail(tx, req.Username, req.Password)
if err != nil {
return err
}
case GrantPasswordSecret:
user, err = authUserByPassword(tx, req.Username, req.Password)
if err != nil {
return err
}
default:
return ErrAuthorizeInvalidRequest
}
// 账户状态
if user2.Status(user.Status) == user2.StatusDisabled {
slog.Debug("账户状态异常", "username", req.Username, "status", user.Status)
return core.NewBizErr("账号无法登录")
}
// 更新用户的登录时间
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
user.LastLoginHost = u.X(ctx.IP())
user.LastLoginAgent = u.X(ctx.Get(fiber.HeaderUserAgent))
if err := tx.User.Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
UserID: &user.ID,
ClientID: &auth.Client.ID,
Scopes_: u.X(req.Scope),
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
if req.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
session, err := FindSessionByRefresh(req.RefreshToken, now)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAuthorizeInvalidGrant
}
return nil, err
}
// todo 检查权限
// 生成令牌
session.AccessToken = uuid.NewString()
session.AccessTokenExpires = orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second))
if session.RefreshToken != nil {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
// 保存令牌
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr AuthErr
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, ErrAuthorizeInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, ErrAuthorizeInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
func Revoke() error {
@@ -56,3 +395,12 @@ func Revoke() error {
func Introspect() error {
return nil
}
type CodeContext struct {
UserID int32 `json:"user_id"`
ClientID int32 `json:"client_id"`
Scopes []string `json:"scopes"`
Remember bool `json:"remember"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

99
web/auth/check.go Normal file
View File

@@ -0,0 +1,99 @@
package auth
import (
"platform/web/domains/client"
m "platform/web/models"
"github.com/gofiber/fiber/v2"
)
type AuthCtx struct {
User *m.User `json:"account,omitempty"`
Admin *m.Admin `json:"admin,omitempty"`
Client *m.Client `json:"client,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Session *m.Session `json:"session,omitempty"`
smap map[string]struct{}
}
func (a *AuthCtx) PermitUser(scopes ...string) (*AuthCtx, error) {
if a.User == nil {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitAdmin(scopes ...string) (*AuthCtx, error) {
if a.Admin == nil {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitSecretClient(scopes ...string) (*AuthCtx, error) {
if a.Client == nil {
return a, ErrAuthenticateForbidden
}
spec := client.Spec(a.Client.Spec)
if spec != client.SpecApi && spec != client.SpecWeb {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitInternalClient(scopes ...string) (*AuthCtx, error) {
if a.Client == nil {
return a, ErrAuthenticateForbidden
}
spec := client.Spec(a.Client.Spec)
if spec != client.SpecApi && spec != client.SpecWeb {
return a, ErrAuthenticateForbidden
}
cType := client.Type(a.Client.Type)
if cType != client.TypeInternal {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) checkScopes(scopes ...string) bool {
if len(scopes) == 0 || len(a.Scopes) == 0 {
return true
}
if len(a.smap) == 0 && len(a.Scopes) > 0 {
for _, scope := range scopes {
a.smap[scope] = struct{}{}
}
}
for _, scope := range scopes {
if _, ok := a.smap[scope]; ok {
return true
}
}
return false
}
const AuthCtxKey = "session"
func SetAuthCtx(c *fiber.Ctx, auth *AuthCtx) {
c.Locals(AuthCtxKey, auth)
}
func GetAuthCtx(c *fiber.Ctx) *AuthCtx {
if authCtx, ok := c.Locals(AuthCtxKey).(*AuthCtx); ok {
return authCtx
}
return nil
}

View File

@@ -1,103 +0,0 @@
package auth
import (
client2 "platform/web/domains/client"
)
// Context 定义认证信息
type Context struct {
Payload Payload `json:"payload"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (a *Context) AnyType(types ...PayloadType) bool {
if a == nil {
return false
}
for _, t := range types {
if a.Payload.Type == t {
return true
}
}
return false
}
// AnyPermission 检查认证是否包含指定权限
func (a *Context) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
return false
}
for _, permission := range requiredPermission {
if _, ok := a.Permissions[permission]; ok {
return true
}
}
return false
}
// Payload 定义负载信息
type Payload struct {
Id int32 `json:"id,omitempty"`
Type PayloadType `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Avatar *string `json:"avatar,omitempty"`
}
type PayloadType int
const (
PayloadNone PayloadType = iota // 游客
PayloadUser // 用户
PayloadAdmin // 管理员
PayloadPublicServer // 公共服务public_client
PayloadSecuredServer // 安全服务credential_client
PayloadInternalServer // 内部服务
)
func (t PayloadType) ToStr() string {
switch t {
case PayloadUser:
return "user"
case PayloadAdmin:
return "admn"
case PayloadPublicServer:
return "cpub"
case PayloadSecuredServer:
return "ccnf"
case PayloadInternalServer:
return "inte"
default:
return "none"
}
}
func PayloadTypeFromStr(name string) PayloadType {
switch name {
case "user":
return PayloadUser
case "admn":
return PayloadAdmin
case "cpub":
return PayloadPublicServer
case "ccnf":
return PayloadSecuredServer
case "inte":
return PayloadInternalServer
default:
return PayloadNone
}
}
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
var clientType PayloadType
switch spec {
case client2.SpecNative, client2.SpecBrowser:
clientType = PayloadPublicServer
case client2.SpecWeb:
clientType = PayloadSecuredServer
case client2.SpecTrusted:
clientType = PayloadInternalServer
}
return clientType
}

24
web/auth/errors.go Normal file
View File

@@ -0,0 +1,24 @@
package auth
type AuthErr string
func (e AuthErr) Error() string {
return string(e)
}
// 认证错误
const (
ErrAuthenticateUnauthorize = AuthErr("令牌无效")
ErrAuthenticateForbidden = AuthErr("没有权限")
)
// 授权错误
const (
ErrAuthorizeInvalidRequest = AuthErr("invalid_request")
ErrAuthorizeInvalidClient = AuthErr("invalid_client")
ErrAuthorizeInvalidGrant = AuthErr("invalid_grant")
ErrAuthorizeInvalidScope = AuthErr("invalid_scope")
ErrAuthorizeUnauthorizedClient = AuthErr("unauthorized_client")
ErrAuthorizeUnsupportedGrantType = AuthErr("unsupported_grant_type")
ErrAuthorizeInvalidPKCE = AuthErr("invalid_pkce")
)

View File

@@ -2,160 +2,36 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"platform/pkg/env"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"time"
"gorm.io/gen/field"
)
type Session struct {
// 认证主体
Payload *Payload
// 令牌信息
TokenDetails *TokenDetails
func FindSession(accessToken string, now time.Time) (*m.Session, error) {
return q.Session.
Preload(field.Associations).
Where(
q.Session.AccessToken.Eq(accessToken),
q.Session.AccessTokenExpires.Gt(orm.LocalDateTime(now)),
).First()
}
func FindSession(ctx context.Context, token string) (*Context, error) {
// 读取认证数据
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, errors.New("invalid_token")
}
return nil, err
}
// 反序列化
auth := new(Context)
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
return nil, err
}
return auth, nil
func FindSessionByRefresh(refreshToken string, now time.Time) (*m.Session, error) {
return q.Session.
Preload(field.Associations).
Where(
q.Session.RefreshToken.Eq(refreshToken),
q.Session.RefreshTokenExpires.Gt(orm.LocalDateTime(now)),
).First()
}
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(authCtx)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: authCtx,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := g.Redis.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: authCtx,
}, nil
}
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidRefreshToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
func SaveSession(session *m.Session) error {
return q.Session.Save(session)
}
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
@@ -163,11 +39,6 @@ func RemoveSession(ctx context.Context, accessToken string, refreshToken string)
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
@@ -177,32 +48,3 @@ func accessKey(token string) string {
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth *Context
}
type RefreshData struct {
AuthContext *Context
AccessToken string
}
type SessionErr string
func (e SessionErr) Error() string {
return string(e)
}
const (
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
)