修正基本认证解码方式;添加查询 channels 接口
This commit is contained in:
@@ -81,7 +81,7 @@ func authBearer(ctx context.Context, token string) (*services.AuthContext, error
|
|||||||
func authBasic(_ context.Context, token string) (*services.AuthContext, error) {
|
func authBasic(_ context.Context, token string) (*services.AuthContext, error) {
|
||||||
|
|
||||||
// 解析 Basic 认证信息
|
// 解析 Basic 认证信息
|
||||||
var base, err = base64.URLEncoding.DecodeString(token)
|
var base, err = base64.RawURLEncoding.DecodeString(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Debug(err.Error())
|
slog.Debug(err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
|
|||||||
if header != "" {
|
if header != "" {
|
||||||
basic := strings.TrimPrefix(header, "Basic ")
|
basic := strings.TrimPrefix(header, "Basic ")
|
||||||
if basic != "" {
|
if basic != "" {
|
||||||
base, err := base64.URLEncoding.DecodeString(basic)
|
base, err := base64.RawURLEncoding.DecodeString(basic)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +1,98 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"platform/web/auth"
|
"platform/web/auth"
|
||||||
|
"platform/web/common"
|
||||||
q "platform/web/queries"
|
q "platform/web/queries"
|
||||||
"platform/web/services"
|
s "platform/web/services"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// region ListChannels
|
||||||
|
|
||||||
|
type ListChannelsReq struct {
|
||||||
|
common.PageReq
|
||||||
|
AuthType s.ChannelAuthType `json:"auth_type"`
|
||||||
|
ExpireAfter *time.Time `json:"expire_after"`
|
||||||
|
ExpireBefore *time.Time `json:"expire_before"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListChannels(c *fiber.Ctx) error {
|
||||||
|
// 检查权限
|
||||||
|
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求参数
|
||||||
|
req := new(ListChannelsReq)
|
||||||
|
if err := c.BodyParser(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造查询条件
|
||||||
|
cond := q.Channel.
|
||||||
|
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
|
||||||
|
switch req.AuthType {
|
||||||
|
case s.ChannelAuthTypeIp:
|
||||||
|
cond.Where(q.Channel.AuthIP.Is(true))
|
||||||
|
case s.ChannelAuthTypePass:
|
||||||
|
cond.Where(q.Channel.AuthPass.Is(true))
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.ExpireAfter != nil {
|
||||||
|
cond.Where(q.Channel.Expiration.Gte(common.LocalDateTime(*req.ExpireAfter)))
|
||||||
|
}
|
||||||
|
if req.ExpireBefore != nil {
|
||||||
|
cond.Where(q.Channel.Expiration.Lte(common.LocalDateTime(*req.ExpireBefore)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询数据
|
||||||
|
channels, err := q.Channel.Debug().
|
||||||
|
Where(cond).
|
||||||
|
Order(q.Channel.CreatedAt.Desc()).
|
||||||
|
Offset(req.GetOffset()).
|
||||||
|
Limit(req.GetLimit()).
|
||||||
|
Find()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查询总量
|
||||||
|
var total int64
|
||||||
|
if len(channels) < req.GetLimit() {
|
||||||
|
total = int64(len(channels) + req.GetOffset())
|
||||||
|
} else {
|
||||||
|
total, err = q.Channel.Debug().
|
||||||
|
Where(cond).
|
||||||
|
Count()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回结果
|
||||||
|
return c.JSON(common.PageResp{
|
||||||
|
Total: int(total),
|
||||||
|
Page: req.GetPage(),
|
||||||
|
Size: req.GetSize(),
|
||||||
|
List: channels,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// endregion
|
||||||
|
|
||||||
// region CreateChannel
|
// region CreateChannel
|
||||||
|
|
||||||
type CreateChannelReq struct {
|
type CreateChannelReq struct {
|
||||||
ResourceId int32 `json:"resource_id" validate:"required"`
|
ResourceId int32 `json:"resource_id" validate:"required"`
|
||||||
AuthType services.ChannelAuthType `json:"auth_type" validate:"required"`
|
AuthType s.ChannelAuthType `json:"auth_type" validate:"required"`
|
||||||
Protocol services.ChannelProtocol `json:"protocol" validate:"required"`
|
Protocol s.ChannelProtocol `json:"protocol" validate:"required"`
|
||||||
Count int `json:"count" validate:"required"`
|
Count int `json:"count" validate:"required"`
|
||||||
Prov string `json:"prov"`
|
Prov string `json:"prov"`
|
||||||
City string `json:"city"`
|
City string `json:"city"`
|
||||||
@@ -24,7 +101,7 @@ type CreateChannelReq struct {
|
|||||||
|
|
||||||
func CreateChannel(c *fiber.Ctx) error {
|
func CreateChannel(c *fiber.Ctx) error {
|
||||||
// 检查权限
|
// 检查权限
|
||||||
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
|
authContext, err := auth.Protect(c, []s.PayloadType{s.PayloadUser}, []string{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -65,14 +142,14 @@ func CreateChannel(c *fiber.Ctx) error {
|
|||||||
isp = "移动"
|
isp = "移动"
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := services.Channel.CreateChannel(
|
result, err := s.Channel.CreateChannel(
|
||||||
c.Context(),
|
c.Context(),
|
||||||
authContext,
|
authContext,
|
||||||
req.ResourceId,
|
req.ResourceId,
|
||||||
req.Protocol,
|
req.Protocol,
|
||||||
req.AuthType,
|
req.AuthType,
|
||||||
req.Count,
|
req.Count,
|
||||||
services.NodeFilterConfig{
|
s.NodeFilterConfig{
|
||||||
Isp: isp,
|
Isp: isp,
|
||||||
Prov: req.Prov,
|
Prov: req.Prov,
|
||||||
City: req.City,
|
City: req.City,
|
||||||
@@ -87,11 +164,6 @@ func CreateChannel(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
type CreateChannelResultType string
|
type CreateChannelResultType string
|
||||||
|
|
||||||
const (
|
|
||||||
CreateChannelResultTypeJson CreateChannelResultType = "json"
|
|
||||||
CreateChannelResultTypeText CreateChannelResultType = "text"
|
|
||||||
)
|
|
||||||
|
|
||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
// region RemoveChannels
|
// region RemoveChannels
|
||||||
@@ -101,19 +173,23 @@ type RemoveChannelsReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func RemoveChannels(c *fiber.Ctx) error {
|
func RemoveChannels(c *fiber.Ctx) error {
|
||||||
|
// 检查权限
|
||||||
|
authCtx, err := auth.Protect(c, []s.PayloadType{
|
||||||
|
s.PayloadUser,
|
||||||
|
s.PayloadClientConfidential,
|
||||||
|
}, []string{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求参数
|
||||||
req := new(RemoveChannelsReq)
|
req := new(RemoveChannelsReq)
|
||||||
if err := c.BodyParser(req); err != nil {
|
if err := c.BodyParser(req); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取用户信息
|
|
||||||
authCtx, ok := c.Locals("auth").(*services.AuthContext)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 删除通道
|
// 删除通道
|
||||||
err := services.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...)
|
err = s.Channel.RemoveChannels(c.Context(), authCtx, req.ByIds...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ func ApplyRouters(app *fiber.App) {
|
|||||||
|
|
||||||
// 通道
|
// 通道
|
||||||
channel := api.Group("/channel")
|
channel := api.Group("/channel")
|
||||||
|
channel.Post("/list", handlers.ListChannels)
|
||||||
channel.Post("/create", handlers.CreateChannel)
|
channel.Post("/create", handlers.CreateChannel)
|
||||||
channel.Post("/remove", handlers.RemoveChannels)
|
channel.Post("/remove", handlers.RemoveChannels)
|
||||||
|
|
||||||
|
|||||||
@@ -620,6 +620,7 @@ func assignPort(
|
|||||||
UserID: userId,
|
UserID: userId,
|
||||||
ProxyID: proxy.ID,
|
ProxyID: proxy.ID,
|
||||||
UserHost: item,
|
UserHost: item,
|
||||||
|
ProxyHost: proxy.Host,
|
||||||
ProxyPort: int32(port),
|
ProxyPort: int32(port),
|
||||||
AuthIP: true,
|
AuthIP: true,
|
||||||
AuthPass: false,
|
AuthPass: false,
|
||||||
@@ -639,6 +640,7 @@ func assignPort(
|
|||||||
channels = append(channels, &models.Channel{
|
channels = append(channels, &models.Channel{
|
||||||
UserID: userId,
|
UserID: userId,
|
||||||
ProxyID: proxy.ID,
|
ProxyID: proxy.ID,
|
||||||
|
ProxyHost: proxy.Host,
|
||||||
ProxyPort: int32(port),
|
ProxyPort: int32(port),
|
||||||
AuthIP: false,
|
AuthIP: false,
|
||||||
AuthPass: true,
|
AuthPass: true,
|
||||||
@@ -654,6 +656,8 @@ func assignPort(
|
|||||||
Username: &username,
|
Username: &username,
|
||||||
Password: &password,
|
Password: &password,
|
||||||
})
|
})
|
||||||
|
default:
|
||||||
|
return nil, nil, ChannelServiceErr("不支持的通道认证方式")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user