重构认证授权逻辑,集中到 auth 包中
This commit is contained in:
27
README.md
27
README.md
@@ -1,28 +1,29 @@
|
||||
## todo
|
||||
|
||||
- 长效业务接入
|
||||
- 页面 账户总览
|
||||
- 页面 提取记录
|
||||
- 页面 使用记录
|
||||
|
||||
- 公众号的到期提示
|
||||
- 支付回调处理
|
||||
- 保存 session 到数据库
|
||||
- 移除 PayloadType,使用 Grant_Type
|
||||
|
||||
### 下阶段
|
||||
### 长效业务
|
||||
|
||||
- 代理数据表的 secret 字段 aes 加密存储
|
||||
- 扩展 device 权限验证方式,提供一种方法区分内部和外部服务
|
||||
- 废弃 password 授权模式,迁移到 authorization code 授权模式
|
||||
- oauth token 验证授权范围
|
||||
- 实现白银节点的 warp 服务,用来去重与端口分配,保证测试与生产环境不会产生端口竞争
|
||||
- callback 结果直接由 api 端提供,不通过前端转发
|
||||
- debug:白银节点提供一些工具接口,方便快速操作
|
||||
- 批量下线端口
|
||||
- 统一使用 validator 进行参数验证
|
||||
- 支付回调处理
|
||||
- 公众号的到期提示
|
||||
|
||||
### 长期
|
||||
|
||||
- 修改日志输出提高可读性
|
||||
- 用户最后登录的数据可以通过 session 表进行查询,不再保存在 user 表里
|
||||
- callback 结果直接由 api 端提供,不通过前端转发
|
||||
- debug:白银节点提供一些工具接口,方便快速操作
|
||||
- 批量下线端口
|
||||
- 代理数据表的 secret 字段 aes 加密存储
|
||||
- 废弃 password 授权模式,迁移到 authorization code 授权模式
|
||||
- oauth token 验证授权范围
|
||||
- 实现白银节点的 warp 服务,用来去重与端口分配,保证测试与生产环境不会产生端口竞争
|
||||
- 统一使用 validator 进行参数验证
|
||||
- 分离项目脚手架(env,logs,Server 结构体)
|
||||
- 业务代码和测试代码共用的控制变量可以优化为环境变量
|
||||
- 考虑统计接口调用频率并通过接口展示
|
||||
|
||||
@@ -295,6 +295,52 @@ comment on column client.deleted_at is '删除时间';
|
||||
-- region 权限信息
|
||||
-- ====================
|
||||
|
||||
-- session
|
||||
drop table if exists session cascade;
|
||||
create table session (
|
||||
id serial primary key,
|
||||
user_id int references "user" (id)
|
||||
on update cascade
|
||||
on delete cascade,
|
||||
client_id int references client (id)
|
||||
on update cascade
|
||||
on delete cascade,
|
||||
ip varchar(45),
|
||||
ua varchar(255),
|
||||
grant_type varchar(255) not null default 0,
|
||||
access_token varchar(255) not null unique,
|
||||
access_token_expires timestamp not null,
|
||||
refresh_token varchar(255) unique,
|
||||
refresh_token_expires timestamp,
|
||||
scopes varchar(255),
|
||||
created_at timestamp default current_timestamp,
|
||||
updated_at timestamp default current_timestamp,
|
||||
deleted_at timestamp
|
||||
);
|
||||
create index session_user_id_index on session (user_id);
|
||||
create index session_client_id_index on session (client_id);
|
||||
create index session_access_token_index on session (access_token);
|
||||
create index session_refresh_token_index on session (refresh_token);
|
||||
create index session_created_at_index on session (created_at);
|
||||
create index session_deleted_at_index on session (deleted_at);
|
||||
|
||||
-- session表字段注释
|
||||
comment on table session is '会话表';
|
||||
comment on column session.id is '会话ID';
|
||||
comment on column session.user_id is '用户ID';
|
||||
comment on column session.client_id is '客户端ID';
|
||||
comment on column session.ip is 'IP地址';
|
||||
comment on column session.ua is '用户代理';
|
||||
comment on column session.grant_type is '授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式';
|
||||
comment on column session.access_token is '访问令牌';
|
||||
comment on column session.access_token_expires is '访问令牌过期时间';
|
||||
comment on column session.refresh_token is '刷新令牌';
|
||||
comment on column session.refresh_token_expires is '刷新令牌过期时间';
|
||||
comment on column session.scopes is '权限范围';
|
||||
comment on column session.created_at is '创建时间';
|
||||
comment on column session.updated_at is '更新时间';
|
||||
comment on column session.deleted_at is '删除时间';
|
||||
|
||||
-- permission
|
||||
drop table if exists permission cascade;
|
||||
create table permission (
|
||||
|
||||
@@ -77,7 +77,7 @@ func Locals(c *fiber.Ctx, auth *Context) {
|
||||
}
|
||||
|
||||
func authBearer(ctx context.Context, token string) (*Context, error) {
|
||||
auth, err := find(ctx, token)
|
||||
auth, err := FindSession(ctx, token)
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil, err
|
||||
|
||||
@@ -16,3 +16,43 @@ const (
|
||||
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式
|
||||
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式
|
||||
)
|
||||
|
||||
func Token(grant GrantType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func authAuthorizationCode() {
|
||||
|
||||
}
|
||||
|
||||
func authClientCredential() {
|
||||
|
||||
}
|
||||
|
||||
func authRefreshToken() {
|
||||
|
||||
}
|
||||
|
||||
func authPassword() {
|
||||
|
||||
}
|
||||
|
||||
func authPasswordSecret() {
|
||||
|
||||
}
|
||||
|
||||
func authPasswordPhone() {
|
||||
|
||||
}
|
||||
|
||||
func authPasswordEmail() {
|
||||
|
||||
}
|
||||
|
||||
func Revoke() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Introspect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
client2 "platform/web/domains/client"
|
||||
)
|
||||
|
||||
// Context 定义认证信息
|
||||
type Context struct {
|
||||
Payload Payload `json:"payload"`
|
||||
Agent Agent `json:"agent,omitempty"`
|
||||
Permissions map[string]struct{} `json:"permissions,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (a *Context) AnyType(types ...PayloadType) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
for _, t := range types {
|
||||
if a.Payload.Type == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyPermission 检查认证是否包含指定权限
|
||||
func (a *Context) AnyPermission(requiredPermission ...string) bool {
|
||||
if a == nil || a.Permissions == nil {
|
||||
@@ -29,26 +44,15 @@ type Payload struct {
|
||||
Avatar string `json:"avatar,omitempty"`
|
||||
}
|
||||
|
||||
type Agent struct {
|
||||
Id int32 `json:"id,omitempty"`
|
||||
Addr string `json:"addr,omitempty"`
|
||||
}
|
||||
|
||||
type PayloadType int
|
||||
|
||||
const (
|
||||
// PayloadNone 游客
|
||||
PayloadNone PayloadType = iota
|
||||
// PayloadUser 用户
|
||||
PayloadUser
|
||||
// PayloadAdmin 管理员
|
||||
PayloadAdmin
|
||||
// PayloadPublicServer 公共服务(public_client)
|
||||
PayloadPublicServer
|
||||
// PayloadSecuredServer 安全服务(credential_client)
|
||||
PayloadSecuredServer
|
||||
// PayloadInternalServer 内部服务
|
||||
PayloadInternalServer
|
||||
PayloadNone PayloadType = iota // 游客
|
||||
PayloadUser // 用户
|
||||
PayloadAdmin // 管理员
|
||||
PayloadPublicServer // 公共服务(public_client)
|
||||
PayloadSecuredServer // 安全服务(credential_client)
|
||||
PayloadInternalServer // 内部服务
|
||||
)
|
||||
|
||||
func (t PayloadType) ToStr() string {
|
||||
@@ -80,3 +84,16 @@ func PayloadTypeFromStr(name string) PayloadType {
|
||||
return PayloadNone
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
|
||||
var clientType PayloadType
|
||||
switch spec {
|
||||
case client2.SpecNative, client2.SpecBrowser:
|
||||
clientType = PayloadPublicServer
|
||||
case client2.SpecWeb:
|
||||
clientType = PayloadSecuredServer
|
||||
case client2.SpecTrusted:
|
||||
clientType = PayloadInternalServer
|
||||
}
|
||||
return clientType
|
||||
}
|
||||
|
||||
@@ -5,11 +5,21 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"platform/pkg/env"
|
||||
g "platform/web/globals"
|
||||
"time"
|
||||
)
|
||||
|
||||
func find(ctx context.Context, token string) (*Context, error) {
|
||||
type Session struct {
|
||||
// 认证主体
|
||||
Payload *Payload
|
||||
// 令牌信息
|
||||
TokenDetails *TokenDetails
|
||||
}
|
||||
|
||||
func FindSession(ctx context.Context, token string) (*Context, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
|
||||
@@ -29,6 +39,170 @@ func find(ctx context.Context, token string) (*Context, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
accessToken := genToken()
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(authCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: authCtx,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务保存数据到 Redis
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipe := g.Redis.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
||||
if remember {
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenDetails{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
rKey := refreshKey(refreshToken)
|
||||
var tokenDetails *TokenDetails
|
||||
|
||||
// 刷新令牌
|
||||
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 先获取刷新令牌数据
|
||||
refreshJson, err := tx.Get(ctx, rKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析刷新令牌数据
|
||||
refreshData := new(RefreshData)
|
||||
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成新的令牌
|
||||
newAccessToken := genToken()
|
||||
newRefreshToken := genToken()
|
||||
|
||||
authData, err := json.Marshal(refreshData.AuthContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRefreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: refreshData.AuthContext,
|
||||
AccessToken: newAccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipeline := tx.Pipeline()
|
||||
|
||||
// 保存新的令牌
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
|
||||
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
|
||||
|
||||
// 删除旧的令牌
|
||||
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
||||
pipeline.Del(ctx, refreshKey(refreshToken))
|
||||
|
||||
_, err = pipeline.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenDetails = &TokenDetails{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: refreshData.AuthContext,
|
||||
}
|
||||
return nil
|
||||
}, rKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
||||
}
|
||||
|
||||
return tokenDetails, nil
|
||||
}
|
||||
|
||||
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
|
||||
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成一个新的令牌
|
||||
func genToken() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
// 令牌键的格式为 "session:<token>"
|
||||
func accessKey(token string) string {
|
||||
return fmt.Sprintf("session:%s", token)
|
||||
}
|
||||
|
||||
// 刷新令牌键的格式为 "session:refreshKey:<token>"
|
||||
func refreshKey(token string) string {
|
||||
return fmt.Sprintf("session:refresh:%s", token)
|
||||
}
|
||||
|
||||
// TokenDetails 存储令牌详细信息
|
||||
type TokenDetails struct {
|
||||
// 访问令牌
|
||||
AccessToken string
|
||||
// 刷新令牌
|
||||
RefreshToken string
|
||||
// 访问令牌过期时间
|
||||
AccessTokenExpires time.Time
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth *Context
|
||||
}
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext *Context
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
type SessionErr string
|
||||
|
||||
func (e SessionErr) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
|
||||
)
|
||||
|
||||
@@ -110,7 +110,7 @@ func Token(c *fiber.Ctx) error {
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
if err != nil {
|
||||
if errors.Is(err, s.ErrInvalidToken) {
|
||||
if errors.Is(err, auth2.ErrInvalidRefreshToken) {
|
||||
return sendError(c, s.ErrOauthInvalidGrant)
|
||||
}
|
||||
return sendError(c, err)
|
||||
@@ -226,7 +226,7 @@ func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string)
|
||||
}
|
||||
|
||||
// 发送成功响应
|
||||
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
|
||||
func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
|
||||
return c.JSON(TokenResp{
|
||||
AccessToken: details.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
@@ -292,7 +292,7 @@ func Revoke(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 删除会话
|
||||
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken)
|
||||
err = auth2.RemoveSession(c.Context(), req.AccessToken, req.RefreshToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -198,10 +198,7 @@ type RemoveChannelsReq struct {
|
||||
|
||||
func RemoveChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{
|
||||
auth.PayloadUser,
|
||||
auth.PayloadSecuredServer,
|
||||
}, []string{})
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser, auth.PayloadInternalServer}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -17,9 +17,7 @@ type VerifierReq struct {
|
||||
|
||||
func SmsCode(c *fiber.Ctx) error {
|
||||
|
||||
_, err := auth.Protect(c, []auth.PayloadType{
|
||||
auth.PayloadSecuredServer,
|
||||
}, []string{})
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadSecuredServer}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -4,9 +4,7 @@
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"platform/web/globals/orm"
|
||||
)
|
||||
import "platform/web/globals/orm"
|
||||
|
||||
const TableNameLogsRequest = "logs_request"
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"platform/web/globals/orm"
|
||||
)
|
||||
import "platform/web/globals/orm"
|
||||
|
||||
const TableNameResourcePsr = "resource_psr"
|
||||
|
||||
|
||||
@@ -4,9 +4,7 @@
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"platform/web/globals/orm"
|
||||
)
|
||||
import "platform/web/globals/orm"
|
||||
|
||||
const TableNameResourcePss = "resource_pss"
|
||||
|
||||
|
||||
36
web/models/session.gen.go
Normal file
36
web/models/session.gen.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package models
|
||||
|
||||
import (
|
||||
"platform/web/globals/orm"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const TableNameSession = "session"
|
||||
|
||||
// Session mapped from table <session>
|
||||
type Session struct {
|
||||
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
|
||||
UserID int32 `gorm:"column:user_id;comment:用户ID" json:"user_id"` // 用户ID
|
||||
ClientID int32 `gorm:"column:client_id;comment:客户端ID" json:"client_id"` // 客户端ID
|
||||
IP string `gorm:"column:ip;comment:IP地址" json:"ip"` // IP地址
|
||||
Ua string `gorm:"column:ua;comment:用户代理" json:"ua"` // 用户代理
|
||||
GrantType string `gorm:"column:grant_type;not null;default:0;comment:授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式" json:"grant_type"` // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
AccessToken string `gorm:"column:access_token;not null;comment:访问令牌" json:"access_token"` // 访问令牌
|
||||
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
|
||||
RefreshToken string `gorm:"column:refresh_token;comment:刷新令牌" json:"refresh_token"` // 刷新令牌
|
||||
RefreshTokenExpires orm.LocalDateTime `gorm:"column:refresh_token_expires;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
|
||||
Scopes_ string `gorm:"column:scopes;comment:权限范围" json:"scopes"` // 权限范围
|
||||
CreatedAt orm.LocalDateTime `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt orm.LocalDateTime `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
}
|
||||
|
||||
// TableName Session's table name
|
||||
func (*Session) TableName() string {
|
||||
return TableNameSession
|
||||
}
|
||||
@@ -37,6 +37,7 @@ var (
|
||||
ResourcePps *resourcePps
|
||||
ResourcePsr *resourcePsr
|
||||
ResourcePss *resourcePss
|
||||
Session *session
|
||||
Trade *trade
|
||||
User *user
|
||||
UserRole *userRole
|
||||
@@ -67,6 +68,7 @@ func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
|
||||
ResourcePps = &Q.ResourcePps
|
||||
ResourcePsr = &Q.ResourcePsr
|
||||
ResourcePss = &Q.ResourcePss
|
||||
Session = &Q.Session
|
||||
Trade = &Q.Trade
|
||||
User = &Q.User
|
||||
UserRole = &Q.UserRole
|
||||
@@ -98,6 +100,7 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
|
||||
ResourcePps: newResourcePps(db, opts...),
|
||||
ResourcePsr: newResourcePsr(db, opts...),
|
||||
ResourcePss: newResourcePss(db, opts...),
|
||||
Session: newSession(db, opts...),
|
||||
Trade: newTrade(db, opts...),
|
||||
User: newUser(db, opts...),
|
||||
UserRole: newUserRole(db, opts...),
|
||||
@@ -130,6 +133,7 @@ type Query struct {
|
||||
ResourcePps resourcePps
|
||||
ResourcePsr resourcePsr
|
||||
ResourcePss resourcePss
|
||||
Session session
|
||||
Trade trade
|
||||
User user
|
||||
UserRole userRole
|
||||
@@ -163,6 +167,7 @@ func (q *Query) clone(db *gorm.DB) *Query {
|
||||
ResourcePps: q.ResourcePps.clone(db),
|
||||
ResourcePsr: q.ResourcePsr.clone(db),
|
||||
ResourcePss: q.ResourcePss.clone(db),
|
||||
Session: q.Session.clone(db),
|
||||
Trade: q.Trade.clone(db),
|
||||
User: q.User.clone(db),
|
||||
UserRole: q.UserRole.clone(db),
|
||||
@@ -203,6 +208,7 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query {
|
||||
ResourcePps: q.ResourcePps.replaceDB(db),
|
||||
ResourcePsr: q.ResourcePsr.replaceDB(db),
|
||||
ResourcePss: q.ResourcePss.replaceDB(db),
|
||||
Session: q.Session.replaceDB(db),
|
||||
Trade: q.Trade.replaceDB(db),
|
||||
User: q.User.replaceDB(db),
|
||||
UserRole: q.UserRole.replaceDB(db),
|
||||
@@ -233,6 +239,7 @@ type queryCtx struct {
|
||||
ResourcePps *resourcePpsDo
|
||||
ResourcePsr *resourcePsrDo
|
||||
ResourcePss *resourcePssDo
|
||||
Session *sessionDo
|
||||
Trade *tradeDo
|
||||
User *userDo
|
||||
UserRole *userRoleDo
|
||||
@@ -263,6 +270,7 @@ func (q *Query) WithContext(ctx context.Context) *queryCtx {
|
||||
ResourcePps: q.ResourcePps.WithContext(ctx),
|
||||
ResourcePsr: q.ResourcePsr.WithContext(ctx),
|
||||
ResourcePss: q.ResourcePss.WithContext(ctx),
|
||||
Session: q.Session.WithContext(ctx),
|
||||
Trade: q.Trade.WithContext(ctx),
|
||||
User: q.User.WithContext(ctx),
|
||||
UserRole: q.UserRole.WithContext(ctx),
|
||||
|
||||
371
web/queries/session.gen.go
Normal file
371
web/queries/session.gen.go
Normal file
@@ -0,0 +1,371 @@
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
// Code generated by gorm.io/gen. DO NOT EDIT.
|
||||
|
||||
package queries
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"gorm.io/plugin/dbresolver"
|
||||
|
||||
"platform/web/models"
|
||||
)
|
||||
|
||||
func newSession(db *gorm.DB, opts ...gen.DOOption) session {
|
||||
_session := session{}
|
||||
|
||||
_session.sessionDo.UseDB(db, opts...)
|
||||
_session.sessionDo.UseModel(&models.Session{})
|
||||
|
||||
tableName := _session.sessionDo.TableName()
|
||||
_session.ALL = field.NewAsterisk(tableName)
|
||||
_session.ID = field.NewInt32(tableName, "id")
|
||||
_session.UserID = field.NewInt32(tableName, "user_id")
|
||||
_session.ClientID = field.NewInt32(tableName, "client_id")
|
||||
_session.IP = field.NewString(tableName, "ip")
|
||||
_session.Ua = field.NewString(tableName, "ua")
|
||||
_session.GrantType = field.NewString(tableName, "grant_type")
|
||||
_session.AccessToken = field.NewString(tableName, "access_token")
|
||||
_session.AccessTokenExpires = field.NewField(tableName, "access_token_expires")
|
||||
_session.RefreshToken = field.NewString(tableName, "refresh_token")
|
||||
_session.RefreshTokenExpires = field.NewField(tableName, "refresh_token_expires")
|
||||
_session.Scopes_ = field.NewString(tableName, "scopes")
|
||||
_session.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_session.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_session.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
|
||||
_session.fillFieldMap()
|
||||
|
||||
return _session
|
||||
}
|
||||
|
||||
type session struct {
|
||||
sessionDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 会话ID
|
||||
UserID field.Int32 // 用户ID
|
||||
ClientID field.Int32 // 客户端ID
|
||||
IP field.String // IP地址
|
||||
Ua field.String // 用户代理
|
||||
GrantType field.String // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
AccessToken field.String // 访问令牌
|
||||
AccessTokenExpires field.Field // 访问令牌过期时间
|
||||
RefreshToken field.String // 刷新令牌
|
||||
RefreshTokenExpires field.Field // 刷新令牌过期时间
|
||||
Scopes_ field.String // 权限范围
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
|
||||
func (s session) Table(newTableName string) *session {
|
||||
s.sessionDo.UseTable(newTableName)
|
||||
return s.updateTableName(newTableName)
|
||||
}
|
||||
|
||||
func (s session) As(alias string) *session {
|
||||
s.sessionDo.DO = *(s.sessionDo.As(alias).(*gen.DO))
|
||||
return s.updateTableName(alias)
|
||||
}
|
||||
|
||||
func (s *session) updateTableName(table string) *session {
|
||||
s.ALL = field.NewAsterisk(table)
|
||||
s.ID = field.NewInt32(table, "id")
|
||||
s.UserID = field.NewInt32(table, "user_id")
|
||||
s.ClientID = field.NewInt32(table, "client_id")
|
||||
s.IP = field.NewString(table, "ip")
|
||||
s.Ua = field.NewString(table, "ua")
|
||||
s.GrantType = field.NewString(table, "grant_type")
|
||||
s.AccessToken = field.NewString(table, "access_token")
|
||||
s.AccessTokenExpires = field.NewField(table, "access_token_expires")
|
||||
s.RefreshToken = field.NewString(table, "refresh_token")
|
||||
s.RefreshTokenExpires = field.NewField(table, "refresh_token_expires")
|
||||
s.Scopes_ = field.NewString(table, "scopes")
|
||||
s.CreatedAt = field.NewField(table, "created_at")
|
||||
s.UpdatedAt = field.NewField(table, "updated_at")
|
||||
s.DeletedAt = field.NewField(table, "deleted_at")
|
||||
|
||||
s.fillFieldMap()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
_f, ok := s.fieldMap[fieldName]
|
||||
if !ok || _f == nil {
|
||||
return nil, false
|
||||
}
|
||||
_oe, ok := _f.(field.OrderExpr)
|
||||
return _oe, ok
|
||||
}
|
||||
|
||||
func (s *session) fillFieldMap() {
|
||||
s.fieldMap = make(map[string]field.Expr, 14)
|
||||
s.fieldMap["id"] = s.ID
|
||||
s.fieldMap["user_id"] = s.UserID
|
||||
s.fieldMap["client_id"] = s.ClientID
|
||||
s.fieldMap["ip"] = s.IP
|
||||
s.fieldMap["ua"] = s.Ua
|
||||
s.fieldMap["grant_type"] = s.GrantType
|
||||
s.fieldMap["access_token"] = s.AccessToken
|
||||
s.fieldMap["access_token_expires"] = s.AccessTokenExpires
|
||||
s.fieldMap["refresh_token"] = s.RefreshToken
|
||||
s.fieldMap["refresh_token_expires"] = s.RefreshTokenExpires
|
||||
s.fieldMap["scopes"] = s.Scopes_
|
||||
s.fieldMap["created_at"] = s.CreatedAt
|
||||
s.fieldMap["updated_at"] = s.UpdatedAt
|
||||
s.fieldMap["deleted_at"] = s.DeletedAt
|
||||
}
|
||||
|
||||
func (s session) clone(db *gorm.DB) session {
|
||||
s.sessionDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s session) replaceDB(db *gorm.DB) session {
|
||||
s.sessionDo.ReplaceDB(db)
|
||||
return s
|
||||
}
|
||||
|
||||
type sessionDo struct{ gen.DO }
|
||||
|
||||
func (s sessionDo) Debug() *sessionDo {
|
||||
return s.withDO(s.DO.Debug())
|
||||
}
|
||||
|
||||
func (s sessionDo) WithContext(ctx context.Context) *sessionDo {
|
||||
return s.withDO(s.DO.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (s sessionDo) ReadDB() *sessionDo {
|
||||
return s.Clauses(dbresolver.Read)
|
||||
}
|
||||
|
||||
func (s sessionDo) WriteDB() *sessionDo {
|
||||
return s.Clauses(dbresolver.Write)
|
||||
}
|
||||
|
||||
func (s sessionDo) Session(config *gorm.Session) *sessionDo {
|
||||
return s.withDO(s.DO.Session(config))
|
||||
}
|
||||
|
||||
func (s sessionDo) Clauses(conds ...clause.Expression) *sessionDo {
|
||||
return s.withDO(s.DO.Clauses(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Returning(value interface{}, columns ...string) *sessionDo {
|
||||
return s.withDO(s.DO.Returning(value, columns...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Not(conds ...gen.Condition) *sessionDo {
|
||||
return s.withDO(s.DO.Not(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Or(conds ...gen.Condition) *sessionDo {
|
||||
return s.withDO(s.DO.Or(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Select(conds ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Select(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Where(conds ...gen.Condition) *sessionDo {
|
||||
return s.withDO(s.DO.Where(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Order(conds ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Order(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Distinct(cols ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Distinct(cols...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Omit(cols ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Omit(cols...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Join(table schema.Tabler, on ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Join(table, on...))
|
||||
}
|
||||
|
||||
func (s sessionDo) LeftJoin(table schema.Tabler, on ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.LeftJoin(table, on...))
|
||||
}
|
||||
|
||||
func (s sessionDo) RightJoin(table schema.Tabler, on ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.RightJoin(table, on...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Group(cols ...field.Expr) *sessionDo {
|
||||
return s.withDO(s.DO.Group(cols...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Having(conds ...gen.Condition) *sessionDo {
|
||||
return s.withDO(s.DO.Having(conds...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Limit(limit int) *sessionDo {
|
||||
return s.withDO(s.DO.Limit(limit))
|
||||
}
|
||||
|
||||
func (s sessionDo) Offset(offset int) *sessionDo {
|
||||
return s.withDO(s.DO.Offset(offset))
|
||||
}
|
||||
|
||||
func (s sessionDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *sessionDo {
|
||||
return s.withDO(s.DO.Scopes(funcs...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Unscoped() *sessionDo {
|
||||
return s.withDO(s.DO.Unscoped())
|
||||
}
|
||||
|
||||
func (s sessionDo) Create(values ...*models.Session) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.DO.Create(values)
|
||||
}
|
||||
|
||||
func (s sessionDo) CreateInBatches(values []*models.Session, batchSize int) error {
|
||||
return s.DO.CreateInBatches(values, batchSize)
|
||||
}
|
||||
|
||||
// Save : !!! underlying implementation is different with GORM
|
||||
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
|
||||
func (s sessionDo) Save(values ...*models.Session) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
return s.DO.Save(values)
|
||||
}
|
||||
|
||||
func (s sessionDo) First() (*models.Session, error) {
|
||||
if result, err := s.DO.First(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*models.Session), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionDo) Take() (*models.Session, error) {
|
||||
if result, err := s.DO.Take(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*models.Session), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionDo) Last() (*models.Session, error) {
|
||||
if result, err := s.DO.Last(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*models.Session), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionDo) Find() ([]*models.Session, error) {
|
||||
result, err := s.DO.Find()
|
||||
return result.([]*models.Session), err
|
||||
}
|
||||
|
||||
func (s sessionDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.Session, err error) {
|
||||
buf := make([]*models.Session, 0, batchSize)
|
||||
err = s.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
|
||||
defer func() { results = append(results, buf...) }()
|
||||
return fc(tx, batch)
|
||||
})
|
||||
return results, err
|
||||
}
|
||||
|
||||
func (s sessionDo) FindInBatches(result *[]*models.Session, batchSize int, fc func(tx gen.Dao, batch int) error) error {
|
||||
return s.DO.FindInBatches(result, batchSize, fc)
|
||||
}
|
||||
|
||||
func (s sessionDo) Attrs(attrs ...field.AssignExpr) *sessionDo {
|
||||
return s.withDO(s.DO.Attrs(attrs...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Assign(attrs ...field.AssignExpr) *sessionDo {
|
||||
return s.withDO(s.DO.Assign(attrs...))
|
||||
}
|
||||
|
||||
func (s sessionDo) Joins(fields ...field.RelationField) *sessionDo {
|
||||
for _, _f := range fields {
|
||||
s = *s.withDO(s.DO.Joins(_f))
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s sessionDo) Preload(fields ...field.RelationField) *sessionDo {
|
||||
for _, _f := range fields {
|
||||
s = *s.withDO(s.DO.Preload(_f))
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s sessionDo) FirstOrInit() (*models.Session, error) {
|
||||
if result, err := s.DO.FirstOrInit(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*models.Session), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionDo) FirstOrCreate() (*models.Session, error) {
|
||||
if result, err := s.DO.FirstOrCreate(); err != nil {
|
||||
return nil, err
|
||||
} else {
|
||||
return result.(*models.Session), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s sessionDo) FindByPage(offset int, limit int) (result []*models.Session, count int64, err error) {
|
||||
result, err = s.Offset(offset).Limit(limit).Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if size := len(result); 0 < limit && 0 < size && size < limit {
|
||||
count = int64(size + offset)
|
||||
return
|
||||
}
|
||||
|
||||
count, err = s.Offset(-1).Limit(-1).Count()
|
||||
return
|
||||
}
|
||||
|
||||
func (s sessionDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
|
||||
count, err = s.Count()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.Offset(offset).Limit(limit).Scan(result)
|
||||
return
|
||||
}
|
||||
|
||||
func (s sessionDo) Scan(result interface{}) (err error) {
|
||||
return s.DO.Scan(result)
|
||||
}
|
||||
|
||||
func (s sessionDo) Delete(models ...*models.Session) (result gen.ResultInfo, err error) {
|
||||
return s.DO.Delete(models)
|
||||
}
|
||||
|
||||
func (s *sessionDo) withDO(do gen.Dao) *sessionDo {
|
||||
s.DO = *do.(*gen.DO)
|
||||
return s
|
||||
}
|
||||
@@ -18,21 +18,15 @@ var Auth = &authService{}
|
||||
type authService struct{}
|
||||
|
||||
// OauthAuthorizationCode 验证授权码
|
||||
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) {
|
||||
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*auth2.TokenDetails, error) {
|
||||
// TODO: 从数据库验证授权码
|
||||
return nil, errors.New("TODO")
|
||||
}
|
||||
|
||||
// OauthClientCredentials 验证客户端凭证
|
||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
|
||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*auth2.TokenDetails, error) {
|
||||
|
||||
var clientType auth2.PayloadType
|
||||
switch client2.Spec(client.Spec) {
|
||||
case client2.SpecNative, client2.SpecBrowser:
|
||||
clientType = auth2.PayloadPublicServer
|
||||
case client2.SpecWeb, client2.SpecTrusted:
|
||||
clientType = auth2.PayloadSecuredServer
|
||||
}
|
||||
var clientType = auth2.PayloadTypeFromClientSpec(client2.Spec(client.Spec))
|
||||
|
||||
var permissions = make(map[string]struct{}, len(scope))
|
||||
for _, item := range scope {
|
||||
@@ -50,7 +44,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
}
|
||||
|
||||
// todo 数据库定义会话持续时间
|
||||
token, err := Session.Create(ctx, authCtx, false)
|
||||
token, err := auth2.CreateSession(ctx, &authCtx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -59,9 +53,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
}
|
||||
|
||||
// OauthRefreshToken 验证刷新令牌
|
||||
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) {
|
||||
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*auth2.TokenDetails, error) {
|
||||
// TODO: 从数据库验证刷新令牌
|
||||
details, err := Session.Refresh(ctx, refreshToken)
|
||||
details, err := auth2.RefreshSession(ctx, refreshToken, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -70,7 +64,7 @@ func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refres
|
||||
}
|
||||
|
||||
// OauthPassword 验证密码
|
||||
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*TokenDetails, error) {
|
||||
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*auth2.TokenDetails, error) {
|
||||
var user *m.User
|
||||
err := q.Q.Transaction(func(tx *q.Query) error {
|
||||
|
||||
@@ -145,7 +139,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
||||
},
|
||||
}
|
||||
|
||||
token, err := Session.Create(ctx, authCtx, data.Remember)
|
||||
token, err := auth2.CreateSession(ctx, &authCtx, data.Remember)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -179,7 +173,7 @@ func (e AuthServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
const (
|
||||
ErrOauthInvalidRequest = AuthServiceError("invalid_request")
|
||||
ErrOauthInvalidClient = AuthServiceError("invalid_client")
|
||||
ErrOauthInvalidGrant = AuthServiceError("invalid_grant")
|
||||
|
||||
@@ -443,7 +443,7 @@ func calcChannels(
|
||||
}
|
||||
|
||||
if env.DebugExternalChange && next > count {
|
||||
step = time.Now()
|
||||
var step = time.Now()
|
||||
|
||||
var multiple float64 = 2 // 扩张倍数
|
||||
var newConfig = g.AutoConfig{
|
||||
@@ -550,7 +550,7 @@ func calcChannels(
|
||||
|
||||
// 提交端口配置并更新节点列表
|
||||
if env.DebugExternalChange {
|
||||
step = time.Now()
|
||||
var step = time.Now()
|
||||
|
||||
var secret = strings.Split(proxy.Secret, ":")
|
||||
gateway := g.NewGateway(
|
||||
|
||||
@@ -1,226 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/web/auth"
|
||||
g "platform/web/globals"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// region SessionService
|
||||
|
||||
var Session SessionServiceInter = &sessionService{}
|
||||
|
||||
type SessionServiceInter interface {
|
||||
// Find 通过访问令牌获取会话信息
|
||||
Find(ctx context.Context, token string) (*auth.Context, error)
|
||||
// Create 创建一个新的会话
|
||||
Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error)
|
||||
// Refresh 刷新一个会话
|
||||
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
|
||||
// Remove 删除会话
|
||||
Remove(ctx context.Context, accessToken, refreshToken string) error
|
||||
}
|
||||
|
||||
type SessionServiceError string
|
||||
|
||||
func (e SessionServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrInvalidToken = SessionServiceError("invalid_token")
|
||||
)
|
||||
|
||||
type sessionService struct{}
|
||||
|
||||
// Find 通过访问令牌获取会话信息
|
||||
func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
authCtx := new(auth.Context)
|
||||
if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return authCtx, nil
|
||||
}
|
||||
|
||||
// Create 创建一个新的会话
|
||||
func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
accessToken := genToken()
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(authCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: authCtx,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务保存数据到 Redis
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipe := g.Redis.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
||||
if remember {
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenDetails{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Refresh 刷新一个会话
|
||||
func (s *sessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
rKey := refreshKey(refreshToken)
|
||||
var tokenDetails *TokenDetails
|
||||
|
||||
// 刷新令牌
|
||||
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 先获取刷新令牌数据
|
||||
refreshJson, err := tx.Get(ctx, rKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析刷新令牌数据
|
||||
refreshData := new(RefreshData)
|
||||
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成新的令牌
|
||||
newAccessToken := genToken()
|
||||
newRefreshToken := genToken()
|
||||
|
||||
authData, err := json.Marshal(refreshData.AuthContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRefreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: refreshData.AuthContext,
|
||||
AccessToken: newAccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipeline := tx.Pipeline()
|
||||
|
||||
// 保存新的令牌
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
|
||||
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
|
||||
|
||||
// 删除旧的令牌
|
||||
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
||||
pipeline.Del(ctx, refreshKey(refreshToken))
|
||||
|
||||
_, err = pipeline.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenDetails = &TokenDetails{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: refreshData.AuthContext,
|
||||
}
|
||||
return nil
|
||||
}, rKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
||||
}
|
||||
|
||||
return tokenDetails, nil
|
||||
}
|
||||
|
||||
// Remove 删除会话
|
||||
func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
||||
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成一个新的令牌
|
||||
func genToken() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
// 令牌键的格式为 "session:<token>"
|
||||
func accessKey(token string) string {
|
||||
return fmt.Sprintf("session:%s", token)
|
||||
}
|
||||
|
||||
// 刷新令牌键的格式为 "session:refreshKey:<token>"
|
||||
func refreshKey(token string) string {
|
||||
return fmt.Sprintf("session:refresh:%s", token)
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext auth.Context
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
// TokenDetails 存储令牌详细信息
|
||||
type TokenDetails struct {
|
||||
// 访问令牌
|
||||
AccessToken string
|
||||
// 刷新令牌
|
||||
RefreshToken string
|
||||
// 访问令牌过期时间
|
||||
AccessTokenExpires time.Time
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth auth.Context
|
||||
}
|
||||
Reference in New Issue
Block a user