重构认证授权逻辑,集中到 auth 包中

This commit is contained in:
2025-05-12 10:07:12 +08:00
parent cfdee98a1b
commit 2c37dcc2be
40 changed files with 905 additions and 455 deletions

View File

@@ -1,28 +1,29 @@
## todo ## todo
- 长效业务接入
- 页面 账户总览 - 页面 账户总览
- 页面 提取记录 - 页面 提取记录
- 页面 使用记录 - 页面 使用记录
- 公众号的到期提示
- 支付回调处理
- 保存 session 到数据库 - 保存 session 到数据库
- 移除 PayloadType使用 Grant_Type
### 下阶段 ### 长效业务
- 代理数据表的 secret 字段 aes 加密存储 - 支付回调处理
- 扩展 device 权限验证方式,提供一种方法区分内部和外部服务 - 公众号的到期提示
- 废弃 password 授权模式,迁移到 authorization code 授权模式
- oauth token 验证授权范围
- 实现白银节点的 warp 服务,用来去重与端口分配,保证测试与生产环境不会产生端口竞争
- callback 结果直接由 api 端提供,不通过前端转发
- debug白银节点提供一些工具接口方便快速操作
- 批量下线端口
- 统一使用 validator 进行参数验证
### 长期 ### 长期
- 修改日志输出提高可读性
- 用户最后登录的数据可以通过 session 表进行查询,不再保存在 user 表里
- callback 结果直接由 api 端提供,不通过前端转发
- debug白银节点提供一些工具接口方便快速操作
- 批量下线端口
- 代理数据表的 secret 字段 aes 加密存储
- 废弃 password 授权模式,迁移到 authorization code 授权模式
- oauth token 验证授权范围
- 实现白银节点的 warp 服务,用来去重与端口分配,保证测试与生产环境不会产生端口竞争
- 统一使用 validator 进行参数验证
- 分离项目脚手架envlogsServer 结构体) - 分离项目脚手架envlogsServer 结构体)
- 业务代码和测试代码共用的控制变量可以优化为环境变量 - 业务代码和测试代码共用的控制变量可以优化为环境变量
- 考虑统计接口调用频率并通过接口展示 - 考虑统计接口调用频率并通过接口展示

View File

@@ -295,6 +295,52 @@ comment on column client.deleted_at is '删除时间';
-- region 权限信息 -- region 权限信息
-- ==================== -- ====================
-- session
drop table if exists session cascade;
create table session (
id serial primary key,
user_id int references "user" (id)
on update cascade
on delete cascade,
client_id int references client (id)
on update cascade
on delete cascade,
ip varchar(45),
ua varchar(255),
grant_type varchar(255) not null default 0,
access_token varchar(255) not null unique,
access_token_expires timestamp not null,
refresh_token varchar(255) unique,
refresh_token_expires timestamp,
scopes varchar(255),
created_at timestamp default current_timestamp,
updated_at timestamp default current_timestamp,
deleted_at timestamp
);
create index session_user_id_index on session (user_id);
create index session_client_id_index on session (client_id);
create index session_access_token_index on session (access_token);
create index session_refresh_token_index on session (refresh_token);
create index session_created_at_index on session (created_at);
create index session_deleted_at_index on session (deleted_at);
-- session表字段注释
comment on table session is '会话表';
comment on column session.id is '会话ID';
comment on column session.user_id is '用户ID';
comment on column session.client_id is '客户端ID';
comment on column session.ip is 'IP地址';
comment on column session.ua is '用户代理';
comment on column session.grant_type is '授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式';
comment on column session.access_token is '访问令牌';
comment on column session.access_token_expires is '访问令牌过期时间';
comment on column session.refresh_token is '刷新令牌';
comment on column session.refresh_token_expires is '刷新令牌过期时间';
comment on column session.scopes is '权限范围';
comment on column session.created_at is '创建时间';
comment on column session.updated_at is '更新时间';
comment on column session.deleted_at is '删除时间';
-- permission -- permission
drop table if exists permission cascade; drop table if exists permission cascade;
create table permission ( create table permission (

View File

@@ -77,7 +77,7 @@ func Locals(c *fiber.Ctx, auth *Context) {
} }
func authBearer(ctx context.Context, token string) (*Context, error) { func authBearer(ctx context.Context, token string) (*Context, error) {
auth, err := find(ctx, token) auth, err := FindSession(ctx, token)
if err != nil { if err != nil {
slog.Debug(err.Error()) slog.Debug(err.Error())
return nil, err return nil, err

View File

@@ -16,3 +16,43 @@ const (
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式 GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式 GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式
) )
func Token(grant GrantType) error {
return nil
}
func authAuthorizationCode() {
}
func authClientCredential() {
}
func authRefreshToken() {
}
func authPassword() {
}
func authPasswordSecret() {
}
func authPasswordPhone() {
}
func authPasswordEmail() {
}
func Revoke() error {
return nil
}
func Introspect() error {
return nil
}

View File

@@ -1,13 +1,28 @@
package auth package auth
import (
client2 "platform/web/domains/client"
)
// Context 定义认证信息 // Context 定义认证信息
type Context struct { type Context struct {
Payload Payload `json:"payload"` Payload Payload `json:"payload"`
Agent Agent `json:"agent,omitempty"`
Permissions map[string]struct{} `json:"permissions,omitempty"` Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"` Metadata map[string]interface{} `json:"metadata,omitempty"`
} }
func (a *Context) AnyType(types ...PayloadType) bool {
if a == nil {
return false
}
for _, t := range types {
if a.Payload.Type == t {
return true
}
}
return false
}
// AnyPermission 检查认证是否包含指定权限 // AnyPermission 检查认证是否包含指定权限
func (a *Context) AnyPermission(requiredPermission ...string) bool { func (a *Context) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil { if a == nil || a.Permissions == nil {
@@ -29,26 +44,15 @@ type Payload struct {
Avatar string `json:"avatar,omitempty"` Avatar string `json:"avatar,omitempty"`
} }
type Agent struct {
Id int32 `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
type PayloadType int type PayloadType int
const ( const (
// PayloadNone 游客 PayloadNone PayloadType = iota // 游客
PayloadNone PayloadType = iota PayloadUser // 用户
// PayloadUser 用户 PayloadAdmin // 管理员
PayloadUser PayloadPublicServer // 公共服务public_client
// PayloadAdmin 管理员 PayloadSecuredServer // 安全服务credential_client
PayloadAdmin PayloadInternalServer // 内部服务
// PayloadPublicServer 公共服务public_client
PayloadPublicServer
// PayloadSecuredServer 安全服务credential_client
PayloadSecuredServer
// PayloadInternalServer 内部服务
PayloadInternalServer
) )
func (t PayloadType) ToStr() string { func (t PayloadType) ToStr() string {
@@ -80,3 +84,16 @@ func PayloadTypeFromStr(name string) PayloadType {
return PayloadNone return PayloadNone
} }
} }
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
var clientType PayloadType
switch spec {
case client2.SpecNative, client2.SpecBrowser:
clientType = PayloadPublicServer
case client2.SpecWeb:
clientType = PayloadSecuredServer
case client2.SpecTrusted:
clientType = PayloadInternalServer
}
return clientType
}

View File

@@ -5,11 +5,21 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"platform/pkg/env"
g "platform/web/globals" g "platform/web/globals"
"time"
) )
func find(ctx context.Context, token string) (*Context, error) { type Session struct {
// 认证主体
Payload *Payload
// 令牌信息
TokenDetails *TokenDetails
}
func FindSession(ctx context.Context, token string) (*Context, error) {
// 读取认证数据 // 读取认证数据
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result() authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
@@ -29,6 +39,170 @@ func find(ctx context.Context, token string) (*Context, error) {
return auth, nil return auth, nil
} }
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(authCtx)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: authCtx,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := g.Redis.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: authCtx,
}, nil
}
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidRefreshToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
}
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string { func accessKey(token string) string {
return fmt.Sprintf("session:%s", token) return fmt.Sprintf("session:%s", token)
} }
// 刷新令牌键的格式为 "session:refreshKey:<token>"
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth *Context
}
type RefreshData struct {
AuthContext *Context
AccessToken string
}
type SessionErr string
func (e SessionErr) Error() string {
return string(e)
}
const (
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
)

View File

@@ -110,7 +110,7 @@ func Token(c *fiber.Ctx) error {
scope := strings.Split(req.Scope, ",") scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil { if err != nil {
if errors.Is(err, s.ErrInvalidToken) { if errors.Is(err, auth2.ErrInvalidRefreshToken) {
return sendError(c, s.ErrOauthInvalidGrant) return sendError(c, s.ErrOauthInvalidGrant)
} }
return sendError(c, err) return sendError(c, err)
@@ -226,7 +226,7 @@ func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string)
} }
// 发送成功响应 // 发送成功响应
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error { func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
return c.JSON(TokenResp{ return c.JSON(TokenResp{
AccessToken: details.AccessToken, AccessToken: details.AccessToken,
TokenType: "Bearer", TokenType: "Bearer",
@@ -292,7 +292,7 @@ func Revoke(c *fiber.Ctx) error {
} }
// 删除会话 // 删除会话
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken) err = auth2.RemoveSession(c.Context(), req.AccessToken, req.RefreshToken)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -198,10 +198,7 @@ type RemoveChannelsReq struct {
func RemoveChannels(c *fiber.Ctx) error { func RemoveChannels(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{ authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser, auth.PayloadInternalServer}, []string{})
auth.PayloadUser,
auth.PayloadSecuredServer,
}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -17,9 +17,7 @@ type VerifierReq struct {
func SmsCode(c *fiber.Ctx) error { func SmsCode(c *fiber.Ctx) error {
_, err := auth.Protect(c, []auth.PayloadType{ _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadSecuredServer}, []string{})
auth.PayloadSecuredServer,
}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -4,9 +4,7 @@
package models package models
import ( import "platform/web/globals/orm"
"platform/web/globals/orm"
)
const TableNameLogsRequest = "logs_request" const TableNameLogsRequest = "logs_request"

View File

@@ -4,9 +4,7 @@
package models package models
import ( import "platform/web/globals/orm"
"platform/web/globals/orm"
)
const TableNameResourcePsr = "resource_psr" const TableNameResourcePsr = "resource_psr"

View File

@@ -4,9 +4,7 @@
package models package models
import ( import "platform/web/globals/orm"
"platform/web/globals/orm"
)
const TableNameResourcePss = "resource_pss" const TableNameResourcePss = "resource_pss"

36
web/models/session.gen.go Normal file
View File

@@ -0,0 +1,36 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package models
import (
"platform/web/globals/orm"
"gorm.io/gorm"
)
const TableNameSession = "session"
// Session mapped from table <session>
type Session struct {
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
UserID int32 `gorm:"column:user_id;comment:用户ID" json:"user_id"` // 用户ID
ClientID int32 `gorm:"column:client_id;comment:客户端ID" json:"client_id"` // 客户端ID
IP string `gorm:"column:ip;comment:IP地址" json:"ip"` // IP地址
Ua string `gorm:"column:ua;comment:用户代理" json:"ua"` // 用户代理
GrantType string `gorm:"column:grant_type;not null;default:0;comment:授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式" json:"grant_type"` // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
AccessToken string `gorm:"column:access_token;not null;comment:访问令牌" json:"access_token"` // 访问令牌
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
RefreshToken string `gorm:"column:refresh_token;comment:刷新令牌" json:"refresh_token"` // 刷新令牌
RefreshTokenExpires orm.LocalDateTime `gorm:"column:refresh_token_expires;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
Scopes_ string `gorm:"column:scopes;comment:权限范围" json:"scopes"` // 权限范围
CreatedAt orm.LocalDateTime `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt orm.LocalDateTime `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
}
// TableName Session's table name
func (*Session) TableName() string {
return TableNameSession
}

View File

@@ -37,6 +37,7 @@ var (
ResourcePps *resourcePps ResourcePps *resourcePps
ResourcePsr *resourcePsr ResourcePsr *resourcePsr
ResourcePss *resourcePss ResourcePss *resourcePss
Session *session
Trade *trade Trade *trade
User *user User *user
UserRole *userRole UserRole *userRole
@@ -67,6 +68,7 @@ func SetDefault(db *gorm.DB, opts ...gen.DOOption) {
ResourcePps = &Q.ResourcePps ResourcePps = &Q.ResourcePps
ResourcePsr = &Q.ResourcePsr ResourcePsr = &Q.ResourcePsr
ResourcePss = &Q.ResourcePss ResourcePss = &Q.ResourcePss
Session = &Q.Session
Trade = &Q.Trade Trade = &Q.Trade
User = &Q.User User = &Q.User
UserRole = &Q.UserRole UserRole = &Q.UserRole
@@ -98,6 +100,7 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query {
ResourcePps: newResourcePps(db, opts...), ResourcePps: newResourcePps(db, opts...),
ResourcePsr: newResourcePsr(db, opts...), ResourcePsr: newResourcePsr(db, opts...),
ResourcePss: newResourcePss(db, opts...), ResourcePss: newResourcePss(db, opts...),
Session: newSession(db, opts...),
Trade: newTrade(db, opts...), Trade: newTrade(db, opts...),
User: newUser(db, opts...), User: newUser(db, opts...),
UserRole: newUserRole(db, opts...), UserRole: newUserRole(db, opts...),
@@ -130,6 +133,7 @@ type Query struct {
ResourcePps resourcePps ResourcePps resourcePps
ResourcePsr resourcePsr ResourcePsr resourcePsr
ResourcePss resourcePss ResourcePss resourcePss
Session session
Trade trade Trade trade
User user User user
UserRole userRole UserRole userRole
@@ -163,6 +167,7 @@ func (q *Query) clone(db *gorm.DB) *Query {
ResourcePps: q.ResourcePps.clone(db), ResourcePps: q.ResourcePps.clone(db),
ResourcePsr: q.ResourcePsr.clone(db), ResourcePsr: q.ResourcePsr.clone(db),
ResourcePss: q.ResourcePss.clone(db), ResourcePss: q.ResourcePss.clone(db),
Session: q.Session.clone(db),
Trade: q.Trade.clone(db), Trade: q.Trade.clone(db),
User: q.User.clone(db), User: q.User.clone(db),
UserRole: q.UserRole.clone(db), UserRole: q.UserRole.clone(db),
@@ -203,6 +208,7 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query {
ResourcePps: q.ResourcePps.replaceDB(db), ResourcePps: q.ResourcePps.replaceDB(db),
ResourcePsr: q.ResourcePsr.replaceDB(db), ResourcePsr: q.ResourcePsr.replaceDB(db),
ResourcePss: q.ResourcePss.replaceDB(db), ResourcePss: q.ResourcePss.replaceDB(db),
Session: q.Session.replaceDB(db),
Trade: q.Trade.replaceDB(db), Trade: q.Trade.replaceDB(db),
User: q.User.replaceDB(db), User: q.User.replaceDB(db),
UserRole: q.UserRole.replaceDB(db), UserRole: q.UserRole.replaceDB(db),
@@ -233,6 +239,7 @@ type queryCtx struct {
ResourcePps *resourcePpsDo ResourcePps *resourcePpsDo
ResourcePsr *resourcePsrDo ResourcePsr *resourcePsrDo
ResourcePss *resourcePssDo ResourcePss *resourcePssDo
Session *sessionDo
Trade *tradeDo Trade *tradeDo
User *userDo User *userDo
UserRole *userRoleDo UserRole *userRoleDo
@@ -263,6 +270,7 @@ func (q *Query) WithContext(ctx context.Context) *queryCtx {
ResourcePps: q.ResourcePps.WithContext(ctx), ResourcePps: q.ResourcePps.WithContext(ctx),
ResourcePsr: q.ResourcePsr.WithContext(ctx), ResourcePsr: q.ResourcePsr.WithContext(ctx),
ResourcePss: q.ResourcePss.WithContext(ctx), ResourcePss: q.ResourcePss.WithContext(ctx),
Session: q.Session.WithContext(ctx),
Trade: q.Trade.WithContext(ctx), Trade: q.Trade.WithContext(ctx),
User: q.User.WithContext(ctx), User: q.User.WithContext(ctx),
UserRole: q.UserRole.WithContext(ctx), UserRole: q.UserRole.WithContext(ctx),

371
web/queries/session.gen.go Normal file
View File

@@ -0,0 +1,371 @@
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package queries
import (
"context"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gen"
"gorm.io/gen/field"
"gorm.io/plugin/dbresolver"
"platform/web/models"
)
func newSession(db *gorm.DB, opts ...gen.DOOption) session {
_session := session{}
_session.sessionDo.UseDB(db, opts...)
_session.sessionDo.UseModel(&models.Session{})
tableName := _session.sessionDo.TableName()
_session.ALL = field.NewAsterisk(tableName)
_session.ID = field.NewInt32(tableName, "id")
_session.UserID = field.NewInt32(tableName, "user_id")
_session.ClientID = field.NewInt32(tableName, "client_id")
_session.IP = field.NewString(tableName, "ip")
_session.Ua = field.NewString(tableName, "ua")
_session.GrantType = field.NewString(tableName, "grant_type")
_session.AccessToken = field.NewString(tableName, "access_token")
_session.AccessTokenExpires = field.NewField(tableName, "access_token_expires")
_session.RefreshToken = field.NewString(tableName, "refresh_token")
_session.RefreshTokenExpires = field.NewField(tableName, "refresh_token_expires")
_session.Scopes_ = field.NewString(tableName, "scopes")
_session.CreatedAt = field.NewField(tableName, "created_at")
_session.UpdatedAt = field.NewField(tableName, "updated_at")
_session.DeletedAt = field.NewField(tableName, "deleted_at")
_session.fillFieldMap()
return _session
}
type session struct {
sessionDo
ALL field.Asterisk
ID field.Int32 // 会话ID
UserID field.Int32 // 用户ID
ClientID field.Int32 // 客户端ID
IP field.String // IP地址
Ua field.String // 用户代理
GrantType field.String // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
AccessToken field.String // 访问令牌
AccessTokenExpires field.Field // 访问令牌过期时间
RefreshToken field.String // 刷新令牌
RefreshTokenExpires field.Field // 刷新令牌过期时间
Scopes_ field.String // 权限范围
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
fieldMap map[string]field.Expr
}
func (s session) Table(newTableName string) *session {
s.sessionDo.UseTable(newTableName)
return s.updateTableName(newTableName)
}
func (s session) As(alias string) *session {
s.sessionDo.DO = *(s.sessionDo.As(alias).(*gen.DO))
return s.updateTableName(alias)
}
func (s *session) updateTableName(table string) *session {
s.ALL = field.NewAsterisk(table)
s.ID = field.NewInt32(table, "id")
s.UserID = field.NewInt32(table, "user_id")
s.ClientID = field.NewInt32(table, "client_id")
s.IP = field.NewString(table, "ip")
s.Ua = field.NewString(table, "ua")
s.GrantType = field.NewString(table, "grant_type")
s.AccessToken = field.NewString(table, "access_token")
s.AccessTokenExpires = field.NewField(table, "access_token_expires")
s.RefreshToken = field.NewString(table, "refresh_token")
s.RefreshTokenExpires = field.NewField(table, "refresh_token_expires")
s.Scopes_ = field.NewString(table, "scopes")
s.CreatedAt = field.NewField(table, "created_at")
s.UpdatedAt = field.NewField(table, "updated_at")
s.DeletedAt = field.NewField(table, "deleted_at")
s.fillFieldMap()
return s
}
func (s *session) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
_f, ok := s.fieldMap[fieldName]
if !ok || _f == nil {
return nil, false
}
_oe, ok := _f.(field.OrderExpr)
return _oe, ok
}
func (s *session) fillFieldMap() {
s.fieldMap = make(map[string]field.Expr, 14)
s.fieldMap["id"] = s.ID
s.fieldMap["user_id"] = s.UserID
s.fieldMap["client_id"] = s.ClientID
s.fieldMap["ip"] = s.IP
s.fieldMap["ua"] = s.Ua
s.fieldMap["grant_type"] = s.GrantType
s.fieldMap["access_token"] = s.AccessToken
s.fieldMap["access_token_expires"] = s.AccessTokenExpires
s.fieldMap["refresh_token"] = s.RefreshToken
s.fieldMap["refresh_token_expires"] = s.RefreshTokenExpires
s.fieldMap["scopes"] = s.Scopes_
s.fieldMap["created_at"] = s.CreatedAt
s.fieldMap["updated_at"] = s.UpdatedAt
s.fieldMap["deleted_at"] = s.DeletedAt
}
func (s session) clone(db *gorm.DB) session {
s.sessionDo.ReplaceConnPool(db.Statement.ConnPool)
return s
}
func (s session) replaceDB(db *gorm.DB) session {
s.sessionDo.ReplaceDB(db)
return s
}
type sessionDo struct{ gen.DO }
func (s sessionDo) Debug() *sessionDo {
return s.withDO(s.DO.Debug())
}
func (s sessionDo) WithContext(ctx context.Context) *sessionDo {
return s.withDO(s.DO.WithContext(ctx))
}
func (s sessionDo) ReadDB() *sessionDo {
return s.Clauses(dbresolver.Read)
}
func (s sessionDo) WriteDB() *sessionDo {
return s.Clauses(dbresolver.Write)
}
func (s sessionDo) Session(config *gorm.Session) *sessionDo {
return s.withDO(s.DO.Session(config))
}
func (s sessionDo) Clauses(conds ...clause.Expression) *sessionDo {
return s.withDO(s.DO.Clauses(conds...))
}
func (s sessionDo) Returning(value interface{}, columns ...string) *sessionDo {
return s.withDO(s.DO.Returning(value, columns...))
}
func (s sessionDo) Not(conds ...gen.Condition) *sessionDo {
return s.withDO(s.DO.Not(conds...))
}
func (s sessionDo) Or(conds ...gen.Condition) *sessionDo {
return s.withDO(s.DO.Or(conds...))
}
func (s sessionDo) Select(conds ...field.Expr) *sessionDo {
return s.withDO(s.DO.Select(conds...))
}
func (s sessionDo) Where(conds ...gen.Condition) *sessionDo {
return s.withDO(s.DO.Where(conds...))
}
func (s sessionDo) Order(conds ...field.Expr) *sessionDo {
return s.withDO(s.DO.Order(conds...))
}
func (s sessionDo) Distinct(cols ...field.Expr) *sessionDo {
return s.withDO(s.DO.Distinct(cols...))
}
func (s sessionDo) Omit(cols ...field.Expr) *sessionDo {
return s.withDO(s.DO.Omit(cols...))
}
func (s sessionDo) Join(table schema.Tabler, on ...field.Expr) *sessionDo {
return s.withDO(s.DO.Join(table, on...))
}
func (s sessionDo) LeftJoin(table schema.Tabler, on ...field.Expr) *sessionDo {
return s.withDO(s.DO.LeftJoin(table, on...))
}
func (s sessionDo) RightJoin(table schema.Tabler, on ...field.Expr) *sessionDo {
return s.withDO(s.DO.RightJoin(table, on...))
}
func (s sessionDo) Group(cols ...field.Expr) *sessionDo {
return s.withDO(s.DO.Group(cols...))
}
func (s sessionDo) Having(conds ...gen.Condition) *sessionDo {
return s.withDO(s.DO.Having(conds...))
}
func (s sessionDo) Limit(limit int) *sessionDo {
return s.withDO(s.DO.Limit(limit))
}
func (s sessionDo) Offset(offset int) *sessionDo {
return s.withDO(s.DO.Offset(offset))
}
func (s sessionDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *sessionDo {
return s.withDO(s.DO.Scopes(funcs...))
}
func (s sessionDo) Unscoped() *sessionDo {
return s.withDO(s.DO.Unscoped())
}
func (s sessionDo) Create(values ...*models.Session) error {
if len(values) == 0 {
return nil
}
return s.DO.Create(values)
}
func (s sessionDo) CreateInBatches(values []*models.Session, batchSize int) error {
return s.DO.CreateInBatches(values, batchSize)
}
// Save : !!! underlying implementation is different with GORM
// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values)
func (s sessionDo) Save(values ...*models.Session) error {
if len(values) == 0 {
return nil
}
return s.DO.Save(values)
}
func (s sessionDo) First() (*models.Session, error) {
if result, err := s.DO.First(); err != nil {
return nil, err
} else {
return result.(*models.Session), nil
}
}
func (s sessionDo) Take() (*models.Session, error) {
if result, err := s.DO.Take(); err != nil {
return nil, err
} else {
return result.(*models.Session), nil
}
}
func (s sessionDo) Last() (*models.Session, error) {
if result, err := s.DO.Last(); err != nil {
return nil, err
} else {
return result.(*models.Session), nil
}
}
func (s sessionDo) Find() ([]*models.Session, error) {
result, err := s.DO.Find()
return result.([]*models.Session), err
}
func (s sessionDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.Session, err error) {
buf := make([]*models.Session, 0, batchSize)
err = s.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error {
defer func() { results = append(results, buf...) }()
return fc(tx, batch)
})
return results, err
}
func (s sessionDo) FindInBatches(result *[]*models.Session, batchSize int, fc func(tx gen.Dao, batch int) error) error {
return s.DO.FindInBatches(result, batchSize, fc)
}
func (s sessionDo) Attrs(attrs ...field.AssignExpr) *sessionDo {
return s.withDO(s.DO.Attrs(attrs...))
}
func (s sessionDo) Assign(attrs ...field.AssignExpr) *sessionDo {
return s.withDO(s.DO.Assign(attrs...))
}
func (s sessionDo) Joins(fields ...field.RelationField) *sessionDo {
for _, _f := range fields {
s = *s.withDO(s.DO.Joins(_f))
}
return &s
}
func (s sessionDo) Preload(fields ...field.RelationField) *sessionDo {
for _, _f := range fields {
s = *s.withDO(s.DO.Preload(_f))
}
return &s
}
func (s sessionDo) FirstOrInit() (*models.Session, error) {
if result, err := s.DO.FirstOrInit(); err != nil {
return nil, err
} else {
return result.(*models.Session), nil
}
}
func (s sessionDo) FirstOrCreate() (*models.Session, error) {
if result, err := s.DO.FirstOrCreate(); err != nil {
return nil, err
} else {
return result.(*models.Session), nil
}
}
func (s sessionDo) FindByPage(offset int, limit int) (result []*models.Session, count int64, err error) {
result, err = s.Offset(offset).Limit(limit).Find()
if err != nil {
return
}
if size := len(result); 0 < limit && 0 < size && size < limit {
count = int64(size + offset)
return
}
count, err = s.Offset(-1).Limit(-1).Count()
return
}
func (s sessionDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) {
count, err = s.Count()
if err != nil {
return
}
err = s.Offset(offset).Limit(limit).Scan(result)
return
}
func (s sessionDo) Scan(result interface{}) (err error) {
return s.DO.Scan(result)
}
func (s sessionDo) Delete(models ...*models.Session) (result gen.ResultInfo, err error) {
return s.DO.Delete(models)
}
func (s *sessionDo) withDO(do gen.Dao) *sessionDo {
s.DO = *do.(*gen.DO)
return s
}

View File

@@ -18,21 +18,15 @@ var Auth = &authService{}
type authService struct{} type authService struct{}
// OauthAuthorizationCode 验证授权码 // OauthAuthorizationCode 验证授权码
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) { func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*auth2.TokenDetails, error) {
// TODO: 从数据库验证授权码 // TODO: 从数据库验证授权码
return nil, errors.New("TODO") return nil, errors.New("TODO")
} }
// OauthClientCredentials 验证客户端凭证 // OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*auth2.TokenDetails, error) {
var clientType auth2.PayloadType var clientType = auth2.PayloadTypeFromClientSpec(client2.Spec(client.Spec))
switch client2.Spec(client.Spec) {
case client2.SpecNative, client2.SpecBrowser:
clientType = auth2.PayloadPublicServer
case client2.SpecWeb, client2.SpecTrusted:
clientType = auth2.PayloadSecuredServer
}
var permissions = make(map[string]struct{}, len(scope)) var permissions = make(map[string]struct{}, len(scope))
for _, item := range scope { for _, item := range scope {
@@ -50,7 +44,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
} }
// todo 数据库定义会话持续时间 // todo 数据库定义会话持续时间
token, err := Session.Create(ctx, authCtx, false) token, err := auth2.CreateSession(ctx, &authCtx, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -59,9 +53,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
} }
// OauthRefreshToken 验证刷新令牌 // OauthRefreshToken 验证刷新令牌
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) { func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*auth2.TokenDetails, error) {
// TODO: 从数据库验证刷新令牌 // TODO: 从数据库验证刷新令牌
details, err := Session.Refresh(ctx, refreshToken) details, err := auth2.RefreshSession(ctx, refreshToken, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -70,7 +64,7 @@ func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refres
} }
// OauthPassword 验证密码 // OauthPassword 验证密码
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*TokenDetails, error) { func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*auth2.TokenDetails, error) {
var user *m.User var user *m.User
err := q.Q.Transaction(func(tx *q.Query) error { err := q.Q.Transaction(func(tx *q.Query) error {
@@ -145,7 +139,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
}, },
} }
token, err := Session.Create(ctx, authCtx, data.Remember) token, err := auth2.CreateSession(ctx, &authCtx, data.Remember)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -179,7 +173,7 @@ func (e AuthServiceError) Error() string {
return string(e) return string(e)
} }
var ( const (
ErrOauthInvalidRequest = AuthServiceError("invalid_request") ErrOauthInvalidRequest = AuthServiceError("invalid_request")
ErrOauthInvalidClient = AuthServiceError("invalid_client") ErrOauthInvalidClient = AuthServiceError("invalid_client")
ErrOauthInvalidGrant = AuthServiceError("invalid_grant") ErrOauthInvalidGrant = AuthServiceError("invalid_grant")

View File

@@ -443,7 +443,7 @@ func calcChannels(
} }
if env.DebugExternalChange && next > count { if env.DebugExternalChange && next > count {
step = time.Now() var step = time.Now()
var multiple float64 = 2 // 扩张倍数 var multiple float64 = 2 // 扩张倍数
var newConfig = g.AutoConfig{ var newConfig = g.AutoConfig{
@@ -550,7 +550,7 @@ func calcChannels(
// 提交端口配置并更新节点列表 // 提交端口配置并更新节点列表
if env.DebugExternalChange { if env.DebugExternalChange {
step = time.Now() var step = time.Now()
var secret = strings.Split(proxy.Secret, ":") var secret = strings.Split(proxy.Secret, ":")
gateway := g.NewGateway( gateway := g.NewGateway(

View File

@@ -1,226 +0,0 @@
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"platform/pkg/env"
"platform/web/auth"
g "platform/web/globals"
"time"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
)
// region SessionService
var Session SessionServiceInter = &sessionService{}
type SessionServiceInter interface {
// Find 通过访问令牌获取会话信息
Find(ctx context.Context, token string) (*auth.Context, error)
// Create 创建一个新的会话
Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error)
// Refresh 刷新一个会话
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
// Remove 删除会话
Remove(ctx context.Context, accessToken, refreshToken string) error
}
type SessionServiceError string
func (e SessionServiceError) Error() string {
return string(e)
}
var (
ErrInvalidToken = SessionServiceError("invalid_token")
)
type sessionService struct{}
// Find 通过访问令牌获取会话信息
func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
// 读取认证数据
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, ErrInvalidToken
}
return nil, err
}
// 反序列化
authCtx := new(auth.Context)
if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil {
return nil, err
}
return authCtx, nil
}
// Create 创建一个新的会话
func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(authCtx)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: authCtx,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := g.Redis.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: authCtx,
}, nil
}
// Refresh 刷新一个会话
func (s *sessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
}
// Remove 删除会话
func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
g.Redis.Del(ctx, accessKey(accessToken), refreshKey(refreshToken))
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
}
// 刷新令牌键的格式为 "session:refreshKey:<token>"
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// endregion
type RefreshData struct {
AuthContext auth.Context
AccessToken string
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth auth.Context
}