package handlers import ( "net/netip" "platform/pkg/u" "platform/web/auth" "platform/web/core" g "platform/web/globals" m "platform/web/models" q "platform/web/queries" s "platform/web/services" "time" "github.com/gofiber/fiber/v2" ) // region ListChannels type ListChannelsReq struct { core.PageReq AuthType s.ChannelAuthType `json:"auth_type"` ExpireAfter *time.Time `json:"expire_after"` ExpireBefore *time.Time `json:"expire_before"` } func ListChannels(c *fiber.Ctx) error { // 检查权限 authContext, err := auth.GetAuthCtx(c).PermitUser() if err != nil { return err } // 解析请求参数 req := new(ListChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 构造查询条件 cond := q.Channel. Where(q.Channel.UserID.Eq(authContext.User.ID)) switch req.AuthType { case s.ChannelAuthTypeIp: cond.Where(q.Channel.Whitelists.IsNotNull()) case s.ChannelAuthTypePass: cond.Where(q.Channel.Username.IsNotNull(), q.Channel.Password.IsNotNull()) default: break } if req.ExpireAfter != nil { cond.Where(q.Channel.ExpiredAt.Gte(*req.ExpireAfter)) } if req.ExpireBefore != nil { cond.Where(q.Channel.ExpiredAt.Lte(*req.ExpireBefore)) } // 查询数据 channels, err := q.Channel. Where(cond). Order(q.Channel.CreatedAt.Desc()). Offset(req.GetOffset()). Limit(req.GetLimit()). Find() if err != nil { return err } // 查询总量 var total int64 if len(channels) < req.GetLimit() { total = int64(len(channels) + req.GetOffset()) } else { total, err = q.Channel.Where(cond).Count() if err != nil { return err } } // 返回结果 return c.JSON(core.PageResp{ Total: int(total), Page: req.GetPage(), Size: req.GetSize(), List: channels, }) } // endregion // region CreateChannel type CreateChannelReq struct { ResourceId int32 `json:"resource_id" validate:"required"` AuthType s.ChannelAuthType `json:"auth_type" validate:"required"` Protocol int `json:"protocol" validate:"required"` Count int `json:"count" validate:"required"` Prov *string `json:"prov"` City *string `json:"city"` Isp *int `json:"isp"` } type CreateChannelRespItem struct { Proto int `json:"-"` Host string `json:"host"` Port uint16 `json:"port"` Username *string `json:"username,omitempty"` Password *string `json:"password,omitempty"` } func CreateChannel(c *fiber.Ctx) error { // 解析参数 req := new(CreateChannelReq) if err := g.Validator.ParseBody(c, req); err != nil { return core.NewBizErr("解析参数失败", err) } ip, err := netip.ParseAddr(c.IP()) if err != nil { return core.NewBizErr("获取客户端地址失败", err) } // 创建通道 result, err := s.Channel.CreateChannels( ip, req.ResourceId, req.AuthType == s.ChannelAuthTypeIp, req.AuthType == s.ChannelAuthTypePass, req.Count, s.EdgeFilter{ Isp: u.ElseTo(req.Isp, m.ToEdgeISP), Prov: req.Prov, City: req.City, }, ) if err != nil { return err } // 返回结果 var resp = make([]*CreateChannelRespItem, len(result)) for i, channel := range result { resp[i] = &CreateChannelRespItem{ Proto: req.Protocol, Host: channel.Host, Port: channel.Port, } if req.AuthType == s.ChannelAuthTypePass { resp[i].Username = channel.Username resp[i].Password = channel.Password } } return c.JSON(resp) } type CreateChannelResultType string // endregion // region RemoveChannels type RemoveChannelsReq struct { Batch string `json:"batch" validate:"required"` } func RemoveChannels(c *fiber.Ctx) error { // 检查权限 _, err := auth.GetAuthCtx(c).PermitOfficialClient() if err != nil { return err } // 解析请求参数 req := new(RemoveChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 删除通道 err = s.Channel.RemoveChannels(req.Batch) if err != nil { return err } return c.SendStatus(fiber.StatusOK) } // endregion