diff --git a/web/handlers/resource.go b/web/handlers/resource.go index a307b66..d96dc72 100644 --- a/web/handlers/resource.go +++ b/web/handlers/resource.go @@ -84,6 +84,7 @@ func PageResourceShort(c *fiber.Ctx) error { total = int64(len(resource) + req.GetOffset()) } else { total, err = q.Resource. + Joins(q.Resource.Short). Where(do). Count() if err != nil { @@ -180,6 +181,7 @@ func PageResourceLong(c *fiber.Ctx) error { total = int64(len(resource) + req.GetOffset()) } else { total, err = q.Resource. + Joins(q.Resource.Long). Where(do). Count() if err != nil { diff --git a/web/services/channel.go b/web/services/channel.go index de38666..e078091 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -147,7 +147,7 @@ type ResourceView struct { } // 检查用户是否可提取 -func ensure(now time.Time, source netip.Addr, resourceId int32, count int) (*ResourceView, []string, error) { +func ensure(now time.Time, source netip.Addr, resourceId int32, authWhitelist bool, count int) (*ResourceView, []string, error) { if count > 400 { return nil, nil, core.NewBizErr("单次最多提取 400 个") } @@ -172,6 +172,10 @@ func ensure(now time.Time, source netip.Addr, resourceId int32, count int) (*Res return nil, nil, err } + if authWhitelist && len(whitelists) == 0 { + return nil, nil, core.NewBizErr("当前白名单为空,请先添加白名单") + } + ips := make([]string, len(whitelists)) pass := false for i, item := range whitelists { diff --git a/web/services/channel_baiyin.go b/web/services/channel_baiyin.go index ce096e6..92fecd9 100644 --- a/web/services/channel_baiyin.go +++ b/web/services/channel_baiyin.go @@ -32,7 +32,7 @@ func (s *channelBaiyinProvider) CreateChannels(source netip.Addr, resourceId int batch := ID.GenReadable("bat") // 检查并获取套餐与白名单 - resource, whitelists, err := ensure(now, source, resourceId, count) + resource, whitelists, err := ensure(now, source, resourceId, authWhitelist, count) if err != nil { return nil, err } @@ -240,6 +240,9 @@ func (s *channelBaiyinProvider) CreateChannels(source netip.Addr, resourceId int // 提交配置 secret := strings.Split(u.Z(proxy.Secret), ":") + if len(secret) != 2 { + return nil, core.NewServErr(fmt.Sprintf("代理 %s 密钥格式错误", proxy.IP.String()), nil) + } gateway := g.NewGateway(proxy.IP.String(), secret[0], secret[1]) if env.RunMode == env.RunModeProd {