package auth import ( "context" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "log/slog" "platform/pkg/env" "platform/pkg/u" "platform/web/core" g "platform/web/globals" "platform/web/globals/orm" m "platform/web/models" q "platform/web/queries" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "gorm.io/gorm" ) type GrantType string const ( GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式 GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式 GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式 GrantPassword = GrantType("password") // 密码模式(私有扩展) ) type PasswordGrantType string const ( GrantPasswordSecret = PasswordGrantType("password") // 账号密码 GrantPasswordPhone = PasswordGrantType("phone_code") // 手机验证码 GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码 ) type TokenReq struct { GrantType GrantType `json:"grant_type" form:"grant_type"` ClientID string `json:"client_id" form:"client_id"` ClientSecret string `json:"client_secret" form:"client_secret"` Scope string `json:"scope" form:"scope"` GrantCodeData GrantClientData GrantRefreshData GrantPasswordData } type GrantCodeData struct { Code string `json:"code" form:"code"` RedirectURI string `json:"redirect_uri" form:"redirect_uri"` CodeVerifier string `json:"code_verifier" form:"code_verifier"` } type GrantClientData struct { } type GrantRefreshData struct { RefreshToken string `json:"refresh_token" form:"refresh_token"` } type GrantPasswordData struct { LoginType PasswordGrantType `json:"login_type" form:"login_type"` Username string `json:"username" form:"username"` Password string `json:"password" form:"password"` Remember bool `json:"remember" form:"remember"` } type TokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int `json:"expires_in"` TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` } type TokenErrResp struct { Error string `json:"error"` Description string `json:"error_description,omitempty"` } func Token(c *fiber.Ctx) error { now := time.Now() // 验证请求参数 req := new(TokenReq) if err := c.BodyParser(req); err != nil { return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数") } if req.GrantType == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type") } switch req.GrantType { // 授权码模式 case GrantAuthorizationCode: if req.Code == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code") } // 刷新令牌模式 case GrantRefreshToken: if req.RefreshToken == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token") } // 密码模式 case GrantPassword: if req.LoginType == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type") } if req.Username == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username") } if req.Password == "" { return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password") } } // 验证客户端身份 authCtx := GetAuthCtx(c) if authCtx == nil { authCtx = &AuthCtx{} } if authCtx.Client == nil { client, err := authClient(req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } authCtx.Client = client } // 处理授权 var session *m.Session var err error switch req.GrantType { // 授权码模式 case GrantAuthorizationCode: session, err = authAuthorizationCode(c, authCtx, req, now) // 客户端凭证模式 case GrantClientCredentials: session, err = authClientCredential(c, authCtx, req, now) // 刷新令牌模式 case GrantRefreshToken: session, err = authRefreshToken(c, authCtx, req, now) // 密码模式 case GrantPassword: session, err = authPassword(c, authCtx, req, now) default: return sendError(c, ErrAuthorizeUnsupportedGrantType) } if err != nil { return sendError(c, err) } // 返回响应 return c.JSON(&TokenResp{ TokenType: "Bearer", AccessToken: session.AccessToken, RefreshToken: u.Z(session.RefreshToken), ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()), Scope: u.Z(session.Scopes), }) } func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { // 检查 code 获取用户授权信息 data, err := g.Redis.Get(context.Background(), req.Code).Result() if err != nil { return nil, err } var codeCtx CodeContext if err := json.Unmarshal([]byte(data), &codeCtx); err != nil { return nil, err } // 检查 PKCE if codeCtx.CodeChallengeMethod != "" { if req.CodeVerifier == "" { return nil, ErrAuthorizeInvalidPKCE } switch codeCtx.CodeChallengeMethod { case "plain": if req.CodeVerifier != codeCtx.CodeChallenge { return nil, ErrAuthorizeInvalidPKCE } case "S256": hash := sha256.Sum256([]byte(req.CodeVerifier)) verifier := base64.RawURLEncoding.EncodeToString(hash[:]) if verifier != codeCtx.CodeChallenge { return nil, ErrAuthorizeInvalidPKCE } default: return nil, ErrAuthorizeInvalidPKCE } } user, err := q.User.Where( q.User.ID.Eq(codeCtx.UserID), q.User.Status.Eq(int(m.UserStatusEnabled)), ).First() if err != nil { return nil, err } // todo 检查 scope // 生成会话 ip, _ := orm.ParseInet(ctx.Get(core.HeaderUserIP)) ua := ctx.Get(core.HeaderUserUA) session := &m.Session{ IP: ip, UA: u.X(ua), UserID: &user.ID, ClientID: &auth.Client.ID, Scopes: u.P(strings.Join(codeCtx.Scopes, " ")), AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } if codeCtx.Remember { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } err = SaveSession(session) if err != nil { return nil, err } return session, nil } func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) { // todo 检查 scope // 生成会话 ip, _ := orm.ParseInet(ctx.Get(core.HeaderUserIP)) ua := ctx.Get(core.HeaderUserUA) session := &m.Session{ IP: ip, UA: u.X(ua), ClientID: &auth.Client.ID, AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } // 保存会话 err := SaveSession(session) if err != nil { return nil, err } return session, nil } func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { ip, _ := orm.ParseInet(ctx.Get(core.HeaderUserIP)) ua := ctx.Get(core.HeaderUserUA) var user *m.User err := q.Q.Transaction(func(tx *q.Query) (err error) { switch req.LoginType { case GrantPasswordPhone: user, err = authUserBySms(tx, req.Username, req.Password) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } if user == nil { user = &m.User{ Phone: req.Username, Username: u.P(req.Username), Status: m.UserStatusEnabled, } } case GrantPasswordEmail: user, err = authUserByEmail(tx, req.Username, req.Password) if err != nil { return err } case GrantPasswordSecret: user, err = authUserByPassword(tx, req.Username, req.Password) if err != nil { return err } default: return ErrAuthorizeInvalidRequest } // 账户状态 if user.Status == m.UserStatusDisabled { slog.Debug("账户状态异常", "username", req.Username, "status", user.Status) return core.NewBizErr("账号无法登录") } // 更新用户的登录时间 user.LastLogin = u.P(time.Now()) user.LastLoginIP = ip user.LastLoginUA = u.X(ua) if err := tx.User.Save(user); err != nil { return err } return nil }) if err != nil { return nil, err } // 生成会话 session := &m.Session{ IP: ip, UA: u.X(ua), UserID: &user.ID, ClientID: &auth.Client.ID, Scopes: u.X(req.Scope), AccessToken: uuid.NewString(), AccessTokenExpires: now.Add(time.Duration(env.SessionAccessExpire) * time.Second), } if req.Remember { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } err = SaveSession(session) if err != nil { return nil, err } return session, nil } func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) { session, err := FindSessionByRefresh(req.RefreshToken, now) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, ErrAuthorizeInvalidGrant } return nil, err } // todo 检查权限 // 生成令牌 session.AccessToken = uuid.NewString() session.AccessTokenExpires = now.Add(time.Duration(env.SessionAccessExpire) * time.Second) if session.RefreshToken != nil { session.RefreshToken = u.P(uuid.NewString()) session.RefreshTokenExpires = u.P(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)) } // 保存令牌 err = SaveSession(session) if err != nil { return nil, err } return session, nil } func sendError(c *fiber.Ctx, err error, description ...string) error { var sErr AuthErr if errors.As(err, &sErr) { status := fiber.StatusBadRequest var desc string switch { case errors.Is(sErr, ErrAuthorizeInvalidRequest): desc = "无效的请求" case errors.Is(sErr, ErrAuthorizeInvalidClient): status = fiber.StatusUnauthorized desc = "无效的客户端凭证" case errors.Is(sErr, ErrAuthorizeInvalidGrant): desc = "无效的授权凭证" case errors.Is(sErr, ErrAuthorizeInvalidScope): desc = "无效的授权范围" case errors.Is(sErr, ErrAuthorizeUnauthorizedClient): desc = "未授权的客户端" case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType): desc = "不支持的授权类型" } if len(description) > 0 { desc = description[0] } return c.Status(status).JSON(TokenErrResp{ Error: string(sErr), Description: desc, }) } return err } func Revoke() error { return nil } func Introspect() error { return nil } type CodeContext struct { UserID int32 `json:"user_id"` ClientID int32 `json:"client_id"` Scopes []string `json:"scopes"` Remember bool `json:"remember"` CodeChallenge string `json:"code_challenge"` CodeChallengeMethod string `json:"code_challenge_method"` }