Files
platform/web/handlers/auth.go

341 lines
8.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region /token
type TokenReq struct {
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
s.GrantCodeData
s.GrantClientData
s.GrantRefreshData
s.GrantPasswordData
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
// 授权码模式
case s.OauthGrantTypeAuthorizationCode:
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 客户端凭证模式
case s.OauthGrantTypeClientCredentials:
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 刷新令牌模式
case s.OauthGrantTypeRefreshToken:
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, s.ErrInvalidToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
// 密码模式
case s.OauthGrantTypePassword:
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
if err != nil {
return err
}
return sendSuccess(c, token)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.RawURLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case s.OauthGrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypePassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion
// region /revoke
type RevokeReq struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
func Revoke(c *fiber.Ctx) error {
_, err := auth.Protect(c, []s.PayloadType{s.PayloadClientConfidential}, []string{})
if err != nil {
// 用户未登录
return nil
}
// 解析请求参数
req := new(RevokeReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 删除会话
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken)
if err != nil {
return err
}
return nil
}
// endregion
// region /profile
type IntrospectResp struct {
m.User
}
func Introspect(c *fiber.Ctx) error {
// 验证权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
if err != nil {
return err
}
// 获取用户信息
profile, err := q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Omit(q.User.Password, q.User.DeletedAt).
Take()
if err != nil {
return err
}
// 掩码敏感信息
if profile.Phone != "" {
profile.Phone = maskPhone(profile.Phone)
}
if profile.IDNo != "" {
profile.IDNo = maskIdNo(profile.IDNo)
}
return c.JSON(IntrospectResp{*profile})
}
func maskPhone(phone string) string {
if len(phone) < 11 {
return phone
}
return phone[:3] + "****" + phone[7:]
}
func maskIdNo(idNo string) string {
if len(idNo) < 18 {
return idNo
}
return idNo[:3] + "*********" + idNo[14:]
}
// endregion