完善错误处理逻辑,统一使用 BizErr 包装业务错误,提供打印源码跳转并返回合适的 http 状态码
This commit is contained in:
11
pkg/env/env.go
vendored
11
pkg/env/env.go
vendored
@@ -11,19 +11,24 @@ import (
|
|||||||
|
|
||||||
// region app
|
// region app
|
||||||
|
|
||||||
|
const (
|
||||||
|
RunModeDev = "debug"
|
||||||
|
RunModeProd = "production"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
RunMode = "debug" // debug, production
|
RunMode = RunModeDev
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadApp() {
|
func loadApp() {
|
||||||
_RunMode := os.Getenv("RUN_MODE")
|
_RunMode := os.Getenv("RUN_MODE")
|
||||||
switch _RunMode {
|
switch _RunMode {
|
||||||
case "debug", "production":
|
case RunModeDev, RunModeProd:
|
||||||
RunMode = _RunMode
|
RunMode = _RunMode
|
||||||
case "":
|
case "":
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
panic("环境变量 RUN_MODE 的值只能是 debug 或 production")
|
panic("环境变量 RUN_MODE 的值只能是 " + RunModeDev + " 或 " + RunModeProd)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,20 +16,18 @@ func Init() {
|
|||||||
switch env.RunMode {
|
switch env.RunMode {
|
||||||
case "debug":
|
case "debug":
|
||||||
handler = tint.NewHandler(writer, &tint.Options{
|
handler = tint.NewHandler(writer, &tint.Options{
|
||||||
AddSource: true,
|
|
||||||
Level: env.LogLevel,
|
Level: env.LogLevel,
|
||||||
TimeFormat: timeFormat,
|
TimeFormat: timeFormat,
|
||||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||||
err, ok := attr.Value.Any().(error)
|
switch v := attr.Value.Any().(type) {
|
||||||
if ok {
|
case error:
|
||||||
return tint.Err(err)
|
return tint.Err(v)
|
||||||
}
|
}
|
||||||
return attr
|
return attr
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
case "production":
|
case "production":
|
||||||
handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{
|
handler = slog.NewJSONHandler(writer, &slog.HandlerOptions{
|
||||||
AddSource: false,
|
|
||||||
Level: env.LogLevel,
|
Level: env.LogLevel,
|
||||||
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
|
ReplaceAttr: func(_ []string, a slog.Attr) slog.Attr {
|
||||||
if a.Key == "time" {
|
if a.Key == "time" {
|
||||||
|
|||||||
@@ -44,13 +44,13 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
|
|||||||
var split = strings.Split(header, " ")
|
var split = strings.Split(header, " ")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
slog.Debug("Authorization 头格式不正确")
|
slog.Debug("Authorization 头格式不正确")
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
var token = strings.TrimSpace(split[1])
|
var token = strings.TrimSpace(split[1])
|
||||||
if token == "" {
|
if token == "" {
|
||||||
slog.Debug("提供的令牌为空")
|
slog.Debug("提供的令牌为空")
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
var auth *Context
|
var auth *Context
|
||||||
@@ -61,34 +61,34 @@ func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context,
|
|||||||
auth, err = authBearer(c.Context(), token)
|
auth, err = authBearer(c.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("Bearer 认证失败", "err", err)
|
slog.Debug("Bearer 认证失败", "err", err)
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
case "Basic":
|
case "Basic":
|
||||||
if !slices.Contains(types, PayloadInternalServer) {
|
if !slices.Contains(types, PayloadInternalServer) {
|
||||||
slog.Debug("禁止使用 Basic 认证方式")
|
slog.Debug("禁止使用 Basic 认证方式")
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
auth, err = authBasic(c.Context(), token)
|
auth, err = authBasic(c.Context(), token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug("Basic 认证失败", "err", err)
|
slog.Debug("Basic 认证失败", "err", err)
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
slog.Debug("无效的认证方式", "method", split[0])
|
slog.Debug("无效的认证方式", "method", split[0])
|
||||||
return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌")
|
return nil, ErrUnauthorize
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查权限
|
// 检查权限
|
||||||
if !slices.Contains(types, auth.Payload.Type) {
|
if !slices.Contains(types, auth.Payload.Type) {
|
||||||
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
|
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
|
||||||
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
|
return nil, ErrForbidden
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
|
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
|
||||||
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
|
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
|
||||||
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
|
return nil, ErrForbidden
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存到上下文
|
// 保存到上下文
|
||||||
@@ -115,9 +115,12 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
|||||||
|
|
||||||
// 解析 Basic 认证信息
|
// 解析 Basic 认证信息
|
||||||
var base, err = base64.RawURLEncoding.DecodeString(token)
|
var base, err = base64.RawURLEncoding.DecodeString(token)
|
||||||
|
if err != nil {
|
||||||
|
base, err = base64.URLEncoding.DecodeString(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("令牌格式错误,无法解析令牌")
|
return nil, errors.New("令牌格式错误,无法解析令牌")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var split = strings.Split(string(base), ":")
|
var split = strings.Split(string(base), ":")
|
||||||
if len(split) != 2 {
|
if len(split) != 2 {
|
||||||
@@ -158,3 +161,14 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
|||||||
Metadata: nil,
|
Metadata: nil,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthenticationErr string
|
||||||
|
|
||||||
|
func (e AuthenticationErr) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUnauthorize = AuthenticationErr("令牌无效")
|
||||||
|
ErrForbidden = AuthenticationErr("没有权限")
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,5 +1,11 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"platform/pkg/env"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
// region page
|
// region page
|
||||||
|
|
||||||
type PageReq struct {
|
type PageReq struct {
|
||||||
@@ -40,3 +46,56 @@ type PageResp struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
|
// region error
|
||||||
|
|
||||||
|
type BizErr struct {
|
||||||
|
msg string
|
||||||
|
err error
|
||||||
|
|
||||||
|
errFile string
|
||||||
|
errLine int
|
||||||
|
errFunc string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BizErr) Error() string {
|
||||||
|
if e.err != nil {
|
||||||
|
return e.msg + ":" + e.err.Error()
|
||||||
|
}
|
||||||
|
return e.msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BizErr) Unwrap() error {
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *BizErr) Source() *slog.Source {
|
||||||
|
return &slog.Source{
|
||||||
|
Function: e.errFunc,
|
||||||
|
File: e.errFile,
|
||||||
|
Line: e.errLine,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBizErr(msg string, err ...error) (biz *BizErr) {
|
||||||
|
biz = &BizErr{
|
||||||
|
msg: msg,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(err) > 0 {
|
||||||
|
biz.err = err[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if env.RunMode == env.RunModeDev {
|
||||||
|
pc, file, line, ok := runtime.Caller(1)
|
||||||
|
if ok {
|
||||||
|
biz.errFile = file
|
||||||
|
biz.errLine = line
|
||||||
|
biz.errFunc = runtime.FuncForPC(pc).Name()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return biz
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion
|
||||||
|
|||||||
24
web/error.go
24
web/error.go
@@ -2,8 +2,9 @@ package web
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"gorm.io/gorm"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"platform/web/auth"
|
||||||
|
"platform/web/core"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@@ -15,6 +16,8 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
|
|||||||
var message = "服务器异常"
|
var message = "服务器异常"
|
||||||
|
|
||||||
var fiberErr *fiber.Error
|
var fiberErr *fiber.Error
|
||||||
|
var authErr auth.AuthenticationErr
|
||||||
|
var bizErr *core.BizErr
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
|
||||||
@@ -23,8 +26,23 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
|
|||||||
code = fiberErr.Code
|
code = fiberErr.Code
|
||||||
message = fiberErr.Message
|
message = fiberErr.Message
|
||||||
|
|
||||||
// gorm 错误,忽略
|
// 认证授权错误
|
||||||
case errors.Is(err, gorm.ErrForeignKeyViolated):
|
case errors.As(err, &authErr):
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, auth.ErrUnauthorize):
|
||||||
|
code = fiber.StatusUnauthorized
|
||||||
|
case errors.Is(err, auth.ErrForbidden):
|
||||||
|
code = fiber.StatusForbidden
|
||||||
|
default:
|
||||||
|
code = fiber.StatusBadRequest
|
||||||
|
}
|
||||||
|
message = err.Error()
|
||||||
|
|
||||||
|
// 服务错误
|
||||||
|
case errors.As(err, &bizErr):
|
||||||
|
code = fiber.StatusBadRequest
|
||||||
|
message = err.Error()
|
||||||
|
slog.Debug("服务错误", slog.Any(slog.SourceKey, bizErr.Source()))
|
||||||
|
|
||||||
// 所有未手动声明的错误类型
|
// 所有未手动声明的错误类型
|
||||||
default:
|
default:
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ var formats = []string{
|
|||||||
//goland:noinspection GoMixedReceiverTypes
|
//goland:noinspection GoMixedReceiverTypes
|
||||||
func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
|
func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
|
||||||
var t time.Time
|
var t time.Time
|
||||||
|
|
||||||
if strValue, ok := value.(string); ok {
|
if strValue, ok := value.(string); ok {
|
||||||
var timeValue time.Time
|
var timeValue time.Time
|
||||||
for _, format := range formats {
|
for _, format := range formats {
|
||||||
|
|||||||
@@ -5,11 +5,15 @@ import (
|
|||||||
"encoding/base32"
|
"encoding/base32"
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"platform/pkg/u"
|
||||||
auth2 "platform/web/auth"
|
auth2 "platform/web/auth"
|
||||||
proxy2 "platform/web/domains/proxy"
|
proxy2 "platform/web/domains/proxy"
|
||||||
g "platform/web/globals"
|
g "platform/web/globals"
|
||||||
|
"platform/web/globals/orm"
|
||||||
m "platform/web/models"
|
m "platform/web/models"
|
||||||
q "platform/web/queries"
|
q "platform/web/queries"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gorm.io/gorm/clause"
|
"gorm.io/gorm/clause"
|
||||||
)
|
)
|
||||||
@@ -24,6 +28,7 @@ type OnlineProxyReq struct {
|
|||||||
type OnlineProxyResp struct {
|
type OnlineProxyResp struct {
|
||||||
Id int32 `json:"id"`
|
Id int32 `json:"id"`
|
||||||
Secret string `json:"secret"`
|
Secret string `json:"secret"`
|
||||||
|
Permits []ProxyPermit `json:"permits"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func OnlineProxy(c *fiber.Ctx) (err error) {
|
func OnlineProxy(c *fiber.Ctx) (err error) {
|
||||||
@@ -53,8 +58,8 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
|
|||||||
var secret = base32.StdEncoding.
|
var secret = base32.StdEncoding.
|
||||||
WithPadding(base32.NoPadding).
|
WithPadding(base32.NoPadding).
|
||||||
EncodeToString(secretBytes)
|
EncodeToString(secretBytes)
|
||||||
|
|
||||||
slog.Debug("生成随机密钥", "ip", ip, "secret", secret)
|
slog.Debug("生成随机密钥", "ip", ip, "secret", secret)
|
||||||
|
|
||||||
var proxy = &m.Proxy{
|
var proxy = &m.Proxy{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Version: int32(req.Version),
|
Version: int32(req.Version),
|
||||||
@@ -63,7 +68,7 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
|
|||||||
Secret: secret,
|
Secret: secret,
|
||||||
Status: 1,
|
Status: 1,
|
||||||
}
|
}
|
||||||
err = q.Proxy.Debug().
|
err = q.Proxy.
|
||||||
Clauses(clause.OnConflict{
|
Clauses(clause.OnConflict{
|
||||||
UpdateAll: true,
|
UpdateAll: true,
|
||||||
Columns: []clause.Column{
|
Columns: []clause.Column{
|
||||||
@@ -75,10 +80,31 @@ func OnlineProxy(c *fiber.Ctx) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
channels, err := q.Channel.Where(
|
||||||
|
q.Channel.ProxyID.Eq(proxy.ID),
|
||||||
|
q.Channel.Expiration.Gt(orm.LocalDateTime(time.Now())),
|
||||||
|
).Find()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var permits []ProxyPermit
|
||||||
|
for _, channel := range channels {
|
||||||
|
permit := ProxyPermit{
|
||||||
|
Id: channel.EdgeID,
|
||||||
|
Expire: time.Time(channel.Expiration),
|
||||||
|
Whitelists: u.P(strings.Split(channel.Whitelists, ",")),
|
||||||
|
Username: &channel.Username,
|
||||||
|
Password: &channel.Password,
|
||||||
|
}
|
||||||
|
permits = append(permits, permit)
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID)
|
slog.Debug("注册转发服务", "ip", ip, "id", proxy.ID)
|
||||||
return c.JSON(&OnlineProxyResp{
|
return c.JSON(&OnlineProxyResp{
|
||||||
Id: proxy.ID,
|
Id: proxy.ID,
|
||||||
Secret: secret,
|
Secret: secret,
|
||||||
|
Permits: permits,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -158,3 +184,11 @@ func AssignProxyFwdPort(c *fiber.Ctx) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
|
type ProxyPermit struct {
|
||||||
|
Id int32 `json:"id"`
|
||||||
|
Expire time.Time `json:"expire"`
|
||||||
|
Whitelists *[]string `json:"whitelists"`
|
||||||
|
Username *string `json:"username"`
|
||||||
|
Password *string `json:"password"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"platform/web/core"
|
||||||
"platform/web/handlers"
|
"platform/web/handlers"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@@ -78,6 +79,6 @@ func ApplyRouters(app *fiber.App) {
|
|||||||
|
|
||||||
// 临时
|
// 临时
|
||||||
app.Get("/test", func(c *fiber.Ctx) error {
|
app.Get("/test", func(c *fiber.Ctx) error {
|
||||||
return c.JSON(c.GetReqHeaders())
|
return core.NewBizErr("测试错误")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/hibiken/asynq"
|
|
||||||
"gorm.io/gen/field"
|
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"platform/pkg/env"
|
"platform/pkg/env"
|
||||||
"platform/pkg/u"
|
"platform/pkg/u"
|
||||||
|
"platform/web/core"
|
||||||
channel2 "platform/web/domains/channel"
|
channel2 "platform/web/domains/channel"
|
||||||
edge2 "platform/web/domains/edge"
|
edge2 "platform/web/domains/edge"
|
||||||
proxy2 "platform/web/domains/proxy"
|
proxy2 "platform/web/domains/proxy"
|
||||||
@@ -24,13 +23,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/hibiken/asynq"
|
||||||
|
"gorm.io/gen/field"
|
||||||
|
|
||||||
"github.com/redis/go-redis/v9"
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Channel = &channelService{}
|
var Channel = &channelService{}
|
||||||
|
|
||||||
type channelService struct {
|
type channelService struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// region 删除通道
|
// region 删除通道
|
||||||
|
|
||||||
@@ -46,7 +47,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
|||||||
}
|
}
|
||||||
channels, err := tx.Channel.Where(do).Find()
|
channels, err := tx.Channel.Where(do).Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查找通道失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyMap := make(map[int32]*m.Proxy)
|
proxyMap := make(map[int32]*m.Proxy)
|
||||||
@@ -67,7 +68,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
|||||||
// 查找资源
|
// 查找资源
|
||||||
resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find()
|
resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查找资源失败", err)
|
||||||
}
|
}
|
||||||
for _, res := range resources {
|
for _, res := range resources {
|
||||||
resourceMap[res.ID] = res
|
resourceMap[res.ID] = res
|
||||||
@@ -76,7 +77,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
|||||||
// 查找代理
|
// 查找代理
|
||||||
proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find()
|
proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查找代理失败", err)
|
||||||
}
|
}
|
||||||
for _, proxy := range proxies {
|
for _, proxy := range proxies {
|
||||||
proxyMap[proxy.ID] = proxy
|
proxyMap[proxy.ID] = proxy
|
||||||
@@ -100,10 +101,10 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
|||||||
Where(q.Channel.ID.In(id...)).
|
Where(q.Channel.ID.In(id...)).
|
||||||
Update(q.Channel.DeletedAt, now)
|
Update(q.Channel.DeletedAt, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("删除通道失败", err)
|
||||||
}
|
}
|
||||||
if result.RowsAffected != int64(len(channels)) {
|
if result.RowsAffected != int64(len(channels)) {
|
||||||
return ChannelServiceErr("删除通道失败")
|
return core.NewBizErr("删除通道数量不匹配")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 禁用代理端口并下线用过的节点
|
// 禁用代理端口并下线用过的节点
|
||||||
@@ -113,7 +114,7 @@ func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
|||||||
if len(shortToRemove) > 0 {
|
if len(shortToRemove) > 0 {
|
||||||
err := removeShortChannelExternal(proxies, shortToRemove)
|
err := removeShortChannelExternal(proxies, shortToRemove)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("提交删除通道配置失败", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -160,7 +161,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
|
|||||||
}
|
}
|
||||||
proxy, ok := proxyMap[proxyId]
|
proxy, ok := proxyMap[proxyId]
|
||||||
if !ok {
|
if !ok {
|
||||||
return ChannelServiceErr("代理不存在")
|
return core.NewBizErr("代理不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
var secret = strings.Split(proxy.Secret, ":")
|
var secret = strings.Split(proxy.Secret, ":")
|
||||||
@@ -173,13 +174,13 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
|
|||||||
// 查询节点配置
|
// 查询节点配置
|
||||||
actives, err := gateway.GatewayPortActive()
|
actives, err := gateway.GatewayPortActive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查询节点配置失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新节点配置
|
// 更新节点配置
|
||||||
err = gateway.GatewayPortConfigs(configs)
|
err = gateway.GatewayPortConfigs(configs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("提交删除通道配置失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 下线对应节点
|
// 下线对应节点
|
||||||
@@ -187,7 +188,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
|
|||||||
for portStr, active := range actives {
|
for portStr, active := range actives {
|
||||||
port, err := strconv.Atoi(portStr)
|
port, err := strconv.Atoi(portStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("端口转换失败", err)
|
||||||
}
|
}
|
||||||
key := uint64(proxyId)<<32 | uint64(port)
|
key := uint64(proxyId)<<32 | uint64(port)
|
||||||
if _, ok := portMap[key]; ok {
|
if _, ok := portMap[key]; ok {
|
||||||
@@ -200,7 +201,7 @@ func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error
|
|||||||
Edge: edges,
|
Edge: edges,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("下线节点失败", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -232,7 +233,7 @@ func (s *channelService) CreateChannel(
|
|||||||
// 查找套餐
|
// 查找套餐
|
||||||
resource, err = findResource(q, resourceId, userId, count, now)
|
resource, err = findResource(q, resourceId, userId, count, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查找套餐失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查找白名单
|
// 查找白名单
|
||||||
@@ -240,7 +241,7 @@ func (s *channelService) CreateChannel(
|
|||||||
if authType == ChannelAuthTypeIp {
|
if authType == ChannelAuthTypeIp {
|
||||||
whitelist, err = findWhitelist(q, userId)
|
whitelist, err = findWhitelist(q, userId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("查找白名单失败", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,13 +262,13 @@ func (s *channelService) CreateChannel(
|
|||||||
channels, err = assignLongChannels(q, userId, resourceId, count, config, filter)
|
channels, err = assignLongChannels(q, userId, resourceId, count, config, filter)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("分配通道失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存通道开通结果
|
// 保存通道开通结果
|
||||||
err = saveAssigns(q, resource, channels, now)
|
err = saveAssigns(q, resource, channels, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("保存通道失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -294,7 +295,7 @@ func (s *channelService) CreateChannel(
|
|||||||
asynq.ProcessIn(duration),
|
asynq.ProcessIn(duration),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("提交异步删除通道任务失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return channels, nil
|
return channels, nil
|
||||||
@@ -352,7 +353,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim
|
|||||||
// 检查套餐使用情况
|
// 检查套餐使用情况
|
||||||
switch info.Mode {
|
switch info.Mode {
|
||||||
default:
|
default:
|
||||||
return nil, ChannelServiceErr("不支持的套餐模式")
|
return nil, core.NewBizErr("不支持的套餐模式")
|
||||||
|
|
||||||
// 包时
|
// 包时
|
||||||
case resource2.ModeTime:
|
case resource2.ModeTime:
|
||||||
@@ -388,10 +389,10 @@ func findWhitelist(q *q.Query, userId int32) ([]string, error) {
|
|||||||
Select(q.Whitelist.Host).
|
Select(q.Whitelist.Host).
|
||||||
Scan(&whitelist)
|
Scan(&whitelist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("查询白名单失败", err)
|
||||||
}
|
}
|
||||||
if len(whitelist) == 0 {
|
if len(whitelist) == 0 {
|
||||||
return nil, ChannelServiceErr("用户没有白名单")
|
return nil, core.NewBizErr("没有配置白名单")
|
||||||
}
|
}
|
||||||
|
|
||||||
return whitelist, nil
|
return whitelist, nil
|
||||||
@@ -404,7 +405,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
|
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
|
||||||
Find()
|
Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("查找网关失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查找已使用的节点
|
// 查找已使用的节点
|
||||||
@@ -424,13 +425,13 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
q.Channel.ProxyID).
|
q.Channel.ProxyID).
|
||||||
Find()
|
Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("查找已使用的节点失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询已配置的节点
|
// 查询已配置的节点
|
||||||
remoteConfigs, err := g.Cloud.CloudAutoQuery()
|
remoteConfigs, err := g.Cloud.CloudAutoQuery()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("查询远端节点配置失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 统计已用节点量与端口查找表
|
// 统计已用节点量与端口查找表
|
||||||
@@ -503,7 +504,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
AutoConfig: newConfigs,
|
AutoConfig: newConfigs,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("提交节点配置失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("提交节点配置",
|
slog.Debug("提交节点配置",
|
||||||
@@ -565,7 +566,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
newChannels = append(newChannels, newChannel)
|
newChannels = append(newChannels, newChannel)
|
||||||
}
|
}
|
||||||
if len(portConfigs) < acc {
|
if len(portConfigs) < acc {
|
||||||
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
|
return nil, core.NewBizErr("网关端口数量到达上限,无法分配")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提交端口配置
|
// 提交端口配置
|
||||||
@@ -580,7 +581,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
)
|
)
|
||||||
err = gateway.GatewayPortConfigs(portConfigs)
|
err = gateway.GatewayPortConfigs(portConfigs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("提交端口配置失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("提交端口配置", "step", time.Since(step))
|
slog.Debug("提交端口配置", "step", time.Since(step))
|
||||||
@@ -588,7 +589,7 @@ func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(newChannels) != count {
|
if len(newChannels) != count {
|
||||||
return nil, ChannelServiceErr("分配节点失败")
|
return nil, core.NewBizErr("分配节点失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
return newChannels, nil
|
return newChannels, nil
|
||||||
@@ -630,7 +631,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
|
|||||||
Limit(count).
|
Limit(count).
|
||||||
Scan(&edges)
|
Scan(&edges)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("查询符合条件的节点失败: %w", err)
|
return nil, core.NewBizErr("查询符合条件的节点失败", err)
|
||||||
}
|
}
|
||||||
if len(edges) == 0 {
|
if len(edges) == 0 {
|
||||||
return nil, ErrEdgesNoAvailable
|
return nil, ErrEdgesNoAvailable
|
||||||
@@ -712,7 +713,7 @@ func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, c
|
|||||||
proxy := proxies[id]
|
proxy := proxies[id]
|
||||||
err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs)
|
err := g.Proxy.Permit(proxy.Host, proxy.Secret, reqs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, core.NewBizErr("提交端口配置失败", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("提交端口配置", "step", time.Since(step))
|
slog.Debug("提交端口配置", "step", time.Since(step))
|
||||||
@@ -741,19 +742,14 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
|
|||||||
|
|
||||||
_, err = pipe.Exec(context.Background())
|
_, err = pipe.Exec(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("缓存通道数据失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存通道
|
// 保存通道
|
||||||
err = q.Channel.
|
err = q.Channel.
|
||||||
Omit(
|
|
||||||
q.Channel.EdgeID,
|
|
||||||
q.Channel.EdgeHost,
|
|
||||||
q.Channel.DeletedAt,
|
|
||||||
).
|
|
||||||
Create(channels...)
|
Create(channels...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("保存通道失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新套餐使用记录
|
// 更新套餐使用记录
|
||||||
@@ -785,7 +781,7 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return core.NewBizErr("更新套餐使用记录失败", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -841,17 +837,11 @@ type ResourceInfo struct {
|
|||||||
Expire time.Time
|
Expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelServiceErr string
|
var (
|
||||||
|
ErrResourceNotExist = core.NewBizErr("套餐不存在")
|
||||||
func (c ChannelServiceErr) Error() string {
|
ErrResourceInvalid = core.NewBizErr("套餐不可用")
|
||||||
return string(c)
|
ErrResourceExhausted = core.NewBizErr("套餐已用完")
|
||||||
}
|
ErrResourceExpired = core.NewBizErr("套餐已过期")
|
||||||
|
ErrResourceDailyLimit = core.NewBizErr("套餐每日配额已用完")
|
||||||
const (
|
ErrEdgesNoAvailable = core.NewBizErr("没有可用的节点")
|
||||||
ErrResourceNotExist = ChannelServiceErr("套餐不存在")
|
|
||||||
ErrResourceInvalid = ChannelServiceErr("套餐不可用")
|
|
||||||
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
|
|
||||||
ErrResourceExpired = ChannelServiceErr("套餐已过期")
|
|
||||||
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
|
|
||||||
ErrEdgesNoAvailable = ChannelServiceErr("没有可用的节点")
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ var Edge = &edgeService{}
|
|||||||
type edgeService struct{}
|
type edgeService struct{}
|
||||||
|
|
||||||
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
|
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
|
||||||
do := q.Edge.Where(q.Edge.Status.Eq(1))
|
do := q.Edge.Where()
|
||||||
if filter.Prov != "" {
|
if filter.Prov != "" {
|
||||||
do = do.Where(q.Edge.Prov.Eq(filter.Prov))
|
do = do.Where(q.Edge.Prov.Eq(filter.Prov))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/shopspring/decimal"
|
|
||||||
bill2 "platform/web/domains/bill"
|
bill2 "platform/web/domains/bill"
|
||||||
resource2 "platform/web/domains/resource"
|
resource2 "platform/web/domains/resource"
|
||||||
trade2 "platform/web/domains/trade"
|
trade2 "platform/web/domains/trade"
|
||||||
@@ -14,6 +13,8 @@ import (
|
|||||||
m "platform/web/models"
|
m "platform/web/models"
|
||||||
q "platform/web/queries"
|
q "platform/web/queries"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/shopspring/decimal"
|
||||||
)
|
)
|
||||||
|
|
||||||
var Resource = &resourceService{}
|
var Resource = &resourceService{}
|
||||||
@@ -75,9 +76,6 @@ func (s *resourceService) CreateResource(uid int32, now time.Time, ser *CreateRe
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
|
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
|
||||||
|
|||||||
Reference in New Issue
Block a user