diff --git a/README.md b/README.md index 9782294..bb67e71 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,9 @@ ## todo -- 规范化日志 - - 接口日志采集 - channel 接口 + - 每个用户-节点为一条数据,联查白名单 - 重新梳理逻辑流程,简化循环 - 端口分配时加锁 - - 每个用户-节点为一条数据,联查白名单 - 长效业务接入 - 微信支付 - 页面 账户总览 diff --git a/cmd/fill/main.go b/cmd/fill/main.go index e2d6c00..a9aa122 100644 --- a/cmd/fill/main.go +++ b/cmd/fill/main.go @@ -54,21 +54,23 @@ func main() { q.Client.ClientSecret, q.Client.GrantClient, q.Client.GrantRefresh, + q.Client.GrantPassword, q.Client.Spec, q.Client.Name, ). Create(&m.Client{ - ClientID: "test", - ClientSecret: string(testSecret), - GrantClient: true, - GrantRefresh: true, - Spec: 3, - Name: "默认客户端", + ClientID: "test", + ClientSecret: string(testSecret), + GrantCode: true, + GrantClient: true, + GrantRefresh: true, + GrantPassword: true, + Spec: 3, + Name: "默认客户端", }, &m.Client{ ClientID: "tasks", ClientSecret: string(tasksSecret), GrantClient: true, - GrantRefresh: true, Spec: 3, Name: "异步任务处理服务", }) diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index e3e8517..d08cb9f 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -34,6 +34,7 @@ create table logs_request ( method varchar(10) not null, path varchar(255) not null, + latency varchar(255), status int not null, error varchar(255), @@ -49,6 +50,7 @@ comment on column logs_request.ip is 'IP地址'; comment on column logs_request.ua is '用户代理'; comment on column logs_request.method is '请求方法'; comment on column logs_request.path is '请求路径'; +comment on column logs_request.latency is '请求延迟'; comment on column logs_request.status is '响应状态码'; comment on column logs_request.error is '错误信息'; comment on column logs_request.time is '请求时间'; diff --git a/web/auth/auth.go b/web/auth/authenticate.go similarity index 73% rename from web/auth/auth.go rename to web/auth/authenticate.go index 0fd674b..8c5c5ba 100644 --- a/web/auth/auth.go +++ b/web/auth/authenticate.go @@ -9,13 +9,11 @@ import ( "slices" "strings" - "platform/web/services" - "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" ) -func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) { +func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) { // 获取令牌 var header = c.Get("Authorization") var split = strings.Split(header, " ") @@ -28,7 +26,7 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } - var auth *services.AuthContext + var auth *Context var err error switch split[0] { @@ -36,23 +34,23 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( auth, err = authBearer(c.Context(), token) if err != nil { slog.Debug("Bearer 认证失败", "err", err) - return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } case "Basic": - if !slices.Contains(types, services.PayloadClientConfidential) { + if !slices.Contains(types, PayloadClientConfidential) { slog.Debug("禁止使用 Basic 认证方式") - return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } auth, err = authBasic(c.Context(), token) if err != nil { slog.Debug("Basic 认证失败", "err", err) - return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } default: slog.Debug("无效的认证方式") - return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") + return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } // 检查权限 @@ -67,15 +65,18 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) ( } // 保存到上下文 - c.Locals("auth", auth) - c.Locals("authid", auth.Payload.Id) - c.Locals("authtype", auth.Payload.Type.Name()) - + Locals(c, auth) return auth, nil } -func authBearer(ctx context.Context, token string) (*services.AuthContext, error) { - auth, err := services.Session.Find(ctx, token) +func Locals(c *fiber.Ctx, auth *Context) { + c.Locals("auth", auth) + c.Locals("authtype", auth.Payload.Type.ToStr()) + c.Locals("authid", auth.Payload.Id) +} + +func authBearer(ctx context.Context, token string) (*Context, error) { + auth, err := find(ctx, token) if err != nil { slog.Debug(err.Error()) return nil, err @@ -83,7 +84,7 @@ func authBearer(ctx context.Context, token string) (*services.AuthContext, error return auth, nil } -func authBasic(_ context.Context, token string) (*services.AuthContext, error) { +func authBasic(_ context.Context, token string) (*Context, error) { // 解析 Basic 认证信息 var base, err = base64.RawURLEncoding.DecodeString(token) @@ -122,10 +123,10 @@ func authBasic(_ context.Context, token string) (*services.AuthContext, error) { // todo 查询客户端关联权限 // 组织授权信息(一次性请求) - return &services.AuthContext{ - Payload: services.Payload{ + return &Context{ + Payload: Payload{ Id: client.ID, - Type: services.PayloadClientConfidential, + Type: PayloadClientConfidential, Name: client.Name, Avatar: client.Icon, }, diff --git a/web/auth/context.go b/web/auth/context.go new file mode 100644 index 0000000..0e92e38 --- /dev/null +++ b/web/auth/context.go @@ -0,0 +1,78 @@ +package auth + +import "platform/pkg/u" + +// Context 定义认证信息 +type Context struct { + Payload Payload `json:"payload"` + Agent Agent `json:"agent,omitempty"` + Permissions map[string]struct{} `json:"permissions,omitempty"` + Metadata map[string]interface{} `json:"metadata,omitempty"` +} + +// AnyPermission 检查认证是否包含指定权限 +func (a *Context) AnyPermission(requiredPermission ...string) bool { + if a == nil || a.Permissions == nil { + return false + } + for _, permission := range requiredPermission { + if _, ok := a.Permissions[permission]; ok { + return true + } + } + return false +} + +// Payload 定义负载信息 +type Payload struct { + Id int32 `json:"id,omitempty"` + Type PayloadType `json:"type,omitempty"` + Name string `json:"name,omitempty"` + Avatar string `json:"avatar,omitempty"` +} + +type Agent struct { + Id int32 `json:"id,omitempty"` + Addr string `json:"addr,omitempty"` +} + +type PayloadType int + +const ( + // PayloadUser 用户类型 + PayloadUser PayloadType = iota + 1 + // PayloadAdmin 管理员类型 + PayloadAdmin + // PayloadClientPublic 公共客户端类型 + PayloadClientPublic + // PayloadClientConfidential 机密客户端类型 + PayloadClientConfidential +) + +func (t PayloadType) ToStr() string { + switch t { + case PayloadUser: + return "user" + case PayloadAdmin: + return "admn" + case PayloadClientPublic: + return "cpub" + case PayloadClientConfidential: + return "ccnf" + } + return "none" +} + +func PayloadTypeFromStr(name string) *PayloadType { + switch name { + case "user": + return u.P(PayloadUser) + case "admn": + return u.P(PayloadAdmin) + case "cpub": + return u.P(PayloadClientPublic) + case "ccnf": + return u.P(PayloadClientConfidential) + } + return nil +} diff --git a/web/auth/session.go b/web/auth/session.go new file mode 100644 index 0000000..01baa9a --- /dev/null +++ b/web/auth/session.go @@ -0,0 +1,34 @@ +package auth + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/redis/go-redis/v9" + "platform/pkg/rds" +) + +func find(ctx context.Context, token string) (*Context, error) { + + // 读取认证数据 + authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, errors.New("invalid_token") + } + return nil, err + } + + // 反序列化 + auth := new(Context) + if err := json.Unmarshal([]byte(authJSON), auth); err != nil { + return nil, err + } + + return auth, nil +} + +func accessKey(token string) string { + return fmt.Sprintf("session:%s", token) +} diff --git a/web/core/errors.go b/web/core/errors.go index b619113..f59decb 100644 --- a/web/core/errors.go +++ b/web/core/errors.go @@ -1,14 +1,14 @@ package core -type AuthUnAuthorizedErr string +type UnAuthorizedErr string -func (e AuthUnAuthorizedErr) Error() string { +func (e UnAuthorizedErr) Error() string { return string(e) } -type AuthForbiddenErr string +type ForbiddenErr string -func (e AuthForbiddenErr) Error() string { +func (e ForbiddenErr) Error() string { return string(e) } diff --git a/web/handlers/announcement.go b/web/handlers/announcement.go index ec9c431..1e08fc2 100644 --- a/web/handlers/announcement.go +++ b/web/handlers/announcement.go @@ -5,7 +5,6 @@ import ( "platform/web/auth" "platform/web/core" q "platform/web/queries" - "platform/web/services" ) // region ListAnnouncements @@ -17,7 +16,7 @@ type ListAnnouncementsRequest struct { func ListAnnouncements(c *fiber.Ctx) error { // 检查权限 - _, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/auth.go b/web/handlers/auth.go index 9de8040..1417b30 100644 --- a/web/handlers/auth.go +++ b/web/handlers/auth.go @@ -136,7 +136,7 @@ func Token(c *fiber.Ctx) error { token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent")) if err != nil { - return err + return sendError(c, err) } return sendSuccess(c, token) @@ -211,6 +211,16 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string } } + // 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑) + auth.Locals(c, &auth.Context{ + Payload: auth.Payload{ + Id: client.ID, + Type: auth.PayloadClientConfidential, + Name: client.Name, + Avatar: client.Icon, + }, + }) + return client, nil } @@ -268,7 +278,7 @@ type RevokeReq struct { } func Revoke(c *fiber.Ctx) error { - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { // 用户未登录 return nil @@ -299,7 +309,7 @@ type IntrospectResp struct { func Introspect(c *fiber.Ctx) error { // 验证权限 - authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/bill.go b/web/handlers/bill.go index e7892c1..4446e05 100644 --- a/web/handlers/bill.go +++ b/web/handlers/bill.go @@ -4,7 +4,6 @@ import ( "platform/web/auth" "platform/web/core" q "platform/web/queries" - "platform/web/services" "time" "github.com/gofiber/fiber/v2" @@ -23,7 +22,7 @@ type ListBillReq struct { // ListBill 获取账单列表 func ListBill(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/channel.go b/web/handlers/channel.go index 58a5698..a763ebd 100644 --- a/web/handlers/channel.go +++ b/web/handlers/channel.go @@ -22,7 +22,7 @@ type ListChannelsReq struct { func ListChannels(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -99,7 +99,7 @@ type CreateChannelReq struct { func CreateChannel(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -172,9 +172,9 @@ type RemoveChannelsReq struct { func RemoveChannels(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []s.PayloadType{ - s.PayloadUser, - s.PayloadClientConfidential, + authCtx, err := auth.Protect(c, []auth.PayloadType{ + auth.PayloadUser, + auth.PayloadClientConfidential, }, []string{}) if err != nil { return err diff --git a/web/handlers/iden.go b/web/handlers/iden.go index bc53755..878e4fc 100644 --- a/web/handlers/iden.go +++ b/web/handlers/iden.go @@ -36,7 +36,7 @@ type IdentifyRes struct { func Identify(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/resource.go b/web/handlers/resource.go index 557bbbb..5ddd013 100644 --- a/web/handlers/resource.go +++ b/web/handlers/resource.go @@ -27,7 +27,7 @@ type ListResourcePssReq struct { // ListResourcePss 获取套餐列表 func ListResourcePss(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -102,7 +102,7 @@ type AllResourceReq struct { func AllResource(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -158,7 +158,7 @@ type PaidCreateResourceReq struct { func PrepareResourceByAlipay(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -190,7 +190,7 @@ func PrepareResourceByAlipay(c *fiber.Ctx) error { func PrepareResourceByWechat(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -221,7 +221,7 @@ func PrepareResourceByWechat(c *fiber.Ctx) error { func CreateResourceByAlipay(c *fiber.Ctx) error { // 检查权限 - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -252,7 +252,7 @@ func CreateResourceByAlipay(c *fiber.Ctx) error { func CreateResourceByWechat(c *fiber.Ctx) error { // 检查权限 - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -284,7 +284,7 @@ func CreateResourceByWechat(c *fiber.Ctx) error { func CreateResourceByBalance(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/user.go b/web/handlers/user.go index 0f1ba4c..2e509b9 100644 --- a/web/handlers/user.go +++ b/web/handlers/user.go @@ -24,7 +24,7 @@ type UpdateUserReq struct { func UpdateUser(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -63,7 +63,7 @@ type UpdateAccountReq struct { func UpdateAccount(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -101,7 +101,7 @@ type UpdatePasswordReq struct { func UpdatePassword(c *fiber.Ctx) error { // 检查权限 - authCtx, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -162,7 +162,7 @@ type RechargeConfirmResp struct { // RechargePrepareAlipay 通过支付宝充值 func RechargePrepareAlipay(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -198,7 +198,7 @@ func RechargePrepareAlipay(c *fiber.Ctx) error { func RechargeConfirmAlipay(c *fiber.Ctx) error { // 检查权限 - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -229,7 +229,7 @@ func RechargeConfirmAlipay(c *fiber.Ctx) error { func RechargePrepareWechat(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -265,7 +265,7 @@ func RechargePrepareWechat(c *fiber.Ctx) error { func RechargeConfirmWechat(c *fiber.Ctx) error { // 检查权限 - _, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/verifier.go b/web/handlers/verifier.go index d366784..4edfb6d 100644 --- a/web/handlers/verifier.go +++ b/web/handlers/verifier.go @@ -17,8 +17,8 @@ type VerifierReq struct { func SmsCode(c *fiber.Ctx) error { - _, err := auth.Protect(c, []services.PayloadType{ - services.PayloadClientConfidential, + _, err := auth.Protect(c, []auth.PayloadType{ + auth.PayloadClientConfidential, }, []string{}) if err != nil { return err diff --git a/web/handlers/whitelist.go b/web/handlers/whitelist.go index 2e10a13..9a602ce 100644 --- a/web/handlers/whitelist.go +++ b/web/handlers/whitelist.go @@ -7,7 +7,6 @@ import ( g "platform/web/globals" m "platform/web/models" q "platform/web/queries" - "platform/web/services" "time" "github.com/gofiber/fiber/v2" @@ -27,7 +26,7 @@ type ListWhitelistResp struct { func ListWhitelist(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -78,7 +77,7 @@ type CreateWhitelistReq struct { func CreateWhitelist(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -112,7 +111,7 @@ type UpdateWhitelistReq struct { func UpdateWhitelist(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } @@ -150,7 +149,7 @@ type RemoveWhitelistReq struct { func RemoveWhitelist(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/models/logs_request.gen.go b/web/models/logs_request.gen.go index e1c046a..598f13d 100644 --- a/web/models/logs_request.gen.go +++ b/web/models/logs_request.gen.go @@ -17,6 +17,7 @@ type LogsRequest struct { Ua string `gorm:"column:ua;comment:用户代理" json:"ua"` // 用户代理 Method string `gorm:"column:method;not null;comment:请求方法" json:"method"` // 请求方法 Path string `gorm:"column:path;not null;comment:请求路径" json:"path"` // 请求路径 + Latency string `gorm:"column:latency;comment:请求延迟" json:"latency"` // 请求延迟 Status int32 `gorm:"column:status;not null;comment:响应状态码" json:"status"` // 响应状态码 Error string `gorm:"column:error;comment:错误信息" json:"error"` // 错误信息 Time core.LocalDateTime `gorm:"column:time;default:CURRENT_TIMESTAMP;comment:请求时间" json:"time"` // 请求时间 diff --git a/web/queries/logs_request.gen.go b/web/queries/logs_request.gen.go index 804f3c2..6ec7576 100644 --- a/web/queries/logs_request.gen.go +++ b/web/queries/logs_request.gen.go @@ -34,6 +34,7 @@ func newLogsRequest(db *gorm.DB, opts ...gen.DOOption) logsRequest { _logsRequest.Ua = field.NewString(tableName, "ua") _logsRequest.Method = field.NewString(tableName, "method") _logsRequest.Path = field.NewString(tableName, "path") + _logsRequest.Latency = field.NewString(tableName, "latency") _logsRequest.Status = field.NewInt32(tableName, "status") _logsRequest.Error = field.NewString(tableName, "error") _logsRequest.Time = field.NewField(tableName, "time") @@ -54,6 +55,7 @@ type logsRequest struct { Ua field.String // 用户代理 Method field.String // 请求方法 Path field.String // 请求路径 + Latency field.String // 请求延迟 Status field.Int32 // 响应状态码 Error field.String // 错误信息 Time field.Field // 请求时间 @@ -80,6 +82,7 @@ func (l *logsRequest) updateTableName(table string) *logsRequest { l.Ua = field.NewString(table, "ua") l.Method = field.NewString(table, "method") l.Path = field.NewString(table, "path") + l.Latency = field.NewString(table, "latency") l.Status = field.NewInt32(table, "status") l.Error = field.NewString(table, "error") l.Time = field.NewField(table, "time") @@ -99,7 +102,7 @@ func (l *logsRequest) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (l *logsRequest) fillFieldMap() { - l.fieldMap = make(map[string]field.Expr, 10) + l.fieldMap = make(map[string]field.Expr, 11) l.fieldMap["id"] = l.ID l.fieldMap["identity"] = l.Identity l.fieldMap["visitor"] = l.Visitor @@ -107,6 +110,7 @@ func (l *logsRequest) fillFieldMap() { l.fieldMap["ua"] = l.Ua l.fieldMap["method"] = l.Method l.fieldMap["path"] = l.Path + l.fieldMap["latency"] = l.Latency l.fieldMap["status"] = l.Status l.fieldMap["error"] = l.Error l.fieldMap["time"] = l.Time diff --git a/web/services/auth.go b/web/services/auth.go index 749dc1d..15102c3 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -3,6 +3,7 @@ package services import ( "context" "errors" + "platform/web/auth" "platform/web/core" m "platform/web/models" q "platform/web/queries" @@ -24,14 +25,14 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie // OauthClientCredentials 验证客户端凭证 func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { - var clientType PayloadType + var clientType auth.PayloadType switch client.Spec { case 1: - clientType = PayloadClientPublic + clientType = auth.PayloadClientPublic case 2: - clientType = PayloadClientPublic + clientType = auth.PayloadClientPublic case 3: - clientType = PayloadClientConfidential + clientType = auth.PayloadClientConfidential } var permissions = make(map[string]struct{}, len(scope)) @@ -40,9 +41,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie } // 保存会话并返回令牌 - auth := AuthContext{ + authCtx := auth.Context{ Permissions: permissions, - Payload: Payload{ + Payload: auth.Payload{ Id: client.ID, Type: clientType, Name: client.Name, @@ -50,7 +51,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie } // todo 数据库定义会话持续时间 - token, err := Session.Create(ctx, auth, false) + token, err := Session.Create(ctx, authCtx, false) if err != nil { return nil, err } @@ -136,16 +137,16 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran } // 保存到会话 - auth := AuthContext{ - Payload: Payload{ + authCtx := auth.Context{ + Payload: auth.Payload{ Id: user.ID, - Type: PayloadUser, + Type: auth.PayloadUser, Name: user.Name, Avatar: user.Avatar, }, } - token, err := Session.Create(ctx, auth, data.Remember) + token, err := Session.Create(ctx, authCtx, data.Remember) if err != nil { return nil, err } diff --git a/web/services/channel.go b/web/services/channel.go index dfc40ae..5ae5f64 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -13,6 +13,7 @@ import ( "platform/pkg/orm" "platform/pkg/rds" "platform/pkg/u" + "platform/web/auth" "platform/web/core" g "platform/web/globals" "platform/web/models" @@ -64,7 +65,7 @@ type ResourceInfo struct { // region RemoveChannel -func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error { +func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error { var step = time.Now() var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string) @@ -82,8 +83,8 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, // 检查权限,如果为用户操作的话,则只能删除自己的通道 for _, channel := range channels { - if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID { - return core.AuthForbiddenErr("无权限访问") + if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID { + return core.ForbiddenErr("无权限访问") } } @@ -238,7 +239,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, func (s *channelService) CreateChannel( ctx context.Context, - auth *AuthContext, + authCtx *auth.Context, resourceId int32, protocol ChannelProtocol, authType ChannelAuthType, @@ -283,7 +284,7 @@ func (s *channelService) CreateChannel( slog.Debug("查找套餐", "rid", rid, "step", time.Since(step)) // 检查用户权限 - err = checkUser(auth, resource, count) + err = checkUser(authCtx, resource, count) if err != nil { return err } @@ -302,7 +303,7 @@ func (s *channelService) CreateChannel( step = time.Now() expiration := core.LocalDateTime(now.Add(time.Duration(resource.Live) * time.Second)) - _addr, channels, err := assignPort(q, edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter) + _addr, channels, err := assignPort(q, edgeAssigns, authCtx.Payload.Id, protocol, authType, expiration, filter) if err != nil { return err } @@ -356,11 +357,11 @@ func (s *channelService) CreateChannel( return addr, nil } -func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { +func checkUser(authCtx *auth.Context, resource *ResourceInfo, count int) error { // 检查使用人 - if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId { - return core.AuthForbiddenErr("无权限访问") + if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != resource.UserId { + return core.ForbiddenErr("无权限访问") } // 检查套餐状态 diff --git a/web/services/channel_test.go b/web/services/channel_test.go index 44f6ca0..27082fe 100644 --- a/web/services/channel_test.go +++ b/web/services/channel_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "platform/pkg/testutil" + "platform/web/auth" "platform/web/core" g "platform/web/globals" "platform/web/models" @@ -276,7 +277,7 @@ func Test_channelService_CreateChannel(t *testing.T) { type args struct { ctx context.Context - auth *AuthContext + auth *auth.Context resourceId int32 protocol ChannelProtocol authType ChannelAuthType @@ -286,8 +287,8 @@ func Test_channelService_CreateChannel(t *testing.T) { // 准备测试数据 ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") - var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} - var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} + var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} + var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} mc.AutoQueryMock = func() (g.CloudConnectResp, error) { return g.CloudConnectResp{ "test-proxy": []g.AutoConfig{ @@ -967,7 +968,7 @@ func Test_channelService_RemoveChannels(t *testing.T) { type args struct { ctx context.Context - auth *AuthContext + auth *auth.Context id []int32 } @@ -989,8 +990,8 @@ func Test_channelService_RemoveChannels(t *testing.T) { md.Create(adminUser) // 认证上下文 - var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} - var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} + var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} + var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} // 创建代理 var proxy = &models.Proxy{ diff --git a/web/services/session.go b/web/services/session.go index f2ec7fb..8b7a816 100644 --- a/web/services/session.go +++ b/web/services/session.go @@ -7,6 +7,7 @@ import ( "fmt" "platform/pkg/env" "platform/pkg/rds" + "platform/web/auth" "time" "github.com/google/uuid" @@ -19,9 +20,9 @@ var Session SessionServiceInter = &sessionService{} type SessionServiceInter interface { // Find 通过访问令牌获取会话信息 - Find(ctx context.Context, token string) (*AuthContext, error) + Find(ctx context.Context, token string) (*auth.Context, error) // Create 创建一个新的会话 - Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) + Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) // Refresh 刷新一个会话 Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) // Remove 删除会话 @@ -41,7 +42,7 @@ var ( type sessionService struct{} // Find 通过访问令牌获取会话信息 -func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) { +func (s *sessionService) Find(ctx context.Context, token string) (*auth.Context, error) { // 读取认证数据 authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result() @@ -53,16 +54,16 @@ func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, } // 反序列化 - auth := new(AuthContext) - if err := json.Unmarshal([]byte(authJSON), auth); err != nil { + authCtx := new(auth.Context) + if err := json.Unmarshal([]byte(authJSON), authCtx); err != nil { return nil, err } - return auth, nil + return authCtx, nil } // Create 创建一个新的会话 -func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { +func (s *sessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) { var now = time.Now() // 生成令牌组 @@ -70,14 +71,14 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember refreshToken := genToken() // 序列化认证数据 - authData, err := json.Marshal(auth) + authData, err := json.Marshal(authCtx) if err != nil { return nil, err } // 序列化刷新令牌数据 refreshData, err := json.Marshal(RefreshData{ - AuthContext: auth, + AuthContext: authCtx, AccessToken: accessToken, }) if err != nil { @@ -103,7 +104,7 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember AccessTokenExpires: now.Add(accessExpire), RefreshToken: refreshToken, RefreshTokenExpires: now.Add(refreshExpire), - Auth: auth, + Auth: authCtx, }, nil } @@ -205,74 +206,8 @@ func refreshKey(token string) string { // endregion -// region AuthContext - -// AuthContext 定义认证信息 -type AuthContext struct { - Payload Payload `json:"payload"` - Agent Agent `json:"agent,omitempty"` - Permissions map[string]struct{} `json:"permissions,omitempty"` - Metadata map[string]interface{} `json:"metadata,omitempty"` -} - -// Payload 定义负载信息 -type Payload struct { - Id int32 `json:"id,omitempty"` - Type PayloadType `json:"type,omitempty"` - Name string `json:"name,omitempty"` - Avatar string `json:"avatar,omitempty"` -} - -// PayloadType 定义负载类型 -type PayloadType int - -const ( - // PayloadUser 用户类型 - PayloadUser PayloadType = iota - // PayloadAdmin 管理员类型 - PayloadAdmin - // PayloadClientPublic 公共客户端类型 - PayloadClientPublic - // PayloadClientConfidential 机密客户端类型 - PayloadClientConfidential -) - -func (t PayloadType) Name() string { - switch t { - case PayloadUser: - return "user" - case PayloadAdmin: - return "admn" - case PayloadClientPublic: - return "cpub" - case PayloadClientConfidential: - return "ccnf" - } - return "unknown" -} - -type Agent struct { - Id int32 `json:"id,omitempty"` - Addr string `json:"addr,omitempty"` -} - -// AnyPermission 检查认证是否包含指定权限 -func (a *AuthContext) AnyPermission(requiredPermission ...string) bool { - if a == nil || a.Permissions == nil { - return false - } - for _, permission := range requiredPermission { - if _, ok := a.Permissions[permission]; ok { - return true - } - } - return false -} - -// endregion - type RefreshData struct { - AuthContext AuthContext + AuthContext auth.Context AccessToken string } @@ -287,5 +222,5 @@ type TokenDetails struct { // 刷新令牌过期时间 RefreshTokenExpires time.Time // 认证信息 - Auth AuthContext + Auth auth.Context } diff --git a/web/services/session_test.go b/web/services/session_test.go index 16aa120..31d4e11 100644 --- a/web/services/session_test.go +++ b/web/services/session_test.go @@ -4,17 +4,18 @@ import ( "context" "errors" "platform/pkg/testutil" + "platform/web/auth" "reflect" "testing" "time" ) // 创建测试用的认证上下文 -func createTestAuthContext() AuthContext { +func createTestAuthContext() auth.Context { //goland:noinspection ALL - return AuthContext{ - Payload: Payload{ - Type: PayloadUser, + return auth.Context{ + Payload: auth.Payload{ + Type: auth.PayloadUser, Id: 1001, }, Permissions: map[string]struct{}{ @@ -31,11 +32,11 @@ func createTestAuthContext() AuthContext { func Test_sessionService_Create(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() - auth := createTestAuthContext() + authCtx := createTestAuthContext() type args struct { ctx context.Context - auth AuthContext + auth auth.Context } tests := []struct { name string @@ -47,7 +48,7 @@ func Test_sessionService_Create(t *testing.T) { name: "创建会话", args: args{ ctx: ctx, - auth: auth, + auth: authCtx, }, want: func(td *TokenDetails) bool { // 验证令牌存在且格式正确 @@ -60,7 +61,7 @@ func Test_sessionService_Create(t *testing.T) { return false } // 验证认证信息正确 - if !reflect.DeepEqual(td.Auth, auth) { + if !reflect.DeepEqual(td.Auth, authCtx) { return false } return true @@ -100,11 +101,11 @@ func Test_sessionService_Create(t *testing.T) { func Test_sessionService_Find(t *testing.T) { testutil.SetupRedisTest(t) ctx := context.Background() - auth := createTestAuthContext() + authCtx := createTestAuthContext() s := &sessionService{} // 创建一个有效的会话 - td, err := s.Create(ctx, auth, true) + td, err := s.Create(ctx, authCtx, true) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } @@ -119,7 +120,7 @@ func Test_sessionService_Find(t *testing.T) { tests := []struct { name string args args - want *AuthContext + want *auth.Context wantErr error }{ { @@ -128,7 +129,7 @@ func Test_sessionService_Find(t *testing.T) { ctx: ctx, token: validToken, }, - want: &auth, + want: &authCtx, wantErr: nil, }, { @@ -159,11 +160,11 @@ func Test_sessionService_Find(t *testing.T) { func Test_sessionService_Refresh(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() - auth := createTestAuthContext() + authCtx := createTestAuthContext() s := &sessionService{} // 创建一个初始会话 - td, err := s.Create(ctx, auth, true) + td, err := s.Create(ctx, authCtx, true) if err != nil { t.Fatalf("无法创建初始会话: %v", err) } @@ -197,7 +198,7 @@ func Test_sessionService_Refresh(t *testing.T) { return false } // 验证认证信息一致 - if !reflect.DeepEqual(td.Auth, auth) { + if !reflect.DeepEqual(td.Auth, authCtx) { return false } return true @@ -251,11 +252,11 @@ func Test_sessionService_Refresh(t *testing.T) { func Test_sessionService_Remove(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() - auth := createTestAuthContext() + authCtx := createTestAuthContext() s := &sessionService{} // 创建一个会话 - td, err := s.Create(ctx, auth, true) + td, err := s.Create(ctx, authCtx, true) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } @@ -312,7 +313,7 @@ func Test_sessionService_Remove(t *testing.T) { func TestAuthContext_AnyPermission(t *testing.T) { type fields struct { - Payload Payload + Payload auth.Payload Permissions map[string]struct{} Metadata map[string]interface{} } @@ -328,7 +329,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "用户拥有所需权限", fields: fields{ - Payload: Payload{Type: PayloadUser, Id: 1}, + Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, "write": {}, @@ -343,7 +344,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "用户拥有至少一个所需权限", fields: fields{ - Payload: Payload{Type: PayloadUser, Id: 1}, + Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, }, @@ -357,7 +358,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "用户没有所需权限", fields: fields{ - Payload: Payload{Type: PayloadUser, Id: 1}, + Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, }, @@ -371,7 +372,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "空权限列表", fields: fields{ - Payload: Payload{Type: PayloadUser, Id: 1}, + Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, Permissions: map[string]struct{}{}, Metadata: nil, }, @@ -383,7 +384,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "nil权限列表", fields: fields{ - Payload: Payload{Type: PayloadUser, Id: 1}, + Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, Permissions: nil, Metadata: nil, }, @@ -395,7 +396,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { { name: "nil认证上下文", fields: fields{ - Payload: Payload{}, + Payload: auth.Payload{}, Permissions: nil, Metadata: nil, }, @@ -408,7 +409,7 @@ func TestAuthContext_AnyPermission(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := &AuthContext{ + a := &auth.Context{ Payload: tt.fields.Payload, Permissions: tt.fields.Permissions, Metadata: tt.fields.Metadata, diff --git a/web/web.go b/web/web.go index 1749c23..7daabd7 100644 --- a/web/web.go +++ b/web/web.go @@ -1,21 +1,27 @@ package web import ( - "net/http" - g "platform/web/globals" - "runtime" - - "log/slog" - "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/requestid" "github.com/google/uuid" "github.com/jxskiss/base62" - + "log/slog" + "net/http" _ "net/http/pprof" + "platform/web/auth" + "platform/web/core" + g "platform/web/globals" + m "platform/web/models" + q "platform/web/queries" + "runtime" + "strconv" + "strings" + "time" ) +// region web + type Config struct { Listen string } @@ -51,18 +57,11 @@ func (s *Server) Run() error { ErrorHandler: ErrorHandler, }) - s.fiber.Use(requestid.New(requestid.Config{ - Generator: func() string { - binary, _ := uuid.New().MarshalBinary() - return base62.EncodeToString(binary) - }, - })) - s.fiber.Use(logger.New(logger.Config{ - Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n", - TimeFormat: "2006-01-02 15:04:05", - TimeZone: "Asia/Shanghai", - })) + // middlewares + s.fiber.Use(useRequestId()) + s.fiber.Use(useLogger()) + // routes ApplyRouters(s.fiber) // pprof @@ -91,3 +90,76 @@ func (s *Server) Stop() { slog.Error("Failed to shutdown server", slog.Any("err", err)) } } + +// endregion + +// region requestid + +func useRequestId() fiber.Handler { + return requestid.New(requestid.Config{ + Generator: func() string { + binary, _ := uuid.New().MarshalBinary() + return base62.EncodeToString(binary) + }, + }) +} + +// endregion + +// region logger + +func useLogger() fiber.Handler { + return logger.New(logger.Config{ + Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n", + TimeFormat: "2006-01-02 15:04:05", + TimeZone: "Asia/Shanghai", + Done: func(c *fiber.Ctx, logBytes []byte) { + var logStr = strings.TrimPrefix(string(logBytes), "🚀") + var logVars = strings.Split(logStr, "|") + + var reqTimeStr = strings.TrimSpace(logVars[0]) + reqTime, err := time.ParseInLocation("2006-01-02 15:04:05", reqTimeStr, time.Local) + if err != nil { + slog.Error("时间解析错误", slog.Any("err", err)) + return + } + + var authInfo = strings.Split(strings.TrimSpace(logVars[1]), " ") + var authType = auth.PayloadTypeFromStr(strings.TrimSpace(authInfo[0])) + authID, err := strconv.Atoi(strings.TrimSpace(authInfo[1])) + if err != nil { + slog.Error("负载ID解析错误", slog.Any("err", err)) + return + } + + var latency = strings.TrimSpace(logVars[4]) + + var errStr = strings.TrimSpace(logVars[5]) + + var item = &m.LogsRequest{ + IP: c.IP(), + Ua: c.Get("User-Agent"), + Method: c.Method(), + Path: c.Path(), + Latency: latency, + Status: int32(c.Response().StatusCode()), + Error: errStr, + Time: core.LocalDateTime(reqTime), + } + if authType != nil { + item.Identity = int32(*authType) + } + if authID != 0 { + item.Visitor = int32(authID) + } + + err = q.LogsRequest.Create(item) + if err != nil { + slog.Error("日志记录错误", slog.Any("err", err)) + return + } + }, + }) +} + +// endregion