重构代码结构与认证体系,集成异步任务消费者

This commit is contained in:
2025-11-17 18:38:10 +08:00
parent a97c970166
commit a245229bc2
70 changed files with 2000 additions and 2334 deletions

View File

@@ -4,112 +4,101 @@ import (
"context"
"encoding/base64"
"errors"
"fmt"
"log/slog"
"platform/web/core"
client2 "platform/web/domains/client"
m "platform/web/models"
q "platform/web/queries"
"slices"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
)
type ProtectBuilder struct {
c *fiber.Ctx
types []PayloadType
scopes []string
func Authenticate() fiber.Handler {
return func(ctx *fiber.Ctx) error {
header := ctx.Get(fiber.HeaderAuthorization)
authCtx, err := authHeader(ctx.Context(), header)
if err != nil {
return err
}
if authCtx == nil {
authCtx = &AuthCtx{}
}
SetAuthCtx(ctx, authCtx)
return ctx.Next()
}
}
func NewProtect(c *fiber.Ctx) *ProtectBuilder {
return &ProtectBuilder{c, []PayloadType{}, []string{}}
}
func authHeader(ctx context.Context, header string) (*AuthCtx, error) {
if header == "" {
return nil, nil
}
func (p *ProtectBuilder) Payload(types ...PayloadType) *ProtectBuilder {
p.types = types
return p
}
func (p *ProtectBuilder) Scopes(scopes ...string) *ProtectBuilder {
p.scopes = scopes
return p
}
func (p *ProtectBuilder) Do() (*Context, error) {
return Protect(p.c, p.types, p.scopes)
}
func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) {
// 获取令牌
var header = c.Get("Authorization")
var split = strings.Split(header, " ")
if len(split) != 2 {
slog.Debug("Authorization 头格式不正确")
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
var token = strings.TrimSpace(split[1])
if token == "" {
slog.Debug("提供的令牌为空")
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
var auth *Context
var authCtx *AuthCtx
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
authCtx, err = authBearer(ctx, token)
if err != nil {
slog.Debug("Bearer 认证失败", "err", err)
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
case "Basic":
if !slices.Contains(types, PayloadInternalServer) {
slog.Debug("禁止使用 Basic 认证方式")
return nil, ErrUnauthorize
}
auth, err = authBasic(c.Context(), token)
authCtx, err = authBasic(ctx, token)
if err != nil {
slog.Debug("Basic 认证失败", "err", err)
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
default:
slog.Debug("无效的认证方式", "method", split[0])
return nil, ErrUnauthorize
return nil, ErrAuthenticateUnauthorize
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
return nil, ErrForbidden
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
return nil, ErrForbidden
}
// 保存到上下文
Locals(c, auth)
return auth, nil
return authCtx, err
}
func Locals(c *fiber.Ctx, auth *Context) {
c.Locals("auth", auth)
}
func authBearer(ctx context.Context, token string) (*Context, error) {
auth, err := FindSession(ctx, token)
func authBearer(_ context.Context, token string) (*AuthCtx, error) {
session, err := FindSession(token, time.Now())
if err != nil {
slog.Debug(err.Error())
return nil, err
slog.Debug("Bearer 认证失败", "err", err)
return nil, ErrAuthenticateUnauthorize
}
return auth, nil
scopes := []string{}
if session.Scopes_ != nil {
scopes = strings.Split(*session.Scopes_, " ")
}
return &AuthCtx{
User: session.User,
Admin: session.Admin,
Client: session.Client,
Scopes: scopes,
Session: session,
}, nil
}
func authBasic(_ context.Context, token string) (*Context, error) {
func authBasic(_ context.Context, token string) (*AuthCtx, error) {
// 解析 Basic 认证信息
var base, err = base64.RawURLEncoding.DecodeString(token)
@@ -125,14 +114,23 @@ func authBasic(_ context.Context, token string) (*Context, error) {
return nil, errors.New("令牌格式错误,必须是 <client_id>:<client_secret> 格式")
}
var clientID = split[0]
client, err := authClient(split[0], split[1])
if err != nil {
return nil, fmt.Errorf("客户端认证失败:%w", err)
}
return &AuthCtx{
Client: client,
Scopes: []string{},
}, nil
}
func authClient(clientId, clientSecret string) (*m.Client, error) {
// 获取客户端信息
client, err := q.Client.
Where(
q.Client.ClientID.Eq(clientID),
q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)),
q.Client.GrantClient.Is(true),
q.Client.ClientID.Eq(clientId),
q.Client.Status.Eq(1)).
Take()
if err != nil {
@@ -140,33 +138,57 @@ func authBasic(_ context.Context, token string) (*Context, error) {
}
// 检查客户端密钥
var clientSecret = split[1]
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, errors.New("客户端密钥错误")
spec := client2.Spec(client.Spec)
if spec == client2.SpecWeb || spec == client2.SpecApi {
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, errors.New("客户端密钥错误")
}
}
// todo 查询客户端关联权限
// 组织授权信息(一次性请求)
return &Context{
Payload: Payload{
Id: client.ID,
Type: PayloadTypeFromClientSpec(client2.Spec(client.Spec)),
Name: client.Name,
Avatar: client.Icon,
},
Permissions: nil,
Metadata: nil,
}, nil
return client, nil
}
type AuthenticationErr string
func authUserBySms(tx *q.Query, username, code string) (*m.User, error) {
// 验证验证码
err := s.Verifier.VerifySms(context.Background(), username, code)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return nil, ErrAuthorizeInvalidRequest
}
return nil, err
}
func (e AuthenticationErr) Error() string {
return string(e)
// 查找用户
return tx.User.Where(tx.User.Phone.Eq(username)).Take()
}
var (
ErrUnauthorize = AuthenticationErr("令牌无效")
ErrForbidden = AuthenticationErr("没有权限")
)
func authUserByEmail(tx *q.Query, username, code string) (*m.User, error) {
return nil, core.NewServErr("邮箱登录不可用")
}
func authUserByPassword(tx *q.Query, username, password string) (*m.User, error) {
user, err := tx.User.
Where(tx.User.Phone.Eq(username)).
Or(tx.User.Email.Eq(username)).
Or(tx.User.Username.Eq(username)).
Take()
if err != nil {
slog.Debug("查找用户失败", "error", err)
return nil, core.NewBizErr("用户不存在或密码错误")
}
// 验证密码
if user.Password == nil || *user.Password == "" {
slog.Debug("用户未设置密码", "username", username)
return nil, core.NewBizErr("用户不存在或密码错误")
}
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)) != nil {
slog.Debug("密码验证失败", "username", username)
return nil, core.NewBizErr("用户不存在或密码错误")
}
return user, nil
}

View File

@@ -1,5 +1,28 @@
package auth
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"log/slog"
"platform/pkg/env"
"platform/pkg/u"
"platform/web/core"
user2 "platform/web/domains/user"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"gorm.io/gorm"
)
type GrantType string
const (
@@ -17,36 +40,352 @@ const (
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码
)
func Token(grant GrantType) error {
return nil
type TokenReq struct {
GrantType GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
GrantCodeData
GrantClientData
GrantRefreshData
GrantPasswordData
}
func authAuthorizationCode() {
type GrantCodeData struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
func authClientCredential() {
type GrantClientData struct {
}
func authRefreshToken() {
type GrantRefreshData struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
func authPassword() {
type GrantPasswordData struct {
LoginType PasswordGrantType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
func authPasswordSecret() {
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
func authPasswordPhone() {
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
func authPasswordEmail() {
func Token(c *fiber.Ctx) error {
now := time.Now()
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
}
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
if req.Code == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
}
// 刷新令牌模式
case GrantRefreshToken:
if req.RefreshToken == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
}
// 密码模式
case GrantPassword:
if req.LoginType == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
}
if req.Username == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
}
if req.Password == "" {
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
}
}
// 验证客户端身份
authCtx := GetAuthCtx(c)
if authCtx == nil {
authCtx = &AuthCtx{}
}
if authCtx.Client == nil {
client, err := authClient(req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
authCtx.Client = client
}
// 处理授权
var session *m.Session
var err error
switch req.GrantType {
// 授权码模式
case GrantAuthorizationCode:
session, err = authAuthorizationCode(c, authCtx, req, now)
// 客户端凭证模式
case GrantClientCredentials:
session, err = authClientCredential(c, authCtx, req, now)
// 刷新令牌模式
case GrantRefreshToken:
session, err = authRefreshToken(c, authCtx, req, now)
// 密码模式
case GrantPassword:
session, err = authPassword(c, authCtx, req, now)
default:
return sendError(c, ErrAuthorizeUnsupportedGrantType)
}
if err != nil {
return sendError(c, err)
}
// 返回响应
return c.JSON(&TokenResp{
TokenType: "Bearer",
AccessToken: session.AccessToken,
RefreshToken: u.Z(session.RefreshToken),
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
Scope: u.Z(session.Scopes_),
})
}
func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
// 检查 code 获取用户授权信息
data, err := g.Redis.Get(context.Background(), req.Code).Result()
if err != nil {
return nil, err
}
var codeCtx CodeContext
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
return nil, err
}
// 检查 PKCE
if codeCtx.CodeChallengeMethod != "" {
if req.CodeVerifier == "" {
return nil, ErrAuthorizeInvalidPKCE
}
switch codeCtx.CodeChallengeMethod {
case "plain":
if req.CodeVerifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
case "S256":
hash := sha256.Sum256([]byte(req.CodeVerifier))
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
if verifier != codeCtx.CodeChallenge {
return nil, ErrAuthorizeInvalidPKCE
}
default:
return nil, ErrAuthorizeInvalidPKCE
}
}
user, err := q.User.Where(
q.User.ID.Eq(codeCtx.UserID),
q.User.Status.Eq(int32(user2.StatusEnabled)),
).First()
if err != nil {
return nil, err
}
// todo 检查 scope
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
UserID: &user.ID,
ClientID: &auth.Client.ID,
Scopes_: u.P(strings.Join(codeCtx.Scopes, " ")),
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
if codeCtx.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
// todo 检查 scope
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
ClientID: &auth.Client.ID,
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
// 保存会话
err := SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
var user *m.User
err := q.Q.Transaction(func(tx *q.Query) (err error) {
switch req.LoginType {
case GrantPasswordPhone:
user, err = authUserBySms(tx, req.Username, req.Password)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
if user == nil {
user = &m.User{
Phone: req.Username,
Username: u.P(req.Username),
Status: int32(user2.StatusEnabled),
}
}
case GrantPasswordEmail:
user, err = authUserByEmail(tx, req.Username, req.Password)
if err != nil {
return err
}
case GrantPasswordSecret:
user, err = authUserByPassword(tx, req.Username, req.Password)
if err != nil {
return err
}
default:
return ErrAuthorizeInvalidRequest
}
// 账户状态
if user2.Status(user.Status) == user2.StatusDisabled {
slog.Debug("账户状态异常", "username", req.Username, "status", user.Status)
return core.NewBizErr("账号无法登录")
}
// 更新用户的登录时间
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
user.LastLoginHost = u.X(ctx.IP())
user.LastLoginAgent = u.X(ctx.Get(fiber.HeaderUserAgent))
if err := tx.User.Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 生成会话
session := &m.Session{
IP: u.X(ctx.IP()),
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
UserID: &user.ID,
ClientID: &auth.Client.ID,
Scopes_: u.X(req.Scope),
AccessToken: uuid.NewString(),
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
}
if req.Remember {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
session, err := FindSessionByRefresh(req.RefreshToken, now)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAuthorizeInvalidGrant
}
return nil, err
}
// todo 检查权限
// 生成令牌
session.AccessToken = uuid.NewString()
session.AccessTokenExpires = orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second))
if session.RefreshToken != nil {
session.RefreshToken = u.P(uuid.NewString())
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
}
// 保存令牌
err = SaveSession(session)
if err != nil {
return nil, err
}
return session, nil
}
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr AuthErr
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, ErrAuthorizeInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, ErrAuthorizeInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
func Revoke() error {
@@ -56,3 +395,12 @@ func Revoke() error {
func Introspect() error {
return nil
}
type CodeContext struct {
UserID int32 `json:"user_id"`
ClientID int32 `json:"client_id"`
Scopes []string `json:"scopes"`
Remember bool `json:"remember"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
}

99
web/auth/check.go Normal file
View File

@@ -0,0 +1,99 @@
package auth
import (
"platform/web/domains/client"
m "platform/web/models"
"github.com/gofiber/fiber/v2"
)
type AuthCtx struct {
User *m.User `json:"account,omitempty"`
Admin *m.Admin `json:"admin,omitempty"`
Client *m.Client `json:"client,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Session *m.Session `json:"session,omitempty"`
smap map[string]struct{}
}
func (a *AuthCtx) PermitUser(scopes ...string) (*AuthCtx, error) {
if a.User == nil {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitAdmin(scopes ...string) (*AuthCtx, error) {
if a.Admin == nil {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitSecretClient(scopes ...string) (*AuthCtx, error) {
if a.Client == nil {
return a, ErrAuthenticateForbidden
}
spec := client.Spec(a.Client.Spec)
if spec != client.SpecApi && spec != client.SpecWeb {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) PermitInternalClient(scopes ...string) (*AuthCtx, error) {
if a.Client == nil {
return a, ErrAuthenticateForbidden
}
spec := client.Spec(a.Client.Spec)
if spec != client.SpecApi && spec != client.SpecWeb {
return a, ErrAuthenticateForbidden
}
cType := client.Type(a.Client.Type)
if cType != client.TypeInternal {
return a, ErrAuthenticateForbidden
}
if !a.checkScopes(scopes...) {
return a, ErrAuthenticateForbidden
}
return a, nil
}
func (a *AuthCtx) checkScopes(scopes ...string) bool {
if len(scopes) == 0 || len(a.Scopes) == 0 {
return true
}
if len(a.smap) == 0 && len(a.Scopes) > 0 {
for _, scope := range scopes {
a.smap[scope] = struct{}{}
}
}
for _, scope := range scopes {
if _, ok := a.smap[scope]; ok {
return true
}
}
return false
}
const AuthCtxKey = "session"
func SetAuthCtx(c *fiber.Ctx, auth *AuthCtx) {
c.Locals(AuthCtxKey, auth)
}
func GetAuthCtx(c *fiber.Ctx) *AuthCtx {
if authCtx, ok := c.Locals(AuthCtxKey).(*AuthCtx); ok {
return authCtx
}
return nil
}

View File

@@ -1,103 +0,0 @@
package auth
import (
client2 "platform/web/domains/client"
)
// Context 定义认证信息
type Context struct {
Payload Payload `json:"payload"`
Permissions map[string]struct{} `json:"permissions,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (a *Context) AnyType(types ...PayloadType) bool {
if a == nil {
return false
}
for _, t := range types {
if a.Payload.Type == t {
return true
}
}
return false
}
// AnyPermission 检查认证是否包含指定权限
func (a *Context) AnyPermission(requiredPermission ...string) bool {
if a == nil || a.Permissions == nil {
return false
}
for _, permission := range requiredPermission {
if _, ok := a.Permissions[permission]; ok {
return true
}
}
return false
}
// Payload 定义负载信息
type Payload struct {
Id int32 `json:"id,omitempty"`
Type PayloadType `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Avatar *string `json:"avatar,omitempty"`
}
type PayloadType int
const (
PayloadNone PayloadType = iota // 游客
PayloadUser // 用户
PayloadAdmin // 管理员
PayloadPublicServer // 公共服务public_client
PayloadSecuredServer // 安全服务credential_client
PayloadInternalServer // 内部服务
)
func (t PayloadType) ToStr() string {
switch t {
case PayloadUser:
return "user"
case PayloadAdmin:
return "admn"
case PayloadPublicServer:
return "cpub"
case PayloadSecuredServer:
return "ccnf"
case PayloadInternalServer:
return "inte"
default:
return "none"
}
}
func PayloadTypeFromStr(name string) PayloadType {
switch name {
case "user":
return PayloadUser
case "admn":
return PayloadAdmin
case "cpub":
return PayloadPublicServer
case "ccnf":
return PayloadSecuredServer
case "inte":
return PayloadInternalServer
default:
return PayloadNone
}
}
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
var clientType PayloadType
switch spec {
case client2.SpecNative, client2.SpecBrowser:
clientType = PayloadPublicServer
case client2.SpecWeb:
clientType = PayloadSecuredServer
case client2.SpecTrusted:
clientType = PayloadInternalServer
}
return clientType
}

24
web/auth/errors.go Normal file
View File

@@ -0,0 +1,24 @@
package auth
type AuthErr string
func (e AuthErr) Error() string {
return string(e)
}
// 认证错误
const (
ErrAuthenticateUnauthorize = AuthErr("令牌无效")
ErrAuthenticateForbidden = AuthErr("没有权限")
)
// 授权错误
const (
ErrAuthorizeInvalidRequest = AuthErr("invalid_request")
ErrAuthorizeInvalidClient = AuthErr("invalid_client")
ErrAuthorizeInvalidGrant = AuthErr("invalid_grant")
ErrAuthorizeInvalidScope = AuthErr("invalid_scope")
ErrAuthorizeUnauthorizedClient = AuthErr("unauthorized_client")
ErrAuthorizeUnsupportedGrantType = AuthErr("unsupported_grant_type")
ErrAuthorizeInvalidPKCE = AuthErr("invalid_pkce")
)

View File

@@ -2,160 +2,36 @@ package auth
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/redis/go-redis/v9"
"platform/pkg/env"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"time"
"gorm.io/gen/field"
)
type Session struct {
// 认证主体
Payload *Payload
// 令牌信息
TokenDetails *TokenDetails
func FindSession(accessToken string, now time.Time) (*m.Session, error) {
return q.Session.
Preload(field.Associations).
Where(
q.Session.AccessToken.Eq(accessToken),
q.Session.AccessTokenExpires.Gt(orm.LocalDateTime(now)),
).First()
}
func FindSession(ctx context.Context, token string) (*Context, error) {
// 读取认证数据
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, errors.New("invalid_token")
}
return nil, err
}
// 反序列化
auth := new(Context)
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
return nil, err
}
return auth, nil
func FindSessionByRefresh(refreshToken string, now time.Time) (*m.Session, error) {
return q.Session.
Preload(field.Associations).
Where(
q.Session.RefreshToken.Eq(refreshToken),
q.Session.RefreshTokenExpires.Gt(orm.LocalDateTime(now)),
).First()
}
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
var now = time.Now()
// 生成令牌组
accessToken := genToken()
refreshToken := genToken()
// 序列化认证数据
authData, err := json.Marshal(authCtx)
if err != nil {
return nil, err
}
// 序列化刷新令牌数据
refreshData, err := json.Marshal(RefreshData{
AuthContext: authCtx,
AccessToken: accessToken,
})
if err != nil {
return nil, err
}
// 事务保存数据到 Redis
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipe := g.Redis.TxPipeline()
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
if remember {
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
}
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
return &TokenDetails{
AccessToken: accessToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshToken: refreshToken,
RefreshTokenExpires: now.Add(refreshExpire),
Auth: authCtx,
}, nil
}
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
var now = time.Now()
rKey := refreshKey(refreshToken)
var tokenDetails *TokenDetails
// 刷新令牌
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
// 先获取刷新令牌数据
refreshJson, err := tx.Get(ctx, rKey).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return ErrInvalidRefreshToken
}
return err
}
// 解析刷新令牌数据
refreshData := new(RefreshData)
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
return err
}
// 生成新的令牌
newAccessToken := genToken()
newRefreshToken := genToken()
authData, err := json.Marshal(refreshData.AuthContext)
if err != nil {
return err
}
newRefreshData, err := json.Marshal(RefreshData{
AuthContext: refreshData.AuthContext,
AccessToken: newAccessToken,
})
if err != nil {
return err
}
pipeline := tx.Pipeline()
// 保存新的令牌
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
// 删除旧的令牌
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
pipeline.Del(ctx, refreshKey(refreshToken))
_, err = pipeline.Exec(ctx)
if err != nil {
return err
}
tokenDetails = &TokenDetails{
AccessToken: newAccessToken,
RefreshToken: newRefreshToken,
AccessTokenExpires: now.Add(accessExpire),
RefreshTokenExpires: now.Add(refreshExpire),
Auth: refreshData.AuthContext,
}
return nil
}, rKey)
if err != nil {
return nil, fmt.Errorf("刷新令牌失败: %w", err)
}
return tokenDetails, nil
func SaveSession(session *m.Session) error {
return q.Session.Save(session)
}
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
@@ -163,11 +39,6 @@ func RemoveSession(ctx context.Context, accessToken string, refreshToken string)
return nil
}
// 生成一个新的令牌
func genToken() string {
return uuid.NewString()
}
// 令牌键的格式为 "session:<token>"
func accessKey(token string) string {
return fmt.Sprintf("session:%s", token)
@@ -177,32 +48,3 @@ func accessKey(token string) string {
func refreshKey(token string) string {
return fmt.Sprintf("session:refresh:%s", token)
}
// TokenDetails 存储令牌详细信息
type TokenDetails struct {
// 访问令牌
AccessToken string
// 刷新令牌
RefreshToken string
// 访问令牌过期时间
AccessTokenExpires time.Time
// 刷新令牌过期时间
RefreshTokenExpires time.Time
// 认证信息
Auth *Context
}
type RefreshData struct {
AuthContext *Context
AccessToken string
}
type SessionErr string
func (e SessionErr) Error() string {
return string(e)
}
const (
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
)

View File

@@ -60,7 +60,7 @@ type Err struct {
func (e *Err) Error() string {
if e.err != nil {
return e.msg + "" + e.err.Error()
return e.msg + ": " + e.err.Error()
}
return e.msg
}

View File

@@ -6,5 +6,12 @@ const (
SpecNative Spec = iota + 1 // 原生客户端
SpecBrowser // 浏览器客户端
SpecWeb // Web 服务
SpecTrusted // 可信服务
SpecApi // Api 服务
)
type Type int32
const (
TypeNormal Type = iota // 普通客户端
TypeInternal // 内部客户端
)

View File

@@ -16,7 +16,7 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
var message = "服务器异常"
var fiberErr *fiber.Error
var authErr auth.AuthenticationErr
var authErr auth.AuthErr
var bizErr *core.BizErr
var servErr *core.ServErr
@@ -30,9 +30,9 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
// 认证授权错误
case errors.As(err, &authErr):
switch {
case errors.Is(err, auth.ErrUnauthorize):
case errors.Is(err, auth.ErrAuthenticateUnauthorize):
code = fiber.StatusUnauthorized
case errors.Is(err, auth.ErrForbidden):
case errors.Is(err, auth.ErrAuthenticateForbidden):
code = fiber.StatusForbidden
default:
code = fiber.StatusBadRequest

View File

@@ -1,4 +1,4 @@
package tasks
package events
import (
"encoding/json"

View File

@@ -1,4 +1,4 @@
package tasks
package events
import (
"encoding/json"

View File

@@ -1,10 +1,11 @@
package tasks
package events
import (
"encoding/json"
"github.com/hibiken/asynq"
"log/slog"
trade2 "platform/web/domains/trade"
"github.com/hibiken/asynq"
)
const CancelTrade = "trade:update"

View File

@@ -1,6 +1,7 @@
package globals
import (
"fmt"
"platform/pkg/env"
"github.com/smartwalle/alipay/v3"
@@ -8,25 +9,26 @@ import (
var Alipay *alipay.Client
func initAlipay() {
func initAlipay() error {
var client, err = alipay.New(
env.AlipayAppId,
env.AlipayAppPrivateKey,
env.AlipayProduction,
)
if err != nil {
panic("初始化支付宝客户端失败: " + err.Error())
return fmt.Errorf("初始化支付宝客户端失败: %w", err)
}
err = client.LoadAliPayPublicKey(env.AlipayPublicKey)
if err != nil {
panic("加载支付宝公钥失败: " + err.Error())
return fmt.Errorf("加载支付宝公钥失败: %w", err)
}
err = client.SetEncryptKey(env.AlipayApiCert)
if err != nil {
panic("设置支付宝加密密钥失败: " + err.Error())
return fmt.Errorf("设置支付宝加密证书失败: %w", err)
}
Alipay = client
return nil
}

View File

@@ -1,6 +1,7 @@
package globals
import (
"fmt"
"platform/pkg/env"
"platform/pkg/u"
@@ -14,17 +15,18 @@ type aliyunClient struct {
Sms *sms.Client
}
func initAliyun() {
func initAliyun() error {
client, err := sms.NewClient(&openapi.Config{
AccessKeyId: &env.AliyunAccessKey,
AccessKeySecret: &env.AliyunAccessKeySecret,
Endpoint: u.P("dysmsapi.aliyuncs.com"),
})
if err != nil {
panic(err)
return fmt.Errorf("初始化阿里云客户端失败: %w", err)
}
Aliyun = &aliyunClient{
Sms: client,
}
return nil
}

View File

@@ -35,10 +35,11 @@ type cloud struct {
var Cloud CloudClient
func initBaiyin() {
func initBaiyin() error {
Cloud = &cloud{
url: env.BaiyinAddr,
}
return nil
}
type AutoConfig struct {

View File

@@ -1,14 +1,38 @@
package globals
func Init() {
initBaiyin()
initAlipay()
initWechatPay()
initAliyun()
initValidator()
initRedis()
initOrm()
initProxy()
initAsynq()
initSft()
import (
"context"
"platform/pkg/u"
)
func Init(ctx context.Context) error {
errs := make([]error, 0)
errs = append(errs, initBaiyin())
errs = append(errs, initAlipay())
errs = append(errs, initWechatPay())
errs = append(errs, initAliyun())
errs = append(errs, initValidator())
errs = append(errs, initRedis())
errs = append(errs, initOrm())
errs = append(errs, initProxy())
errs = append(errs, initSft())
return u.CombineErrors(errs)
}
func Stop() error {
var errs = make([]error, 0)
err := stopRedis()
if err != nil {
errs = append(errs, err)
}
err = stopOrm()
if err != nil {
errs = append(errs, err)
}
return u.CombineErrors(errs)
}

View File

@@ -1,17 +1,20 @@
package globals
import (
"database/sql"
"fmt"
"platform/pkg/env"
"platform/web/queries"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"log/slog"
"platform/pkg/env"
)
var DB *gorm.DB
var Conn *sql.DB
func initOrm() {
func initOrm() error {
// 连接数据库
dsn := fmt.Sprintf(
@@ -25,27 +28,29 @@ func initOrm() {
},
})
if err != nil {
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
panic(err)
return fmt.Errorf("连接数据库失败: %w", err)
}
// 连接池
conn, err := db.DB()
if err != nil {
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
panic(err)
return fmt.Errorf("配置连接池失败: %w", err)
}
conn.SetMaxIdleConns(10)
conn.SetMaxOpenConns(100)
queries.SetDefault(db)
DB = db
Conn = conn
return nil
}
func ExitOrm() error {
func stopOrm() error {
if DB != nil {
conn, err := DB.DB()
if err != nil {
return err
return fmt.Errorf("关闭数据库连接失败: %w", err)
}
return conn.Close()
}

View File

@@ -23,8 +23,9 @@ var Proxy *ProxyClient
type ProxyClient struct {
}
func initProxy() {
func initProxy() error {
Proxy = &ProxyClient{}
return nil
}
type ProxyPermitConfig struct {

View File

@@ -1,12 +1,13 @@
package globals
import (
"github.com/go-redsync/redsync/v4/redis/goredis/v9"
"log/slog"
"net"
"platform/pkg/env"
"platform/web/core"
"github.com/go-redsync/redsync/v4/redis/goredis/v9"
"github.com/go-redsync/redsync/v4"
"github.com/redis/go-redis/v9"
)
@@ -18,11 +19,10 @@ type ExtendRedSync struct {
*redsync.Redsync
}
func initRedis() {
func initRedis() error {
client := redis.NewClient(&redis.Options{
Addr: net.JoinHostPort(env.RedisHost, env.RedisPort),
DB: env.RedisDb,
Password: env.RedisPass,
Password: env.RedisPassword,
})
pool := goredis.NewPool(client)
@@ -30,9 +30,11 @@ func initRedis() {
Redis = client
Redsync = &ExtendRedSync{sync}
return nil
}
func ExitRedis() error {
func stopRedis() error {
if Redis != nil {
return Redis.Close()
}

View File

@@ -28,9 +28,9 @@ type SftClient struct {
publicKey *rsa.PublicKey
}
func initSft() {
func initSft() error {
if !env.SftPayEnable {
panic("商福通支付未启用,请检查环境变量 SFTPAY_ENABLE")
return fmt.Errorf("商福通支付未启用,请检查环境变量 SFTPAY_ENABLE")
}
SFTPay = SftClient{
@@ -41,7 +41,7 @@ func initSft() {
// 加载私钥
private, err := base64.StdEncoding.DecodeString(env.SftPayAppPrivateKey)
if err != nil {
panic("解析商福通私钥失败: " + err.Error())
return fmt.Errorf("解析商福通私钥失败: %w", err)
}
var privateKey *rsa.PrivateKey
@@ -49,13 +49,13 @@ func initSft() {
if err != nil {
pkcs8, err := x509.ParsePKCS8PrivateKey(private)
if err != nil {
panic("解析商福通私钥失败: " + err.Error())
return fmt.Errorf("解析商福通私钥失败: %w", err)
}
var ok bool
privateKey, ok = pkcs8.(*rsa.PrivateKey)
if !ok {
panic("解析商福通私钥失败")
return fmt.Errorf("解析商福通私钥失败")
}
}
SFTPay.privateKey = privateKey
@@ -63,35 +63,36 @@ func initSft() {
// 加载公钥
public, err := base64.StdEncoding.DecodeString(env.SftPayPublicKey)
if err != nil {
panic("解析商福通公钥失败: " + err.Error())
return fmt.Errorf("解析商福通公钥失败: %w", err)
}
var publicKey *rsa.PublicKey
pkix, err := x509.ParsePKIXPublicKey(public)
if err != nil {
panic("解析商福通公钥失败: " + err.Error())
return fmt.Errorf("解析商福通公钥失败: %w", err)
}
var ok bool
publicKey, ok = pkix.(*rsa.PublicKey)
if !ok {
panic("解析商福通公钥失败")
return fmt.Errorf("解析商福通公钥失败")
}
SFTPay.publicKey = publicKey
return nil
}
func (s *SftClient) PaymentScanPay(req *PaymentScanPayReq) (*PaymentScanPayResp, error) {
const url = "https://pay.rscygroup.com/api/open/payment/scanpay"
req.ReturnUrl = env.SftReturnUrl
req.NotifyUrl = env.SftNotifyUrl
req.ReturnUrl = u.X(env.SftReturnUrl)
req.NotifyUrl = u.X(env.SftNotifyUrl)
req.RouteNo = u.P(s.routeId)
return call[PaymentScanPayResp](s, url, req)
}
func (s *SftClient) PaymentH5Pay(req *PaymentH5PayReq) (*PaymentH5PayResp, error) {
const url = "https://pay.rscygroup.com/api/open/payment/h5pay"
req.ReturnUrl = env.SftReturnUrl
req.NotifyUrl = env.SftNotifyUrl
req.ReturnUrl = u.X(env.SftReturnUrl)
req.NotifyUrl = u.X(env.SftNotifyUrl)
req.RouteNo = u.P(s.routeId)
return call[PaymentH5PayResp](s, url, req)
}
@@ -256,7 +257,7 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
encode, err := s.sign(req)
if err != nil {
return nil, fmt.Errorf("加密请求内容失败%w", err)
return nil, fmt.Errorf("加密请求内容失败: %w", err)
}
bytes, err := json.Marshal(encode)
@@ -266,33 +267,33 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
request, err := http.NewRequest("POST", url, strings.NewReader(string(bytes)))
if err != nil {
return nil, fmt.Errorf("创建请求失败%w", err)
return nil, fmt.Errorf("创建请求失败: %w", err)
}
request.Header.Set("Content-Type", "application/json")
if env.DebugHttpDump == true {
reqDump, err := httputil.DumpRequest(request, true)
if err != nil {
return nil, fmt.Errorf("请求内容转储失败%w", err)
return nil, fmt.Errorf("请求内容转储失败: %w", err)
}
println(string(reqDump) + "\n\n")
}
response, err := http.DefaultClient.Do(request)
if err != nil {
return nil, fmt.Errorf("请求失败%w", err)
return nil, fmt.Errorf("请求失败: %w", err)
}
if env.DebugHttpDump == true {
respDump, err := httputil.DumpResponse(response, true)
if err != nil {
return nil, fmt.Errorf("响应内容转储失败%w", err)
return nil, fmt.Errorf("响应内容转储失败: %w", err)
}
println(string(respDump) + "\n\n")
}
if response.StatusCode != http.StatusOK {
return nil, fmt.Errorf("请求响应失败%d", response.StatusCode)
return nil, fmt.Errorf("请求响应失败: %d", response.StatusCode)
}
defer func(body io.ReadCloser) {
_ = body.Close()
@@ -300,18 +301,18 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
body, err := io.ReadAll(response.Body)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败%w", err)
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
decode, err := s.verify(body)
if err != nil {
return nil, fmt.Errorf("解密响应内容失败%w", err)
return nil, fmt.Errorf("解密响应内容失败: %w", err)
}
var resp = new(T)
err = json.Unmarshal([]byte(decode), resp)
if err != nil {
return nil, fmt.Errorf("响应正文解析失败%w", err)
return nil, fmt.Errorf("响应正文解析失败: %w", err)
}
return resp, nil
@@ -321,7 +322,7 @@ func (s *SftClient) sign(msg any) (*request, error) {
bytes, err := json.Marshal(msg)
if err != nil {
return nil, fmt.Errorf("格式化加密正文失败%w", err)
return nil, fmt.Errorf("格式化加密正文失败: %w", err)
}
if env.DebugHttpDump {
@@ -341,7 +342,7 @@ func (s *SftClient) sign(msg any) (*request, error) {
hashed := sha256.Sum256([]byte(body.String()))
signature, err := rsa.SignPKCS1v15(nil, s.privateKey, crypto.SHA256, hashed[:])
if err != nil {
return nil, fmt.Errorf("签名失败%w", err)
return nil, fmt.Errorf("签名失败: %w", err)
}
body.Sign = base64.StdEncoding.EncodeToString(signature)
@@ -353,11 +354,11 @@ func (s *SftClient) verify(str []byte) (string, error) {
var resp = new(response)
err := json.Unmarshal(str, resp)
if err != nil {
return "", fmt.Errorf("解析响应正文失败%w", err)
return "", fmt.Errorf("解析响应正文失败: %w", err)
}
if resp.Code != "000000" {
return "", fmt.Errorf("请求业务响应失败%s", u.Z(resp.Msg))
return "", fmt.Errorf("请求业务响应失败: %s", u.Z(resp.Msg))
}
if resp.Sign == nil {
@@ -371,13 +372,13 @@ func (s *SftClient) verify(str []byte) (string, error) {
ser, err := resp.String()
if err != nil {
return "", fmt.Errorf("格式化响应内容失败%w", err)
return "", fmt.Errorf("格式化响应内容失败: %w", err)
}
hashed := sha256.Sum256([]byte(ser))
err = rsa.VerifyPKCS1v15(s.publicKey, crypto.SHA256, hashed[:], sign)
if err != nil {
return "", fmt.Errorf("验签失败%w", err)
return "", fmt.Errorf("验签失败: %w", err)
}
return *resp.BizData, nil
@@ -412,7 +413,7 @@ type response struct {
func (r response) String() (string, error) {
if r.BizData == nil || r.Msg == nil || r.SignType == nil {
return "", core.NewServErr(fmt.Sprintf(
"上游数据返回有空值BizData %vMsg %v, SignType %v",
"上游数据返回有空值: BizData %vMsg %v, SignType %v",
r.BizData == nil, r.Msg == nil, r.SignType == nil,
))
}

View File

@@ -2,12 +2,14 @@ package globals
import (
"errors"
"fmt"
"strings"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
zhtrans "github.com/go-playground/validator/v10/translations/zh"
"github.com/gofiber/fiber/v2"
"strings"
)
var Validator *ValidatorClient
@@ -38,17 +40,18 @@ func (v *ValidatorClient) Validate(c *fiber.Ctx, data any) error {
return nil
}
func initValidator() {
func initValidator() error {
var validate = validator.New(validator.WithRequiredStructEnabled())
var translator = ut.New(zh.New()).GetFallback()
err := zhtrans.RegisterDefaultTranslations(validate, translator)
if err != nil {
panic(err)
return fmt.Errorf("初始化验证器失败: %w", err)
}
Validator = &ValidatorClient{
validator: validate,
translator: translator,
}
return nil
}

View File

@@ -3,6 +3,7 @@ package globals
import (
"context"
"encoding/base64"
"fmt"
"platform/pkg/env"
"github.com/wechatpay-apiv3/wechatpay-go/core"
@@ -20,28 +21,28 @@ type WechatPayClient struct {
Notify *notify.Handler
}
func initWechatPay() {
func initWechatPay() error {
// 加载商户私钥
private, err := base64.StdEncoding.DecodeString(env.WechatPayMchPrivateKey)
if err != nil {
panic(err)
return fmt.Errorf("加载微信支付商户私钥失败: %w", err)
}
appPrivateKey, err := utils.LoadPrivateKey(string(private))
if err != nil {
panic(err)
return fmt.Errorf("解析微信支付商户私钥失败: %w", err)
}
// 加载微信支付公钥
public, err := base64.StdEncoding.DecodeString(env.WechatPayPublicKey)
if err != nil {
panic(err)
return fmt.Errorf("加载微信支付公钥失败: %w", err)
}
wechatPublicKey, err := utils.LoadPublicKey(string(public))
if err != nil {
panic(err)
return fmt.Errorf("解析微信支付公钥失败: %w", err)
}
// 创建 WechatPay 客户端
@@ -55,7 +56,7 @@ func initWechatPay() {
),
)
if err != nil {
panic(err)
return fmt.Errorf("创建微信支付客户端失败: %w", err)
}
// 创建 WechatPay 通知处理器
@@ -64,7 +65,7 @@ func initWechatPay() {
*wechatPublicKey,
))
if err != nil {
panic(err)
return fmt.Errorf("创建微信支付通知处理器失败: %w", err)
}
// 创建 WechatPay 服务
@@ -72,4 +73,5 @@ func initWechatPay() {
Native: &native.NativeApiService{Client: client},
Notify: handler,
}
return nil
}

View File

@@ -1,10 +1,11 @@
package handlers
import (
"github.com/gofiber/fiber/v2"
"platform/web/auth"
"platform/web/core"
q "platform/web/queries"
"github.com/gofiber/fiber/v2"
)
// region ListAnnouncements
@@ -16,7 +17,7 @@ type ListAnnouncementsRequest struct {
func ListAnnouncements(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}

View File

@@ -1,277 +1,14 @@
package handlers
import (
"encoding/base64"
"errors"
"log/slog"
"platform/pkg/u"
auth2 "platform/web/auth"
client2 "platform/web/domains/client"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region /token
type TokenReq struct {
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
s.GrantCodeData
s.GrantClientData
s.GrantRefreshData
s.GrantPasswordData
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
// 授权码模式
case auth2.GrantAuthorizationCode:
if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 客户端凭证模式
case auth2.GrantClientCredentials:
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(s.AuthServiceError))
}
return sendSuccess(c, token)
// 刷新令牌模式
case auth2.GrantRefreshToken:
if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, auth2.ErrInvalidRefreshToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
return sendSuccess(c, token)
// 密码模式
case auth2.GrantPassword:
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
if err != nil {
return sendError(c, err)
}
return sendSuccess(c, token)
default:
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.RawURLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case auth2.GrantAuthorizationCode:
if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantClientCredentials:
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantRefreshToken:
if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient
}
case auth2.GrantPassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == int32(client2.SpecWeb) || client.Spec == int32(client2.SpecTrusted) {
if clientSecret == "" {
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, s.ErrOauthInvalidClient
}
}
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
auth2.Locals(c, &auth2.Context{
Payload: auth2.Payload{
Id: client.ID,
Type: auth2.PayloadSecuredServer,
Name: client.Name,
Avatar: client.Icon,
},
})
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr s.AuthServiceError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion
// region /revoke
type RevokeReq struct {
@@ -280,7 +17,7 @@ type RevokeReq struct {
}
func Revoke(c *fiber.Ctx) error {
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
_, err := auth2.GetAuthCtx(c).PermitUser()
if err != nil {
// 用户未登录
return nil
@@ -312,14 +49,14 @@ type IntrospectResp struct {
func Introspect(c *fiber.Ctx) error {
// 验证权限
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
authCtx, err := auth2.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
// 获取用户信息
profile, err := q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Omit(q.User.DeletedAt).
Take()
if err != nil {

View File

@@ -23,7 +23,7 @@ type ListBillReq struct {
// ListBill 获取账单列表
func ListBill(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -36,7 +36,7 @@ func ListBill(c *fiber.Ctx) error {
// 查询账单列表
do := q.Bill.
Where(q.Bill.UserID.Eq(authContext.Payload.Id))
Where(q.Bill.UserID.Eq(authCtx.User.ID))
if req.Type != nil {
do.Where(q.Bill.Type.Eq(int32(*req.Type)))

View File

@@ -24,7 +24,7 @@ type ListChannelsReq struct {
func ListChannels(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authContext, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -37,7 +37,7 @@ func ListChannels(c *fiber.Ctx) error {
// 构造查询条件
cond := q.Channel.
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
Where(q.Channel.UserID.Eq(authContext.User.ID))
switch req.AuthType {
case s.ChannelAuthTypeIp:
cond.Where(q.Channel.AuthIP.Is(true))
@@ -110,24 +110,19 @@ type CreateChannelRespItem struct {
func CreateChannel(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
// 检查用户其他权限
user, err := q.User.
Where(q.User.ID.Eq(authContext.Payload.Id)).
Take()
if err != nil {
return err
}
user := authCtx.User
if user.IDToken == nil || *user.IDToken == "" {
return fiber.NewError(fiber.StatusForbidden, "账号未实名")
}
count, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(user.ID),
q.Whitelist.Host.Eq(c.IP()),
).Count()
if err != nil {
@@ -155,7 +150,7 @@ func CreateChannel(c *fiber.Ctx) error {
// 创建通道
result, err := s.Channel.CreateChannel(
c,
authContext.Payload.Id,
user.ID,
req.ResourceId,
req.Protocol,
req.AuthType,
@@ -198,7 +193,7 @@ type RemoveChannelsReq struct {
func RemoveChannels(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -210,31 +205,7 @@ func RemoveChannels(c *fiber.Ctx) error {
}
// 删除通道
err = s.Channel.RemoveChannels(req.ByIds, authCtx.Payload.Id)
if err != nil {
return err
}
return c.SendStatus(fiber.StatusOK)
}
type RemoveChannelByTaskReq []int32
func RemoveChannelByTask(c *fiber.Ctx) error {
// 检查权限
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
if err != nil {
return err
}
// 解析请求参数
var req RemoveChannelByTaskReq
if err := c.BodyParser(&req); err != nil {
return err
}
// 删除通道
err = s.Channel.RemoveChannels(req)
err = s.Channel.RemoveChannels(req.ByIds, authCtx.User.ID)
if err != nil {
return err
}

View File

@@ -2,8 +2,6 @@ package handlers
import (
"errors"
"gorm.io/gen/field"
"gorm.io/gorm"
"log/slog"
"platform/pkg/u"
"platform/web/auth"
@@ -14,6 +12,9 @@ import (
q "platform/web/queries"
s "platform/web/services"
"gorm.io/gen/field"
"gorm.io/gorm"
"github.com/gofiber/fiber/v2"
)
@@ -120,7 +121,7 @@ type AllEdgesAvailableRespItem struct {
func AllEdgesAvailable(c *fiber.Ctx) (err error) {
// 检查权限
_, err = auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
_, err = auth.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -37,17 +37,11 @@ type IdentifyRes struct {
func Identify(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
if err != nil {
return err
}
user, err := q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Select(q.User.ID, q.User.IDToken).
Take()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
user := authCtx.User
if user.IDToken != nil && *user.IDToken != "" {
// 用户已实名认证
return c.JSON(IdentifyRes{
@@ -86,7 +80,7 @@ func Identify(c *fiber.Ctx) error {
// 保存认证中间状态
info := idenInfo{
Uid: authCtx.Payload.Id,
Uid: user.ID,
Type: req.Type,
Name: req.Name,
IdNo: req.IdenNo,

View File

@@ -40,9 +40,7 @@ type ProxyReportOnlineResp struct {
func ProxyReportOnline(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}
@@ -149,9 +147,7 @@ type ProxyReportOfflineReq struct {
func ProxyReportOffline(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}
@@ -193,9 +189,7 @@ type ProxyReportUpdateReq struct {
func ProxyReportUpdate(c *fiber.Ctx) (err error) {
// 检查接口权限
_, err = auth2.NewProtect(c).Payload(
auth2.PayloadInternalServer,
).Do()
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -1,7 +1,6 @@
package handlers
import (
"gorm.io/gen/field"
"platform/pkg/u"
"platform/web/auth"
"platform/web/core"
@@ -12,6 +11,8 @@ import (
s "platform/web/services"
"time"
"gorm.io/gen/field"
"github.com/gofiber/fiber/v2"
)
@@ -28,7 +29,7 @@ type ListResourceShortReq struct {
func ListResourceShort(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -41,7 +42,7 @@ func ListResourceShort(c *fiber.Ctx) error {
// 查询套餐列表
do := q.Resource.Where(
q.Resource.UserID.Eq(authContext.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Type.Eq(int32(resource2.TypeShort)),
)
if req.ResourceNo != nil && *req.ResourceNo != "" {
@@ -109,7 +110,7 @@ type ListResourceLongReq struct {
func ListResourceLong(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -122,7 +123,7 @@ func ListResourceLong(c *fiber.Ctx) error {
// 查询套餐列表
do := q.Resource.Where(
q.Resource.UserID.Eq(authContext.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Type.Eq(int32(resource2.TypeLong)),
)
if req.ResourceNo != nil && *req.ResourceNo != "" {
@@ -182,7 +183,7 @@ type AllResourceReq struct {
func AllActiveResource(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -198,7 +199,7 @@ func AllActiveResource(c *fiber.Ctx) error {
q.Resource.Long,
).
Where(
q.Resource.UserID.Eq(authCtx.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Active.Is(true),
q.Resource.Where(
q.Resource.Type.Eq(int32(resource2.TypeShort)),
@@ -254,7 +255,7 @@ type StatisticLong struct {
func StatisticResourceFree(c *fiber.Ctx) error {
// 检查权限
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -266,7 +267,7 @@ func StatisticResourceFree(c *fiber.Ctx) error {
q.Resource.Long,
).
Where(
q.Resource.UserID.Eq(session.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.Active.Is(true),
).
Select(q.Resource.ID, q.Resource.Type).
@@ -347,7 +348,7 @@ type StatisticResourceUsageResp []struct {
func StatisticResourceUsage(c *fiber.Ctx) error {
// 检查权限
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -359,12 +360,12 @@ func StatisticResourceUsage(c *fiber.Ctx) error {
}
// 统计套餐提取数量
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(session.Payload.Id))
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(authCtx.User.ID))
if req.ResourceNo != nil && *req.ResourceNo != "" {
var resourceID int32
err := q.Resource.
Where(
q.Resource.UserID.Eq(session.Payload.Id),
q.Resource.UserID.Eq(authCtx.User.ID),
q.Resource.ResourceNo.Eq(*req.ResourceNo),
).
Select(q.Resource.ID).
@@ -409,7 +410,7 @@ type CreateResourceReq struct {
func CreateResource(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -421,7 +422,7 @@ func CreateResource(c *fiber.Ctx) error {
}
// 创建套餐
err = s.Resource.CreateResourceByBalance(authCtx.Payload.Id, time.Now(), req.CreateResourceData)
err = s.Resource.CreateResourceByBalance(authCtx.User.ID, time.Now(), req.CreateResourceData)
if err != nil {
return err
}
@@ -431,7 +432,7 @@ func CreateResource(c *fiber.Ctx) error {
func ResourcePrice(c *fiber.Ctx) error {
// 检查权限
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
_, err := auth.GetAuthCtx(c).PermitSecretClient()
if err != nil {
return err
}

View File

@@ -7,7 +7,6 @@ import (
trade2 "platform/web/domains/trade"
g "platform/web/globals"
s "platform/web/services"
"platform/web/tasks"
"time"
"github.com/gofiber/fiber/v2"
@@ -27,7 +26,7 @@ type TradeCreateResp struct {
func TradeCreate(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -52,7 +51,7 @@ func TradeCreate(c *fiber.Ctx) error {
}
// 创建交易
result, err := s.Trade.CreateTrade(authCtx.Payload.Id, time.Now(), &req.CreateTradeData)
result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData)
if err != nil {
slog.Error("创建交易失败", "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"})
@@ -70,7 +69,7 @@ type TradeCompleteReq struct {
func TradeComplete(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -99,7 +98,7 @@ type TradeCancelReq struct {
func TradeCancel(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
_, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -119,29 +118,3 @@ func TradeCancel(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusNoContent)
}
type TradeCheckReq struct {
tasks.CancelTradeData
}
func TradeCancelByTask(c *fiber.Ctx) error {
// 检查权限
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
if err != nil {
return err
}
// 取消交易
req := new(TradeCheckReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 检查订单状态
err = s.Trade.CancelTrade(req.TradeNo, req.Method, time.Now())
if err != nil {
return err
}
return nil
}

View File

@@ -1,12 +1,13 @@
package handlers
import (
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
s "platform/web/services"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
)
// region /update
@@ -20,7 +21,7 @@ type UpdateUserReq struct {
func UpdateUser(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -33,7 +34,7 @@ func UpdateUser(c *fiber.Ctx) error {
// 更新用户信息
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Updates(m.User{
Username: &req.Username,
Email: &req.Email,
@@ -59,7 +60,7 @@ type UpdateAccountReq struct {
func UpdateAccount(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -72,7 +73,7 @@ func UpdateAccount(c *fiber.Ctx) error {
// 更新用户信息
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
Updates(m.User{
Username: &req.Username,
Password: &req.Password,
@@ -97,7 +98,7 @@ type UpdatePasswordReq struct {
func UpdatePassword(c *fiber.Ctx) error {
// 检查权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -124,7 +125,7 @@ func UpdatePassword(c *fiber.Ctx) error {
}
_, err = q.User.
Where(q.User.ID.Eq(authCtx.Payload.Id)).
Where(q.User.ID.Eq(authCtx.User.ID)).
UpdateColumn(q.User.Password, newHash)
if err != nil {
return err

View File

@@ -2,12 +2,14 @@ package handlers
import (
"errors"
"platform/pkg/env"
"platform/web/auth"
"platform/web/services"
"regexp"
"strconv"
"github.com/gofiber/fiber/v2"
"github.com/redis/go-redis/v9"
)
type VerifierReq struct {
@@ -17,7 +19,7 @@ type VerifierReq struct {
func SmsCode(c *fiber.Ctx) error {
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
_, err := auth.GetAuthCtx(c).PermitInternalClient()
if err != nil {
return err
}
@@ -48,3 +50,19 @@ func SmsCode(c *fiber.Ctx) error {
// 发送成功
return nil
}
func DebugGetSmsCode(c *fiber.Ctx) error {
if env.RunMode != env.RunModeDev {
return fiber.NewError(fiber.StatusForbidden, "not allowed")
}
code, err := services.Verifier.GetSms(c.Context(), c.Params("phone"))
if err != nil {
if errors.Is(err, redis.Nil) {
return c.SendString("还没有验证码")
}
return err
}
return c.SendString(code)
}

View File

@@ -26,7 +26,7 @@ type ListWhitelistResp struct {
func ListWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -38,8 +38,7 @@ func ListWhitelist(c *fiber.Ctx) error {
}
// 获取白名单信息
do := q.Whitelist.
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id))
do := q.Whitelist.Where(q.Whitelist.UserID.Eq(authCtx.User.ID))
list, err := q.Whitelist.Where(do).
Offset(req.GetOffset()).
@@ -77,7 +76,7 @@ type CreateWhitelistReq struct {
func CreateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -96,7 +95,7 @@ func CreateWhitelist(c *fiber.Ctx) error {
// 创建白名单
err = q.Whitelist.Create(&m.Whitelist{
UserID: authContext.Payload.Id,
UserID: authCtx.User.ID,
Host: req.Host,
Remark: &req.Remark,
})
@@ -111,7 +110,7 @@ type UpdateWhitelistReq struct {
func UpdateWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -129,7 +128,7 @@ func UpdateWhitelist(c *fiber.Ctx) error {
_, err = q.Whitelist.
Where(
q.Whitelist.ID.Eq(req.ID),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(authCtx.User.ID),
).
Updates(&m.Whitelist{
ID: req.ID,
@@ -149,7 +148,7 @@ type RemoveWhitelistReq struct {
func RemoveWhitelist(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil {
return err
}
@@ -175,7 +174,7 @@ func RemoveWhitelist(c *fiber.Ctx) error {
_, err = q.Whitelist.
Where(
q.Whitelist.ID.In(ids...),
q.Whitelist.UserID.Eq(authContext.Payload.Id),
q.Whitelist.UserID.Eq(authCtx.User.ID),
).
Update(
q.Whitelist.DeletedAt, time.Now(),

42
web/middlewares.go Normal file
View File

@@ -0,0 +1,42 @@
package web
import (
"platform/web/auth"
"github.com/gofiber/contrib/otelfiber/v2"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
"github.com/jxskiss/base62"
)
func ApplyMiddlewares(app *fiber.App) {
// recover
app.Use(recover.New(recover.Config{
EnableStackTrace: true,
}))
// metric
app.Use(otelfiber.Middleware())
// logger
app.Use(logger.New(logger.Config{
Next: func(c *fiber.Ctx) bool {
return c.Path() == "/favicon.ico"
},
}))
// request id
app.Use(requestid.New(requestid.Config{
Generator: func() string {
binary, _ := uuid.New().MarshalBinary()
return base62.EncodeToString(binary)
},
}))
// authenticate
app.Use(auth.Authenticate())
}

View File

@@ -17,10 +17,14 @@ type Channel struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:通道ID" json:"id"` // 通道ID
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
ProxyID int32 `gorm:"column:proxy_id;type:integer;not null;comment:代理ID" json:"proxy_id"` // 代理ID
EdgeID *int32 `gorm:"column:edge_id;type:integer;comment:节点ID" json:"edge_id"` // 节点ID
ResourceID int32 `gorm:"column:resource_id;type:integer;not null;comment:套餐ID" json:"resource_id"` // 套餐ID
ProxyHost string `gorm:"column:proxy_host;type:character varying(255);not null;comment:代理地址" json:"proxy_host"` // 代理地址
ProxyPort int32 `gorm:"column:proxy_port;type:integer;not null;comment:转发端口" json:"proxy_port"` // 转发端口
EdgeHost *string `gorm:"column:edge_host;type:character varying(255);comment:节点地址" json:"edge_host"` // 节点地址
Protocol *int32 `gorm:"column:protocol;type:integer;comment:协议类型1-http2-https3-socks5" json:"protocol"` // 协议类型1-http2-https3-socks5
AuthIP bool `gorm:"column:auth_ip;type:boolean;not null;comment:IP认证" json:"auth_ip"` // IP认证
Whitelists *string `gorm:"column:whitelists;type:text;comment:IP白名单逗号分隔" json:"whitelists"` // IP白名单逗号分隔
AuthPass bool `gorm:"column:auth_pass;type:boolean;not null;comment:密码认证" json:"auth_pass"` // 密码认证
Username *string `gorm:"column:username;type:character varying(255);comment:用户名" json:"username"` // 用户名
Password *string `gorm:"column:password;type:character varying(255);comment:密码" json:"password"` // 密码
@@ -28,10 +32,6 @@ type Channel struct {
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
EdgeHost *string `gorm:"column:edge_host;type:character varying(255);comment:节点地址" json:"edge_host"` // 节点地址
EdgeID *int32 `gorm:"column:edge_id;type:integer;comment:节点ID" json:"edge_id"` // 节点ID
Whitelists *string `gorm:"column:whitelists;type:text;comment:IP白名单逗号分隔" json:"whitelists"` // IP白名单逗号分隔
ResourceID int32 `gorm:"column:resource_id;type:integer;not null;comment:套餐ID" json:"resource_id"` // 套餐ID
}
// TableName Channel's table name

View File

@@ -14,21 +14,18 @@ const TableNameClient = "client"
// Client mapped from table <client>
type Client struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:客户端ID" json:"id"` // 客户端ID
ClientID string `gorm:"column:client_id;type:character varying(255);not null;comment:OAuth2客户端标识符" json:"client_id"` // OAuth2客户端标识符
ClientSecret string `gorm:"column:client_secret;type:character varying(255);not null;comment:OAuth2客户端密钥" json:"client_secret"` // OAuth2客户端密钥
RedirectURI *string `gorm:"column:redirect_uri;type:character varying(255);comment:OAuth2 重定向URI" json:"redirect_uri"` // OAuth2 重定向URI
GrantCode bool `gorm:"column:grant_code;type:boolean;not null;comment:允许授权码授予" json:"grant_code"` // 允许授权码授予
GrantClient bool `gorm:"column:grant_client;type:boolean;not null;comment:允许客户端凭证授予" json:"grant_client"` // 允许客户端凭证授予
GrantRefresh bool `gorm:"column:grant_refresh;type:boolean;not null;comment:允许刷新令牌授予" json:"grant_refresh"` // 允许刷新令牌授予
GrantPassword bool `gorm:"column:grant_password;type:boolean;not null;comment:允许密码授予" json:"grant_password"` // 允许密码授予
Spec int32 `gorm:"column:spec;type:integer;not null;comment:安全规范1-native2-browser3-web4-trusted" json:"spec"` // 安全规范1-native2-browser3-web4-trusted
Name string `gorm:"column:name;type:character varying(255);not null;comment:名称" json:"name"` // 名称
Icon *string `gorm:"column:icon;type:character varying(255);comment:图标URL" json:"icon"` // 图标URL
Status int32 `gorm:"column:status;type:integer;not null;default:1;comment:状态0-禁用1-正常" json:"status"` // 状态0-禁用1-正常
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:客户端ID" json:"id"` // 客户端ID
ClientID string `gorm:"column:client_id;type:character varying(255);not null;comment:OAuth2客户端标识符" json:"client_id"` // OAuth2客户端标识符
ClientSecret string `gorm:"column:client_secret;type:character varying(255);not null;comment:OAuth2客户端密钥" json:"client_secret"` // OAuth2客户端密钥
RedirectURI *string `gorm:"column:redirect_uri;type:character varying(255);comment:OAuth2 重定向URI" json:"redirect_uri"` // OAuth2 重定向URI
Spec int32 `gorm:"column:spec;type:integer;not null;comment:安全规范1-native2-browser3-web4-api" json:"spec"` // 安全规范1-native2-browser3-web4-api
Name string `gorm:"column:name;type:character varying(255);not null;comment:名称" json:"name"` // 名称
Icon *string `gorm:"column:icon;type:character varying(255);comment:图标URL" json:"icon"` // 图标URL
Status int32 `gorm:"column:status;type:integer;not null;default:1;comment:状态0-禁用1-正常" json:"status"` // 状态0-禁用1-正常
Type int32 `gorm:"column:type;type:integer;not null;comment:类型0-普通1-官方" json:"type"` // 类型0-普通1-官方
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
}
// TableName Client's table name

View File

@@ -12,12 +12,12 @@ const TableNameLogsLogin = "logs_login"
type LogsLogin struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:登录日志ID" json:"id"` // 登录日志ID
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
Ua string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
UA string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
GrantType string `gorm:"column:grant_type;type:character varying(255);not null;comment:授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式" json:"grant_type"` // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
PasswordGrantType string `gorm:"column:password_grant_type;type:character varying(255);not null;comment:密码模式子授权类型password-账号密码phone_code-手机验证码email_code-邮箱验证码" json:"password_grant_type"` // 密码模式子授权类型password-账号密码phone_code-手机验证码email_code-邮箱验证码
Success bool `gorm:"column:success;type:boolean;not null;comment:登录是否成功" json:"success"` // 登录是否成功
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:登录时间" json:"time"` // 登录时间
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:登录时间" json:"time"` // 登录时间
}
// TableName LogsLogin's table name

View File

@@ -10,17 +10,17 @@ const TableNameLogsRequest = "logs_request"
// LogsRequest mapped from table <logs_request>
type LogsRequest struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:访问日志ID" json:"id"` // 访问日志ID
Identity int32 `gorm:"column:identity;type:integer;not null;comment:访客身份0-游客1-用户2-管理员3-公共服务4-安全服务5-内部服务" json:"identity"` // 访客身份0-游客1-用户2-管理员3-公共服务4-安全服务5-内部服务
Visitor *int32 `gorm:"column:visitor;type:integer;comment:访客ID" json:"visitor"` // 访客ID
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
Ua *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
Method string `gorm:"column:method;type:character varying(10);not null;comment:请求方法" json:"method"` // 请求方法
Path string `gorm:"column:path;type:character varying(255);not null;comment:请求路径" json:"path"` // 请求路径
Latency string `gorm:"column:latency;type:character varying(255);not null;comment:请求延迟" json:"latency"` // 请求延迟
Status int32 `gorm:"column:status;type:integer;not null;comment:响应状态码" json:"status"` // 响应状态码
Error *string `gorm:"column:error;type:text;comment:错误信息" json:"error"` // 错误信息
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:请求时间" json:"time"` // 请求时间
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:访问日志ID" json:"id"` // 访问日志ID
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
UA string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
Method string `gorm:"column:method;type:character varying(10);not null;comment:请求方法" json:"method"` // 请求方法
Path string `gorm:"column:path;type:character varying(255);not null;comment:请求路径" json:"path"` // 请求路径
Status int32 `gorm:"column:status;type:integer;not null;comment:响应状态码" json:"status"` // 响应状态码
Error *string `gorm:"column:error;type:text;comment:错误信息" json:"error"` // 错误信息
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:请求时间" json:"time"` // 请求时间
Latency string `gorm:"column:latency;type:character varying(255);not null;comment:请求延迟" json:"latency"` // 请求延迟
}
// TableName LogsRequest's table name

View File

@@ -18,12 +18,12 @@ type Proxy struct {
Version int32 `gorm:"column:version;type:integer;not null;comment:代理服务版本" json:"version"` // 代理服务版本
Name string `gorm:"column:name;type:character varying(255);not null;comment:代理服务名称" json:"name"` // 代理服务名称
Host string `gorm:"column:host;type:character varying(255);not null;comment:代理服务地址" json:"host"` // 代理服务地址
Type int32 `gorm:"column:type;type:integer;not null;comment:代理服务类型1-三方2-自有" json:"type"` // 代理服务类型1-三方2-自有
Secret *string `gorm:"column:secret;type:character varying(255);comment:代理服务密钥" json:"secret"` // 代理服务密钥
Type int32 `gorm:"column:type;type:integer;not null;comment:代理服务类型1-三方2-自有" json:"type"` // 代理服务类型1-三方2-自有
Status int32 `gorm:"column:status;type:integer;not null;comment:代理服务状态0-离线1-在线" json:"status"` // 代理服务状态0-离线1-在线
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
Status int32 `gorm:"column:status;type:integer;not null;comment:代理服务状态0-离线1-在线" json:"status"` // 代理服务状态0-离线1-在线
Edges []Edge `gorm:"foreignKey:ProxyID;references:ID" json:"edges"`
}

View File

@@ -14,20 +14,23 @@ const TableNameSession = "session"
// Session mapped from table <session>
type Session struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
IP *string `gorm:"column:ip;type:character varying(45);comment:IP地址" json:"ip"` // IP地址
Ua *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
GrantType string `gorm:"column:grant_type;type:character varying(255);not null;default:0;comment:授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式" json:"grant_type"` // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
AccessToken string `gorm:"column:access_token;type:character varying(255);not null;comment:访问令牌" json:"access_token"` // 访问令牌
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;type:timestamp without time zone;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
RefreshToken *string `gorm:"column:refresh_token;type:character varying(255);comment:刷新令牌" json:"refresh_token"` // 刷新令牌
RefreshTokenExpires *orm.LocalDateTime `gorm:"column:refresh_token_expires;type:timestamp without time zone;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
Scopes_ *string `gorm:"column:scopes;type:character varying(255);comment:权限范围" json:"scopes"` // 权限范围
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
AdminID *int32 `gorm:"column:admin_id;type:integer;comment:管理员ID" json:"admin_id"` // 管理员ID
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
IP *string `gorm:"column:ip;type:character varying(45);comment:IP地址" json:"ip"` // IP地址
UA *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
AccessToken string `gorm:"column:access_token;type:character varying(255);not null;comment:访问令牌" json:"access_token"` // 访问令牌
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;type:timestamp without time zone;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
RefreshToken *string `gorm:"column:refresh_token;type:character varying(255);comment:刷新令牌" json:"refresh_token"` // 刷新令牌
RefreshTokenExpires *orm.LocalDateTime `gorm:"column:refresh_token_expires;type:timestamp without time zone;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
Scopes_ *string `gorm:"column:scopes;type:character varying(255);comment:权限范围" json:"scopes"` // 权限范围
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
User *User `gorm:"foreignKey:UserID" json:"user"`
Admin *Admin `gorm:"foreignKey:UserID" json:"admin"`
Client *Client `gorm:"belongsTo:ID;foreignKey:ClientID" json:"client"`
}
// TableName Session's table name

View File

@@ -15,26 +15,26 @@ const TableNameTrade = "trade"
// Trade mapped from table <trade>
type Trade struct {
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:订单ID" json:"id"` // 订单ID
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
InnerNo string `gorm:"column:inner_no;type:character varying(255);not null;comment:内部订单号" json:"inner_no"` // 内部订单号
OuterNo *string `gorm:"column:outer_no;type:character varying(255);comment:外部订单号" json:"outer_no"` // 外部订单号
Type int32 `gorm:"column:type;type:integer;not null;comment:订单类型1-购买产品2-充值余额" json:"type"` // 订单类型1-购买产品2-充值余额
Subject string `gorm:"column:subject;type:character varying(255);not null;comment:订单主题" json:"subject"` // 订单主题
Remark *string `gorm:"column:remark;type:character varying(255);comment:订单备注" json:"remark"` // 订单备注
Amount decimal.Decimal `gorm:"column:amount;type:numeric(12,2);not null;comment:订单总金额" json:"amount"` // 订单总金额
Payment decimal.Decimal `gorm:"column:payment;type:numeric(12,2);not null;comment:支付金额" json:"payment"` // 支付金额
Method int32 `gorm:"column:method;type:integer;not null;comment:支付方式1-支付宝2-微信3-商福通渠道支付宝,4-商福通渠道微信" json:"method"` // 支付方式1-支付宝2-微信3-商福通渠道支付宝,4-商福通渠道微信
Status int32 `gorm:"column:status;type:integer;not null;comment:订单状态0-待支付1-已支付2-已取消" json:"status"` // 订单状态0-待支付1-已支付2-已取消
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
Acquirer *int32 `gorm:"column:acquirer;type:integer;comment:收单机构1-支付宝2-微信3-银联" json:"acquirer"` // 收单机构1-支付宝2-微信3-银联
Platform int32 `gorm:"column:platform;type:integer;not null;comment:支付平台1-电脑网站2-手机网站" json:"platform"` // 支付平台1-电脑网站2-手机网站
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:订单ID" json:"id"` // 订单ID
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
InnerNo string `gorm:"column:inner_no;type:character varying(255);not null;comment:内部订单号" json:"inner_no"` // 内部订单号
OuterNo *string `gorm:"column:outer_no;type:character varying(255);comment:外部订单号" json:"outer_no"` // 外部订单号
Type int32 `gorm:"column:type;type:integer;not null;comment:订单类型1-购买产品2-充值余额" json:"type"` // 订单类型1-购买产品2-充值余额
Subject string `gorm:"column:subject;type:character varying(255);not null;comment:订单主题" json:"subject"` // 订单主题
Remark *string `gorm:"column:remark;type:character varying(255);comment:订单备注" json:"remark"` // 订单备注
Amount decimal.Decimal `gorm:"column:amount;type:numeric(12,2);not null;comment:订单总金额" json:"amount"` // 订单总金额
Payment decimal.Decimal `gorm:"column:payment;type:numeric(12,2);not null;comment:实际支付金额" json:"payment"` // 实际支付金额
Method int32 `gorm:"column:method;type:integer;not null;comment:支付方式1-支付宝2-微信3-商福通4-商福通渠道支付宝,5-商福通渠道微信" json:"method"` // 支付方式1-支付宝2-微信3-商福通4-商福通渠道支付宝,5-商福通渠道微信
Platform int32 `gorm:"column:platform;type:integer;not null;comment:支付平台1-电脑网站2-手机网站" json:"platform"` // 支付平台1-电脑网站2-手机网站
Acquirer *int32 `gorm:"column:acquirer;type:integer;comment:收单机构1-支付宝2-微信3-银联" json:"acquirer"` // 收单机构1-支付宝2-微信3-银联
Status int32 `gorm:"column:status;type:integer;not null;comment:订单状态0-待支付1-已支付2-已取消" json:"status"` // 订单状态0-待支付1-已支付2-已取消
Refunded bool `gorm:"column:refunded;type:boolean;not null" json:"refunded"`
PaymentURL *string `gorm:"column:payment_url;type:text;comment:支付链接" json:"payment_url"` // 支付链接
CompletedAt *orm.LocalDateTime `gorm:"column:completed_at;type:timestamp without time zone;comment:支付时间" json:"completed_at"` // 支付时间
CanceledAt *orm.LocalDateTime `gorm:"column:canceled_at;type:timestamp without time zone;comment:取消时间" json:"canceled_at"` // 取消时间
Refunded bool `gorm:"column:refunded;type:boolean;not null" json:"refunded"`
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
}
// TableName Trade's table name

View File

@@ -30,10 +30,14 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
_channel.ID = field.NewInt32(tableName, "id")
_channel.UserID = field.NewInt32(tableName, "user_id")
_channel.ProxyID = field.NewInt32(tableName, "proxy_id")
_channel.EdgeID = field.NewInt32(tableName, "edge_id")
_channel.ResourceID = field.NewInt32(tableName, "resource_id")
_channel.ProxyHost = field.NewString(tableName, "proxy_host")
_channel.ProxyPort = field.NewInt32(tableName, "proxy_port")
_channel.EdgeHost = field.NewString(tableName, "edge_host")
_channel.Protocol = field.NewInt32(tableName, "protocol")
_channel.AuthIP = field.NewBool(tableName, "auth_ip")
_channel.Whitelists = field.NewString(tableName, "whitelists")
_channel.AuthPass = field.NewBool(tableName, "auth_pass")
_channel.Username = field.NewString(tableName, "username")
_channel.Password = field.NewString(tableName, "password")
@@ -41,10 +45,6 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
_channel.CreatedAt = field.NewField(tableName, "created_at")
_channel.UpdatedAt = field.NewField(tableName, "updated_at")
_channel.DeletedAt = field.NewField(tableName, "deleted_at")
_channel.EdgeHost = field.NewString(tableName, "edge_host")
_channel.EdgeID = field.NewInt32(tableName, "edge_id")
_channel.Whitelists = field.NewString(tableName, "whitelists")
_channel.ResourceID = field.NewInt32(tableName, "resource_id")
_channel.fillFieldMap()
@@ -58,10 +58,14 @@ type channel struct {
ID field.Int32 // 通道ID
UserID field.Int32 // 用户ID
ProxyID field.Int32 // 代理ID
EdgeID field.Int32 // 节点ID
ResourceID field.Int32 // 套餐ID
ProxyHost field.String // 代理地址
ProxyPort field.Int32 // 转发端口
EdgeHost field.String // 节点地址
Protocol field.Int32 // 协议类型1-http2-https3-socks5
AuthIP field.Bool // IP认证
Whitelists field.String // IP白名单逗号分隔
AuthPass field.Bool // 密码认证
Username field.String // 用户名
Password field.String // 密码
@@ -69,10 +73,6 @@ type channel struct {
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
EdgeHost field.String // 节点地址
EdgeID field.Int32 // 节点ID
Whitelists field.String // IP白名单逗号分隔
ResourceID field.Int32 // 套餐ID
fieldMap map[string]field.Expr
}
@@ -92,10 +92,14 @@ func (c *channel) updateTableName(table string) *channel {
c.ID = field.NewInt32(table, "id")
c.UserID = field.NewInt32(table, "user_id")
c.ProxyID = field.NewInt32(table, "proxy_id")
c.EdgeID = field.NewInt32(table, "edge_id")
c.ResourceID = field.NewInt32(table, "resource_id")
c.ProxyHost = field.NewString(table, "proxy_host")
c.ProxyPort = field.NewInt32(table, "proxy_port")
c.EdgeHost = field.NewString(table, "edge_host")
c.Protocol = field.NewInt32(table, "protocol")
c.AuthIP = field.NewBool(table, "auth_ip")
c.Whitelists = field.NewString(table, "whitelists")
c.AuthPass = field.NewBool(table, "auth_pass")
c.Username = field.NewString(table, "username")
c.Password = field.NewString(table, "password")
@@ -103,10 +107,6 @@ func (c *channel) updateTableName(table string) *channel {
c.CreatedAt = field.NewField(table, "created_at")
c.UpdatedAt = field.NewField(table, "updated_at")
c.DeletedAt = field.NewField(table, "deleted_at")
c.EdgeHost = field.NewString(table, "edge_host")
c.EdgeID = field.NewInt32(table, "edge_id")
c.Whitelists = field.NewString(table, "whitelists")
c.ResourceID = field.NewInt32(table, "resource_id")
c.fillFieldMap()
@@ -127,10 +127,14 @@ func (c *channel) fillFieldMap() {
c.fieldMap["id"] = c.ID
c.fieldMap["user_id"] = c.UserID
c.fieldMap["proxy_id"] = c.ProxyID
c.fieldMap["edge_id"] = c.EdgeID
c.fieldMap["resource_id"] = c.ResourceID
c.fieldMap["proxy_host"] = c.ProxyHost
c.fieldMap["proxy_port"] = c.ProxyPort
c.fieldMap["edge_host"] = c.EdgeHost
c.fieldMap["protocol"] = c.Protocol
c.fieldMap["auth_ip"] = c.AuthIP
c.fieldMap["whitelists"] = c.Whitelists
c.fieldMap["auth_pass"] = c.AuthPass
c.fieldMap["username"] = c.Username
c.fieldMap["password"] = c.Password
@@ -138,10 +142,6 @@ func (c *channel) fillFieldMap() {
c.fieldMap["created_at"] = c.CreatedAt
c.fieldMap["updated_at"] = c.UpdatedAt
c.fieldMap["deleted_at"] = c.DeletedAt
c.fieldMap["edge_host"] = c.EdgeHost
c.fieldMap["edge_id"] = c.EdgeID
c.fieldMap["whitelists"] = c.Whitelists
c.fieldMap["resource_id"] = c.ResourceID
}
func (c channel) clone(db *gorm.DB) channel {

View File

@@ -31,14 +31,11 @@ func newClient(db *gorm.DB, opts ...gen.DOOption) client {
_client.ClientID = field.NewString(tableName, "client_id")
_client.ClientSecret = field.NewString(tableName, "client_secret")
_client.RedirectURI = field.NewString(tableName, "redirect_uri")
_client.GrantCode = field.NewBool(tableName, "grant_code")
_client.GrantClient = field.NewBool(tableName, "grant_client")
_client.GrantRefresh = field.NewBool(tableName, "grant_refresh")
_client.GrantPassword = field.NewBool(tableName, "grant_password")
_client.Spec = field.NewInt32(tableName, "spec")
_client.Name = field.NewString(tableName, "name")
_client.Icon = field.NewString(tableName, "icon")
_client.Status = field.NewInt32(tableName, "status")
_client.Type = field.NewInt32(tableName, "type")
_client.CreatedAt = field.NewField(tableName, "created_at")
_client.UpdatedAt = field.NewField(tableName, "updated_at")
_client.DeletedAt = field.NewField(tableName, "deleted_at")
@@ -51,22 +48,19 @@ func newClient(db *gorm.DB, opts ...gen.DOOption) client {
type client struct {
clientDo
ALL field.Asterisk
ID field.Int32 // 客户端ID
ClientID field.String // OAuth2客户端标识符
ClientSecret field.String // OAuth2客户端密钥
RedirectURI field.String // OAuth2 重定向URI
GrantCode field.Bool // 允许授权码授予
GrantClient field.Bool // 允许客户端凭证授予
GrantRefresh field.Bool // 允许刷新令牌授予
GrantPassword field.Bool // 允许密码授予
Spec field.Int32 // 安全规范1-native2-browser3-web4-trusted
Name field.String // 名称
Icon field.String // 图标URL
Status field.Int32 // 状态0-禁用1-正常
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
ALL field.Asterisk
ID field.Int32 // 客户端ID
ClientID field.String // OAuth2客户端标识符
ClientSecret field.String // OAuth2客户端密钥
RedirectURI field.String // OAuth2 重定向URI
Spec field.Int32 // 安全规范1-native2-browser3-web4-api
Name field.String // 名称
Icon field.String // 图标URL
Status field.Int32 // 状态0-禁用1-正常
Type field.Int32 // 类型0-普通1-官方
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
fieldMap map[string]field.Expr
}
@@ -87,14 +81,11 @@ func (c *client) updateTableName(table string) *client {
c.ClientID = field.NewString(table, "client_id")
c.ClientSecret = field.NewString(table, "client_secret")
c.RedirectURI = field.NewString(table, "redirect_uri")
c.GrantCode = field.NewBool(table, "grant_code")
c.GrantClient = field.NewBool(table, "grant_client")
c.GrantRefresh = field.NewBool(table, "grant_refresh")
c.GrantPassword = field.NewBool(table, "grant_password")
c.Spec = field.NewInt32(table, "spec")
c.Name = field.NewString(table, "name")
c.Icon = field.NewString(table, "icon")
c.Status = field.NewInt32(table, "status")
c.Type = field.NewInt32(table, "type")
c.CreatedAt = field.NewField(table, "created_at")
c.UpdatedAt = field.NewField(table, "updated_at")
c.DeletedAt = field.NewField(table, "deleted_at")
@@ -114,19 +105,16 @@ func (c *client) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (c *client) fillFieldMap() {
c.fieldMap = make(map[string]field.Expr, 15)
c.fieldMap = make(map[string]field.Expr, 12)
c.fieldMap["id"] = c.ID
c.fieldMap["client_id"] = c.ClientID
c.fieldMap["client_secret"] = c.ClientSecret
c.fieldMap["redirect_uri"] = c.RedirectURI
c.fieldMap["grant_code"] = c.GrantCode
c.fieldMap["grant_client"] = c.GrantClient
c.fieldMap["grant_refresh"] = c.GrantRefresh
c.fieldMap["grant_password"] = c.GrantPassword
c.fieldMap["spec"] = c.Spec
c.fieldMap["name"] = c.Name
c.fieldMap["icon"] = c.Icon
c.fieldMap["status"] = c.Status
c.fieldMap["type"] = c.Type
c.fieldMap["created_at"] = c.CreatedAt
c.fieldMap["updated_at"] = c.UpdatedAt
c.fieldMap["deleted_at"] = c.DeletedAt

View File

@@ -29,12 +29,12 @@ func newLogsLogin(db *gorm.DB, opts ...gen.DOOption) logsLogin {
_logsLogin.ALL = field.NewAsterisk(tableName)
_logsLogin.ID = field.NewInt32(tableName, "id")
_logsLogin.IP = field.NewString(tableName, "ip")
_logsLogin.Ua = field.NewString(tableName, "ua")
_logsLogin.UA = field.NewString(tableName, "ua")
_logsLogin.GrantType = field.NewString(tableName, "grant_type")
_logsLogin.PasswordGrantType = field.NewString(tableName, "password_grant_type")
_logsLogin.Success = field.NewBool(tableName, "success")
_logsLogin.Time = field.NewField(tableName, "time")
_logsLogin.UserID = field.NewInt32(tableName, "user_id")
_logsLogin.Time = field.NewField(tableName, "time")
_logsLogin.fillFieldMap()
@@ -47,12 +47,12 @@ type logsLogin struct {
ALL field.Asterisk
ID field.Int32 // 登录日志ID
IP field.String // IP地址
Ua field.String // 用户代理
UA field.String // 用户代理
GrantType field.String // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
PasswordGrantType field.String // 密码模式子授权类型password-账号密码phone_code-手机验证码email_code-邮箱验证码
Success field.Bool // 登录是否成功
Time field.Field // 登录时间
UserID field.Int32 // 用户ID
Time field.Field // 登录时间
fieldMap map[string]field.Expr
}
@@ -71,12 +71,12 @@ func (l *logsLogin) updateTableName(table string) *logsLogin {
l.ALL = field.NewAsterisk(table)
l.ID = field.NewInt32(table, "id")
l.IP = field.NewString(table, "ip")
l.Ua = field.NewString(table, "ua")
l.UA = field.NewString(table, "ua")
l.GrantType = field.NewString(table, "grant_type")
l.PasswordGrantType = field.NewString(table, "password_grant_type")
l.Success = field.NewBool(table, "success")
l.Time = field.NewField(table, "time")
l.UserID = field.NewInt32(table, "user_id")
l.Time = field.NewField(table, "time")
l.fillFieldMap()
@@ -96,12 +96,12 @@ func (l *logsLogin) fillFieldMap() {
l.fieldMap = make(map[string]field.Expr, 8)
l.fieldMap["id"] = l.ID
l.fieldMap["ip"] = l.IP
l.fieldMap["ua"] = l.Ua
l.fieldMap["ua"] = l.UA
l.fieldMap["grant_type"] = l.GrantType
l.fieldMap["password_grant_type"] = l.PasswordGrantType
l.fieldMap["success"] = l.Success
l.fieldMap["time"] = l.Time
l.fieldMap["user_id"] = l.UserID
l.fieldMap["time"] = l.Time
}
func (l logsLogin) clone(db *gorm.DB) logsLogin {

View File

@@ -28,16 +28,16 @@ func newLogsRequest(db *gorm.DB, opts ...gen.DOOption) logsRequest {
tableName := _logsRequest.logsRequestDo.TableName()
_logsRequest.ALL = field.NewAsterisk(tableName)
_logsRequest.ID = field.NewInt32(tableName, "id")
_logsRequest.Identity = field.NewInt32(tableName, "identity")
_logsRequest.Visitor = field.NewInt32(tableName, "visitor")
_logsRequest.IP = field.NewString(tableName, "ip")
_logsRequest.Ua = field.NewString(tableName, "ua")
_logsRequest.UA = field.NewString(tableName, "ua")
_logsRequest.UserID = field.NewInt32(tableName, "user_id")
_logsRequest.ClientID = field.NewInt32(tableName, "client_id")
_logsRequest.Method = field.NewString(tableName, "method")
_logsRequest.Path = field.NewString(tableName, "path")
_logsRequest.Latency = field.NewString(tableName, "latency")
_logsRequest.Status = field.NewInt32(tableName, "status")
_logsRequest.Error = field.NewString(tableName, "error")
_logsRequest.Time = field.NewField(tableName, "time")
_logsRequest.Latency = field.NewString(tableName, "latency")
_logsRequest.fillFieldMap()
@@ -49,16 +49,16 @@ type logsRequest struct {
ALL field.Asterisk
ID field.Int32 // 访问日志ID
Identity field.Int32 // 访客身份0-游客1-用户2-管理员3-公共服务4-安全服务5-内部服务
Visitor field.Int32 // 访客ID
IP field.String // IP地址
Ua field.String // 用户代理
UA field.String // 用户代理
UserID field.Int32 // 用户ID
ClientID field.Int32 // 客户端ID
Method field.String // 请求方法
Path field.String // 请求路径
Latency field.String // 请求延迟
Status field.Int32 // 响应状态码
Error field.String // 错误信息
Time field.Field // 请求时间
Latency field.String // 请求延迟
fieldMap map[string]field.Expr
}
@@ -76,16 +76,16 @@ func (l logsRequest) As(alias string) *logsRequest {
func (l *logsRequest) updateTableName(table string) *logsRequest {
l.ALL = field.NewAsterisk(table)
l.ID = field.NewInt32(table, "id")
l.Identity = field.NewInt32(table, "identity")
l.Visitor = field.NewInt32(table, "visitor")
l.IP = field.NewString(table, "ip")
l.Ua = field.NewString(table, "ua")
l.UA = field.NewString(table, "ua")
l.UserID = field.NewInt32(table, "user_id")
l.ClientID = field.NewInt32(table, "client_id")
l.Method = field.NewString(table, "method")
l.Path = field.NewString(table, "path")
l.Latency = field.NewString(table, "latency")
l.Status = field.NewInt32(table, "status")
l.Error = field.NewString(table, "error")
l.Time = field.NewField(table, "time")
l.Latency = field.NewString(table, "latency")
l.fillFieldMap()
@@ -104,16 +104,16 @@ func (l *logsRequest) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
func (l *logsRequest) fillFieldMap() {
l.fieldMap = make(map[string]field.Expr, 11)
l.fieldMap["id"] = l.ID
l.fieldMap["identity"] = l.Identity
l.fieldMap["visitor"] = l.Visitor
l.fieldMap["ip"] = l.IP
l.fieldMap["ua"] = l.Ua
l.fieldMap["ua"] = l.UA
l.fieldMap["user_id"] = l.UserID
l.fieldMap["client_id"] = l.ClientID
l.fieldMap["method"] = l.Method
l.fieldMap["path"] = l.Path
l.fieldMap["latency"] = l.Latency
l.fieldMap["status"] = l.Status
l.fieldMap["error"] = l.Error
l.fieldMap["time"] = l.Time
l.fieldMap["latency"] = l.Latency
}
func (l logsRequest) clone(db *gorm.DB) logsRequest {

View File

@@ -31,12 +31,12 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy {
_proxy.Version = field.NewInt32(tableName, "version")
_proxy.Name = field.NewString(tableName, "name")
_proxy.Host = field.NewString(tableName, "host")
_proxy.Type = field.NewInt32(tableName, "type")
_proxy.Secret = field.NewString(tableName, "secret")
_proxy.Type = field.NewInt32(tableName, "type")
_proxy.Status = field.NewInt32(tableName, "status")
_proxy.CreatedAt = field.NewField(tableName, "created_at")
_proxy.UpdatedAt = field.NewField(tableName, "updated_at")
_proxy.DeletedAt = field.NewField(tableName, "deleted_at")
_proxy.Status = field.NewInt32(tableName, "status")
_proxy.Edges = proxyHasManyEdges{
db: db.Session(&gorm.Session{}),
@@ -56,12 +56,12 @@ type proxy struct {
Version field.Int32 // 代理服务版本
Name field.String // 代理服务名称
Host field.String // 代理服务地址
Type field.Int32 // 代理服务类型1-三方2-自有
Secret field.String // 代理服务密钥
Type field.Int32 // 代理服务类型1-三方2-自有
Status field.Int32 // 代理服务状态0-离线1-在线
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
Status field.Int32 // 代理服务状态0-离线1-在线
Edges proxyHasManyEdges
fieldMap map[string]field.Expr
@@ -83,12 +83,12 @@ func (p *proxy) updateTableName(table string) *proxy {
p.Version = field.NewInt32(table, "version")
p.Name = field.NewString(table, "name")
p.Host = field.NewString(table, "host")
p.Type = field.NewInt32(table, "type")
p.Secret = field.NewString(table, "secret")
p.Type = field.NewInt32(table, "type")
p.Status = field.NewInt32(table, "status")
p.CreatedAt = field.NewField(table, "created_at")
p.UpdatedAt = field.NewField(table, "updated_at")
p.DeletedAt = field.NewField(table, "deleted_at")
p.Status = field.NewInt32(table, "status")
p.fillFieldMap()
@@ -110,12 +110,12 @@ func (p *proxy) fillFieldMap() {
p.fieldMap["version"] = p.Version
p.fieldMap["name"] = p.Name
p.fieldMap["host"] = p.Host
p.fieldMap["type"] = p.Type
p.fieldMap["secret"] = p.Secret
p.fieldMap["type"] = p.Type
p.fieldMap["status"] = p.Status
p.fieldMap["created_at"] = p.CreatedAt
p.fieldMap["updated_at"] = p.UpdatedAt
p.fieldMap["deleted_at"] = p.DeletedAt
p.fieldMap["status"] = p.Status
}

View File

@@ -29,10 +29,10 @@ func newSession(db *gorm.DB, opts ...gen.DOOption) session {
_session.ALL = field.NewAsterisk(tableName)
_session.ID = field.NewInt32(tableName, "id")
_session.UserID = field.NewInt32(tableName, "user_id")
_session.AdminID = field.NewInt32(tableName, "admin_id")
_session.ClientID = field.NewInt32(tableName, "client_id")
_session.IP = field.NewString(tableName, "ip")
_session.Ua = field.NewString(tableName, "ua")
_session.GrantType = field.NewString(tableName, "grant_type")
_session.UA = field.NewString(tableName, "ua")
_session.AccessToken = field.NewString(tableName, "access_token")
_session.AccessTokenExpires = field.NewField(tableName, "access_token_expires")
_session.RefreshToken = field.NewString(tableName, "refresh_token")
@@ -41,6 +41,23 @@ func newSession(db *gorm.DB, opts ...gen.DOOption) session {
_session.CreatedAt = field.NewField(tableName, "created_at")
_session.UpdatedAt = field.NewField(tableName, "updated_at")
_session.DeletedAt = field.NewField(tableName, "deleted_at")
_session.User = sessionBelongsToUser{
db: db.Session(&gorm.Session{}),
RelationField: field.NewRelation("User", "models.User"),
}
_session.Admin = sessionBelongsToAdmin{
db: db.Session(&gorm.Session{}),
RelationField: field.NewRelation("Admin", "models.Admin"),
}
_session.Client = sessionBelongsToClient{
db: db.Session(&gorm.Session{}),
RelationField: field.NewRelation("Client", "models.Client"),
}
_session.fillFieldMap()
@@ -53,10 +70,10 @@ type session struct {
ALL field.Asterisk
ID field.Int32 // 会话ID
UserID field.Int32 // 用户ID
AdminID field.Int32 // 管理员ID
ClientID field.Int32 // 客户端ID
IP field.String // IP地址
Ua field.String // 用户代理
GrantType field.String // 授权类型authorization_code-授权码模式client_credentials-客户端凭证模式refresh_token-刷新令牌模式password-密码模式
UA field.String // 用户代理
AccessToken field.String // 访问令牌
AccessTokenExpires field.Field // 访问令牌过期时间
RefreshToken field.String // 刷新令牌
@@ -65,6 +82,11 @@ type session struct {
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
User sessionBelongsToUser
Admin sessionBelongsToAdmin
Client sessionBelongsToClient
fieldMap map[string]field.Expr
}
@@ -83,10 +105,10 @@ func (s *session) updateTableName(table string) *session {
s.ALL = field.NewAsterisk(table)
s.ID = field.NewInt32(table, "id")
s.UserID = field.NewInt32(table, "user_id")
s.AdminID = field.NewInt32(table, "admin_id")
s.ClientID = field.NewInt32(table, "client_id")
s.IP = field.NewString(table, "ip")
s.Ua = field.NewString(table, "ua")
s.GrantType = field.NewString(table, "grant_type")
s.UA = field.NewString(table, "ua")
s.AccessToken = field.NewString(table, "access_token")
s.AccessTokenExpires = field.NewField(table, "access_token_expires")
s.RefreshToken = field.NewString(table, "refresh_token")
@@ -111,13 +133,13 @@ func (s *session) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (s *session) fillFieldMap() {
s.fieldMap = make(map[string]field.Expr, 14)
s.fieldMap = make(map[string]field.Expr, 17)
s.fieldMap["id"] = s.ID
s.fieldMap["user_id"] = s.UserID
s.fieldMap["admin_id"] = s.AdminID
s.fieldMap["client_id"] = s.ClientID
s.fieldMap["ip"] = s.IP
s.fieldMap["ua"] = s.Ua
s.fieldMap["grant_type"] = s.GrantType
s.fieldMap["ua"] = s.UA
s.fieldMap["access_token"] = s.AccessToken
s.fieldMap["access_token_expires"] = s.AccessTokenExpires
s.fieldMap["refresh_token"] = s.RefreshToken
@@ -126,18 +148,271 @@ func (s *session) fillFieldMap() {
s.fieldMap["created_at"] = s.CreatedAt
s.fieldMap["updated_at"] = s.UpdatedAt
s.fieldMap["deleted_at"] = s.DeletedAt
}
func (s session) clone(db *gorm.DB) session {
s.sessionDo.ReplaceConnPool(db.Statement.ConnPool)
s.User.db = db.Session(&gorm.Session{Initialized: true})
s.User.db.Statement.ConnPool = db.Statement.ConnPool
s.Admin.db = db.Session(&gorm.Session{Initialized: true})
s.Admin.db.Statement.ConnPool = db.Statement.ConnPool
s.Client.db = db.Session(&gorm.Session{Initialized: true})
s.Client.db.Statement.ConnPool = db.Statement.ConnPool
return s
}
func (s session) replaceDB(db *gorm.DB) session {
s.sessionDo.ReplaceDB(db)
s.User.db = db.Session(&gorm.Session{})
s.Admin.db = db.Session(&gorm.Session{})
s.Client.db = db.Session(&gorm.Session{})
return s
}
type sessionBelongsToUser struct {
db *gorm.DB
field.RelationField
}
func (a sessionBelongsToUser) Where(conds ...field.Expr) *sessionBelongsToUser {
if len(conds) == 0 {
return &a
}
exprs := make([]clause.Expression, 0, len(conds))
for _, cond := range conds {
exprs = append(exprs, cond.BeCond().(clause.Expression))
}
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
return &a
}
func (a sessionBelongsToUser) WithContext(ctx context.Context) *sessionBelongsToUser {
a.db = a.db.WithContext(ctx)
return &a
}
func (a sessionBelongsToUser) Session(session *gorm.Session) *sessionBelongsToUser {
a.db = a.db.Session(session)
return &a
}
func (a sessionBelongsToUser) Model(m *models.Session) *sessionBelongsToUserTx {
return &sessionBelongsToUserTx{a.db.Model(m).Association(a.Name())}
}
func (a sessionBelongsToUser) Unscoped() *sessionBelongsToUser {
a.db = a.db.Unscoped()
return &a
}
type sessionBelongsToUserTx struct{ tx *gorm.Association }
func (a sessionBelongsToUserTx) Find() (result *models.User, err error) {
return result, a.tx.Find(&result)
}
func (a sessionBelongsToUserTx) Append(values ...*models.User) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Append(targetValues...)
}
func (a sessionBelongsToUserTx) Replace(values ...*models.User) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Replace(targetValues...)
}
func (a sessionBelongsToUserTx) Delete(values ...*models.User) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Delete(targetValues...)
}
func (a sessionBelongsToUserTx) Clear() error {
return a.tx.Clear()
}
func (a sessionBelongsToUserTx) Count() int64 {
return a.tx.Count()
}
func (a sessionBelongsToUserTx) Unscoped() *sessionBelongsToUserTx {
a.tx = a.tx.Unscoped()
return &a
}
type sessionBelongsToAdmin struct {
db *gorm.DB
field.RelationField
}
func (a sessionBelongsToAdmin) Where(conds ...field.Expr) *sessionBelongsToAdmin {
if len(conds) == 0 {
return &a
}
exprs := make([]clause.Expression, 0, len(conds))
for _, cond := range conds {
exprs = append(exprs, cond.BeCond().(clause.Expression))
}
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
return &a
}
func (a sessionBelongsToAdmin) WithContext(ctx context.Context) *sessionBelongsToAdmin {
a.db = a.db.WithContext(ctx)
return &a
}
func (a sessionBelongsToAdmin) Session(session *gorm.Session) *sessionBelongsToAdmin {
a.db = a.db.Session(session)
return &a
}
func (a sessionBelongsToAdmin) Model(m *models.Session) *sessionBelongsToAdminTx {
return &sessionBelongsToAdminTx{a.db.Model(m).Association(a.Name())}
}
func (a sessionBelongsToAdmin) Unscoped() *sessionBelongsToAdmin {
a.db = a.db.Unscoped()
return &a
}
type sessionBelongsToAdminTx struct{ tx *gorm.Association }
func (a sessionBelongsToAdminTx) Find() (result *models.Admin, err error) {
return result, a.tx.Find(&result)
}
func (a sessionBelongsToAdminTx) Append(values ...*models.Admin) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Append(targetValues...)
}
func (a sessionBelongsToAdminTx) Replace(values ...*models.Admin) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Replace(targetValues...)
}
func (a sessionBelongsToAdminTx) Delete(values ...*models.Admin) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Delete(targetValues...)
}
func (a sessionBelongsToAdminTx) Clear() error {
return a.tx.Clear()
}
func (a sessionBelongsToAdminTx) Count() int64 {
return a.tx.Count()
}
func (a sessionBelongsToAdminTx) Unscoped() *sessionBelongsToAdminTx {
a.tx = a.tx.Unscoped()
return &a
}
type sessionBelongsToClient struct {
db *gorm.DB
field.RelationField
}
func (a sessionBelongsToClient) Where(conds ...field.Expr) *sessionBelongsToClient {
if len(conds) == 0 {
return &a
}
exprs := make([]clause.Expression, 0, len(conds))
for _, cond := range conds {
exprs = append(exprs, cond.BeCond().(clause.Expression))
}
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
return &a
}
func (a sessionBelongsToClient) WithContext(ctx context.Context) *sessionBelongsToClient {
a.db = a.db.WithContext(ctx)
return &a
}
func (a sessionBelongsToClient) Session(session *gorm.Session) *sessionBelongsToClient {
a.db = a.db.Session(session)
return &a
}
func (a sessionBelongsToClient) Model(m *models.Session) *sessionBelongsToClientTx {
return &sessionBelongsToClientTx{a.db.Model(m).Association(a.Name())}
}
func (a sessionBelongsToClient) Unscoped() *sessionBelongsToClient {
a.db = a.db.Unscoped()
return &a
}
type sessionBelongsToClientTx struct{ tx *gorm.Association }
func (a sessionBelongsToClientTx) Find() (result *models.Client, err error) {
return result, a.tx.Find(&result)
}
func (a sessionBelongsToClientTx) Append(values ...*models.Client) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Append(targetValues...)
}
func (a sessionBelongsToClientTx) Replace(values ...*models.Client) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Replace(targetValues...)
}
func (a sessionBelongsToClientTx) Delete(values ...*models.Client) (err error) {
targetValues := make([]interface{}, len(values))
for i, v := range values {
targetValues[i] = v
}
return a.tx.Delete(targetValues...)
}
func (a sessionBelongsToClientTx) Clear() error {
return a.tx.Clear()
}
func (a sessionBelongsToClientTx) Count() int64 {
return a.tx.Count()
}
func (a sessionBelongsToClientTx) Unscoped() *sessionBelongsToClientTx {
a.tx = a.tx.Unscoped()
return &a
}
type sessionDo struct{ gen.DO }
func (s sessionDo) Debug() *sessionDo {

View File

@@ -37,16 +37,16 @@ func newTrade(db *gorm.DB, opts ...gen.DOOption) trade {
_trade.Amount = field.NewField(tableName, "amount")
_trade.Payment = field.NewField(tableName, "payment")
_trade.Method = field.NewInt32(tableName, "method")
_trade.Status = field.NewInt32(tableName, "status")
_trade.CreatedAt = field.NewField(tableName, "created_at")
_trade.UpdatedAt = field.NewField(tableName, "updated_at")
_trade.DeletedAt = field.NewField(tableName, "deleted_at")
_trade.Acquirer = field.NewInt32(tableName, "acquirer")
_trade.Platform = field.NewInt32(tableName, "platform")
_trade.Acquirer = field.NewInt32(tableName, "acquirer")
_trade.Status = field.NewInt32(tableName, "status")
_trade.Refunded = field.NewBool(tableName, "refunded")
_trade.PaymentURL = field.NewString(tableName, "payment_url")
_trade.CompletedAt = field.NewField(tableName, "completed_at")
_trade.CanceledAt = field.NewField(tableName, "canceled_at")
_trade.Refunded = field.NewBool(tableName, "refunded")
_trade.CreatedAt = field.NewField(tableName, "created_at")
_trade.UpdatedAt = field.NewField(tableName, "updated_at")
_trade.DeletedAt = field.NewField(tableName, "deleted_at")
_trade.fillFieldMap()
@@ -65,18 +65,18 @@ type trade struct {
Subject field.String // 订单主题
Remark field.String // 订单备注
Amount field.Field // 订单总金额
Payment field.Field // 支付金额
Method field.Int32 // 支付方式1-支付宝2-微信3-商福通渠道支付宝,4-商福通渠道微信
Status field.Int32 // 订单状态0-待支付1-已支付2-已取消
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
Acquirer field.Int32 // 收单机构1-支付宝2-微信3-银联
Payment field.Field // 实际支付金额
Method field.Int32 // 支付方式1-支付宝2-微信3-商福通4-商福通渠道支付宝,5-商福通渠道微信
Platform field.Int32 // 支付平台1-电脑网站2-手机网站
Acquirer field.Int32 // 收单机构1-支付宝2-微信3-银联
Status field.Int32 // 订单状态0-待支付1-已支付2-已取消
Refunded field.Bool
PaymentURL field.String // 支付链接
CompletedAt field.Field // 支付时间
CanceledAt field.Field // 取消时间
Refunded field.Bool
CreatedAt field.Field // 创建时间
UpdatedAt field.Field // 更新时间
DeletedAt field.Field // 删除时间
fieldMap map[string]field.Expr
}
@@ -103,16 +103,16 @@ func (t *trade) updateTableName(table string) *trade {
t.Amount = field.NewField(table, "amount")
t.Payment = field.NewField(table, "payment")
t.Method = field.NewInt32(table, "method")
t.Status = field.NewInt32(table, "status")
t.CreatedAt = field.NewField(table, "created_at")
t.UpdatedAt = field.NewField(table, "updated_at")
t.DeletedAt = field.NewField(table, "deleted_at")
t.Acquirer = field.NewInt32(table, "acquirer")
t.Platform = field.NewInt32(table, "platform")
t.Acquirer = field.NewInt32(table, "acquirer")
t.Status = field.NewInt32(table, "status")
t.Refunded = field.NewBool(table, "refunded")
t.PaymentURL = field.NewString(table, "payment_url")
t.CompletedAt = field.NewField(table, "completed_at")
t.CanceledAt = field.NewField(table, "canceled_at")
t.Refunded = field.NewBool(table, "refunded")
t.CreatedAt = field.NewField(table, "created_at")
t.UpdatedAt = field.NewField(table, "updated_at")
t.DeletedAt = field.NewField(table, "deleted_at")
t.fillFieldMap()
@@ -140,16 +140,16 @@ func (t *trade) fillFieldMap() {
t.fieldMap["amount"] = t.Amount
t.fieldMap["payment"] = t.Payment
t.fieldMap["method"] = t.Method
t.fieldMap["status"] = t.Status
t.fieldMap["created_at"] = t.CreatedAt
t.fieldMap["updated_at"] = t.UpdatedAt
t.fieldMap["deleted_at"] = t.DeletedAt
t.fieldMap["acquirer"] = t.Acquirer
t.fieldMap["platform"] = t.Platform
t.fieldMap["acquirer"] = t.Acquirer
t.fieldMap["status"] = t.Status
t.fieldMap["refunded"] = t.Refunded
t.fieldMap["payment_url"] = t.PaymentURL
t.fieldMap["completed_at"] = t.CompletedAt
t.fieldMap["canceled_at"] = t.CanceledAt
t.fieldMap["refunded"] = t.Refunded
t.fieldMap["created_at"] = t.CreatedAt
t.fieldMap["updated_at"] = t.UpdatedAt
t.fieldMap["deleted_at"] = t.DeletedAt
}
func (t trade) clone(db *gorm.DB) trade {

View File

@@ -1,7 +1,7 @@
package web
import (
"platform/web/core"
auth2 "platform/web/auth"
"platform/web/handlers"
"github.com/gofiber/fiber/v2"
@@ -12,7 +12,7 @@ func ApplyRouters(app *fiber.App) {
// 认证
auth := api.Group("/auth")
auth.Post("/token", handlers.Token)
auth.Post("/token", auth2.Token)
auth.Post("/revoke", handlers.Revoke)
auth.Post("/introspect", handlers.Introspect)
auth.Post("/verify/sms", handlers.SmsCode)
@@ -47,7 +47,6 @@ func ApplyRouters(app *fiber.App) {
channel.Post("/list", handlers.ListChannels)
channel.Post("/create", handlers.CreateChannel)
channel.Post("/remove", handlers.RemoveChannels)
channel.Post("/remove/by-task", handlers.RemoveChannelByTask)
// 交易
trade := api.Group("/trade")
@@ -75,12 +74,6 @@ func ApplyRouters(app *fiber.App) {
edge.Post("/all", handlers.AllEdgesAvailable)
// 临时
app.Get("/test", func(c *fiber.Ctx) error {
return core.NewBizErr("测试错误")
})
// 异步任务客户端
tasks := api.Group("/tasks")
tasks.Post("/channel/remove", handlers.RemoveChannelByTask)
tasks.Post("/trade/cancel", handlers.TradeCancelByTask)
debug := app.Group("/debug")
debug.Get("/sms/:phone", handlers.DebugGetSmsCode)
}

View File

@@ -1,201 +0,0 @@
package services
import (
"context"
"errors"
"golang.org/x/crypto/bcrypt"
"log/slog"
"platform/pkg/u"
auth2 "platform/web/auth"
"platform/web/core"
client2 "platform/web/domains/client"
user2 "platform/web/domains/user"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"time"
"gorm.io/gorm"
)
var Auth = &authService{}
type authService struct{}
// OauthAuthorizationCode 验证授权码
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*auth2.TokenDetails, error) {
return nil, errors.New("TODO")
}
// OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*auth2.TokenDetails, error) {
var clientType = auth2.PayloadTypeFromClientSpec(client2.Spec(client.Spec))
var permissions = make(map[string]struct{}, len(scope))
for _, item := range scope {
permissions[item] = struct{}{}
}
// 保存会话并返回令牌
authCtx := auth2.Context{
Permissions: permissions,
Payload: auth2.Payload{
Id: client.ID,
Type: clientType,
Name: client.Name,
},
}
token, err := auth2.CreateSession(ctx, &authCtx, false)
if err != nil {
return nil, err
}
return token, nil
}
// OauthRefreshToken 验证刷新令牌
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*auth2.TokenDetails, error) {
details, err := auth2.RefreshSession(ctx, refreshToken, true)
if err != nil {
return nil, err
}
return details, nil
}
// OauthPassword 验证密码
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*auth2.TokenDetails, error) {
var user *m.User
err := q.Q.Transaction(func(tx *q.Query) error {
switch data.LoginType {
case auth2.GrantPasswordPhone:
// 验证验证码
err := Verifier.VerifySms(ctx, data.Username, data.Password)
if err != nil {
if errors.Is(err, ErrVerifierServiceInvalid) {
return ErrOauthInvalidRequest
}
return err
}
// 查找用户
user, err =
tx.User.Where(tx.User.Phone.Eq(data.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case auth2.GrantPasswordEmail:
return core.NewServErr("邮箱登录暂不可用")
case auth2.GrantPasswordSecret:
var err error
user, err = tx.User.
Where(tx.User.Phone.Eq(data.Username)).
Or(tx.User.Email.Eq(data.Username)).
Or(tx.User.Username.Eq(data.Username)).
Take()
if err != nil {
slog.Debug("查找用户失败", "error", err)
return core.NewBizErr("用户不存在或密码错误")
}
// 账户状态
if user2.Status(user.Status) == user2.StatusDisabled {
slog.Debug("账户状态异常", "username", data.Username, "status", user.Status)
return core.NewBizErr("用户不存在或密码错误")
}
// 验证密码
if user.Password == nil || *user.Password == "" {
slog.Debug("用户未设置密码", "username", data.Username)
return core.NewBizErr("用户不存在或密码错误")
}
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(data.Password)) != nil {
slog.Debug("密码验证失败", "username", data.Username)
return core.NewBizErr("用户不存在或密码错误")
}
default:
return ErrOauthInvalidRequest
}
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
if user == nil {
user = &m.User{
Phone: data.Username,
Username: u.P(data.Username),
}
}
// 更新用户的登录时间
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
user.LastLoginHost = u.P(ip)
user.LastLoginAgent = u.P(agent)
if err := tx.User.Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 保存到会话
var name = ""
if user.Name != nil {
name = *user.Name
}
authCtx := auth2.Context{
Payload: auth2.Payload{
Id: user.ID,
Type: auth2.PayloadUser,
Name: name,
Avatar: user.Avatar,
},
}
token, err := auth2.CreateSession(ctx, &authCtx, data.Remember)
if err != nil {
return nil, err
}
return token, nil
}
type GrantCodeData struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type GrantClientData struct {
}
type GrantRefreshData struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type GrantPasswordData struct {
LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type AuthServiceError string
func (e AuthServiceError) Error() string {
return string(e)
}
const (
ErrOauthInvalidRequest = AuthServiceError("invalid_request")
ErrOauthInvalidClient = AuthServiceError("invalid_client")
ErrOauthInvalidGrant = AuthServiceError("invalid_grant")
ErrOauthInvalidScope = AuthServiceError("invalid_scope")
ErrOauthUnauthorizedClient = AuthServiceError("unauthorized_client")
ErrOauthUnsupportedGrantType = AuthServiceError("unsupported_grant_type")
)

View File

@@ -4,7 +4,6 @@ import (
"context"
"database/sql"
"fmt"
"github.com/gofiber/fiber/v2"
"log/slog"
"math"
"math/rand/v2"
@@ -15,15 +14,17 @@ import (
edge2 "platform/web/domains/edge"
proxy2 "platform/web/domains/proxy"
resource2 "platform/web/domains/resource"
"platform/web/events"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"platform/web/tasks"
"strconv"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/hibiken/asynq"
"gorm.io/gen/field"
@@ -296,7 +297,7 @@ func (s *channelService) CreateChannel(
ids[i] = channels[i].ID
}
_, err = g.Asynq.Enqueue(
tasks.NewRemoveChannel(ids),
events.NewRemoveChannel(ids),
asynq.ProcessIn(duration),
)
if err != nil {

View File

@@ -12,11 +12,11 @@ import (
"platform/web/core"
coupon2 "platform/web/domains/coupon"
trade2 "platform/web/domains/trade"
"platform/web/events"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"platform/web/tasks"
"time"
"github.com/shopspring/decimal"
@@ -240,7 +240,7 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, data *CreateTradeDa
}
// 提交异步关闭事件
_, err = g.Asynq.Enqueue(tasks.NewCancelTrade(tasks.CancelTradeData{
_, err = g.Asynq.Enqueue(events.NewCancelTrade(events.CancelTradeData{
TradeNo: tradeNo,
Method: method,
}))
@@ -417,7 +417,7 @@ func (s *tradeService) CancelTrade(tradeNo string, method trade2.Method, now tim
MchOrderNo: &tradeNo,
})
if err != nil {
slog.Debug(fmt.Sprintf("订单无需关闭%s", err.Error()))
slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error()))
return nil
}
@@ -546,11 +546,11 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
return nil, core.NewBizErr("订单不存在")
}
return nil, core.NewServErr(
fmt.Sprintf("微信上游接口异常code=%vmessage=%v", apiErr.Code, apiErr.Message),
fmt.Sprintf("微信上游接口异常: code=%vmessage=%v", apiErr.Code, apiErr.Message),
apiErr,
)
}
return nil, core.NewServErr(fmt.Sprintf("微信上游支付接口异常%s", err.Error()))
return nil, core.NewServErr(fmt.Sprintf("微信上游支付接口异常: %s", err.Error()))
}
// 填充返回值

View File

@@ -19,28 +19,6 @@ import (
var Verifier = &verifierService{}
type VerifierServiceError string
func (e VerifierServiceError) Error() string {
return string(e)
}
var (
ErrVerifierServiceInvalid = VerifierServiceError("验证码错误")
)
type VerifierServiceSendLimitErr int
func (e VerifierServiceSendLimitErr) Error() string {
return "发送频率过快"
}
type VerifierSmsPurpose int
const (
VerifierSmsPurposeLogin VerifierSmsPurpose = iota
)
type verifierService struct {
}
@@ -148,6 +126,43 @@ func (s *verifierService) VerifySms(ctx context.Context, phone, code string) err
return nil
}
func (s *verifierService) GetSms(ctx context.Context, phone string) (string, error) {
key := smsKey(phone, VerifierSmsPurposeLogin)
val, err := g.Redis.Get(ctx, key).Result()
if err != nil {
return "", fmt.Errorf("验证码获取失败: %w", err)
}
return val, nil
}
func smsKey(phone string, purpose VerifierSmsPurpose) string {
return fmt.Sprintf("verify:sms:%d:%s", purpose, phone)
}
// region 短信目的
type VerifierSmsPurpose int
const (
VerifierSmsPurposeLogin VerifierSmsPurpose = iota // 登录
)
// region 服务异常
type VerifierServiceError string
func (e VerifierServiceError) Error() string {
return string(e)
}
var (
ErrVerifierServiceInvalid = VerifierServiceError("验证码错误")
)
type VerifierServiceSendLimitErr int
func (e VerifierServiceSendLimitErr) Error() string {
return "发送频率过快"
}

41
web/tasks/task.go Normal file
View File

@@ -0,0 +1,41 @@
package tasks
import (
"context"
"encoding/json"
"fmt"
"platform/web/events"
s "platform/web/services"
"time"
"github.com/hibiken/asynq"
)
func HandleCancelTrade(_ context.Context, task *asynq.Task) (err error) {
data := new(events.CancelTradeData)
err = json.Unmarshal(task.Payload(), data)
if err != nil {
return fmt.Errorf("解析任务参数失败: %w", err)
}
err = s.Trade.CancelTrade(data.TradeNo, data.Method, time.Now())
if err != nil {
return fmt.Errorf("取消交易失败: %w", err)
}
return nil
}
func HandleRemoveChannel(_ context.Context, task *asynq.Task) (err error) {
data := make([]int32, 0)
err = json.Unmarshal(task.Payload(), &data)
if err != nil {
return fmt.Errorf("解析任务参数失败: %w", err)
}
err = s.Channel.RemoveChannels(data)
if err != nil {
return fmt.Errorf("删除通道失败: %w", err)
}
return nil
}

View File

@@ -1,166 +1,89 @@
package web
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"context"
"fmt"
"log/slog"
"net/http"
_ "net/http/pprof"
"platform/web/auth"
g "platform/web/globals"
q "platform/web/queries"
"runtime"
"strings"
"time"
"platform/web/events"
base "platform/web/globals"
"platform/web/tasks"
"github.com/gofiber/fiber/v2"
"github.com/hibiken/asynq"
"golang.org/x/sync/errgroup"
)
// region web
func RunApp(pCtx context.Context) error {
g, ctx := errgroup.WithContext(pCtx)
type Config struct {
Listen string
}
type Server struct {
config *Config
fiber *fiber.App
}
func New(config *Config) (*Server, error) {
_config := config
if config == nil {
_config = &Config{}
// 初始化依赖
err := base.Init(ctx)
if err != nil {
return fmt.Errorf("初始化依赖失败: %w", err)
}
return &Server{
config: _config,
}, nil
// 运行服务
g.Go(func() error {
return RunWeb(ctx)
})
g.Go(func() error {
return RunTask(ctx)
})
return g.Wait()
}
func (s *Server) Run() error {
func RunWeb(ctx context.Context) error {
// inits
g.Init()
q.SetDefault(g.DB)
// config
s.fiber = fiber.New(fiber.Config{
fiber := fiber.New(fiber.Config{
ProxyHeader: fiber.HeaderXForwardedFor,
ErrorHandler: ErrorHandler,
})
// middlewares
s.fiber.Use(newRecover())
s.fiber.Use(newRequestId())
s.fiber.Use(newLogger())
ApplyMiddlewares(fiber)
ApplyRouters(fiber)
// routes
ApplyRouters(s.fiber)
// pprof
// 停止服务
go func() {
runtime.SetBlockProfileRate(1)
err := http.ListenAndServe(":6060", nil)
<-ctx.Done()
err := fiber.Shutdown()
if err != nil {
slog.Error("pprof 服务错误", slog.Any("err", err))
slog.Error("服务停止失败", "error", err)
}
}()
// listen
slog.Info("服务开始监听 :8080")
err := s.fiber.Listen("0.0.0.0:8080")
// 启动服务
slog.Info("web 服务开始监听 :8080")
err := fiber.Listen("0.0.0.0:8080")
if err != nil {
slog.Error("Failed to start server", slog.Any("err", err))
return fmt.Errorf("web 服务监听失败: %w", err)
}
slog.Info("服务已停止")
slog.Info("web 服务已停止")
return nil
}
func (s *Server) Stop() {
err := g.ExitRedis()
func RunTask(ctx context.Context) error {
var server = asynq.NewServerFromRedisClient(base.Redis, asynq.Config{})
var mux = asynq.NewServeMux()
mux.HandleFunc(events.RemoveChannel, tasks.HandleRemoveChannel)
mux.HandleFunc(events.CancelTrade, tasks.HandleCancelTrade)
// 停止服务
go func() {
<-ctx.Done()
server.Shutdown()
}()
// 启动服务
err := server.Run(mux)
if err != nil {
slog.Error("Failed to close Redis connection", slog.Any("err", err))
return fmt.Errorf("任务服务运行失败: %w", err)
}
err = g.ExitOrm()
if err != nil {
slog.Error("Failed to close database connection", slog.Any("err", err))
}
err = s.fiber.Shutdown()
if err != nil {
slog.Error("Failed to shutdown server", slog.Any("err", err))
}
return nil
}
// endregion
// region middlewares
func newRequestId() fiber.Handler {
return requestid.New(requestid.Config{
Generator: func() string {
binary, _ := uuid.New().MarshalBinary()
return base62.EncodeToString(binary)
},
})
}
func newLogger() fiber.Handler {
return logger.New(logger.Config{
DisableColors: true,
Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n",
TimeFormat: "2006-01-02 15:04:05",
TimeZone: "Asia/Shanghai",
Next: func(c *fiber.Ctx) bool {
authCtx, ok := c.Locals("auth").(*auth.Context)
if ok {
c.Locals("authtype", authCtx.Payload.Type.ToStr())
c.Locals("authid", authCtx.Payload.Id)
} else {
c.Locals("authtype", auth.PayloadNone.ToStr())
c.Locals("authid", 0)
}
return false
},
Done: func(c *fiber.Ctx, logBytes []byte) {
var logStr = strings.TrimPrefix(string(logBytes), "🚀")
var logVars = strings.Split(logStr, "|")
var reqTimeStr = strings.TrimSpace(logVars[0])
reqTime, err := time.ParseInLocation("2006-01-02 15:04:05", reqTimeStr, time.Local)
if err != nil {
slog.Error("时间解析错误", slog.Any("err", err))
return
}
var latency = strings.TrimSpace(logVars[4])
var errStr = strings.TrimSpace(logVars[5])
slog.Info("接口请求",
slog.String("identity", c.Locals("authtype").(string)),
slog.Int("visitor", c.Locals("authid").(int)),
slog.String("ip", c.IP()),
slog.String("ua", c.Get("User-Agent")),
slog.String("method", c.Method()),
slog.String("path", c.Path()),
slog.Int("status", c.Response().StatusCode()),
slog.String("error", errStr),
slog.String("latency", latency),
slog.Time("time", reqTime),
)
},
})
}
func newRecover() fiber.Handler {
return recover.New(recover.Config{
EnableStackTrace: true,
})
}
// endregion