227 lines
5.4 KiB
Go
227 lines
5.4 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"platform/pkg/env"
|
|
"platform/web/auth"
|
|
g "platform/web/globals"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// region SessionService
|
|
|
|
var Session SessionServiceInter = &sessionService{}
|
|
|
|
type SessionServiceInter interface {
|
|
// Find 通过访问令牌获取会话信息
|
|
Find(ctx context.Context, token string) (*auth.Context, error)
|
|
// Create 创建一个新的会话
|
|
Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error)
|
|
// Refresh 刷新一个会话
|
|
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
|
|
// Remove 删除会话
|
|
Remove(ctx context.Context, accessToken, refreshToken string) error
|
|
}
|
|
|
|
type SessionServiceError string
|
|
|
|
func (e SessionServiceError) Error() string {
|
|
return string(e)
|
|
}
|
|
|
|
var (
|
|
ErrInvalidToken = SessionServiceError("invalid_token")
|
|
)
|
|
|
|
type sessionService struct{}
|
|
|
|
// Find 通过访问令牌获取会话信息
|
|
func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
|
|
|
|
// 读取认证数据
|
|
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// 反序列化
|
|
authCtx := new(auth.Context)
|
|
if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return authCtx, nil
|
|
}
|
|
|
|
// Create 创建一个新的会话
|
|
func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
|
|
var now = time.Now()
|
|
|
|
// 生成令牌组
|
|
accessToken := genToken()
|
|
refreshToken := genToken()
|
|
|
|
// 序列化认证数据
|
|
authData, err := json.Marshal(authCtx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 序列化刷新令牌数据
|
|
refreshData, err := json.Marshal(RefreshData{
|
|
AuthContext: authCtx,
|
|
AccessToken: accessToken,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 事务保存数据到 Redis
|
|
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
|
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
|
|
|
pipe := g.Redis.TxPipeline()
|
|
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
|
if remember {
|
|
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
|
}
|
|
_, err = pipe.Exec(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &TokenDetails{
|
|
AccessToken: accessToken,
|
|
AccessTokenExpires: now.Add(accessExpire),
|
|
RefreshToken: refreshToken,
|
|
RefreshTokenExpires: now.Add(refreshExpire),
|
|
Auth: authCtx,
|
|
}, nil
|
|
}
|
|
|
|
// Refresh 刷新一个会话
|
|
func (s *sessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
|
var now = time.Now()
|
|
|
|
rKey := refreshKey(refreshToken)
|
|
var tokenDetails *TokenDetails
|
|
|
|
// 刷新令牌
|
|
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
|
|
|
|
// 先获取刷新令牌数据
|
|
refreshJson, err := tx.Get(ctx, rKey).Result()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return ErrInvalidToken
|
|
}
|
|
return err
|
|
}
|
|
|
|
// 解析刷新令牌数据
|
|
refreshData := new(RefreshData)
|
|
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 生成新的令牌
|
|
newAccessToken := genToken()
|
|
newRefreshToken := genToken()
|
|
|
|
authData, err := json.Marshal(refreshData.AuthContext)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
newRefreshData, err := json.Marshal(RefreshData{
|
|
AuthContext: refreshData.AuthContext,
|
|
AccessToken: newAccessToken,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
pipeline := tx.Pipeline()
|
|
|
|
// 保存新的令牌
|
|
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
|
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
|
|
|
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
|
|
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
|
|
|
|
// 删除旧的令牌
|
|
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
|
pipeline.Del(ctx, refreshKey(refreshToken))
|
|
|
|
_, err = pipeline.Exec(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tokenDetails = &TokenDetails{
|
|
AccessToken: newAccessToken,
|
|
RefreshToken: newRefreshToken,
|
|
AccessTokenExpires: now.Add(accessExpire),
|
|
RefreshTokenExpires: now.Add(refreshExpire),
|
|
Auth: refreshData.AuthContext,
|
|
}
|
|
return nil
|
|
}, rKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
|
}
|
|
|
|
return tokenDetails, nil
|
|
}
|
|
|
|
// Remove 删除会话
|
|
func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
|
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
|
|
return nil
|
|
}
|
|
|
|
// 生成一个新的令牌
|
|
func genToken() string {
|
|
return uuid.NewString()
|
|
}
|
|
|
|
// 令牌键的格式为 "session:<token>"
|
|
func accessKey(token string) string {
|
|
return fmt.Sprintf("session:%s", token)
|
|
}
|
|
|
|
// 刷新令牌键的格式为 "session:refreshKey:<token>"
|
|
func refreshKey(token string) string {
|
|
return fmt.Sprintf("session:refresh:%s", token)
|
|
}
|
|
|
|
// endregion
|
|
|
|
type RefreshData struct {
|
|
AuthContext auth.Context
|
|
AccessToken string
|
|
}
|
|
|
|
// TokenDetails 存储令牌详细信息
|
|
type TokenDetails struct {
|
|
// 访问令牌
|
|
AccessToken string
|
|
// 刷新令牌
|
|
RefreshToken string
|
|
// 访问令牌过期时间
|
|
AccessTokenExpires time.Time
|
|
// 刷新令牌过期时间
|
|
RefreshTokenExpires time.Time
|
|
// 认证信息
|
|
Auth auth.Context
|
|
}
|