实现余额购买接口 & 实现全局 id 生成器

This commit is contained in:
2025-04-08 17:15:23 +08:00
parent c02d843dbc
commit 4c47a71f30
10 changed files with 506 additions and 116 deletions

140
web/handlers/resource.go Normal file
View File

@@ -0,0 +1,140 @@
package handlers
import (
"errors"
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
"platform/web/services"
"time"
"github.com/gofiber/fiber/v2"
)
// region CreateResourceByBalance
type CreateResourceByBalanceReq struct {
Type int32 `json:"type" validate:"required"`
Live int32 `json:"live" validate:"required"`
Expire int32 `json:"expire" validate:"required"`
Quota int32 `json:"quota" validate:"required"`
DailyLimit int32 `json:"daily_limit" validate:"required"`
}
// CreateResourceByBalance 通过余额创建资源
func CreateResourceByBalance(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(CreateResourceByBalanceReq)
if err := c.BodyParser(req); err != nil {
return err
}
err = q.Q.Transaction(func(q *q.Query) error {
// 检查用户
user, err := q.User.Where(q.User.ID.Eq(authContext.Payload.Id)).Take()
if err != nil {
return err
}
// 计算价格
var amount = 0
var payment = 0
// 检查余额
if user.Balance < float64(req.Quota)/100 {
return errors.New("余额不足")
}
// 创建资源
resource := m.Resource{
UserID: authContext.Payload.Id,
}
err = q.Resource.Save(&resource)
if err != nil {
return err
}
resourcePss := m.ResourcePss{
ResourceID: resource.ID,
Type: req.Type,
Live: req.Live,
Quota: req.Quota,
Expire: time.Now().Add(time.Duration(req.Expire) * time.Second),
DailyLimit: req.DailyLimit,
}
err = q.ResourcePss.Save(&resourcePss)
if err != nil {
return err
}
// 更新用户余额
user.Balance -= float64(payment)
_, err = q.User.
Where(q.User.ID.Eq(authContext.Payload.Id)).
Update(q.User.Balance, user.Balance)
if err != nil {
return err
}
// 生成账单
bill := m.Bill{
UserID: authContext.Payload.Id,
ResourceID: resource.ID,
BillNo: services.ID.GenReadable("bil"),
Type: 1,
Info: "购买套餐",
Amount: float64(amount),
Payment: float64(payment),
}
err = q.Bill.Save(&bill)
if err != nil {
return err
}
return nil
})
if err != nil {
return err
}
return errors.New("not implemented")
}
// endregion
// region CreateResourceByAlipayCallback
type CreateResourceByAlipayCallbackReq struct {
}
// CreateResourceByAlipayCallback 支付宝支付回调
func CreateResourceByAlipayCallback(c *fiber.Ctx) error {
// 根据支付类型执行不同流程:
// 1. 支付宝或微信(即时支付)
// - 更新订单状态
// - 生成账单
// - 生成资源
return errors.New("not implemented")
}
// endregion
// region CreateResourceByWechatCallback
type CreateResourceByWechatCallbackReq struct {
}
// CreateResourceByWechatCallback 微信支付回调
func CreateResourceByWechatCallback(c *fiber.Ctx) error {
return errors.New("not implemented")
}
// endregion

61
web/handlers/trade.go Normal file
View File

@@ -0,0 +1,61 @@
package handlers
import (
"platform/web/auth"
m "platform/web/models"
q "platform/web/queries"
"platform/web/services"
"strconv"
"github.com/gofiber/fiber/v2"
)
// region CreateTrade
type CreateTradeReq struct {
Subject string `json:"subject" validate:"required"`
Remark string `json:"remark"`
Amount int `json:"amount" validate:"required"`
Method int `json:"method" validate:"required"` // 支付方式1.支付宝2.微信
}
func CreateTrade(c *fiber.Ctx) error {
// 检查权限
authContext, err := auth.Protect(c, []services.PayloadType{services.PayloadUser}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(CreateTradeReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 创建交易订单
num, err := services.ID.GenSerial(c.Context())
if err != nil {
return err
}
var trade = m.Trade{
UserID: authContext.Payload.Id,
InnerNo: strconv.FormatUint(num, 10),
Subject: req.Subject,
Remark: req.Remark,
Amount: float64(req.Amount) / 100,
Method: int32(req.Method),
}
// 调用外部接口
// 保存交易订单
err = q.Trade.Create(&trade)
if err != nil {
return err
}
// 返回结果,外部支付链接
return nil
}
// endregion

View File

@@ -14,18 +14,18 @@ const TableNameBill = "bill"
// Bill mapped from table <bill>
type Bill struct {
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:账单ID" json:"id"` // 账单ID
OrderID int32 `gorm:"column:order_id;not null;comment:订单ID" json:"order_id"` // 订单ID
UserID int32 `gorm:"column:user_id;not null;comment:用户ID" json:"user_id"` // 用户ID
ProductID int32 `gorm:"column:product_id;comment:产品ID" json:"product_id"` // 产品ID
Info string `gorm:"column:info;comment:产品可读信息" json:"info"` // 产品可读信息
Count_ int32 `gorm:"column:count;comment:购买数量" json:"count"` // 购买数量
Price float64 `gorm:"column:price;not null;comment:单价" json:"price"` // 单价
Amount float64 `gorm:"column:amount;not null;comment:总金额" json:"amount"` // 总金额
Payment float64 `gorm:"column:payment;not null;comment:支付金额" json:"payment"` // 支付金额
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:账单ID" json:"id"` // 账单ID
UserID int32 `gorm:"column:user_id;not null;comment:用户ID" json:"user_id"` // 用户ID
Info string `gorm:"column:info;comment:产品可读信息" json:"info"` // 产品可读信息
Amount float64 `gorm:"column:amount;not null;comment:总金额" json:"amount"` // 总金额
Payment float64 `gorm:"column:payment;not null;comment:支付金额" json:"payment"` // 支付金额
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
TradeID int32 `gorm:"column:trade_id" json:"trade_id"`
ResourceID int32 `gorm:"column:resource_id" json:"resource_id"`
Type int32 `gorm:"column:type;not null" json:"type"`
BillNo string `gorm:"column:bill_no;not null" json:"bill_no"`
}
// TableName Bill's table name

View File

@@ -15,12 +15,12 @@ const TableNameRefund = "refund"
// Refund mapped from table <refund>
type Refund struct {
ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:退款ID" json:"id"` // 退款ID
OrderID int32 `gorm:"column:order_id;not null;comment:订单ID" json:"order_id"` // 订单ID
ProductID int32 `gorm:"column:product_id;comment:产品ID" json:"product_id"` // 产品ID
Amount float64 `gorm:"column:amount;not null;comment:退款金额" json:"amount"` // 退款金额
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
TradeID int32 `gorm:"column:trade_id;not null" json:"trade_id"`
}
// TableName Refund's table name

View File

@@ -28,17 +28,17 @@ func newBill(db *gorm.DB, opts ...gen.DOOption) bill {
tableName := _bill.billDo.TableName()
_bill.ALL = field.NewAsterisk(tableName)
_bill.ID = field.NewInt32(tableName, "id")
_bill.OrderID = field.NewInt32(tableName, "order_id")
_bill.UserID = field.NewInt32(tableName, "user_id")
_bill.ProductID = field.NewInt32(tableName, "product_id")
_bill.Info = field.NewString(tableName, "info")
_bill.Count_ = field.NewInt32(tableName, "count")
_bill.Price = field.NewFloat64(tableName, "price")
_bill.Amount = field.NewFloat64(tableName, "amount")
_bill.Payment = field.NewFloat64(tableName, "payment")
_bill.CreatedAt = field.NewTime(tableName, "created_at")
_bill.UpdatedAt = field.NewTime(tableName, "updated_at")
_bill.DeletedAt = field.NewField(tableName, "deleted_at")
_bill.TradeID = field.NewInt32(tableName, "trade_id")
_bill.ResourceID = field.NewInt32(tableName, "resource_id")
_bill.Type = field.NewInt32(tableName, "type")
_bill.BillNo = field.NewString(tableName, "bill_no")
_bill.fillFieldMap()
@@ -48,19 +48,19 @@ func newBill(db *gorm.DB, opts ...gen.DOOption) bill {
type bill struct {
billDo
ALL field.Asterisk
ID field.Int32 // 账单ID
OrderID field.Int32 // 订单ID
UserID field.Int32 // 用户ID
ProductID field.Int32 // 产品ID
Info field.String // 产品可读信息
Count_ field.Int32 // 购买数量
Price field.Float64 // 单价
Amount field.Float64 // 总金额
Payment field.Float64 // 支付金额
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
ALL field.Asterisk
ID field.Int32 // 账单ID
UserID field.Int32 // 用户ID
Info field.String // 产品可读信息
Amount field.Float64 // 总金额
Payment field.Float64 // 支付金额
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
TradeID field.Int32
ResourceID field.Int32
Type field.Int32
BillNo field.String
fieldMap map[string]field.Expr
}
@@ -78,17 +78,17 @@ func (b bill) As(alias string) *bill {
func (b *bill) updateTableName(table string) *bill {
b.ALL = field.NewAsterisk(table)
b.ID = field.NewInt32(table, "id")
b.OrderID = field.NewInt32(table, "order_id")
b.UserID = field.NewInt32(table, "user_id")
b.ProductID = field.NewInt32(table, "product_id")
b.Info = field.NewString(table, "info")
b.Count_ = field.NewInt32(table, "count")
b.Price = field.NewFloat64(table, "price")
b.Amount = field.NewFloat64(table, "amount")
b.Payment = field.NewFloat64(table, "payment")
b.CreatedAt = field.NewTime(table, "created_at")
b.UpdatedAt = field.NewTime(table, "updated_at")
b.DeletedAt = field.NewField(table, "deleted_at")
b.TradeID = field.NewInt32(table, "trade_id")
b.ResourceID = field.NewInt32(table, "resource_id")
b.Type = field.NewInt32(table, "type")
b.BillNo = field.NewString(table, "bill_no")
b.fillFieldMap()
@@ -107,17 +107,17 @@ func (b *bill) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
func (b *bill) fillFieldMap() {
b.fieldMap = make(map[string]field.Expr, 12)
b.fieldMap["id"] = b.ID
b.fieldMap["order_id"] = b.OrderID
b.fieldMap["user_id"] = b.UserID
b.fieldMap["product_id"] = b.ProductID
b.fieldMap["info"] = b.Info
b.fieldMap["count"] = b.Count_
b.fieldMap["price"] = b.Price
b.fieldMap["amount"] = b.Amount
b.fieldMap["payment"] = b.Payment
b.fieldMap["created_at"] = b.CreatedAt
b.fieldMap["updated_at"] = b.UpdatedAt
b.fieldMap["deleted_at"] = b.DeletedAt
b.fieldMap["trade_id"] = b.TradeID
b.fieldMap["resource_id"] = b.ResourceID
b.fieldMap["type"] = b.Type
b.fieldMap["bill_no"] = b.BillNo
}
func (b bill) clone(db *gorm.DB) bill {

View File

@@ -28,12 +28,12 @@ func newRefund(db *gorm.DB, opts ...gen.DOOption) refund {
tableName := _refund.refundDo.TableName()
_refund.ALL = field.NewAsterisk(tableName)
_refund.ID = field.NewInt32(tableName, "id")
_refund.OrderID = field.NewInt32(tableName, "order_id")
_refund.ProductID = field.NewInt32(tableName, "product_id")
_refund.Amount = field.NewFloat64(tableName, "amount")
_refund.CreatedAt = field.NewTime(tableName, "created_at")
_refund.UpdatedAt = field.NewTime(tableName, "updated_at")
_refund.DeletedAt = field.NewField(tableName, "deleted_at")
_refund.TradeID = field.NewInt32(tableName, "trade_id")
_refund.fillFieldMap()
@@ -45,12 +45,12 @@ type refund struct {
ALL field.Asterisk
ID field.Int32 // 退款ID
OrderID field.Int32 // 订单ID
ProductID field.Int32 // 产品ID
Amount field.Float64 // 退款金额
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
TradeID field.Int32
fieldMap map[string]field.Expr
}
@@ -68,12 +68,12 @@ func (r refund) As(alias string) *refund {
func (r *refund) updateTableName(table string) *refund {
r.ALL = field.NewAsterisk(table)
r.ID = field.NewInt32(table, "id")
r.OrderID = field.NewInt32(table, "order_id")
r.ProductID = field.NewInt32(table, "product_id")
r.Amount = field.NewFloat64(table, "amount")
r.CreatedAt = field.NewTime(table, "created_at")
r.UpdatedAt = field.NewTime(table, "updated_at")
r.DeletedAt = field.NewField(table, "deleted_at")
r.TradeID = field.NewInt32(table, "trade_id")
r.fillFieldMap()
@@ -92,12 +92,12 @@ func (r *refund) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
func (r *refund) fillFieldMap() {
r.fieldMap = make(map[string]field.Expr, 7)
r.fieldMap["id"] = r.ID
r.fieldMap["order_id"] = r.OrderID
r.fieldMap["product_id"] = r.ProductID
r.fieldMap["amount"] = r.Amount
r.fieldMap["created_at"] = r.CreatedAt
r.fieldMap["updated_at"] = r.UpdatedAt
r.fieldMap["deleted_at"] = r.DeletedAt
r.fieldMap["trade_id"] = r.TradeID
}
func (r refund) clone(db *gorm.DB) refund {

178
web/services/id.go Normal file
View File

@@ -0,0 +1,178 @@
package services
import (
"context"
"errors"
"fmt"
"platform/pkg/rds"
"strings"
"time"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"github.com/redis/go-redis/v9"
)
var ID IdService = IdService{}
type IdService struct {
}
// region SerialID
const (
// 保留位确保最高位为0防止产生负值
reservedBits = 1
// 时间戳位数
timestampBits = 41
// 序列号位数
sequenceBits = 22
// 最大序列号掩码2^22 - 1
maxSequence = (1 << sequenceBits) - 1
// 位移计算常量
timestampShift = sequenceBits
// Redis 缓存过期时间(秒)
redisTTL = 5
)
var (
ErrSequenceOverflow = errors.New("sequence overflow")
)
func (s *IdService) GenSerial(ctx context.Context) (uint64, error) {
// 构造Redis键
now := time.Now().Unix()
key := idSerialKey(now)
// 使用Redis事务确保原子操作
var sequence int64
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
// 获取当前序列号
currentVal, err := tx.Get(ctx, key).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
return err
}
if errors.Is(err, redis.Nil) {
currentVal = 0
}
sequence = currentVal + 1
// 检查序列号是否溢出
if sequence > maxSequence {
return ErrSequenceOverflow
}
// 将更新后的序列号保存回Redis设置5秒过期时间
pipe := tx.Pipeline()
pipe.Set(ctx, key, sequence, redisTTL*time.Second)
_, err = pipe.Exec(ctx)
return err
}, key)
if err != nil {
return 0, err
}
// 组装最终ID
id := uint64((now << timestampShift) | sequence)
return id, nil
}
// ParseSerial 解析ID返回其组成部分
func (s *IdService) ParseSerial(id uint64) (timestamp int64, sequence int64) {
// 通过位运算和掩码提取各部分
timestamp = int64(id >> timestampShift)
sequence = int64(id & maxSequence)
return
}
// idSerialKey 根据时间戳生成Redis键
func idSerialKey(timestamp int64) string {
return fmt.Sprintf("global:id:serial:%d", timestamp)
}
// endregion
// region ReadableID
// GenReadable 根据给定的标签生成易读的全局唯一标识符
// tag 参数用于标识 ID 的用途,如 "usr" 表示用户ID"ord" 表示订单ID等
// 生成的 ID 格式为:<tag>_<encoded-uuid>例如usr_7NLmVLeHwqS73enFZ1i8tB
func (s *IdService) GenReadable(tag string) string {
// 生成 UUID
id := uuid.New()
// 将 UUID 编码为 Base62 字符串(更短,更易读)
encoded := base62.EncodeToString(id[:])
// 如标签为空,则直接返回编码后的字符串
if tag == "" {
return encoded
}
// 标准化标签:转换为小写并移除特殊字符
tag = normalizeTag(tag)
// 组合最终 ID
return fmt.Sprintf("%s_%s", tag, encoded)
}
// ParseReadableID 解析易读ID返回其标签和编码部分
func (s *IdService) ParseReadableID(id string) (tag string, encoded string) {
parts := strings.SplitN(id, "_", 2)
if len(parts) != 2 {
return "", id
}
return parts[0], parts[1]
}
// TryDecodeID 尝试将编码部分解码回 UUID
// 如果解码失败,返回错误
func (s *IdService) TryDecodeID(encoded string) (uuid.UUID, error) {
// 尝试解码 Base62 编码
bytes, err := base62.DecodeString(encoded)
if err != nil {
return uuid.UUID{}, err
}
// 确保长度正确
if len(bytes) != 16 {
return uuid.UUID{}, fmt.Errorf("invalid UUID length after decoding: %d", len(bytes))
}
// 转换为 UUID
var result uuid.UUID
copy(result[:], bytes)
return result, nil
}
// normalizeTag 标准化标签
// 转换为小写,移除特殊字符,最多保留 5 个字符
func normalizeTag(tag string) string {
// 转换为小写
tag = strings.ToLower(tag)
// 移除特殊字符
var sb strings.Builder
for _, c := range tag {
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') {
sb.WriteRune(c)
}
}
// 截取最多 5 个字符
result := sb.String()
if len(result) > 5 {
result = result[:5]
}
return result
}
// endregion