package handlers import ( "fmt" "platform/web/auth" "platform/web/core" q "platform/web/queries" s "platform/web/services" "time" "github.com/gofiber/fiber/v2" ) // region ListChannels type ListChannelsReq struct { core.PageReq AuthType s.ChannelAuthType `json:"auth_type"` ExpireAfter *time.Time `json:"expire_after"` ExpireBefore *time.Time `json:"expire_before"` } func ListChannels(c *fiber.Ctx) error { // 检查权限 authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) if err != nil { return err } // 解析请求参数 req := new(ListChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 构造查询条件 cond := q.Channel. Where(q.Channel.UserID.Eq(authContext.Payload.Id)) switch req.AuthType { case s.ChannelAuthTypeIp: cond.Where(q.Channel.AuthIP.Is(true)) case s.ChannelAuthTypePass: cond.Where(q.Channel.AuthPass.Is(true)) default: break } if req.ExpireAfter != nil { cond.Where(q.Channel.Expiration.Gte(core.LocalDateTime(*req.ExpireAfter))) } if req.ExpireBefore != nil { cond.Where(q.Channel.Expiration.Lte(core.LocalDateTime(*req.ExpireBefore))) } // 查询数据 channels, err := q.Channel.Debug(). Where(cond). Order(q.Channel.CreatedAt.Desc()). Offset(req.GetOffset()). Limit(req.GetLimit()). Find() if err != nil { return err } // 查询总量 var total int64 if len(channels) < req.GetLimit() { total = int64(len(channels) + req.GetOffset()) } else { total, err = q.Channel.Where(cond).Count() if err != nil { return err } } // 返回结果 return c.JSON(core.PageResp{ Total: int(total), Page: req.GetPage(), Size: req.GetSize(), List: channels, }) } // endregion // region CreateChannel type CreateChannelReq struct { ResourceId int32 `json:"resource_id" validate:"required"` AuthType s.ChannelAuthType `json:"auth_type" validate:"required"` Protocol s.ChannelProtocol `json:"protocol" validate:"required"` Count int `json:"count" validate:"required"` Prov string `json:"prov"` City string `json:"city"` Isp string `json:"isp"` } func CreateChannel(c *fiber.Ctx) error { // 检查权限 authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{}) if err != nil { return err } // 获取用户信息 user, err := q.User. Where(q.User.ID.Eq(authContext.Payload.Id)). Take() if err != nil { return err } if user.IDToken == "" { return fiber.NewError(fiber.StatusForbidden, "账号未实名") } count, err := q.Whitelist.Where( q.Whitelist.UserID.Eq(authContext.Payload.Id), q.Whitelist.Host.Eq(c.IP()), ).Count() if err != nil { return err } if count == 0 { return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("非白名单IP %s", c.IP())) } req := new(CreateChannelReq) if err := c.BodyParser(req); err != nil { return err } var isp string switch req.Isp { case "1": isp = "电信" case "2": isp = "联通" case "3": isp = "移动" } result, err := s.Channel.CreateChannel( c.Context(), authContext, req.ResourceId, req.Protocol, req.AuthType, req.Count, s.NodeFilterConfig{ Isp: isp, Prov: req.Prov, City: req.City, }, ) if err != nil { return err } return c.JSON(result) } type CreateChannelResultType string // endregion // region RemoveChannels type RemoveChannelsReq struct { ByIds []int32 `json:"by_ids" validate:"required"` } func RemoveChannels(c *fiber.Ctx) error { // 检查权限 authCtx, err := auth.Protect(c, []s.PayloadType{ s.PayloadUser, s.PayloadClientConfidential, }, []string{}) if err != nil { return err } // 解析请求参数 req := new(RemoveChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 删除通道 err = s.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...) if err != nil { return err } return c.SendStatus(fiber.StatusOK) } // endregion