package handlers import ( "fmt" "platform/web/auth" "platform/web/core" channel2 "platform/web/domains/channel" "platform/web/globals/orm" q "platform/web/queries" s "platform/web/services" "time" "github.com/gofiber/fiber/v2" ) // region ListChannels type ListChannelsReq struct { core.PageReq AuthType s.ChannelAuthType `json:"auth_type"` ExpireAfter *time.Time `json:"expire_after"` ExpireBefore *time.Time `json:"expire_before"` } func ListChannels(c *fiber.Ctx) error { // 检查权限 authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } // 解析请求参数 req := new(ListChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 构造查询条件 cond := q.Channel. Where(q.Channel.UserID.Eq(authContext.Payload.Id)) switch req.AuthType { case s.ChannelAuthTypeIp: cond.Where(q.Channel.AuthIP.Is(true)) case s.ChannelAuthTypePass: cond.Where(q.Channel.AuthPass.Is(true)) default: break } if req.ExpireAfter != nil { cond.Where(q.Channel.Expiration.Gte(orm.LocalDateTime(*req.ExpireAfter))) } if req.ExpireBefore != nil { cond.Where(q.Channel.Expiration.Lte(orm.LocalDateTime(*req.ExpireBefore))) } // 查询数据 channels, err := q.Channel.Debug(). Where(cond). Order(q.Channel.CreatedAt.Desc()). Offset(req.GetOffset()). Limit(req.GetLimit()). Find() if err != nil { return err } // 查询总量 var total int64 if len(channels) < req.GetLimit() { total = int64(len(channels) + req.GetOffset()) } else { total, err = q.Channel.Where(cond).Count() if err != nil { return err } } // 返回结果 return c.JSON(core.PageResp{ Total: int(total), Page: req.GetPage(), Size: req.GetSize(), List: channels, }) } // endregion // region CreateChannel type CreateChannelReq struct { ResourceId int32 `json:"resource_id" validate:"required"` AuthType s.ChannelAuthType `json:"auth_type" validate:"required"` Protocol channel2.Protocol `json:"protocol" validate:"required"` Count int `json:"count" validate:"required"` Prov string `json:"prov"` City string `json:"city"` Isp string `json:"isp"` } type CreateChannelRespItem struct { Proto channel2.Protocol `json:"-"` Host string `json:"host"` Port int32 `json:"port"` Username *string `json:"username,omitempty"` Password *string `json:"password,omitempty"` } func CreateChannel(c *fiber.Ctx) error { // 检查权限 authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } // 检查用户其他权限 user, err := q.User. Where(q.User.ID.Eq(authContext.Payload.Id)). Take() if err != nil { return err } if user.IDToken == nil || *user.IDToken == "" { return fiber.NewError(fiber.StatusForbidden, "账号未实名") } count, err := q.Whitelist.Where( q.Whitelist.UserID.Eq(authContext.Payload.Id), q.Whitelist.Host.Eq(c.IP()), ).Count() if err != nil { return err } if count == 0 { return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("非白名单IP %s", c.IP())) } req := new(CreateChannelReq) if err := c.BodyParser(req); err != nil { return err } var isp string switch req.Isp { case "1": isp = "电信" case "2": isp = "联通" case "3": isp = "移动" } // 创建通道 result, err := s.Channel.CreateChannel( authContext.Payload.Id, req.ResourceId, req.Protocol, req.AuthType, req.Count, s.EdgeFilter{ Isp: isp, Prov: req.Prov, City: req.City, }, ) if err != nil { return err } // 返回结果 var resp = make([]*CreateChannelRespItem, len(result)) for i, channel := range result { resp[i] = &CreateChannelRespItem{ Proto: req.Protocol, Host: channel.ProxyHost, Port: channel.ProxyPort, } if req.AuthType == s.ChannelAuthTypePass { resp[i].Username = channel.Username resp[i].Password = channel.Password } } return c.JSON(resp) } type CreateChannelResultType string // endregion // region RemoveChannels type RemoveChannelsReq struct { ByIds []int32 `json:"by_ids" validate:"required"` } func RemoveChannels(c *fiber.Ctx) error { // 检查权限 authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do() if err != nil { return err } // 解析请求参数 req := new(RemoveChannelsReq) if err := c.BodyParser(req); err != nil { return err } // 删除通道 err = s.Channel.RemoveChannels(req.ByIds, authCtx.Payload.Id) if err != nil { return err } return c.SendStatus(fiber.StatusOK) } type RemoveChannelByTaskReq []int32 func RemoveChannelByTask(c *fiber.Ctx) error { // 检查权限 _, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do() if err != nil { return err } // 解析请求参数 var req RemoveChannelByTaskReq if err := c.BodyParser(&req); err != nil { return err } // 删除通道 err = s.Channel.RemoveChannels(req) if err != nil { return err } return c.SendStatus(fiber.StatusOK) } // endregion