完善错误处理逻辑,统一使用 BizErr 包装业务错误,提供打印源码跳转并返回合适的 http 状态码

This commit is contained in:
2025-05-24 12:37:16 +08:00
parent 928d78d41b
commit 1e7b5777a2
11 changed files with 203 additions and 87 deletions

11
pkg/env/env.go vendored
View File

@@ -11,19 +11,24 @@ import (
// region app // region app
const (
RunModeDev = "debug"
RunModeProd = "production"
)
var ( var (
RunMode = "debug" // debug, production RunMode = RunModeDev
) )
func loadApp() { func loadApp() {
_RunMode := os.Getenv("RUN_MODE") _RunMode := os.Getenv("RUN_MODE")
switch _RunMode { switch _RunMode {
case "debug", "production": case RunModeDev, RunModeProd:
RunMode = _RunMode RunMode = _RunMode
case "": case "":
break break
default: default:
panic("环境变量 RUN_MODE 的值只能是 debug 或 production") panic("环境变量 RUN_MODE 的值只能是 " + RunModeDev + " 或 " + RunModeProd)
} }
} }

View File

@@ -16,21 +16,19 @@ func Init() {
switch env.RunMode { switch env.RunMode {
case "debug": case "debug":
handler = tint.NewHandler(writer, &tint.Options{ handler = tint.NewHandler(writer, &tint.Options{
AddSource: true,
Level: env.LogLevel, Level: env.LogLevel,
TimeFormat: timeFormat, TimeFormat: timeFormat,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
err, ok := attr.Value.Any().(error) switch v := attr.Value.Any().(type) {
if ok { case error:
return tint.Err(err) return tint.Err(v)
} }
return attr return attr
}, },
}) })
case "production": case "production":
handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{ handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{
AddSource: false, Level: env.LogLevel,
Level: env.LogLevel,
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr { ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
if a.Key == "time" { if a.Key == "time" {
return slog.String("time", a.Value.Time().Format(timeFormat)) return slog.String("time", a.Value.Time().Format(timeFormat))

View File

@@ -44,13 +44,13 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
var split = strings.Split(header, " ") var split = strings.Split(header, " ")
if len(split) != 2 { if len(split) != 2 {
slog.Debug("Authorization 头格式不正确") slog.Debug("Authorization 头格式不正确")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
var token = strings.TrimSpace(split[1]) var token = strings.TrimSpace(split[1])
if token == "" { if token == "" {
slog.Debug("提供的令牌为空") slog.Debug("提供的令牌为空")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
var auth *Context var auth *Context
@@ -61,34 +61,34 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
auth, err = authBearer(c.Context(), token) auth, err = authBearer(c.Context(), token)
if err != nil { if err != nil {
slog.Debug("Bearer 认证失败", "err", err) slog.Debug("Bearer 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
case "Basic": case "Basic":
if !slices.Contains(types, PayloadInternalServer) { if !slices.Contains(types, PayloadInternalServer) {
slog.Debug("禁止使用 Basic 认证方式") slog.Debug("禁止使用 Basic 认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
auth, err = authBasic(c.Context(), token) auth, err = authBasic(c.Context(), token)
if err != nil { if err != nil {
slog.Debug("Basic 认证失败", "err", err) slog.Debug("Basic 认证失败", "err", err)
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
default: default:
slog.Debug("无效的认证方式", "method", split[0]) slog.Debug("无效的认证方式", "method", split[0])
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") return nil, ErrUnauthorize
} }
// 检查权限 // 检查权限
if !slices.Contains(types, auth.Payload.Type) { if !slices.Contains(types, auth.Payload.Type) {
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type) slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") return nil, ErrForbidden
} }
if len(permissions) > 0 && !auth.AnyPermission(permissions...) { if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions) slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") return nil, ErrForbidden
} }
// 保存到上下文 // 保存到上下文
@@ -116,7 +116,10 @@ func authBasic(_ context.Context, token string) (*Context, error) {
// 解析 Basic 认证信息 // 解析 Basic 认证信息
var base, err = base64.RawURLEncoding.DecodeString(token) var base, err = base64.RawURLEncoding.DecodeString(token)
if err != nil { if err != nil {
return nil, errors.New("令牌格式错误,无法解析令牌") base, err = base64.URLEncoding.DecodeString(token)
if err != nil {
return nil, errors.New("令牌格式错误,无法解析令牌")
}
} }
var split = strings.Split(string(base), ":") var split = strings.Split(string(base), ":")
@@ -158,3 +161,14 @@ func authBasic(_ context.Context, token string) (*Context, error) {
Metadata: nil, Metadata: nil,
}, nil }, nil
} }
type AuthenticationErr string
func (e AuthenticationErr) Error() string {
return string(e)
}
var (
ErrUnauthorize = AuthenticationErr("令牌无效")
ErrForbidden = AuthenticationErr("没有权限")
)

View File

@@ -1,5 +1,11 @@
package core package core
import (
"log/slog"
"platform/pkg/env"
"runtime"
)
// region page // region page
type PageReq struct { type PageReq struct {
@@ -40,3 +46,56 @@ type PageResp struct {
} }
// endregion // endregion
// region error
type BizErr struct {
msg string
err error
errFile string
errLine int
errFunc string
}
func (e *BizErr) Error() string {
if e.err != nil {
return e.msg + "" + e.err.Error()
}
return e.msg
}
func (e *BizErr) Unwrap() error {
return e.err
}
func (e *BizErr) Source() *slog.Source {
return &slog.Source{
Function: e.errFunc,
File: e.errFile,
Line: e.errLine,
}
}
func NewBizErr(msg string, err ...error) (biz *BizErr) {
biz = &BizErr{
msg: msg,
}
if len(err) > 0 {
biz.err = err[0]
}
if env.RunMode == env.RunModeDev {
pc, file, line, ok := runtime.Caller(1)
if ok {
biz.errFile = file
biz.errLine = line
biz.errFunc = runtime.FuncForPC(pc).Name()
}
}
return biz
}
// endregion

View File

@@ -2,8 +2,9 @@ package web
import ( import (
"errors" "errors"
"gorm.io/gorm"
"log/slog" "log/slog"
"platform/web/auth"
"platform/web/core"
"reflect" "reflect"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -15,6 +16,8 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
var message = "服务器异常" var message = "服务器异常"
var fiberErr *fiber.Error var fiberErr *fiber.Error
var authErr auth.AuthenticationErr
var bizErr *core.BizErr
switch { switch {
@@ -23,8 +26,23 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
code = fiberErr.Code code = fiberErr.Code
message = fiberErr.Message message = fiberErr.Message
// gorm 错误,忽略 // 认证授权错误
case errors.Is(err, gorm.ErrForeignKeyViolated): case errors.As(err, &authErr):
switch {
case errors.Is(err, auth.ErrUnauthorize):
code = fiber.StatusUnauthorized
case errors.Is(err, auth.ErrForbidden):
code = fiber.StatusForbidden
default:
code = fiber.StatusBadRequest
}
message = err.Error()
// 服务错误
case errors.As(err, &bizErr):
code = fiber.StatusBadRequest
message = err.Error()
slog.Debug("服务错误", slog.Any(slog.SourceKey, bizErr.Source()))
// 所有未手动声明的错误类型 // 所有未手动声明的错误类型
default: default:

View File

@@ -23,7 +23,6 @@ var formats = []string{
//goland:noinspection GoMixedReceiverTypes //goland:noinspection GoMixedReceiverTypes
func (ldt *LocalDateTime) Scan(value interface{}) (err error) { func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
var t time.Time var t time.Time
if strValue, ok := value.(string); ok { if strValue, ok := value.(string); ok {
var timeValue time.Time var timeValue time.Time
for _, format := range formats { for _, format := range formats {

View File

@@ -5,11 +5,15 @@ import (
"encoding/base32" "encoding/base32"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"log/slog" "log/slog"
"platform/pkg/u"
auth2 "platform/web/auth" auth2 "platform/web/auth"
proxy2 "platform/web/domains/proxy" proxy2 "platform/web/domains/proxy"
g "platform/web/globals" g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"strings"
"time"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
@@ -22,8 +26,9 @@ type OnlineProxyReq struct {
} }
type OnlineProxyResp struct { type OnlineProxyResp struct {
Id int32 `json:"id"` Id int32 `json:"id"`
Secret string `json:"secret"` Secret string `json:"secret"`
Permits []ProxyPermit `json:"permits"`
} }
func OnlineProxy(c *fiber.Ctx) (err error) { func OnlineProxy(c *fiber.Ctx) (err error) {
@@ -53,8 +58,8 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
var secret = base32.StdEncoding. var secret = base32.StdEncoding.
WithPadding(base32.NoPadding). WithPadding(base32.NoPadding).
EncodeToString(secretBytes) EncodeToString(secretBytes)
slog.Debug("生成随机密钥", "ip", ip, "secret", secret) slog.Debug("生成随机密钥", "ip", ip, "secret", secret)
var proxy = &m.Proxy{ var proxy = &m.Proxy{
Name: req.Name, Name: req.Name,
Version: int32(req.Version), Version: int32(req.Version),
@@ -63,7 +68,7 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
Secret: secret, Secret: secret,
Status: 1, Status: 1,
} }
err = q.Proxy.Debug(). err = q.Proxy.
Clauses(clause.OnConflict{ Clauses(clause.OnConflict{
UpdateAll: true, UpdateAll: true,
Columns: []clause.Column{ Columns: []clause.Column{
@@ -75,10 +80,31 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
return err return err
} }
channels, err := q.Channel.Where(
q.Channel.ProxyID.Eq(proxy.ID),
q.Channel.Expiration.Gt(orm.LocalDateTime(time.Now())),
).Find()
if err != nil {
return err
}
var permits []ProxyPermit
for _, channel := range channels {
permit := ProxyPermit{
Id: channel.EdgeID,
Expire: time.Time(channel.Expiration),
Whitelists: u.P(strings.Split(channel.Whitelists, ",")),
Username: &channel.Username,
Password: &channel.Password,
}
permits = append(permits, permit)
}
slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID) slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID)
return c.JSON(&OnlineProxyResp{ return c.JSON(&OnlineProxyResp{
Id: proxy.ID, Id: proxy.ID,
Secret: secret, Secret: secret,
Permits: permits,
}) })
} }
@@ -158,3 +184,11 @@ func AssignProxyFwdPort(c *fiber.Ctx) (err error) {
} }
// endregion // endregion
type ProxyPermit struct {
Id int32 `json:"id"`
Expire time.Time `json:"expire"`
Whitelists *[]string `json:"whitelists"`
Username *string `json:"username"`
Password *string `json:"password"`
}

View File

@@ -1,6 +1,7 @@
package web package web
import ( import (
"platform/web/core"
"platform/web/handlers" "platform/web/handlers"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
@@ -78,6 +79,6 @@ func ApplyRouters(app *fiber.App) {
// 临时 // 临时
app.Get("/test", func(c *fiber.Ctx) error { app.Get("/test", func(c *fiber.Ctx) error {
return c.JSON(c.GetReqHeaders()) return core.NewBizErr("测试错误")
}) })
} }

View File

@@ -4,13 +4,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/hibiken/asynq"
"gorm.io/gen/field"
"log/slog" "log/slog"
"math" "math"
"math/rand/v2" "math/rand/v2"
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/u" "platform/pkg/u"
"platform/web/core"
channel2 "platform/web/domains/channel" channel2 "platform/web/domains/channel"
edge2 "platform/web/domains/edge" edge2 "platform/web/domains/edge"
proxy2 "platform/web/domains/proxy" proxy2 "platform/web/domains/proxy"
@@ -24,13 +23,15 @@ import (
"strings" "strings"
"time" "time"
"github.com/hibiken/asynq"
"gorm.io/gen/field"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
var Channel = &channelService{} var Channel = &channelService{}
type channelService struct { type channelService struct{}
}
// region 删除通道 // region 删除通道
@@ -46,7 +47,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
} }
channels, err := tx.Channel.Where(do).Find() channels, err := tx.Channel.Where(do).Find()
if err != nil { if err != nil {
return err return core.NewBizErr("查找通道失败", err)
} }
proxyMap := make(map[int32]*m.Proxy) proxyMap := make(map[int32]*m.Proxy)
@@ -67,7 +68,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
// 查找资源 // 查找资源
resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find() resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find()
if err != nil { if err != nil {
return err return core.NewBizErr("查找资源失败", err)
} }
for _, res := range resources { for _, res := range resources {
resourceMap[res.ID] = res resourceMap[res.ID] = res
@@ -76,7 +77,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
// 查找代理 // 查找代理
proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find() proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find()
if err != nil { if err != nil {
return err return core.NewBizErr("查找代理失败", err)
} }
for _, proxy := range proxies { for _, proxy := range proxies {
proxyMap[proxy.ID] = proxy proxyMap[proxy.ID] = proxy
@@ -100,10 +101,10 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
Where(q.Channel.ID.In(id...)). Where(q.Channel.ID.In(id...)).
Update(q.Channel.DeletedAt, now) Update(q.Channel.DeletedAt, now)
if err != nil { if err != nil {
return err return core.NewBizErr("删除通道失败", err)
} }
if result.RowsAffected != int64(len(channels)) { if result.RowsAffected != int64(len(channels)) {
return ChannelServiceErr("删除通道失败") return core.NewBizErr("删除通道数量不匹配")
} }
// 禁用代理端口并下线用过的节点 // 禁用代理端口并下线用过的节点
@@ -113,7 +114,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
if len(shortToRemove) > 0 { if len(shortToRemove) > 0 {
err := removeShortChannelExternal(proxies, shortToRemove) err := removeShortChannelExternal(proxies, shortToRemove)
if err != nil { if err != nil {
return err return core.NewBizErr("提交删除通道配置失败", err)
} }
} }
@@ -160,7 +161,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
} }
proxy, ok := proxyMap[proxyId] proxy, ok := proxyMap[proxyId]
if !ok { if !ok {
return ChannelServiceErr("代理不存在") return core.NewBizErr("代理不存在")
} }
var secret = strings.Split(proxy.Secret, ":") var secret = strings.Split(proxy.Secret, ":")
@@ -173,13 +174,13 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
// 查询节点配置 // 查询节点配置
actives, err := gateway.GatewayPortActive() actives, err := gateway.GatewayPortActive()
if err != nil { if err != nil {
return err return core.NewBizErr("查询节点配置失败", err)
} }
// 更新节点配置 // 更新节点配置
err = gateway.GatewayPortConfigs(configs) err = gateway.GatewayPortConfigs(configs)
if err != nil { if err != nil {
return err return core.NewBizErr("提交删除通道配置失败", err)
} }
// 下线对应节点 // 下线对应节点
@@ -187,7 +188,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
for portStr, active := range actives { for portStr, active := range actives {
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
if err != nil { if err != nil {
return err return core.NewBizErr("端口转换失败", err)
} }
key := uint64(proxyId)<<32 | uint64(port) key := uint64(proxyId)<<32 | uint64(port)
if _, ok := portMap[key]; ok { if _, ok := portMap[key]; ok {
@@ -200,7 +201,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
Edge: edges, Edge: edges,
}) })
if err != nil { if err != nil {
return err return core.NewBizErr("下线节点失败", err)
} }
} }
} }
@@ -232,7 +233,7 @@ func (s *channelService) CreateChannel(
// 查找套餐 // 查找套餐
resource, err = findResource(q, resourceId, userId, count, now) resource, err = findResource(q, resourceId, userId, count, now)
if err != nil { if err != nil {
return err return core.NewBizErr("查找套餐失败", err)
} }
// 查找白名单 // 查找白名单
@@ -240,7 +241,7 @@ func (s *channelService) CreateChannel(
if authType == ChannelAuthTypeIp { if authType == ChannelAuthTypeIp {
whitelist, err = findWhitelist(q, userId) whitelist, err = findWhitelist(q, userId)
if err != nil { if err != nil {
return err return core.NewBizErr("查找白名单失败", err)
} }
} }
@@ -261,13 +262,13 @@ func (s *channelService) CreateChannel(
channels, err = assignLongChannels(q, userId, resourceId, count, config, filter) channels, err = assignLongChannels(q, userId, resourceId, count, config, filter)
} }
if err != nil { if err != nil {
return err return core.NewBizErr("分配通道失败", err)
} }
// 保存通道开通结果 // 保存通道开通结果
err = saveAssigns(q, resource, channels, now) err = saveAssigns(q, resource, channels, now)
if err != nil { if err != nil {
return err return core.NewBizErr("保存通道失败", err)
} }
return nil return nil
@@ -294,7 +295,7 @@ func (s *channelService) CreateChannel(
asynq.ProcessIn(duration), asynq.ProcessIn(duration),
) )
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("提交异步删除通道任务失败", err)
} }
return channels, nil return channels, nil
@@ -352,7 +353,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim
// 检查套餐使用情况 // 检查套餐使用情况
switch info.Mode { switch info.Mode {
default: default:
return nil, ChannelServiceErr("不支持的套餐模式") return nil, core.NewBizErr("不支持的套餐模式")
// 包时 // 包时
case resource2.ModeTime: case resource2.ModeTime:
@@ -388,10 +389,10 @@ func findWhitelist(q *q.Query, userId int32) ([]string, error) {
Select(q.Whitelist.Host). Select(q.Whitelist.Host).
Scan(&whitelist) Scan(&whitelist)
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("查询白名单失败", err)
} }
if len(whitelist) == 0 { if len(whitelist) == 0 {
return nil, ChannelServiceErr("用户没有白名单") return nil, core.NewBizErr("没有配置白名单")
} }
return whitelist, nil return whitelist, nil
@@ -404,7 +405,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))). Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
Find() Find()
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("查找网关失败", err)
} }
// 查找已使用的节点 // 查找已使用的节点
@@ -424,13 +425,13 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
q.Channel.ProxyID). q.Channel.ProxyID).
Find() Find()
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("查找已使用的节点失败", err)
} }
// 查询已配置的节点 // 查询已配置的节点
remoteConfigs, err := g.Cloud.CloudAutoQuery() remoteConfigs, err := g.Cloud.CloudAutoQuery()
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("查询远端节点配置失败", err)
} }
// 统计已用节点量与端口查找表 // 统计已用节点量与端口查找表
@@ -503,7 +504,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
AutoConfig: newConfigs, AutoConfig: newConfigs,
}) })
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("提交节点配置失败", err)
} }
slog.Debug("提交节点配置", slog.Debug("提交节点配置",
@@ -565,7 +566,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
newChannels = append(newChannels, newChannel) newChannels = append(newChannels, newChannel)
} }
if len(portConfigs) < acc { if len(portConfigs) < acc {
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配") return nil, core.NewBizErr("网关端口数量到达上限,无法分配")
} }
// 提交端口配置 // 提交端口配置
@@ -580,7 +581,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
) )
err = gateway.GatewayPortConfigs(portConfigs) err = gateway.GatewayPortConfigs(portConfigs)
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("提交端口配置失败", err)
} }
slog.Debug("提交端口配置", "step", time.Since(step)) slog.Debug("提交端口配置", "step", time.Since(step))
@@ -588,7 +589,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
} }
if len(newChannels) != count { if len(newChannels) != count {
return nil, ChannelServiceErr("分配节点失败") return nil, core.NewBizErr("分配节点失败")
} }
return newChannels, nil return newChannels, nil
@@ -630,7 +631,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
Limit(count). Limit(count).
Scan(&edges) Scan(&edges)
if err != nil { if err != nil {
return nil, fmt.Errorf("查询符合条件的节点失败: %w", err) return nil, core.NewBizErr("查询符合条件的节点失败", err)
} }
if len(edges) == 0 { if len(edges) == 0 {
return nil, ErrEdgesNoAvailable return nil, ErrEdgesNoAvailable
@@ -712,7 +713,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
proxy := proxies[id] proxy := proxies[id]
err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs) err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs)
if err != nil { if err != nil {
return nil, err return nil, core.NewBizErr("提交端口配置失败", err)
} }
} }
slog.Debug("提交端口配置", "step", time.Since(step)) slog.Debug("提交端口配置", "step", time.Since(step))
@@ -741,19 +742,14 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
_, err = pipe.Exec(context.Background()) _, err = pipe.Exec(context.Background())
if err != nil { if err != nil {
return err return core.NewBizErr("缓存通道数据失败", err)
} }
// 保存通道 // 保存通道
err = q.Channel. err = q.Channel.
Omit(
q.Channel.EdgeID,
q.Channel.EdgeHost,
q.Channel.DeletedAt,
).
Create(channels...) Create(channels...)
if err != nil { if err != nil {
return err return core.NewBizErr("保存通道失败", err)
} }
// 更新套餐使用记录 // 更新套餐使用记录
@@ -785,7 +781,7 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
) )
} }
if err != nil { if err != nil {
return err return core.NewBizErr("更新套餐使用记录失败", err)
} }
return nil return nil
@@ -841,17 +837,11 @@ type ResourceInfo struct {
Expire time.Time Expire time.Time
} }
type ChannelServiceErr string var (
ErrResourceNotExist = core.NewBizErr("套餐不存在")
func (c ChannelServiceErr) Error() string { ErrResourceInvalid = core.NewBizErr("套餐不可用")
return string(c) ErrResourceExhausted = core.NewBizErr("套餐已用完")
} ErrResourceExpired = core.NewBizErr("套餐已过期")
ErrResourceDailyLimit = core.NewBizErr("套餐每日配额已用完")
const ( ErrEdgesNoAvailable = core.NewBizErr("没有可用的节点")
ErrResourceNotExist = ChannelServiceErr("套餐不存在")
ErrResourceInvalid = ChannelServiceErr("套餐不可用")
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
ErrResourceExpired = ChannelServiceErr("套餐已过期")
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
ErrEdgesNoAvailable = ChannelServiceErr("没有可用的节点")
) )

View File

@@ -11,7 +11,7 @@ var Edge = &edgeService{}
type edgeService struct{} type edgeService struct{}
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) { func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
do := q.Edge.Where(q.Edge.Status.Eq(1)) do := q.Edge.Where()
if filter.Prov != "" { if filter.Prov != "" {
do = do.Where(q.Edge.Prov.Eq(filter.Prov)) do = do.Where(q.Edge.Prov.Eq(filter.Prov))
} }

View File

@@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/shopspring/decimal"
bill2 "platform/web/domains/bill" bill2 "platform/web/domains/bill"
resource2 "platform/web/domains/resource" resource2 "platform/web/domains/resource"
trade2 "platform/web/domains/trade" trade2 "platform/web/domains/trade"
@@ -14,6 +13,8 @@ import (
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"time" "time"
"github.com/shopspring/decimal"
) )
var Resource = &resourceService{} var Resource = &resourceService{}
@@ -75,9 +76,6 @@ func (s *resourceService) CreateResource(uid int32, now time.Time, ser *CreateRe
if err != nil { if err != nil {
return err return err
} }
if err != nil {
return err
}
return nil return nil
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead}) }, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})