package services import ( "context" "encoding/json" "errors" "fmt" "platform/pkg/env" "platform/web/auth" g "platform/web/globals" "time" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) // region SessionService var Session SessionServiceInter = &sessionService{} type SessionServiceInter interface { // Find 通过访问令牌获取会话信息 Find(ctx context.Context, token string) (*auth.Context, error) // Create 创建一个新的会话 Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) // Refresh 刷新一个会话 Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) // Remove 删除会话 Remove(ctx context.Context, accessToken, refreshToken string) error } type SessionServiceError string func (e SessionServiceError) Error() string { return string(e) } var ( ErrInvalidToken = SessionServiceError("invalid_token") ) type sessionService struct{} // Find 通过访问令牌获取会话信息 func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) { // 读取认证数据 authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result() if err != nil { if errors.Is(err, redis.Nil) { return nil, ErrInvalidToken } return nil, err } // 反序列化 authCtx := new(auth.Context) if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil { return nil, err } return authCtx, nil } // Create 创建一个新的会话 func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) { var now = time.Now() // 生成令牌组 accessToken := genToken() refreshToken := genToken() // 序列化认证数据 authData, err := json.Marshal(authCtx) if err != nil { return nil, err } // 序列化刷新令牌数据 refreshData, err := json.Marshal(RefreshData{ AuthContext: authCtx, AccessToken: accessToken, }) if err != nil { return nil, err } // 事务保存数据到 Redis var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second pipe := g.Redis.TxPipeline() pipe.Set(ctx, accessKey(accessToken), authData, accessExpire) if remember { pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire) } _, err = pipe.Exec(ctx) if err != nil { return nil, err } return &TokenDetails{ AccessToken: accessToken, AccessTokenExpires: now.Add(accessExpire), RefreshToken: refreshToken, RefreshTokenExpires: now.Add(refreshExpire), Auth: authCtx, }, nil } // Refresh 刷新一个会话 func (s *sessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) { var now = time.Now() rKey := refreshKey(refreshToken) var tokenDetails *TokenDetails // 刷新令牌 err := g.Redis.Watch(ctx, func(tx *redis.Tx) error { // 先获取刷新令牌数据 refreshJson, err := tx.Get(ctx, rKey).Result() if err != nil { if errors.Is(err, redis.Nil) { return ErrInvalidToken } return err } // 解析刷新令牌数据 refreshData := new(RefreshData) if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil { return err } // 生成新的令牌 newAccessToken := genToken() newRefreshToken := genToken() authData, err := json.Marshal(refreshData.AuthContext) if err != nil { return err } newRefreshData, err := json.Marshal(RefreshData{ AuthContext: refreshData.AuthContext, AccessToken: newAccessToken, }) if err != nil { return err } pipeline := tx.Pipeline() // 保存新的令牌 var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire) pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire) // 删除旧的令牌 pipeline.Del(ctx, accessKey(refreshData.AccessToken)) pipeline.Del(ctx, refreshKey(refreshToken)) _, err = pipeline.Exec(ctx) if err != nil { return err } tokenDetails = &TokenDetails{ AccessToken: newAccessToken, RefreshToken: newRefreshToken, AccessTokenExpires: now.Add(accessExpire), RefreshTokenExpires: now.Add(refreshExpire), Auth: refreshData.AuthContext, } return nil }, rKey) if err != nil { return nil, fmt.Errorf("刷新令牌失败: %w", err) } return tokenDetails, nil } // Remove 删除会话 func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken)) return nil } // 生成一个新的令牌 func genToken() string { return uuid.NewString() } // 令牌键的格式为 "session:" func accessKey(token string) string { return fmt.Sprintf("session:%s", token) } // 刷新令牌键的格式为 "session:refreshKey:" func refreshKey(token string) string { return fmt.Sprintf("session:refresh:%s", token) } // endregion type RefreshData struct { AuthContext auth.Context AccessToken string } // TokenDetails 存储令牌详细信息 type TokenDetails struct { // 访问令牌 AccessToken string // 刷新令牌 RefreshToken string // 访问令牌过期时间 AccessTokenExpires time.Time // 刷新令牌过期时间 RefreshTokenExpires time.Time // 认证信息 Auth auth.Context }