认证授权主要流程实现
This commit is contained in:
85
web/services/auth.go
Normal file
85
web/services/auth.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/web/models"
|
||||
)
|
||||
|
||||
var Auth = &authService{}
|
||||
|
||||
type authService struct{}
|
||||
|
||||
type AuthServiceError string
|
||||
|
||||
func (e AuthServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type AuthServiceOauthError string
|
||||
|
||||
func (e AuthServiceOauthError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrOauthInvalidRequest = AuthServiceOauthError("invalid_request")
|
||||
ErrOauthInvalidClient = AuthServiceOauthError("invalid_client")
|
||||
ErrOauthInvalidGrant = AuthServiceOauthError("invalid_grant")
|
||||
ErrOauthInvalidScope = AuthServiceOauthError("invalid_scope")
|
||||
ErrOauthUnauthorizedClient = AuthServiceOauthError("unauthorized_client")
|
||||
ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type")
|
||||
)
|
||||
|
||||
// OauthAuthorizationCode 验证授权码
|
||||
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) {
|
||||
// TODO: 从数据库验证授权码
|
||||
return nil, errors.New("TODO")
|
||||
}
|
||||
|
||||
// OauthClientCredentials 验证客户端凭证
|
||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...[]string) (*TokenDetails, error) {
|
||||
|
||||
var clientType PayloadType
|
||||
switch client.Spec {
|
||||
case 0:
|
||||
clientType = PayloadClientConfidential
|
||||
case 1:
|
||||
clientType = PayloadClientPublic
|
||||
case 2:
|
||||
clientType = PayloadClientConfidential
|
||||
}
|
||||
|
||||
// 保存会话并返回令牌
|
||||
auth := AuthContext{
|
||||
Permissions: map[string]struct{}{
|
||||
"client": {},
|
||||
},
|
||||
Payload: Payload{
|
||||
Type: clientType,
|
||||
Id: client.ID,
|
||||
},
|
||||
}
|
||||
|
||||
// todo 数据库定义会话持续时间
|
||||
token, err := Session.Create(ctx, auth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// OauthRefreshToken 验证刷新令牌
|
||||
func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
|
||||
// TODO: 从数据库验证刷新令牌
|
||||
return nil, errors.New("TODO")
|
||||
}
|
||||
|
||||
type GrantType int
|
||||
|
||||
const (
|
||||
GrantTypeAuthorizationCode GrantType = iota
|
||||
GrantTypeClientCredentials
|
||||
GrantTypeRefreshToken
|
||||
)
|
||||
288
web/services/session.go
Normal file
288
web/services/session.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"platform/init/rds"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// region SessionService
|
||||
|
||||
var Session = &sessionService{}
|
||||
|
||||
type sessionService struct {
|
||||
}
|
||||
|
||||
type SessionServiceError string
|
||||
|
||||
func (e SessionServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidToken = SessionServiceError("invalid_token")
|
||||
)
|
||||
|
||||
// Find 通过访问令牌获取会话信息
|
||||
func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
auth := new(AuthContext)
|
||||
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// Create 创建一个新的会话
|
||||
func (s *sessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) {
|
||||
// 解析可选配置
|
||||
cfg := DefaultSessionConfig
|
||||
if len(config) > 0 {
|
||||
cfg = mergeConfig(DefaultSessionConfig, config[0])
|
||||
}
|
||||
|
||||
// 生成令牌组
|
||||
accessToken := genToken()
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(auth)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: auth,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务保存数据到 Redis
|
||||
pipe := rds.Client.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, cfg.AccessTokenDuration)
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, cfg.RefreshTokenDuration)
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenDetails{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration),
|
||||
Auth: auth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh 刷新一个会话
|
||||
func (s *sessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) {
|
||||
// 解析可选配置
|
||||
cfg := DefaultSessionConfig
|
||||
if len(config) > 0 {
|
||||
cfg = mergeConfig(DefaultSessionConfig, config[0])
|
||||
}
|
||||
|
||||
rKey := refreshKey(refreshToken)
|
||||
var tokenDetails *TokenDetails
|
||||
|
||||
// 刷新令牌
|
||||
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 先获取刷新令牌数据
|
||||
refreshJson, err := tx.Get(ctx, rKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析刷新令牌数据
|
||||
refreshData := new(RefreshData)
|
||||
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除旧的令牌
|
||||
pipeline := tx.Pipeline()
|
||||
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
||||
pipeline.Del(ctx, refreshKey(refreshToken))
|
||||
|
||||
// 生成新的令牌
|
||||
newAccessToken := genToken()
|
||||
newRefreshToken := genToken()
|
||||
|
||||
authData, err := json.Marshal(refreshData.AuthContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRefreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: refreshData.AuthContext,
|
||||
AccessToken: newAccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipeline.Set(ctx, accessKey(newAccessToken), authData, cfg.AccessTokenDuration)
|
||||
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, cfg.RefreshTokenDuration)
|
||||
|
||||
_, err = pipeline.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenDetails = &TokenDetails{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration),
|
||||
RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration),
|
||||
Auth: refreshData.AuthContext,
|
||||
}
|
||||
return nil
|
||||
}, rKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
||||
}
|
||||
|
||||
return tokenDetails, nil
|
||||
}
|
||||
|
||||
// Remove 删除会话
|
||||
func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
||||
rds.Client.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 令牌键的格式为 "session:<token>"
|
||||
func accessKey(token string) string {
|
||||
return fmt.Sprintf("session:%s", token)
|
||||
}
|
||||
|
||||
// 刷新令牌键的格式为 "session:refreshKey:<token>"
|
||||
func refreshKey(token string) string {
|
||||
return fmt.Sprintf("session:refresh:%s", token)
|
||||
}
|
||||
|
||||
// 生成一个新的令牌
|
||||
func genToken() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region SessionConfig
|
||||
|
||||
// SessionConfig 定义会话管理的配置选项
|
||||
type SessionConfig struct {
|
||||
// 令牌配置
|
||||
AccessTokenDuration time.Duration
|
||||
RefreshTokenDuration time.Duration
|
||||
}
|
||||
|
||||
// DefaultSessionConfig 默认会话配置
|
||||
var DefaultSessionConfig = SessionConfig{
|
||||
AccessTokenDuration: 2 * time.Hour,
|
||||
RefreshTokenDuration: 7 * 24 * time.Hour,
|
||||
}
|
||||
|
||||
// 合并配置,保留非零值
|
||||
func mergeConfig(defaultCfg SessionConfig, customCfg SessionConfig) SessionConfig {
|
||||
result := defaultCfg
|
||||
|
||||
if customCfg.AccessTokenDuration != 0 {
|
||||
result.AccessTokenDuration = customCfg.AccessTokenDuration
|
||||
}
|
||||
|
||||
if customCfg.RefreshTokenDuration != 0 {
|
||||
result.RefreshTokenDuration = customCfg.RefreshTokenDuration
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region AuthContext
|
||||
|
||||
// AuthContext 定义认证信息
|
||||
type AuthContext struct {
|
||||
Payload Payload
|
||||
Permissions map[string]struct{}
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
|
||||
// Payload 定义负载信息
|
||||
type Payload struct {
|
||||
Type PayloadType
|
||||
Id int32
|
||||
}
|
||||
|
||||
// PayloadType 定义负载类型
|
||||
type PayloadType int
|
||||
|
||||
const (
|
||||
// PayloadUser 用户类型
|
||||
PayloadUser PayloadType = iota
|
||||
// PayloadAdmin 管理员类型
|
||||
PayloadAdmin
|
||||
// PayloadClientPublic 公共客户端类型
|
||||
PayloadClientPublic
|
||||
// PayloadClientConfidential 机密客户端类型
|
||||
PayloadClientConfidential
|
||||
)
|
||||
|
||||
// AnyPermission 检查认证是否包含指定权限
|
||||
func (a *AuthContext) AnyPermission(requiredPermission ...string) bool {
|
||||
if a == nil || a.Permissions == nil {
|
||||
return false
|
||||
}
|
||||
for _, permission := range requiredPermission {
|
||||
if _, ok := a.Permissions[permission]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext AuthContext
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
// TokenDetails 存储令牌详细信息
|
||||
type TokenDetails struct {
|
||||
// 访问令牌
|
||||
AccessToken string
|
||||
// 刷新令牌
|
||||
RefreshToken string
|
||||
// 访问令牌过期时间
|
||||
AccessTokenExpires time.Time
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth AuthContext
|
||||
}
|
||||
124
web/services/verifier.go
Normal file
124
web/services/verifier.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"platform/init/rds"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var Verifier = &verifierService{}
|
||||
|
||||
type verifierService struct {
|
||||
}
|
||||
|
||||
type VerifierServiceError string
|
||||
|
||||
func (e VerifierServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVerifierServiceInvalid = VerifierServiceError("验证码错误")
|
||||
)
|
||||
|
||||
type VerifierServiceSendLimitErr int
|
||||
|
||||
func (e VerifierServiceSendLimitErr) Error() string {
|
||||
return "发送频率过快"
|
||||
}
|
||||
|
||||
type VerifierSmsPurpose int
|
||||
|
||||
const (
|
||||
Login VerifierSmsPurpose = iota
|
||||
)
|
||||
|
||||
func smsKey(phone string, purpose VerifierSmsPurpose) string {
|
||||
return fmt.Sprintf("verify:sms:%d:%s", purpose, phone)
|
||||
}
|
||||
|
||||
func (s *verifierService) SendSms(ctx context.Context, phone string, purpose VerifierSmsPurpose) error {
|
||||
key := smsKey(phone, purpose)
|
||||
keyLock := key + ":lock"
|
||||
|
||||
// 生成验证码
|
||||
code := rand.Intn(900000) + 100000 // 6-digit code between 100000-999999
|
||||
|
||||
// 检查发送频率,1 分钟内只能发送一次
|
||||
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
|
||||
result, err := tx.TTL(ctx, keyLock).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if result > 0 {
|
||||
return VerifierServiceSendLimitErr(result.Seconds())
|
||||
}
|
||||
if result != -2 {
|
||||
return VerifierServiceError("验证码检查异常")
|
||||
}
|
||||
|
||||
pipe := rds.Client.Pipeline()
|
||||
pipe.Set(ctx, key, code, 10*time.Minute)
|
||||
pipe.Set(ctx, keyLock, "", 1*time.Minute)
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, keyLock)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: 发送短信验证码
|
||||
slog.Debug("发送验证码", slog.String("phone", phone), slog.String("code", strconv.Itoa(code)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bool, error) {
|
||||
key := smsKey(phone, Login)
|
||||
keyLock := key + ":lock"
|
||||
|
||||
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 检查验证码
|
||||
val, err := rds.Client.Get(ctx, key).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
slog.Error("验证码获取失败", slog.Any("err", err))
|
||||
return err
|
||||
}
|
||||
|
||||
if val != code {
|
||||
return ErrVerifierServiceInvalid
|
||||
}
|
||||
|
||||
// 删除验证码
|
||||
_, err = tx.Pipelined(ctx, func(pipe redis.Pipeliner) error {
|
||||
pipe.Del(ctx, key)
|
||||
pipe.Del(ctx, keyLock)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("验证码删除失败", slog.Any("err", err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrVerifierServiceInvalid) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
Reference in New Issue
Block a user