重构代码结构与认证体系,集成异步任务消费者
This commit is contained in:
@@ -4,112 +4,101 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"platform/web/core"
|
||||
client2 "platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"slices"
|
||||
s "platform/web/services"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type ProtectBuilder struct {
|
||||
c *fiber.Ctx
|
||||
types []PayloadType
|
||||
scopes []string
|
||||
func Authenticate() fiber.Handler {
|
||||
return func(ctx *fiber.Ctx) error {
|
||||
header := ctx.Get(fiber.HeaderAuthorization)
|
||||
authCtx, err := authHeader(ctx.Context(), header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if authCtx == nil {
|
||||
authCtx = &AuthCtx{}
|
||||
}
|
||||
|
||||
SetAuthCtx(ctx, authCtx)
|
||||
return ctx.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func NewProtect(c *fiber.Ctx) *ProtectBuilder {
|
||||
return &ProtectBuilder{c, []PayloadType{}, []string{}}
|
||||
}
|
||||
func authHeader(ctx context.Context, header string) (*AuthCtx, error) {
|
||||
if header == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Payload(types ...PayloadType) *ProtectBuilder {
|
||||
p.types = types
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Scopes(scopes ...string) *ProtectBuilder {
|
||||
p.scopes = scopes
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProtectBuilder) Do() (*Context, error) {
|
||||
return Protect(p.c, p.types, p.scopes)
|
||||
}
|
||||
|
||||
func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) {
|
||||
// 获取令牌
|
||||
var header = c.Get("Authorization")
|
||||
var split = strings.Split(header, " ")
|
||||
if len(split) != 2 {
|
||||
slog.Debug("Authorization 头格式不正确")
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
var token = strings.TrimSpace(split[1])
|
||||
if token == "" {
|
||||
slog.Debug("提供的令牌为空")
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
var auth *Context
|
||||
var authCtx *AuthCtx
|
||||
var err error
|
||||
switch split[0] {
|
||||
|
||||
case "Bearer":
|
||||
auth, err = authBearer(c.Context(), token)
|
||||
authCtx, err = authBearer(ctx, token)
|
||||
if err != nil {
|
||||
slog.Debug("Bearer 认证失败", "err", err)
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
case "Basic":
|
||||
if !slices.Contains(types, PayloadInternalServer) {
|
||||
slog.Debug("禁止使用 Basic 认证方式")
|
||||
return nil, ErrUnauthorize
|
||||
}
|
||||
auth, err = authBasic(c.Context(), token)
|
||||
authCtx, err = authBasic(ctx, token)
|
||||
if err != nil {
|
||||
slog.Debug("Basic 认证失败", "err", err)
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
default:
|
||||
slog.Debug("无效的认证方式", "method", split[0])
|
||||
return nil, ErrUnauthorize
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
|
||||
// 检查权限
|
||||
if !slices.Contains(types, auth.Payload.Type) {
|
||||
slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type)
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
|
||||
slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions)
|
||||
return nil, ErrForbidden
|
||||
}
|
||||
|
||||
// 保存到上下文
|
||||
Locals(c, auth)
|
||||
return auth, nil
|
||||
return authCtx, err
|
||||
}
|
||||
|
||||
func Locals(c *fiber.Ctx, auth *Context) {
|
||||
c.Locals("auth", auth)
|
||||
}
|
||||
|
||||
func authBearer(ctx context.Context, token string) (*Context, error) {
|
||||
auth, err := FindSession(ctx, token)
|
||||
func authBearer(_ context.Context, token string) (*AuthCtx, error) {
|
||||
session, err := FindSession(token, time.Now())
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil, err
|
||||
slog.Debug("Bearer 认证失败", "err", err)
|
||||
return nil, ErrAuthenticateUnauthorize
|
||||
}
|
||||
return auth, nil
|
||||
|
||||
scopes := []string{}
|
||||
if session.Scopes_ != nil {
|
||||
scopes = strings.Split(*session.Scopes_, " ")
|
||||
}
|
||||
return &AuthCtx{
|
||||
User: session.User,
|
||||
Admin: session.Admin,
|
||||
Client: session.Client,
|
||||
Scopes: scopes,
|
||||
Session: session,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
func authBasic(_ context.Context, token string) (*AuthCtx, error) {
|
||||
|
||||
// 解析 Basic 认证信息
|
||||
var base, err = base64.RawURLEncoding.DecodeString(token)
|
||||
@@ -125,14 +114,23 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
return nil, errors.New("令牌格式错误,必须是 <client_id>:<client_secret> 格式")
|
||||
}
|
||||
|
||||
var clientID = split[0]
|
||||
client, err := authClient(split[0], split[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("客户端认证失败:%w", err)
|
||||
}
|
||||
|
||||
return &AuthCtx{
|
||||
Client: client,
|
||||
Scopes: []string{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func authClient(clientId, clientSecret string) (*m.Client, error) {
|
||||
|
||||
// 获取客户端信息
|
||||
client, err := q.Client.
|
||||
Where(
|
||||
q.Client.ClientID.Eq(clientID),
|
||||
q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)),
|
||||
q.Client.GrantClient.Is(true),
|
||||
q.Client.ClientID.Eq(clientId),
|
||||
q.Client.Status.Eq(1)).
|
||||
Take()
|
||||
if err != nil {
|
||||
@@ -140,33 +138,57 @@ func authBasic(_ context.Context, token string) (*Context, error) {
|
||||
}
|
||||
|
||||
// 检查客户端密钥
|
||||
var clientSecret = split[1]
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, errors.New("客户端密钥错误")
|
||||
spec := client2.Spec(client.Spec)
|
||||
if spec == client2.SpecWeb || spec == client2.SpecApi {
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, errors.New("客户端密钥错误")
|
||||
}
|
||||
}
|
||||
|
||||
// todo 查询客户端关联权限
|
||||
|
||||
// 组织授权信息(一次性请求)
|
||||
return &Context{
|
||||
Payload: Payload{
|
||||
Id: client.ID,
|
||||
Type: PayloadTypeFromClientSpec(client2.Spec(client.Spec)),
|
||||
Name: client.Name,
|
||||
Avatar: client.Icon,
|
||||
},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
}, nil
|
||||
return client, nil
|
||||
}
|
||||
|
||||
type AuthenticationErr string
|
||||
func authUserBySms(tx *q.Query, username, code string) (*m.User, error) {
|
||||
// 验证验证码
|
||||
err := s.Verifier.VerifySms(context.Background(), username, code)
|
||||
if err != nil {
|
||||
if errors.Is(err, s.ErrVerifierServiceInvalid) {
|
||||
return nil, ErrAuthorizeInvalidRequest
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (e AuthenticationErr) Error() string {
|
||||
return string(e)
|
||||
// 查找用户
|
||||
return tx.User.Where(tx.User.Phone.Eq(username)).Take()
|
||||
}
|
||||
|
||||
var (
|
||||
ErrUnauthorize = AuthenticationErr("令牌无效")
|
||||
ErrForbidden = AuthenticationErr("没有权限")
|
||||
)
|
||||
func authUserByEmail(tx *q.Query, username, code string) (*m.User, error) {
|
||||
return nil, core.NewServErr("邮箱登录不可用")
|
||||
}
|
||||
|
||||
func authUserByPassword(tx *q.Query, username, password string) (*m.User, error) {
|
||||
user, err := tx.User.
|
||||
Where(tx.User.Phone.Eq(username)).
|
||||
Or(tx.User.Email.Eq(username)).
|
||||
Or(tx.User.Username.Eq(username)).
|
||||
Take()
|
||||
if err != nil {
|
||||
slog.Debug("查找用户失败", "error", err)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if user.Password == nil || *user.Password == "" {
|
||||
slog.Debug("用户未设置密码", "username", username)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(password)) != nil {
|
||||
slog.Debug("密码验证失败", "username", username)
|
||||
return nil, core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/u"
|
||||
"platform/web/core"
|
||||
user2 "platform/web/domains/user"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GrantType string
|
||||
|
||||
const (
|
||||
@@ -17,36 +40,352 @@ const (
|
||||
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱验证码
|
||||
)
|
||||
|
||||
func Token(grant GrantType) error {
|
||||
return nil
|
||||
type TokenReq struct {
|
||||
GrantType GrantType `json:"grant_type" form:"grant_type"`
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
GrantCodeData
|
||||
GrantClientData
|
||||
GrantRefreshData
|
||||
GrantPasswordData
|
||||
}
|
||||
|
||||
func authAuthorizationCode() {
|
||||
|
||||
type GrantCodeData struct {
|
||||
Code string `json:"code" form:"code"`
|
||||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||||
}
|
||||
|
||||
func authClientCredential() {
|
||||
|
||||
type GrantClientData struct {
|
||||
}
|
||||
|
||||
func authRefreshToken() {
|
||||
|
||||
type GrantRefreshData struct {
|
||||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||||
}
|
||||
|
||||
func authPassword() {
|
||||
|
||||
type GrantPasswordData struct {
|
||||
LoginType PasswordGrantType `json:"login_type" form:"login_type"`
|
||||
Username string `json:"username" form:"username"`
|
||||
Password string `json:"password" form:"password"`
|
||||
Remember bool `json:"remember" form:"remember"`
|
||||
}
|
||||
|
||||
func authPasswordSecret() {
|
||||
|
||||
type TokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
func authPasswordPhone() {
|
||||
|
||||
type TokenErrResp struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
func authPasswordEmail() {
|
||||
func Token(c *fiber.Ctx) error {
|
||||
now := time.Now()
|
||||
|
||||
// 验证请求参数
|
||||
req := new(TokenReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "无法解析请求参数")
|
||||
}
|
||||
if req.GrantType == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: grant_type")
|
||||
}
|
||||
|
||||
switch req.GrantType {
|
||||
// 授权码模式
|
||||
case GrantAuthorizationCode:
|
||||
if req.Code == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: code")
|
||||
}
|
||||
// 刷新令牌模式
|
||||
case GrantRefreshToken:
|
||||
if req.RefreshToken == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: refresh_token")
|
||||
}
|
||||
// 密码模式
|
||||
case GrantPassword:
|
||||
if req.LoginType == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password_type")
|
||||
}
|
||||
if req.Username == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: username")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return sendError(c, ErrAuthorizeInvalidRequest, "缺少必要参数: password")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证客户端身份
|
||||
authCtx := GetAuthCtx(c)
|
||||
if authCtx == nil {
|
||||
authCtx = &AuthCtx{}
|
||||
}
|
||||
if authCtx.Client == nil {
|
||||
client, err := authClient(req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
authCtx.Client = client
|
||||
}
|
||||
|
||||
// 处理授权
|
||||
var session *m.Session
|
||||
var err error
|
||||
switch req.GrantType {
|
||||
// 授权码模式
|
||||
case GrantAuthorizationCode:
|
||||
session, err = authAuthorizationCode(c, authCtx, req, now)
|
||||
// 客户端凭证模式
|
||||
case GrantClientCredentials:
|
||||
session, err = authClientCredential(c, authCtx, req, now)
|
||||
// 刷新令牌模式
|
||||
case GrantRefreshToken:
|
||||
session, err = authRefreshToken(c, authCtx, req, now)
|
||||
// 密码模式
|
||||
case GrantPassword:
|
||||
session, err = authPassword(c, authCtx, req, now)
|
||||
default:
|
||||
return sendError(c, ErrAuthorizeUnsupportedGrantType)
|
||||
}
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
return c.JSON(&TokenResp{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: session.AccessToken,
|
||||
RefreshToken: u.Z(session.RefreshToken),
|
||||
ExpiresIn: int(time.Time(session.AccessTokenExpires).Sub(now).Seconds()),
|
||||
Scope: u.Z(session.Scopes_),
|
||||
})
|
||||
}
|
||||
|
||||
func authAuthorizationCode(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
|
||||
// 检查 code 获取用户授权信息
|
||||
data, err := g.Redis.Get(context.Background(), req.Code).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var codeCtx CodeContext
|
||||
if err := json.Unmarshal([]byte(data), &codeCtx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查 PKCE
|
||||
if codeCtx.CodeChallengeMethod != "" {
|
||||
if req.CodeVerifier == "" {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
|
||||
switch codeCtx.CodeChallengeMethod {
|
||||
case "plain":
|
||||
if req.CodeVerifier != codeCtx.CodeChallenge {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
case "S256":
|
||||
hash := sha256.Sum256([]byte(req.CodeVerifier))
|
||||
verifier := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
if verifier != codeCtx.CodeChallenge {
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
default:
|
||||
return nil, ErrAuthorizeInvalidPKCE
|
||||
}
|
||||
}
|
||||
|
||||
user, err := q.User.Where(
|
||||
q.User.ID.Eq(codeCtx.UserID),
|
||||
q.User.Status.Eq(int32(user2.StatusEnabled)),
|
||||
).First()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo 检查 scope
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
UserID: &user.ID,
|
||||
ClientID: &auth.Client.ID,
|
||||
Scopes_: u.P(strings.Join(codeCtx.Scopes, " ")),
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
if codeCtx.Remember {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authClientCredential(ctx *fiber.Ctx, auth *AuthCtx, _ *TokenReq, now time.Time) (*m.Session, error) {
|
||||
// todo 检查 scope
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
ClientID: &auth.Client.ID,
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
|
||||
// 保存会话
|
||||
err := SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authPassword(ctx *fiber.Ctx, auth *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
var user *m.User
|
||||
err := q.Q.Transaction(func(tx *q.Query) (err error) {
|
||||
switch req.LoginType {
|
||||
case GrantPasswordPhone:
|
||||
user, err = authUserBySms(tx, req.Username, req.Password)
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
if user == nil {
|
||||
user = &m.User{
|
||||
Phone: req.Username,
|
||||
Username: u.P(req.Username),
|
||||
Status: int32(user2.StatusEnabled),
|
||||
}
|
||||
}
|
||||
case GrantPasswordEmail:
|
||||
user, err = authUserByEmail(tx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case GrantPasswordSecret:
|
||||
user, err = authUserByPassword(tx, req.Username, req.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return ErrAuthorizeInvalidRequest
|
||||
}
|
||||
|
||||
// 账户状态
|
||||
if user2.Status(user.Status) == user2.StatusDisabled {
|
||||
slog.Debug("账户状态异常", "username", req.Username, "status", user.Status)
|
||||
return core.NewBizErr("账号无法登录")
|
||||
}
|
||||
|
||||
// 更新用户的登录时间
|
||||
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
|
||||
user.LastLoginHost = u.X(ctx.IP())
|
||||
user.LastLoginAgent = u.X(ctx.Get(fiber.HeaderUserAgent))
|
||||
if err := tx.User.Save(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成会话
|
||||
session := &m.Session{
|
||||
IP: u.X(ctx.IP()),
|
||||
UA: u.X(ctx.Get(fiber.HeaderUserAgent)),
|
||||
UserID: &user.ID,
|
||||
ClientID: &auth.Client.ID,
|
||||
Scopes_: u.X(req.Scope),
|
||||
AccessToken: uuid.NewString(),
|
||||
AccessTokenExpires: orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second)),
|
||||
}
|
||||
if req.Remember {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func authRefreshToken(_ *fiber.Ctx, _ *AuthCtx, req *TokenReq, now time.Time) (*m.Session, error) {
|
||||
session, err := FindSessionByRefresh(req.RefreshToken, now)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAuthorizeInvalidGrant
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// todo 检查权限
|
||||
|
||||
// 生成令牌
|
||||
session.AccessToken = uuid.NewString()
|
||||
session.AccessTokenExpires = orm.LocalDateTime(now.Add(time.Duration(env.SessionAccessExpire) * time.Second))
|
||||
if session.RefreshToken != nil {
|
||||
session.RefreshToken = u.P(uuid.NewString())
|
||||
session.RefreshTokenExpires = u.P(orm.LocalDateTime(now.Add(time.Duration(env.SessionRefreshExpire) * time.Second)))
|
||||
}
|
||||
|
||||
// 保存令牌
|
||||
err = SaveSession(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||||
var sErr AuthErr
|
||||
if errors.As(err, &sErr) {
|
||||
status := fiber.StatusBadRequest
|
||||
var desc string
|
||||
switch {
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidRequest):
|
||||
desc = "无效的请求"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidClient):
|
||||
status = fiber.StatusUnauthorized
|
||||
desc = "无效的客户端凭证"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidGrant):
|
||||
desc = "无效的授权凭证"
|
||||
case errors.Is(sErr, ErrAuthorizeInvalidScope):
|
||||
desc = "无效的授权范围"
|
||||
case errors.Is(sErr, ErrAuthorizeUnauthorizedClient):
|
||||
desc = "未授权的客户端"
|
||||
case errors.Is(sErr, ErrAuthorizeUnsupportedGrantType):
|
||||
desc = "不支持的授权类型"
|
||||
}
|
||||
if len(description) > 0 {
|
||||
desc = description[0]
|
||||
}
|
||||
|
||||
return c.Status(status).JSON(TokenErrResp{
|
||||
Error: string(sErr),
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func Revoke() error {
|
||||
@@ -56,3 +395,12 @@ func Revoke() error {
|
||||
func Introspect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type CodeContext struct {
|
||||
UserID int32 `json:"user_id"`
|
||||
ClientID int32 `json:"client_id"`
|
||||
Scopes []string `json:"scopes"`
|
||||
Remember bool `json:"remember"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
CodeChallengeMethod string `json:"code_challenge_method"`
|
||||
}
|
||||
|
||||
99
web/auth/check.go
Normal file
99
web/auth/check.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
type AuthCtx struct {
|
||||
User *m.User `json:"account,omitempty"`
|
||||
Admin *m.Admin `json:"admin,omitempty"`
|
||||
Client *m.Client `json:"client,omitempty"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
Session *m.Session `json:"session,omitempty"`
|
||||
smap map[string]struct{}
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitUser(scopes ...string) (*AuthCtx, error) {
|
||||
if a.User == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitAdmin(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Admin == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitSecretClient(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Client == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
spec := client.Spec(a.Client.Spec)
|
||||
if spec != client.SpecApi && spec != client.SpecWeb {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) PermitInternalClient(scopes ...string) (*AuthCtx, error) {
|
||||
if a.Client == nil {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
spec := client.Spec(a.Client.Spec)
|
||||
if spec != client.SpecApi && spec != client.SpecWeb {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
cType := client.Type(a.Client.Type)
|
||||
if cType != client.TypeInternal {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
if !a.checkScopes(scopes...) {
|
||||
return a, ErrAuthenticateForbidden
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *AuthCtx) checkScopes(scopes ...string) bool {
|
||||
if len(scopes) == 0 || len(a.Scopes) == 0 {
|
||||
return true
|
||||
}
|
||||
if len(a.smap) == 0 && len(a.Scopes) > 0 {
|
||||
for _, scope := range scopes {
|
||||
a.smap[scope] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
if _, ok := a.smap[scope]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
const AuthCtxKey = "session"
|
||||
|
||||
func SetAuthCtx(c *fiber.Ctx, auth *AuthCtx) {
|
||||
c.Locals(AuthCtxKey, auth)
|
||||
}
|
||||
|
||||
func GetAuthCtx(c *fiber.Ctx) *AuthCtx {
|
||||
if authCtx, ok := c.Locals(AuthCtxKey).(*AuthCtx); ok {
|
||||
return authCtx
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
client2 "platform/web/domains/client"
|
||||
)
|
||||
|
||||
// Context 定义认证信息
|
||||
type Context struct {
|
||||
Payload Payload `json:"payload"`
|
||||
Permissions map[string]struct{} `json:"permissions,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
func (a *Context) AnyType(types ...PayloadType) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
for _, t := range types {
|
||||
if a.Payload.Type == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AnyPermission 检查认证是否包含指定权限
|
||||
func (a *Context) AnyPermission(requiredPermission ...string) bool {
|
||||
if a == nil || a.Permissions == nil {
|
||||
return false
|
||||
}
|
||||
for _, permission := range requiredPermission {
|
||||
if _, ok := a.Permissions[permission]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Payload 定义负载信息
|
||||
type Payload struct {
|
||||
Id int32 `json:"id,omitempty"`
|
||||
Type PayloadType `json:"type,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Avatar *string `json:"avatar,omitempty"`
|
||||
}
|
||||
|
||||
type PayloadType int
|
||||
|
||||
const (
|
||||
PayloadNone PayloadType = iota // 游客
|
||||
PayloadUser // 用户
|
||||
PayloadAdmin // 管理员
|
||||
PayloadPublicServer // 公共服务(public_client)
|
||||
PayloadSecuredServer // 安全服务(credential_client)
|
||||
PayloadInternalServer // 内部服务
|
||||
)
|
||||
|
||||
func (t PayloadType) ToStr() string {
|
||||
switch t {
|
||||
case PayloadUser:
|
||||
return "user"
|
||||
case PayloadAdmin:
|
||||
return "admn"
|
||||
case PayloadPublicServer:
|
||||
return "cpub"
|
||||
case PayloadSecuredServer:
|
||||
return "ccnf"
|
||||
case PayloadInternalServer:
|
||||
return "inte"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadTypeFromStr(name string) PayloadType {
|
||||
switch name {
|
||||
case "user":
|
||||
return PayloadUser
|
||||
case "admn":
|
||||
return PayloadAdmin
|
||||
case "cpub":
|
||||
return PayloadPublicServer
|
||||
case "ccnf":
|
||||
return PayloadSecuredServer
|
||||
case "inte":
|
||||
return PayloadInternalServer
|
||||
default:
|
||||
return PayloadNone
|
||||
}
|
||||
}
|
||||
|
||||
func PayloadTypeFromClientSpec(spec client2.Spec) PayloadType {
|
||||
var clientType PayloadType
|
||||
switch spec {
|
||||
case client2.SpecNative, client2.SpecBrowser:
|
||||
clientType = PayloadPublicServer
|
||||
case client2.SpecWeb:
|
||||
clientType = PayloadSecuredServer
|
||||
case client2.SpecTrusted:
|
||||
clientType = PayloadInternalServer
|
||||
}
|
||||
return clientType
|
||||
}
|
||||
24
web/auth/errors.go
Normal file
24
web/auth/errors.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package auth
|
||||
|
||||
type AuthErr string
|
||||
|
||||
func (e AuthErr) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
// 认证错误
|
||||
const (
|
||||
ErrAuthenticateUnauthorize = AuthErr("令牌无效")
|
||||
ErrAuthenticateForbidden = AuthErr("没有权限")
|
||||
)
|
||||
|
||||
// 授权错误
|
||||
const (
|
||||
ErrAuthorizeInvalidRequest = AuthErr("invalid_request")
|
||||
ErrAuthorizeInvalidClient = AuthErr("invalid_client")
|
||||
ErrAuthorizeInvalidGrant = AuthErr("invalid_grant")
|
||||
ErrAuthorizeInvalidScope = AuthErr("invalid_scope")
|
||||
ErrAuthorizeUnauthorizedClient = AuthErr("unauthorized_client")
|
||||
ErrAuthorizeUnsupportedGrantType = AuthErr("unsupported_grant_type")
|
||||
ErrAuthorizeInvalidPKCE = AuthErr("invalid_pkce")
|
||||
)
|
||||
@@ -2,160 +2,36 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"platform/pkg/env"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"time"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
// 认证主体
|
||||
Payload *Payload
|
||||
// 令牌信息
|
||||
TokenDetails *TokenDetails
|
||||
func FindSession(accessToken string, now time.Time) (*m.Session, error) {
|
||||
return q.Session.
|
||||
Preload(field.Associations).
|
||||
Where(
|
||||
q.Session.AccessToken.Eq(accessToken),
|
||||
q.Session.AccessTokenExpires.Gt(orm.LocalDateTime(now)),
|
||||
).First()
|
||||
}
|
||||
|
||||
func FindSession(ctx context.Context, token string) (*Context, error) {
|
||||
|
||||
// 读取认证数据
|
||||
authJSON, err := g.Redis.Get(ctx, accessKey(token)).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, errors.New("invalid_token")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 反序列化
|
||||
auth := new(Context)
|
||||
if err := json.Unmarshal([]byte(authJSON), auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
func FindSessionByRefresh(refreshToken string, now time.Time) (*m.Session, error) {
|
||||
return q.Session.
|
||||
Preload(field.Associations).
|
||||
Where(
|
||||
q.Session.RefreshToken.Eq(refreshToken),
|
||||
q.Session.RefreshTokenExpires.Gt(orm.LocalDateTime(now)),
|
||||
).First()
|
||||
}
|
||||
|
||||
func CreateSession(ctx context.Context, authCtx *Context, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
accessToken := genToken()
|
||||
refreshToken := genToken()
|
||||
|
||||
// 序列化认证数据
|
||||
authData, err := json.Marshal(authCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 序列化刷新令牌数据
|
||||
refreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: authCtx,
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 事务保存数据到 Redis
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipe := g.Redis.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
||||
if remember {
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &TokenDetails{
|
||||
AccessToken: accessToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshToken: refreshToken,
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RefreshSession(ctx context.Context, refreshToken string, renew bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
rKey := refreshKey(refreshToken)
|
||||
var tokenDetails *TokenDetails
|
||||
|
||||
// 刷新令牌
|
||||
err := g.Redis.Watch(ctx, func(tx *redis.Tx) error {
|
||||
|
||||
// 先获取刷新令牌数据
|
||||
refreshJson, err := tx.Get(ctx, rKey).Result()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return ErrInvalidRefreshToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析刷新令牌数据
|
||||
refreshData := new(RefreshData)
|
||||
if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 生成新的令牌
|
||||
newAccessToken := genToken()
|
||||
newRefreshToken := genToken()
|
||||
|
||||
authData, err := json.Marshal(refreshData.AuthContext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newRefreshData, err := json.Marshal(RefreshData{
|
||||
AuthContext: refreshData.AuthContext,
|
||||
AccessToken: newAccessToken,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipeline := tx.Pipeline()
|
||||
|
||||
// 保存新的令牌
|
||||
var accessExpire = time.Duration(env.SessionAccessExpire) * time.Second
|
||||
var refreshExpire = time.Duration(env.SessionRefreshExpire) * time.Second
|
||||
|
||||
pipeline.Set(ctx, accessKey(newAccessToken), authData, accessExpire)
|
||||
pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, refreshExpire)
|
||||
|
||||
// 删除旧的令牌
|
||||
pipeline.Del(ctx, accessKey(refreshData.AccessToken))
|
||||
pipeline.Del(ctx, refreshKey(refreshToken))
|
||||
|
||||
_, err = pipeline.Exec(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tokenDetails = &TokenDetails{
|
||||
AccessToken: newAccessToken,
|
||||
RefreshToken: newRefreshToken,
|
||||
AccessTokenExpires: now.Add(accessExpire),
|
||||
RefreshTokenExpires: now.Add(refreshExpire),
|
||||
Auth: refreshData.AuthContext,
|
||||
}
|
||||
return nil
|
||||
}, rKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("刷新令牌失败: %w", err)
|
||||
}
|
||||
|
||||
return tokenDetails, nil
|
||||
func SaveSession(session *m.Session) error {
|
||||
return q.Session.Save(session)
|
||||
}
|
||||
|
||||
func RemoveSession(ctx context.Context, accessToken string, refreshToken string) error {
|
||||
@@ -163,11 +39,6 @@ func RemoveSession(ctx context.Context, accessToken string, refreshToken string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 生成一个新的令牌
|
||||
func genToken() string {
|
||||
return uuid.NewString()
|
||||
}
|
||||
|
||||
// 令牌键的格式为 "session:<token>"
|
||||
func accessKey(token string) string {
|
||||
return fmt.Sprintf("session:%s", token)
|
||||
@@ -177,32 +48,3 @@ func accessKey(token string) string {
|
||||
func refreshKey(token string) string {
|
||||
return fmt.Sprintf("session:refresh:%s", token)
|
||||
}
|
||||
|
||||
// TokenDetails 存储令牌详细信息
|
||||
type TokenDetails struct {
|
||||
// 访问令牌
|
||||
AccessToken string
|
||||
// 刷新令牌
|
||||
RefreshToken string
|
||||
// 访问令牌过期时间
|
||||
AccessTokenExpires time.Time
|
||||
// 刷新令牌过期时间
|
||||
RefreshTokenExpires time.Time
|
||||
// 认证信息
|
||||
Auth *Context
|
||||
}
|
||||
|
||||
type RefreshData struct {
|
||||
AuthContext *Context
|
||||
AccessToken string
|
||||
}
|
||||
|
||||
type SessionErr string
|
||||
|
||||
func (e SessionErr) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
ErrInvalidRefreshToken = SessionErr("无效的刷新令牌")
|
||||
)
|
||||
|
||||
@@ -60,7 +60,7 @@ type Err struct {
|
||||
|
||||
func (e *Err) Error() string {
|
||||
if e.err != nil {
|
||||
return e.msg + ":" + e.err.Error()
|
||||
return e.msg + ": " + e.err.Error()
|
||||
}
|
||||
return e.msg
|
||||
}
|
||||
|
||||
@@ -6,5 +6,12 @@ const (
|
||||
SpecNative Spec = iota + 1 // 原生客户端
|
||||
SpecBrowser // 浏览器客户端
|
||||
SpecWeb // Web 服务
|
||||
SpecTrusted // 可信服务
|
||||
SpecApi // Api 服务
|
||||
)
|
||||
|
||||
type Type int32
|
||||
|
||||
const (
|
||||
TypeNormal Type = iota // 普通客户端
|
||||
TypeInternal // 内部客户端
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
|
||||
var message = "服务器异常"
|
||||
|
||||
var fiberErr *fiber.Error
|
||||
var authErr auth.AuthenticationErr
|
||||
var authErr auth.AuthErr
|
||||
var bizErr *core.BizErr
|
||||
var servErr *core.ServErr
|
||||
|
||||
@@ -30,9 +30,9 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
|
||||
// 认证授权错误
|
||||
case errors.As(err, &authErr):
|
||||
switch {
|
||||
case errors.Is(err, auth.ErrUnauthorize):
|
||||
case errors.Is(err, auth.ErrAuthenticateUnauthorize):
|
||||
code = fiber.StatusUnauthorized
|
||||
case errors.Is(err, auth.ErrForbidden):
|
||||
case errors.Is(err, auth.ErrAuthenticateForbidden):
|
||||
code = fiber.StatusForbidden
|
||||
default:
|
||||
code = fiber.StatusBadRequest
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package tasks
|
||||
package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,4 +1,4 @@
|
||||
package tasks
|
||||
package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
@@ -1,10 +1,11 @@
|
||||
package tasks
|
||||
package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/hibiken/asynq"
|
||||
"log/slog"
|
||||
trade2 "platform/web/domains/trade"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
const CancelTrade = "trade:update"
|
||||
@@ -1,6 +1,7 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
|
||||
"github.com/smartwalle/alipay/v3"
|
||||
@@ -8,25 +9,26 @@ import (
|
||||
|
||||
var Alipay *alipay.Client
|
||||
|
||||
func initAlipay() {
|
||||
func initAlipay() error {
|
||||
var client, err = alipay.New(
|
||||
env.AlipayAppId,
|
||||
env.AlipayAppPrivateKey,
|
||||
env.AlipayProduction,
|
||||
)
|
||||
if err != nil {
|
||||
panic("初始化支付宝客户端失败: " + err.Error())
|
||||
return fmt.Errorf("初始化支付宝客户端失败: %w", err)
|
||||
}
|
||||
|
||||
err = client.LoadAliPayPublicKey(env.AlipayPublicKey)
|
||||
if err != nil {
|
||||
panic("加载支付宝公钥失败: " + err.Error())
|
||||
return fmt.Errorf("加载支付宝公钥失败: %w", err)
|
||||
}
|
||||
|
||||
err = client.SetEncryptKey(env.AlipayApiCert)
|
||||
if err != nil {
|
||||
panic("设置支付宝加密密钥失败: " + err.Error())
|
||||
return fmt.Errorf("设置支付宝加密证书失败: %w", err)
|
||||
}
|
||||
|
||||
Alipay = client
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/u"
|
||||
|
||||
@@ -14,17 +15,18 @@ type aliyunClient struct {
|
||||
Sms *sms.Client
|
||||
}
|
||||
|
||||
func initAliyun() {
|
||||
func initAliyun() error {
|
||||
client, err := sms.NewClient(&openapi.Config{
|
||||
AccessKeyId: &env.AliyunAccessKey,
|
||||
AccessKeySecret: &env.AliyunAccessKeySecret,
|
||||
Endpoint: u.P("dysmsapi.aliyuncs.com"),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("初始化阿里云客户端失败: %w", err)
|
||||
}
|
||||
|
||||
Aliyun = &aliyunClient{
|
||||
Sms: client,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -35,10 +35,11 @@ type cloud struct {
|
||||
|
||||
var Cloud CloudClient
|
||||
|
||||
func initBaiyin() {
|
||||
func initBaiyin() error {
|
||||
Cloud = &cloud{
|
||||
url: env.BaiyinAddr,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type AutoConfig struct {
|
||||
|
||||
@@ -1,14 +1,38 @@
|
||||
package globals
|
||||
|
||||
func Init() {
|
||||
initBaiyin()
|
||||
initAlipay()
|
||||
initWechatPay()
|
||||
initAliyun()
|
||||
initValidator()
|
||||
initRedis()
|
||||
initOrm()
|
||||
initProxy()
|
||||
initAsynq()
|
||||
initSft()
|
||||
import (
|
||||
"context"
|
||||
"platform/pkg/u"
|
||||
)
|
||||
|
||||
func Init(ctx context.Context) error {
|
||||
errs := make([]error, 0)
|
||||
|
||||
errs = append(errs, initBaiyin())
|
||||
errs = append(errs, initAlipay())
|
||||
errs = append(errs, initWechatPay())
|
||||
errs = append(errs, initAliyun())
|
||||
errs = append(errs, initValidator())
|
||||
errs = append(errs, initRedis())
|
||||
errs = append(errs, initOrm())
|
||||
errs = append(errs, initProxy())
|
||||
errs = append(errs, initSft())
|
||||
|
||||
return u.CombineErrors(errs)
|
||||
}
|
||||
|
||||
func Stop() error {
|
||||
var errs = make([]error, 0)
|
||||
|
||||
err := stopRedis()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
err = stopOrm()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
return u.CombineErrors(errs)
|
||||
}
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/web/queries"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"log/slog"
|
||||
"platform/pkg/env"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
var Conn *sql.DB
|
||||
|
||||
func initOrm() {
|
||||
func initOrm() error {
|
||||
|
||||
// 连接数据库
|
||||
dsn := fmt.Sprintf(
|
||||
@@ -25,27 +28,29 @@ func initOrm() {
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
|
||||
panic(err)
|
||||
return fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 连接池
|
||||
conn, err := db.DB()
|
||||
if err != nil {
|
||||
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
|
||||
panic(err)
|
||||
return fmt.Errorf("配置连接池失败: %w", err)
|
||||
}
|
||||
conn.SetMaxIdleConns(10)
|
||||
conn.SetMaxOpenConns(100)
|
||||
|
||||
queries.SetDefault(db)
|
||||
DB = db
|
||||
Conn = conn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExitOrm() error {
|
||||
func stopOrm() error {
|
||||
if DB != nil {
|
||||
conn, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("关闭数据库连接失败: %w", err)
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
@@ -23,8 +23,9 @@ var Proxy *ProxyClient
|
||||
type ProxyClient struct {
|
||||
}
|
||||
|
||||
func initProxy() {
|
||||
func initProxy() error {
|
||||
Proxy = &ProxyClient{}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ProxyPermitConfig struct {
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package globals
|
||||
|
||||
import (
|
||||
"github.com/go-redsync/redsync/v4/redis/goredis/v9"
|
||||
"log/slog"
|
||||
"net"
|
||||
"platform/pkg/env"
|
||||
"platform/web/core"
|
||||
|
||||
"github.com/go-redsync/redsync/v4/redis/goredis/v9"
|
||||
|
||||
"github.com/go-redsync/redsync/v4"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@@ -18,11 +19,10 @@ type ExtendRedSync struct {
|
||||
*redsync.Redsync
|
||||
}
|
||||
|
||||
func initRedis() {
|
||||
func initRedis() error {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: net.JoinHostPort(env.RedisHost, env.RedisPort),
|
||||
DB: env.RedisDb,
|
||||
Password: env.RedisPass,
|
||||
Password: env.RedisPassword,
|
||||
})
|
||||
|
||||
pool := goredis.NewPool(client)
|
||||
@@ -30,9 +30,11 @@ func initRedis() {
|
||||
|
||||
Redis = client
|
||||
Redsync = &ExtendRedSync{sync}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExitRedis() error {
|
||||
func stopRedis() error {
|
||||
if Redis != nil {
|
||||
return Redis.Close()
|
||||
}
|
||||
|
||||
@@ -28,9 +28,9 @@ type SftClient struct {
|
||||
publicKey *rsa.PublicKey
|
||||
}
|
||||
|
||||
func initSft() {
|
||||
func initSft() error {
|
||||
if !env.SftPayEnable {
|
||||
panic("商福通支付未启用,请检查环境变量 SFTPAY_ENABLE")
|
||||
return fmt.Errorf("商福通支付未启用,请检查环境变量 SFTPAY_ENABLE")
|
||||
}
|
||||
|
||||
SFTPay = SftClient{
|
||||
@@ -41,7 +41,7 @@ func initSft() {
|
||||
// 加载私钥
|
||||
private, err := base64.StdEncoding.DecodeString(env.SftPayAppPrivateKey)
|
||||
if err != nil {
|
||||
panic("解析商福通私钥失败: " + err.Error())
|
||||
return fmt.Errorf("解析商福通私钥失败: %w", err)
|
||||
}
|
||||
|
||||
var privateKey *rsa.PrivateKey
|
||||
@@ -49,13 +49,13 @@ func initSft() {
|
||||
if err != nil {
|
||||
pkcs8, err := x509.ParsePKCS8PrivateKey(private)
|
||||
if err != nil {
|
||||
panic("解析商福通私钥失败: " + err.Error())
|
||||
return fmt.Errorf("解析商福通私钥失败: %w", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
privateKey, ok = pkcs8.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
panic("解析商福通私钥失败")
|
||||
return fmt.Errorf("解析商福通私钥失败")
|
||||
}
|
||||
}
|
||||
SFTPay.privateKey = privateKey
|
||||
@@ -63,35 +63,36 @@ func initSft() {
|
||||
// 加载公钥
|
||||
public, err := base64.StdEncoding.DecodeString(env.SftPayPublicKey)
|
||||
if err != nil {
|
||||
panic("解析商福通公钥失败: " + err.Error())
|
||||
return fmt.Errorf("解析商福通公钥失败: %w", err)
|
||||
}
|
||||
|
||||
var publicKey *rsa.PublicKey
|
||||
pkix, err := x509.ParsePKIXPublicKey(public)
|
||||
if err != nil {
|
||||
panic("解析商福通公钥失败: " + err.Error())
|
||||
return fmt.Errorf("解析商福通公钥失败: %w", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
publicKey, ok = pkix.(*rsa.PublicKey)
|
||||
if !ok {
|
||||
panic("解析商福通公钥失败")
|
||||
return fmt.Errorf("解析商福通公钥失败")
|
||||
}
|
||||
SFTPay.publicKey = publicKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SftClient) PaymentScanPay(req *PaymentScanPayReq) (*PaymentScanPayResp, error) {
|
||||
const url = "https://pay.rscygroup.com/api/open/payment/scanpay"
|
||||
req.ReturnUrl = env.SftReturnUrl
|
||||
req.NotifyUrl = env.SftNotifyUrl
|
||||
req.ReturnUrl = u.X(env.SftReturnUrl)
|
||||
req.NotifyUrl = u.X(env.SftNotifyUrl)
|
||||
req.RouteNo = u.P(s.routeId)
|
||||
return call[PaymentScanPayResp](s, url, req)
|
||||
}
|
||||
|
||||
func (s *SftClient) PaymentH5Pay(req *PaymentH5PayReq) (*PaymentH5PayResp, error) {
|
||||
const url = "https://pay.rscygroup.com/api/open/payment/h5pay"
|
||||
req.ReturnUrl = env.SftReturnUrl
|
||||
req.NotifyUrl = env.SftNotifyUrl
|
||||
req.ReturnUrl = u.X(env.SftReturnUrl)
|
||||
req.NotifyUrl = u.X(env.SftNotifyUrl)
|
||||
req.RouteNo = u.P(s.routeId)
|
||||
return call[PaymentH5PayResp](s, url, req)
|
||||
}
|
||||
@@ -256,7 +257,7 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
|
||||
|
||||
encode, err := s.sign(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加密请求内容失败:%w", err)
|
||||
return nil, fmt.Errorf("加密请求内容失败: %w", err)
|
||||
}
|
||||
|
||||
bytes, err := json.Marshal(encode)
|
||||
@@ -266,33 +267,33 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
|
||||
|
||||
request, err := http.NewRequest("POST", url, strings.NewReader(string(bytes)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败:%w", err)
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
if env.DebugHttpDump == true {
|
||||
reqDump, err := httputil.DumpRequest(request, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求内容转储失败:%w", err)
|
||||
return nil, fmt.Errorf("请求内容转储失败: %w", err)
|
||||
}
|
||||
println(string(reqDump) + "\n\n")
|
||||
}
|
||||
|
||||
response, err := http.DefaultClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败:%w", err)
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
|
||||
if env.DebugHttpDump == true {
|
||||
respDump, err := httputil.DumpResponse(response, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应内容转储失败:%w", err)
|
||||
return nil, fmt.Errorf("响应内容转储失败: %w", err)
|
||||
}
|
||||
println(string(respDump) + "\n\n")
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("请求响应失败:%d", response.StatusCode)
|
||||
return nil, fmt.Errorf("请求响应失败: %d", response.StatusCode)
|
||||
}
|
||||
defer func(body io.ReadCloser) {
|
||||
_ = body.Close()
|
||||
@@ -300,18 +301,18 @@ func call[T any](s *SftClient, url string, req any) (*T, error) {
|
||||
|
||||
body, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应内容失败:%w", err)
|
||||
return nil, fmt.Errorf("读取响应内容失败: %w", err)
|
||||
}
|
||||
|
||||
decode, err := s.verify(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解密响应内容失败:%w", err)
|
||||
return nil, fmt.Errorf("解密响应内容失败: %w", err)
|
||||
}
|
||||
|
||||
var resp = new(T)
|
||||
err = json.Unmarshal([]byte(decode), resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("响应正文解析失败:%w", err)
|
||||
return nil, fmt.Errorf("响应正文解析失败: %w", err)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
@@ -321,7 +322,7 @@ func (s *SftClient) sign(msg any) (*request, error) {
|
||||
|
||||
bytes, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("格式化加密正文失败:%w", err)
|
||||
return nil, fmt.Errorf("格式化加密正文失败: %w", err)
|
||||
}
|
||||
|
||||
if env.DebugHttpDump {
|
||||
@@ -341,7 +342,7 @@ func (s *SftClient) sign(msg any) (*request, error) {
|
||||
hashed := sha256.Sum256([]byte(body.String()))
|
||||
signature, err := rsa.SignPKCS1v15(nil, s.privateKey, crypto.SHA256, hashed[:])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("签名失败:%w", err)
|
||||
return nil, fmt.Errorf("签名失败: %w", err)
|
||||
}
|
||||
|
||||
body.Sign = base64.StdEncoding.EncodeToString(signature)
|
||||
@@ -353,11 +354,11 @@ func (s *SftClient) verify(str []byte) (string, error) {
|
||||
var resp = new(response)
|
||||
err := json.Unmarshal(str, resp)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解析响应正文失败:%w", err)
|
||||
return "", fmt.Errorf("解析响应正文失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.Code != "000000" {
|
||||
return "", fmt.Errorf("请求业务响应失败:%s", u.Z(resp.Msg))
|
||||
return "", fmt.Errorf("请求业务响应失败: %s", u.Z(resp.Msg))
|
||||
}
|
||||
|
||||
if resp.Sign == nil {
|
||||
@@ -371,13 +372,13 @@ func (s *SftClient) verify(str []byte) (string, error) {
|
||||
|
||||
ser, err := resp.String()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("格式化响应内容失败:%w", err)
|
||||
return "", fmt.Errorf("格式化响应内容失败: %w", err)
|
||||
}
|
||||
|
||||
hashed := sha256.Sum256([]byte(ser))
|
||||
err = rsa.VerifyPKCS1v15(s.publicKey, crypto.SHA256, hashed[:], sign)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("验签失败:%w", err)
|
||||
return "", fmt.Errorf("验签失败: %w", err)
|
||||
}
|
||||
|
||||
return *resp.BizData, nil
|
||||
@@ -412,7 +413,7 @@ type response struct {
|
||||
func (r response) String() (string, error) {
|
||||
if r.BizData == nil || r.Msg == nil || r.SignType == nil {
|
||||
return "", core.NewServErr(fmt.Sprintf(
|
||||
"上游数据返回有空值:BizData %v,Msg %v, SignType %v",
|
||||
"上游数据返回有空值: BizData %v,Msg %v, SignType %v",
|
||||
r.BizData == nil, r.Msg == nil, r.SignType == nil,
|
||||
))
|
||||
}
|
||||
|
||||
@@ -2,12 +2,14 @@ package globals
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
zhtrans "github.com/go-playground/validator/v10/translations/zh"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var Validator *ValidatorClient
|
||||
@@ -38,17 +40,18 @@ func (v *ValidatorClient) Validate(c *fiber.Ctx, data any) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func initValidator() {
|
||||
func initValidator() error {
|
||||
var validate = validator.New(validator.WithRequiredStructEnabled())
|
||||
|
||||
var translator = ut.New(zh.New()).GetFallback()
|
||||
err := zhtrans.RegisterDefaultTranslations(validate, translator)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("初始化验证器失败: %w", err)
|
||||
}
|
||||
|
||||
Validator = &ValidatorClient{
|
||||
validator: validate,
|
||||
translator: translator,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package globals
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
|
||||
"github.com/wechatpay-apiv3/wechatpay-go/core"
|
||||
@@ -20,28 +21,28 @@ type WechatPayClient struct {
|
||||
Notify *notify.Handler
|
||||
}
|
||||
|
||||
func initWechatPay() {
|
||||
func initWechatPay() error {
|
||||
|
||||
// 加载商户私钥
|
||||
private, err := base64.StdEncoding.DecodeString(env.WechatPayMchPrivateKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("加载微信支付商户私钥失败: %w", err)
|
||||
}
|
||||
|
||||
appPrivateKey, err := utils.LoadPrivateKey(string(private))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("解析微信支付商户私钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 加载微信支付公钥
|
||||
public, err := base64.StdEncoding.DecodeString(env.WechatPayPublicKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("加载微信支付公钥失败: %w", err)
|
||||
}
|
||||
|
||||
wechatPublicKey, err := utils.LoadPublicKey(string(public))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("解析微信支付公钥失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建 WechatPay 客户端
|
||||
@@ -55,7 +56,7 @@ func initWechatPay() {
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("创建微信支付客户端失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建 WechatPay 通知处理器
|
||||
@@ -64,7 +65,7 @@ func initWechatPay() {
|
||||
*wechatPublicKey,
|
||||
))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return fmt.Errorf("创建微信支付通知处理器失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建 WechatPay 服务
|
||||
@@ -72,4 +73,5 @@ func initWechatPay() {
|
||||
Native: &native.NativeApiService{Client: client},
|
||||
Notify: handler,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
q "platform/web/queries"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
// region ListAnnouncements
|
||||
@@ -16,7 +17,7 @@ type ListAnnouncementsRequest struct {
|
||||
func ListAnnouncements(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,277 +1,14 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"platform/pkg/u"
|
||||
auth2 "platform/web/auth"
|
||||
client2 "platform/web/domains/client"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// region /token
|
||||
|
||||
type TokenReq struct {
|
||||
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
s.GrantCodeData
|
||||
s.GrantClientData
|
||||
s.GrantRefreshData
|
||||
s.GrantPasswordData
|
||||
}
|
||||
|
||||
type TokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
}
|
||||
|
||||
type TokenErrResp struct {
|
||||
Error string `json:"error"`
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
// Token 处理 OAuth2.0 授权请求
|
||||
func Token(c *fiber.Ctx) error {
|
||||
|
||||
// 验证请求参数
|
||||
req := new(TokenReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
|
||||
}
|
||||
if req.GrantType == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
||||
}
|
||||
|
||||
slog.Debug("oauth token", slog.String("grant_type",
|
||||
string(req.GrantType)),
|
||||
slog.String("client_id", req.ClientID),
|
||||
)
|
||||
|
||||
// 基于授权类型处理请求
|
||||
switch req.GrantType {
|
||||
|
||||
// 授权码模式
|
||||
case auth2.GrantAuthorizationCode:
|
||||
if req.Code == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
if err != nil {
|
||||
return sendError(c, err.(s.AuthServiceError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 客户端凭证模式
|
||||
case auth2.GrantClientCredentials:
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
||||
if err != nil {
|
||||
return sendError(c, err.(s.AuthServiceError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 刷新令牌模式
|
||||
case auth2.GrantRefreshToken:
|
||||
if req.RefreshToken == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
if err != nil {
|
||||
if errors.Is(err, auth2.ErrInvalidRefreshToken) {
|
||||
return sendError(c, s.ErrOauthInvalidGrant)
|
||||
}
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
// 密码模式
|
||||
case auth2.GrantPassword:
|
||||
if req.LoginType == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
||||
}
|
||||
if req.Username == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password")
|
||||
}
|
||||
|
||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
|
||||
default:
|
||||
return sendError(c, s.ErrOauthUnsupportedGrantType)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查客户端凭证
|
||||
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
|
||||
header := c.Get("Authorization")
|
||||
if header != "" {
|
||||
basic := strings.TrimPrefix(header, "Basic ")
|
||||
if basic != "" {
|
||||
base, err := base64.RawURLEncoding.DecodeString(basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parts := strings.SplitN(string(base), ":", 2)
|
||||
if len(parts) == 2 {
|
||||
clientId = parts[0]
|
||||
clientSecret = parts[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找客户端
|
||||
if clientId == "" {
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证客户端状态
|
||||
if client.Status != 1 {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
switch grant {
|
||||
case auth2.GrantAuthorizationCode:
|
||||
if !client.GrantCode {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantClientCredentials:
|
||||
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantRefreshToken:
|
||||
if !client.GrantRefresh {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case auth2.GrantPassword:
|
||||
if !client.GrantPassword {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端是 confidential,验证 client_secret,失败返回错误
|
||||
if client.Spec == int32(client2.SpecWeb) || client.Spec == int32(client2.SpecTrusted) {
|
||||
if clientSecret == "" {
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
}
|
||||
|
||||
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
|
||||
auth2.Locals(c, &auth2.Context{
|
||||
Payload: auth2.Payload{
|
||||
Id: client.ID,
|
||||
Type: auth2.PayloadSecuredServer,
|
||||
Name: client.Name,
|
||||
Avatar: client.Icon,
|
||||
},
|
||||
})
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// 发送成功响应
|
||||
func sendSuccess(c *fiber.Ctx, details *auth2.TokenDetails) error {
|
||||
return c.JSON(TokenResp{
|
||||
AccessToken: details.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
|
||||
RefreshToken: details.RefreshToken,
|
||||
})
|
||||
}
|
||||
|
||||
// 发送错误响应
|
||||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||||
var sErr s.AuthServiceError
|
||||
if errors.As(err, &sErr) {
|
||||
status := fiber.StatusBadRequest
|
||||
var desc string
|
||||
switch {
|
||||
case errors.Is(sErr, s.ErrOauthInvalidRequest):
|
||||
desc = "无效的请求"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidClient):
|
||||
status = fiber.StatusUnauthorized
|
||||
desc = "无效的客户端凭证"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidGrant):
|
||||
desc = "无效的授权凭证"
|
||||
case errors.Is(sErr, s.ErrOauthInvalidScope):
|
||||
desc = "无效的授权范围"
|
||||
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
|
||||
desc = "未授权的客户端"
|
||||
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
|
||||
desc = "不支持的授权类型"
|
||||
}
|
||||
if len(description) > 0 {
|
||||
desc = description[0]
|
||||
}
|
||||
|
||||
return c.Status(status).JSON(TokenErrResp{
|
||||
Error: string(sErr),
|
||||
Description: desc,
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region /revoke
|
||||
|
||||
type RevokeReq struct {
|
||||
@@ -280,7 +17,7 @@ type RevokeReq struct {
|
||||
}
|
||||
|
||||
func Revoke(c *fiber.Ctx) error {
|
||||
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||
_, err := auth2.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
// 用户未登录
|
||||
return nil
|
||||
@@ -312,14 +49,14 @@ type IntrospectResp struct {
|
||||
|
||||
func Introspect(c *fiber.Ctx) error {
|
||||
// 验证权限
|
||||
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||
authCtx, err := auth2.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
profile, err := q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Omit(q.User.DeletedAt).
|
||||
Take()
|
||||
if err != nil {
|
||||
|
||||
@@ -23,7 +23,7 @@ type ListBillReq struct {
|
||||
// ListBill 获取账单列表
|
||||
func ListBill(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func ListBill(c *fiber.Ctx) error {
|
||||
|
||||
// 查询账单列表
|
||||
do := q.Bill.
|
||||
Where(q.Bill.UserID.Eq(authContext.Payload.Id))
|
||||
Where(q.Bill.UserID.Eq(authCtx.User.ID))
|
||||
|
||||
if req.Type != nil {
|
||||
do.Where(q.Bill.Type.Eq(int32(*req.Type)))
|
||||
|
||||
@@ -24,7 +24,7 @@ type ListChannelsReq struct {
|
||||
|
||||
func ListChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authContext, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -37,7 +37,7 @@ func ListChannels(c *fiber.Ctx) error {
|
||||
|
||||
// 构造查询条件
|
||||
cond := q.Channel.
|
||||
Where(q.Channel.UserID.Eq(authContext.Payload.Id))
|
||||
Where(q.Channel.UserID.Eq(authContext.User.ID))
|
||||
switch req.AuthType {
|
||||
case s.ChannelAuthTypeIp:
|
||||
cond.Where(q.Channel.AuthIP.Is(true))
|
||||
@@ -110,24 +110,19 @@ type CreateChannelRespItem struct {
|
||||
func CreateChannel(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查用户其他权限
|
||||
user, err := q.User.
|
||||
Where(q.User.ID.Eq(authContext.Payload.Id)).
|
||||
Take()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := authCtx.User
|
||||
if user.IDToken == nil || *user.IDToken == "" {
|
||||
return fiber.NewError(fiber.StatusForbidden, "账号未实名")
|
||||
}
|
||||
|
||||
count, err := q.Whitelist.Where(
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(user.ID),
|
||||
q.Whitelist.Host.Eq(c.IP()),
|
||||
).Count()
|
||||
if err != nil {
|
||||
@@ -155,7 +150,7 @@ func CreateChannel(c *fiber.Ctx) error {
|
||||
// 创建通道
|
||||
result, err := s.Channel.CreateChannel(
|
||||
c,
|
||||
authContext.Payload.Id,
|
||||
user.ID,
|
||||
req.ResourceId,
|
||||
req.Protocol,
|
||||
req.AuthType,
|
||||
@@ -198,7 +193,7 @@ type RemoveChannelsReq struct {
|
||||
|
||||
func RemoveChannels(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -210,31 +205,7 @@ func RemoveChannels(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 删除通道
|
||||
err = s.Channel.RemoveChannels(req.ByIds, authCtx.Payload.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.SendStatus(fiber.StatusOK)
|
||||
}
|
||||
|
||||
type RemoveChannelByTaskReq []int32
|
||||
|
||||
func RemoveChannelByTask(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
var req RemoveChannelByTaskReq
|
||||
if err := c.BodyParser(&req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除通道
|
||||
err = s.Channel.RemoveChannels(req)
|
||||
err = s.Channel.RemoveChannels(req.ByIds, authCtx.User.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2,8 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
"log/slog"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
@@ -14,6 +12,9 @@ import (
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -120,7 +121,7 @@ type AllEdgesAvailableRespItem struct {
|
||||
|
||||
func AllEdgesAvailable(c *fiber.Ctx) (err error) {
|
||||
// 检查权限
|
||||
_, err = auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
_, err = auth.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -37,17 +37,11 @@ type IdentifyRes struct {
|
||||
func Identify(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user, err := q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Select(q.User.ID, q.User.IDToken).
|
||||
Take()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
user := authCtx.User
|
||||
if user.IDToken != nil && *user.IDToken != "" {
|
||||
// 用户已实名认证
|
||||
return c.JSON(IdentifyRes{
|
||||
@@ -86,7 +80,7 @@ func Identify(c *fiber.Ctx) error {
|
||||
|
||||
// 保存认证中间状态
|
||||
info := idenInfo{
|
||||
Uid: authCtx.Payload.Id,
|
||||
Uid: user.ID,
|
||||
Type: req.Type,
|
||||
Name: req.Name,
|
||||
IdNo: req.IdenNo,
|
||||
|
||||
@@ -40,9 +40,7 @@ type ProxyReportOnlineResp struct {
|
||||
func ProxyReportOnline(c *fiber.Ctx) (err error) {
|
||||
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -149,9 +147,7 @@ type ProxyReportOfflineReq struct {
|
||||
|
||||
func ProxyReportOffline(c *fiber.Ctx) (err error) {
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -193,9 +189,7 @@ type ProxyReportUpdateReq struct {
|
||||
|
||||
func ProxyReportUpdate(c *fiber.Ctx) (err error) {
|
||||
// 检查接口权限
|
||||
_, err = auth2.NewProtect(c).Payload(
|
||||
auth2.PayloadInternalServer,
|
||||
).Do()
|
||||
_, err = auth2.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"gorm.io/gen/field"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
"platform/web/core"
|
||||
@@ -12,6 +11,8 @@ import (
|
||||
s "platform/web/services"
|
||||
"time"
|
||||
|
||||
"gorm.io/gen/field"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
)
|
||||
|
||||
@@ -28,7 +29,7 @@ type ListResourceShortReq struct {
|
||||
|
||||
func ListResourceShort(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -41,7 +42,7 @@ func ListResourceShort(c *fiber.Ctx) error {
|
||||
|
||||
// 查询套餐列表
|
||||
do := q.Resource.Where(
|
||||
q.Resource.UserID.Eq(authContext.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Type.Eq(int32(resource2.TypeShort)),
|
||||
)
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
@@ -109,7 +110,7 @@ type ListResourceLongReq struct {
|
||||
|
||||
func ListResourceLong(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -122,7 +123,7 @@ func ListResourceLong(c *fiber.Ctx) error {
|
||||
|
||||
// 查询套餐列表
|
||||
do := q.Resource.Where(
|
||||
q.Resource.UserID.Eq(authContext.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Type.Eq(int32(resource2.TypeLong)),
|
||||
)
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
@@ -182,7 +183,7 @@ type AllResourceReq struct {
|
||||
|
||||
func AllActiveResource(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,7 +199,7 @@ func AllActiveResource(c *fiber.Ctx) error {
|
||||
q.Resource.Long,
|
||||
).
|
||||
Where(
|
||||
q.Resource.UserID.Eq(authCtx.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Active.Is(true),
|
||||
q.Resource.Where(
|
||||
q.Resource.Type.Eq(int32(resource2.TypeShort)),
|
||||
@@ -254,7 +255,7 @@ type StatisticLong struct {
|
||||
|
||||
func StatisticResourceFree(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -266,7 +267,7 @@ func StatisticResourceFree(c *fiber.Ctx) error {
|
||||
q.Resource.Long,
|
||||
).
|
||||
Where(
|
||||
q.Resource.UserID.Eq(session.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.Active.Is(true),
|
||||
).
|
||||
Select(q.Resource.ID, q.Resource.Type).
|
||||
@@ -347,7 +348,7 @@ type StatisticResourceUsageResp []struct {
|
||||
|
||||
func StatisticResourceUsage(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
session, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -359,12 +360,12 @@ func StatisticResourceUsage(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 统计套餐提取数量
|
||||
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(session.Payload.Id))
|
||||
do := q.LogsUserUsage.Where(q.LogsUserUsage.UserID.Eq(authCtx.User.ID))
|
||||
if req.ResourceNo != nil && *req.ResourceNo != "" {
|
||||
var resourceID int32
|
||||
err := q.Resource.
|
||||
Where(
|
||||
q.Resource.UserID.Eq(session.Payload.Id),
|
||||
q.Resource.UserID.Eq(authCtx.User.ID),
|
||||
q.Resource.ResourceNo.Eq(*req.ResourceNo),
|
||||
).
|
||||
Select(q.Resource.ID).
|
||||
@@ -409,7 +410,7 @@ type CreateResourceReq struct {
|
||||
func CreateResource(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -421,7 +422,7 @@ func CreateResource(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 创建套餐
|
||||
err = s.Resource.CreateResourceByBalance(authCtx.Payload.Id, time.Now(), req.CreateResourceData)
|
||||
err = s.Resource.CreateResourceByBalance(authCtx.User.ID, time.Now(), req.CreateResourceData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -431,7 +432,7 @@ func CreateResource(c *fiber.Ctx) error {
|
||||
|
||||
func ResourcePrice(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.NewProtect(c).Payload(auth.PayloadInternalServer).Do()
|
||||
_, err := auth.GetAuthCtx(c).PermitSecretClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
trade2 "platform/web/domains/trade"
|
||||
g "platform/web/globals"
|
||||
s "platform/web/services"
|
||||
"platform/web/tasks"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -27,7 +26,7 @@ type TradeCreateResp struct {
|
||||
|
||||
func TradeCreate(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.NewProtect(c).Payload(auth.PayloadUser).Do()
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -52,7 +51,7 @@ func TradeCreate(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 创建交易
|
||||
result, err := s.Trade.CreateTrade(authCtx.Payload.Id, time.Now(), &req.CreateTradeData)
|
||||
result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData)
|
||||
if err != nil {
|
||||
slog.Error("创建交易失败", "error", err)
|
||||
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"})
|
||||
@@ -70,7 +69,7 @@ type TradeCompleteReq struct {
|
||||
|
||||
func TradeComplete(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -99,7 +98,7 @@ type TradeCancelReq struct {
|
||||
|
||||
func TradeCancel(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,29 +118,3 @@ func TradeCancel(c *fiber.Ctx) error {
|
||||
|
||||
return c.SendStatus(fiber.StatusNoContent)
|
||||
}
|
||||
|
||||
type TradeCheckReq struct {
|
||||
tasks.CancelTradeData
|
||||
}
|
||||
|
||||
func TradeCancelByTask(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 取消交易
|
||||
req := new(TradeCheckReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查订单状态
|
||||
err = s.Trade.CancelTrade(req.TradeNo, req.Method, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"platform/web/auth"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
s "platform/web/services"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// region /update
|
||||
@@ -20,7 +21,7 @@ type UpdateUserReq struct {
|
||||
|
||||
func UpdateUser(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -33,7 +34,7 @@ func UpdateUser(c *fiber.Ctx) error {
|
||||
|
||||
// 更新用户信息
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Updates(m.User{
|
||||
Username: &req.Username,
|
||||
Email: &req.Email,
|
||||
@@ -59,7 +60,7 @@ type UpdateAccountReq struct {
|
||||
|
||||
func UpdateAccount(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -72,7 +73,7 @@ func UpdateAccount(c *fiber.Ctx) error {
|
||||
|
||||
// 更新用户信息
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
Updates(m.User{
|
||||
Username: &req.Username,
|
||||
Password: &req.Password,
|
||||
@@ -97,7 +98,7 @@ type UpdatePasswordReq struct {
|
||||
|
||||
func UpdatePassword(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,7 +125,7 @@ func UpdatePassword(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
_, err = q.User.
|
||||
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
||||
Where(q.User.ID.Eq(authCtx.User.ID)).
|
||||
UpdateColumn(q.User.Password, newHash)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -2,12 +2,14 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"platform/pkg/env"
|
||||
"platform/web/auth"
|
||||
"platform/web/services"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type VerifierReq struct {
|
||||
@@ -17,7 +19,7 @@ type VerifierReq struct {
|
||||
|
||||
func SmsCode(c *fiber.Ctx) error {
|
||||
|
||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadInternalServer}, []string{})
|
||||
_, err := auth.GetAuthCtx(c).PermitInternalClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -48,3 +50,19 @@ func SmsCode(c *fiber.Ctx) error {
|
||||
// 发送成功
|
||||
return nil
|
||||
}
|
||||
|
||||
func DebugGetSmsCode(c *fiber.Ctx) error {
|
||||
if env.RunMode != env.RunModeDev {
|
||||
return fiber.NewError(fiber.StatusForbidden, "not allowed")
|
||||
}
|
||||
|
||||
code, err := services.Verifier.GetSms(c.Context(), c.Params("phone"))
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return c.SendString("还没有验证码")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return c.SendString(code)
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ type ListWhitelistResp struct {
|
||||
func ListWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -38,8 +38,7 @@ func ListWhitelist(c *fiber.Ctx) error {
|
||||
}
|
||||
|
||||
// 获取白名单信息
|
||||
do := q.Whitelist.
|
||||
Where(q.Whitelist.UserID.Eq(authContext.Payload.Id))
|
||||
do := q.Whitelist.Where(q.Whitelist.UserID.Eq(authCtx.User.ID))
|
||||
|
||||
list, err := q.Whitelist.Where(do).
|
||||
Offset(req.GetOffset()).
|
||||
@@ -77,7 +76,7 @@ type CreateWhitelistReq struct {
|
||||
func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -96,7 +95,7 @@ func CreateWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 创建白名单
|
||||
err = q.Whitelist.Create(&m.Whitelist{
|
||||
UserID: authContext.Payload.Id,
|
||||
UserID: authCtx.User.ID,
|
||||
Host: req.Host,
|
||||
Remark: &req.Remark,
|
||||
})
|
||||
@@ -111,7 +110,7 @@ type UpdateWhitelistReq struct {
|
||||
|
||||
func UpdateWhitelist(c *fiber.Ctx) error {
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -129,7 +128,7 @@ func UpdateWhitelist(c *fiber.Ctx) error {
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.Eq(req.ID),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(authCtx.User.ID),
|
||||
).
|
||||
Updates(&m.Whitelist{
|
||||
ID: req.ID,
|
||||
@@ -149,7 +148,7 @@ type RemoveWhitelistReq struct {
|
||||
func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
|
||||
// 检查权限
|
||||
authContext, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
||||
authCtx, err := auth.GetAuthCtx(c).PermitUser()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -175,7 +174,7 @@ func RemoveWhitelist(c *fiber.Ctx) error {
|
||||
_, err = q.Whitelist.
|
||||
Where(
|
||||
q.Whitelist.ID.In(ids...),
|
||||
q.Whitelist.UserID.Eq(authContext.Payload.Id),
|
||||
q.Whitelist.UserID.Eq(authCtx.User.ID),
|
||||
).
|
||||
Update(
|
||||
q.Whitelist.DeletedAt, time.Now(),
|
||||
|
||||
42
web/middlewares.go
Normal file
42
web/middlewares.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"platform/web/auth"
|
||||
|
||||
"github.com/gofiber/contrib/otelfiber/v2"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jxskiss/base62"
|
||||
)
|
||||
|
||||
func ApplyMiddlewares(app *fiber.App) {
|
||||
|
||||
// recover
|
||||
app.Use(recover.New(recover.Config{
|
||||
EnableStackTrace: true,
|
||||
}))
|
||||
|
||||
// metric
|
||||
app.Use(otelfiber.Middleware())
|
||||
|
||||
// logger
|
||||
app.Use(logger.New(logger.Config{
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
return c.Path() == "/favicon.ico"
|
||||
},
|
||||
}))
|
||||
|
||||
// request id
|
||||
app.Use(requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
binary, _ := uuid.New().MarshalBinary()
|
||||
return base62.EncodeToString(binary)
|
||||
},
|
||||
}))
|
||||
|
||||
// authenticate
|
||||
app.Use(auth.Authenticate())
|
||||
}
|
||||
@@ -17,10 +17,14 @@ type Channel struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:通道ID" json:"id"` // 通道ID
|
||||
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
|
||||
ProxyID int32 `gorm:"column:proxy_id;type:integer;not null;comment:代理ID" json:"proxy_id"` // 代理ID
|
||||
EdgeID *int32 `gorm:"column:edge_id;type:integer;comment:节点ID" json:"edge_id"` // 节点ID
|
||||
ResourceID int32 `gorm:"column:resource_id;type:integer;not null;comment:套餐ID" json:"resource_id"` // 套餐ID
|
||||
ProxyHost string `gorm:"column:proxy_host;type:character varying(255);not null;comment:代理地址" json:"proxy_host"` // 代理地址
|
||||
ProxyPort int32 `gorm:"column:proxy_port;type:integer;not null;comment:转发端口" json:"proxy_port"` // 转发端口
|
||||
EdgeHost *string `gorm:"column:edge_host;type:character varying(255);comment:节点地址" json:"edge_host"` // 节点地址
|
||||
Protocol *int32 `gorm:"column:protocol;type:integer;comment:协议类型:1-http,2-https,3-socks5" json:"protocol"` // 协议类型:1-http,2-https,3-socks5
|
||||
AuthIP bool `gorm:"column:auth_ip;type:boolean;not null;comment:IP认证" json:"auth_ip"` // IP认证
|
||||
Whitelists *string `gorm:"column:whitelists;type:text;comment:IP白名单,逗号分隔" json:"whitelists"` // IP白名单,逗号分隔
|
||||
AuthPass bool `gorm:"column:auth_pass;type:boolean;not null;comment:密码认证" json:"auth_pass"` // 密码认证
|
||||
Username *string `gorm:"column:username;type:character varying(255);comment:用户名" json:"username"` // 用户名
|
||||
Password *string `gorm:"column:password;type:character varying(255);comment:密码" json:"password"` // 密码
|
||||
@@ -28,10 +32,6 @@ type Channel struct {
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
EdgeHost *string `gorm:"column:edge_host;type:character varying(255);comment:节点地址" json:"edge_host"` // 节点地址
|
||||
EdgeID *int32 `gorm:"column:edge_id;type:integer;comment:节点ID" json:"edge_id"` // 节点ID
|
||||
Whitelists *string `gorm:"column:whitelists;type:text;comment:IP白名单,逗号分隔" json:"whitelists"` // IP白名单,逗号分隔
|
||||
ResourceID int32 `gorm:"column:resource_id;type:integer;not null;comment:套餐ID" json:"resource_id"` // 套餐ID
|
||||
}
|
||||
|
||||
// TableName Channel's table name
|
||||
|
||||
@@ -14,21 +14,18 @@ const TableNameClient = "client"
|
||||
|
||||
// Client mapped from table <client>
|
||||
type Client struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:客户端ID" json:"id"` // 客户端ID
|
||||
ClientID string `gorm:"column:client_id;type:character varying(255);not null;comment:OAuth2客户端标识符" json:"client_id"` // OAuth2客户端标识符
|
||||
ClientSecret string `gorm:"column:client_secret;type:character varying(255);not null;comment:OAuth2客户端密钥" json:"client_secret"` // OAuth2客户端密钥
|
||||
RedirectURI *string `gorm:"column:redirect_uri;type:character varying(255);comment:OAuth2 重定向URI" json:"redirect_uri"` // OAuth2 重定向URI
|
||||
GrantCode bool `gorm:"column:grant_code;type:boolean;not null;comment:允许授权码授予" json:"grant_code"` // 允许授权码授予
|
||||
GrantClient bool `gorm:"column:grant_client;type:boolean;not null;comment:允许客户端凭证授予" json:"grant_client"` // 允许客户端凭证授予
|
||||
GrantRefresh bool `gorm:"column:grant_refresh;type:boolean;not null;comment:允许刷新令牌授予" json:"grant_refresh"` // 允许刷新令牌授予
|
||||
GrantPassword bool `gorm:"column:grant_password;type:boolean;not null;comment:允许密码授予" json:"grant_password"` // 允许密码授予
|
||||
Spec int32 `gorm:"column:spec;type:integer;not null;comment:安全规范:1-native,2-browser,3-web,4-trusted" json:"spec"` // 安全规范:1-native,2-browser,3-web,4-trusted
|
||||
Name string `gorm:"column:name;type:character varying(255);not null;comment:名称" json:"name"` // 名称
|
||||
Icon *string `gorm:"column:icon;type:character varying(255);comment:图标URL" json:"icon"` // 图标URL
|
||||
Status int32 `gorm:"column:status;type:integer;not null;default:1;comment:状态:0-禁用,1-正常" json:"status"` // 状态:0-禁用,1-正常
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:客户端ID" json:"id"` // 客户端ID
|
||||
ClientID string `gorm:"column:client_id;type:character varying(255);not null;comment:OAuth2客户端标识符" json:"client_id"` // OAuth2客户端标识符
|
||||
ClientSecret string `gorm:"column:client_secret;type:character varying(255);not null;comment:OAuth2客户端密钥" json:"client_secret"` // OAuth2客户端密钥
|
||||
RedirectURI *string `gorm:"column:redirect_uri;type:character varying(255);comment:OAuth2 重定向URI" json:"redirect_uri"` // OAuth2 重定向URI
|
||||
Spec int32 `gorm:"column:spec;type:integer;not null;comment:安全规范:1-native,2-browser,3-web,4-api" json:"spec"` // 安全规范:1-native,2-browser,3-web,4-api
|
||||
Name string `gorm:"column:name;type:character varying(255);not null;comment:名称" json:"name"` // 名称
|
||||
Icon *string `gorm:"column:icon;type:character varying(255);comment:图标URL" json:"icon"` // 图标URL
|
||||
Status int32 `gorm:"column:status;type:integer;not null;default:1;comment:状态:0-禁用,1-正常" json:"status"` // 状态:0-禁用,1-正常
|
||||
Type int32 `gorm:"column:type;type:integer;not null;comment:类型:0-普通,1-官方" json:"type"` // 类型:0-普通,1-官方
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
}
|
||||
|
||||
// TableName Client's table name
|
||||
|
||||
@@ -12,12 +12,12 @@ const TableNameLogsLogin = "logs_login"
|
||||
type LogsLogin struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:登录日志ID" json:"id"` // 登录日志ID
|
||||
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
|
||||
Ua string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
|
||||
UA string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
|
||||
GrantType string `gorm:"column:grant_type;type:character varying(255);not null;comment:授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式" json:"grant_type"` // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
PasswordGrantType string `gorm:"column:password_grant_type;type:character varying(255);not null;comment:密码模式子授权类型:password-账号密码,phone_code-手机验证码,email_code-邮箱验证码" json:"password_grant_type"` // 密码模式子授权类型:password-账号密码,phone_code-手机验证码,email_code-邮箱验证码
|
||||
Success bool `gorm:"column:success;type:boolean;not null;comment:登录是否成功" json:"success"` // 登录是否成功
|
||||
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:登录时间" json:"time"` // 登录时间
|
||||
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
|
||||
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:登录时间" json:"time"` // 登录时间
|
||||
}
|
||||
|
||||
// TableName LogsLogin's table name
|
||||
|
||||
@@ -10,17 +10,17 @@ const TableNameLogsRequest = "logs_request"
|
||||
|
||||
// LogsRequest mapped from table <logs_request>
|
||||
type LogsRequest struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:访问日志ID" json:"id"` // 访问日志ID
|
||||
Identity int32 `gorm:"column:identity;type:integer;not null;comment:访客身份:0-游客,1-用户,2-管理员,3-公共服务,4-安全服务,5-内部服务" json:"identity"` // 访客身份:0-游客,1-用户,2-管理员,3-公共服务,4-安全服务,5-内部服务
|
||||
Visitor *int32 `gorm:"column:visitor;type:integer;comment:访客ID" json:"visitor"` // 访客ID
|
||||
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
|
||||
Ua *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
|
||||
Method string `gorm:"column:method;type:character varying(10);not null;comment:请求方法" json:"method"` // 请求方法
|
||||
Path string `gorm:"column:path;type:character varying(255);not null;comment:请求路径" json:"path"` // 请求路径
|
||||
Latency string `gorm:"column:latency;type:character varying(255);not null;comment:请求延迟" json:"latency"` // 请求延迟
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:响应状态码" json:"status"` // 响应状态码
|
||||
Error *string `gorm:"column:error;type:text;comment:错误信息" json:"error"` // 错误信息
|
||||
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:请求时间" json:"time"` // 请求时间
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:访问日志ID" json:"id"` // 访问日志ID
|
||||
IP string `gorm:"column:ip;type:character varying(45);not null;comment:IP地址" json:"ip"` // IP地址
|
||||
UA string `gorm:"column:ua;type:character varying(255);not null;comment:用户代理" json:"ua"` // 用户代理
|
||||
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
|
||||
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
|
||||
Method string `gorm:"column:method;type:character varying(10);not null;comment:请求方法" json:"method"` // 请求方法
|
||||
Path string `gorm:"column:path;type:character varying(255);not null;comment:请求路径" json:"path"` // 请求路径
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:响应状态码" json:"status"` // 响应状态码
|
||||
Error *string `gorm:"column:error;type:text;comment:错误信息" json:"error"` // 错误信息
|
||||
Time orm.LocalDateTime `gorm:"column:time;type:timestamp without time zone;not null;comment:请求时间" json:"time"` // 请求时间
|
||||
Latency string `gorm:"column:latency;type:character varying(255);not null;comment:请求延迟" json:"latency"` // 请求延迟
|
||||
}
|
||||
|
||||
// TableName LogsRequest's table name
|
||||
|
||||
@@ -18,12 +18,12 @@ type Proxy struct {
|
||||
Version int32 `gorm:"column:version;type:integer;not null;comment:代理服务版本" json:"version"` // 代理服务版本
|
||||
Name string `gorm:"column:name;type:character varying(255);not null;comment:代理服务名称" json:"name"` // 代理服务名称
|
||||
Host string `gorm:"column:host;type:character varying(255);not null;comment:代理服务地址" json:"host"` // 代理服务地址
|
||||
Type int32 `gorm:"column:type;type:integer;not null;comment:代理服务类型:1-三方,2-自有" json:"type"` // 代理服务类型:1-三方,2-自有
|
||||
Secret *string `gorm:"column:secret;type:character varying(255);comment:代理服务密钥" json:"secret"` // 代理服务密钥
|
||||
Type int32 `gorm:"column:type;type:integer;not null;comment:代理服务类型:1-三方,2-自有" json:"type"` // 代理服务类型:1-三方,2-自有
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:代理服务状态:0-离线,1-在线" json:"status"` // 代理服务状态:0-离线,1-在线
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:代理服务状态:0-离线,1-在线" json:"status"` // 代理服务状态:0-离线,1-在线
|
||||
Edges []Edge `gorm:"foreignKey:ProxyID;references:ID" json:"edges"`
|
||||
}
|
||||
|
||||
|
||||
@@ -14,20 +14,23 @@ const TableNameSession = "session"
|
||||
|
||||
// Session mapped from table <session>
|
||||
type Session struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
|
||||
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
|
||||
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
|
||||
IP *string `gorm:"column:ip;type:character varying(45);comment:IP地址" json:"ip"` // IP地址
|
||||
Ua *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
|
||||
GrantType string `gorm:"column:grant_type;type:character varying(255);not null;default:0;comment:授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式" json:"grant_type"` // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
AccessToken string `gorm:"column:access_token;type:character varying(255);not null;comment:访问令牌" json:"access_token"` // 访问令牌
|
||||
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;type:timestamp without time zone;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
|
||||
RefreshToken *string `gorm:"column:refresh_token;type:character varying(255);comment:刷新令牌" json:"refresh_token"` // 刷新令牌
|
||||
RefreshTokenExpires *orm.LocalDateTime `gorm:"column:refresh_token_expires;type:timestamp without time zone;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
|
||||
Scopes_ *string `gorm:"column:scopes;type:character varying(255);comment:权限范围" json:"scopes"` // 权限范围
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:会话ID" json:"id"` // 会话ID
|
||||
UserID *int32 `gorm:"column:user_id;type:integer;comment:用户ID" json:"user_id"` // 用户ID
|
||||
AdminID *int32 `gorm:"column:admin_id;type:integer;comment:管理员ID" json:"admin_id"` // 管理员ID
|
||||
ClientID *int32 `gorm:"column:client_id;type:integer;comment:客户端ID" json:"client_id"` // 客户端ID
|
||||
IP *string `gorm:"column:ip;type:character varying(45);comment:IP地址" json:"ip"` // IP地址
|
||||
UA *string `gorm:"column:ua;type:character varying(255);comment:用户代理" json:"ua"` // 用户代理
|
||||
AccessToken string `gorm:"column:access_token;type:character varying(255);not null;comment:访问令牌" json:"access_token"` // 访问令牌
|
||||
AccessTokenExpires orm.LocalDateTime `gorm:"column:access_token_expires;type:timestamp without time zone;not null;comment:访问令牌过期时间" json:"access_token_expires"` // 访问令牌过期时间
|
||||
RefreshToken *string `gorm:"column:refresh_token;type:character varying(255);comment:刷新令牌" json:"refresh_token"` // 刷新令牌
|
||||
RefreshTokenExpires *orm.LocalDateTime `gorm:"column:refresh_token_expires;type:timestamp without time zone;comment:刷新令牌过期时间" json:"refresh_token_expires"` // 刷新令牌过期时间
|
||||
Scopes_ *string `gorm:"column:scopes;type:character varying(255);comment:权限范围" json:"scopes"` // 权限范围
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
User *User `gorm:"foreignKey:UserID" json:"user"`
|
||||
Admin *Admin `gorm:"foreignKey:UserID" json:"admin"`
|
||||
Client *Client `gorm:"belongsTo:ID;foreignKey:ClientID" json:"client"`
|
||||
}
|
||||
|
||||
// TableName Session's table name
|
||||
|
||||
@@ -15,26 +15,26 @@ const TableNameTrade = "trade"
|
||||
|
||||
// Trade mapped from table <trade>
|
||||
type Trade struct {
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:订单ID" json:"id"` // 订单ID
|
||||
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
|
||||
InnerNo string `gorm:"column:inner_no;type:character varying(255);not null;comment:内部订单号" json:"inner_no"` // 内部订单号
|
||||
OuterNo *string `gorm:"column:outer_no;type:character varying(255);comment:外部订单号" json:"outer_no"` // 外部订单号
|
||||
Type int32 `gorm:"column:type;type:integer;not null;comment:订单类型:1-购买产品,2-充值余额" json:"type"` // 订单类型:1-购买产品,2-充值余额
|
||||
Subject string `gorm:"column:subject;type:character varying(255);not null;comment:订单主题" json:"subject"` // 订单主题
|
||||
Remark *string `gorm:"column:remark;type:character varying(255);comment:订单备注" json:"remark"` // 订单备注
|
||||
Amount decimal.Decimal `gorm:"column:amount;type:numeric(12,2);not null;comment:订单总金额" json:"amount"` // 订单总金额
|
||||
Payment decimal.Decimal `gorm:"column:payment;type:numeric(12,2);not null;comment:支付金额" json:"payment"` // 支付金额
|
||||
Method int32 `gorm:"column:method;type:integer;not null;comment:支付方式:1-支付宝,2-微信,3-商福通渠道支付宝,4-商福通渠道微信" json:"method"` // 支付方式:1-支付宝,2-微信,3-商福通渠道支付宝,4-商福通渠道微信
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:订单状态:0-待支付,1-已支付,2-已取消" json:"status"` // 订单状态:0-待支付,1-已支付,2-已取消
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
Acquirer *int32 `gorm:"column:acquirer;type:integer;comment:收单机构:1-支付宝,2-微信,3-银联" json:"acquirer"` // 收单机构:1-支付宝,2-微信,3-银联
|
||||
Platform int32 `gorm:"column:platform;type:integer;not null;comment:支付平台:1-电脑网站,2-手机网站" json:"platform"` // 支付平台:1-电脑网站,2-手机网站
|
||||
ID int32 `gorm:"column:id;type:integer;primaryKey;autoIncrement:true;comment:订单ID" json:"id"` // 订单ID
|
||||
UserID int32 `gorm:"column:user_id;type:integer;not null;comment:用户ID" json:"user_id"` // 用户ID
|
||||
InnerNo string `gorm:"column:inner_no;type:character varying(255);not null;comment:内部订单号" json:"inner_no"` // 内部订单号
|
||||
OuterNo *string `gorm:"column:outer_no;type:character varying(255);comment:外部订单号" json:"outer_no"` // 外部订单号
|
||||
Type int32 `gorm:"column:type;type:integer;not null;comment:订单类型:1-购买产品,2-充值余额" json:"type"` // 订单类型:1-购买产品,2-充值余额
|
||||
Subject string `gorm:"column:subject;type:character varying(255);not null;comment:订单主题" json:"subject"` // 订单主题
|
||||
Remark *string `gorm:"column:remark;type:character varying(255);comment:订单备注" json:"remark"` // 订单备注
|
||||
Amount decimal.Decimal `gorm:"column:amount;type:numeric(12,2);not null;comment:订单总金额" json:"amount"` // 订单总金额
|
||||
Payment decimal.Decimal `gorm:"column:payment;type:numeric(12,2);not null;comment:实际支付金额" json:"payment"` // 实际支付金额
|
||||
Method int32 `gorm:"column:method;type:integer;not null;comment:支付方式:1-支付宝,2-微信,3-商福通,4-商福通渠道支付宝,5-商福通渠道微信" json:"method"` // 支付方式:1-支付宝,2-微信,3-商福通,4-商福通渠道支付宝,5-商福通渠道微信
|
||||
Platform int32 `gorm:"column:platform;type:integer;not null;comment:支付平台:1-电脑网站,2-手机网站" json:"platform"` // 支付平台:1-电脑网站,2-手机网站
|
||||
Acquirer *int32 `gorm:"column:acquirer;type:integer;comment:收单机构:1-支付宝,2-微信,3-银联" json:"acquirer"` // 收单机构:1-支付宝,2-微信,3-银联
|
||||
Status int32 `gorm:"column:status;type:integer;not null;comment:订单状态:0-待支付,1-已支付,2-已取消" json:"status"` // 订单状态:0-待支付,1-已支付,2-已取消
|
||||
Refunded bool `gorm:"column:refunded;type:boolean;not null" json:"refunded"`
|
||||
PaymentURL *string `gorm:"column:payment_url;type:text;comment:支付链接" json:"payment_url"` // 支付链接
|
||||
CompletedAt *orm.LocalDateTime `gorm:"column:completed_at;type:timestamp without time zone;comment:支付时间" json:"completed_at"` // 支付时间
|
||||
CanceledAt *orm.LocalDateTime `gorm:"column:canceled_at;type:timestamp without time zone;comment:取消时间" json:"canceled_at"` // 取消时间
|
||||
Refunded bool `gorm:"column:refunded;type:boolean;not null" json:"refunded"`
|
||||
CreatedAt *orm.LocalDateTime `gorm:"column:created_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt *orm.LocalDateTime `gorm:"column:updated_at;type:timestamp without time zone;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;type:timestamp without time zone;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
}
|
||||
|
||||
// TableName Trade's table name
|
||||
|
||||
@@ -30,10 +30,14 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
|
||||
_channel.ID = field.NewInt32(tableName, "id")
|
||||
_channel.UserID = field.NewInt32(tableName, "user_id")
|
||||
_channel.ProxyID = field.NewInt32(tableName, "proxy_id")
|
||||
_channel.EdgeID = field.NewInt32(tableName, "edge_id")
|
||||
_channel.ResourceID = field.NewInt32(tableName, "resource_id")
|
||||
_channel.ProxyHost = field.NewString(tableName, "proxy_host")
|
||||
_channel.ProxyPort = field.NewInt32(tableName, "proxy_port")
|
||||
_channel.EdgeHost = field.NewString(tableName, "edge_host")
|
||||
_channel.Protocol = field.NewInt32(tableName, "protocol")
|
||||
_channel.AuthIP = field.NewBool(tableName, "auth_ip")
|
||||
_channel.Whitelists = field.NewString(tableName, "whitelists")
|
||||
_channel.AuthPass = field.NewBool(tableName, "auth_pass")
|
||||
_channel.Username = field.NewString(tableName, "username")
|
||||
_channel.Password = field.NewString(tableName, "password")
|
||||
@@ -41,10 +45,6 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel {
|
||||
_channel.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_channel.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_channel.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_channel.EdgeHost = field.NewString(tableName, "edge_host")
|
||||
_channel.EdgeID = field.NewInt32(tableName, "edge_id")
|
||||
_channel.Whitelists = field.NewString(tableName, "whitelists")
|
||||
_channel.ResourceID = field.NewInt32(tableName, "resource_id")
|
||||
|
||||
_channel.fillFieldMap()
|
||||
|
||||
@@ -58,10 +58,14 @@ type channel struct {
|
||||
ID field.Int32 // 通道ID
|
||||
UserID field.Int32 // 用户ID
|
||||
ProxyID field.Int32 // 代理ID
|
||||
EdgeID field.Int32 // 节点ID
|
||||
ResourceID field.Int32 // 套餐ID
|
||||
ProxyHost field.String // 代理地址
|
||||
ProxyPort field.Int32 // 转发端口
|
||||
EdgeHost field.String // 节点地址
|
||||
Protocol field.Int32 // 协议类型:1-http,2-https,3-socks5
|
||||
AuthIP field.Bool // IP认证
|
||||
Whitelists field.String // IP白名单,逗号分隔
|
||||
AuthPass field.Bool // 密码认证
|
||||
Username field.String // 用户名
|
||||
Password field.String // 密码
|
||||
@@ -69,10 +73,6 @@ type channel struct {
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
EdgeHost field.String // 节点地址
|
||||
EdgeID field.Int32 // 节点ID
|
||||
Whitelists field.String // IP白名单,逗号分隔
|
||||
ResourceID field.Int32 // 套餐ID
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -92,10 +92,14 @@ func (c *channel) updateTableName(table string) *channel {
|
||||
c.ID = field.NewInt32(table, "id")
|
||||
c.UserID = field.NewInt32(table, "user_id")
|
||||
c.ProxyID = field.NewInt32(table, "proxy_id")
|
||||
c.EdgeID = field.NewInt32(table, "edge_id")
|
||||
c.ResourceID = field.NewInt32(table, "resource_id")
|
||||
c.ProxyHost = field.NewString(table, "proxy_host")
|
||||
c.ProxyPort = field.NewInt32(table, "proxy_port")
|
||||
c.EdgeHost = field.NewString(table, "edge_host")
|
||||
c.Protocol = field.NewInt32(table, "protocol")
|
||||
c.AuthIP = field.NewBool(table, "auth_ip")
|
||||
c.Whitelists = field.NewString(table, "whitelists")
|
||||
c.AuthPass = field.NewBool(table, "auth_pass")
|
||||
c.Username = field.NewString(table, "username")
|
||||
c.Password = field.NewString(table, "password")
|
||||
@@ -103,10 +107,6 @@ func (c *channel) updateTableName(table string) *channel {
|
||||
c.CreatedAt = field.NewField(table, "created_at")
|
||||
c.UpdatedAt = field.NewField(table, "updated_at")
|
||||
c.DeletedAt = field.NewField(table, "deleted_at")
|
||||
c.EdgeHost = field.NewString(table, "edge_host")
|
||||
c.EdgeID = field.NewInt32(table, "edge_id")
|
||||
c.Whitelists = field.NewString(table, "whitelists")
|
||||
c.ResourceID = field.NewInt32(table, "resource_id")
|
||||
|
||||
c.fillFieldMap()
|
||||
|
||||
@@ -127,10 +127,14 @@ func (c *channel) fillFieldMap() {
|
||||
c.fieldMap["id"] = c.ID
|
||||
c.fieldMap["user_id"] = c.UserID
|
||||
c.fieldMap["proxy_id"] = c.ProxyID
|
||||
c.fieldMap["edge_id"] = c.EdgeID
|
||||
c.fieldMap["resource_id"] = c.ResourceID
|
||||
c.fieldMap["proxy_host"] = c.ProxyHost
|
||||
c.fieldMap["proxy_port"] = c.ProxyPort
|
||||
c.fieldMap["edge_host"] = c.EdgeHost
|
||||
c.fieldMap["protocol"] = c.Protocol
|
||||
c.fieldMap["auth_ip"] = c.AuthIP
|
||||
c.fieldMap["whitelists"] = c.Whitelists
|
||||
c.fieldMap["auth_pass"] = c.AuthPass
|
||||
c.fieldMap["username"] = c.Username
|
||||
c.fieldMap["password"] = c.Password
|
||||
@@ -138,10 +142,6 @@ func (c *channel) fillFieldMap() {
|
||||
c.fieldMap["created_at"] = c.CreatedAt
|
||||
c.fieldMap["updated_at"] = c.UpdatedAt
|
||||
c.fieldMap["deleted_at"] = c.DeletedAt
|
||||
c.fieldMap["edge_host"] = c.EdgeHost
|
||||
c.fieldMap["edge_id"] = c.EdgeID
|
||||
c.fieldMap["whitelists"] = c.Whitelists
|
||||
c.fieldMap["resource_id"] = c.ResourceID
|
||||
}
|
||||
|
||||
func (c channel) clone(db *gorm.DB) channel {
|
||||
|
||||
@@ -31,14 +31,11 @@ func newClient(db *gorm.DB, opts ...gen.DOOption) client {
|
||||
_client.ClientID = field.NewString(tableName, "client_id")
|
||||
_client.ClientSecret = field.NewString(tableName, "client_secret")
|
||||
_client.RedirectURI = field.NewString(tableName, "redirect_uri")
|
||||
_client.GrantCode = field.NewBool(tableName, "grant_code")
|
||||
_client.GrantClient = field.NewBool(tableName, "grant_client")
|
||||
_client.GrantRefresh = field.NewBool(tableName, "grant_refresh")
|
||||
_client.GrantPassword = field.NewBool(tableName, "grant_password")
|
||||
_client.Spec = field.NewInt32(tableName, "spec")
|
||||
_client.Name = field.NewString(tableName, "name")
|
||||
_client.Icon = field.NewString(tableName, "icon")
|
||||
_client.Status = field.NewInt32(tableName, "status")
|
||||
_client.Type = field.NewInt32(tableName, "type")
|
||||
_client.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_client.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_client.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
@@ -51,22 +48,19 @@ func newClient(db *gorm.DB, opts ...gen.DOOption) client {
|
||||
type client struct {
|
||||
clientDo
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 客户端ID
|
||||
ClientID field.String // OAuth2客户端标识符
|
||||
ClientSecret field.String // OAuth2客户端密钥
|
||||
RedirectURI field.String // OAuth2 重定向URI
|
||||
GrantCode field.Bool // 允许授权码授予
|
||||
GrantClient field.Bool // 允许客户端凭证授予
|
||||
GrantRefresh field.Bool // 允许刷新令牌授予
|
||||
GrantPassword field.Bool // 允许密码授予
|
||||
Spec field.Int32 // 安全规范:1-native,2-browser,3-web,4-trusted
|
||||
Name field.String // 名称
|
||||
Icon field.String // 图标URL
|
||||
Status field.Int32 // 状态:0-禁用,1-正常
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 客户端ID
|
||||
ClientID field.String // OAuth2客户端标识符
|
||||
ClientSecret field.String // OAuth2客户端密钥
|
||||
RedirectURI field.String // OAuth2 重定向URI
|
||||
Spec field.Int32 // 安全规范:1-native,2-browser,3-web,4-api
|
||||
Name field.String // 名称
|
||||
Icon field.String // 图标URL
|
||||
Status field.Int32 // 状态:0-禁用,1-正常
|
||||
Type field.Int32 // 类型:0-普通,1-官方
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -87,14 +81,11 @@ func (c *client) updateTableName(table string) *client {
|
||||
c.ClientID = field.NewString(table, "client_id")
|
||||
c.ClientSecret = field.NewString(table, "client_secret")
|
||||
c.RedirectURI = field.NewString(table, "redirect_uri")
|
||||
c.GrantCode = field.NewBool(table, "grant_code")
|
||||
c.GrantClient = field.NewBool(table, "grant_client")
|
||||
c.GrantRefresh = field.NewBool(table, "grant_refresh")
|
||||
c.GrantPassword = field.NewBool(table, "grant_password")
|
||||
c.Spec = field.NewInt32(table, "spec")
|
||||
c.Name = field.NewString(table, "name")
|
||||
c.Icon = field.NewString(table, "icon")
|
||||
c.Status = field.NewInt32(table, "status")
|
||||
c.Type = field.NewInt32(table, "type")
|
||||
c.CreatedAt = field.NewField(table, "created_at")
|
||||
c.UpdatedAt = field.NewField(table, "updated_at")
|
||||
c.DeletedAt = field.NewField(table, "deleted_at")
|
||||
@@ -114,19 +105,16 @@ func (c *client) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
}
|
||||
|
||||
func (c *client) fillFieldMap() {
|
||||
c.fieldMap = make(map[string]field.Expr, 15)
|
||||
c.fieldMap = make(map[string]field.Expr, 12)
|
||||
c.fieldMap["id"] = c.ID
|
||||
c.fieldMap["client_id"] = c.ClientID
|
||||
c.fieldMap["client_secret"] = c.ClientSecret
|
||||
c.fieldMap["redirect_uri"] = c.RedirectURI
|
||||
c.fieldMap["grant_code"] = c.GrantCode
|
||||
c.fieldMap["grant_client"] = c.GrantClient
|
||||
c.fieldMap["grant_refresh"] = c.GrantRefresh
|
||||
c.fieldMap["grant_password"] = c.GrantPassword
|
||||
c.fieldMap["spec"] = c.Spec
|
||||
c.fieldMap["name"] = c.Name
|
||||
c.fieldMap["icon"] = c.Icon
|
||||
c.fieldMap["status"] = c.Status
|
||||
c.fieldMap["type"] = c.Type
|
||||
c.fieldMap["created_at"] = c.CreatedAt
|
||||
c.fieldMap["updated_at"] = c.UpdatedAt
|
||||
c.fieldMap["deleted_at"] = c.DeletedAt
|
||||
|
||||
@@ -29,12 +29,12 @@ func newLogsLogin(db *gorm.DB, opts ...gen.DOOption) logsLogin {
|
||||
_logsLogin.ALL = field.NewAsterisk(tableName)
|
||||
_logsLogin.ID = field.NewInt32(tableName, "id")
|
||||
_logsLogin.IP = field.NewString(tableName, "ip")
|
||||
_logsLogin.Ua = field.NewString(tableName, "ua")
|
||||
_logsLogin.UA = field.NewString(tableName, "ua")
|
||||
_logsLogin.GrantType = field.NewString(tableName, "grant_type")
|
||||
_logsLogin.PasswordGrantType = field.NewString(tableName, "password_grant_type")
|
||||
_logsLogin.Success = field.NewBool(tableName, "success")
|
||||
_logsLogin.Time = field.NewField(tableName, "time")
|
||||
_logsLogin.UserID = field.NewInt32(tableName, "user_id")
|
||||
_logsLogin.Time = field.NewField(tableName, "time")
|
||||
|
||||
_logsLogin.fillFieldMap()
|
||||
|
||||
@@ -47,12 +47,12 @@ type logsLogin struct {
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 登录日志ID
|
||||
IP field.String // IP地址
|
||||
Ua field.String // 用户代理
|
||||
UA field.String // 用户代理
|
||||
GrantType field.String // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
PasswordGrantType field.String // 密码模式子授权类型:password-账号密码,phone_code-手机验证码,email_code-邮箱验证码
|
||||
Success field.Bool // 登录是否成功
|
||||
Time field.Field // 登录时间
|
||||
UserID field.Int32 // 用户ID
|
||||
Time field.Field // 登录时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -71,12 +71,12 @@ func (l *logsLogin) updateTableName(table string) *logsLogin {
|
||||
l.ALL = field.NewAsterisk(table)
|
||||
l.ID = field.NewInt32(table, "id")
|
||||
l.IP = field.NewString(table, "ip")
|
||||
l.Ua = field.NewString(table, "ua")
|
||||
l.UA = field.NewString(table, "ua")
|
||||
l.GrantType = field.NewString(table, "grant_type")
|
||||
l.PasswordGrantType = field.NewString(table, "password_grant_type")
|
||||
l.Success = field.NewBool(table, "success")
|
||||
l.Time = field.NewField(table, "time")
|
||||
l.UserID = field.NewInt32(table, "user_id")
|
||||
l.Time = field.NewField(table, "time")
|
||||
|
||||
l.fillFieldMap()
|
||||
|
||||
@@ -96,12 +96,12 @@ func (l *logsLogin) fillFieldMap() {
|
||||
l.fieldMap = make(map[string]field.Expr, 8)
|
||||
l.fieldMap["id"] = l.ID
|
||||
l.fieldMap["ip"] = l.IP
|
||||
l.fieldMap["ua"] = l.Ua
|
||||
l.fieldMap["ua"] = l.UA
|
||||
l.fieldMap["grant_type"] = l.GrantType
|
||||
l.fieldMap["password_grant_type"] = l.PasswordGrantType
|
||||
l.fieldMap["success"] = l.Success
|
||||
l.fieldMap["time"] = l.Time
|
||||
l.fieldMap["user_id"] = l.UserID
|
||||
l.fieldMap["time"] = l.Time
|
||||
}
|
||||
|
||||
func (l logsLogin) clone(db *gorm.DB) logsLogin {
|
||||
|
||||
@@ -28,16 +28,16 @@ func newLogsRequest(db *gorm.DB, opts ...gen.DOOption) logsRequest {
|
||||
tableName := _logsRequest.logsRequestDo.TableName()
|
||||
_logsRequest.ALL = field.NewAsterisk(tableName)
|
||||
_logsRequest.ID = field.NewInt32(tableName, "id")
|
||||
_logsRequest.Identity = field.NewInt32(tableName, "identity")
|
||||
_logsRequest.Visitor = field.NewInt32(tableName, "visitor")
|
||||
_logsRequest.IP = field.NewString(tableName, "ip")
|
||||
_logsRequest.Ua = field.NewString(tableName, "ua")
|
||||
_logsRequest.UA = field.NewString(tableName, "ua")
|
||||
_logsRequest.UserID = field.NewInt32(tableName, "user_id")
|
||||
_logsRequest.ClientID = field.NewInt32(tableName, "client_id")
|
||||
_logsRequest.Method = field.NewString(tableName, "method")
|
||||
_logsRequest.Path = field.NewString(tableName, "path")
|
||||
_logsRequest.Latency = field.NewString(tableName, "latency")
|
||||
_logsRequest.Status = field.NewInt32(tableName, "status")
|
||||
_logsRequest.Error = field.NewString(tableName, "error")
|
||||
_logsRequest.Time = field.NewField(tableName, "time")
|
||||
_logsRequest.Latency = field.NewString(tableName, "latency")
|
||||
|
||||
_logsRequest.fillFieldMap()
|
||||
|
||||
@@ -49,16 +49,16 @@ type logsRequest struct {
|
||||
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 访问日志ID
|
||||
Identity field.Int32 // 访客身份:0-游客,1-用户,2-管理员,3-公共服务,4-安全服务,5-内部服务
|
||||
Visitor field.Int32 // 访客ID
|
||||
IP field.String // IP地址
|
||||
Ua field.String // 用户代理
|
||||
UA field.String // 用户代理
|
||||
UserID field.Int32 // 用户ID
|
||||
ClientID field.Int32 // 客户端ID
|
||||
Method field.String // 请求方法
|
||||
Path field.String // 请求路径
|
||||
Latency field.String // 请求延迟
|
||||
Status field.Int32 // 响应状态码
|
||||
Error field.String // 错误信息
|
||||
Time field.Field // 请求时间
|
||||
Latency field.String // 请求延迟
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -76,16 +76,16 @@ func (l logsRequest) As(alias string) *logsRequest {
|
||||
func (l *logsRequest) updateTableName(table string) *logsRequest {
|
||||
l.ALL = field.NewAsterisk(table)
|
||||
l.ID = field.NewInt32(table, "id")
|
||||
l.Identity = field.NewInt32(table, "identity")
|
||||
l.Visitor = field.NewInt32(table, "visitor")
|
||||
l.IP = field.NewString(table, "ip")
|
||||
l.Ua = field.NewString(table, "ua")
|
||||
l.UA = field.NewString(table, "ua")
|
||||
l.UserID = field.NewInt32(table, "user_id")
|
||||
l.ClientID = field.NewInt32(table, "client_id")
|
||||
l.Method = field.NewString(table, "method")
|
||||
l.Path = field.NewString(table, "path")
|
||||
l.Latency = field.NewString(table, "latency")
|
||||
l.Status = field.NewInt32(table, "status")
|
||||
l.Error = field.NewString(table, "error")
|
||||
l.Time = field.NewField(table, "time")
|
||||
l.Latency = field.NewString(table, "latency")
|
||||
|
||||
l.fillFieldMap()
|
||||
|
||||
@@ -104,16 +104,16 @@ func (l *logsRequest) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
func (l *logsRequest) fillFieldMap() {
|
||||
l.fieldMap = make(map[string]field.Expr, 11)
|
||||
l.fieldMap["id"] = l.ID
|
||||
l.fieldMap["identity"] = l.Identity
|
||||
l.fieldMap["visitor"] = l.Visitor
|
||||
l.fieldMap["ip"] = l.IP
|
||||
l.fieldMap["ua"] = l.Ua
|
||||
l.fieldMap["ua"] = l.UA
|
||||
l.fieldMap["user_id"] = l.UserID
|
||||
l.fieldMap["client_id"] = l.ClientID
|
||||
l.fieldMap["method"] = l.Method
|
||||
l.fieldMap["path"] = l.Path
|
||||
l.fieldMap["latency"] = l.Latency
|
||||
l.fieldMap["status"] = l.Status
|
||||
l.fieldMap["error"] = l.Error
|
||||
l.fieldMap["time"] = l.Time
|
||||
l.fieldMap["latency"] = l.Latency
|
||||
}
|
||||
|
||||
func (l logsRequest) clone(db *gorm.DB) logsRequest {
|
||||
|
||||
@@ -31,12 +31,12 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy {
|
||||
_proxy.Version = field.NewInt32(tableName, "version")
|
||||
_proxy.Name = field.NewString(tableName, "name")
|
||||
_proxy.Host = field.NewString(tableName, "host")
|
||||
_proxy.Type = field.NewInt32(tableName, "type")
|
||||
_proxy.Secret = field.NewString(tableName, "secret")
|
||||
_proxy.Type = field.NewInt32(tableName, "type")
|
||||
_proxy.Status = field.NewInt32(tableName, "status")
|
||||
_proxy.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_proxy.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_proxy.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_proxy.Status = field.NewInt32(tableName, "status")
|
||||
_proxy.Edges = proxyHasManyEdges{
|
||||
db: db.Session(&gorm.Session{}),
|
||||
|
||||
@@ -56,12 +56,12 @@ type proxy struct {
|
||||
Version field.Int32 // 代理服务版本
|
||||
Name field.String // 代理服务名称
|
||||
Host field.String // 代理服务地址
|
||||
Type field.Int32 // 代理服务类型:1-三方,2-自有
|
||||
Secret field.String // 代理服务密钥
|
||||
Type field.Int32 // 代理服务类型:1-三方,2-自有
|
||||
Status field.Int32 // 代理服务状态:0-离线,1-在线
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
Status field.Int32 // 代理服务状态:0-离线,1-在线
|
||||
Edges proxyHasManyEdges
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
@@ -83,12 +83,12 @@ func (p *proxy) updateTableName(table string) *proxy {
|
||||
p.Version = field.NewInt32(table, "version")
|
||||
p.Name = field.NewString(table, "name")
|
||||
p.Host = field.NewString(table, "host")
|
||||
p.Type = field.NewInt32(table, "type")
|
||||
p.Secret = field.NewString(table, "secret")
|
||||
p.Type = field.NewInt32(table, "type")
|
||||
p.Status = field.NewInt32(table, "status")
|
||||
p.CreatedAt = field.NewField(table, "created_at")
|
||||
p.UpdatedAt = field.NewField(table, "updated_at")
|
||||
p.DeletedAt = field.NewField(table, "deleted_at")
|
||||
p.Status = field.NewInt32(table, "status")
|
||||
|
||||
p.fillFieldMap()
|
||||
|
||||
@@ -110,12 +110,12 @@ func (p *proxy) fillFieldMap() {
|
||||
p.fieldMap["version"] = p.Version
|
||||
p.fieldMap["name"] = p.Name
|
||||
p.fieldMap["host"] = p.Host
|
||||
p.fieldMap["type"] = p.Type
|
||||
p.fieldMap["secret"] = p.Secret
|
||||
p.fieldMap["type"] = p.Type
|
||||
p.fieldMap["status"] = p.Status
|
||||
p.fieldMap["created_at"] = p.CreatedAt
|
||||
p.fieldMap["updated_at"] = p.UpdatedAt
|
||||
p.fieldMap["deleted_at"] = p.DeletedAt
|
||||
p.fieldMap["status"] = p.Status
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -29,10 +29,10 @@ func newSession(db *gorm.DB, opts ...gen.DOOption) session {
|
||||
_session.ALL = field.NewAsterisk(tableName)
|
||||
_session.ID = field.NewInt32(tableName, "id")
|
||||
_session.UserID = field.NewInt32(tableName, "user_id")
|
||||
_session.AdminID = field.NewInt32(tableName, "admin_id")
|
||||
_session.ClientID = field.NewInt32(tableName, "client_id")
|
||||
_session.IP = field.NewString(tableName, "ip")
|
||||
_session.Ua = field.NewString(tableName, "ua")
|
||||
_session.GrantType = field.NewString(tableName, "grant_type")
|
||||
_session.UA = field.NewString(tableName, "ua")
|
||||
_session.AccessToken = field.NewString(tableName, "access_token")
|
||||
_session.AccessTokenExpires = field.NewField(tableName, "access_token_expires")
|
||||
_session.RefreshToken = field.NewString(tableName, "refresh_token")
|
||||
@@ -41,6 +41,23 @@ func newSession(db *gorm.DB, opts ...gen.DOOption) session {
|
||||
_session.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_session.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_session.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_session.User = sessionBelongsToUser{
|
||||
db: db.Session(&gorm.Session{}),
|
||||
|
||||
RelationField: field.NewRelation("User", "models.User"),
|
||||
}
|
||||
|
||||
_session.Admin = sessionBelongsToAdmin{
|
||||
db: db.Session(&gorm.Session{}),
|
||||
|
||||
RelationField: field.NewRelation("Admin", "models.Admin"),
|
||||
}
|
||||
|
||||
_session.Client = sessionBelongsToClient{
|
||||
db: db.Session(&gorm.Session{}),
|
||||
|
||||
RelationField: field.NewRelation("Client", "models.Client"),
|
||||
}
|
||||
|
||||
_session.fillFieldMap()
|
||||
|
||||
@@ -53,10 +70,10 @@ type session struct {
|
||||
ALL field.Asterisk
|
||||
ID field.Int32 // 会话ID
|
||||
UserID field.Int32 // 用户ID
|
||||
AdminID field.Int32 // 管理员ID
|
||||
ClientID field.Int32 // 客户端ID
|
||||
IP field.String // IP地址
|
||||
Ua field.String // 用户代理
|
||||
GrantType field.String // 授权类型:authorization_code-授权码模式,client_credentials-客户端凭证模式,refresh_token-刷新令牌模式,password-密码模式
|
||||
UA field.String // 用户代理
|
||||
AccessToken field.String // 访问令牌
|
||||
AccessTokenExpires field.Field // 访问令牌过期时间
|
||||
RefreshToken field.String // 刷新令牌
|
||||
@@ -65,6 +82,11 @@ type session struct {
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
User sessionBelongsToUser
|
||||
|
||||
Admin sessionBelongsToAdmin
|
||||
|
||||
Client sessionBelongsToClient
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -83,10 +105,10 @@ func (s *session) updateTableName(table string) *session {
|
||||
s.ALL = field.NewAsterisk(table)
|
||||
s.ID = field.NewInt32(table, "id")
|
||||
s.UserID = field.NewInt32(table, "user_id")
|
||||
s.AdminID = field.NewInt32(table, "admin_id")
|
||||
s.ClientID = field.NewInt32(table, "client_id")
|
||||
s.IP = field.NewString(table, "ip")
|
||||
s.Ua = field.NewString(table, "ua")
|
||||
s.GrantType = field.NewString(table, "grant_type")
|
||||
s.UA = field.NewString(table, "ua")
|
||||
s.AccessToken = field.NewString(table, "access_token")
|
||||
s.AccessTokenExpires = field.NewField(table, "access_token_expires")
|
||||
s.RefreshToken = field.NewString(table, "refresh_token")
|
||||
@@ -111,13 +133,13 @@ func (s *session) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
|
||||
}
|
||||
|
||||
func (s *session) fillFieldMap() {
|
||||
s.fieldMap = make(map[string]field.Expr, 14)
|
||||
s.fieldMap = make(map[string]field.Expr, 17)
|
||||
s.fieldMap["id"] = s.ID
|
||||
s.fieldMap["user_id"] = s.UserID
|
||||
s.fieldMap["admin_id"] = s.AdminID
|
||||
s.fieldMap["client_id"] = s.ClientID
|
||||
s.fieldMap["ip"] = s.IP
|
||||
s.fieldMap["ua"] = s.Ua
|
||||
s.fieldMap["grant_type"] = s.GrantType
|
||||
s.fieldMap["ua"] = s.UA
|
||||
s.fieldMap["access_token"] = s.AccessToken
|
||||
s.fieldMap["access_token_expires"] = s.AccessTokenExpires
|
||||
s.fieldMap["refresh_token"] = s.RefreshToken
|
||||
@@ -126,18 +148,271 @@ func (s *session) fillFieldMap() {
|
||||
s.fieldMap["created_at"] = s.CreatedAt
|
||||
s.fieldMap["updated_at"] = s.UpdatedAt
|
||||
s.fieldMap["deleted_at"] = s.DeletedAt
|
||||
|
||||
}
|
||||
|
||||
func (s session) clone(db *gorm.DB) session {
|
||||
s.sessionDo.ReplaceConnPool(db.Statement.ConnPool)
|
||||
s.User.db = db.Session(&gorm.Session{Initialized: true})
|
||||
s.User.db.Statement.ConnPool = db.Statement.ConnPool
|
||||
s.Admin.db = db.Session(&gorm.Session{Initialized: true})
|
||||
s.Admin.db.Statement.ConnPool = db.Statement.ConnPool
|
||||
s.Client.db = db.Session(&gorm.Session{Initialized: true})
|
||||
s.Client.db.Statement.ConnPool = db.Statement.ConnPool
|
||||
return s
|
||||
}
|
||||
|
||||
func (s session) replaceDB(db *gorm.DB) session {
|
||||
s.sessionDo.ReplaceDB(db)
|
||||
s.User.db = db.Session(&gorm.Session{})
|
||||
s.Admin.db = db.Session(&gorm.Session{})
|
||||
s.Client.db = db.Session(&gorm.Session{})
|
||||
return s
|
||||
}
|
||||
|
||||
type sessionBelongsToUser struct {
|
||||
db *gorm.DB
|
||||
|
||||
field.RelationField
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUser) Where(conds ...field.Expr) *sessionBelongsToUser {
|
||||
if len(conds) == 0 {
|
||||
return &a
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, 0, len(conds))
|
||||
for _, cond := range conds {
|
||||
exprs = append(exprs, cond.BeCond().(clause.Expression))
|
||||
}
|
||||
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUser) WithContext(ctx context.Context) *sessionBelongsToUser {
|
||||
a.db = a.db.WithContext(ctx)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUser) Session(session *gorm.Session) *sessionBelongsToUser {
|
||||
a.db = a.db.Session(session)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUser) Model(m *models.Session) *sessionBelongsToUserTx {
|
||||
return &sessionBelongsToUserTx{a.db.Model(m).Association(a.Name())}
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUser) Unscoped() *sessionBelongsToUser {
|
||||
a.db = a.db.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionBelongsToUserTx struct{ tx *gorm.Association }
|
||||
|
||||
func (a sessionBelongsToUserTx) Find() (result *models.User, err error) {
|
||||
return result, a.tx.Find(&result)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Append(values ...*models.User) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Append(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Replace(values ...*models.User) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Replace(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Delete(values ...*models.User) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Delete(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Clear() error {
|
||||
return a.tx.Clear()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Count() int64 {
|
||||
return a.tx.Count()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToUserTx) Unscoped() *sessionBelongsToUserTx {
|
||||
a.tx = a.tx.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionBelongsToAdmin struct {
|
||||
db *gorm.DB
|
||||
|
||||
field.RelationField
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdmin) Where(conds ...field.Expr) *sessionBelongsToAdmin {
|
||||
if len(conds) == 0 {
|
||||
return &a
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, 0, len(conds))
|
||||
for _, cond := range conds {
|
||||
exprs = append(exprs, cond.BeCond().(clause.Expression))
|
||||
}
|
||||
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdmin) WithContext(ctx context.Context) *sessionBelongsToAdmin {
|
||||
a.db = a.db.WithContext(ctx)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdmin) Session(session *gorm.Session) *sessionBelongsToAdmin {
|
||||
a.db = a.db.Session(session)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdmin) Model(m *models.Session) *sessionBelongsToAdminTx {
|
||||
return &sessionBelongsToAdminTx{a.db.Model(m).Association(a.Name())}
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdmin) Unscoped() *sessionBelongsToAdmin {
|
||||
a.db = a.db.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionBelongsToAdminTx struct{ tx *gorm.Association }
|
||||
|
||||
func (a sessionBelongsToAdminTx) Find() (result *models.Admin, err error) {
|
||||
return result, a.tx.Find(&result)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Append(values ...*models.Admin) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Append(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Replace(values ...*models.Admin) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Replace(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Delete(values ...*models.Admin) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Delete(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Clear() error {
|
||||
return a.tx.Clear()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Count() int64 {
|
||||
return a.tx.Count()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToAdminTx) Unscoped() *sessionBelongsToAdminTx {
|
||||
a.tx = a.tx.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionBelongsToClient struct {
|
||||
db *gorm.DB
|
||||
|
||||
field.RelationField
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClient) Where(conds ...field.Expr) *sessionBelongsToClient {
|
||||
if len(conds) == 0 {
|
||||
return &a
|
||||
}
|
||||
|
||||
exprs := make([]clause.Expression, 0, len(conds))
|
||||
for _, cond := range conds {
|
||||
exprs = append(exprs, cond.BeCond().(clause.Expression))
|
||||
}
|
||||
a.db = a.db.Clauses(clause.Where{Exprs: exprs})
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClient) WithContext(ctx context.Context) *sessionBelongsToClient {
|
||||
a.db = a.db.WithContext(ctx)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClient) Session(session *gorm.Session) *sessionBelongsToClient {
|
||||
a.db = a.db.Session(session)
|
||||
return &a
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClient) Model(m *models.Session) *sessionBelongsToClientTx {
|
||||
return &sessionBelongsToClientTx{a.db.Model(m).Association(a.Name())}
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClient) Unscoped() *sessionBelongsToClient {
|
||||
a.db = a.db.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionBelongsToClientTx struct{ tx *gorm.Association }
|
||||
|
||||
func (a sessionBelongsToClientTx) Find() (result *models.Client, err error) {
|
||||
return result, a.tx.Find(&result)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Append(values ...*models.Client) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Append(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Replace(values ...*models.Client) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Replace(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Delete(values ...*models.Client) (err error) {
|
||||
targetValues := make([]interface{}, len(values))
|
||||
for i, v := range values {
|
||||
targetValues[i] = v
|
||||
}
|
||||
return a.tx.Delete(targetValues...)
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Clear() error {
|
||||
return a.tx.Clear()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Count() int64 {
|
||||
return a.tx.Count()
|
||||
}
|
||||
|
||||
func (a sessionBelongsToClientTx) Unscoped() *sessionBelongsToClientTx {
|
||||
a.tx = a.tx.Unscoped()
|
||||
return &a
|
||||
}
|
||||
|
||||
type sessionDo struct{ gen.DO }
|
||||
|
||||
func (s sessionDo) Debug() *sessionDo {
|
||||
|
||||
@@ -37,16 +37,16 @@ func newTrade(db *gorm.DB, opts ...gen.DOOption) trade {
|
||||
_trade.Amount = field.NewField(tableName, "amount")
|
||||
_trade.Payment = field.NewField(tableName, "payment")
|
||||
_trade.Method = field.NewInt32(tableName, "method")
|
||||
_trade.Status = field.NewInt32(tableName, "status")
|
||||
_trade.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_trade.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_trade.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
_trade.Acquirer = field.NewInt32(tableName, "acquirer")
|
||||
_trade.Platform = field.NewInt32(tableName, "platform")
|
||||
_trade.Acquirer = field.NewInt32(tableName, "acquirer")
|
||||
_trade.Status = field.NewInt32(tableName, "status")
|
||||
_trade.Refunded = field.NewBool(tableName, "refunded")
|
||||
_trade.PaymentURL = field.NewString(tableName, "payment_url")
|
||||
_trade.CompletedAt = field.NewField(tableName, "completed_at")
|
||||
_trade.CanceledAt = field.NewField(tableName, "canceled_at")
|
||||
_trade.Refunded = field.NewBool(tableName, "refunded")
|
||||
_trade.CreatedAt = field.NewField(tableName, "created_at")
|
||||
_trade.UpdatedAt = field.NewField(tableName, "updated_at")
|
||||
_trade.DeletedAt = field.NewField(tableName, "deleted_at")
|
||||
|
||||
_trade.fillFieldMap()
|
||||
|
||||
@@ -65,18 +65,18 @@ type trade struct {
|
||||
Subject field.String // 订单主题
|
||||
Remark field.String // 订单备注
|
||||
Amount field.Field // 订单总金额
|
||||
Payment field.Field // 支付金额
|
||||
Method field.Int32 // 支付方式:1-支付宝,2-微信,3-商福通渠道支付宝,4-商福通渠道微信
|
||||
Status field.Int32 // 订单状态:0-待支付,1-已支付,2-已取消
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
Acquirer field.Int32 // 收单机构:1-支付宝,2-微信,3-银联
|
||||
Payment field.Field // 实际支付金额
|
||||
Method field.Int32 // 支付方式:1-支付宝,2-微信,3-商福通,4-商福通渠道支付宝,5-商福通渠道微信
|
||||
Platform field.Int32 // 支付平台:1-电脑网站,2-手机网站
|
||||
Acquirer field.Int32 // 收单机构:1-支付宝,2-微信,3-银联
|
||||
Status field.Int32 // 订单状态:0-待支付,1-已支付,2-已取消
|
||||
Refunded field.Bool
|
||||
PaymentURL field.String // 支付链接
|
||||
CompletedAt field.Field // 支付时间
|
||||
CanceledAt field.Field // 取消时间
|
||||
Refunded field.Bool
|
||||
CreatedAt field.Field // 创建时间
|
||||
UpdatedAt field.Field // 更新时间
|
||||
DeletedAt field.Field // 删除时间
|
||||
|
||||
fieldMap map[string]field.Expr
|
||||
}
|
||||
@@ -103,16 +103,16 @@ func (t *trade) updateTableName(table string) *trade {
|
||||
t.Amount = field.NewField(table, "amount")
|
||||
t.Payment = field.NewField(table, "payment")
|
||||
t.Method = field.NewInt32(table, "method")
|
||||
t.Status = field.NewInt32(table, "status")
|
||||
t.CreatedAt = field.NewField(table, "created_at")
|
||||
t.UpdatedAt = field.NewField(table, "updated_at")
|
||||
t.DeletedAt = field.NewField(table, "deleted_at")
|
||||
t.Acquirer = field.NewInt32(table, "acquirer")
|
||||
t.Platform = field.NewInt32(table, "platform")
|
||||
t.Acquirer = field.NewInt32(table, "acquirer")
|
||||
t.Status = field.NewInt32(table, "status")
|
||||
t.Refunded = field.NewBool(table, "refunded")
|
||||
t.PaymentURL = field.NewString(table, "payment_url")
|
||||
t.CompletedAt = field.NewField(table, "completed_at")
|
||||
t.CanceledAt = field.NewField(table, "canceled_at")
|
||||
t.Refunded = field.NewBool(table, "refunded")
|
||||
t.CreatedAt = field.NewField(table, "created_at")
|
||||
t.UpdatedAt = field.NewField(table, "updated_at")
|
||||
t.DeletedAt = field.NewField(table, "deleted_at")
|
||||
|
||||
t.fillFieldMap()
|
||||
|
||||
@@ -140,16 +140,16 @@ func (t *trade) fillFieldMap() {
|
||||
t.fieldMap["amount"] = t.Amount
|
||||
t.fieldMap["payment"] = t.Payment
|
||||
t.fieldMap["method"] = t.Method
|
||||
t.fieldMap["status"] = t.Status
|
||||
t.fieldMap["created_at"] = t.CreatedAt
|
||||
t.fieldMap["updated_at"] = t.UpdatedAt
|
||||
t.fieldMap["deleted_at"] = t.DeletedAt
|
||||
t.fieldMap["acquirer"] = t.Acquirer
|
||||
t.fieldMap["platform"] = t.Platform
|
||||
t.fieldMap["acquirer"] = t.Acquirer
|
||||
t.fieldMap["status"] = t.Status
|
||||
t.fieldMap["refunded"] = t.Refunded
|
||||
t.fieldMap["payment_url"] = t.PaymentURL
|
||||
t.fieldMap["completed_at"] = t.CompletedAt
|
||||
t.fieldMap["canceled_at"] = t.CanceledAt
|
||||
t.fieldMap["refunded"] = t.Refunded
|
||||
t.fieldMap["created_at"] = t.CreatedAt
|
||||
t.fieldMap["updated_at"] = t.UpdatedAt
|
||||
t.fieldMap["deleted_at"] = t.DeletedAt
|
||||
}
|
||||
|
||||
func (t trade) clone(db *gorm.DB) trade {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"platform/web/core"
|
||||
auth2 "platform/web/auth"
|
||||
"platform/web/handlers"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
@@ -12,7 +12,7 @@ func ApplyRouters(app *fiber.App) {
|
||||
|
||||
// 认证
|
||||
auth := api.Group("/auth")
|
||||
auth.Post("/token", handlers.Token)
|
||||
auth.Post("/token", auth2.Token)
|
||||
auth.Post("/revoke", handlers.Revoke)
|
||||
auth.Post("/introspect", handlers.Introspect)
|
||||
auth.Post("/verify/sms", handlers.SmsCode)
|
||||
@@ -47,7 +47,6 @@ func ApplyRouters(app *fiber.App) {
|
||||
channel.Post("/list", handlers.ListChannels)
|
||||
channel.Post("/create", handlers.CreateChannel)
|
||||
channel.Post("/remove", handlers.RemoveChannels)
|
||||
channel.Post("/remove/by-task", handlers.RemoveChannelByTask)
|
||||
|
||||
// 交易
|
||||
trade := api.Group("/trade")
|
||||
@@ -75,12 +74,6 @@ func ApplyRouters(app *fiber.App) {
|
||||
edge.Post("/all", handlers.AllEdgesAvailable)
|
||||
|
||||
// 临时
|
||||
app.Get("/test", func(c *fiber.Ctx) error {
|
||||
return core.NewBizErr("测试错误")
|
||||
})
|
||||
|
||||
// 异步任务客户端
|
||||
tasks := api.Group("/tasks")
|
||||
tasks.Post("/channel/remove", handlers.RemoveChannelByTask)
|
||||
tasks.Post("/trade/cancel", handlers.TradeCancelByTask)
|
||||
debug := app.Group("/debug")
|
||||
debug.Get("/sms/:phone", handlers.DebugGetSmsCode)
|
||||
}
|
||||
@@ -1,201 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"log/slog"
|
||||
"platform/pkg/u"
|
||||
auth2 "platform/web/auth"
|
||||
"platform/web/core"
|
||||
client2 "platform/web/domains/client"
|
||||
user2 "platform/web/domains/user"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var Auth = &authService{}
|
||||
|
||||
type authService struct{}
|
||||
|
||||
// OauthAuthorizationCode 验证授权码
|
||||
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Client, code, redirectURI, codeVerifier string) (*auth2.TokenDetails, error) {
|
||||
return nil, errors.New("TODO")
|
||||
}
|
||||
|
||||
// OauthClientCredentials 验证客户端凭证
|
||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*auth2.TokenDetails, error) {
|
||||
|
||||
var clientType = auth2.PayloadTypeFromClientSpec(client2.Spec(client.Spec))
|
||||
|
||||
var permissions = make(map[string]struct{}, len(scope))
|
||||
for _, item := range scope {
|
||||
permissions[item] = struct{}{}
|
||||
}
|
||||
|
||||
// 保存会话并返回令牌
|
||||
authCtx := auth2.Context{
|
||||
Permissions: permissions,
|
||||
Payload: auth2.Payload{
|
||||
Id: client.ID,
|
||||
Type: clientType,
|
||||
Name: client.Name,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := auth2.CreateSession(ctx, &authCtx, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// OauthRefreshToken 验证刷新令牌
|
||||
func (s *authService) OauthRefreshToken(ctx context.Context, _ *m.Client, refreshToken string, scope ...[]string) (*auth2.TokenDetails, error) {
|
||||
details, err := auth2.RefreshSession(ctx, refreshToken, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// OauthPassword 验证密码
|
||||
func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *GrantPasswordData, ip, agent string) (*auth2.TokenDetails, error) {
|
||||
var user *m.User
|
||||
err := q.Q.Transaction(func(tx *q.Query) error {
|
||||
|
||||
switch data.LoginType {
|
||||
case auth2.GrantPasswordPhone:
|
||||
// 验证验证码
|
||||
err := Verifier.VerifySms(ctx, data.Username, data.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrVerifierServiceInvalid) {
|
||||
return ErrOauthInvalidRequest
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
user, err =
|
||||
tx.User.Where(tx.User.Phone.Eq(data.Username)).Take()
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
case auth2.GrantPasswordEmail:
|
||||
return core.NewServErr("邮箱登录暂不可用")
|
||||
case auth2.GrantPasswordSecret:
|
||||
var err error
|
||||
user, err = tx.User.
|
||||
Where(tx.User.Phone.Eq(data.Username)).
|
||||
Or(tx.User.Email.Eq(data.Username)).
|
||||
Or(tx.User.Username.Eq(data.Username)).
|
||||
Take()
|
||||
if err != nil {
|
||||
slog.Debug("查找用户失败", "error", err)
|
||||
return core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
// 账户状态
|
||||
if user2.Status(user.Status) == user2.StatusDisabled {
|
||||
slog.Debug("账户状态异常", "username", data.Username, "status", user.Status)
|
||||
return core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if user.Password == nil || *user.Password == "" {
|
||||
slog.Debug("用户未设置密码", "username", data.Username)
|
||||
return core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(*user.Password), []byte(data.Password)) != nil {
|
||||
slog.Debug("密码验证失败", "username", data.Username)
|
||||
return core.NewBizErr("用户不存在或密码错误")
|
||||
}
|
||||
|
||||
default:
|
||||
return ErrOauthInvalidRequest
|
||||
}
|
||||
|
||||
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
|
||||
if user == nil {
|
||||
user = &m.User{
|
||||
Phone: data.Username,
|
||||
Username: u.P(data.Username),
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户的登录时间
|
||||
user.LastLogin = u.P(orm.LocalDateTime(time.Now()))
|
||||
user.LastLoginHost = u.P(ip)
|
||||
user.LastLoginAgent = u.P(agent)
|
||||
if err := tx.User.Save(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存到会话
|
||||
var name = ""
|
||||
if user.Name != nil {
|
||||
name = *user.Name
|
||||
}
|
||||
authCtx := auth2.Context{
|
||||
Payload: auth2.Payload{
|
||||
Id: user.ID,
|
||||
Type: auth2.PayloadUser,
|
||||
Name: name,
|
||||
Avatar: user.Avatar,
|
||||
},
|
||||
}
|
||||
|
||||
token, err := auth2.CreateSession(ctx, &authCtx, data.Remember)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
type GrantCodeData struct {
|
||||
Code string `json:"code" form:"code"`
|
||||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||||
}
|
||||
|
||||
type GrantClientData struct {
|
||||
}
|
||||
|
||||
type GrantRefreshData struct {
|
||||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||||
}
|
||||
|
||||
type GrantPasswordData struct {
|
||||
LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"`
|
||||
Username string `json:"username" form:"username"`
|
||||
Password string `json:"password" form:"password"`
|
||||
Remember bool `json:"remember" form:"remember"`
|
||||
}
|
||||
|
||||
type AuthServiceError string
|
||||
|
||||
func (e AuthServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
const (
|
||||
ErrOauthInvalidRequest = AuthServiceError("invalid_request")
|
||||
ErrOauthInvalidClient = AuthServiceError("invalid_client")
|
||||
ErrOauthInvalidGrant = AuthServiceError("invalid_grant")
|
||||
ErrOauthInvalidScope = AuthServiceError("invalid_scope")
|
||||
ErrOauthUnauthorizedClient = AuthServiceError("unauthorized_client")
|
||||
ErrOauthUnsupportedGrantType = AuthServiceError("unsupported_grant_type")
|
||||
)
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
@@ -15,15 +14,17 @@ import (
|
||||
edge2 "platform/web/domains/edge"
|
||||
proxy2 "platform/web/domains/proxy"
|
||||
resource2 "platform/web/domains/resource"
|
||||
"platform/web/events"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/tasks"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"gorm.io/gen/field"
|
||||
|
||||
@@ -296,7 +297,7 @@ func (s *channelService) CreateChannel(
|
||||
ids[i] = channels[i].ID
|
||||
}
|
||||
_, err = g.Asynq.Enqueue(
|
||||
tasks.NewRemoveChannel(ids),
|
||||
events.NewRemoveChannel(ids),
|
||||
asynq.ProcessIn(duration),
|
||||
)
|
||||
if err != nil {
|
||||
|
||||
@@ -12,11 +12,11 @@ import (
|
||||
"platform/web/core"
|
||||
coupon2 "platform/web/domains/coupon"
|
||||
trade2 "platform/web/domains/trade"
|
||||
"platform/web/events"
|
||||
g "platform/web/globals"
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/tasks"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
@@ -240,7 +240,7 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, data *CreateTradeDa
|
||||
}
|
||||
|
||||
// 提交异步关闭事件
|
||||
_, err = g.Asynq.Enqueue(tasks.NewCancelTrade(tasks.CancelTradeData{
|
||||
_, err = g.Asynq.Enqueue(events.NewCancelTrade(events.CancelTradeData{
|
||||
TradeNo: tradeNo,
|
||||
Method: method,
|
||||
}))
|
||||
@@ -417,7 +417,7 @@ func (s *tradeService) CancelTrade(tradeNo string, method trade2.Method, now tim
|
||||
MchOrderNo: &tradeNo,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Debug(fmt.Sprintf("订单无需关闭:%s", err.Error()))
|
||||
slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error()))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -546,11 +546,11 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
|
||||
return nil, core.NewBizErr("订单不存在")
|
||||
}
|
||||
return nil, core.NewServErr(
|
||||
fmt.Sprintf("微信上游接口异常:code=%v,message=%v", apiErr.Code, apiErr.Message),
|
||||
fmt.Sprintf("微信上游接口异常: code=%v,message=%v", apiErr.Code, apiErr.Message),
|
||||
apiErr,
|
||||
)
|
||||
}
|
||||
return nil, core.NewServErr(fmt.Sprintf("微信上游支付接口异常:%s", err.Error()))
|
||||
return nil, core.NewServErr(fmt.Sprintf("微信上游支付接口异常: %s", err.Error()))
|
||||
}
|
||||
|
||||
// 填充返回值
|
||||
|
||||
@@ -19,28 +19,6 @@ import (
|
||||
|
||||
var Verifier = &verifierService{}
|
||||
|
||||
type VerifierServiceError string
|
||||
|
||||
func (e VerifierServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVerifierServiceInvalid = VerifierServiceError("验证码错误")
|
||||
)
|
||||
|
||||
type VerifierServiceSendLimitErr int
|
||||
|
||||
func (e VerifierServiceSendLimitErr) Error() string {
|
||||
return "发送频率过快"
|
||||
}
|
||||
|
||||
type VerifierSmsPurpose int
|
||||
|
||||
const (
|
||||
VerifierSmsPurposeLogin VerifierSmsPurpose = iota
|
||||
)
|
||||
|
||||
type verifierService struct {
|
||||
}
|
||||
|
||||
@@ -148,6 +126,43 @@ func (s *verifierService) VerifySms(ctx context.Context, phone, code string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *verifierService) GetSms(ctx context.Context, phone string) (string, error) {
|
||||
key := smsKey(phone, VerifierSmsPurposeLogin)
|
||||
|
||||
val, err := g.Redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("验证码获取失败: %w", err)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func smsKey(phone string, purpose VerifierSmsPurpose) string {
|
||||
return fmt.Sprintf("verify:sms:%d:%s", purpose, phone)
|
||||
}
|
||||
|
||||
// region 短信目的
|
||||
|
||||
type VerifierSmsPurpose int
|
||||
|
||||
const (
|
||||
VerifierSmsPurposeLogin VerifierSmsPurpose = iota // 登录
|
||||
)
|
||||
|
||||
// region 服务异常
|
||||
|
||||
type VerifierServiceError string
|
||||
|
||||
func (e VerifierServiceError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
var (
|
||||
ErrVerifierServiceInvalid = VerifierServiceError("验证码错误")
|
||||
)
|
||||
|
||||
type VerifierServiceSendLimitErr int
|
||||
|
||||
func (e VerifierServiceSendLimitErr) Error() string {
|
||||
return "发送频率过快"
|
||||
}
|
||||
|
||||
41
web/tasks/task.go
Normal file
41
web/tasks/task.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"platform/web/events"
|
||||
s "platform/web/services"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
)
|
||||
|
||||
func HandleCancelTrade(_ context.Context, task *asynq.Task) (err error) {
|
||||
data := new(events.CancelTradeData)
|
||||
err = json.Unmarshal(task.Payload(), data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析任务参数失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.Trade.CancelTrade(data.TradeNo, data.Method, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("取消交易失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func HandleRemoveChannel(_ context.Context, task *asynq.Task) (err error) {
|
||||
data := make([]int32, 0)
|
||||
err = json.Unmarshal(task.Payload(), &data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("解析任务参数失败: %w", err)
|
||||
}
|
||||
|
||||
err = s.Channel.RemoveChannels(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除通道失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
189
web/web.go
189
web/web.go
@@ -1,166 +1,89 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/middleware/logger"
|
||||
"github.com/gofiber/fiber/v2/middleware/recover"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jxskiss/base62"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"platform/web/auth"
|
||||
g "platform/web/globals"
|
||||
q "platform/web/queries"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
"platform/web/events"
|
||||
base "platform/web/globals"
|
||||
"platform/web/tasks"
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/hibiken/asynq"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// region web
|
||||
func RunApp(pCtx context.Context) error {
|
||||
g, ctx := errgroup.WithContext(pCtx)
|
||||
|
||||
type Config struct {
|
||||
Listen string
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
config *Config
|
||||
fiber *fiber.App
|
||||
}
|
||||
|
||||
func New(config *Config) (*Server, error) {
|
||||
_config := config
|
||||
if config == nil {
|
||||
_config = &Config{}
|
||||
// 初始化依赖
|
||||
err := base.Init(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("初始化依赖失败: %w", err)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: _config,
|
||||
}, nil
|
||||
// 运行服务
|
||||
g.Go(func() error {
|
||||
return RunWeb(ctx)
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
return RunTask(ctx)
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
func RunWeb(ctx context.Context) error {
|
||||
|
||||
// inits
|
||||
g.Init()
|
||||
q.SetDefault(g.DB)
|
||||
|
||||
// config
|
||||
s.fiber = fiber.New(fiber.Config{
|
||||
fiber := fiber.New(fiber.Config{
|
||||
ProxyHeader: fiber.HeaderXForwardedFor,
|
||||
ErrorHandler: ErrorHandler,
|
||||
})
|
||||
|
||||
// middlewares
|
||||
s.fiber.Use(newRecover())
|
||||
s.fiber.Use(newRequestId())
|
||||
s.fiber.Use(newLogger())
|
||||
ApplyMiddlewares(fiber)
|
||||
ApplyRouters(fiber)
|
||||
|
||||
// routes
|
||||
ApplyRouters(s.fiber)
|
||||
|
||||
// pprof
|
||||
// 停止服务
|
||||
go func() {
|
||||
runtime.SetBlockProfileRate(1)
|
||||
err := http.ListenAndServe(":6060", nil)
|
||||
<-ctx.Done()
|
||||
err := fiber.Shutdown()
|
||||
if err != nil {
|
||||
slog.Error("pprof 服务错误", slog.Any("err", err))
|
||||
slog.Error("服务停止失败", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// listen
|
||||
slog.Info("服务开始监听 :8080")
|
||||
err := s.fiber.Listen("0.0.0.0:8080")
|
||||
// 启动服务
|
||||
slog.Info("web 服务开始监听 :8080")
|
||||
err := fiber.Listen("0.0.0.0:8080")
|
||||
if err != nil {
|
||||
slog.Error("Failed to start server", slog.Any("err", err))
|
||||
return fmt.Errorf("web 服务监听失败: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("服务已停止")
|
||||
slog.Info("web 服务已停止")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() {
|
||||
err := g.ExitRedis()
|
||||
func RunTask(ctx context.Context) error {
|
||||
|
||||
var server = asynq.NewServerFromRedisClient(base.Redis, asynq.Config{})
|
||||
|
||||
var mux = asynq.NewServeMux()
|
||||
mux.HandleFunc(events.RemoveChannel, tasks.HandleRemoveChannel)
|
||||
mux.HandleFunc(events.CancelTrade, tasks.HandleCancelTrade)
|
||||
|
||||
// 停止服务
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
server.Shutdown()
|
||||
}()
|
||||
|
||||
// 启动服务
|
||||
err := server.Run(mux)
|
||||
if err != nil {
|
||||
slog.Error("Failed to close Redis connection", slog.Any("err", err))
|
||||
return fmt.Errorf("任务服务运行失败: %w", err)
|
||||
}
|
||||
|
||||
err = g.ExitOrm()
|
||||
if err != nil {
|
||||
slog.Error("Failed to close database connection", slog.Any("err", err))
|
||||
}
|
||||
|
||||
err = s.fiber.Shutdown()
|
||||
if err != nil {
|
||||
slog.Error("Failed to shutdown server", slog.Any("err", err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region middlewares
|
||||
|
||||
func newRequestId() fiber.Handler {
|
||||
return requestid.New(requestid.Config{
|
||||
Generator: func() string {
|
||||
binary, _ := uuid.New().MarshalBinary()
|
||||
return base62.EncodeToString(binary)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newLogger() fiber.Handler {
|
||||
return logger.New(logger.Config{
|
||||
DisableColors: true,
|
||||
Format: "🚀 ${time} | ${locals:authtype} ${locals:authid} | ${method} ${path} | ${status} | ${latency} | ${error}\n",
|
||||
TimeFormat: "2006-01-02 15:04:05",
|
||||
TimeZone: "Asia/Shanghai",
|
||||
Next: func(c *fiber.Ctx) bool {
|
||||
authCtx, ok := c.Locals("auth").(*auth.Context)
|
||||
if ok {
|
||||
c.Locals("authtype", authCtx.Payload.Type.ToStr())
|
||||
c.Locals("authid", authCtx.Payload.Id)
|
||||
} else {
|
||||
c.Locals("authtype", auth.PayloadNone.ToStr())
|
||||
c.Locals("authid", 0)
|
||||
}
|
||||
return false
|
||||
},
|
||||
Done: func(c *fiber.Ctx, logBytes []byte) {
|
||||
var logStr = strings.TrimPrefix(string(logBytes), "🚀")
|
||||
var logVars = strings.Split(logStr, "|")
|
||||
|
||||
var reqTimeStr = strings.TrimSpace(logVars[0])
|
||||
reqTime, err := time.ParseInLocation("2006-01-02 15:04:05", reqTimeStr, time.Local)
|
||||
if err != nil {
|
||||
slog.Error("时间解析错误", slog.Any("err", err))
|
||||
return
|
||||
}
|
||||
|
||||
var latency = strings.TrimSpace(logVars[4])
|
||||
var errStr = strings.TrimSpace(logVars[5])
|
||||
|
||||
slog.Info("接口请求",
|
||||
slog.String("identity", c.Locals("authtype").(string)),
|
||||
slog.Int("visitor", c.Locals("authid").(int)),
|
||||
slog.String("ip", c.IP()),
|
||||
slog.String("ua", c.Get("User-Agent")),
|
||||
slog.String("method", c.Method()),
|
||||
slog.String("path", c.Path()),
|
||||
slog.Int("status", c.Response().StatusCode()),
|
||||
slog.String("error", errStr),
|
||||
slog.String("latency", latency),
|
||||
slog.Time("time", reqTime),
|
||||
)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newRecover() fiber.Handler {
|
||||
return recover.New(recover.Config{
|
||||
EnableStackTrace: true,
|
||||
})
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
Reference in New Issue
Block a user