package services import ( "context" "encoding/json" "errors" "fmt" "platform/pkg/rds" "time" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) // region SessionService var Session SessionServiceInter = &sessionService{} type SessionServiceInter interface { // Find 通过访问令牌获取会话信息 Find(ctx context.Context, token string) (*AuthContext, error) // Create 创建一个新的会话 Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) // Refresh 刷新一个会话 Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) // Remove 删除会话 Remove(ctx context.Context, accessToken, refreshToken string) error } type SessionServiceError string func (e SessionServiceError) Error() string { return string(e) } var ( ErrInvalidToken = SessionServiceError("invalid_token") ) type sessionService struct{} // Find 通过访问令牌获取会话信息 func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) { // 读取认证数据 authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result() if err != nil { if errors.Is(err, redis.Nil) { return nil, ErrInvalidToken } return nil, err } // 反序列化 auth := new(AuthContext) if err := json.Unmarshal([]byte(authJSON), auth); err != nil { return nil, err } return auth, nil } // Create 创建一个新的会话 func (s *sessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) { // 解析可选配置 cfg := DefaultSessionConfig if len(config) > 0 { cfg = mergeConfig(DefaultSessionConfig, config[0]) } // 生成令牌组 accessToken := genToken() refreshToken := genToken() // 序列化认证数据 authData, err := json.Marshal(auth) if err != nil { return nil, err } // 序列化刷新令牌数据 refreshData, err := json.Marshal(RefreshData{ AuthContext: auth, AccessToken: accessToken, }) if err != nil { return nil, err } // 事务保存数据到 Redis pipe := rds.Client.TxPipeline() pipe.Set(ctx, accessKey(accessToken), authData, cfg.AccessTokenDuration) pipe.Set(ctx, refreshKey(refreshToken), refreshData, cfg.RefreshTokenDuration) _, err = pipe.Exec(ctx) if err != nil { return nil, err } return &TokenDetails{ AccessToken: accessToken, AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration), RefreshToken: refreshToken, RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration), Auth: auth, }, nil } // Refresh 刷新一个会话 func (s *sessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) { // 解析可选配置 cfg := DefaultSessionConfig if len(config) > 0 { cfg = mergeConfig(DefaultSessionConfig, config[0]) } rKey := refreshKey(refreshToken) var tokenDetails *TokenDetails // 刷新令牌 err := rds.Client.Watch(ctx, func(tx *redis.Tx) error { // 先获取刷新令牌数据 refreshJson, err := tx.Get(ctx, rKey).Result() if err != nil { if errors.Is(err, redis.Nil) { return ErrInvalidToken } return err } // 解析刷新令牌数据 refreshData := new(RefreshData) if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil { return err } // 删除旧的令牌 pipeline := tx.Pipeline() pipeline.Del(ctx, accessKey(refreshData.AccessToken)) pipeline.Del(ctx, refreshKey(refreshToken)) // 生成新的令牌 newAccessToken := genToken() newRefreshToken := genToken() authData, err := json.Marshal(refreshData.AuthContext) if err != nil { return err } newRefreshData, err := json.Marshal(RefreshData{ AuthContext: refreshData.AuthContext, AccessToken: newAccessToken, }) if err != nil { return err } pipeline.Set(ctx, accessKey(newAccessToken), authData, cfg.AccessTokenDuration) pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, cfg.RefreshTokenDuration) _, err = pipeline.Exec(ctx) if err != nil { return err } tokenDetails = &TokenDetails{ AccessToken: newAccessToken, RefreshToken: newRefreshToken, AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration), RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration), Auth: refreshData.AuthContext, } return nil }, rKey) if err != nil { return nil, fmt.Errorf("刷新令牌失败: %w", err) } return tokenDetails, nil } // Remove 删除会话 func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { rds.Client.Del(ctx, accessKey(accessToken), refreshKey(refreshToken)) return nil } // 生成一个新的令牌 func genToken() string { return uuid.NewString() } // 令牌键的格式为 "session:" func accessKey(token string) string { return fmt.Sprintf("session:%s", token) } // 刷新令牌键的格式为 "session:refreshKey:" func refreshKey(token string) string { return fmt.Sprintf("session:refresh:%s", token) } // endregion // region SessionConfig // SessionConfig 定义会话管理的配置选项 type SessionConfig struct { // 令牌配置 AccessTokenDuration time.Duration RefreshTokenDuration time.Duration } // DefaultSessionConfig 默认会话配置 var DefaultSessionConfig = SessionConfig{ AccessTokenDuration: 2 * time.Hour, RefreshTokenDuration: 7 * 24 * time.Hour, } // 合并配置,保留非零值 func mergeConfig(defaultCfg SessionConfig, customCfg SessionConfig) SessionConfig { result := defaultCfg if customCfg.AccessTokenDuration != 0 { result.AccessTokenDuration = customCfg.AccessTokenDuration } if customCfg.RefreshTokenDuration != 0 { result.RefreshTokenDuration = customCfg.RefreshTokenDuration } return result } // endregion // region AuthContext // AuthContext 定义认证信息 type AuthContext struct { Payload Payload Permissions map[string]struct{} Metadata map[string]interface{} } // Payload 定义负载信息 type Payload struct { Type PayloadType Id int32 } // PayloadType 定义负载类型 type PayloadType int const ( // PayloadUser 用户类型 PayloadUser PayloadType = iota // PayloadAdmin 管理员类型 PayloadAdmin // PayloadClientPublic 公共客户端类型 PayloadClientPublic // PayloadClientConfidential 机密客户端类型 PayloadClientConfidential ) // AnyPermission 检查认证是否包含指定权限 func (a *AuthContext) AnyPermission(requiredPermission ...string) bool { if a == nil || a.Permissions == nil { return false } for _, permission := range requiredPermission { if _, ok := a.Permissions[permission]; ok { return true } } return false } // endregion type RefreshData struct { AuthContext AuthContext AccessToken string } // TokenDetails 存储令牌详细信息 type TokenDetails struct { // 访问令牌 AccessToken string // 刷新令牌 RefreshToken string // 访问令牌过期时间 AccessTokenExpires time.Time // 刷新令牌过期时间 RefreshTokenExpires time.Time // 认证信息 Auth AuthContext }