认证流程迁移到 oauth 风格,登出接口改为 /revoke,重构接口代码到 service

This commit is contained in:
2025-04-24 10:52:13 +08:00
parent 1374757eab
commit 08c88da0ce
7 changed files with 432 additions and 549 deletions

View File

@@ -100,12 +100,12 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
var header = c.Get("Authorization")
var split = strings.Split(header, " ")
if len(split) != 2 {
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
}
var token = split[1]
if token == "" {
return nil, fiber.NewError(fiber.StatusBadRequest, "无效的令牌")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
}
var auth *services.AuthContext
@@ -115,7 +115,7 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
auth, err = authBearer(c.Context(), token)
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
auth, err = authBasic(c.Context(), token)
default:
@@ -127,10 +127,10 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return nil, fiber.NewError(fiber.StatusForbidden, "拒绝访问")
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
// 将认证信息存储在上下文中

292
web/handlers/auth.go Normal file
View File

@@ -0,0 +1,292 @@
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region /token
type TokenReq struct {
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
s.GrantCodeData
s.GrantClientData
s.GrantRefreshData
s.GrantPasswordData
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
// 授权码模式
case s.OauthGrantTypeAuthorizationCode:
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 客户端凭证模式
case s.OauthGrantTypeClientCredentials:
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 刷新令牌模式
case s.OauthGrantTypeRefreshToken:
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, s.ErrInvalidToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
// 密码模式
case s.OauthGrantTypePassword:
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
if err != nil {
return err
}
return sendSuccess(c, token)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.URLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case s.OauthGrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypePassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion
// region /revoke
type RevokeReq struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
func Revoke(c *fiber.Ctx) error {
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
if err != nil {
// 用户未登录
return nil
}
// 解析请求参数
req := new(RevokeReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 删除会话
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken)
if err != nil {
return err
}
return nil
}
// endregion

View File

@@ -1,145 +0,0 @@
package handlers
import (
"errors"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"time"
"github.com/gofiber/fiber/v2"
"gorm.io/gorm"
)
type LoginReq struct {
Username string `json:"username"`
Password string `json:"password"`
Remember bool `json:"remember"`
}
type LoginResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Expires int64 `json:"expires"`
Auth s.AuthContext `json:"auth"`
Profile *m.User `json:"profile"`
}
func Login(c *fiber.Ctx) error {
// 验证请求参数
req := new(LoginReq)
if err := c.BodyParser(req); err != nil {
return err
}
if req.Username == "" {
return fiber.NewError(fiber.StatusBadRequest, "手机号不能为空")
}
if req.Password == "" {
return fiber.NewError(fiber.StatusBadRequest, "验证码不能为空")
}
return loginByPhone(c, req)
}
func loginByPhone(c *fiber.Ctx, req *LoginReq) error {
// 验证验证码
err := s.Verifier.VerifySms(c.Context(), req.Username, req.Password)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
return err
}
// 查找用户 todo 获取权限信息
var user *m.User
err = q.Q.Transaction(func(tx *q.Query) error {
user, err = tx.User.
Where(tx.User.Phone.Eq(req.Username)).
Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
// 如果用户不存在,初始化用户 todo 保存默认权限信息
if user == nil {
user = &m.User{
Phone: req.Username,
Username: req.Username,
}
}
// 更新用户的登录时间
user.LastLogin = time.Now()
user.LastLoginHost = c.IP()
user.LastLoginAgent = c.Get("User-Agent")
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
// 保存到会话
auth := s.AuthContext{
Permissions: map[string]struct{}{
"user": {},
},
Payload: s.Payload{
Id: user.ID,
Type: s.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
duration := time.Hour * 24
if req.Remember {
duration *= 7
}
token, err := s.Session.Create(c.Context(), auth)
if err != nil {
return err
}
user.Password = ""
return c.JSON(LoginResp{
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expires: token.AccessTokenExpires.Unix(),
Auth: auth,
Profile: user,
})
}
type LogoutReq struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
}
func Logout(c *fiber.Ctx) error {
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
if err != nil {
// 用户未登录
return nil
}
// 解析请求参数
req := new(LogoutReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 删除会话
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken)
if err != nil {
return err
}
return nil
}

View File

@@ -1,371 +0,0 @@
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region Token
type TokenReq struct {
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
TokenReqCode
TokenReqClient
TokenReqRefresh
TokenReqPassword
}
type TokenReqCode struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type TokenReqClient struct {
}
type TokenReqRefresh struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type TokenReqPassword struct {
LoginType s.OauthGrantLoginType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
case s.OauthGrantTypeAuthorizationCode:
return authorizationCode(c, req)
case s.OauthGrantTypeClientCredentials:
return clientCredentials(c, req)
case s.OauthGrantTypeRefreshToken:
return refreshToken(c, req)
case s.OauthGrantTypePassword:
return password(c, req)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 授权码
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, s.OauthGrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 客户端凭证
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
client, err := protect(c, s.OauthGrantTypeClientCredentials, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 刷新令牌
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, s.OauthGrantTypeRefreshToken, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, s.ErrInvalidToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
}
func password(c *fiber.Ctx, req *TokenReq) error {
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
// 验证客户端凭证
_, err := protect(c, s.OauthGrantTypePassword, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
// 验证验证码
err = s.Verifier.VerifySms(c.Context(), req.Username, req.Password)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
return err
}
// 查找用户
var user *m.User
err = q.Q.Transaction(func(tx *q.Query) error {
switch req.LoginType {
case s.OauthGrantPasswordTypePhoneCode:
user, err = tx.User.Where(tx.User.Phone.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypeEmailCode:
user, err = tx.User.Where(tx.User.Email.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypePassword:
user, err = tx.User.
Where(tx.User.Or(
tx.User.Phone.Eq(req.Username),
tx.User.Email.Eq(req.Username),
tx.User.Username.Eq(req.Username),
)).
Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
default:
return sendError(c, s.ErrOauthInvalidRequest, "无效的登录类型")
}
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
if user == nil {
user = &m.User{
Phone: req.Username,
Username: req.Username,
}
}
// 更新用户的登录时间
user.LastLogin = time.Now()
user.LastLoginHost = c.IP()
user.LastLoginAgent = c.Get("User-Agent")
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
// 保存到会话
auth := s.AuthContext{
Payload: s.Payload{
Id: user.ID,
Type: s.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
duration := s.DefaultSessionConfig
if !req.Remember {
duration.RefreshTokenDuration = 0
}
token, err := s.Session.Create(c.Context(), auth)
if err != nil {
return err
}
return sendSuccess(c, token)
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.URLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case s.OauthGrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypePassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceOauthError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion

View File

@@ -12,10 +12,9 @@ func ApplyRouters(app *fiber.App) {
// 认证
auth := api.Group("/auth")
auth.Post("/verify/sms", handlers.SmsCode)
auth.Post("/login/sms", auth2.PermitDevice(), handlers.Login)
auth.Post("/logout", handlers.Logout)
auth.Post("/token", handlers.Token)
auth.Post("/revoke", handlers.Revoke)
auth.Post("/verify/sms", handlers.SmsCode)
// 通道
channel := api.Group("/channel")

View File

@@ -3,42 +3,25 @@ package services
import (
"context"
"errors"
"platform/web/models"
m "platform/web/models"
q "platform/web/queries"
"time"
"gorm.io/gorm"
)
var Auth = &authService{}
type AuthServiceError string
func (e AuthServiceError) Error() string {
return string(e)
}
type AuthServiceOauthError string
func (e AuthServiceOauthError) Error() string {
return string(e)
}
var (
ErrOauthInvalidRequest = AuthServiceOauthError("invalid_request")
ErrOauthInvalidClient = AuthServiceOauthError("invalid_client")
ErrOauthInvalidGrant = AuthServiceOauthError("invalid_grant")
ErrOauthInvalidScope = AuthServiceOauthError("invalid_scope")
ErrOauthUnauthorizedClient = AuthServiceOauthError("unauthorized_client")
ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type")
)
type authService struct{}
// OauthAuthorizationCode 验证授权码
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) {
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) {
// TODO: 从数据库验证授权码
return nil, errors.New("TODO")
}
// OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...string) (*TokenDetails, error) {
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
var clientType PayloadType
switch client.Spec {
@@ -75,7 +58,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models
}
// OauthRefreshToken 验证刷新令牌
func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
// TODO: 从数据库验证刷新令牌
details, err := Session.Refresh(ctx, refreshToken)
if err != nil {
@@ -85,6 +68,114 @@ func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Clie
return details, nil
}
// OauthPassword 验证密码
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*TokenDetails, error) {
var user *m.User
err := q.Q.Transaction(func(tx *q.Query) error {
switch data.LoginType {
case OauthGrantPasswordTypePhoneCode:
// 验证验证码
err := Verifier.VerifySms(ctx, data.Username, data.Password)
if err != nil {
if errors.Is(err, ErrVerifierServiceInvalid) {
return ErrOauthInvalidRequest
}
return err
}
// 查找用户
user, err =
tx.User.Where(tx.User.Phone.Eq(data.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case OauthGrantPasswordTypeEmailCode:
var err error
user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case OauthGrantPasswordTypePassword:
var err error
user, err = tx.User.
Where(tx.User.Or(
tx.User.Phone.Eq(data.Username),
tx.User.Email.Eq(data.Username),
tx.User.Username.Eq(data.Username),
)).
Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
default:
return ErrOauthInvalidRequest
}
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
if user == nil {
user = &m.User{
Phone: data.Username,
Username: data.Username,
}
}
// 更新用户的登录时间
user.LastLogin = time.Now()
user.LastLoginHost = ip
user.LastLoginAgent = agent
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 保存到会话
auth := AuthContext{
Payload: Payload{
Id: user.ID,
Type: PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
duration := DefaultSessionConfig
if !data.Remember {
duration.RefreshTokenDuration = 0
}
token, err := Session.Create(ctx, auth)
if err != nil {
return nil, err
}
return token, nil
}
type GrantCodeData struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type GrantClientData struct {
}
type GrantRefreshData struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type GrantPasswordData struct {
LoginType OauthGrantLoginType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type OauthGrantType string
const (
@@ -101,3 +192,18 @@ const (
OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code")
OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code")
)
type AuthServiceError string
func (e AuthServiceError) Error() string {
return string(e)
}
var (
ErrOauthInvalidRequest = AuthServiceError("invalid_request")
ErrOauthInvalidClient = AuthServiceError("invalid_client")
ErrOauthInvalidGrant = AuthServiceError("invalid_grant")
ErrOauthInvalidScope = AuthServiceError("invalid_scope")
ErrOauthUnauthorizedClient = AuthServiceError("unauthorized_client")
ErrOauthUnsupportedGrantType = AuthServiceError("unsupported_grant_type")
)