修正基本认证解码方式;添加查询 channels 接口
This commit is contained in:
@@ -81,7 +81,7 @@ func authBearer(ctx context.Context, token string) (*services.AuthContext, error
|
||||
func authBasic(_ context.Context, token string) (*services.AuthContext, error) {
|
||||
|
||||
// 解析 Basic 认证信息
|
||||
var base, err = base64.URLEncoding.DecodeString(token)
|
||||
var base, err = base64.RawURLEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil, err
|
||||
|
||||
@@ -152,7 +152,7 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
|
||||
if header != "" {
|
||||
basic := strings.TrimPrefix(header, "Basic ")
|
||||
if basic != "" {
|
||||
base, err := base64.URLEncoding.DecodeString(basic)
|
||||
base, err := base64.RawURLEncoding.DecodeString(basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,30 +1,107 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"platform/web/auth"
|
||||
"platform/web/common"
|
||||
q "platform/web/queries"
|
||||
"platform/web/services"
|
||||
s "platform/web/services"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// region ListChannels
|
||||
|
||||
type ListChannelsReq struct {
|
||||
common.PageReq
|
||||
AuthType s.ChannelAuthType `json:"auth_type"`
|
||||
ExpireAfter *time.Time `json:"expire_after"`
|
||||
ExpireBefore *time.Time `json:"expire_before"`
|
||||
}
|
||||
|
||||
func ListChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(ListChannelsReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 构造查询条件
|
||||
cond := q.Channel.
|
||||
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
|
||||
switch req.AuthType {
|
||||
case s.ChannelAuthTypeIp:
|
||||
cond.Where(q.Channel.AuthIP.Is(true))
|
||||
case s.ChannelAuthTypePass:
|
||||
cond.Where(q.Channel.AuthPass.Is(true))
|
||||
default:
|
||||
break
|
||||
}
|
||||
|
||||
if req.ExpireAfter != nil {
|
||||
cond.Where(q.Channel.Expiration.Gte(common.LocalDateTime(*req.ExpireAfter)))
|
||||
}
|
||||
if req.ExpireBefore != nil {
|
||||
cond.Where(q.Channel.Expiration.Lte(common.LocalDateTime(*req.ExpireBefore)))
|
||||
}
|
||||
|
||||
// 查询数据
|
||||
channels, err := q.Channel.Debug().
|
||||
Where(cond).
|
||||
Order(q.Channel.CreatedAt.Desc()).
|
||||
Offset(req.GetOffset()).
|
||||
Limit(req.GetLimit()).
|
||||
Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 查询总量
|
||||
var total int64
|
||||
if len(channels) < req.GetLimit() {
|
||||
total = int64(len(channels) + req.GetOffset())
|
||||
} else {
|
||||
total, err = q.Channel.Debug().
|
||||
Where(cond).
|
||||
Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 返回结果
|
||||
return c.JSON(common.PageResp{
|
||||
Total: int(total),
|
||||
Page: req.GetPage(),
|
||||
Size: req.GetSize(),
|
||||
List: channels,
|
||||
})
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region CreateChannel
|
||||
|
||||
type CreateChannelReq struct {
|
||||
ResourceId int32 `json:"resource_id" validate:"required"`
|
||||
AuthType services.ChannelAuthType `json:"auth_type" validate:"required"`
|
||||
Protocol services.ChannelProtocol `json:"protocol" validate:"required"`
|
||||
Count int `json:"count" validate:"required"`
|
||||
Prov string `json:"prov"`
|
||||
City string `json:"city"`
|
||||
Isp string `json:"isp"`
|
||||
ResourceId int32 `json:"resource_id" validate:"required"`
|
||||
AuthType s.ChannelAuthType `json:"auth_type" validate:"required"`
|
||||
Protocol s.ChannelProtocol `json:"protocol" validate:"required"`
|
||||
Count int `json:"count" validate:"required"`
|
||||
Prov string `json:"prov"`
|
||||
City string `json:"city"`
|
||||
Isp string `json:"isp"`
|
||||
}
|
||||
|
||||
func CreateChannel(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
||||
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -65,14 +142,14 @@ func CreateChannel(c *fiber.Ctx) error {
|
||||
isp = "移动"
|
||||
}
|
||||
|
||||
result, err := services.Channel.CreateChannel(
|
||||
result, err := s.Channel.CreateChannel(
|
||||
c.Context(),
|
||||
authContext,
|
||||
req.ResourceId,
|
||||
req.Protocol,
|
||||
req.AuthType,
|
||||
req.Count,
|
||||
services.NodeFilterConfig{
|
||||
s.NodeFilterConfig{
|
||||
Isp: isp,
|
||||
Prov: req.Prov,
|
||||
City: req.City,
|
||||
@@ -87,11 +164,6 @@ func CreateChannel(c *fiber.Ctx) error {
|
||||
|
||||
type CreateChannelResultType string
|
||||
|
||||
const (
|
||||
CreateChannelResultTypeJson CreateChannelResultType = "json"
|
||||
CreateChannelResultTypeText CreateChannelResultType = "text"
|
||||
)
|
||||
|
||||
// endregion
|
||||
|
||||
// region RemoveChannels
|
||||
@@ -101,19 +173,23 @@ type RemoveChannelsReq struct {
|
||||
}
|
||||
|
||||
func RemoveChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []s.PayloadType{
|
||||
s.PayloadUser,
|
||||
s.PayloadClientConfidential,
|
||||
}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(RemoveChannelsReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
authCtx, ok := c.Locals("auth").(*services.AuthContext)
|
||||
if !ok {
|
||||
return errors.New("user not found")
|
||||
}
|
||||
|
||||
// 删除通道
|
||||
err := services.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...)
|
||||
err = s.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func ApplyRouters(app *fiber.App) {
|
||||
|
||||
// 通道
|
||||
channel := api.Group("/channel")
|
||||
channel.Post("/list", handlers.ListChannels)
|
||||
channel.Post("/create", handlers.CreateChannel)
|
||||
channel.Post("/remove", handlers.RemoveChannels)
|
||||
|
||||
|
||||
@@ -620,6 +620,7 @@ func assignPort(
|
||||
UserID: userId,
|
||||
ProxyID: proxy.ID,
|
||||
UserHost: item,
|
||||
ProxyHost: proxy.Host,
|
||||
ProxyPort: int32(port),
|
||||
AuthIP: true,
|
||||
AuthPass: false,
|
||||
@@ -639,6 +640,7 @@ func assignPort(
|
||||
channels = append(channels, &models.Channel{
|
||||
UserID: userId,
|
||||
ProxyID: proxy.ID,
|
||||
ProxyHost: proxy.Host,
|
||||
ProxyPort: int32(port),
|
||||
AuthIP: false,
|
||||
AuthPass: true,
|
||||
@@ -654,6 +656,8 @@ func assignPort(
|
||||
Username: &username,
|
||||
Password: &password,
|
||||
})
|
||||
default:
|
||||
return nil, nil, ChannelServiceErr("不支持的通道认证方式")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user