From 2fa8b4d5406f8049ddaf6b5ed3cd42ac8615865e Mon Sep 17 00:00:00 2001 From: luorijun Date: Mon, 28 Apr 2025 11:44:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=9F=BA=E6=9C=AC=E8=AE=A4?= =?UTF-8?q?=E8=AF=81=E8=A7=A3=E7=A0=81=E6=96=B9=E5=BC=8F=EF=BC=9B=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=9F=A5=E8=AF=A2=20channels=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/auth/auth.go | 2 +- web/handlers/auth.go | 2 +- web/handlers/channel.go | 124 ++++++++++++++++++++++++++++++++-------- web/router.go | 1 + web/services/channel.go | 4 ++ 5 files changed, 107 insertions(+), 26 deletions(-) diff --git a/web/auth/auth.go b/web/auth/auth.go index f73c399..5f20bb4 100644 --- a/web/auth/auth.go +++ b/web/auth/auth.go @@ -81,7 +81,7 @@ func authBearer(ctx context.Context, token string) (*services.AuthContext, error func authBasic(_ context.Context, token string) (*services.AuthContext, error) { // 解析 Basic 认证信息 - var base, err = base64.URLEncoding.DecodeString(token) + var base, err = base64.RawURLEncoding.DecodeString(token) if err != nil { slog.Debug(err.Error()) return nil, err diff --git a/web/handlers/auth.go b/web/handlers/auth.go index 9851105..aa09854 100644 --- a/web/handlers/auth.go +++ b/web/handlers/auth.go @@ -152,7 +152,7 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string if header != "" { basic := strings.TrimPrefix(header, "Basic ") if basic != "" { - base, err := base64.URLEncoding.DecodeString(basic) + base, err := base64.RawURLEncoding.DecodeString(basic) if err != nil { return nil, err } diff --git a/web/handlers/channel.go b/web/handlers/channel.go index 0ce5bce..c79c16d 100644 --- a/web/handlers/channel.go +++ b/web/handlers/channel.go @@ -1,30 +1,107 @@ package handlers import ( - "errors" "fmt" "platform/web/auth" + "platform/web/common" q "platform/web/queries" - "platform/web/services" + s "platform/web/services" + "time" "github.com/gofiber/fiber/v2" ) +// region ListChannels + +type ListChannelsReq struct { + common.PageReq + AuthType s.ChannelAuthType `json:"auth_type"` + ExpireAfter *time.Time `json:"expire_after"` + ExpireBefore *time.Time `json:"expire_before"` +} + +func ListChannels(c *fiber.Ctx) error { + // 检查权限 + authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) + if err != nil { + return err + } + + // 解析请求参数 + req := new(ListChannelsReq) + if err := c.BodyParser(req); err != nil { + return err + } + + // 构造查询条件 + cond := q.Channel. + Where(q.Channel.UserID.Eq(authContext.Payload.Id)) + switch req.AuthType { + case s.ChannelAuthTypeIp: + cond.Where(q.Channel.AuthIP.Is(true)) + case s.ChannelAuthTypePass: + cond.Where(q.Channel.AuthPass.Is(true)) + default: + break + } + + if req.ExpireAfter != nil { + cond.Where(q.Channel.Expiration.Gte(common.LocalDateTime(*req.ExpireAfter))) + } + if req.ExpireBefore != nil { + cond.Where(q.Channel.Expiration.Lte(common.LocalDateTime(*req.ExpireBefore))) + } + + // 查询数据 + channels, err := q.Channel.Debug(). + Where(cond). + Order(q.Channel.CreatedAt.Desc()). + Offset(req.GetOffset()). + Limit(req.GetLimit()). + Find() + if err != nil { + return err + } + + // 查询总量 + var total int64 + if len(channels) < req.GetLimit() { + total = int64(len(channels) + req.GetOffset()) + } else { + total, err = q.Channel.Debug(). + Where(cond). + Count() + if err != nil { + return err + } + } + + // 返回结果 + return c.JSON(common.PageResp{ + Total: int(total), + Page: req.GetPage(), + Size: req.GetSize(), + List: channels, + }) +} + +// endregion + // region CreateChannel type CreateChannelReq struct { - ResourceId int32 `json:"resource_id" validate:"required"` - AuthType services.ChannelAuthType `json:"auth_type" validate:"required"` - Protocol services.ChannelProtocol `json:"protocol" validate:"required"` - Count int `json:"count" validate:"required"` - Prov string `json:"prov"` - City string `json:"city"` - Isp string `json:"isp"` + ResourceId int32 `json:"resource_id" validate:"required"` + AuthType s.ChannelAuthType `json:"auth_type" validate:"required"` + Protocol s.ChannelProtocol `json:"protocol" validate:"required"` + Count int `json:"count" validate:"required"` + Prov string `json:"prov"` + City string `json:"city"` + Isp string `json:"isp"` } func CreateChannel(c *fiber.Ctx) error { // 检查权限 - authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{}) + authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) if err != nil { return err } @@ -65,14 +142,14 @@ func CreateChannel(c *fiber.Ctx) error { isp = "移动" } - result, err := services.Channel.CreateChannel( + result, err := s.Channel.CreateChannel( c.Context(), authContext, req.ResourceId, req.Protocol, req.AuthType, req.Count, - services.NodeFilterConfig{ + s.NodeFilterConfig{ Isp: isp, Prov: req.Prov, City: req.City, @@ -87,11 +164,6 @@ func CreateChannel(c *fiber.Ctx) error { type CreateChannelResultType string -const ( - CreateChannelResultTypeJson CreateChannelResultType = "json" - CreateChannelResultTypeText CreateChannelResultType = "text" -) - // endregion // region RemoveChannels @@ -101,19 +173,23 @@ type RemoveChannelsReq struct { } func RemoveChannels(c *fiber.Ctx) error { + // 检查权限 + authCtx, err := auth.Protect(c, []s.PayloadType{ + s.PayloadUser, + s.PayloadClientConfidential, + }, []string{}) + if err != nil { + return err + } + + // 解析请求参数 req := new(RemoveChannelsReq) if err := c.BodyParser(req); err != nil { return err } - // 获取用户信息 - authCtx, ok := c.Locals("auth").(*services.AuthContext) - if !ok { - return errors.New("user not found") - } - // 删除通道 - err := services.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...) + err = s.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...) if err != nil { return err } diff --git a/web/router.go b/web/router.go index e8cbd79..518d522 100644 --- a/web/router.go +++ b/web/router.go @@ -27,6 +27,7 @@ func ApplyRouters(app *fiber.App) { // 通道 channel := api.Group("/channel") + channel.Post("/list", handlers.ListChannels) channel.Post("/create", handlers.CreateChannel) channel.Post("/remove", handlers.RemoveChannels) diff --git a/web/services/channel.go b/web/services/channel.go index cf9c971..7aa1aa6 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -620,6 +620,7 @@ func assignPort( UserID: userId, ProxyID: proxy.ID, UserHost: item, + ProxyHost: proxy.Host, ProxyPort: int32(port), AuthIP: true, AuthPass: false, @@ -639,6 +640,7 @@ func assignPort( channels = append(channels, &models.Channel{ UserID: userId, ProxyID: proxy.ID, + ProxyHost: proxy.Host, ProxyPort: int32(port), AuthIP: false, AuthPass: true, @@ -654,6 +656,8 @@ func assignPort( Username: &username, Password: &password, }) + default: + return nil, nil, ChannelServiceErr("不支持的通道认证方式") } }