重构代码结构与认证体系,集成异步任务消费者
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
q "platform/web/queries"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// region ListAnnouncements
|
||||
@@ -16,7 +17,7 @@ type ListAnnouncementsRequest struct {
|
||||
func ListAnnouncements(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,277 +1,14 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"platform/pkg/u"
|
||||
auth2 "platform/web/auth"
|
||||
client2 "platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// region /token
|
||||
|
||||
type TokenReq struct {
|
||||
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
s.GrantCodeData
|
||||
s.GrantClientData
|
||||
s.GrantRefreshData
|
||||
s.GrantPasswordData
|
||||
}
|
||||
|
||||
type TokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type TokenErrResp struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// Token 处理 OAuth2.0 授权请求
|
||||
func Token(c *fiber.Ctx) error {
|
||||
|
||||
// 验证请求参数
|
||||
req := new(TokenReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
|
||||
}
|
||||
if req.GrantType == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
||||
}
|
||||
|
||||
slog.Debug("oauth token", slog.String("grant_type",
|
||||
string(req.GrantType)),
|
||||
slog.String("client_id", req.ClientID),
|
||||
)
|
||||
|
||||
// 基于授权类型处理请求
|
||||
switch req.GrantType {
|
||||
|
||||
// 授权码模式
|
||||
case auth2.GrantAuthorizationCode:
|
||||
if req.Code == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
if err != nil {
|
||||
return sendError(c, err.(s.AuthServiceError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 客户端凭证模式
|
||||
case auth2.GrantClientCredentials:
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
||||
if err != nil {
|
||||
return sendError(c, err.(s.AuthServiceError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 刷新令牌模式
|
||||
case auth2.GrantRefreshToken:
|
||||
if req.RefreshToken == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth2.ErrInvalidRefreshToken) {
|
||||
return sendError(c, s.ErrOauthInvalidGrant)
|
||||
}
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 密码模式
|
||||
case auth2.GrantPassword:
|
||||
if req.LoginType == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
||||
}
|
||||
if req.Username == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
default:
|
||||
return sendError(c, s.ErrOauthUnsupportedGrantType)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查客户端凭证
|
||||
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
|
||||
header := c.Get("Authorization")
|
||||
if header != "" {
|
||||
basic := strings.TrimPrefix(header, "Basic ")
|
||||
if basic != "" {
|
||||
base, err := base64.RawURLEncoding.DecodeString(basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := strings.SplitN(string(base), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
clientId = parts[0]
|
||||
clientSecret = parts[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找客户端
|
||||
if clientId == "" {
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证客户端状态
|
||||
if client.Status != 1 {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
switch grant {
|
||||
case auth2.GrantAuthorizationCode:
|
||||
if !client.GrantCode {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantClientCredentials:
|
||||
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantRefreshToken:
|
||||
if !client.GrantRefresh {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantPassword:
|
||||
if !client.GrantPassword {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端是 confidential,验证 client_secret,失败返回错误
|
||||
if client.Spec == int32(client2.SpecWeb) || client.Spec == int32(client2.SpecTrusted) {
|
||||
if clientSecret == "" {
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
}
|
||||
|
||||
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
|
||||
auth2.Locals(c, &auth2.Context{
|
||||
Payload: auth2.Payload{
|
||||
Id: client.ID,
|
||||
Type: auth2.PayloadSecuredServer,
|
||||
Name: client.Name,
|
||||
Avatar: client.Icon,
|
||||
},
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// 发送成功响应
|
||||
func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
|
||||
return c.JSON(TokenResp{
|
||||
AccessToken: details.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
|
||||
RefreshToken: details.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// 发送错误响应
|
||||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||||
var sErr s.AuthServiceError
|
||||
if errors.As(err, &sErr) {
|
||||
status := fiber.StatusBadRequest
|
||||
var desc string
|
||||
switch {
|
||||
case errors.Is(sErr, s.ErrOauthInvalidRequest):
|
||||
desc = "无效的请求"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidClient):
|
||||
status = fiber.StatusUnauthorized
|
||||
desc = "无效的客户端凭证"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidGrant):
|
||||
desc = "无效的授权凭证"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidScope):
|
||||
desc = "无效的授权范围"
|
||||
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
|
||||
desc = "未授权的客户端"
|
||||
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
|
||||
desc = "不支持的授权类型"
|
||||
}
|
||||
if len(description) > 0 {
|
||||
desc = description[0]
|
||||
}
|
||||
|
||||
return c.Status(status).JSON(TokenErrResp{
|
||||
Error: string(sErr),
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region /revoke
|
||||
|
||||
type RevokeReq struct {
|
||||
@@ -280,7 +17,7 @@ type RevokeReq struct {
|
||||
}
|
||||
|
||||
func Revoke(c *fiber.Ctx) error {
|
||||
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||
_, err := auth2.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
// 用户未登录
|
||||
return nil
|
||||
@@ -312,14 +49,14 @@ type IntrospectResp struct {
|
||||
|
||||
func Introspect(c *fiber.Ctx) error {
|
||||
// 验证权限
|
||||
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||
authCtx, err := auth2.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
profile, err := q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Omit(q.User.DeletedAt).
|
||||
Take()
|
||||
if err != nil {
|
||||
|
||||
@@ -23,7 +23,7 @@ type ListBillReq struct {
|
||||
// ListBill 获取账单列表
|
||||
func ListBill(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func ListBill(c *fiber.Ctx) error {
|
||||
|
||||
// 查询账单列表
|
||||
do := q.Bill.
|
||||
Where(q.Bill.UserID.Eq(authContext.Payload.Id))
|
||||
Where(q.Bill.UserID.Eq(authCtx.User.ID))
|
||||
|
||||
if req.Type != nil {
|
||||
do.Where(q.Bill.Type.Eq(int32(*req.Type)))
|
||||
|
||||
@@ -24,7 +24,7 @@ type ListChannelsReq struct {
|
||||
|
||||
func ListChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authContext, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func ListChannels(c *fiber.Ctx) error {
|
||||
|
||||
// 构造查询条件
|
||||
cond := q.Channel.
|
||||
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
|
||||
Where(q.Channel.UserID.Eq(authContext.User.ID))
|
||||
switch req.AuthType {
|
||||
case s.ChannelAuthTypeIp:
|
||||
cond.Where(q.Channel.AuthIP.Is(true))
|
||||
@@ -110,24 +110,19 @@ type CreateChannelRespItem struct {
|
||||
func CreateChannel(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查用户其他权限
|
||||
user, err := q.User.
|
||||
Where(q.User.ID.Eq(authContext.Payload.Id)).
|
||||
Take()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := authCtx.User
|
||||
if user.IDToken == nil || *user.IDToken == "" {
|
||||
return fiber.NewError(fiber.StatusForbidden, "账号未实名")
|
||||
}
|
||||
|
||||
count, err := q.Whitelist.Where(
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(user.ID),
|
||||
q.Whitelist.Host.Eq(c.IP()),
|
||||
).Count()
|
||||
if err != nil {
|
||||
@@ -155,7 +150,7 @@ func CreateChannel(c *fiber.Ctx) error {
|
||||
// 创建通道
|
||||
result, err := s.Channel.CreateChannel(
|
||||
c,
|
||||
authContext.Payload.Id,
|
||||
user.ID,
|
||||
req.ResourceId,
|
||||
req.Protocol,
|
||||
req.AuthType,
|
||||
@@ -198,7 +193,7 @@ type RemoveChannelsReq struct {
|
||||
|
||||
func RemoveChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -210,31 +205,7 @@ func RemoveChannels(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 删除通道
|
||||
err = s.Channel.RemoveChannels(req.ByIds, authCtx.Payload.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
type RemoveChannelByTaskReq []int32
|
||||
|
||||
func RemoveChannelByTask(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
var req RemoveChannelByTaskReq
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除通道
|
||||
err = s.Channel.RemoveChannels(req)
|
||||
err = s.Channel.RemoveChannels(req.ByIds, authCtx.User.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
"log/slog"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
@@ -14,6 +12,9 @@ import (
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -120,7 +121,7 @@ type AllEdgesAvailableRespItem struct {
|
||||
|
||||
func AllEdgesAvailable(c *fiber.Ctx) (err error) {
|
||||
// 检查权限
|
||||
_, err = auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
_, err = auth.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -37,17 +37,11 @@ type IdentifyRes struct {
|
||||
func Identify(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Select(q.User.ID, q.User.IDToken).
|
||||
Take()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := authCtx.User
|
||||
if user.IDToken != nil && *user.IDToken != "" {
|
||||
// 用户已实名认证
|
||||
return c.JSON(IdentifyRes{
|
||||
@@ -86,7 +80,7 @@ func Identify(c *fiber.Ctx) error {
|
||||
|
||||
// 保存认证中间状态
|
||||
info := idenInfo{
|
||||
Uid: authCtx.Payload.Id,
|
||||
Uid: user.ID,
|
||||
Type: req.Type,
|
||||
Name: req.Name,
|
||||
IdNo: req.IdenNo,
|
||||
|
||||
@@ -40,9 +40,7 @@ type ProxyReportOnlineResp struct {
|
||||
func ProxyReportOnline(c *fiber.Ctx) (err error) {
|
||||
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,9 +147,7 @@ type ProxyReportOfflineReq struct {
|
||||
|
||||
func ProxyReportOffline(c *fiber.Ctx) (err error) {
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -193,9 +189,7 @@ type ProxyReportUpdateReq struct {
|
||||
|
||||
func ProxyReportUpdate(c *fiber.Ctx) (err error) {
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gorm.io/gen/field"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
@@ -12,6 +11,8 @@ import (
|
||||
s "platform/web/services"
|
||||
"time"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ type ListResourceShortReq struct {
|
||||
|
||||
func ListResourceShort(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func ListResourceShort(c *fiber.Ctx) error {
|
||||
|
||||
// 查询套餐列表
|
||||
do := q.Resource.Where(
|
||||
q.Resource.UserID.Eq(authContext.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Type.Eq(int32(resource2.TypeShort)),
|
||||
)
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
@@ -109,7 +110,7 @@ type ListResourceLongReq struct {
|
||||
|
||||
func ListResourceLong(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func ListResourceLong(c *fiber.Ctx) error {
|
||||
|
||||
// 查询套餐列表
|
||||
do := q.Resource.Where(
|
||||
q.Resource.UserID.Eq(authContext.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Type.Eq(int32(resource2.TypeLong)),
|
||||
)
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
@@ -182,7 +183,7 @@ type AllResourceReq struct {
|
||||
|
||||
func AllActiveResource(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,7 +199,7 @@ func AllActiveResource(c *fiber.Ctx) error {
|
||||
q.Resource.Long,
|
||||
).
|
||||
Where(
|
||||
q.Resource.UserID.Eq(authCtx.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Active.Is(true),
|
||||
q.Resource.Where(
|
||||
q.Resource.Type.Eq(int32(resource2.TypeShort)),
|
||||
@@ -254,7 +255,7 @@ type StatisticLong struct {
|
||||
|
||||
func StatisticResourceFree(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -266,7 +267,7 @@ func StatisticResourceFree(c *fiber.Ctx) error {
|
||||
q.Resource.Long,
|
||||
).
|
||||
Where(
|
||||
q.Resource.UserID.Eq(session.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Active.Is(true),
|
||||
).
|
||||
Select(q.Resource.ID, q.Resource.Type).
|
||||
@@ -347,7 +348,7 @@ type StatisticResourceUsageResp []struct {
|
||||
|
||||
func StatisticResourceUsage(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -359,12 +360,12 @@ func StatisticResourceUsage(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 统计套餐提取数量
|
||||
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(session.Payload.Id))
|
||||
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(authCtx.User.ID))
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
var resourceID int32
|
||||
err := q.Resource.
|
||||
Where(
|
||||
q.Resource.UserID.Eq(session.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.ResourceNo.Eq(*req.ResourceNo),
|
||||
).
|
||||
Select(q.Resource.ID).
|
||||
@@ -409,7 +410,7 @@ type CreateResourceReq struct {
|
||||
func CreateResource(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -421,7 +422,7 @@ func CreateResource(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 创建套餐
|
||||
err = s.Resource.CreateResourceByBalance(authCtx.Payload.Id, time.Now(), req.CreateResourceData)
|
||||
err = s.Resource.CreateResourceByBalance(authCtx.User.ID, time.Now(), req.CreateResourceData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -431,7 +432,7 @@ func CreateResource(c *fiber.Ctx) error {
|
||||
|
||||
func ResourcePrice(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
_, err := auth.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
trade2 "platform/web/domains/trade"
|
||||
g "platform/web/globals"
|
||||
s "platform/web/services"
|
||||
"platform/web/tasks"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -27,7 +26,7 @@ type TradeCreateResp struct {
|
||||
|
||||
func TradeCreate(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +51,7 @@ func TradeCreate(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 创建交易
|
||||
result, err := s.Trade.CreateTrade(authCtx.Payload.Id, time.Now(), &req.CreateTradeData)
|
||||
result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData)
|
||||
if err != nil {
|
||||
slog.Error("创建交易失败", "error", err)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"})
|
||||
@@ -70,7 +69,7 @@ type TradeCompleteReq struct {
|
||||
|
||||
func TradeComplete(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -99,7 +98,7 @@ type TradeCancelReq struct {
|
||||
|
||||
func TradeCancel(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,29 +118,3 @@ func TradeCancel(c *fiber.Ctx) error {
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
type TradeCheckReq struct {
|
||||
tasks.CancelTradeData
|
||||
}
|
||||
|
||||
func TradeCancelByTask(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 取消交易
|
||||
req := new(TradeCheckReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查订单状态
|
||||
err = s.Trade.CancelTrade(req.TradeNo, req.Method, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"platform/web/auth"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// region /update
|
||||
@@ -20,7 +21,7 @@ type UpdateUserReq struct {
|
||||
|
||||
func UpdateUser(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -33,7 +34,7 @@ func UpdateUser(c *fiber.Ctx) error {
|
||||
|
||||
// 更新用户信息
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Updates(m.User{
|
||||
Username: &req.Username,
|
||||
Email: &req.Email,
|
||||
@@ -59,7 +60,7 @@ type UpdateAccountReq struct {
|
||||
|
||||
func UpdateAccount(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -72,7 +73,7 @@ func UpdateAccount(c *fiber.Ctx) error {
|
||||
|
||||
// 更新用户信息
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Updates(m.User{
|
||||
Username: &req.Username,
|
||||
Password: &req.Password,
|
||||
@@ -97,7 +98,7 @@ type UpdatePasswordReq struct {
|
||||
|
||||
func UpdatePassword(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func UpdatePassword(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
UpdateColumn(q.User.Password, newHash)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,12 +2,14 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"platform/pkg/env"
|
||||
"platform/web/auth"
|
||||
"platform/web/services"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type VerifierReq struct {
|
||||
@@ -17,7 +19,7 @@ type VerifierReq struct {
|
||||
|
||||
func SmsCode(c *fiber.Ctx) error {
|
||||
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitInternalClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -48,3 +50,19 @@ func SmsCode(c *fiber.Ctx) error {
|
||||
// 发送成功
|
||||
return nil
|
||||
}
|
||||
|
||||
func DebugGetSmsCode(c *fiber.Ctx) error {
|
||||
if env.RunMode != env.RunModeDev {
|
||||
return fiber.NewError(fiber.StatusForbidden, "not allowed")
|
||||
}
|
||||
|
||||
code, err := services.Verifier.GetSms(c.Context(), c.Params("phone"))
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return c.SendString("还没有验证码")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return c.SendString(code)
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type ListWhitelistResp struct {
|
||||
func ListWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -38,8 +38,7 @@ func ListWhitelist(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 获取白名单信息
|
||||
do := q.Whitelist.
|
||||
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id))
|
||||
do := q.Whitelist.Where(q.Whitelist.UserID.Eq(authCtx.User.ID))
|
||||
|
||||
list, err := q.Whitelist.Where(do).
|
||||
Offset(req.GetOffset()).
|
||||
@@ -77,7 +76,7 @@ type CreateWhitelistReq struct {
|
||||
func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -96,7 +95,7 @@ func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 创建白名单
|
||||
err = q.Whitelist.Create(&m.Whitelist{
|
||||
UserID: authContext.Payload.Id,
|
||||
UserID: authCtx.User.ID,
|
||||
Host: req.Host,
|
||||
Remark: &req.Remark,
|
||||
})
|
||||
@@ -111,7 +110,7 @@ type UpdateWhitelistReq struct {
|
||||
|
||||
func UpdateWhitelist(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -129,7 +128,7 @@ func UpdateWhitelist(c *fiber.Ctx) error {
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.Eq(req.ID),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(authCtx.User.ID),
|
||||
).
|
||||
Updates(&m.Whitelist{
|
||||
ID: req.ID,
|
||||
@@ -149,7 +148,7 @@ type RemoveWhitelistReq struct {
|
||||
func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -175,7 +174,7 @@ func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.In(ids...),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(authCtx.User.ID),
|
||||
).
|
||||
Update(
|
||||
q.Whitelist.DeletedAt, time.Now(),
|
||||
|
||||
Reference in New Issue
Block a user