Files
platform/web/services/session.go

278 lines
6.6 KiB
Go

package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"platform/pkg/env"
"platform/pkg/rds"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
// region SessionService
var Session SessionServiceInter = &sessionService{}
type SessionServiceInter interface {
// Find 通过访问令牌获取会话信息
Find(ctx context.Context, token string) (*AuthContext, error)
// Create 创建一个新的会话
Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error)
// Refresh 刷新一个会话
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
// Remove 删除会话
Remove(ctx context.Context, accessToken, refreshToken string) error
}
type SessionServiceError string
func (e SessionServiceError) Error() string {
return string(e)
}
var (
ErrInvalidToken = SessionServiceError("invalid_token")
)
type sessionService struct{}
// Find 通过访问令牌获取会话信息
func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
// 读取认证数据
authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrInvalidToken
}
return nil, err
}
// 反序列化
auth := new(AuthContext)
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
return nil, err
}
return auth, nil
}
// Create 创建一个新的会话
func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(auth)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: auth,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := rds.Client.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: auth,
}, nil
}
// Refresh 刷新一个会话
func (s *sessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
}
// Remove 删除会话
func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
rds.Client.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
}
// 刷新令牌键的格式为 "session:refreshKey:<token>"
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// endregion
// region AuthContext
// AuthContext 定义认证信息
type AuthContext struct {
Payload Payload `json:"payload"`
Agent Agent `json:"agent,omitempty"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Payload 定义负载信息
type Payload struct {
Id int32 `json:"id,omitempty"`
Type PayloadType `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"`
}
// PayloadType 定义负载类型
type PayloadType int
const (
// PayloadUser 用户类型
PayloadUser PayloadType = iota
// PayloadAdmin 管理员类型
PayloadAdmin
// PayloadClientPublic 公共客户端类型
PayloadClientPublic
// PayloadClientConfidential 机密客户端类型
PayloadClientConfidential
)
type Agent struct {
Id int32 `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
// AnyPermission 检查认证是否包含指定权限
func (a *AuthContext) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
return false
}
for _, permission := range requiredPermission {
if _, ok := a.Permissions[permission]; ok {
return true
}
}
return false
}
// endregion
type RefreshData struct {
AuthContext AuthContext
AccessToken string
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth AuthContext
}