Files
platform/web/handlers/oauth.go

372 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region Token
type TokenReq struct {
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
TokenReqCode
TokenReqClient
TokenReqRefresh
TokenReqPassword
}
type TokenReqCode struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type TokenReqClient struct {
}
type TokenReqRefresh struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type TokenReqPassword struct {
LoginType s.OauthGrantLoginType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
case s.OauthGrantTypeAuthorizationCode:
return authorizationCode(c, req)
case s.OauthGrantTypeClientCredentials:
return clientCredentials(c, req)
case s.OauthGrantTypeRefreshToken:
return refreshToken(c, req)
case s.OauthGrantTypePassword:
return password(c, req)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 授权码
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, s.OauthGrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 客户端凭证
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
client, err := protect(c, s.OauthGrantTypeClientCredentials, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 刷新令牌
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, s.OauthGrantTypeRefreshToken, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, s.ErrInvalidToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
}
func password(c *fiber.Ctx, req *TokenReq) error {
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
// 验证客户端凭证
_, err := protect(c, s.OauthGrantTypePassword, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
// 验证验证码
err = s.Verifier.VerifySms(c.Context(), req.Username, req.Password)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
return err
}
// 查找用户
var user *m.User
err = q.Q.Transaction(func(tx *q.Query) error {
switch req.LoginType {
case s.OauthGrantPasswordTypePhoneCode:
user, err = tx.User.Where(tx.User.Phone.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypeEmailCode:
user, err = tx.User.Where(tx.User.Email.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypePassword:
user, err = tx.User.
Where(tx.User.Or(
tx.User.Phone.Eq(req.Username),
tx.User.Email.Eq(req.Username),
tx.User.Username.Eq(req.Username),
)).
Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
default:
return sendError(c, s.ErrOauthInvalidRequest, "无效的登录类型")
}
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
if user == nil {
user = &m.User{
Phone: req.Username,
Username: req.Username,
}
}
// 更新用户的登录时间
user.LastLogin = time.Now()
user.LastLoginHost = c.IP()
user.LastLoginAgent = c.Get("User-Agent")
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
// 保存到会话
auth := s.AuthContext{
Payload: s.Payload{
Id: user.ID,
Type: s.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
duration := s.DefaultSessionConfig
if !req.Remember {
duration.RefreshTokenDuration = 0
}
token, err := s.Session.Create(c.Context(), auth)
if err != nil {
return err
}
return sendSuccess(c, token)
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.URLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case s.OauthGrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypePassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceOauthError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion