重构迁移核心数据结构到认证模块;完善中间件初始化逻辑以及 logger 记录过程

This commit is contained in:
2025-05-08 13:18:54 +08:00
parent c93d0bf467
commit e2cc318560
24 changed files with 353 additions and 215 deletions

View File

@@ -1,11 +1,9 @@
## todo ## todo
- 规范化日志
- 接口日志采集
- channel 接口 - channel 接口
- 每个用户-节点为一条数据,联查白名单
- 重新梳理逻辑流程,简化循环 - 重新梳理逻辑流程,简化循环
- 端口分配时加锁 - 端口分配时加锁
- 每个用户-节点为一条数据,联查白名单
- 长效业务接入 - 长效业务接入
- 微信支付 - 微信支付
- 页面 账户总览 - 页面 账户总览

View File

@@ -54,21 +54,23 @@ func main() {
q.Client.ClientSecret, q.Client.ClientSecret,
q.Client.GrantClient, q.Client.GrantClient,
q.Client.GrantRefresh, q.Client.GrantRefresh,
q.Client.GrantPassword,
q.Client.Spec, q.Client.Spec,
q.Client.Name, q.Client.Name,
). ).
Create(&m.Client{ Create(&m.Client{
ClientID: "test", ClientID: "test",
ClientSecret: string(testSecret), ClientSecret: string(testSecret),
GrantCode: true,
GrantClient: true, GrantClient: true,
GrantRefresh: true, GrantRefresh: true,
GrantPassword: true,
Spec: 3, Spec: 3,
Name: "默认客户端", Name: "默认客户端",
}, &m.Client{ }, &m.Client{
ClientID: "tasks", ClientID: "tasks",
ClientSecret: string(tasksSecret), ClientSecret: string(tasksSecret),
GrantClient: true, GrantClient: true,
GrantRefresh: true,
Spec: 3, Spec: 3,
Name: "异步任务处理服务", Name: "异步任务处理服务",
}) })

View File

@@ -34,6 +34,7 @@ create table logs_request (
method varchar(10) not null, method varchar(10) not null,
path varchar(255) not null, path varchar(255) not null,
latency varchar(255),
status int not null, status int not null,
error varchar(255), error varchar(255),
@@ -49,6 +50,7 @@ comment on column logs_request.ip is 'IP地址';
comment on column logs_request.ua is '用户代理'; comment on column logs_request.ua is '用户代理';
comment on column logs_request.method is '请求方法'; comment on column logs_request.method is '请求方法';
comment on column logs_request.path is '请求路径'; comment on column logs_request.path is '请求路径';
comment on column logs_request.latency is '请求延迟';
comment on column logs_request.status is '响应状态码'; comment on column logs_request.status is '响应状态码';
comment on column logs_request.error is '错误信息'; comment on column logs_request.error is '错误信息';
comment on column logs_request.time is '请求时间'; comment on column logs_request.time is '请求时间';

View File

@@ -9,13 +9,11 @@ import (
"slices" "slices"
"strings" "strings"
"platform/web/services"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) { func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) {
// 获取令牌 // 获取令牌
var header = c.Get("Authorization") var header = c.Get("Authorization")
var split = strings.Split(header, " ") var split = strings.Split(header, " ")
@@ -28,7 +26,7 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
} }
var auth *services.AuthContext var auth *Context
var err error var err error
switch split[0] { switch split[0] {
@@ -36,23 +34,23 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
auth, err = authBearer(c.Context(), token) auth, err = authBearer(c.Context(), token)
if err != nil { if err != nil {
slog.Debug("Bearer 认证失败", "err", err) slog.Debug("Bearer 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
} }
case "Basic": case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) { if !slices.Contains(types, PayloadClientConfidential) {
slog.Debug("禁止使用 Basic 认证方式") slog.Debug("禁止使用 Basic 认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
} }
auth, err = authBasic(c.Context(), token) auth, err = authBasic(c.Context(), token)
if err != nil { if err != nil {
slog.Debug("Basic 认证失败", "err", err) slog.Debug("Basic 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
} }
default: default:
slog.Debug("无效的认证方式") slog.Debug("无效的认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
} }
// 检查权限 // 检查权限
@@ -67,15 +65,18 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
} }
// 保存到上下文 // 保存到上下文
c.Locals("auth", auth) Locals(c, auth)
c.Locals("authid", auth.Payload.Id)
c.Locals("authtype", auth.Payload.Type.Name())
return auth, nil return auth, nil
} }
func authBearer(ctx context.Context, token string) (*services.AuthContext, error) { func Locals(c *fiber.Ctx, auth *Context) {
auth, err := services.Session.Find(ctx, token) c.Locals("auth", auth)
c.Locals("authtype", auth.Payload.Type.ToStr())
c.Locals("authid", auth.Payload.Id)
}
func authBearer(ctx context.Context, token string) (*Context, error) {
auth, err := find(ctx, token)
if err != nil { if err != nil {
slog.Debug(err.Error()) slog.Debug(err.Error())
return nil, err return nil, err
@@ -83,7 +84,7 @@ func authBearer(ctx context.Context, token string) (*services.AuthContext, error
return auth, nil return auth, nil
} }
func authBasic(_ context.Context, token string) (*services.AuthContext, error) { func authBasic(_ context.Context, token string) (*Context, error) {
// 解析 Basic 认证信息 // 解析 Basic 认证信息
var base, err = base64.RawURLEncoding.DecodeString(token) var base, err = base64.RawURLEncoding.DecodeString(token)
@@ -122,10 +123,10 @@ func authBasic(_ context.Context, token string) (*services.AuthContext, error) {
// todo 查询客户端关联权限 // todo 查询客户端关联权限
// 组织授权信息(一次性请求) // 组织授权信息(一次性请求)
return &services.AuthContext{ return &Context{
Payload: services.Payload{ Payload: Payload{
Id: client.ID, Id: client.ID,
Type: services.PayloadClientConfidential, Type: PayloadClientConfidential,
Name: client.Name, Name: client.Name,
Avatar: client.Icon, Avatar: client.Icon,
}, },

78
web/auth/context.go Normal file
View File

@@ -0,0 +1,78 @@
package auth
import "platform/pkg/u"
// Context 定义认证信息
type Context struct {
Payload Payload `json:"payload"`
Agent Agent `json:"agent,omitempty"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// AnyPermission 检查认证是否包含指定权限
func (a *Context) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
return false
}
for _, permission := range requiredPermission {
if _, ok := a.Permissions[permission]; ok {
return true
}
}
return false
}
// Payload 定义负载信息
type Payload struct {
Id int32 `json:"id,omitempty"`
Type PayloadType `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"`
}
type Agent struct {
Id int32 `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
type PayloadType int
const (
// PayloadUser 用户类型
PayloadUser PayloadType = iota + 1
// PayloadAdmin 管理员类型
PayloadAdmin
// PayloadClientPublic 公共客户端类型
PayloadClientPublic
// PayloadClientConfidential 机密客户端类型
PayloadClientConfidential
)
func (t PayloadType) ToStr() string {
switch t {
case PayloadUser:
return "user"
case PayloadAdmin:
return "admn"
case PayloadClientPublic:
return "cpub"
case PayloadClientConfidential:
return "ccnf"
}
return "none"
}
func PayloadTypeFromStr(name string) *PayloadType {
switch name {
case "user":
return u.P(PayloadUser)
case "admn":
return u.P(PayloadAdmin)
case "cpub":
return u.P(PayloadClientPublic)
case "ccnf":
return u.P(PayloadClientConfidential)
}
return nil
}

34
web/auth/session.go Normal file
View File

@@ -0,0 +1,34 @@
package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/redis/go-redis/v9"
"platform/pkg/rds"
)
func find(ctx context.Context, token string) (*Context, error) {
// 读取认证数据
authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, errors.New("invalid_token")
}
return nil, err
}
// 反序列化
auth := new(Context)
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
return nil, err
}
return auth, nil
}
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
}

View File

@@ -1,14 +1,14 @@
package core package core
type AuthUnAuthorizedErr string type UnAuthorizedErr string
func (e AuthUnAuthorizedErr) Error() string { func (e UnAuthorizedErr) Error() string {
return string(e) return string(e)
} }
type AuthForbiddenErr string type ForbiddenErr string
func (e AuthForbiddenErr) Error() string { func (e ForbiddenErr) Error() string {
return string(e) return string(e)
} }

View File

@@ -5,7 +5,6 @@ import (
"platform/web/auth" "platform/web/auth"
"platform/web/core" "platform/web/core"
q "platform/web/queries" q "platform/web/queries"
"platform/web/services"
) )
// region ListAnnouncements // region ListAnnouncements
@@ -17,7 +16,7 @@ type ListAnnouncementsRequest struct {
func ListAnnouncements(c *fiber.Ctx) error { func ListAnnouncements(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -136,7 +136,7 @@ func Token(c *fiber.Ctx) error {
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent")) token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
if err != nil { if err != nil {
return err return sendError(c, err)
} }
return sendSuccess(c, token) return sendSuccess(c, token)
@@ -211,6 +211,16 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
} }
} }
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
auth.Locals(c, &auth.Context{
Payload: auth.Payload{
Id: client.ID,
Type: auth.PayloadClientConfidential,
Name: client.Name,
Avatar: client.Icon,
},
})
return client, nil return client, nil
} }
@@ -268,7 +278,7 @@ type RevokeReq struct {
} }
func Revoke(c *fiber.Ctx) error { func Revoke(c *fiber.Ctx) error {
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
// 用户未登录 // 用户未登录
return nil return nil
@@ -299,7 +309,7 @@ type IntrospectResp struct {
func Introspect(c *fiber.Ctx) error { func Introspect(c *fiber.Ctx) error {
// 验证权限 // 验证权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -4,7 +4,6 @@ import (
"platform/web/auth" "platform/web/auth"
"platform/web/core" "platform/web/core"
q "platform/web/queries" q "platform/web/queries"
"platform/web/services"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -23,7 +22,7 @@ type ListBillReq struct {
// ListBill 获取账单列表 // ListBill 获取账单列表
func ListBill(c *fiber.Ctx) error { func ListBill(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -22,7 +22,7 @@ type ListChannelsReq struct {
func ListChannels(c *fiber.Ctx) error { func ListChannels(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -99,7 +99,7 @@ type CreateChannelReq struct {
func CreateChannel(c *fiber.Ctx) error { func CreateChannel(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -172,9 +172,9 @@ type RemoveChannelsReq struct {
func RemoveChannels(c *fiber.Ctx) error { func RemoveChannels(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []s.PayloadType{ authCtx, err := auth.Protect(c, []auth.PayloadType{
s.PayloadUser, auth.PayloadUser,
s.PayloadClientConfidential, auth.PayloadClientConfidential,
}, []string{}) }, []string{})
if err != nil { if err != nil {
return err return err

View File

@@ -36,7 +36,7 @@ type IdentifyRes struct {
func Identify(c *fiber.Ctx) error { func Identify(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -27,7 +27,7 @@ type ListResourcePssReq struct {
// ListResourcePss 获取套餐列表 // ListResourcePss 获取套餐列表
func ListResourcePss(c *fiber.Ctx) error { func ListResourcePss(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -102,7 +102,7 @@ type AllResourceReq struct {
func AllResource(c *fiber.Ctx) error { func AllResource(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -158,7 +158,7 @@ type PaidCreateResourceReq struct {
func PrepareResourceByAlipay(c *fiber.Ctx) error { func PrepareResourceByAlipay(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -190,7 +190,7 @@ func PrepareResourceByAlipay(c *fiber.Ctx) error {
func PrepareResourceByWechat(c *fiber.Ctx) error { func PrepareResourceByWechat(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -221,7 +221,7 @@ func PrepareResourceByWechat(c *fiber.Ctx) error {
func CreateResourceByAlipay(c *fiber.Ctx) error { func CreateResourceByAlipay(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -252,7 +252,7 @@ func CreateResourceByAlipay(c *fiber.Ctx) error {
func CreateResourceByWechat(c *fiber.Ctx) error { func CreateResourceByWechat(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -284,7 +284,7 @@ func CreateResourceByWechat(c *fiber.Ctx) error {
func CreateResourceByBalance(c *fiber.Ctx) error { func CreateResourceByBalance(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -24,7 +24,7 @@ type UpdateUserReq struct {
func UpdateUser(c *fiber.Ctx) error { func UpdateUser(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -63,7 +63,7 @@ type UpdateAccountReq struct {
func UpdateAccount(c *fiber.Ctx) error { func UpdateAccount(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -101,7 +101,7 @@ type UpdatePasswordReq struct {
func UpdatePassword(c *fiber.Ctx) error { func UpdatePassword(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -162,7 +162,7 @@ type RechargeConfirmResp struct {
// RechargePrepareAlipay 通过支付宝充值 // RechargePrepareAlipay 通过支付宝充值
func RechargePrepareAlipay(c *fiber.Ctx) error { func RechargePrepareAlipay(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -198,7 +198,7 @@ func RechargePrepareAlipay(c *fiber.Ctx) error {
func RechargeConfirmAlipay(c *fiber.Ctx) error { func RechargeConfirmAlipay(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -229,7 +229,7 @@ func RechargeConfirmAlipay(c *fiber.Ctx) error {
func RechargePrepareWechat(c *fiber.Ctx) error { func RechargePrepareWechat(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -265,7 +265,7 @@ func RechargePrepareWechat(c *fiber.Ctx) error {
func RechargeConfirmWechat(c *fiber.Ctx) error { func RechargeConfirmWechat(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -17,8 +17,8 @@ type VerifierReq struct {
func SmsCode(c *fiber.Ctx) error { func SmsCode(c *fiber.Ctx) error {
_, err := auth.Protect(c, []services.PayloadType{ _, err := auth.Protect(c, []auth.PayloadType{
services.PayloadClientConfidential, auth.PayloadClientConfidential,
}, []string{}) }, []string{})
if err != nil { if err != nil {
return err return err

View File

@@ -7,7 +7,6 @@ import (
g "platform/web/globals" g "platform/web/globals"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"platform/web/services"
"time" "time"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -27,7 +26,7 @@ type ListWhitelistResp struct {
func ListWhitelist(c *fiber.Ctx) error { func ListWhitelist(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -78,7 +77,7 @@ type CreateWhitelistReq struct {
func CreateWhitelist(c *fiber.Ctx) error { func CreateWhitelist(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -112,7 +111,7 @@ type UpdateWhitelistReq struct {
func UpdateWhitelist(c *fiber.Ctx) error { func UpdateWhitelist(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }
@@ -150,7 +149,7 @@ type RemoveWhitelistReq struct {
func RemoveWhitelist(c *fiber.Ctx) error { func RemoveWhitelist(c *fiber.Ctx) error {
// 检查权限 // 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -17,6 +17,7 @@ type LogsRequest struct {
Ua string `gorm:"column:ua;comment:用户代理" json:"ua"` // 用户代理 Ua string `gorm:"column:ua;comment:用户代理" json:"ua"` // 用户代理
Method string `gorm:"column:method;not null;comment:请求方法" json:"method"` // 请求方法 Method string `gorm:"column:method;not null;comment:请求方法" json:"method"` // 请求方法
Path string `gorm:"column:path;not null;comment:请求路径" json:"path"` // 请求路径 Path string `gorm:"column:path;not null;comment:请求路径" json:"path"` // 请求路径
Latency string `gorm:"column:latency;comment:请求延迟" json:"latency"` // 请求延迟
Status int32 `gorm:"column:status;not null;comment:响应状态码" json:"status"` // 响应状态码 Status int32 `gorm:"column:status;not null;comment:响应状态码" json:"status"` // 响应状态码
Error string `gorm:"column:error;comment:错误信息" json:"error"` // 错误信息 Error string `gorm:"column:error;comment:错误信息" json:"error"` // 错误信息
Time core.LocalDateTime `gorm:"column:time;default:CURRENT_TIMESTAMP;comment:请求时间" json:"time"` // 请求时间 Time core.LocalDateTime `gorm:"column:time;default:CURRENT_TIMESTAMP;comment:请求时间" json:"time"` // 请求时间

View File

@@ -34,6 +34,7 @@ func newLogsRequest(db *gorm.DB, opts ...gen.DOOption) logsRequest {
_logsRequest.Ua = field.NewString(tableName, "ua") _logsRequest.Ua = field.NewString(tableName, "ua")
_logsRequest.Method = field.NewString(tableName, "method") _logsRequest.Method = field.NewString(tableName, "method")
_logsRequest.Path = field.NewString(tableName, "path") _logsRequest.Path = field.NewString(tableName, "path")
_logsRequest.Latency = field.NewString(tableName, "latency")
_logsRequest.Status = field.NewInt32(tableName, "status") _logsRequest.Status = field.NewInt32(tableName, "status")
_logsRequest.Error = field.NewString(tableName, "error") _logsRequest.Error = field.NewString(tableName, "error")
_logsRequest.Time = field.NewField(tableName, "time") _logsRequest.Time = field.NewField(tableName, "time")
@@ -54,6 +55,7 @@ type logsRequest struct {
Ua field.String // 用户代理 Ua field.String // 用户代理
Method field.String // 请求方法 Method field.String // 请求方法
Path field.String // 请求路径 Path field.String // 请求路径
Latency field.String // 请求延迟
Status field.Int32 // 响应状态码 Status field.Int32 // 响应状态码
Error field.String // 错误信息 Error field.String // 错误信息
Time field.Field // 请求时间 Time field.Field // 请求时间
@@ -80,6 +82,7 @@ func (l *logsRequest) updateTableName(table string) *logsRequest {
l.Ua = field.NewString(table, "ua") l.Ua = field.NewString(table, "ua")
l.Method = field.NewString(table, "method") l.Method = field.NewString(table, "method")
l.Path = field.NewString(table, "path") l.Path = field.NewString(table, "path")
l.Latency = field.NewString(table, "latency")
l.Status = field.NewInt32(table, "status") l.Status = field.NewInt32(table, "status")
l.Error = field.NewString(table, "error") l.Error = field.NewString(table, "error")
l.Time = field.NewField(table, "time") l.Time = field.NewField(table, "time")
@@ -99,7 +102,7 @@ func (l *logsRequest) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
} }
func (l *logsRequest) fillFieldMap() { func (l *logsRequest) fillFieldMap() {
l.fieldMap = make(map[string]field.Expr, 10) l.fieldMap = make(map[string]field.Expr, 11)
l.fieldMap["id"] = l.ID l.fieldMap["id"] = l.ID
l.fieldMap["identity"] = l.Identity l.fieldMap["identity"] = l.Identity
l.fieldMap["visitor"] = l.Visitor l.fieldMap["visitor"] = l.Visitor
@@ -107,6 +110,7 @@ func (l *logsRequest) fillFieldMap() {
l.fieldMap["ua"] = l.Ua l.fieldMap["ua"] = l.Ua
l.fieldMap["method"] = l.Method l.fieldMap["method"] = l.Method
l.fieldMap["path"] = l.Path l.fieldMap["path"] = l.Path
l.fieldMap["latency"] = l.Latency
l.fieldMap["status"] = l.Status l.fieldMap["status"] = l.Status
l.fieldMap["error"] = l.Error l.fieldMap["error"] = l.Error
l.fieldMap["time"] = l.Time l.fieldMap["time"] = l.Time

View File

@@ -3,6 +3,7 @@ package services
import ( import (
"context" "context"
"errors" "errors"
"platform/web/auth"
"platform/web/core" "platform/web/core"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
@@ -24,14 +25,14 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie
// OauthClientCredentials 验证客户端凭证 // OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
var clientType PayloadType var clientType auth.PayloadType
switch client.Spec { switch client.Spec {
case 1: case 1:
clientType = PayloadClientPublic clientType = auth.PayloadClientPublic
case 2: case 2:
clientType = PayloadClientPublic clientType = auth.PayloadClientPublic
case 3: case 3:
clientType = PayloadClientConfidential clientType = auth.PayloadClientConfidential
} }
var permissions = make(map[string]struct{}, len(scope)) var permissions = make(map[string]struct{}, len(scope))
@@ -40,9 +41,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
} }
// 保存会话并返回令牌 // 保存会话并返回令牌
auth := AuthContext{ authCtx := auth.Context{
Permissions: permissions, Permissions: permissions,
Payload: Payload{ Payload: auth.Payload{
Id: client.ID, Id: client.ID,
Type: clientType, Type: clientType,
Name: client.Name, Name: client.Name,
@@ -50,7 +51,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
} }
// todo 数据库定义会话持续时间 // todo 数据库定义会话持续时间
token, err := Session.Create(ctx, auth, false) token, err := Session.Create(ctx, authCtx, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -136,16 +137,16 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
} }
// 保存到会话 // 保存到会话
auth := AuthContext{ authCtx := auth.Context{
Payload: Payload{ Payload: auth.Payload{
Id: user.ID, Id: user.ID,
Type: PayloadUser, Type: auth.PayloadUser,
Name: user.Name, Name: user.Name,
Avatar: user.Avatar, Avatar: user.Avatar,
}, },
} }
token, err := Session.Create(ctx, auth, data.Remember) token, err := Session.Create(ctx, authCtx, data.Remember)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -13,6 +13,7 @@ import (
"platform/pkg/orm" "platform/pkg/orm"
"platform/pkg/rds" "platform/pkg/rds"
"platform/pkg/u" "platform/pkg/u"
"platform/web/auth"
"platform/web/core" "platform/web/core"
g "platform/web/globals" g "platform/web/globals"
"platform/web/models" "platform/web/models"
@@ -64,7 +65,7 @@ type ResourceInfo struct {
// region RemoveChannel // region RemoveChannel
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error { func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error {
var step = time.Now() var step = time.Now()
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string) var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
@@ -82,8 +83,8 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
// 检查权限,如果为用户操作的话,则只能删除自己的通道 // 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels { for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID { if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
return core.AuthForbiddenErr("无权限访问") return core.ForbiddenErr("无权限访问")
} }
} }
@@ -238,7 +239,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
func (s *channelService) CreateChannel( func (s *channelService) CreateChannel(
ctx context.Context, ctx context.Context,
auth *AuthContext, authCtx *auth.Context,
resourceId int32, resourceId int32,
protocol ChannelProtocol, protocol ChannelProtocol,
authType ChannelAuthType, authType ChannelAuthType,
@@ -283,7 +284,7 @@ func (s *channelService) CreateChannel(
slog.Debug("查找套餐", "rid", rid, "step", time.Since(step)) slog.Debug("查找套餐", "rid", rid, "step", time.Since(step))
// 检查用户权限 // 检查用户权限
err = checkUser(auth, resource, count) err = checkUser(authCtx, resource, count)
if err != nil { if err != nil {
return err return err
} }
@@ -302,7 +303,7 @@ func (s *channelService) CreateChannel(
step = time.Now() step = time.Now()
expiration := core.LocalDateTime(now.Add(time.Duration(resource.Live) * time.Second)) expiration := core.LocalDateTime(now.Add(time.Duration(resource.Live) * time.Second))
_addr, channels, err := assignPort(q, edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter) _addr, channels, err := assignPort(q, edgeAssigns, authCtx.Payload.Id, protocol, authType, expiration, filter)
if err != nil { if err != nil {
return err return err
} }
@@ -356,11 +357,11 @@ func (s *channelService) CreateChannel(
return addr, nil return addr, nil
} }
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { func checkUser(authCtx *auth.Context, resource *ResourceInfo, count int) error {
// 检查使用人 // 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId { if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != resource.UserId {
return core.AuthForbiddenErr("无权限访问") return core.ForbiddenErr("无权限访问")
} }
// 检查套餐状态 // 检查套餐状态

View File

@@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"platform/pkg/testutil" "platform/pkg/testutil"
"platform/web/auth"
"platform/web/core" "platform/web/core"
g "platform/web/globals" g "platform/web/globals"
"platform/web/models" "platform/web/models"
@@ -276,7 +277,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
auth *AuthContext auth *auth.Context
resourceId int32 resourceId int32
protocol ChannelProtocol protocol ChannelProtocol
authType ChannelAuthType authType ChannelAuthType
@@ -286,8 +287,8 @@ func Test_channelService_CreateChannel(t *testing.T) {
// 准备测试数据 // 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}}
mc.AutoQueryMock = func() (g.CloudConnectResp, error) { mc.AutoQueryMock = func() (g.CloudConnectResp, error) {
return g.CloudConnectResp{ return g.CloudConnectResp{
"test-proxy": []g.AutoConfig{ "test-proxy": []g.AutoConfig{
@@ -967,7 +968,7 @@ func Test_channelService_RemoveChannels(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
auth *AuthContext auth *auth.Context
id []int32 id []int32
} }
@@ -989,8 +990,8 @@ func Test_channelService_RemoveChannels(t *testing.T) {
md.Create(adminUser) md.Create(adminUser)
// 认证上下文 // 认证上下文
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}}
// 创建代理 // 创建代理
var proxy = &models.Proxy{ var proxy = &models.Proxy{

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/rds" "platform/pkg/rds"
"platform/web/auth"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -19,9 +20,9 @@ var Session SessionServiceInter = &sessionService{}
type SessionServiceInter interface { type SessionServiceInter interface {
// Find 通过访问令牌获取会话信息 // Find 通过访问令牌获取会话信息
Find(ctx context.Context, token string) (*AuthContext, error) Find(ctx context.Context, token string) (*auth.Context, error)
// Create 创建一个新的会话 // Create 创建一个新的会话
Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error)
// Refresh 刷新一个会话 // Refresh 刷新一个会话
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
// Remove 删除会话 // Remove 删除会话
@@ -41,7 +42,7 @@ var (
type sessionService struct{} type sessionService struct{}
// Find 通过访问令牌获取会话信息 // Find 通过访问令牌获取会话信息
func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) { func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
// 读取认证数据 // 读取认证数据
authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result() authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result()
@@ -53,16 +54,16 @@ func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext,
} }
// 反序列化 // 反序列化
auth := new(AuthContext) authCtx := new(auth.Context)
if err := json.Unmarshal([]byte(authJSON), auth); err != nil { if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil {
return nil, err return nil, err
} }
return auth, nil return authCtx, nil
} }
// Create 创建一个新的会话 // Create 创建一个新的会话
func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
var now = time.Now() var now = time.Now()
// 生成令牌组 // 生成令牌组
@@ -70,14 +71,14 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember
refreshToken := genToken() refreshToken := genToken()
// 序列化认证数据 // 序列化认证数据
authData, err := json.Marshal(auth) authData, err := json.Marshal(authCtx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 序列化刷新令牌数据 // 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{ refreshData, err := json.Marshal(RefreshData{
AuthContext: auth, AuthContext: authCtx,
AccessToken: accessToken, AccessToken: accessToken,
}) })
if err != nil { if err != nil {
@@ -103,7 +104,7 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember
AccessTokenExpires: now.Add(accessExpire), AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken, RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire), RefreshTokenExpires: now.Add(refreshExpire),
Auth: auth, Auth: authCtx,
}, nil }, nil
} }
@@ -205,74 +206,8 @@ func refreshKey(token string) string {
// endregion // endregion
// region AuthContext
// AuthContext 定义认证信息
type AuthContext struct {
Payload Payload `json:"payload"`
Agent Agent `json:"agent,omitempty"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// Payload 定义负载信息
type Payload struct {
Id int32 `json:"id,omitempty"`
Type PayloadType `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Avatar string `json:"avatar,omitempty"`
}
// PayloadType 定义负载类型
type PayloadType int
const (
// PayloadUser 用户类型
PayloadUser PayloadType = iota
// PayloadAdmin 管理员类型
PayloadAdmin
// PayloadClientPublic 公共客户端类型
PayloadClientPublic
// PayloadClientConfidential 机密客户端类型
PayloadClientConfidential
)
func (t PayloadType) Name() string {
switch t {
case PayloadUser:
return "user"
case PayloadAdmin:
return "admn"
case PayloadClientPublic:
return "cpub"
case PayloadClientConfidential:
return "ccnf"
}
return "unknown"
}
type Agent struct {
Id int32 `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
// AnyPermission 检查认证是否包含指定权限
func (a *AuthContext) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
return false
}
for _, permission := range requiredPermission {
if _, ok := a.Permissions[permission]; ok {
return true
}
}
return false
}
// endregion
type RefreshData struct { type RefreshData struct {
AuthContext AuthContext AuthContext auth.Context
AccessToken string AccessToken string
} }
@@ -287,5 +222,5 @@ type TokenDetails struct {
// 刷新令牌过期时间 // 刷新令牌过期时间
RefreshTokenExpires time.Time RefreshTokenExpires time.Time
// 认证信息 // 认证信息
Auth AuthContext Auth auth.Context
} }

View File

@@ -4,17 +4,18 @@ import (
"context" "context"
"errors" "errors"
"platform/pkg/testutil" "platform/pkg/testutil"
"platform/web/auth"
"reflect" "reflect"
"testing" "testing"
"time" "time"
) )
// 创建测试用的认证上下文 // 创建测试用的认证上下文
func createTestAuthContext() AuthContext { func createTestAuthContext() auth.Context {
//goland:noinspection ALL //goland:noinspection ALL
return AuthContext{ return auth.Context{
Payload: Payload{ Payload: auth.Payload{
Type: PayloadUser, Type: auth.PayloadUser,
Id: 1001, Id: 1001,
}, },
Permissions: map[string]struct{}{ Permissions: map[string]struct{}{
@@ -31,11 +32,11 @@ func createTestAuthContext() AuthContext {
func Test_sessionService_Create(t *testing.T) { func Test_sessionService_Create(t *testing.T) {
mr := testutil.SetupRedisTest(t) mr := testutil.SetupRedisTest(t)
ctx := context.Background() ctx := context.Background()
auth := createTestAuthContext() authCtx := createTestAuthContext()
type args struct { type args struct {
ctx context.Context ctx context.Context
auth AuthContext auth auth.Context
} }
tests := []struct { tests := []struct {
name string name string
@@ -47,7 +48,7 @@ func Test_sessionService_Create(t *testing.T) {
name: "创建会话", name: "创建会话",
args: args{ args: args{
ctx: ctx, ctx: ctx,
auth: auth, auth: authCtx,
}, },
want: func(td *TokenDetails) bool { want: func(td *TokenDetails) bool {
// 验证令牌存在且格式正确 // 验证令牌存在且格式正确
@@ -60,7 +61,7 @@ func Test_sessionService_Create(t *testing.T) {
return false return false
} }
// 验证认证信息正确 // 验证认证信息正确
if !reflect.DeepEqual(td.Auth, auth) { if !reflect.DeepEqual(td.Auth, authCtx) {
return false return false
} }
return true return true
@@ -100,11 +101,11 @@ func Test_sessionService_Create(t *testing.T) {
func Test_sessionService_Find(t *testing.T) { func Test_sessionService_Find(t *testing.T) {
testutil.SetupRedisTest(t) testutil.SetupRedisTest(t)
ctx := context.Background() ctx := context.Background()
auth := createTestAuthContext() authCtx := createTestAuthContext()
s := &sessionService{} s := &sessionService{}
// 创建一个有效的会话 // 创建一个有效的会话
td, err := s.Create(ctx, auth, true) td, err := s.Create(ctx, authCtx, true)
if err != nil { if err != nil {
t.Fatalf("无法创建测试会话: %v", err) t.Fatalf("无法创建测试会话: %v", err)
} }
@@ -119,7 +120,7 @@ func Test_sessionService_Find(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
args args args args
want *AuthContext want *auth.Context
wantErr error wantErr error
}{ }{
{ {
@@ -128,7 +129,7 @@ func Test_sessionService_Find(t *testing.T) {
ctx: ctx, ctx: ctx,
token: validToken, token: validToken,
}, },
want: &auth, want: &authCtx,
wantErr: nil, wantErr: nil,
}, },
{ {
@@ -159,11 +160,11 @@ func Test_sessionService_Find(t *testing.T) {
func Test_sessionService_Refresh(t *testing.T) { func Test_sessionService_Refresh(t *testing.T) {
mr := testutil.SetupRedisTest(t) mr := testutil.SetupRedisTest(t)
ctx := context.Background() ctx := context.Background()
auth := createTestAuthContext() authCtx := createTestAuthContext()
s := &sessionService{} s := &sessionService{}
// 创建一个初始会话 // 创建一个初始会话
td, err := s.Create(ctx, auth, true) td, err := s.Create(ctx, authCtx, true)
if err != nil { if err != nil {
t.Fatalf("无法创建初始会话: %v", err) t.Fatalf("无法创建初始会话: %v", err)
} }
@@ -197,7 +198,7 @@ func Test_sessionService_Refresh(t *testing.T) {
return false return false
} }
// 验证认证信息一致 // 验证认证信息一致
if !reflect.DeepEqual(td.Auth, auth) { if !reflect.DeepEqual(td.Auth, authCtx) {
return false return false
} }
return true return true
@@ -251,11 +252,11 @@ func Test_sessionService_Refresh(t *testing.T) {
func Test_sessionService_Remove(t *testing.T) { func Test_sessionService_Remove(t *testing.T) {
mr := testutil.SetupRedisTest(t) mr := testutil.SetupRedisTest(t)
ctx := context.Background() ctx := context.Background()
auth := createTestAuthContext() authCtx := createTestAuthContext()
s := &sessionService{} s := &sessionService{}
// 创建一个会话 // 创建一个会话
td, err := s.Create(ctx, auth, true) td, err := s.Create(ctx, authCtx, true)
if err != nil { if err != nil {
t.Fatalf("无法创建测试会话: %v", err) t.Fatalf("无法创建测试会话: %v", err)
} }
@@ -312,7 +313,7 @@ func Test_sessionService_Remove(t *testing.T) {
func TestAuthContext_AnyPermission(t *testing.T) { func TestAuthContext_AnyPermission(t *testing.T) {
type fields struct { type fields struct {
Payload Payload Payload auth.Payload
Permissions map[string]struct{} Permissions map[string]struct{}
Metadata map[string]interface{} Metadata map[string]interface{}
} }
@@ -328,7 +329,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "用户拥有所需权限", name: "用户拥有所需权限",
fields: fields{ fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1}, Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{ Permissions: map[string]struct{}{
"read": {}, "read": {},
"write": {}, "write": {},
@@ -343,7 +344,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "用户拥有至少一个所需权限", name: "用户拥有至少一个所需权限",
fields: fields{ fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1}, Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{ Permissions: map[string]struct{}{
"read": {}, "read": {},
}, },
@@ -357,7 +358,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "用户没有所需权限", name: "用户没有所需权限",
fields: fields{ fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1}, Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{ Permissions: map[string]struct{}{
"read": {}, "read": {},
}, },
@@ -371,7 +372,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "空权限列表", name: "空权限列表",
fields: fields{ fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1}, Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{}, Permissions: map[string]struct{}{},
Metadata: nil, Metadata: nil,
}, },
@@ -383,7 +384,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "nil权限列表", name: "nil权限列表",
fields: fields{ fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1}, Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: nil, Permissions: nil,
Metadata: nil, Metadata: nil,
}, },
@@ -395,7 +396,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
{ {
name: "nil认证上下文", name: "nil认证上下文",
fields: fields{ fields: fields{
Payload: Payload{}, Payload: auth.Payload{},
Permissions: nil, Permissions: nil,
Metadata: nil, Metadata: nil,
}, },
@@ -408,7 +409,7 @@ func TestAuthContext_AnyPermission(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
a := &AuthContext{ a := &auth.Context{
Payload: tt.fields.Payload, Payload: tt.fields.Payload,
Permissions: tt.fields.Permissions, Permissions: tt.fields.Permissions,
Metadata: tt.fields.Metadata, Metadata: tt.fields.Metadata,

View File

@@ -1,21 +1,27 @@
package web package web
import ( import (
"net/http"
g "platform/web/globals"
"runtime"
"log/slog"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/requestid" "github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jxskiss/base62" "github.com/jxskiss/base62"
"log/slog"
"net/http"
_ "net/http/pprof" _ "net/http/pprof"
"platform/web/auth"
"platform/web/core"
g "platform/web/globals"
m "platform/web/models"
q "platform/web/queries"
"runtime"
"strconv"
"strings"
"time"
) )
// region web
type Config struct { type Config struct {
Listen string Listen string
} }
@@ -51,18 +57,11 @@ func (s *Server) Run() error {
ErrorHandler: ErrorHandler, ErrorHandler: ErrorHandler,
}) })
s.fiber.Use(requestid.New(requestid.Config{ // middlewares
Generator: func() string { s.fiber.Use(useRequestId())
binary, _ := uuid.New().MarshalBinary() s.fiber.Use(useLogger())
return base62.EncodeToString(binary)
},
}))
s.fiber.Use(logger.New(logger.Config{
Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n",
TimeFormat: "2006-01-02 15:04:05",
TimeZone: "Asia/Shanghai",
}))
// routes
ApplyRouters(s.fiber) ApplyRouters(s.fiber)
// pprof // pprof
@@ -91,3 +90,76 @@ func (s *Server) Stop() {
slog.Error("Failed to shutdown server", slog.Any("err", err)) slog.Error("Failed to shutdown server", slog.Any("err", err))
} }
} }
// endregion
// region requestid
func useRequestId() fiber.Handler {
return requestid.New(requestid.Config{
Generator: func() string {
binary, _ := uuid.New().MarshalBinary()
return base62.EncodeToString(binary)
},
})
}
// endregion
// region logger
func useLogger() fiber.Handler {
return logger.New(logger.Config{
Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n",
TimeFormat: "2006-01-02 15:04:05",
TimeZone: "Asia/Shanghai",
Done: func(c *fiber.Ctx, logBytes []byte) {
var logStr = strings.TrimPrefix(string(logBytes), "🚀")
var logVars = strings.Split(logStr, "|")
var reqTimeStr = strings.TrimSpace(logVars[0])
reqTime, err := time.ParseInLocation("2006-01-02 15:04:05", reqTimeStr, time.Local)
if err != nil {
slog.Error("时间解析错误", slog.Any("err", err))
return
}
var authInfo = strings.Split(strings.TrimSpace(logVars[1]), " ")
var authType = auth.PayloadTypeFromStr(strings.TrimSpace(authInfo[0]))
authID, err := strconv.Atoi(strings.TrimSpace(authInfo[1]))
if err != nil {
slog.Error("负载ID解析错误", slog.Any("err", err))
return
}
var latency = strings.TrimSpace(logVars[4])
var errStr = strings.TrimSpace(logVars[5])
var item = &m.LogsRequest{
IP: c.IP(),
Ua: c.Get("User-Agent"),
Method: c.Method(),
Path: c.Path(),
Latency: latency,
Status: int32(c.Response().StatusCode()),
Error: errStr,
Time: core.LocalDateTime(reqTime),
}
if authType != nil {
item.Identity = int32(*authType)
}
if authID != 0 {
item.Visitor = int32(authID)
}
err = q.LogsRequest.Create(item)
if err != nil {
slog.Error("日志记录错误", slog.Any("err", err))
return
}
},
})
}
// endregion