完善套餐与账单接口 & 完善支付数据保存,记录实付价格并关联优惠券

This commit is contained in:
2026-03-26 14:39:19 +08:00
parent 5ffa151f58
commit 75ad12efb3
23 changed files with 706 additions and 613 deletions

View File

@@ -1,5 +1,7 @@
## TODO ## TODO
增删改数据权限排查
后端默认用户名不能是完整手机号 后端默认用户名不能是完整手机号
前端需要 token 化改造,以避免每次 basic 认证流程中 bcrypt 对比导致的性能对比 前端需要 token 化改造,以避免每次 basic 认证流程中 bcrypt 对比导致的性能对比

2
pkg/env/env.go vendored
View File

@@ -20,7 +20,7 @@ const (
var ( var (
RunMode = RunModeProd RunMode = RunModeProd
LogLevel = slog.LevelDebug LogLevel = slog.LevelDebug
TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 15 分钟 TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 900 秒(15 分钟
SessionAccessExpire = 60 * 60 * 2 // 访问令牌过期时间,单位秒。默认 2 小时 SessionAccessExpire = 60 * 60 * 2 // 访问令牌过期时间,单位秒。默认 2 小时
SessionRefreshExpire = 60 * 60 * 24 * 7 // 刷新令牌过期时间,单位秒。默认 7 天 SessionRefreshExpire = 60 * 60 * 24 * 7 // 刷新令牌过期时间,单位秒。默认 7 天
DebugHttpDump = false // 是否打印请求和响应的原始数据 DebugHttpDump = false // 是否打印请求和响应的原始数据

View File

@@ -5,8 +5,8 @@
insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'web', 'web', '$2a$10$Ss12mXQgpYyo1CKIZ3URouDm.Lc2KcYJzsvEK2PTIXlv6fHQht45a', ''); insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'web', 'web', '$2a$10$Ss12mXQgpYyo1CKIZ3URouDm.Lc2KcYJzsvEK2PTIXlv6fHQht45a', '');
insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'admin', 'admin', '$2a$10$dlfvX5Uf3iVsUWgwlb0Wt.oYsw/OEXgS.Aior3yoT63Ju7ZSsJr/2', ''); insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'admin', 'admin', '$2a$10$dlfvX5Uf3iVsUWgwlb0Wt.oYsw/OEXgS.Aior3yoT63Ju7ZSsJr/2', '');
insert into product (code, name, description) values ('dynamic-short', '短效动态', '短效动态'); insert into product (code, name, description) values ('short', '短效动态', '短效动态');
insert into product (code, name, description) values ('dynamic-long', '长效动态', '长效动态'); insert into product (code, name, description) values ('long', '长效动态', '长效动态');
insert into product (code, name, description) values ('static', '长效静态', '长效静态'); insert into product (code, name, description) values ('static', '长效静态', '长效静态');
delete from permission where true; delete from permission where true;

View File

@@ -747,7 +747,7 @@ create table product_sku (
id int generated by default as identity primary key, id int generated by default as identity primary key,
product_id int not null, product_id int not null,
discount_id int, discount_id int,
code text not null, code text not null unique,
name text not null, name text not null,
price decimal not null, price decimal not null,
created_at timestamptz default current_timestamp, created_at timestamptz default current_timestamp,
@@ -756,7 +756,7 @@ create table product_sku (
); );
create index idx_product_sku_product_id on product_sku (product_id) where deleted_at is null; create index idx_product_sku_product_id on product_sku (product_id) where deleted_at is null;
create index idx_product_sku_discount_id on product_sku (discount_id) where deleted_at is null; create index idx_product_sku_discount_id on product_sku (discount_id) where deleted_at is null;
create index idx_product_sku_code on product_sku (code) where deleted_at is null; create unique index idx_product_sku_code on product_sku (code) where deleted_at is null;
-- product_sku表字段注释 -- product_sku表字段注释
comment on table product_sku is '产品SKU表'; comment on table product_sku is '产品SKU表';
@@ -977,10 +977,12 @@ create table bill (
trade_id int, trade_id int,
resource_id int, resource_id int,
refund_id int, refund_id int,
coupon_id int,
bill_no text not null, bill_no text not null,
info text, info text,
type int not null, type int not null,
amount decimal(12, 2) not null default 0, amount decimal(12, 2) not null default 0,
actual decimal(12, 2) not null default 0,
created_at timestamptz default current_timestamp, created_at timestamptz default current_timestamp,
updated_at timestamptz default current_timestamp, updated_at timestamptz default current_timestamp,
deleted_at timestamptz deleted_at timestamptz
@@ -990,6 +992,7 @@ create index idx_bill_user_id on bill (user_id) where deleted_at is null;
create index idx_bill_trade_id on bill (trade_id) where deleted_at is null; create index idx_bill_trade_id on bill (trade_id) where deleted_at is null;
create index idx_bill_resource_id on bill (resource_id) where deleted_at is null; create index idx_bill_resource_id on bill (resource_id) where deleted_at is null;
create index idx_bill_refund_id on bill (refund_id) where deleted_at is null; create index idx_bill_refund_id on bill (refund_id) where deleted_at is null;
create index idx_bill_coupon_id on bill (coupon_id) where deleted_at is null;
create index idx_bill_created_at on bill (created_at) where deleted_at is null; create index idx_bill_created_at on bill (created_at) where deleted_at is null;
-- bill表字段注释 -- bill表字段注释
@@ -1002,7 +1005,8 @@ comment on column bill.refund_id is '退款ID';
comment on column bill.bill_no is '易读账单号'; comment on column bill.bill_no is '易读账单号';
comment on column bill.info is '产品可读信息'; comment on column bill.info is '产品可读信息';
comment on column bill.type is '账单类型1-消费2-退款3-充值'; comment on column bill.type is '账单类型1-消费2-退款3-充值';
comment on column bill.amount is '账单金额'; comment on column bill.amount is '应付金额';
comment on column bill.actual is '实付金额';
comment on column bill.created_at is '创建时间'; comment on column bill.created_at is '创建时间';
comment on column bill.updated_at is '更新时间'; comment on column bill.updated_at is '更新时间';
comment on column bill.deleted_at is '删除时间'; comment on column bill.deleted_at is '删除时间';
@@ -1107,14 +1111,20 @@ alter table channel
-- resource表外键 -- resource表外键
alter table resource alter table resource
add constraint fk_resource_user_id foreign key (user_id) references "user" (id) on delete cascade; add constraint fk_resource_user_id foreign key (user_id) references "user" (id) on delete cascade;
alter table resource
add constraint fk_product_code foreign key (code) references product (code) on update cascade on delete restrict;
-- resource_short表外键 -- resource_short表外键
alter table resource_short alter table resource_short
add constraint fk_resource_short_resource_id foreign key (resource_id) references resource (id) on delete cascade; add constraint fk_resource_short_resource_id foreign key (resource_id) references resource (id) on delete cascade;
alter table resource_short
add constraint fk_resource_short_code foreign key (code) references product_sku (code) on update cascade on delete restrict;
-- resource_long表外键 -- resource_long表外键
alter table resource_long alter table resource_long
add constraint fk_resource_long_resource_id foreign key (resource_id) references resource (id) on delete cascade; add constraint fk_resource_long_resource_id foreign key (resource_id) references resource (id) on delete cascade;
alter table resource_long
add constraint fk_resource_long_code foreign key (code) references product_sku (code) on update cascade on delete restrict;
-- trade表外键 -- trade表外键
alter table trade alter table trade
@@ -1135,6 +1145,8 @@ alter table bill
add constraint fk_bill_resource_id foreign key (resource_id) references resource (id) on delete set null; add constraint fk_bill_resource_id foreign key (resource_id) references resource (id) on delete set null;
alter table bill alter table bill
add constraint fk_bill_refund_id foreign key (refund_id) references refund (id) on delete set null; add constraint fk_bill_refund_id foreign key (refund_id) references refund (id) on delete set null;
alter table bill
add constraint fk_bill_coupon_id foreign key (coupon_id) references coupon (id) on delete set null;
-- coupon表外键 -- coupon表外键
alter table coupon alter table coupon

View File

@@ -3,14 +3,22 @@ package core
const ( const (
ScopePermissionRead = string("permission:read") ScopePermissionRead = string("permission:read")
ScopePermissionWrite = string("permission:write") ScopePermissionWrite = string("permission:write")
ScopeAdminRoleRead = string("admin_role:read") ScopeAdminRoleRead = string("admin_role:read")
ScopeAdminRoleWrite = string("admin_role:write") ScopeAdminRoleWrite = string("admin_role:write")
ScopeAdminRead = string("admin:read") ScopeAdminRead = string("admin:read")
ScopeAdminWrite = string("admin:write") ScopeAdminWrite = string("admin:write")
ScopeProductRead = string("product:read") ScopeProductRead = string("product:read")
ScopeProductWrite = string("product:write") ScopeProductWrite = string("product:write")
ScopeProductSkuRead = string("product_sku:read") ScopeProductSkuRead = string("product_sku:read")
ScopeProductSkuWrite = string("product_sku:write") ScopeProductSkuWrite = string("product_sku:write")
ScopeProductDiscountRead = string("product_discount:read") ScopeProductDiscountRead = string("product_discount:read")
ScopeProductDiscountWrite = string("product_discount:write") ScopeProductDiscountWrite = string("product_discount:write")
ScopeResourceRead = string("resource:read")
ScopeResourceWrite = string("resource:write")
) )

View File

@@ -1,7 +1,9 @@
package web package web
import ( import (
"encoding/json"
"errors" "errors"
"fmt"
"log/slog" "log/slog"
"platform/web/auth" "platform/web/auth"
"platform/web/core" "platform/web/core"
@@ -19,6 +21,7 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
var authErr auth.AuthErr var authErr auth.AuthErr
var bizErr *core.BizErr var bizErr *core.BizErr
var servErr *core.ServErr var servErr *core.ServErr
var jsonErr *json.UnmarshalTypeError
switch { switch {
@@ -48,6 +51,10 @@ func ErrorHandler(c *fiber.Ctx, err error) error {
code = fiber.StatusInternalServerError code = fiber.StatusInternalServerError
message = err.Error() message = err.Error()
case errors.As(err, &jsonErr):
code = fiber.StatusBadRequest
message = fmt.Sprintf("参数 %s 类型不正确,传入类型为 %s正确类型应该为 %s", jsonErr.Field, jsonErr.Value, jsonErr.Type.Name())
// 所有未手动声明的错误类型 // 所有未手动声明的错误类型
default: default:
slog.Warn("未处理的异常", slog.String("type", reflect.TypeOf(err).Name()), slog.String("error", err.Error())) slog.Warn("未处理的异常", slog.String("type", reflect.TypeOf(err).Name()), slog.String("error", err.Error()))

View File

@@ -9,18 +9,23 @@ import (
"github.com/hibiken/asynq" "github.com/hibiken/asynq"
) )
const CompleteTrade = "trade:update" const CloseTrade = "trade:update"
type CompleteTradeData struct { type CloseTradeData struct {
UserId int32 `json:"user_id" validate:"required"`
TradeNo string `json:"trade_no" validate:"required"` TradeNo string `json:"trade_no" validate:"required"`
Method m.TradeMethod `json:"method" validate:"required"` Method m.TradeMethod `json:"method" validate:"required"`
} }
func NewCancelTrade(data CompleteTradeData) *asynq.Task { func NewCloseTradeTask(uid int32, tradeNo string, method m.TradeMethod) *asynq.Task {
bytes, err := json.Marshal(data) bytes, err := json.Marshal(CloseTradeData{
UserId: uid,
TradeNo: tradeNo,
Method: method,
})
if err != nil { if err != nil {
slog.Error("序列化更新交易任务失败", "error", err) slog.Error("序列化更新交易任务失败", "error", err)
return nil return nil
} }
return asynq.NewTask(CompleteTrade, bytes) return asynq.NewTask(CloseTrade, bytes)
} }

30
web/globals/orm/timez.go Normal file
View File

@@ -0,0 +1,30 @@
package orm
import (
"fmt"
"time"
)
type DateTime struct {
time.Time
}
func (dt *DateTime) Scan(value any) error {
switch v := value.(type) {
case time.Time:
dt.Time = v
case string:
t, err := time.Parse(time.RFC3339, v)
if err != nil {
return err
}
dt.Time = t
default:
return fmt.Errorf("unsupported type: %T", value)
}
return nil
}
func (dt DateTime) Value() (any, error) {
return dt.Time.Format(time.RFC3339), nil
}

View File

@@ -47,10 +47,24 @@ func PageBillByAdmin(c *fiber.Ctx) error {
time := u.DateHead(*req.CreatedAtEnd) time := u.DateHead(*req.CreatedAtEnd)
do = do.Where(q.Bill.CreatedAt.Lte(time)) do = do.Where(q.Bill.CreatedAt.Lte(time))
} }
if req.ProductCode != nil {
do = do.Where(q.Resource.As("Resource").Code.Eq(*req.ProductCode))
}
if req.SkuCode != nil {
do = do.Where(q.Bill.
Where(q.ResourceShort.As("Resource__Short").Code.Eq(*req.SkuCode)).
Or(q.ResourceLong.As("Resource__Long").Code.Eq(*req.SkuCode)))
}
// 查询用户列表 // 查询用户列表
list, total, err := q.Bill.Debug(). list, total, err := q.Bill.Debug().
Joins(q.Bill.User, q.Bill.Resource, q.Bill.Trade). Joins(
q.Bill.User,
q.Bill.Resource,
q.Bill.Trade,
q.Bill.Resource.Short,
q.Bill.Resource.Long,
).
Select( Select(
q.Bill.ALL, q.Bill.ALL,
q.User.As("User").Phone.As("User__phone"), q.User.As("User").Phone.As("User__phone"),
@@ -82,6 +96,8 @@ type PageBillByAdminReq struct {
BillNo *string `json:"bill_no,omitempty"` BillNo *string `json:"bill_no,omitempty"`
CreatedAtStart *time.Time `json:"created_at_start,omitempty"` CreatedAtStart *time.Time `json:"created_at_start,omitempty"`
CreatedAtEnd *time.Time `json:"created_at_end,omitempty"` CreatedAtEnd *time.Time `json:"created_at_end,omitempty"`
ProductCode *string `json:"product_code,omitempty"`
SkuCode *string `json:"sku_code,omitempty"`
} }
// ListBill 获取账单列表 // ListBill 获取账单列表

View File

@@ -91,6 +91,29 @@ func DeleteProduct(c *fiber.Ctx) error {
return nil return nil
} }
func AllProductSkuByAdmin(c *fiber.Ctx) error {
_, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeProductSkuRead)
if err != nil {
return err
}
var req AllProductSkuByAdminReq
if err := g.Validator.ParseBody(c, &req); err != nil {
return err
}
list, err := s.ProductSku.All(req.Code)
if err != nil {
return err
}
return c.JSON(list)
}
type AllProductSkuByAdminReq struct {
Code string `json:"product_code"`
}
func PageProductSkuByAdmin(c *fiber.Ctx) error { func PageProductSkuByAdmin(c *fiber.Ctx) error {
_, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeProductSkuRead) _, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeProductSkuRead)
if err != nil { if err != nil {

View File

@@ -70,7 +70,7 @@ func PageResourceShort(c *fiber.Ctx) error {
} }
resource, err := q.Resource.Where(do). resource, err := q.Resource.Where(do).
Joins(q.Resource.Short, q.ResourceShort.Sku). Joins(q.Resource.Short).
Order(q.Resource.CreatedAt.Desc()). Order(q.Resource.CreatedAt.Desc()).
Offset(req.GetOffset()). Offset(req.GetOffset()).
Limit(req.GetLimit()). Limit(req.GetLimit()).
@@ -240,9 +240,28 @@ func PageResourceShortByAdmin(c *fiber.Ctx) error {
time := u.DateTail(*req.CreatedAtEnd) time := u.DateTail(*req.CreatedAtEnd)
do = do.Where(q.Resource.CreatedAt.Lte(time)) do = do.Where(q.Resource.CreatedAt.Lte(time))
} }
if req.Expired != nil {
if *req.Expired {
do = do.Where(q.Resource.Where(
q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeTime)),
q.ResourceShort.As("Short").ExpireAt.Lte(time.Now()),
).Or(
q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeQuota)),
q.ResourceShort.As("Short").Quota.LteCol(q.ResourceShort.As("Short").Used),
))
} else {
do = do.Where(q.Resource.Where(
q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeTime)),
q.ResourceShort.As("Short").ExpireAt.Gt(time.Now()),
).Or(
q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeQuota)),
q.ResourceShort.As("Short").Quota.GtCol(q.ResourceShort.As("Short").Used),
))
}
}
list, total, err := q.Resource.Debug(). list, total, err := q.Resource.Debug().
Joins(q.Resource.User, q.Resource.Short). Joins(q.Resource.User, q.Resource.Short, q.Resource.Short.Sku).
Select( Select(
q.Resource.ALL, q.Resource.ALL,
q.User.As("User").Phone.As("User__phone"), q.User.As("User").Phone.As("User__phone"),
@@ -254,9 +273,14 @@ func PageResourceShortByAdmin(c *fiber.Ctx) error {
q.ResourceShort.As("Short").Daily.As("Short__daily"), q.ResourceShort.As("Short").Daily.As("Short__daily"),
q.ResourceShort.As("Short").LastAt.As("Short__last_at"), q.ResourceShort.As("Short").LastAt.As("Short__last_at"),
q.ResourceShort.As("Short").ExpireAt.As("Short__expire_at"), q.ResourceShort.As("Short").ExpireAt.As("Short__expire_at"),
q.ProductSku.As("Short__Sku").Name.As("Short__Sku__name"),
). ).
Where(q.Resource.Type.Eq(int(m.ResourceTypeShort)), do). Where(q.Resource.Type.Eq(int(m.ResourceTypeShort)), do).
Order(q.Resource.CreatedAt.Desc()).
FindByPage(req.GetOffset(), req.GetLimit()) FindByPage(req.GetOffset(), req.GetLimit())
if err != nil {
return err
}
return c.JSON(core.PageResp{ return c.JSON(core.PageResp{
List: list, List: list,
@@ -274,9 +298,10 @@ type PageResourceShortByAdminReq struct {
Mode *int `json:"mode" form:"mode"` Mode *int `json:"mode" form:"mode"`
CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"` CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"`
CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"` CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"`
Expired *bool `json:"expired" form:"expired"`
} }
// PageResourceLongByAdmin 分页查询全部效套餐 // PageResourceLongByAdmin 分页查询全部效套餐
func PageResourceLongByAdmin(c *fiber.Ctx) error { func PageResourceLongByAdmin(c *fiber.Ctx) error {
_, err := auth.GetAuthCtx(c).PermitAdmin() _, err := auth.GetAuthCtx(c).PermitAdmin()
if err != nil { if err != nil {
@@ -307,9 +332,28 @@ func PageResourceLongByAdmin(c *fiber.Ctx) error {
if req.CreatedAtEnd != nil { if req.CreatedAtEnd != nil {
do = do.Where(q.Resource.CreatedAt.Lte(*req.CreatedAtEnd)) do = do.Where(q.Resource.CreatedAt.Lte(*req.CreatedAtEnd))
} }
if req.Expired != nil {
if *req.Expired {
do = do.Where(q.Resource.Where(
q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeTime)),
q.ResourceLong.As("Long").ExpireAt.Lte(time.Now()),
).Or(
q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeQuota)),
q.ResourceLong.As("Long").Quota.LteCol(q.ResourceLong.As("Long").Used),
))
} else {
do = do.Where(q.Resource.Where(
q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeTime)),
q.ResourceLong.As("Long").ExpireAt.Gt(time.Now()),
).Or(
q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeQuota)),
q.ResourceLong.As("Long").Quota.GtCol(q.ResourceLong.As("Long").Used),
))
}
}
list, total, err := q.Resource. list, total, err := q.Resource.Debug().
Joins(q.Resource.User, q.Resource.Long). Joins(q.Resource.User, q.Resource.Long, q.Resource.Long.Sku).
Select( Select(
q.Resource.ALL, q.Resource.ALL,
q.User.As("User").Phone.As("User__phone"), q.User.As("User").Phone.As("User__phone"),
@@ -321,9 +365,14 @@ func PageResourceLongByAdmin(c *fiber.Ctx) error {
q.ResourceLong.As("Long").Daily.As("Long__daily"), q.ResourceLong.As("Long").Daily.As("Long__daily"),
q.ResourceLong.As("Long").LastAt.As("Long__last_at"), q.ResourceLong.As("Long").LastAt.As("Long__last_at"),
q.ResourceLong.As("Long").ExpireAt.As("Long__expire_at"), q.ResourceLong.As("Long").ExpireAt.As("Long__expire_at"),
q.ProductSku.As("Long__Sku").Name.As("Long__Sku__name"),
). ).
Where(q.Resource.Type.Eq(int(m.ResourceTypeLong)), do). Where(q.Resource.Type.Eq(int(m.ResourceTypeLong)), do).
Order(q.Resource.CreatedAt.Desc()).
FindByPage(req.GetOffset(), req.GetLimit()) FindByPage(req.GetOffset(), req.GetLimit())
if err != nil {
return err
}
return c.JSON(core.PageResp{ return c.JSON(core.PageResp{
List: list, List: list,
@@ -341,6 +390,7 @@ type PageResourceLongByAdminReq struct {
Mode *int `json:"mode" form:"mode"` Mode *int `json:"mode" form:"mode"`
CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"` CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"`
CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"` CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"`
Expired *bool `json:"expired" form:"expired"`
} }
// AllActiveResource 所有可用套餐 // AllActiveResource 所有可用套餐
@@ -402,6 +452,24 @@ func AllActiveResource(c *fiber.Ctx) error {
type AllResourceReq struct { type AllResourceReq struct {
} }
func UpdateResourceByAdmin(c *fiber.Ctx) error {
_, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeResourceWrite)
if err != nil {
return err
}
var req s.UpdateResourceData
if err := c.BodyParser(&req); err != nil {
return err
}
if err := s.Resource.Update(&req); err != nil {
return err
}
return c.JSON(nil)
}
// StatisticResourceFree 统计每日可用 // StatisticResourceFree 统计每日可用
func StatisticResourceFree(c *fiber.Ctx) error { func StatisticResourceFree(c *fiber.Ctx) error {
// 检查权限 // 检查权限
@@ -602,26 +670,28 @@ func ResourcePrice(c *fiber.Ctx) error {
} }
// 获取套餐价格 // 获取套餐价格
sku, err := s.Resource.GetSku(req.CreateResourceData) sku, err := s.Resource.GetSku(req.CreateResourceData.Code())
if err != nil { if err != nil {
return err return err
} }
before, after, err := s.Resource.GetPrice(sku, req.Count(), nil) _, amount, discounted, couponApplied, err := s.Resource.GetPrice(sku, req.Count(), nil, nil)
if err != nil { if err != nil {
return err return err
} }
// 计算折扣 // 计算折扣
return c.JSON(ResourcePriceResp{ return c.JSON(ResourcePriceResp{
Price: before.StringFixed(2), Discount: float32(sku.Discount.Discount) / 100,
Discounted: float32(sku.Discount.Discount) / 100, Price: amount.StringFixed(2),
DiscountedPrice: after.StringFixed(2), Discounted: discounted.StringFixed(2),
CouponApplied: couponApplied.StringFixed(2),
}) })
} }
type ResourcePriceResp struct { type ResourcePriceResp struct {
Price string `json:"price"` Price string `json:"price"`
Discounted float32 `json:"discounted"` Discount float32 `json:"discounted"`
DiscountedPrice string `json:"discounted_price"` Discounted string `json:"discounted_price"`
CouponApplied string `json:"coupon_applied"`
} }

View File

@@ -109,53 +109,38 @@ func TradeCreate(c *fiber.Ctx) error {
if err := g.Validator.ParseBody(c, req); err != nil { if err := g.Validator.ParseBody(c, req); err != nil {
return err return err
} }
var product s.ProductInfo
switch req.Type { switch req.Type {
case m.TradeTypePurchase: case m.TradeTypePurchase:
if req.Resource == nil { if req.Resource == nil {
return core.NewBizErr("购买信息不能为空") return core.NewBizErr("购买信息不能为空")
} }
product, err = s.NewCreateResourceByTradeData(req.Resource)
if err != nil {
return core.NewServErr("处理购买产品信息失败", err)
}
case m.TradeTypeRecharge: case m.TradeTypeRecharge:
if req.Recharge == nil { if req.Recharge == nil {
return core.NewBizErr("充值信息不能为空") return core.NewBizErr("充值信息不能为空")
} }
product = req.Recharge
} }
// 创建交易 // 处理订单
result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData, product) uid := authCtx.User.ID
result, err := s.Trade.Create(uid, req.CreateTradeData, req.Resource)
if err != nil { if err != nil {
slog.Error("创建交易失败", "error", err) return core.NewServErr("处理购买产品信息失败", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"})
} }
return c.JSON(&TradeCreateResp{ return c.JSON(result)
PayUrl: result.PaymentUrl,
TradeNo: result.TradeNo,
})
} }
type TradeCreateReq struct { type TradeCreateReq struct {
s.CreateTradeData *s.CreateTradeData
Type m.TradeType `json:"type" validate:"required"` Type m.TradeType `json:"type" validate:"required"`
Resource *s.CreateResourceData `json:"resource,omitempty"` Resource *s.CreateResourceData `json:"resource,omitempty"`
Recharge *s.RechargeProductInfo `json:"recharge,omitempty"` Recharge *s.UpdateBalanceData `json:"recharge,omitempty"`
}
type TradeCreateResp struct {
PayUrl string `json:"pay_url"`
TradeNo string `json:"trade_no"`
} }
// 完成订单 // 完成订单
func TradeComplete(c *fiber.Ctx) error { func TradeComplete(c *fiber.Ctx) error {
// 检查权限 // 检查权限
_, err := auth.GetAuthCtx(c).PermitUser() authCtx, err := auth.GetAuthCtx(c).PermitUser()
if err != nil { if err != nil {
return err return err
} }
@@ -167,7 +152,7 @@ func TradeComplete(c *fiber.Ctx) error {
} }
// 检查订单状态 // 检查订单状态
err = s.Trade.CompleteTrade(&req.ModifyTradeData) err = s.Trade.CompleteTrade(authCtx.User, &req.TradeRef)
if err != nil { if err != nil {
return err return err
} }
@@ -176,7 +161,7 @@ func TradeComplete(c *fiber.Ctx) error {
} }
type TradeCompleteReq struct { type TradeCompleteReq struct {
s.ModifyTradeData s.TradeRef
} }
// 取消订单 // 取消订单
@@ -194,7 +179,7 @@ func TradeCancel(c *fiber.Ctx) error {
} }
// 取消交易 // 取消交易
err = s.Trade.CancelTrade(&req.ModifyTradeData, time.Now()) err = s.Trade.CancelTrade(&req.TradeRef)
if err != nil { if err != nil {
slog.Error("取消交易失败", "trade_no", req.TradeNo, "error", err) slog.Error("取消交易失败", "trade_no", req.TradeNo, "error", err)
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "取消交易失败"}) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "取消交易失败"})
@@ -204,7 +189,7 @@ func TradeCancel(c *fiber.Ctx) error {
} }
type TradeCancelReq struct { type TradeCancelReq struct {
s.ModifyTradeData s.TradeRef
} }
// 检查订单 // 检查订单
@@ -225,7 +210,7 @@ func TradeCheck(c *fiber.Ctx) error {
interval := 5 interval := 5
for range expire / interval { for range expire / interval {
// 检查订单状态 // 检查订单状态
result, err := s.Trade.CheckTrade(&req.ModifyTradeData) result, err := s.Trade.CheckTrade(&req.TradeRef)
if err != nil { if err != nil {
slog.Error("检查订单状态失败", "trade_no", req.TradeNo, "error", err) slog.Error("检查订单状态失败", "trade_no", req.TradeNo, "error", err)
return return
@@ -256,5 +241,5 @@ func TradeCheck(c *fiber.Ctx) error {
} }
type TradeCheckReq struct { type TradeCheckReq struct {
s.ModifyTradeData s.TradeRef
} }

View File

@@ -13,10 +13,12 @@ type Bill struct {
TradeID *int32 `json:"trade_id,omitempty" gorm:"column:trade_id"` // 订单ID TradeID *int32 `json:"trade_id,omitempty" gorm:"column:trade_id"` // 订单ID
ResourceID *int32 `json:"resource_id,omitempty" gorm:"column:resource_id"` // 套餐ID ResourceID *int32 `json:"resource_id,omitempty" gorm:"column:resource_id"` // 套餐ID
RefundID *int32 `json:"refund_id,omitempty" gorm:"column:refund_id"` // 退款ID RefundID *int32 `json:"refund_id,omitempty" gorm:"column:refund_id"` // 退款ID
CouponID *int32 `json:"coupon_id,omitempty" gorm:"column:coupon_id"` // 优惠券ID
BillNo string `json:"bill_no" gorm:"column:bill_no"` // 易读账单号 BillNo string `json:"bill_no" gorm:"column:bill_no"` // 易读账单号
Info *string `json:"info,omitempty" gorm:"column:info"` // 产品可读信息 Info *string `json:"info,omitempty" gorm:"column:info"` // 产品可读信息
Type BillType `json:"type" gorm:"column:type"` // 账单类型1-消费2-退款3-充值 Type BillType `json:"type" gorm:"column:type"` // 账单类型1-消费2-退款3-充值
Amount decimal.Decimal `json:"amount" gorm:"column:amount"` // 账单金额 Amount decimal.Decimal `json:"amount" gorm:"column:amount"` // 应付金额
Actual decimal.Decimal `json:"actual" gorm:"column:actual"` // 实付金额
User *User `json:"user,omitempty" gorm:"foreignKey:UserID"` User *User `json:"user,omitempty" gorm:"foreignKey:UserID"`
Trade *Trade `json:"trade,omitempty" gorm:"foreignKey:TradeID"` Trade *Trade `json:"trade,omitempty" gorm:"foreignKey:TradeID"`

View File

@@ -35,10 +35,12 @@ func newBill(db *gorm.DB, opts ...gen.DOOption) bill {
_bill.TradeID = field.NewInt32(tableName, "trade_id") _bill.TradeID = field.NewInt32(tableName, "trade_id")
_bill.ResourceID = field.NewInt32(tableName, "resource_id") _bill.ResourceID = field.NewInt32(tableName, "resource_id")
_bill.RefundID = field.NewInt32(tableName, "refund_id") _bill.RefundID = field.NewInt32(tableName, "refund_id")
_bill.CouponID = field.NewInt32(tableName, "coupon_id")
_bill.BillNo = field.NewString(tableName, "bill_no") _bill.BillNo = field.NewString(tableName, "bill_no")
_bill.Info = field.NewString(tableName, "info") _bill.Info = field.NewString(tableName, "info")
_bill.Type = field.NewInt(tableName, "type") _bill.Type = field.NewInt(tableName, "type")
_bill.Amount = field.NewField(tableName, "amount") _bill.Amount = field.NewField(tableName, "amount")
_bill.Actual = field.NewField(tableName, "actual")
_bill.User = billBelongsToUser{ _bill.User = billBelongsToUser{
db: db.Session(&gorm.Session{}), db: db.Session(&gorm.Session{}),
@@ -208,10 +210,12 @@ type bill struct {
TradeID field.Int32 TradeID field.Int32
ResourceID field.Int32 ResourceID field.Int32
RefundID field.Int32 RefundID field.Int32
CouponID field.Int32
BillNo field.String BillNo field.String
Info field.String Info field.String
Type field.Int Type field.Int
Amount field.Field Amount field.Field
Actual field.Field
User billBelongsToUser User billBelongsToUser
Trade billBelongsToTrade Trade billBelongsToTrade
@@ -243,10 +247,12 @@ func (b *bill) updateTableName(table string) *bill {
b.TradeID = field.NewInt32(table, "trade_id") b.TradeID = field.NewInt32(table, "trade_id")
b.ResourceID = field.NewInt32(table, "resource_id") b.ResourceID = field.NewInt32(table, "resource_id")
b.RefundID = field.NewInt32(table, "refund_id") b.RefundID = field.NewInt32(table, "refund_id")
b.CouponID = field.NewInt32(table, "coupon_id")
b.BillNo = field.NewString(table, "bill_no") b.BillNo = field.NewString(table, "bill_no")
b.Info = field.NewString(table, "info") b.Info = field.NewString(table, "info")
b.Type = field.NewInt(table, "type") b.Type = field.NewInt(table, "type")
b.Amount = field.NewField(table, "amount") b.Amount = field.NewField(table, "amount")
b.Actual = field.NewField(table, "actual")
b.fillFieldMap() b.fillFieldMap()
@@ -263,7 +269,7 @@ func (b *bill) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
} }
func (b *bill) fillFieldMap() { func (b *bill) fillFieldMap() {
b.fieldMap = make(map[string]field.Expr, 16) b.fieldMap = make(map[string]field.Expr, 18)
b.fieldMap["id"] = b.ID b.fieldMap["id"] = b.ID
b.fieldMap["created_at"] = b.CreatedAt b.fieldMap["created_at"] = b.CreatedAt
b.fieldMap["updated_at"] = b.UpdatedAt b.fieldMap["updated_at"] = b.UpdatedAt
@@ -272,10 +278,12 @@ func (b *bill) fillFieldMap() {
b.fieldMap["trade_id"] = b.TradeID b.fieldMap["trade_id"] = b.TradeID
b.fieldMap["resource_id"] = b.ResourceID b.fieldMap["resource_id"] = b.ResourceID
b.fieldMap["refund_id"] = b.RefundID b.fieldMap["refund_id"] = b.RefundID
b.fieldMap["coupon_id"] = b.CouponID
b.fieldMap["bill_no"] = b.BillNo b.fieldMap["bill_no"] = b.BillNo
b.fieldMap["info"] = b.Info b.fieldMap["info"] = b.Info
b.fieldMap["type"] = b.Type b.fieldMap["type"] = b.Type
b.fieldMap["amount"] = b.Amount b.fieldMap["amount"] = b.Amount
b.fieldMap["actual"] = b.Actual
} }

View File

@@ -4,6 +4,9 @@ import (
"platform/pkg/env" "platform/pkg/env"
auth2 "platform/web/auth" auth2 "platform/web/auth"
"platform/web/handlers" "platform/web/handlers"
"time"
q "platform/web/queries"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -23,6 +26,13 @@ func ApplyRouters(app *fiber.App) {
debug.Get("/sms/:phone", handlers.DebugGetSmsCode) debug.Get("/sms/:phone", handlers.DebugGetSmsCode)
debug.Get("/proxy/register", handlers.DebugRegisterProxyBaiYin) debug.Get("/proxy/register", handlers.DebugRegisterProxyBaiYin)
debug.Get("/iden/clear/:phone", handlers.DebugIdentifyClear) debug.Get("/iden/clear/:phone", handlers.DebugIdentifyClear)
debug.Get("/session/now", func(ctx *fiber.Ctx) error {
rs, err := q.Session.Where(q.Session.AccessTokenExpires.Gt(time.Now())).Find()
if err != nil {
return err
}
return ctx.JSON(rs)
})
} }
} }
@@ -136,6 +146,7 @@ func adminRouter(api fiber.Router) {
var resource = api.Group("/resource") var resource = api.Group("/resource")
resource.Post("/short/page", handlers.PageResourceShortByAdmin) resource.Post("/short/page", handlers.PageResourceShortByAdmin)
resource.Post("/long/page", handlers.PageResourceLongByAdmin) resource.Post("/long/page", handlers.PageResourceLongByAdmin)
resource.Post("/update", handlers.UpdateResourceByAdmin)
// batch 批次 // batch 批次
var usage = api.Group("batch") var usage = api.Group("batch")
@@ -159,6 +170,7 @@ func adminRouter(api fiber.Router) {
product.Post("/create", handlers.CreateProduct) product.Post("/create", handlers.CreateProduct)
product.Post("/update", handlers.UpdateProduct) product.Post("/update", handlers.UpdateProduct)
product.Post("/remove", handlers.DeleteProduct) product.Post("/remove", handlers.DeleteProduct)
product.Post("/sku/all", handlers.AllProductSkuByAdmin)
product.Post("/sku/page", handlers.PageProductSkuByAdmin) product.Post("/sku/page", handlers.PageProductSkuByAdmin)
product.Post("/sku/create", handlers.CreateProductSku) product.Post("/sku/create", handlers.CreateProductSku)
product.Post("/sku/update", handlers.UpdateProductSku) product.Post("/sku/update", handlers.UpdateProductSku)

View File

@@ -2,6 +2,7 @@ package services
import ( import (
m "platform/web/models" m "platform/web/models"
q "platform/web/queries"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
) )
@@ -10,34 +11,41 @@ var Bill = &billService{}
type billService struct{} type billService struct{}
func (s *billService) GenNo() string { func (s *billService) CreateForBalance(q *q.Query, uid, tradeId int32, detail *TradeDetail) error {
return ID.GenReadable("bil") return q.Bill.Create(&m.Bill{
}
func newForRecharge(uid int32, billNo string, info string, amount decimal.Decimal, trade *m.Trade) *m.Bill {
return &m.Bill{
UserID: uid, UserID: uid,
BillNo: billNo, BillNo: ID.GenReadable("bil"),
TradeID: &trade.ID, TradeID: &tradeId,
Type: m.BillTypeRecharge, Type: m.BillTypeRecharge,
Info: &info, Info: &detail.Subject,
Amount: amount, Amount: detail.Amount,
} Actual: detail.Actual,
})
} }
func newForConsume(uid int32, billNo string, info string, amount decimal.Decimal, resource *m.Resource, trade ...*m.Trade) *m.Bill { func (s *billService) CreateForResourceByTrade(q *q.Query, uid, tradeId, resourceId int32, detail *TradeDetail) error {
var bill = &m.Bill{ return q.Bill.Create(&m.Bill{
UserID: uid, UserID: uid,
BillNo: billNo, BillNo: ID.GenReadable("bil"),
ResourceID: &resource.ID, ResourceID: &resourceId,
TradeID: &tradeId,
CouponID: detail.CouponId,
Type: m.BillTypeConsume, Type: m.BillTypeConsume,
Info: &info, Info: &detail.Subject,
Amount: amount, Amount: detail.Amount,
} Actual: detail.Actual,
})
if len(trade) > 0 { }
bill.TradeID = &trade[0].ID
} func (s *billService) CreateForResourceByBalance(q *q.Query, uid, resourceId int32, couponId *int32, subject string, amount, actual decimal.Decimal) error {
return q.Bill.Create(&m.Bill{
return bill UserID: uid,
BillNo: ID.GenReadable("bil"),
ResourceID: &resourceId,
CouponID: couponId,
Type: m.BillTypeConsume,
Info: &subject,
Amount: amount,
Actual: actual,
})
} }

64
web/services/coupon.go Normal file
View File

@@ -0,0 +1,64 @@
package services
import (
"errors"
"fmt"
"platform/web/core"
m "platform/web/models"
q "platform/web/queries"
"time"
"github.com/shopspring/decimal"
"gorm.io/gorm"
)
var Coupon = &couponService{}
type couponService struct{}
func (s *couponService) GetCouponAvailableByCode(code string, amount decimal.Decimal, uid *int32) (*m.Coupon, error) {
// 获取优惠券
coupon, err := q.Coupon.Where(
q.Coupon.Code.Eq(code),
q.Coupon.Status.Eq(int(m.CouponStatusUnused)),
q.Coupon.
Where(q.Coupon.ExpireAt.Gt(time.Now())).
Or(q.Coupon.ExpireAt.IsNull()),
).Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, core.NewBizErr("优惠券不存在或已失效")
}
if err != nil {
return nil, core.NewBizErr("获取优惠券数据失败", err)
}
// 检查最小使用额度
if amount.Cmp(coupon.MinAmount) < 0 {
return nil, core.NewBizErr(fmt.Sprintf("使用此优惠券的最小额度为 %s", coupon.MinAmount))
}
// 检查所属
if coupon.UserID != nil {
if uid == nil {
return nil, core.NewBizErr("检查优惠券所属用户失败")
}
if *coupon.UserID != *uid {
return nil, core.NewBizErr("优惠券不属于当前用户")
}
}
return coupon, nil
}
func (s *couponService) UseCoupon(q *q.Query, id int32) error {
_, err := q.Coupon.
Where(
q.Coupon.ID.Eq(id),
q.Coupon.Status.Eq(int(m.CouponStatusUnused)),
q.Coupon.ExpireAt.Gt(time.Now()),
).
UpdateSimple(
q.Coupon.Status.Value(int(m.CouponStatusUsed)),
)
return err
}

View File

@@ -15,6 +15,14 @@ var ProductSku = &productSkuService{}
type productSkuService struct{} type productSkuService struct{}
func (s *productSkuService) All(product_code string) (result []*m.ProductSku, err error) {
return q.ProductSku.
Joins(q.ProductSku.Product).
Where(q.Product.As("Product").Code.Eq(product_code)).
Select(q.ProductSku.ALL).
Find()
}
func (s *productSkuService) Page(req *core.PageReq, productId *int32) (result []*m.ProductSku, count int64, err error) { func (s *productSkuService) Page(req *core.PageReq, productId *int32) (result []*m.ProductSku, count int64, err error) {
do := make([]gen.Condition, 0) do := make([]gen.Condition, 0)
if productId != nil { if productId != nil {

View File

@@ -1,7 +1,6 @@
package services package services
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"platform/pkg/u" "platform/pkg/u"
@@ -11,6 +10,7 @@ import (
"time" "time"
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
"gorm.io/gen/field"
"gorm.io/gorm" "gorm.io/gorm"
) )
@@ -18,6 +18,7 @@ var Resource = &resourceService{}
type resourceService struct{} type resourceService struct{}
// CreateResourceByBalance 通过余额购买套餐
func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data *CreateResourceData) error { func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data *CreateResourceData) error {
// 找到用户 // 找到用户
@@ -29,16 +30,21 @@ func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data
} }
// 获取 sku // 获取 sku
sku, err := s.GetSku(data) sku, err := s.GetSku(data.Code())
if err != nil { if err != nil {
return err return err
} }
// 检查余额 // 检查余额
_, amount, err := s.GetPrice(sku, data.Count(), &uid) coupon, _, amount, actual, err := s.GetPrice(sku, data.Count(), &uid, data.CouponCode)
if err != nil { if err != nil {
return err return err
} }
couponId := (*int32)(nil)
if coupon != nil {
couponId = &coupon.ID
}
newBalance := user.Balance.Sub(amount) newBalance := user.Balance.Sub(amount)
if newBalance.IsNegative() { if newBalance.IsNegative() {
return ErrBalanceNotEnough return ErrBalanceNotEnough
@@ -58,49 +64,30 @@ func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data
} }
// 保存套餐 // 保存套餐
resource, err := createResource(q, uid, now, data) resource, err := s.Create(q, uid, now, data)
if err != nil { if err != nil {
return core.NewServErr("创建套餐失败", err) return core.NewServErr("创建套餐失败", err)
} }
// 生成账单 // 生成账单
err = q.Bill.Create(newForConsume(uid, Bill.GenNo(), sku.Name, amount, resource)) err = Bill.CreateForResourceByBalance(q, uid, resource.ID, couponId, sku.Name, amount, actual)
if err != nil { if err != nil {
return core.NewServErr("生成账单失败", err) return core.NewServErr("生成账单失败", err)
} }
// 核销优惠券
if coupon != nil {
err = Coupon.UseCoupon(q, coupon.ID)
if err != nil {
return core.NewServErr("核销优惠券失败", err)
}
}
return nil return nil
}) })
} }
func (s *resourceService) CreateResourceByTrade(uid int32, now time.Time, data *CreateResourceByTradeData, trade *m.Trade) error { // 检查交易 func (s *resourceService) Create(q *q.Query, uid int32, now time.Time, data *CreateResourceData) (*m.Resource, error) {
if trade == nil {
return core.NewBizErr("交易数据不能为空")
}
if trade.Status != m.TradeStatusSuccess {
return core.NewBizErr("交易状态不正确")
}
return q.Q.Transaction(func(q *q.Query) error {
// 保存套餐
resource, err := createResource(q, uid, now, data.Req)
if err != nil {
return core.NewServErr("创建套餐失败", err)
}
// 生成账单
err = q.Bill.Create(newForConsume(uid, Bill.GenNo(), data.GetSubject(), data.GetAmount(), resource, trade))
if err != nil {
return core.NewServErr("生成账单失败", err)
}
return nil
})
}
func createResource(q *q.Query, uid int32, now time.Time, data *CreateResourceData) (*m.Resource, error) {
// 套餐基本信息 // 套餐基本信息
var resource = m.Resource{ var resource = m.Resource{
UserID: uid, UserID: uid,
@@ -162,10 +149,35 @@ func createResource(q *q.Query, uid int32, now time.Time, data *CreateResourceDa
return &resource, nil return &resource, nil
} }
func (s *resourceService) GetSku(data *CreateResourceData) (*m.ProductSku, error) { func (s *resourceService) Update(data *UpdateResourceData) error {
if data.Active == nil {
return core.NewBizErr("更新套餐失败active 不能为空")
}
do := make([]field.AssignExpr, 0)
if data.Active != nil {
do = append(do, q.Resource.Active.Value(*data.Active))
}
_, err := q.Resource.
Where(q.Resource.ID.Eq(data.Id)).
UpdateSimple(do...)
if err != nil {
return core.NewServErr("更新套餐失败", err)
}
return nil
}
type UpdateResourceData struct {
core.IdReq
Active *bool `json:"active"`
}
func (s *resourceService) GetSku(code string) (*m.ProductSku, error) {
sku, err := q.ProductSku. sku, err := q.ProductSku.
Joins(q.ProductSku.Discount). Joins(q.ProductSku.Discount).
Where(q.ProductSku.Code.Eq(data.Code())). Where(q.ProductSku.Code.Eq(code)).
Take() Take()
if err != nil { if err != nil {
return nil, core.NewServErr("产品不可用", err) return nil, core.NewServErr("产品不可用", err)
@@ -178,43 +190,55 @@ func (s *resourceService) GetSku(data *CreateResourceData) (*m.ProductSku, error
return sku, nil return sku, nil
} }
func (s *resourceService) GetPrice(sku *m.ProductSku, count int32, uid *int32) (decimal.Decimal, decimal.Decimal, error) { func (s *resourceService) GetPrice(sku *m.ProductSku, count int32, uid *int32, couponCode *string) (*m.Coupon, decimal.Decimal, decimal.Decimal, decimal.Decimal, error) {
// 根据用户 id 查询特殊优惠 // 原价
var uSku *m.ProductSkuUser price := sku.Price
if uid != nil { amount := price.Mul(decimal.NewFromInt32(count))
// 折扣价
discount := sku.Discount.Decimal()
if uid != nil { // 用户特殊优惠
var err error var err error
uSku, err = q.ProductSkuUser. uSku, err := q.ProductSkuUser.
Joins(q.ProductSkuUser.Discount). Joins(q.ProductSkuUser.Discount).
Where( Where(
q.ProductSkuUser.UserID.Eq(*uid), q.ProductSkuUser.UserID.Eq(*uid),
q.ProductSkuUser.ProductSkuID.Eq(sku.ID)). q.ProductSkuUser.ProductSkuID.Eq(sku.ID)).
Take() Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return decimal.Zero, decimal.Zero, core.NewServErr("客户特殊价查询失败", err) return nil, decimal.Zero, decimal.Zero, decimal.Zero, core.NewServErr("客户特殊价查询失败", err)
} }
}
if uSku.Discount == nil { if uSku.Discount == nil {
return decimal.Decimal{}, decimal.Decimal{}, core.NewServErr("价格获取失败") return nil, decimal.Zero, decimal.Zero, decimal.Zero, core.NewServErr("价格获取失败")
}
uDiscount := uSku.Discount.Decimal()
if uDiscount.Cmp(discount) > 0 {
discount = uDiscount
}
}
discounted := amount.Mul(discount)
// 优惠价
coupon := (*m.Coupon)(nil)
couponApplied := discounted.Copy()
if couponCode != nil {
var err error
coupon, err = Coupon.GetCouponAvailableByCode(*couponCode, discounted, uid)
if err != nil {
return nil, decimal.Zero, decimal.Zero, decimal.Zero, err
}
couponApplied = discounted.Sub(coupon.Amount)
} }
// 返回计算价格 return coupon, amount, discounted, couponApplied, nil
price := sku.Price
discount := sku.Discount.Decimal()
if uSku != nil {
discount = uSku.Discount.Decimal()
}
before := price.Mul(decimal.NewFromInt32(count))
after := before.Mul(discount)
return before, after, nil
} }
type CreateResourceData struct { type CreateResourceData struct {
Type m.ResourceType `json:"type" validate:"required"` Type m.ResourceType `json:"type" validate:"required"`
Short *CreateShortResourceData `json:"short,omitempty"` Short *CreateShortResourceData `json:"short,omitempty"`
Long *CreateLongResourceData `json:"long,omitempty"` Long *CreateLongResourceData `json:"long,omitempty"`
CouponCode *string `json:"coupon,omitempty"`
} }
type CreateShortResourceData struct { type CreateShortResourceData struct {
@@ -267,71 +291,22 @@ func (c *CreateResourceData) Code() string {
} }
} }
func (c *CreateResourceData) Serialize() (string, error) { func (c *CreateResourceData) TradeDetail() (*TradeDetail, error) {
bytes, err := json.Marshal(c) sku, err := Resource.GetSku(c.Code())
return string(bytes), err
}
func (c *CreateResourceData) Deserialize(str string) error {
return json.Unmarshal([]byte(str), c)
}
// 交易后创建套餐
type ResourceOnTradeComplete struct{}
func (r ResourceOnTradeComplete) Check(t m.TradeType) (ProductInfo, bool) {
if t == m.TradeTypePurchase {
return &CreateResourceByTradeData{}, true
}
return nil, false
}
func (r ResourceOnTradeComplete) OnTradeComplete(info ProductInfo, trade *m.Trade) error {
return Resource.CreateResourceByTrade(trade.UserID, time.Time(*trade.CompletedAt), info.(*CreateResourceByTradeData), trade)
}
type CreateResourceByTradeData struct {
Subject string `json:"subject"`
Amount decimal.Decimal `json:"amount"`
Req *CreateResourceData `json:"data"`
}
func (e CreateResourceByTradeData) GetType() m.TradeType {
return m.TradeTypePurchase
}
func (e CreateResourceByTradeData) GetSubject() string {
return e.Subject
}
func (e CreateResourceByTradeData) GetAmount() decimal.Decimal {
return e.Amount
}
func (e CreateResourceByTradeData) Serialize() (string, error) {
bytes, err := json.Marshal(e)
return string(bytes), err
}
func (e *CreateResourceByTradeData) Deserialize(str string) error {
return json.Unmarshal([]byte(str), e)
}
func NewCreateResourceByTradeData(req *CreateResourceData) (*CreateResourceByTradeData, error) {
sku, err := Resource.GetSku(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, amount, err := Resource.GetPrice(sku, req.Count(), nil) coupon, _, amount, actual, err := Resource.GetPrice(sku, c.Count(), nil, c.CouponCode)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &CreateResourceByTradeData{ return &TradeDetail{
Subject: sku.Name, m.TradeTypePurchase,
Amount: amount, sku.Name,
Req: req, amount, actual,
&coupon.ID, c,
}, nil }, nil
} }

View File

@@ -2,6 +2,7 @@ package services
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -23,7 +24,6 @@ import (
"github.com/smartwalle/alipay/v3" "github.com/smartwalle/alipay/v3"
"github.com/wechatpay-apiv3/wechatpay-go/services/partnerpayments/h5" "github.com/wechatpay-apiv3/wechatpay-go/services/partnerpayments/h5"
"github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native"
"gorm.io/gorm"
) )
var Trade = &tradeService{} var Trade = &tradeService{}
@@ -32,72 +32,17 @@ type tradeService struct {
} }
// 创建交易 // 创建交易
func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTradeData, product ProductInfo) (*CreateTradeResult, error) { func (s *tradeService) Create(uid int32, tradeData *CreateTradeData, productData *CreateResourceData) (*CreateTradeResult, error) {
platform := payment.Platform detail, err := productData.TradeDetail()
method := payment.Method
tType := product.GetType()
expire := time.Now().Add(30 * time.Minute)
subject := product.GetSubject()
amount := product.GetAmount()
// 实际支付金额,只在创建真实订单时使用
amountReal := amount
if env.RunMode == env.RunModeDev {
amountReal = decimal.NewFromFloat(0.01)
}
// 附加优惠券
if payment.CouponCode != nil {
coupon, err := q.Coupon.
Where(
q.Coupon.Code.Eq(*payment.CouponCode),
q.Coupon.Status.Eq(int(m.CouponStatusUnused)),
).
Take()
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { return nil, core.NewServErr("获取产品支付信息失败", err)
return nil, errors.New("优惠券不存在或已失效")
}
return nil, err
} }
expireAt := time.Time(u.Z(coupon.ExpireAt)) now := time.Now()
if !expireAt.IsZero() && expireAt.Before(now) { platform := tradeData.Platform
_, err = q.Coupon. method := tradeData.Method
Where(q.Coupon.ID.Eq(coupon.ID)). expireIn := time.Duration(env.TradeExpire) * time.Second
Update(q.Coupon.Status, m.CouponStatusExpired) expireAt := now.Add(expireIn)
if err != nil {
return nil, err
}
return nil, errors.New("优惠券已过期")
}
if amount.Cmp(coupon.MinAmount) < 0 {
return nil, errors.New("订单金额未达到使用优惠券的条件")
}
if coupon.UserID != nil {
switch *coupon.UserID {
// 指定用户的优惠券
case uid:
amount = amount.Sub(coupon.Amount)
if expireAt.IsZero() {
_, err = q.Coupon.
Where(q.Coupon.ID.Eq(coupon.ID)).
Update(q.Coupon.Status, int(m.CouponStatusUsed))
if err != nil {
return nil, err
}
}
// 该优惠券不属于当前用户
default:
return nil, errors.New("优惠券不属于当前用户")
}
} else {
// 公开优惠券
amount = amount.Sub(coupon.Amount)
}
}
// 生成订单号 // 生成订单号
tradeNo, err := ID.GenSerial() tradeNo, err := ID.GenSerial()
@@ -105,6 +50,12 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
return nil, core.NewServErr("生成订单号失败", err) return nil, core.NewServErr("生成订单号失败", err)
} }
// 实际支付金额,只在创建真实订单时使用
amountReal := detail.Actual
if env.RunMode == env.RunModeDev {
amountReal = decimal.NewFromFloat(0.01)
}
// 提交支付订单 // 提交支付订单
var paymentUrl string var paymentUrl string
switch { switch {
@@ -117,9 +68,9 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
Trade: alipay.Trade{ Trade: alipay.Trade{
ProductCode: "FAST_INSTANT_TRADE_PAY", ProductCode: "FAST_INSTANT_TRADE_PAY",
OutTradeNo: tradeNo, OutTradeNo: tradeNo,
Subject: subject, Subject: detail.Subject,
TotalAmount: amountReal.StringFixed(2), TotalAmount: amountReal.StringFixed(2),
TimeExpire: expire.Format("2006-01-02 15:04:05"), TimeExpire: expireAt.Format("2006-01-02 15:04:05"),
}, },
}) })
if err != nil { if err != nil {
@@ -133,8 +84,8 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
Appid: &env.WechatPayAppId, Appid: &env.WechatPayAppId,
Mchid: &env.WechatPayMchId, Mchid: &env.WechatPayMchId,
OutTradeNo: &tradeNo, OutTradeNo: &tradeNo,
Description: &subject, Description: &detail.Subject,
TimeExpire: &expire, TimeExpire: &expireAt,
NotifyUrl: &env.WechatPayCallbackUrl, NotifyUrl: &env.WechatPayCallbackUrl,
Amount: &native.Amount{ Amount: &native.Amount{
Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()), Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()),
@@ -151,8 +102,8 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
SpAppid: &env.WechatPayAppId, SpAppid: &env.WechatPayAppId,
SpMchid: &env.WechatPayMchId, SpMchid: &env.WechatPayMchId,
OutTradeNo: &tradeNo, OutTradeNo: &tradeNo,
Description: &subject, Description: &detail.Subject,
TimeExpire: &expire, TimeExpire: &expireAt,
NotifyUrl: &env.WechatPayCallbackUrl, NotifyUrl: &env.WechatPayCallbackUrl,
Amount: &h5.Amount{ Amount: &h5.Amount{
Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()), Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()),
@@ -174,18 +125,17 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
payType = g.SftAlipay payType = g.SftAlipay
case m.TradeMethodSftWechat: case m.TradeMethodSftWechat:
payType = g.SftWeChat payType = g.SftWeChat
default:
panic("unhandled default case")
} }
resp, err := g.SFTPay.PaymentScanPay(&g.PaymentScanPayReq{ resp, err := g.SFTPay.PaymentScanPay(&g.PaymentScanPayReq{
MchOrderNo: tradeNo, MchOrderNo: tradeNo,
Subject: subject, Subject: detail.Subject,
Body: subject, Body: detail.Subject,
Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(), Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(),
PayType: payType, PayType: payType,
Currency: "cny", Currency: "cny",
ClientIp: "123.52.74.23", ClientIp: "123.52.74.23",
OrderTimeout: u.P(expire.Format("2006-01-02 15:04:05")), OrderTimeout: u.P(expireAt.Format("2006-01-02 15:04:05")),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -196,24 +146,24 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
case case
method == m.TradeMethodSftAlipay && platform == m.TradePlatformMobile, method == m.TradeMethodSftAlipay && platform == m.TradePlatformMobile,
method == m.TradeMethodSftWechat && platform == m.TradePlatformMobile: method == m.TradeMethodSftWechat && platform == m.TradePlatformMobile:
var payType g.SftPayType var payType g.SftPayType
switch method { switch method {
case m.TradeMethodSftAlipay: case m.TradeMethodSftAlipay:
payType = g.SftAlipay payType = g.SftAlipay
case m.TradeMethodSftWechat: case m.TradeMethodSftWechat:
payType = g.SftWeChat payType = g.SftWeChat
default:
panic("unhandled default case")
} }
resp, err := g.SFTPay.PaymentH5Pay(&g.PaymentH5PayReq{ resp, err := g.SFTPay.PaymentH5Pay(&g.PaymentH5PayReq{
MchOrderNo: tradeNo, MchOrderNo: tradeNo,
Subject: subject, Subject: detail.Subject,
Body: subject, Body: detail.Subject,
Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(), Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(),
PayType: payType, PayType: payType,
Currency: "cny", Currency: "cny",
ClientIp: "123.52.74.23", ClientIp: "123.52.74.23",
OrderTimeout: u.P(expire.Format("2006-01-02 15:04:05")), OrderTimeout: u.P(expireAt.Format("2006-01-02 15:04:05")),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -230,9 +180,9 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
err = q.Trade.Create(&m.Trade{ err = q.Trade.Create(&m.Trade{
UserID: uid, UserID: uid,
InnerNo: tradeNo, InnerNo: tradeNo,
Type: tType, Type: detail.Type,
Subject: subject, Subject: detail.Subject,
Amount: amount, Amount: detail.Actual,
Method: method, Method: method,
Platform: platform, Platform: platform,
PaymentURL: &paymentUrl, PaymentURL: &paymentUrl,
@@ -242,7 +192,7 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
} }
// 缓存产品数据 // 缓存产品数据
serialized, err := product.Serialize() serialized, err := json.Marshal(detail)
if err != nil { if err != nil {
return nil, core.NewServErr("序列化产品信息失败", err) return nil, core.NewServErr("序列化产品信息失败", err)
} }
@@ -251,34 +201,29 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad
context.Background(), context.Background(),
tradeProductKey(tradeNo), tradeProductKey(tradeNo),
serialized, serialized,
time.Duration(env.TradeExpire+10)*time.Second, expireIn,
).Err() ).Err()
if err != nil { if err != nil {
return nil, core.NewServErr("保存购买信息失败", err) return nil, core.NewServErr("保存购买信息失败", err)
} }
// 提交异步关闭事件 // 提交异步关闭事件
closeAt := now.Add(time.Duration(env.TradeExpire) * time.Second) _, err = g.Asynq.Enqueue(e.NewCloseTradeTask(uid, tradeNo, method), asynq.ProcessAt(expireAt))
_, err = g.Asynq.Enqueue(e.NewCancelTrade(e.CompleteTradeData{
TradeNo: tradeNo,
Method: method,
}), asynq.ProcessAt(closeAt))
if err != nil { if err != nil {
return nil, core.NewServErr("提交异步关闭事件失败", err) return nil, core.NewServErr("提交异步关闭事件失败", err)
} }
return &CreateTradeResult{ return &CreateTradeResult{
PaymentUrl: paymentUrl, PayUrl: paymentUrl,
TradeNo: tradeNo, TradeNo: tradeNo,
}, nil }, nil
} }
// 完成交易 // 完成交易
func (s *tradeService) CompleteTrade(data *ModifyTradeData) error { func (s *tradeService) CompleteTrade(user *m.User, ref *TradeRef) error {
return g.Redsync.WithLock(tradeLockKey(data.TradeNo), func() error {
// 检查订单状态 // 检查订单状态
result, err := s.CheckTrade(data) result, err := s.CheckTrade(ref)
if err != nil { if err != nil {
return core.NewServErr("检查订单状态失败", err) return core.NewServErr("检查订单状态失败", err)
} }
@@ -292,55 +237,18 @@ func (s *tradeService) CompleteTrade(data *ModifyTradeData) error {
} }
// 更新交易状态 // 更新交易状态
trade, err := completeTrade(&OnTradeCompletedData{ err = s.OnCompleteTrade(user, ref.TradeNo, result.TransId, &result.Success)
data.TradeNo,
result.TransId,
result.Success,
})
if err != nil { if err != nil {
return core.NewServErr("处理交易失败", err) return core.NewServErr("处理交易失败", err)
} }
// 处理交易完成事件
err = afterTradeComplete(trade)
if err != nil {
return core.NewServErr("处理交易完成事件失败", err)
}
return nil return nil
})
} }
func (s *tradeService) OnTradeCompleted(data *OnTradeCompletedData) error { func (s *tradeService) OnCompleteTrade(user *m.User, interNo string, outerNo string, result *TradeSuccessResult) error {
return g.Redsync.WithLock(tradeLockKey(data.TradeNo), func() error {
// 更新交易状态
trade, err := completeTrade(data)
if err != nil {
return core.NewServErr("处理交易失败", err)
}
// 处理交易完成事件
err = afterTradeComplete(trade)
if err != nil {
return core.NewServErr("处理交易完成事件失败", err)
}
return nil
})
}
func completeTrade(data *OnTradeCompletedData) (*m.Trade, error) {
var trade = new(m.Trade)
var err = q.Q.Transaction(func(tx *q.Query) error {
var tradeNo = data.TradeNo
var transId = data.TransId
var payment = data.Payment
var acquirer = data.Acquirer
var paidAt = data.Time
// 获取交易信息 // 获取交易信息
var err error trade, err := q.Trade.
trade, err = q.Trade. Where(q.Trade.InnerNo.Eq(interNo)).
Where(q.Trade.InnerNo.Eq(tradeNo)).
Take() Take()
if err != nil { if err != nil {
return core.NewBizErr("获取交易信息失败", err) return core.NewBizErr("获取交易信息失败", err)
@@ -355,75 +263,92 @@ func completeTrade(data *OnTradeCompletedData) (*m.Trade, error) {
case m.TradeStatusPending: case m.TradeStatusPending:
} }
// 更新交易信息 // 恢复购买信息
trade.Status = m.TradeStatusSuccess detailStr, err := g.Redis.Get(context.Background(), tradeProductKey(interNo)).Result()
trade.OuterNo = &transId if err != nil {
trade.Payment = payment return core.NewServErr("恢复购买信息失败", err)
trade.Acquirer = u.P(acquirer)
trade.CompletedAt = u.P(paidAt)
rs, err := q.Trade.
Where(q.Trade.InnerNo.Eq(tradeNo), q.Trade.Status.Eq(int(m.TradeStatusPending))).
Updates(trade)
if rs.RowsAffected == 0 {
return core.NewBizErr("交易状态已发生变化")
} }
var detail TradeDetail
if err := json.Unmarshal([]byte(detailStr), &detail); err != nil {
return core.NewServErr("解析购买信息失败", err)
}
err = q.Q.Transaction(func(q *q.Query) error {
// 更新交易信息
_, err := q.Trade.
Where(
q.Trade.InnerNo.Eq(interNo),
q.Trade.Status.Eq(int(m.TradeStatusPending)),
).
UpdateSimple(
q.Trade.Status.Value(int(m.TradeStatusSuccess)),
q.Trade.OuterNo.Value(outerNo),
q.Trade.Payment.Value(result.Actual),
q.Trade.Acquirer.Value(int(result.Acquirer)),
q.Trade.CompletedAt.Value(result.Time),
)
if err != nil { if err != nil {
return core.NewServErr("更新交易信息失败", err) return core.NewServErr("更新交易信息失败", err)
} }
switch trade.Type {
case m.TradeTypeRecharge:
// 更新用户余额
if err := User.UpdateBalance(q, user, detail.Actual); err != nil {
return err
}
// 生成账单
err = Bill.CreateForBalance(q, user.ID, trade.ID, &detail)
if err != nil {
return core.NewServErr("生成账单失败", err)
}
case m.TradeTypePurchase:
data, ok := detail.Product.(*CreateResourceData)
if !ok {
return core.NewServErr("购买信息解析失败", nil)
}
// 保存套餐
resource, err := Resource.Create(q, user.ID, result.Time, data)
if err != nil {
return core.NewServErr("创建套餐失败", err)
}
// 生成账单
err = Bill.CreateForResourceByTrade(q, user.ID, resource.ID, trade.ID, &detail)
if err != nil {
return core.NewServErr("生成账单失败", err)
}
// 核销优惠券
if detail.CouponId != nil {
err = Coupon.UseCoupon(q, *detail.CouponId)
if err != nil {
return core.NewServErr("核销优惠券失败", err)
}
}
}
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return err
} else {
return trade, err
}
}
func afterTradeComplete(trade *m.Trade) error {
// 恢复购买信息
productData, err := g.Redis.Get(context.Background(), tradeProductKey(trade.InnerNo)).Result()
if err != nil {
return core.NewServErr("恢复购买信息失败", err)
}
// 执行资源创建
var ComplementEvents = []CompleteEvent{
ResourceOnTradeComplete{},
UserOnTradeComplete{},
}
for _, event := range ComplementEvents {
info, ok := event.Check(trade.Type)
if !ok {
continue
}
err = info.Deserialize(productData)
if err != nil {
return core.NewServErr("反序列化购买信息失败", err)
}
err = event.OnTradeComplete(info, trade)
if err != nil {
return core.NewServErr("处理交易完成事件失败", err)
}
} }
return nil return nil
} }
// 取消交易 // 取消交易
func (s *tradeService) CancelTrade(data *ModifyTradeData, now time.Time) error { func (s *tradeService) CancelTrade(ref *TradeRef) error {
tradeNo := data.TradeNo now := time.Now()
method := data.Method
return g.Redsync.WithLock(tradeLockKey(tradeNo), func() error {
switch method {
switch ref.Method {
case m.TradeMethodAlipay: case m.TradeMethodAlipay:
resp, err := g.Alipay.TradeCancel(context.Background(), alipay.TradeCancel{ resp, err := g.Alipay.TradeCancel(context.Background(), alipay.TradeCancel{
OutTradeNo: tradeNo, OutTradeNo: ref.TradeNo,
}) })
if err != nil { if err != nil {
return core.NewServErr("上游取消交易失败", err) return core.NewServErr("上游取消交易失败", err)
@@ -436,7 +361,7 @@ func (s *tradeService) CancelTrade(data *ModifyTradeData, now time.Time) error {
case m.TradeMethodWechat: case m.TradeMethodWechat:
resp, err := g.WechatPay.Native.CloseOrder(context.Background(), native.CloseOrderRequest{ resp, err := g.WechatPay.Native.CloseOrder(context.Background(), native.CloseOrderRequest{
Mchid: &env.WechatPayMchId, Mchid: &env.WechatPayMchId,
OutTradeNo: &tradeNo, OutTradeNo: &ref.TradeNo,
}) })
if err != nil { if err != nil {
return core.NewServErr("上游取消交易失败", err) return core.NewServErr("上游取消交易失败", err)
@@ -453,7 +378,7 @@ func (s *tradeService) CancelTrade(data *ModifyTradeData, now time.Time) error {
case m.TradeMethodSft, m.TradeMethodSftAlipay, m.TradeMethodSftWechat: case m.TradeMethodSft, m.TradeMethodSftAlipay, m.TradeMethodSftWechat:
_, err := g.SFTPay.OrderClose(&g.OrderCloseReq{ _, err := g.SFTPay.OrderClose(&g.OrderCloseReq{
MchOrderNo: &tradeNo, MchOrderNo: &ref.TradeNo,
}) })
if err != nil { if err != nil {
slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error())) slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error()))
@@ -464,47 +389,19 @@ func (s *tradeService) CancelTrade(data *ModifyTradeData, now time.Time) error {
return ErrTransactionNotSupported return ErrTransactionNotSupported
} }
err := cancelTrade(tradeNo, now) err := s.OnCancelTrade(ref.TradeNo, now)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
})
} }
func (s *tradeService) OnTradeCanceled(tradeNo string, now time.Time) error { func (s *tradeService) OnCancelTrade(tradeNo string, now time.Time) error {
err := g.Redsync.WithLock(tradeLockKey(tradeNo), func() error { _, err := q.Trade.
return cancelTrade(tradeNo, now) Where(
}) q.Trade.InnerNo.Eq(tradeNo),
if err != nil { q.Trade.Status.Eq(int(m.TradeStatusPending)),
return core.NewServErr("处理交易取消失败", err) ).
}
return nil
}
func cancelTrade(tradeNo string, now time.Time) error {
return q.Q.Transaction(func(q *q.Query) error {
// 获取交易信息
var status m.TradeStatus
err := q.Trade.
Where(q.Trade.InnerNo.Eq(tradeNo)).
Select(q.Trade.Status).
Scan(&status)
if err != nil {
return core.NewBizErr("获取交易信息失败", err)
}
// 检查交易状态
switch status {
case m.TradeStatusCanceled:
return core.NewBizErr("交易已取消")
case m.TradeStatusSuccess:
return core.NewBizErr("交易已完成")
case m.TradeStatusPending:
}
// 更新交易状态
_, err = q.Trade.
Where(q.Trade.InnerNo.Eq(tradeNo)).
UpdateSimple( UpdateSimple(
q.Trade.Status.Value(int(m.TradeStatusCanceled)), q.Trade.Status.Value(int(m.TradeStatusCanceled)),
q.Trade.CanceledAt.Value(now), q.Trade.CanceledAt.Value(now),
@@ -512,25 +409,25 @@ func cancelTrade(tradeNo string, now time.Time) error {
if err != nil { if err != nil {
return core.NewServErr("更新交易状态失败", err) return core.NewServErr("更新交易状态失败", err)
} }
return nil return nil
})
} }
// 交易退款 // 交易退款
func (s *tradeService) RefundTrade(data *ModifyTradeData) error { func (s *tradeService) RefundTrade(ref *TradeRef) error {
panic("todo") panic("todo")
} }
func (s *tradeService) OnTradeRefunded(q *q.Query, tradeNo string, now time.Time) error { func (s *tradeService) OnRefundTrade(q *q.Query, tradeNo string, now time.Time) error {
panic("todo") panic("todo")
} }
// 检查交易状态 // 检查交易状态
func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, error) { func (s *tradeService) CheckTrade(ref *TradeRef) (*CheckTradeResult, error) {
var tradeNo = data.TradeNo var tradeNo = ref.TradeNo
var method = data.Method var method = ref.Method
// 检查交易号是否存在 // 检查交易号是否存在
var result = new(CheckTradeResult) var result CheckTradeResult
switch method { switch method {
// 支付宝 // 支付宝
@@ -560,9 +457,8 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
case alipay.TradeStatusSuccess, alipay.TradeStatusFinished: case alipay.TradeStatusSuccess, alipay.TradeStatusFinished:
result.Status = m.TradeStatusSuccess result.Status = m.TradeStatusSuccess
result.Success = &TradeSuccessResult{}
result.Success.Acquirer = m.TradeAcquirerAlipay result.Success.Acquirer = m.TradeAcquirerAlipay
result.Success.Payment, err = decimal.NewFromString(resp.TotalAmount) result.Success.Actual, err = decimal.NewFromString(resp.ReceiptAmount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -606,9 +502,8 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
case "SUCCESS", "REFUND": case "SUCCESS", "REFUND":
result.Status = m.TradeStatusSuccess result.Status = m.TradeStatusSuccess
result.Success = &TradeSuccessResult{}
result.Success.Acquirer = m.TradeAcquirerWechat result.Success.Acquirer = m.TradeAcquirerWechat
result.Success.Payment = decimal.NewFromInt(*resp.Amount.PayerTotal).Div(decimal.NewFromInt(100)) result.Success.Actual = decimal.NewFromInt(*resp.Amount.PayerTotal).Div(decimal.NewFromInt(100))
result.Success.Time, err = time.Parse(time.RFC3339, *resp.SuccessTime) result.Success.Time, err = time.Parse(time.RFC3339, *resp.SuccessTime)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -626,12 +521,12 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
return nil, err return nil, err
} }
// 填充返回值
if resp.PayOrderId == nil { if resp.PayOrderId == nil {
return nil, errors.New("商福通交易号不存在") return nil, errors.New("商福通交易号不存在")
} }
// 填充返回值
result.TransId = *resp.PayOrderId result.TransId = *resp.PayOrderId
switch resp.State { switch resp.State {
case g.SftInit, g.SftTradeAwait, g.SftTradeFail: case g.SftInit, g.SftTradeAwait, g.SftTradeFail:
@@ -642,7 +537,6 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
case g.SftTradeSuccess, g.SftTradeRefund, g.SftRefundIng: case g.SftTradeSuccess, g.SftTradeRefund, g.SftRefundIng:
result.Status = m.TradeStatusSuccess result.Status = m.TradeStatusSuccess
result.Success = &TradeSuccessResult{}
switch resp.PayType { switch resp.PayType {
case "WECHAT": case "WECHAT":
result.Success.Acquirer = m.TradeAcquirerWechat result.Success.Acquirer = m.TradeAcquirerWechat
@@ -651,7 +545,7 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
case "UNIONPAY": case "UNIONPAY":
result.Success.Acquirer = m.TradeAcquirerUnionPay result.Success.Acquirer = m.TradeAcquirerUnionPay
} }
result.Success.Payment = decimal.NewFromInt(resp.Amount).Div(decimal.NewFromInt(100)) result.Success.Actual = decimal.NewFromInt(resp.Amount).Div(decimal.NewFromInt(100))
result.Success.Time, err = time.Parse("2006-01-02 15:04:05", *resp.PayTime) result.Success.Time, err = time.Parse("2006-01-02 15:04:05", *resp.PayTime)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -663,7 +557,7 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err
return nil, ErrTransactionNotSupported return nil, ErrTransactionNotSupported
} }
return result, nil return &result, nil
} }
func tradeProductKey(no string) string { func tradeProductKey(no string) string {
@@ -677,15 +571,14 @@ func tradeLockKey(no string) string {
type CreateTradeData struct { type CreateTradeData struct {
Platform m.TradePlatform `json:"platform" validate:"required"` Platform m.TradePlatform `json:"platform" validate:"required"`
Method m.TradeMethod `json:"method" validate:"required"` Method m.TradeMethod `json:"method" validate:"required"`
CouponCode *string `json:"coupon_code"`
} }
type CreateTradeResult struct { type CreateTradeResult struct {
TradeNo string PayUrl string `json:"pay_url"`
PaymentUrl string TradeNo string `json:"trade_no"`
} }
type ModifyTradeData struct { type TradeRef struct {
TradeNo string `json:"trade_no" query:"trade_no" validate:"required"` TradeNo string `json:"trade_no" query:"trade_no" validate:"required"`
Method m.TradeMethod `json:"method" validate:"required"` Method m.TradeMethod `json:"method" validate:"required"`
} }
@@ -693,12 +586,12 @@ type ModifyTradeData struct {
type CheckTradeResult struct { type CheckTradeResult struct {
TransId string TransId string
Status m.TradeStatus Status m.TradeStatus
Success *TradeSuccessResult Success TradeSuccessResult
} }
type TradeSuccessResult struct { type TradeSuccessResult struct {
Acquirer m.TradeAcquirer Acquirer m.TradeAcquirer
Payment decimal.Decimal Actual decimal.Decimal
Time time.Time Time time.Time
} }
@@ -709,11 +602,16 @@ type OnTradeCompletedData struct {
} }
type ProductInfo interface { type ProductInfo interface {
GetType() m.TradeType TradeDetail() (*TradeDetail, error)
GetSubject() string }
GetAmount() decimal.Decimal
Serialize() (string, error) type TradeDetail struct {
Deserialize(str string) error Type m.TradeType `json:"type"`
Subject string `json:"subject"`
Amount decimal.Decimal `json:"amount"`
Actual decimal.Decimal `json:"actual"`
CouponId *int32 `json:"coupon_id,omitempty"`
Product ProductInfo `json:"product"`
} }
type CompleteEvent interface { type CompleteEvent interface {

View File

@@ -1,10 +1,8 @@
package services package services
import ( import (
"encoding/json"
"fmt" "fmt"
"platform/web/core" "platform/web/core"
g "platform/web/globals"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
@@ -15,48 +13,29 @@ var User = &userService{}
type userService struct{} type userService struct{}
func (s *userService) UpdateBalanceByTrade(uid int32, info *RechargeProductInfo, trade *m.Trade) (err error) { func (s *userService) Get(q *q.Query, uid int32) (*m.User, error) {
err = g.Redsync.WithLock(userBalanceKey(uid), func() error {
return q.Q.Transaction(func(q *q.Query) error {
err = updateBalance(q, uid, info)
if err != nil {
return err
}
// 生成账单
subject := info.GetSubject()
amount := info.GetAmount()
err = q.Bill.Create(newForRecharge(uid, Bill.GenNo(), subject, amount, trade))
if err != nil {
return core.NewServErr("生成账单失败", err)
}
return nil
})
})
if err != nil {
return core.NewServErr("更新用户余额失败")
}
return nil
}
func updateBalance(q *q.Query, uid int32, info *RechargeProductInfo) error {
user, err := q.User. user, err := q.User.
Where(q.User.ID.Eq(uid)).Take() Where(q.User.ID.Eq(uid)).Take()
if err != nil { if err != nil {
return core.NewServErr("查询用户失败", err) return nil, core.NewServErr("查询用户失败", err)
} }
return user, nil
}
amount := info.GetAmount() func (s *userService) UpdateBalance(q *q.Query, user *m.User, amount decimal.Decimal) error {
balance := user.Balance.Add(amount) balance := user.Balance.Add(amount)
if balance.IsNegative() { if balance.IsNegative() {
return core.NewServErr("用户余额不足") return core.NewServErr("用户余额不足")
} }
_, err = q.User. _, err := q.User.
Where(q.User.ID.Eq(user.ID)). Where(
UpdateSimple(q.User.Balance.Value(balance)) q.User.ID.Eq(user.ID),
q.User.Balance.Eq(user.Balance),
).
UpdateSimple(
q.User.Balance.Value(balance),
)
if err != nil { if err != nil {
return core.NewServErr("更新用户余额失败", err) return core.NewServErr("更新用户余额失败", err)
} }
@@ -68,40 +47,16 @@ func userBalanceKey(uid int32) string {
return fmt.Sprintf("user:%d:balance", uid) return fmt.Sprintf("user:%d:balance", uid)
} }
type RechargeProductInfo struct { type UpdateBalanceData struct {
Amount int `json:"amount"` Amount int `json:"amount"`
} }
func (r *RechargeProductInfo) GetType() m.TradeType { func (c *UpdateBalanceData) TradeDetail() (*TradeDetail, error) {
return m.TradeTypeRecharge amount := decimal.NewFromInt(int64(c.Amount)).Div(decimal.NewFromInt(100))
} return &TradeDetail{
m.TradeTypeRecharge,
func (r *RechargeProductInfo) GetSubject() string { fmt.Sprintf("账户充值 - %s元", amount.StringFixed(2)),
return fmt.Sprintf("账户充值 - %s元", r.GetAmount().StringFixed(2)) amount, amount,
} nil, c,
}, nil
func (r *RechargeProductInfo) GetAmount() decimal.Decimal {
return decimal.NewFromInt(int64(r.Amount)).Div(decimal.NewFromInt(100))
}
func (r *RechargeProductInfo) Serialize() (string, error) {
bytes, err := json.Marshal(r)
return string(bytes), err
}
func (r *RechargeProductInfo) Deserialize(str string) error {
return json.Unmarshal([]byte(str), r)
}
type UserOnTradeComplete struct{}
func (u UserOnTradeComplete) Check(t m.TradeType) (ProductInfo, bool) {
if t == m.TradeTypeRecharge {
return &RechargeProductInfo{}, true
}
return nil, false
}
func (u UserOnTradeComplete) OnTradeComplete(info ProductInfo, trade *m.Trade) error {
return User.UpdateBalanceByTrade(trade.UserID, info.(*RechargeProductInfo), trade)
} }

View File

@@ -6,29 +6,34 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"platform/web/events" "platform/web/events"
q "platform/web/queries"
s "platform/web/services" s "platform/web/services"
"time"
"github.com/hibiken/asynq" "github.com/hibiken/asynq"
) )
func HandleCompleteTrade(_ context.Context, task *asynq.Task) (err error) { func HandleCompleteTrade(_ context.Context, task *asynq.Task) error {
event := new(events.CompleteTradeData) var event events.CloseTradeData
err = json.Unmarshal(task.Payload(), event) if err := json.Unmarshal(task.Payload(), &event); err != nil {
if err != nil {
return fmt.Errorf("解析任务参数失败: %w", err) return fmt.Errorf("解析任务参数失败: %w", err)
} }
data := &s.ModifyTradeData{ data := s.TradeRef{
TradeNo: event.TradeNo, TradeNo: event.TradeNo,
Method: event.Method, Method: event.Method,
} }
err = s.Trade.CompleteTrade(data) // 尝试完成交易
user, err := s.User.Get(q.Q, event.UserId)
if err != nil { if err != nil {
return fmt.Errorf("获取用户失败: %w", err)
}
if err := s.Trade.CompleteTrade(user, &data); err != nil {
slog.Debug("完成交易失败[异步结束订单]", "err", err) slog.Debug("完成交易失败[异步结束订单]", "err", err)
err = s.Trade.CancelTrade(data, time.Now())
if err != nil { // 交易无法完成,关闭交易
if err := s.Trade.CancelTrade(&data); err != nil {
return fmt.Errorf("取消交易失败[异步结束订单]: %w", err) return fmt.Errorf("取消交易失败[异步结束订单]: %w", err)
} }
} }

View File

@@ -89,7 +89,7 @@ func RunTask(ctx context.Context) error {
var mux = asynq.NewServeMux() var mux = asynq.NewServeMux()
mux.HandleFunc(events.RemoveChannel, tasks.HandleRemoveChannel) mux.HandleFunc(events.RemoveChannel, tasks.HandleRemoveChannel)
mux.HandleFunc(events.CompleteTrade, tasks.HandleCompleteTrade) mux.HandleFunc(events.CloseTrade, tasks.HandleCompleteTrade)
// 停止服务 // 停止服务
go func() { go func() {