diff --git a/pkg/env/env.go b/pkg/env/env.go index 3c1744f..1dae0ab 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -11,19 +11,24 @@ import ( // region app +const ( + RunModeDev = "debug" + RunModeProd = "production" +) + var ( - RunMode = "debug" // debug, production + RunMode = RunModeDev ) func loadApp() { _RunMode := os.Getenv("RUN_MODE") switch _RunMode { - case "debug", "production": + case RunModeDev, RunModeProd: RunMode = _RunMode case "": break default: - panic("环境变量 RUN_MODE 的值只能是 debug 或 production") + panic("环境变量 RUN_MODE 的值只能是 " + RunModeDev + " 或 " + RunModeProd) } } diff --git a/pkg/logs/logs.go b/pkg/logs/logs.go index 19ed09a..01a0da0 100644 --- a/pkg/logs/logs.go +++ b/pkg/logs/logs.go @@ -16,21 +16,19 @@ func Init() { switch env.RunMode { case "debug": handler = tint.NewHandler(writer, &tint.Options{ - AddSource: true, Level: env.LogLevel, TimeFormat: timeFormat, ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { - err, ok := attr.Value.Any().(error) - if ok { - return tint.Err(err) + switch v := attr.Value.Any().(type) { + case error: + return tint.Err(v) } return attr }, }) case "production": handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{ - AddSource: false, - Level: env.LogLevel, + Level: env.LogLevel, ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { if a.Key == "time" { return slog.String("time", a.Value.Time().Format(timeFormat)) diff --git a/web/auth/authenticate.go b/web/auth/authenticate.go index 9aa736b..fa69752 100644 --- a/web/auth/authenticate.go +++ b/web/auth/authenticate.go @@ -44,13 +44,13 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, var split = strings.Split(header, " ") if len(split) != 2 { slog.Debug("Authorization 头格式不正确") - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } var token = strings.TrimSpace(split[1]) if token == "" { slog.Debug("提供的令牌为空") - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } var auth *Context @@ -61,34 +61,34 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, auth, err = authBearer(c.Context(), token) if err != nil { slog.Debug("Bearer 认证失败", "err", err) - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } case "Basic": if !slices.Contains(types, PayloadInternalServer) { slog.Debug("禁止使用 Basic 认证方式") - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } auth, err = authBasic(c.Context(), token) if err != nil { slog.Debug("Basic 认证失败", "err", err) - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } default: slog.Debug("无效的认证方式", "method", split[0]) - return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") + return nil, ErrUnauthorize } // 检查权限 if !slices.Contains(types, auth.Payload.Type) { slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type) - return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") + return nil, ErrForbidden } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions) - return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") + return nil, ErrForbidden } // 保存到上下文 @@ -116,7 +116,10 @@ func authBasic(_ context.Context, token string) (*Context, error) { // 解析 Basic 认证信息 var base, err = base64.RawURLEncoding.DecodeString(token) if err != nil { - return nil, errors.New("令牌格式错误,无法解析令牌") + base, err = base64.URLEncoding.DecodeString(token) + if err != nil { + return nil, errors.New("令牌格式错误,无法解析令牌") + } } var split = strings.Split(string(base), ":") @@ -158,3 +161,14 @@ func authBasic(_ context.Context, token string) (*Context, error) { Metadata: nil, }, nil } + +type AuthenticationErr string + +func (e AuthenticationErr) Error() string { + return string(e) +} + +var ( + ErrUnauthorize = AuthenticationErr("令牌无效") + ErrForbidden = AuthenticationErr("没有权限") +) diff --git a/web/core/types.go b/web/core/types.go index 5a48d22..68499b0 100644 --- a/web/core/types.go +++ b/web/core/types.go @@ -1,5 +1,11 @@ package core +import ( + "log/slog" + "platform/pkg/env" + "runtime" +) + // region page type PageReq struct { @@ -40,3 +46,56 @@ type PageResp struct { } // endregion + +// region error + +type BizErr struct { + msg string + err error + + errFile string + errLine int + errFunc string +} + +func (e *BizErr) Error() string { + if e.err != nil { + return e.msg + ":" + e.err.Error() + } + return e.msg +} + +func (e *BizErr) Unwrap() error { + return e.err +} + +func (e *BizErr) Source() *slog.Source { + return &slog.Source{ + Function: e.errFunc, + File: e.errFile, + Line: e.errLine, + } +} + +func NewBizErr(msg string, err ...error) (biz *BizErr) { + biz = &BizErr{ + msg: msg, + } + + if len(err) > 0 { + biz.err = err[0] + } + + if env.RunMode == env.RunModeDev { + pc, file, line, ok := runtime.Caller(1) + if ok { + biz.errFile = file + biz.errLine = line + biz.errFunc = runtime.FuncForPC(pc).Name() + } + } + + return biz +} + +// endregion diff --git a/web/error.go b/web/error.go index ffc67d1..13eeeba 100644 --- a/web/error.go +++ b/web/error.go @@ -2,8 +2,9 @@ package web import ( "errors" - "gorm.io/gorm" "log/slog" + "platform/web/auth" + "platform/web/core" "reflect" "github.com/gofiber/fiber/v2" @@ -15,6 +16,8 @@ func ErrorHandler(c *fiber.Ctx, err error) error { var message = "服务器异常" var fiberErr *fiber.Error + var authErr auth.AuthenticationErr + var bizErr *core.BizErr switch { @@ -23,8 +26,23 @@ func ErrorHandler(c *fiber.Ctx, err error) error { code = fiberErr.Code message = fiberErr.Message - // gorm 错误,忽略 - case errors.Is(err, gorm.ErrForeignKeyViolated): + // 认证授权错误 + case errors.As(err, &authErr): + switch { + case errors.Is(err, auth.ErrUnauthorize): + code = fiber.StatusUnauthorized + case errors.Is(err, auth.ErrForbidden): + code = fiber.StatusForbidden + default: + code = fiber.StatusBadRequest + } + message = err.Error() + + // 服务错误 + case errors.As(err, &bizErr): + code = fiber.StatusBadRequest + message = err.Error() + slog.Debug("服务错误", slog.Any(slog.SourceKey, bizErr.Source())) // 所有未手动声明的错误类型 default: diff --git a/web/globals/orm/localdatetime.go b/web/globals/orm/localdatetime.go index 0d184ba..17bbd69 100644 --- a/web/globals/orm/localdatetime.go +++ b/web/globals/orm/localdatetime.go @@ -23,7 +23,6 @@ var formats = []string{ //goland:noinspection GoMixedReceiverTypes func (ldt *LocalDateTime) Scan(value interface{}) (err error) { var t time.Time - if strValue, ok := value.(string); ok { var timeValue time.Time for _, format := range formats { diff --git a/web/handlers/proxy.go b/web/handlers/proxy.go index ca94c9c..b6a3210 100644 --- a/web/handlers/proxy.go +++ b/web/handlers/proxy.go @@ -5,11 +5,15 @@ import ( "encoding/base32" "github.com/gofiber/fiber/v2" "log/slog" + "platform/pkg/u" auth2 "platform/web/auth" proxy2 "platform/web/domains/proxy" g "platform/web/globals" + "platform/web/globals/orm" m "platform/web/models" q "platform/web/queries" + "strings" + "time" "gorm.io/gorm/clause" ) @@ -22,8 +26,9 @@ type OnlineProxyReq struct { } type OnlineProxyResp struct { - Id int32 `json:"id"` - Secret string `json:"secret"` + Id int32 `json:"id"` + Secret string `json:"secret"` + Permits []ProxyPermit `json:"permits"` } func OnlineProxy(c *fiber.Ctx) (err error) { @@ -53,8 +58,8 @@ func OnlineProxy(c *fiber.Ctx) (err error) { var secret = base32.StdEncoding. WithPadding(base32.NoPadding). EncodeToString(secretBytes) - slog.Debug("生成随机密钥", "ip", ip, "secret", secret) + var proxy = &m.Proxy{ Name: req.Name, Version: int32(req.Version), @@ -63,7 +68,7 @@ func OnlineProxy(c *fiber.Ctx) (err error) { Secret: secret, Status: 1, } - err = q.Proxy.Debug(). + err = q.Proxy. Clauses(clause.OnConflict{ UpdateAll: true, Columns: []clause.Column{ @@ -75,10 +80,31 @@ func OnlineProxy(c *fiber.Ctx) (err error) { return err } + channels, err := q.Channel.Where( + q.Channel.ProxyID.Eq(proxy.ID), + q.Channel.Expiration.Gt(orm.LocalDateTime(time.Now())), + ).Find() + if err != nil { + return err + } + + var permits []ProxyPermit + for _, channel := range channels { + permit := ProxyPermit{ + Id: channel.EdgeID, + Expire: time.Time(channel.Expiration), + Whitelists: u.P(strings.Split(channel.Whitelists, ",")), + Username: &channel.Username, + Password: &channel.Password, + } + permits = append(permits, permit) + } + slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID) return c.JSON(&OnlineProxyResp{ - Id: proxy.ID, - Secret: secret, + Id: proxy.ID, + Secret: secret, + Permits: permits, }) } @@ -158,3 +184,11 @@ func AssignProxyFwdPort(c *fiber.Ctx) (err error) { } // endregion + +type ProxyPermit struct { + Id int32 `json:"id"` + Expire time.Time `json:"expire"` + Whitelists *[]string `json:"whitelists"` + Username *string `json:"username"` + Password *string `json:"password"` +} diff --git a/web/router.go b/web/router.go index 442c715..0cd8295 100644 --- a/web/router.go +++ b/web/router.go @@ -1,6 +1,7 @@ package web import ( + "platform/web/core" "platform/web/handlers" "github.com/gofiber/fiber/v2" @@ -78,6 +79,6 @@ func ApplyRouters(app *fiber.App) { // 临时 app.Get("/test", func(c *fiber.Ctx) error { - return c.JSON(c.GetReqHeaders()) + return core.NewBizErr("测试错误") }) } diff --git a/web/services/channel.go b/web/services/channel.go index caacbbe..b0e0821 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -4,13 +4,12 @@ import ( "context" "database/sql" "fmt" - "github.com/hibiken/asynq" - "gorm.io/gen/field" "log/slog" "math" "math/rand/v2" "platform/pkg/env" "platform/pkg/u" + "platform/web/core" channel2 "platform/web/domains/channel" edge2 "platform/web/domains/edge" proxy2 "platform/web/domains/proxy" @@ -24,13 +23,15 @@ import ( "strings" "time" + "github.com/hibiken/asynq" + "gorm.io/gen/field" + "github.com/redis/go-redis/v9" ) var Channel = &channelService{} -type channelService struct { -} +type channelService struct{} // region 删除通道 @@ -46,7 +47,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error { } channels, err := tx.Channel.Where(do).Find() if err != nil { - return err + return core.NewBizErr("查找通道失败", err) } proxyMap := make(map[int32]*m.Proxy) @@ -67,7 +68,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error { // 查找资源 resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find() if err != nil { - return err + return core.NewBizErr("查找资源失败", err) } for _, res := range resources { resourceMap[res.ID] = res @@ -76,7 +77,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error { // 查找代理 proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find() if err != nil { - return err + return core.NewBizErr("查找代理失败", err) } for _, proxy := range proxies { proxyMap[proxy.ID] = proxy @@ -100,10 +101,10 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error { Where(q.Channel.ID.In(id...)). Update(q.Channel.DeletedAt, now) if err != nil { - return err + return core.NewBizErr("删除通道失败", err) } if result.RowsAffected != int64(len(channels)) { - return ChannelServiceErr("删除通道失败") + return core.NewBizErr("删除通道数量不匹配") } // 禁用代理端口并下线用过的节点 @@ -113,7 +114,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error { if len(shortToRemove) > 0 { err := removeShortChannelExternal(proxies, shortToRemove) if err != nil { - return err + return core.NewBizErr("提交删除通道配置失败", err) } } @@ -160,7 +161,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error } proxy, ok := proxyMap[proxyId] if !ok { - return ChannelServiceErr("代理不存在") + return core.NewBizErr("代理不存在") } var secret = strings.Split(proxy.Secret, ":") @@ -173,13 +174,13 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error // 查询节点配置 actives, err := gateway.GatewayPortActive() if err != nil { - return err + return core.NewBizErr("查询节点配置失败", err) } // 更新节点配置 err = gateway.GatewayPortConfigs(configs) if err != nil { - return err + return core.NewBizErr("提交删除通道配置失败", err) } // 下线对应节点 @@ -187,7 +188,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error for portStr, active := range actives { port, err := strconv.Atoi(portStr) if err != nil { - return err + return core.NewBizErr("端口转换失败", err) } key := uint64(proxyId)<<32 | uint64(port) if _, ok := portMap[key]; ok { @@ -200,7 +201,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error Edge: edges, }) if err != nil { - return err + return core.NewBizErr("下线节点失败", err) } } } @@ -232,7 +233,7 @@ func (s *channelService) CreateChannel( // 查找套餐 resource, err = findResource(q, resourceId, userId, count, now) if err != nil { - return err + return core.NewBizErr("查找套餐失败", err) } // 查找白名单 @@ -240,7 +241,7 @@ func (s *channelService) CreateChannel( if authType == ChannelAuthTypeIp { whitelist, err = findWhitelist(q, userId) if err != nil { - return err + return core.NewBizErr("查找白名单失败", err) } } @@ -261,13 +262,13 @@ func (s *channelService) CreateChannel( channels, err = assignLongChannels(q, userId, resourceId, count, config, filter) } if err != nil { - return err + return core.NewBizErr("分配通道失败", err) } // 保存通道开通结果 err = saveAssigns(q, resource, channels, now) if err != nil { - return err + return core.NewBizErr("保存通道失败", err) } return nil @@ -294,7 +295,7 @@ func (s *channelService) CreateChannel( asynq.ProcessIn(duration), ) if err != nil { - return nil, err + return nil, core.NewBizErr("提交异步删除通道任务失败", err) } return channels, nil @@ -352,7 +353,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim // 检查套餐使用情况 switch info.Mode { default: - return nil, ChannelServiceErr("不支持的套餐模式") + return nil, core.NewBizErr("不支持的套餐模式") // 包时 case resource2.ModeTime: @@ -388,10 +389,10 @@ func findWhitelist(q *q.Query, userId int32) ([]string, error) { Select(q.Whitelist.Host). Scan(&whitelist) if err != nil { - return nil, err + return nil, core.NewBizErr("查询白名单失败", err) } if len(whitelist) == 0 { - return nil, ChannelServiceErr("用户没有白名单") + return nil, core.NewBizErr("没有配置白名单") } return whitelist, nil @@ -404,7 +405,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))). Find() if err != nil { - return nil, err + return nil, core.NewBizErr("查找网关失败", err) } // 查找已使用的节点 @@ -424,13 +425,13 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, q.Channel.ProxyID). Find() if err != nil { - return nil, err + return nil, core.NewBizErr("查找已使用的节点失败", err) } // 查询已配置的节点 remoteConfigs, err := g.Cloud.CloudAutoQuery() if err != nil { - return nil, err + return nil, core.NewBizErr("查询远端节点配置失败", err) } // 统计已用节点量与端口查找表 @@ -503,7 +504,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, AutoConfig: newConfigs, }) if err != nil { - return nil, err + return nil, core.NewBizErr("提交节点配置失败", err) } slog.Debug("提交节点配置", @@ -565,7 +566,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, newChannels = append(newChannels, newChannel) } if len(portConfigs) < acc { - return nil, ChannelServiceErr("网关端口数量到达上限,无法分配") + return nil, core.NewBizErr("网关端口数量到达上限,无法分配") } // 提交端口配置 @@ -580,7 +581,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, ) err = gateway.GatewayPortConfigs(portConfigs) if err != nil { - return nil, err + return nil, core.NewBizErr("提交端口配置失败", err) } slog.Debug("提交端口配置", "step", time.Since(step)) @@ -588,7 +589,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, } if len(newChannels) != count { - return nil, ChannelServiceErr("分配节点失败") + return nil, core.NewBizErr("分配节点失败") } return newChannels, nil @@ -630,7 +631,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c Limit(count). Scan(&edges) if err != nil { - return nil, fmt.Errorf("查询符合条件的节点失败: %w", err) + return nil, core.NewBizErr("查询符合条件的节点失败", err) } if len(edges) == 0 { return nil, ErrEdgesNoAvailable @@ -712,7 +713,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c proxy := proxies[id] err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs) if err != nil { - return nil, err + return nil, core.NewBizErr("提交端口配置失败", err) } } slog.Debug("提交端口配置", "step", time.Since(step)) @@ -741,19 +742,14 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now _, err = pipe.Exec(context.Background()) if err != nil { - return err + return core.NewBizErr("缓存通道数据失败", err) } // 保存通道 err = q.Channel. - Omit( - q.Channel.EdgeID, - q.Channel.EdgeHost, - q.Channel.DeletedAt, - ). Create(channels...) if err != nil { - return err + return core.NewBizErr("保存通道失败", err) } // 更新套餐使用记录 @@ -785,7 +781,7 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now ) } if err != nil { - return err + return core.NewBizErr("更新套餐使用记录失败", err) } return nil @@ -841,17 +837,11 @@ type ResourceInfo struct { Expire time.Time } -type ChannelServiceErr string - -func (c ChannelServiceErr) Error() string { - return string(c) -} - -const ( - ErrResourceNotExist = ChannelServiceErr("套餐不存在") - ErrResourceInvalid = ChannelServiceErr("套餐不可用") - ErrResourceExhausted = ChannelServiceErr("套餐已用完") - ErrResourceExpired = ChannelServiceErr("套餐已过期") - ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完") - ErrEdgesNoAvailable = ChannelServiceErr("没有可用的节点") +var ( + ErrResourceNotExist = core.NewBizErr("套餐不存在") + ErrResourceInvalid = core.NewBizErr("套餐不可用") + ErrResourceExhausted = core.NewBizErr("套餐已用完") + ErrResourceExpired = core.NewBizErr("套餐已过期") + ErrResourceDailyLimit = core.NewBizErr("套餐每日配额已用完") + ErrEdgesNoAvailable = core.NewBizErr("没有可用的节点") ) diff --git a/web/services/edge.go b/web/services/edge.go index 819e0c5..920cd10 100644 --- a/web/services/edge.go +++ b/web/services/edge.go @@ -11,7 +11,7 @@ var Edge = &edgeService{} type edgeService struct{} func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) { - do := q.Edge.Where(q.Edge.Status.Eq(1)) + do := q.Edge.Where() if filter.Prov != "" { do = do.Where(q.Edge.Prov.Eq(filter.Prov)) } diff --git a/web/services/resource.go b/web/services/resource.go index a7fe067..66c4da6 100644 --- a/web/services/resource.go +++ b/web/services/resource.go @@ -5,7 +5,6 @@ import ( "database/sql" "encoding/json" "fmt" - "github.com/shopspring/decimal" bill2 "platform/web/domains/bill" resource2 "platform/web/domains/resource" trade2 "platform/web/domains/trade" @@ -14,6 +13,8 @@ import ( m "platform/web/models" q "platform/web/queries" "time" + + "github.com/shopspring/decimal" ) var Resource = &resourceService{} @@ -75,9 +76,6 @@ func (s *resourceService) CreateResource(uid int32, now time.Time, ser *CreateRe if err != nil { return err } - if err != nil { - return err - } return nil }, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})