重构代码结构与认证体系,集成异步任务消费者

This commit is contained in:
2025-11-17 18:38:10 +08:00
parent a97c970166
commit a245229bc2
70 changed files with 2000 additions and 2334 deletions

View File

@@ -1,10 +1,11 @@
package handlers
import (
"github.com/gofiber/fiber/v2"
"platform/web/auth"
"platform/web/core"
q "platform/web/queries"
"github.com/gofiber/fiber/v2"
)
// region ListAnnouncements
@@ -16,7 +17,7 @@ type ListAnnouncementsRequest struct {
func ListAnnouncements(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}

View File

@@ -1,277 +1,14 @@
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
"platform/pkg/u"
auth2 "platform/web/auth"
client2 "platform/web/domains/client"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region /token
type TokenReq struct {
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
s.GrantCodeData
s.GrantClientData
s.GrantRefreshData
s.GrantPasswordData
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
// 授权码模式
case auth2.GrantAuthorizationCode:
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 客户端凭证模式
case auth2.GrantClientCredentials:
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 刷新令牌模式
case auth2.GrantRefreshToken:
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, auth2.ErrInvalidRefreshToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
// 密码模式
case auth2.GrantPassword:
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
if err != nil {
return sendError(c, err)
}
return sendSuccess(c, token)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.RawURLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case auth2.GrantAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantClientCredentials:
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantPassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == int32(client2.SpecWeb) || client.Spec == int32(client2.SpecTrusted) {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
auth2.Locals(c, &auth2.Context{
Payload: auth2.Payload{
Id: client.ID,
Type: auth2.PayloadSecuredServer,
Name: client.Name,
Avatar: client.Icon,
},
})
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion
// region /revoke
type RevokeReq struct {
@@ -280,7 +17,7 @@ type RevokeReq struct {
}
func Revoke(c *fiber.Ctx) error {
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
_, err := auth2.GetAuthCtx(c).PermitUser()
if err != nil {
// 用户未登录
return nil
@@ -312,14 +49,14 @@ type IntrospectResp struct {
func Introspect(c *fiber.Ctx) error {
// 验证权限
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
authCtx, err := auth2.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
// 获取用户信息
profile, err := q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Omit(q.User.DeletedAt).
Take()
if err != nil {

View File

@@ -23,7 +23,7 @@ type ListBillReq struct {
// ListBill 获取账单列表
func ListBill(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -36,7 +36,7 @@ func ListBill(c *fiber.Ctx) error {
// 查询账单列表
do := q.Bill.
Where(q.Bill.UserID.Eq(authContext.Payload.Id))
Where(q.Bill.UserID.Eq(authCtx.User.ID))
if req.Type != nil {
do.Where(q.Bill.Type.Eq(int32(*req.Type)))

View File

@@ -24,7 +24,7 @@ type ListChannelsReq struct {
func ListChannels(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authContext, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -37,7 +37,7 @@ func ListChannels(c *fiber.Ctx) error {
// 构造查询条件
cond := q.Channel.
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
Where(q.Channel.UserID.Eq(authContext.User.ID))
switch req.AuthType {
case s.ChannelAuthTypeIp:
cond.Where(q.Channel.AuthIP.Is(true))
@@ -110,24 +110,19 @@ type CreateChannelRespItem struct {
func CreateChannel(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
// 检查用户其他权限
user, err := q.User.
Where(q.User.ID.Eq(authContext.Payload.Id)).
Take()
if err != nil {
return err
}
user := authCtx.User
if user.IDToken == nil || *user.IDToken == "" {
return fiber.NewError(fiber.StatusForbidden, "账号未实名")
}
count, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(user.ID),
q.Whitelist.Host.Eq(c.IP()),
).Count()
if err != nil {
@@ -155,7 +150,7 @@ func CreateChannel(c *fiber.Ctx) error {
// 创建通道
result, err := s.Channel.CreateChannel(
c,
authContext.Payload.Id,
user.ID,
req.ResourceId,
req.Protocol,
req.AuthType,
@@ -198,7 +193,7 @@ type RemoveChannelsReq struct {
func RemoveChannels(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -210,31 +205,7 @@ func RemoveChannels(c *fiber.Ctx) error {
}
// 删除通道
err = s.Channel.RemoveChannels(req.ByIds, authCtx.Payload.Id)
if err != nil {
return err
}
return c.SendStatus(fiber.StatusOK)
}
type RemoveChannelByTaskReq []int32
func RemoveChannelByTask(c *fiber.Ctx) error {
// 检查权限
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
if err != nil {
return err
}
// 解析请求参数
var req RemoveChannelByTaskReq
if err := c.BodyParser(&req); err != nil {
return err
}
// 删除通道
err = s.Channel.RemoveChannels(req)
err = s.Channel.RemoveChannels(req.ByIds, authCtx.User.ID)
if err != nil {
return err
}

View File

@@ -2,8 +2,6 @@ package handlers
import (
"errors"
"gorm.io/gen/field"
"gorm.io/gorm"
"log/slog"
"platform/pkg/u"
"platform/web/auth"
@@ -14,6 +12,9 @@ import (
q "platform/web/queries"
s "platform/web/services"
"gorm.io/gen/field"
"gorm.io/gorm"
"github.com/gofiber/fiber/v2"
)
@@ -120,7 +121,7 @@ type AllEdgesAvailableRespItem struct {
func AllEdgesAvailable(c *fiber.Ctx) (err error) {
// 检查权限
_, err = auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
_, err = auth.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -37,17 +37,11 @@ type IdentifyRes struct {
func Identify(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil {
return err
}
user, err := q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Select(q.User.ID, q.User.IDToken).
Take()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
user := authCtx.User
if user.IDToken != nil && *user.IDToken != "" {
// 用户已实名认证
return c.JSON(IdentifyRes{
@@ -86,7 +80,7 @@ func Identify(c *fiber.Ctx) error {
// 保存认证中间状态
info := idenInfo{
Uid: authCtx.Payload.Id,
Uid: user.ID,
Type: req.Type,
Name: req.Name,
IdNo: req.IdenNo,

View File

@@ -40,9 +40,7 @@ type ProxyReportOnlineResp struct {
func ProxyReportOnline(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}
@@ -149,9 +147,7 @@ type ProxyReportOfflineReq struct {
func ProxyReportOffline(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}
@@ -193,9 +189,7 @@ type ProxyReportUpdateReq struct {
func ProxyReportUpdate(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"gorm.io/gen/field"
"platform/pkg/u"
"platform/web/auth"
"platform/web/core"
@@ -12,6 +11,8 @@ import (
s "platform/web/services"
"time"
"gorm.io/gen/field"
"github.com/gofiber/fiber/v2"
)
@@ -28,7 +29,7 @@ type ListResourceShortReq struct {
func ListResourceShort(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -41,7 +42,7 @@ func ListResourceShort(c *fiber.Ctx) error {
// 查询套餐列表
do := q.Resource.Where(
q.Resource.UserID.Eq(authContext.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Type.Eq(int32(resource2.TypeShort)),
)
if req.ResourceNo != nil && *req.ResourceNo != "" {
@@ -109,7 +110,7 @@ type ListResourceLongReq struct {
func ListResourceLong(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -122,7 +123,7 @@ func ListResourceLong(c *fiber.Ctx) error {
// 查询套餐列表
do := q.Resource.Where(
q.Resource.UserID.Eq(authContext.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Type.Eq(int32(resource2.TypeLong)),
)
if req.ResourceNo != nil && *req.ResourceNo != "" {
@@ -182,7 +183,7 @@ type AllResourceReq struct {
func AllActiveResource(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -198,7 +199,7 @@ func AllActiveResource(c *fiber.Ctx) error {
q.Resource.Long,
).
Where(
q.Resource.UserID.Eq(authCtx.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Active.Is(true),
q.Resource.Where(
q.Resource.Type.Eq(int32(resource2.TypeShort)),
@@ -254,7 +255,7 @@ type StatisticLong struct {
func StatisticResourceFree(c *fiber.Ctx) error {
// 检查权限
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -266,7 +267,7 @@ func StatisticResourceFree(c *fiber.Ctx) error {
q.Resource.Long,
).
Where(
q.Resource.UserID.Eq(session.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Active.Is(true),
).
Select(q.Resource.ID, q.Resource.Type).
@@ -347,7 +348,7 @@ type StatisticResourceUsageResp []struct {
func StatisticResourceUsage(c *fiber.Ctx) error {
// 检查权限
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -359,12 +360,12 @@ func StatisticResourceUsage(c *fiber.Ctx) error {
}
// 统计套餐提取数量
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(session.Payload.Id))
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(authCtx.User.ID))
if req.ResourceNo != nil && *req.ResourceNo != "" {
var resourceID int32
err := q.Resource.
Where(
q.Resource.UserID.Eq(session.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.ResourceNo.Eq(*req.ResourceNo),
).
Select(q.Resource.ID).
@@ -409,7 +410,7 @@ type CreateResourceReq struct {
func CreateResource(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -421,7 +422,7 @@ func CreateResource(c *fiber.Ctx) error {
}
// 创建套餐
err = s.Resource.CreateResourceByBalance(authCtx.Payload.Id, time.Now(), req.CreateResourceData)
err = s.Resource.CreateResourceByBalance(authCtx.User.ID, time.Now(), req.CreateResourceData)
if err != nil {
return err
}
@@ -431,7 +432,7 @@ func CreateResource(c *fiber.Ctx) error {
func ResourcePrice(c *fiber.Ctx) error {
// 检查权限
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
_, err := auth.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -7,7 +7,6 @@ import (
trade2 "platform/web/domains/trade"
g "platform/web/globals"
s "platform/web/services"
"platform/web/tasks"
"time"
"github.com/gofiber/fiber/v2"
@@ -27,7 +26,7 @@ type TradeCreateResp struct {
func TradeCreate(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -52,7 +51,7 @@ func TradeCreate(c *fiber.Ctx) error {
}
// 创建交易
result, err := s.Trade.CreateTrade(authCtx.Payload.Id, time.Now(), &req.CreateTradeData)
result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData)
if err != nil {
slog.Error("创建交易失败", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"})
@@ -70,7 +69,7 @@ type TradeCompleteReq struct {
func TradeComplete(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -99,7 +98,7 @@ type TradeCancelReq struct {
func TradeCancel(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -119,29 +118,3 @@ func TradeCancel(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
}
type TradeCheckReq struct {
tasks.CancelTradeData
}
func TradeCancelByTask(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
if err != nil {
return err
}
// 取消交易
req := new(TradeCheckReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 检查订单状态
err = s.Trade.CancelTrade(req.TradeNo, req.Method, time.Now())
if err != nil {
return err
}
return nil
}

View File

@@ -1,12 +1,13 @@
package handlers
import (
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
)
// region /update
@@ -20,7 +21,7 @@ type UpdateUserReq struct {
func UpdateUser(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -33,7 +34,7 @@ func UpdateUser(c *fiber.Ctx) error {
// 更新用户信息
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Updates(m.User{
Username: &req.Username,
Email: &req.Email,
@@ -59,7 +60,7 @@ type UpdateAccountReq struct {
func UpdateAccount(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -72,7 +73,7 @@ func UpdateAccount(c *fiber.Ctx) error {
// 更新用户信息
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Updates(m.User{
Username: &req.Username,
Password: &req.Password,
@@ -97,7 +98,7 @@ type UpdatePasswordReq struct {
func UpdatePassword(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -124,7 +125,7 @@ func UpdatePassword(c *fiber.Ctx) error {
}
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
UpdateColumn(q.User.Password, newHash)
if err != nil {
return err

View File

@@ -2,12 +2,14 @@ package handlers
import (
"errors"
"platform/pkg/env"
"platform/web/auth"
"platform/web/services"
"regexp"
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
)
type VerifierReq struct {
@@ -17,7 +19,7 @@ type VerifierReq struct {
func SmsCode(c *fiber.Ctx) error {
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
_, err := auth.GetAuthCtx(c).PermitInternalClient()
if err != nil {
return err
}
@@ -48,3 +50,19 @@ func SmsCode(c *fiber.Ctx) error {
// 发送成功
return nil
}
func DebugGetSmsCode(c *fiber.Ctx) error {
if env.RunMode != env.RunModeDev {
return fiber.NewError(fiber.StatusForbidden, "not allowed")
}
code, err := services.Verifier.GetSms(c.Context(), c.Params("phone"))
if err != nil {
if errors.Is(err, redis.Nil) {
return c.SendString("还没有验证码")
}
return err
}
return c.SendString(code)
}

View File

@@ -26,7 +26,7 @@ type ListWhitelistResp struct {
func ListWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -38,8 +38,7 @@ func ListWhitelist(c *fiber.Ctx) error {
}
// 获取白名单信息
do := q.Whitelist.
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id))
do := q.Whitelist.Where(q.Whitelist.UserID.Eq(authCtx.User.ID))
list, err := q.Whitelist.Where(do).
Offset(req.GetOffset()).
@@ -77,7 +76,7 @@ type CreateWhitelistReq struct {
func CreateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -96,7 +95,7 @@ func CreateWhitelist(c *fiber.Ctx) error {
// 创建白名单
err = q.Whitelist.Create(&m.Whitelist{
UserID: authContext.Payload.Id,
UserID: authCtx.User.ID,
Host: req.Host,
Remark: &req.Remark,
})
@@ -111,7 +110,7 @@ type UpdateWhitelistReq struct {
func UpdateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -129,7 +128,7 @@ func UpdateWhitelist(c *fiber.Ctx) error {
_, err = q.Whitelist.
Where(
q.Whitelist.ID.Eq(req.ID),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(authCtx.User.ID),
).
Updates(&m.Whitelist{
ID: req.ID,
@@ -149,7 +148,7 @@ type RemoveWhitelistReq struct {
func RemoveWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -175,7 +174,7 @@ func RemoveWhitelist(c *fiber.Ctx) error {
_, err = q.Whitelist.
Where(
q.Whitelist.ID.In(ids...),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(authCtx.User.ID),
).
Update(
q.Whitelist.DeletedAt, time.Now(),