完善错误处理逻辑,统一使用 BizErr 包装业务错误,提供打印源码跳转并返回合适的 http 状态码

This commit is contained in:
2025-05-24 12:37:16 +08:00
parent 928d78d41b
commit 1e7b5777a2
11 changed files with 203 additions and 87 deletions

View File

@@ -44,13 +44,13 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
var split = strings.Split(header, " ")
if len(split) != 2 {
slog.Debug("Authorization 头格式不正确")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
var token = strings.TrimSpace(split[1])
if token == "" {
slog.Debug("提供的令牌为空")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
var auth *Context
@@ -61,34 +61,34 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
auth, err = authBearer(c.Context(), token)
if err != nil {
slog.Debug("Bearer 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
case "Basic":
if !slices.Contains(types, PayloadInternalServer) {
slog.Debug("禁止使用 Basic 认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
auth, err = authBasic(c.Context(), token)
if err != nil {
slog.Debug("Basic 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
default:
slog.Debug("无效的认证方式", "method", split[0])
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
return nil, ErrUnauthorize
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
return nil, ErrForbidden
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
return nil, ErrForbidden
}
// 保存到上下文
@@ -116,7 +116,10 @@ func authBasic(_ context.Context, token string) (*Context, error) {
// 解析 Basic 认证信息
var base, err = base64.RawURLEncoding.DecodeString(token)
if err != nil {
return nil, errors.New("令牌格式错误,无法解析令牌")
base, err = base64.URLEncoding.DecodeString(token)
if err != nil {
return nil, errors.New("令牌格式错误,无法解析令牌")
}
}
var split = strings.Split(string(base), ":")
@@ -158,3 +161,14 @@ func authBasic(_ context.Context, token string) (*Context, error) {
Metadata: nil,
}, nil
}
type AuthenticationErr string
func (e AuthenticationErr) Error() string {
return string(e)
}
var (
ErrUnauthorize = AuthenticationErr("令牌无效")
ErrForbidden = AuthenticationErr("没有权限")
)

View File

@@ -1,5 +1,11 @@
package core
import (
"log/slog"
"platform/pkg/env"
"runtime"
)
// region page
type PageReq struct {
@@ -40,3 +46,56 @@ type PageResp struct {
}
// endregion
// region error
type BizErr struct {
msg string
err error
errFile string
errLine int
errFunc string
}
func (e *BizErr) Error() string {
if e.err != nil {
return e.msg + "" + e.err.Error()
}
return e.msg
}
func (e *BizErr) Unwrap() error {
return e.err
}
func (e *BizErr) Source() *slog.Source {
return &slog.Source{
Function: e.errFunc,
File: e.errFile,
Line: e.errLine,
}
}
func NewBizErr(msg string, err ...error) (biz *BizErr) {
biz = &BizErr{
msg: msg,
}
if len(err) > 0 {
biz.err = err[0]
}
if env.RunMode == env.RunModeDev {
pc, file, line, ok := runtime.Caller(1)
if ok {
biz.errFile = file
biz.errLine = line
biz.errFunc = runtime.FuncForPC(pc).Name()
}
}
return biz
}
// endregion

View File

@@ -2,8 +2,9 @@ package web
import (
"errors"
"gorm.io/gorm"
"log/slog"
"platform/web/auth"
"platform/web/core"
"reflect"
"github.com/gofiber/fiber/v2"
@@ -15,6 +16,8 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
var message = "服务器异常"
var fiberErr *fiber.Error
var authErr auth.AuthenticationErr
var bizErr *core.BizErr
switch {
@@ -23,8 +26,23 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
code = fiberErr.Code
message = fiberErr.Message
// gorm 错误,忽略
case errors.Is(err, gorm.ErrForeignKeyViolated):
// 认证授权错误
case errors.As(err, &authErr):
switch {
case errors.Is(err, auth.ErrUnauthorize):
code = fiber.StatusUnauthorized
case errors.Is(err, auth.ErrForbidden):
code = fiber.StatusForbidden
default:
code = fiber.StatusBadRequest
}
message = err.Error()
// 服务错误
case errors.As(err, &bizErr):
code = fiber.StatusBadRequest
message = err.Error()
slog.Debug("服务错误", slog.Any(slog.SourceKey, bizErr.Source()))
// 所有未手动声明的错误类型
default:

View File

@@ -23,7 +23,6 @@ var formats = []string{
//goland:noinspection GoMixedReceiverTypes
func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
var t time.Time
if strValue, ok := value.(string); ok {
var timeValue time.Time
for _, format := range formats {

View File

@@ -5,11 +5,15 @@ import (
"encoding/base32"
"github.com/gofiber/fiber/v2"
"log/slog"
"platform/pkg/u"
auth2 "platform/web/auth"
proxy2 "platform/web/domains/proxy"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"strings"
"time"
"gorm.io/gorm/clause"
)
@@ -22,8 +26,9 @@ type OnlineProxyReq struct {
}
type OnlineProxyResp struct {
Id int32 `json:"id"`
Secret string `json:"secret"`
Id int32 `json:"id"`
Secret string `json:"secret"`
Permits []ProxyPermit `json:"permits"`
}
func OnlineProxy(c *fiber.Ctx) (err error) {
@@ -53,8 +58,8 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
var secret = base32.StdEncoding.
WithPadding(base32.NoPadding).
EncodeToString(secretBytes)
slog.Debug("生成随机密钥", "ip", ip, "secret", secret)
var proxy = &m.Proxy{
Name: req.Name,
Version: int32(req.Version),
@@ -63,7 +68,7 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
Secret: secret,
Status: 1,
}
err = q.Proxy.Debug().
err = q.Proxy.
Clauses(clause.OnConflict{
UpdateAll: true,
Columns: []clause.Column{
@@ -75,10 +80,31 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
return err
}
channels, err := q.Channel.Where(
q.Channel.ProxyID.Eq(proxy.ID),
q.Channel.Expiration.Gt(orm.LocalDateTime(time.Now())),
).Find()
if err != nil {
return err
}
var permits []ProxyPermit
for _, channel := range channels {
permit := ProxyPermit{
Id: channel.EdgeID,
Expire: time.Time(channel.Expiration),
Whitelists: u.P(strings.Split(channel.Whitelists, ",")),
Username: &channel.Username,
Password: &channel.Password,
}
permits = append(permits, permit)
}
slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID)
return c.JSON(&OnlineProxyResp{
Id: proxy.ID,
Secret: secret,
Id: proxy.ID,
Secret: secret,
Permits: permits,
})
}
@@ -158,3 +184,11 @@ func AssignProxyFwdPort(c *fiber.Ctx) (err error) {
}
// endregion
type ProxyPermit struct {
Id int32 `json:"id"`
Expire time.Time `json:"expire"`
Whitelists *[]string `json:"whitelists"`
Username *string `json:"username"`
Password *string `json:"password"`
}

View File

@@ -1,6 +1,7 @@
package web
import (
"platform/web/core"
"platform/web/handlers"
"github.com/gofiber/fiber/v2"
@@ -78,6 +79,6 @@ func ApplyRouters(app *fiber.App) {
// 临时
app.Get("/test", func(c *fiber.Ctx) error {
return c.JSON(c.GetReqHeaders())
return core.NewBizErr("测试错误")
})
}

View File

@@ -4,13 +4,12 @@ import (
"context"
"database/sql"
"fmt"
"github.com/hibiken/asynq"
"gorm.io/gen/field"
"log/slog"
"math"
"math/rand/v2"
"platform/pkg/env"
"platform/pkg/u"
"platform/web/core"
channel2 "platform/web/domains/channel"
edge2 "platform/web/domains/edge"
proxy2 "platform/web/domains/proxy"
@@ -24,13 +23,15 @@ import (
"strings"
"time"
"github.com/hibiken/asynq"
"gorm.io/gen/field"
"github.com/redis/go-redis/v9"
)
var Channel = &channelService{}
type channelService struct {
}
type channelService struct{}
// region 删除通道
@@ -46,7 +47,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
}
channels, err := tx.Channel.Where(do).Find()
if err != nil {
return err
return core.NewBizErr("查找通道失败", err)
}
proxyMap := make(map[int32]*m.Proxy)
@@ -67,7 +68,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
// 查找资源
resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find()
if err != nil {
return err
return core.NewBizErr("查找资源失败", err)
}
for _, res := range resources {
resourceMap[res.ID] = res
@@ -76,7 +77,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
// 查找代理
proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find()
if err != nil {
return err
return core.NewBizErr("查找代理失败", err)
}
for _, proxy := range proxies {
proxyMap[proxy.ID] = proxy
@@ -100,10 +101,10 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
Where(q.Channel.ID.In(id...)).
Update(q.Channel.DeletedAt, now)
if err != nil {
return err
return core.NewBizErr("删除通道失败", err)
}
if result.RowsAffected != int64(len(channels)) {
return ChannelServiceErr("删除通道失败")
return core.NewBizErr("删除通道数量不匹配")
}
// 禁用代理端口并下线用过的节点
@@ -113,7 +114,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
if len(shortToRemove) > 0 {
err := removeShortChannelExternal(proxies, shortToRemove)
if err != nil {
return err
return core.NewBizErr("提交删除通道配置失败", err)
}
}
@@ -160,7 +161,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
}
proxy, ok := proxyMap[proxyId]
if !ok {
return ChannelServiceErr("代理不存在")
return core.NewBizErr("代理不存在")
}
var secret = strings.Split(proxy.Secret, ":")
@@ -173,13 +174,13 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
// 查询节点配置
actives, err := gateway.GatewayPortActive()
if err != nil {
return err
return core.NewBizErr("查询节点配置失败", err)
}
// 更新节点配置
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return err
return core.NewBizErr("提交删除通道配置失败", err)
}
// 下线对应节点
@@ -187,7 +188,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
for portStr, active := range actives {
port, err := strconv.Atoi(portStr)
if err != nil {
return err
return core.NewBizErr("端口转换失败", err)
}
key := uint64(proxyId)<<32 | uint64(port)
if _, ok := portMap[key]; ok {
@@ -200,7 +201,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
Edge: edges,
})
if err != nil {
return err
return core.NewBizErr("下线节点失败", err)
}
}
}
@@ -232,7 +233,7 @@ func (s *channelService) CreateChannel(
// 查找套餐
resource, err = findResource(q, resourceId, userId, count, now)
if err != nil {
return err
return core.NewBizErr("查找套餐失败", err)
}
// 查找白名单
@@ -240,7 +241,7 @@ func (s *channelService) CreateChannel(
if authType == ChannelAuthTypeIp {
whitelist, err = findWhitelist(q, userId)
if err != nil {
return err
return core.NewBizErr("查找白名单失败", err)
}
}
@@ -261,13 +262,13 @@ func (s *channelService) CreateChannel(
channels, err = assignLongChannels(q, userId, resourceId, count, config, filter)
}
if err != nil {
return err
return core.NewBizErr("分配通道失败", err)
}
// 保存通道开通结果
err = saveAssigns(q, resource, channels, now)
if err != nil {
return err
return core.NewBizErr("保存通道失败", err)
}
return nil
@@ -294,7 +295,7 @@ func (s *channelService) CreateChannel(
asynq.ProcessIn(duration),
)
if err != nil {
return nil, err
return nil, core.NewBizErr("提交异步删除通道任务失败", err)
}
return channels, nil
@@ -352,7 +353,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim
// 检查套餐使用情况
switch info.Mode {
default:
return nil, ChannelServiceErr("不支持的套餐模式")
return nil, core.NewBizErr("不支持的套餐模式")
// 包时
case resource2.ModeTime:
@@ -388,10 +389,10 @@ func findWhitelist(q *q.Query, userId int32) ([]string, error) {
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, err
return nil, core.NewBizErr("查询白名单失败", err)
}
if len(whitelist) == 0 {
return nil, ChannelServiceErr("用户没有白名单")
return nil, core.NewBizErr("没有配置白名单")
}
return whitelist, nil
@@ -404,7 +405,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
Find()
if err != nil {
return nil, err
return nil, core.NewBizErr("查找网关失败", err)
}
// 查找已使用的节点
@@ -424,13 +425,13 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
q.Channel.ProxyID).
Find()
if err != nil {
return nil, err
return nil, core.NewBizErr("查找已使用的节点失败", err)
}
// 查询已配置的节点
remoteConfigs, err := g.Cloud.CloudAutoQuery()
if err != nil {
return nil, err
return nil, core.NewBizErr("查询远端节点配置失败", err)
}
// 统计已用节点量与端口查找表
@@ -503,7 +504,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
AutoConfig: newConfigs,
})
if err != nil {
return nil, err
return nil, core.NewBizErr("提交节点配置失败", err)
}
slog.Debug("提交节点配置",
@@ -565,7 +566,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
newChannels = append(newChannels, newChannel)
}
if len(portConfigs) < acc {
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
return nil, core.NewBizErr("网关端口数量到达上限,无法分配")
}
// 提交端口配置
@@ -580,7 +581,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
)
err = gateway.GatewayPortConfigs(portConfigs)
if err != nil {
return nil, err
return nil, core.NewBizErr("提交端口配置失败", err)
}
slog.Debug("提交端口配置", "step", time.Since(step))
@@ -588,7 +589,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
}
if len(newChannels) != count {
return nil, ChannelServiceErr("分配节点失败")
return nil, core.NewBizErr("分配节点失败")
}
return newChannels, nil
@@ -630,7 +631,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
Limit(count).
Scan(&edges)
if err != nil {
return nil, fmt.Errorf("查询符合条件的节点失败: %w", err)
return nil, core.NewBizErr("查询符合条件的节点失败", err)
}
if len(edges) == 0 {
return nil, ErrEdgesNoAvailable
@@ -712,7 +713,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
proxy := proxies[id]
err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs)
if err != nil {
return nil, err
return nil, core.NewBizErr("提交端口配置失败", err)
}
}
slog.Debug("提交端口配置", "step", time.Since(step))
@@ -741,19 +742,14 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
_, err = pipe.Exec(context.Background())
if err != nil {
return err
return core.NewBizErr("缓存通道数据失败", err)
}
// 保存通道
err = q.Channel.
Omit(
q.Channel.EdgeID,
q.Channel.EdgeHost,
q.Channel.DeletedAt,
).
Create(channels...)
if err != nil {
return err
return core.NewBizErr("保存通道失败", err)
}
// 更新套餐使用记录
@@ -785,7 +781,7 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
)
}
if err != nil {
return err
return core.NewBizErr("更新套餐使用记录失败", err)
}
return nil
@@ -841,17 +837,11 @@ type ResourceInfo struct {
Expire time.Time
}
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}
const (
ErrResourceNotExist = ChannelServiceErr("套餐不存在")
ErrResourceInvalid = ChannelServiceErr("套餐不可用")
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
ErrResourceExpired = ChannelServiceErr("套餐已过期")
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
ErrEdgesNoAvailable = ChannelServiceErr("没有可用的节点")
var (
ErrResourceNotExist = core.NewBizErr("套餐不存在")
ErrResourceInvalid = core.NewBizErr("套餐不可用")
ErrResourceExhausted = core.NewBizErr("套餐已用完")
ErrResourceExpired = core.NewBizErr("套餐已过期")
ErrResourceDailyLimit = core.NewBizErr("套餐每日配额已用完")
ErrEdgesNoAvailable = core.NewBizErr("没有可用的节点")
)

View File

@@ -11,7 +11,7 @@ var Edge = &edgeService{}
type edgeService struct{}
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
do := q.Edge.Where(q.Edge.Status.Eq(1))
do := q.Edge.Where()
if filter.Prov != "" {
do = do.Where(q.Edge.Prov.Eq(filter.Prov))
}

View File

@@ -5,7 +5,6 @@ import (
"database/sql"
"encoding/json"
"fmt"
"github.com/shopspring/decimal"
bill2 "platform/web/domains/bill"
resource2 "platform/web/domains/resource"
trade2 "platform/web/domains/trade"
@@ -14,6 +13,8 @@ import (
m "platform/web/models"
q "platform/web/queries"
"time"
"github.com/shopspring/decimal"
)
var Resource = &resourceService{}
@@ -75,9 +76,6 @@ func (s *resourceService) CreateResource(uid int32, now time.Time, ser *CreateRe
if err != nil {
return err
}
if err != nil {
return err
}
return nil
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})