重构代码结构与认证体系,集成异步任务消费者
This commit is contained in:
@@ -4,112 +4,101 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"platform/web/core"
|
||||
client2 "platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"slices"
|
||||
s "platform/web/services"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type ProtectBuilder struct {
|
||||
c *fiber.Ctx
|
||||
types []PayloadType
|
||||
scopes []string
|
||||
func Authenticate() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
header := ctx.Get(fiber.HeaderAuthorization)
|
||||
authCtx, err := authHeader(ctx.Context(), header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if authCtx == nil {
|
||||
authCtx = &AuthCtx{}
|
||||
}
|
||||
|
||||
SetAuthCtx(ctx, authCtx)
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func NewProtect(c *fiber.Ctx) *ProtectBuilder {
|
||||
return &ProtectBuilder{c, []PayloadType{}, []string{}}
|
||||
}
|
||||
func authHeader(ctx context.Context, header string) (*AuthCtx, error) {
|
||||
if header == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Payload(types ...PayloadType) *ProtectBuilder {
|
||||
p.types = types
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Scopes(scopes ...string) *ProtectBuilder {
|
||||
p.scopes = scopes
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Do() (*Context, error) {
|
||||
return Protect(p.c, p.types, p.scopes)
|
||||
}
|
||||
|
||||
func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) {
|
||||
// 获取令牌
|
||||
var header = c.Get("Authorization")
|
||||
var split = strings.Split(header, " ")
|
||||
if len(split) != 2 {
|
||||
slog.Debug("Authorization 头格式不正确")
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
var token = strings.TrimSpace(split[1])
|
||||
if token == "" {
|
||||
slog.Debug("提供的令牌为空")
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
var auth *Context
|
||||
var authCtx *AuthCtx
|
||||
var err error
|
||||
switch split[0] {
|
||||
|
||||
case "Bearer":
|
||||
auth, err = authBearer(c.Context(), token)
|
||||
authCtx, err = authBearer(ctx, token)
|
||||
if err != nil {
|
||||
slog.Debug("Bearer 认证失败", "err", err)
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
case "Basic":
|
||||
if !slices.Contains(types, PayloadInternalServer) {
|
||||
slog.Debug("禁止使用 Basic 认证方式")
|
||||
return nil, ErrUnauthorize
|
||||
}
|
||||
auth, err = authBasic(c.Context(), token)
|
||||
authCtx, err = authBasic(ctx, token)
|
||||
if err != nil {
|
||||
slog.Debug("Basic 认证失败", "err", err)
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
default:
|
||||
slog.Debug("无效的认证方式", "method", split[0])
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if !slices.Contains(types, auth.Payload.Type) {
|
||||
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
|
||||
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
// 保存到上下文
|
||||
Locals(c, auth)
|
||||
return auth, nil
|
||||
return authCtx, err
|
||||
}
|
||||
|
||||
func Locals(c *fiber.Ctx, auth *Context) {
|
||||
c.Locals("auth", auth)
|
||||
}
|
||||
|
||||
func authBearer(ctx context.Context, token string) (*Context, error) {
|
||||
auth, err := FindSession(ctx, token)
|
||||
func authBearer(_ context.Context, token string) (*AuthCtx, error) {
|
||||
session, err := FindSession(token, time.Now())
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil, err
|
||||
slog.Debug("Bearer 认证失败", "err", err)
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
return auth, nil
|
||||
|
||||
scopes := []string{}
|
||||
if session.Scopes_ != nil {
|
||||
scopes = strings.Split(*session.Scopes_, " ")
|
||||
}
|
||||
return &AuthCtx{
|
||||
User: session.User,
|
||||
Admin: session.Admin,
|
||||
Client: session.Client,
|
||||
Scopes: scopes,
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
func authBasic(_ context.Context, token string) (*AuthCtx, error) {
|
||||
|
||||
// 解析 Basic 认证信息
|
||||
var base, err = base64.RawURLEncoding.DecodeString(token)
|
||||
@@ -125,14 +114,23 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
return nil, errors.New("令牌格式错误,必须是 <client_id>:<client_secret> 格式")
|
||||
}
|
||||
|
||||
var clientID = split[0]
|
||||
client, err := authClient(split[0], split[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("客户端认证失败:%w", err)
|
||||
}
|
||||
|
||||
return &AuthCtx{
|
||||
Client: client,
|
||||
Scopes: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authClient(clientId, clientSecret string) (*m.Client, error) {
|
||||
|
||||
// 获取客户端信息
|
||||
client, err := q.Client.
|
||||
Where(
|
||||
q.Client.ClientID.Eq(clientID),
|
||||
q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)),
|
||||
q.Client.GrantClient.Is(true),
|
||||
q.Client.ClientID.Eq(clientId),
|
||||
q.Client.Status.Eq(1)).
|
||||
Take()
|
||||
if err != nil {
|
||||
@@ -140,33 +138,57 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
}
|
||||
|
||||
// 检查客户端密钥
|
||||
var clientSecret = split[1]
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, errors.New("客户端密钥错误")
|
||||
spec := client2.Spec(client.Spec)
|
||||
if spec == client2.SpecWeb || spec == client2.SpecApi {
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, errors.New("客户端密钥错误")
|
||||
}
|
||||
}
|
||||
|
||||
// todo 查询客户端关联权限
|
||||
|
||||
// 组织授权信息(一次性请求)
|
||||
return &Context{
|
||||
Payload: Payload{
|
||||
Id: client.ID,
|
||||
Type: PayloadTypeFromClientSpec(client2.Spec(client.Spec)),
|
||||
Name: client.Name,
|
||||
Avatar: client.Icon,
|
||||
},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
}, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type AuthenticationErr string
|
||||
func authUserBySms(tx *q.Query, username, code string) (*m.User, error) {
|
||||
// 验证验证码
|
||||
err := s.Verifier.VerifySms(context.Background(), username, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, s.ErrVerifierServiceInvalid) {
|
||||
return nil, ErrAuthorizeInvalidRequest
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (e AuthenticationErr) Error() string {
|
||||
return string(e)
|
||||
// 查找用户
|
||||
return tx.User.Where(tx.User.Phone.Eq(username)).Take()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnauthorize = AuthenticationErr("令牌无效")
|
||||
ErrForbidden = AuthenticationErr("没有权限")
|
||||
)
|
||||
func authUserByEmail(tx *q.Query, username, code string) (*m.User, error) {
|
||||
return nil, core.NewServErr("邮箱登录不可用")
|
||||
}
|
||||
|
||||
func authUserByPassword(tx *q.Query, username, password string) (*m.User, error) {
|
||||
user, err := tx.User.
|
||||
Where(tx.User.Phone.Eq(username)).
|
||||
Or(tx.User.Email.Eq(username)).
|
||||
Or(tx.User.Username.Eq(username)).
|
||||
Take()
|
||||
if err != nil {
|
||||
slog.Debug("查找用户失败", "error", err)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if user.Password == nil || *user.Password == "" {
|
||||
slog.Debug("用户未设置密码", "username", username)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)) != nil {
|
||||
slog.Debug("密码验证失败", "username", username)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/u"
|
||||
"platform/web/core"
|
||||
user2 "platform/web/domains/user"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GrantType string
|
||||
|
||||
const (
|
||||
@@ -17,36 +40,352 @@ const (
|
||||
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码
|
||||
)
|
||||
|
||||
func Token(grant GrantType) error {
|
||||
return nil
|
||||
type TokenReq struct {
|
||||
GrantType GrantType `json:"grant_type" form:"grant_type"`
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
GrantCodeData
|
||||
GrantClientData
|
||||
GrantRefreshData
|
||||
GrantPasswordData
|
||||
}
|
||||
|
||||
func authAuthorizationCode() {
|
||||
|
||||
type GrantCodeData struct {
|
||||
Code string `json:"code" form:"code"`
|
||||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||||
}
|
||||
|
||||
func authClientCredential() {
|
||||
|
||||
type GrantClientData struct {
|
||||
}
|
||||
|
||||
func authRefreshToken() {
|
||||
|
||||
type GrantRefreshData struct {
|
||||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||||
}
|
||||
|
||||
func authPassword() {
|
||||
|
||||
type GrantPasswordData struct {
|
||||
LoginType PasswordGrantType `json:"login_type" form:"login_type"`
|
||||
Username string `json:"username" form:"username"`
|
||||
Password string `json:"password" form:"password"`
|
||||
Remember bool `json:"remember" form:"remember"`
|
||||
}
|
||||
|
||||
func authPasswordSecret() {
|
||||
|
||||
type TokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
func authPasswordPhone() {
|
||||
|
||||
type TokenErrResp struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
func authPasswordEmail() {
|
||||
func Token(c *fiber.Ctx) error {
|
||||
now := time.Now()
|
||||
|
||||
// 验证请求参数
|
||||
req := new(TokenReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
|
||||
}
|
||||
if req.GrantType == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
|
||||
}
|
||||
|
||||
switch req.GrantType {
|
||||
// 授权码模式
|
||||
case GrantAuthorizationCode:
|
||||
if req.Code == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
|
||||
}
|
||||
// 刷新令牌模式
|
||||
case GrantRefreshToken:
|
||||
if req.RefreshToken == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
|
||||
}
|
||||
// 密码模式
|
||||
case GrantPassword:
|
||||
if req.LoginType == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
|
||||
}
|
||||
if req.Username == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证客户端身份
|
||||
authCtx := GetAuthCtx(c)
|
||||
if authCtx == nil {
|
||||
authCtx = &AuthCtx{}
|
||||
}
|
||||
if authCtx.Client == nil {
|
||||
client, err := authClient(req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
authCtx.Client = client
|
||||
}
|
||||
|
||||
// 处理授权
|
||||
var session *m.Session
|
||||
var err error
|
||||
switch req.GrantType {
|
||||
// 授权码模式
|
||||
case GrantAuthorizationCode:
|
||||
session, err = authAuthorizationCode(c, authCtx, req, now)
|
||||
// 客户端凭证模式
|
||||
case GrantClientCredentials:
|
||||
session, err = authClientCredential(c, authCtx, req, now)
|
||||
// 刷新令牌模式
|
||||
case GrantRefreshToken:
|
||||
session, err = authRefreshToken(c, authCtx, req, now)
|
||||
// 密码模式
|
||||
case GrantPassword:
|
||||
session, err = authPassword(c, authCtx, req, now)
|
||||
default:
|
||||
return sendError(c, ErrAuthorizeUnsupportedGrantType)
|
||||
}
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
return c.JSON(&TokenResp{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: u.Z(session.RefreshToken),
|
||||
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
|
||||
Scope: u.Z(session.Scopes_),
|
||||
})
|
||||
}
|
||||
|
||||
func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
|
||||
// 检查 code 获取用户授权信息
|
||||
data, err := g.Redis.Get(context.Background(), req.Code).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var codeCtx CodeContext
|
||||
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查 PKCE
|
||||
if codeCtx.CodeChallengeMethod != "" {
|
||||
if req.CodeVerifier == "" {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
|
||||
switch codeCtx.CodeChallengeMethod {
|
||||
case "plain":
|
||||
if req.CodeVerifier != codeCtx.CodeChallenge {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
case "S256":
|
||||
hash := sha256.Sum256([]byte(req.CodeVerifier))
|
||||
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
if verifier != codeCtx.CodeChallenge {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
default:
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
user, err := q.User.Where(
|
||||
q.User.ID.Eq(codeCtx.UserID),
|
||||
q.User.Status.Eq(int32(user2.StatusEnabled)),
|
||||
).First()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo 检查 scope
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
UserID: &user.ID,
|
||||
ClientID: &auth.Client.ID,
|
||||
Scopes_: u.P(strings.Join(codeCtx.Scopes, " ")),
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
if codeCtx.Remember {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
|
||||
// todo 检查 scope
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
ClientID: &auth.Client.ID,
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
|
||||
// 保存会话
|
||||
err := SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
var user *m.User
|
||||
err := q.Q.Transaction(func(tx *q.Query) (err error) {
|
||||
switch req.LoginType {
|
||||
case GrantPasswordPhone:
|
||||
user, err = authUserBySms(tx, req.Username, req.Password)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if user == nil {
|
||||
user = &m.User{
|
||||
Phone: req.Username,
|
||||
Username: u.P(req.Username),
|
||||
Status: int32(user2.StatusEnabled),
|
||||
}
|
||||
}
|
||||
case GrantPasswordEmail:
|
||||
user, err = authUserByEmail(tx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case GrantPasswordSecret:
|
||||
user, err = authUserByPassword(tx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return ErrAuthorizeInvalidRequest
|
||||
}
|
||||
|
||||
// 账户状态
|
||||
if user2.Status(user.Status) == user2.StatusDisabled {
|
||||
slog.Debug("账户状态异常", "username", req.Username, "status", user.Status)
|
||||
return core.NewBizErr("账号无法登录")
|
||||
}
|
||||
|
||||
// 更新用户的登录时间
|
||||
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
|
||||
user.LastLoginHost = u.X(ctx.IP())
|
||||
user.LastLoginAgent = u.X(ctx.Get(fiber.HeaderUserAgent))
|
||||
if err := tx.User.Save(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
UserID: &user.ID,
|
||||
ClientID: &auth.Client.ID,
|
||||
Scopes_: u.X(req.Scope),
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
if req.Remember {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
session, err := FindSessionByRefresh(req.RefreshToken, now)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAuthorizeInvalidGrant
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo 检查权限
|
||||
|
||||
// 生成令牌
|
||||
session.AccessToken = uuid.NewString()
|
||||
session.AccessTokenExpires = orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second))
|
||||
if session.RefreshToken != nil {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
// 保存令牌
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||||
var sErr AuthErr
|
||||
if errors.As(err, &sErr) {
|
||||
status := fiber.StatusBadRequest
|
||||
var desc string
|
||||
switch {
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
|
||||
desc = "无效的请求"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidClient):
|
||||
status = fiber.StatusUnauthorized
|
||||
desc = "无效的客户端凭证"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
|
||||
desc = "无效的授权凭证"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidScope):
|
||||
desc = "无效的授权范围"
|
||||
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
|
||||
desc = "未授权的客户端"
|
||||
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
|
||||
desc = "不支持的授权类型"
|
||||
}
|
||||
if len(description) > 0 {
|
||||
desc = description[0]
|
||||
}
|
||||
|
||||
return c.Status(status).JSON(TokenErrResp{
|
||||
Error: string(sErr),
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Revoke() error {
|
||||
@@ -56,3 +395,12 @@ func Revoke() error {
|
||||
func Introspect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type CodeContext struct {
|
||||
UserID int32 `json:"user_id"`
|
||||
ClientID int32 `json:"client_id"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Remember bool `json:"remember"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
}
|
||||
|
||||
99
web/auth/check.go
Normal file
99
web/auth/check.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type AuthCtx struct {
|
||||
User *m.User `json:"account,omitempty"`
|
||||
Admin *m.Admin `json:"admin,omitempty"`
|
||||
Client *m.Client `json:"client,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Session *m.Session `json:"session,omitempty"`
|
||||
smap map[string]struct{}
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitUser(scopes ...string) (*AuthCtx, error) {
|
||||
if a.User == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitAdmin(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Admin == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitSecretClient(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Client == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
spec := client.Spec(a.Client.Spec)
|
||||
if spec != client.SpecApi && spec != client.SpecWeb {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitInternalClient(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Client == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
spec := client.Spec(a.Client.Spec)
|
||||
if spec != client.SpecApi && spec != client.SpecWeb {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
cType := client.Type(a.Client.Type)
|
||||
if cType != client.TypeInternal {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) checkScopes(scopes ...string) bool {
|
||||
if len(scopes) == 0 || len(a.Scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(a.smap) == 0 && len(a.Scopes) > 0 {
|
||||
for _, scope := range scopes {
|
||||
a.smap[scope] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if _, ok := a.smap[scope]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const AuthCtxKey = "session"
|
||||
|
||||
func SetAuthCtx(c *fiber.Ctx, auth *AuthCtx) {
|
||||
c.Locals(AuthCtxKey, auth)
|
||||
}
|
||||
|
||||
func GetAuthCtx(c *fiber.Ctx) *AuthCtx {
|
||||
if authCtx, ok := c.Locals(AuthCtxKey).(*AuthCtx); ok {
|
||||
return authCtx
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
client2 "platform/web/domains/client"
|
||||
)
|
||||
|
||||
// Context 定义认证信息
|
||||
type Context struct {
|
||||
Payload Payload `json:"payload"`
|
||||
Permissions map[string]struct{} `json:"permissions,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (a *Context) AnyType(types ...PayloadType) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
for _, t := range types {
|
||||
if a.Payload.Type == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyPermission 检查认证是否包含指定权限
|
||||
func (a *Context) AnyPermission(requiredPermission ...string) bool {
|
||||
if a == nil || a.Permissions == nil {
|
||||
return false
|
||||
}
|
||||
for _, permission := range requiredPermission {
|
||||
if _, ok := a.Permissions[permission]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Payload 定义负载信息
|
||||
type Payload struct {
|
||||
Id int32 `json:"id,omitempty"`
|
||||
Type PayloadType `json:"type,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
}
|
||||
|
||||
type PayloadType int
|
||||
|
||||
const (
|
||||
PayloadNone PayloadType = iota // 游客
|
||||
PayloadUser // 用户
|
||||
PayloadAdmin // 管理员
|
||||
PayloadPublicServer // 公共服务(public_client)
|
||||
PayloadSecuredServer // 安全服务(credential_client)
|
||||
PayloadInternalServer // 内部服务
|
||||
)
|
||||
|
||||
func (t PayloadType) ToStr() string {
|
||||
switch t {
|
||||
case PayloadUser:
|
||||
return "user"
|
||||
case PayloadAdmin:
|
||||
return "admn"
|
||||
case PayloadPublicServer:
|
||||
return "cpub"
|
||||
case PayloadSecuredServer:
|
||||
return "ccnf"
|
||||
case PayloadInternalServer:
|
||||
return "inte"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadTypeFromStr(name string) PayloadType {
|
||||
switch name {
|
||||
case "user":
|
||||
return PayloadUser
|
||||
case "admn":
|
||||
return PayloadAdmin
|
||||
case "cpub":
|
||||
return PayloadPublicServer
|
||||
case "ccnf":
|
||||
return PayloadSecuredServer
|
||||
case "inte":
|
||||
return PayloadInternalServer
|
||||
default:
|
||||
return PayloadNone
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
|
||||
var clientType PayloadType
|
||||
switch spec {
|
||||
case client2.SpecNative, client2.SpecBrowser:
|
||||
clientType = PayloadPublicServer
|
||||
case client2.SpecWeb:
|
||||
clientType = PayloadSecuredServer
|
||||
case client2.SpecTrusted:
|
||||
clientType = PayloadInternalServer
|
||||
}
|
||||
return clientType
|
||||
}
|
||||
24
web/auth/errors.go
Normal file
24
web/auth/errors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package auth
|
||||
|
||||
type AuthErr string
|
||||
|
||||
func (e AuthErr) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// 认证错误
|
||||
const (
|
||||
ErrAuthenticateUnauthorize = AuthErr("令牌无效")
|
||||
ErrAuthenticateForbidden = AuthErr("没有权限")
|
||||
)
|
||||
|
||||
// 授权错误
|
||||
const (
|
||||
ErrAuthorizeInvalidRequest = AuthErr("invalid_request")
|
||||
ErrAuthorizeInvalidClient = AuthErr("invalid_client")
|
||||
ErrAuthorizeInvalidGrant = AuthErr("invalid_grant")
|
||||
ErrAuthorizeInvalidScope = AuthErr("invalid_scope")
|
||||
ErrAuthorizeUnauthorizedClient = AuthErr("unauthorized_client")
|
||||
ErrAuthorizeUnsupportedGrantType = AuthErr("unsupported_grant_type")
|
||||
ErrAuthorizeInvalidPKCE = AuthErr("invalid_pkce")
|
||||
)
|
||||
@@ -2,160 +2,36 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"platform/pkg/env"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"time"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
// 认证主体
|
||||
Payload *Payload
|
||||
// 令牌信息
|
||||
TokenDetails *TokenDetails
|
||||
func FindSession(accessToken string, now time.Time) (*m.Session, error) {
|
||||
return q.Session.
|
||||
Preload(field.Associations).
|
||||
Where(
|
||||
q.Session.AccessToken.Eq(accessToken),
|
||||
q.Session.AccessTokenExpires.Gt(orm.LocalDateTime(now)),
|
||||
).First()
|
||||
}
|
||||
|
||||
func FindSession(ctx context.Context, token string) (*Context, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, errors.New("invalid_token")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
auth := new(Context)
|
||||
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
func FindSessionByRefresh(refreshToken string, now time.Time) (*m.Session, error) {
|
||||
return q.Session.
|
||||
Preload(field.Associations).
|
||||
Where(
|
||||
q.Session.RefreshToken.Eq(refreshToken),
|
||||
q.Session.RefreshTokenExpires.Gt(orm.LocalDateTime(now)),
|
||||
).First()
|
||||
}
|
||||
|
||||
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
accessToken := genToken()
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(authCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: authCtx,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务保存数据到 Redis
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipe := g.Redis.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
||||
if remember {
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenDetails{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
rKey := refreshKey(refreshToken)
|
||||
var tokenDetails *TokenDetails
|
||||
|
||||
// 刷新令牌
|
||||
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 先获取刷新令牌数据
|
||||
refreshJson, err := tx.Get(ctx, rKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析刷新令牌数据
|
||||
refreshData := new(RefreshData)
|
||||
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成新的令牌
|
||||
newAccessToken := genToken()
|
||||
newRefreshToken := genToken()
|
||||
|
||||
authData, err := json.Marshal(refreshData.AuthContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRefreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: refreshData.AuthContext,
|
||||
AccessToken: newAccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipeline := tx.Pipeline()
|
||||
|
||||
// 保存新的令牌
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
|
||||
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
|
||||
|
||||
// 删除旧的令牌
|
||||
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
||||
pipeline.Del(ctx, refreshKey(refreshToken))
|
||||
|
||||
_, err = pipeline.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenDetails = &TokenDetails{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: refreshData.AuthContext,
|
||||
}
|
||||
return nil
|
||||
}, rKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
||||
}
|
||||
|
||||
return tokenDetails, nil
|
||||
func SaveSession(session *m.Session) error {
|
||||
return q.Session.Save(session)
|
||||
}
|
||||
|
||||
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
|
||||
@@ -163,11 +39,6 @@ func RemoveSession(ctx context.Context, accessToken string, refreshToken string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成一个新的令牌
|
||||
func genToken() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
// 令牌键的格式为 "session:<token>"
|
||||
func accessKey(token string) string {
|
||||
return fmt.Sprintf("session:%s", token)
|
||||
@@ -177,32 +48,3 @@ func accessKey(token string) string {
|
||||
func refreshKey(token string) string {
|
||||
return fmt.Sprintf("session:refresh:%s", token)
|
||||
}
|
||||
|
||||
// TokenDetails 存储令牌详细信息
|
||||
type TokenDetails struct {
|
||||
// 访问令牌
|
||||
AccessToken string
|
||||
// 刷新令牌
|
||||
RefreshToken string
|
||||
// 访问令牌过期时间
|
||||
AccessTokenExpires time.Time
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth *Context
|
||||
}
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext *Context
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
type SessionErr string
|
||||
|
||||
func (e SessionErr) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user