407 lines
11 KiB
Go
407 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"log/slog"
|
|
"platform/pkg/env"
|
|
"platform/pkg/u"
|
|
"platform/web/core"
|
|
user2 "platform/web/domains/user"
|
|
g "platform/web/globals"
|
|
"platform/web/globals/orm"
|
|
m "platform/web/models"
|
|
q "platform/web/queries"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/google/uuid"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type GrantType string
|
|
|
|
const (
|
|
GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式
|
|
GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式
|
|
GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式
|
|
GrantPassword = GrantType("password") // 密码模式(私有扩展)
|
|
)
|
|
|
|
type PasswordGrantType string
|
|
|
|
const (
|
|
GrantPasswordSecret = PasswordGrantType("password") // 账号密码
|
|
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机验证码
|
|
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码
|
|
)
|
|
|
|
type TokenReq struct {
|
|
GrantType GrantType `json:"grant_type" form:"grant_type"`
|
|
ClientID string `json:"client_id" form:"client_id"`
|
|
ClientSecret string `json:"client_secret" form:"client_secret"`
|
|
Scope string `json:"scope" form:"scope"`
|
|
GrantCodeData
|
|
GrantClientData
|
|
GrantRefreshData
|
|
GrantPasswordData
|
|
}
|
|
|
|
type GrantCodeData struct {
|
|
Code string `json:"code" form:"code"`
|
|
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
|
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
|
}
|
|
|
|
type GrantClientData struct {
|
|
}
|
|
|
|
type GrantRefreshData struct {
|
|
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
|
}
|
|
|
|
type GrantPasswordData struct {
|
|
LoginType PasswordGrantType `json:"login_type" form:"login_type"`
|
|
Username string `json:"username" form:"username"`
|
|
Password string `json:"password" form:"password"`
|
|
Remember bool `json:"remember" form:"remember"`
|
|
}
|
|
|
|
type TokenResp struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
TokenType string `json:"token_type"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
type TokenErrResp struct {
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description,omitempty"`
|
|
}
|
|
|
|
func Token(c *fiber.Ctx) error {
|
|
now := time.Now()
|
|
|
|
// 验证请求参数
|
|
req := new(TokenReq)
|
|
if err := c.BodyParser(req); err != nil {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
|
|
}
|
|
if req.GrantType == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
|
|
}
|
|
|
|
switch req.GrantType {
|
|
// 授权码模式
|
|
case GrantAuthorizationCode:
|
|
if req.Code == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
|
|
}
|
|
// 刷新令牌模式
|
|
case GrantRefreshToken:
|
|
if req.RefreshToken == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
|
|
}
|
|
// 密码模式
|
|
case GrantPassword:
|
|
if req.LoginType == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
|
|
}
|
|
if req.Username == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
|
|
}
|
|
if req.Password == "" {
|
|
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
|
|
}
|
|
}
|
|
|
|
// 验证客户端身份
|
|
authCtx := GetAuthCtx(c)
|
|
if authCtx == nil {
|
|
authCtx = &AuthCtx{}
|
|
}
|
|
if authCtx.Client == nil {
|
|
client, err := authClient(req.ClientID, req.ClientSecret)
|
|
if err != nil {
|
|
return sendError(c, err)
|
|
}
|
|
authCtx.Client = client
|
|
}
|
|
|
|
// 处理授权
|
|
var session *m.Session
|
|
var err error
|
|
switch req.GrantType {
|
|
// 授权码模式
|
|
case GrantAuthorizationCode:
|
|
session, err = authAuthorizationCode(c, authCtx, req, now)
|
|
// 客户端凭证模式
|
|
case GrantClientCredentials:
|
|
session, err = authClientCredential(c, authCtx, req, now)
|
|
// 刷新令牌模式
|
|
case GrantRefreshToken:
|
|
session, err = authRefreshToken(c, authCtx, req, now)
|
|
// 密码模式
|
|
case GrantPassword:
|
|
session, err = authPassword(c, authCtx, req, now)
|
|
default:
|
|
return sendError(c, ErrAuthorizeUnsupportedGrantType)
|
|
}
|
|
if err != nil {
|
|
return sendError(c, err)
|
|
}
|
|
|
|
// 返回响应
|
|
return c.JSON(&TokenResp{
|
|
TokenType: "Bearer",
|
|
AccessToken: session.AccessToken,
|
|
RefreshToken: u.Z(session.RefreshToken),
|
|
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
|
|
Scope: u.Z(session.Scopes_),
|
|
})
|
|
}
|
|
|
|
func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
|
|
// 检查 code 获取用户授权信息
|
|
data, err := g.Redis.Get(context.Background(), req.Code).Result()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var codeCtx CodeContext
|
|
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 检查 PKCE
|
|
if codeCtx.CodeChallengeMethod != "" {
|
|
if req.CodeVerifier == "" {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
|
|
switch codeCtx.CodeChallengeMethod {
|
|
case "plain":
|
|
if req.CodeVerifier != codeCtx.CodeChallenge {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
case "S256":
|
|
hash := sha256.Sum256([]byte(req.CodeVerifier))
|
|
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
if verifier != codeCtx.CodeChallenge {
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
default:
|
|
return nil, ErrAuthorizeInvalidPKCE
|
|
}
|
|
}
|
|
|
|
user, err := q.User.Where(
|
|
q.User.ID.Eq(codeCtx.UserID),
|
|
q.User.Status.Eq(int32(user2.StatusEnabled)),
|
|
).First()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// todo 检查 scope
|
|
|
|
// 生成会话
|
|
session := &m.Session{
|
|
IP: u.X(ctx.IP()),
|
|
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
|
UserID: &user.ID,
|
|
ClientID: &auth.Client.ID,
|
|
Scopes_: u.P(strings.Join(codeCtx.Scopes, " ")),
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
|
}
|
|
if codeCtx.Remember {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
|
}
|
|
|
|
err = SaveSession(session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
|
|
// todo 检查 scope
|
|
|
|
// 生成会话
|
|
session := &m.Session{
|
|
IP: u.X(ctx.IP()),
|
|
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
|
ClientID: &auth.Client.ID,
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
|
}
|
|
|
|
// 保存会话
|
|
err := SaveSession(session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
var user *m.User
|
|
err := q.Q.Transaction(func(tx *q.Query) (err error) {
|
|
switch req.LoginType {
|
|
case GrantPasswordPhone:
|
|
user, err = authUserBySms(tx, req.Username, req.Password)
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return err
|
|
}
|
|
if user == nil {
|
|
user = &m.User{
|
|
Phone: req.Username,
|
|
Username: u.P(req.Username),
|
|
Status: int32(user2.StatusEnabled),
|
|
}
|
|
}
|
|
case GrantPasswordEmail:
|
|
user, err = authUserByEmail(tx, req.Username, req.Password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case GrantPasswordSecret:
|
|
user, err = authUserByPassword(tx, req.Username, req.Password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return ErrAuthorizeInvalidRequest
|
|
}
|
|
|
|
// 账户状态
|
|
if user2.Status(user.Status) == user2.StatusDisabled {
|
|
slog.Debug("账户状态异常", "username", req.Username, "status", user.Status)
|
|
return core.NewBizErr("账号无法登录")
|
|
}
|
|
|
|
// 更新用户的登录时间
|
|
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
|
|
user.LastLoginHost = u.X(ctx.IP())
|
|
user.LastLoginAgent = u.X(ctx.Get(fiber.HeaderUserAgent))
|
|
if err := tx.User.Save(user); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 生成会话
|
|
session := &m.Session{
|
|
IP: u.X(ctx.IP()),
|
|
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
|
UserID: &user.ID,
|
|
ClientID: &auth.Client.ID,
|
|
Scopes_: u.X(req.Scope),
|
|
AccessToken: uuid.NewString(),
|
|
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
|
}
|
|
if req.Remember {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
|
}
|
|
|
|
err = SaveSession(session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
|
session, err := FindSessionByRefresh(req.RefreshToken, now)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrAuthorizeInvalidGrant
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// todo 检查权限
|
|
|
|
// 生成令牌
|
|
session.AccessToken = uuid.NewString()
|
|
session.AccessTokenExpires = orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second))
|
|
if session.RefreshToken != nil {
|
|
session.RefreshToken = u.P(uuid.NewString())
|
|
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
|
}
|
|
|
|
// 保存令牌
|
|
err = SaveSession(session)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
|
var sErr AuthErr
|
|
if errors.As(err, &sErr) {
|
|
status := fiber.StatusBadRequest
|
|
var desc string
|
|
switch {
|
|
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
|
|
desc = "无效的请求"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidClient):
|
|
status = fiber.StatusUnauthorized
|
|
desc = "无效的客户端凭证"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
|
|
desc = "无效的授权凭证"
|
|
case errors.Is(sErr, ErrAuthorizeInvalidScope):
|
|
desc = "无效的授权范围"
|
|
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
|
|
desc = "未授权的客户端"
|
|
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
|
|
desc = "不支持的授权类型"
|
|
}
|
|
if len(description) > 0 {
|
|
desc = description[0]
|
|
}
|
|
|
|
return c.Status(status).JSON(TokenErrResp{
|
|
Error: string(sErr),
|
|
Description: desc,
|
|
})
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func Revoke() error {
|
|
return nil
|
|
}
|
|
|
|
func Introspect() error {
|
|
return nil
|
|
}
|
|
|
|
type CodeContext struct {
|
|
UserID int32 `json:"user_id"`
|
|
ClientID int32 `json:"client_id"`
|
|
Scopes []string `json:"scopes"`
|
|
Remember bool `json:"remember"`
|
|
CodeChallenge string `json:"code_challenge"`
|
|
CodeChallengeMethod string `json:"code_challenge_method"`
|
|
}
|