From 75ad12efb3061516dfb3066293323e98b56414f0 Mon Sep 17 00:00:00 2001 From: luorijun Date: Thu, 26 Mar 2026 14:39:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E5=A5=97=E9=A4=90=E4=B8=8E?= =?UTF-8?q?=E8=B4=A6=E5=8D=95=E6=8E=A5=E5=8F=A3=20&=20=E5=AE=8C=E5=96=84?= =?UTF-8?q?=E6=94=AF=E4=BB=98=E6=95=B0=E6=8D=AE=E4=BF=9D=E5=AD=98=EF=BC=8C?= =?UTF-8?q?=E8=AE=B0=E5=BD=95=E5=AE=9E=E4=BB=98=E4=BB=B7=E6=A0=BC=E5=B9=B6?= =?UTF-8?q?=E5=85=B3=E8=81=94=E4=BC=98=E6=83=A0=E5=88=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 + pkg/env/env.go | 2 +- scripts/sql/fill.sql | 4 +- scripts/sql/init.sql | 18 +- web/core/scopes.go | 28 +- web/error.go | 7 + web/events/trade.go | 15 +- web/globals/orm/timez.go | 30 ++ web/handlers/bill.go | 18 +- web/handlers/product.go | 23 ++ web/handlers/resource.go | 96 ++++++- web/handlers/trade.go | 47 ++- web/models/bill.go | 4 +- web/queries/bill.gen.go | 10 +- web/routes.go | 12 + web/services/bill.go | 56 ++-- web/services/coupon.go | 64 +++++ web/services/product_sku.go | 8 + web/services/resource.go | 201 ++++++------- web/services/trade.go | 560 +++++++++++++++--------------------- web/services/user.go | 89 ++---- web/tasks/task.go | 23 +- web/web.go | 2 +- 23 files changed, 706 insertions(+), 613 deletions(-) create mode 100644 web/globals/orm/timez.go create mode 100644 web/services/coupon.go diff --git a/README.md b/README.md index a03f067..0097629 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ ## TODO +增删改数据权限排查 + 后端默认用户名不能是完整手机号 前端需要 token 化改造,以避免每次 basic 认证流程中 bcrypt 对比导致的性能对比 diff --git a/pkg/env/env.go b/pkg/env/env.go index 5b4a8f0..6d732e8 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -20,7 +20,7 @@ const ( var ( RunMode = RunModeProd LogLevel = slog.LevelDebug - TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 15 分钟 + TradeExpire = 15 * 60 // 交易过期时间,单位秒。默认 900 秒(15 分钟) SessionAccessExpire = 60 * 60 * 2 // 访问令牌过期时间,单位秒。默认 2 小时 SessionRefreshExpire = 60 * 60 * 24 * 7 // 刷新令牌过期时间,单位秒。默认 7 天 DebugHttpDump = false // 是否打印请求和响应的原始数据 diff --git a/scripts/sql/fill.sql b/scripts/sql/fill.sql index edef7f2..d2a39a7 100644 --- a/scripts/sql/fill.sql +++ b/scripts/sql/fill.sql @@ -5,8 +5,8 @@ insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'web', 'web', '$2a$10$Ss12mXQgpYyo1CKIZ3URouDm.Lc2KcYJzsvEK2PTIXlv6fHQht45a', ''); insert into client (type, spec, name, client_id, client_secret, redirect_uri) values (1, 3, 'admin', 'admin', '$2a$10$dlfvX5Uf3iVsUWgwlb0Wt.oYsw/OEXgS.Aior3yoT63Ju7ZSsJr/2', ''); -insert into product (code, name, description) values ('dynamic-short', '短效动态', '短效动态'); -insert into product (code, name, description) values ('dynamic-long', '长效动态', '长效动态'); +insert into product (code, name, description) values ('short', '短效动态', '短效动态'); +insert into product (code, name, description) values ('long', '长效动态', '长效动态'); insert into product (code, name, description) values ('static', '长效静态', '长效静态'); delete from permission where true; diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index 6e497fa..cf7a6eb 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -747,7 +747,7 @@ create table product_sku ( id int generated by default as identity primary key, product_id int not null, discount_id int, - code text not null, + code text not null unique, name text not null, price decimal not null, created_at timestamptz default current_timestamp, @@ -756,7 +756,7 @@ create table product_sku ( ); create index idx_product_sku_product_id on product_sku (product_id) where deleted_at is null; create index idx_product_sku_discount_id on product_sku (discount_id) where deleted_at is null; -create index idx_product_sku_code on product_sku (code) where deleted_at is null; +create unique index idx_product_sku_code on product_sku (code) where deleted_at is null; -- product_sku表字段注释 comment on table product_sku is '产品SKU表'; @@ -977,10 +977,12 @@ create table bill ( trade_id int, resource_id int, refund_id int, + coupon_id int, bill_no text not null, info text, type int not null, amount decimal(12, 2) not null default 0, + actual decimal(12, 2) not null default 0, created_at timestamptz default current_timestamp, updated_at timestamptz default current_timestamp, deleted_at timestamptz @@ -990,6 +992,7 @@ create index idx_bill_user_id on bill (user_id) where deleted_at is null; create index idx_bill_trade_id on bill (trade_id) where deleted_at is null; create index idx_bill_resource_id on bill (resource_id) where deleted_at is null; create index idx_bill_refund_id on bill (refund_id) where deleted_at is null; +create index idx_bill_coupon_id on bill (coupon_id) where deleted_at is null; create index idx_bill_created_at on bill (created_at) where deleted_at is null; -- bill表字段注释 @@ -1002,7 +1005,8 @@ comment on column bill.refund_id is '退款ID'; comment on column bill.bill_no is '易读账单号'; comment on column bill.info is '产品可读信息'; comment on column bill.type is '账单类型:1-消费,2-退款,3-充值'; -comment on column bill.amount is '账单金额'; +comment on column bill.amount is '应付金额'; +comment on column bill.actual is '实付金额'; comment on column bill.created_at is '创建时间'; comment on column bill.updated_at is '更新时间'; comment on column bill.deleted_at is '删除时间'; @@ -1107,14 +1111,20 @@ alter table channel -- resource表外键 alter table resource add constraint fk_resource_user_id foreign key (user_id) references "user" (id) on delete cascade; +alter table resource + add constraint fk_product_code foreign key (code) references product (code) on update cascade on delete restrict; -- resource_short表外键 alter table resource_short add constraint fk_resource_short_resource_id foreign key (resource_id) references resource (id) on delete cascade; +alter table resource_short + add constraint fk_resource_short_code foreign key (code) references product_sku (code) on update cascade on delete restrict; -- resource_long表外键 alter table resource_long add constraint fk_resource_long_resource_id foreign key (resource_id) references resource (id) on delete cascade; +alter table resource_long + add constraint fk_resource_long_code foreign key (code) references product_sku (code) on update cascade on delete restrict; -- trade表外键 alter table trade @@ -1135,6 +1145,8 @@ alter table bill add constraint fk_bill_resource_id foreign key (resource_id) references resource (id) on delete set null; alter table bill add constraint fk_bill_refund_id foreign key (refund_id) references refund (id) on delete set null; +alter table bill + add constraint fk_bill_coupon_id foreign key (coupon_id) references coupon (id) on delete set null; -- coupon表外键 alter table coupon diff --git a/web/core/scopes.go b/web/core/scopes.go index a4dd983..35b7cef 100644 --- a/web/core/scopes.go +++ b/web/core/scopes.go @@ -1,16 +1,24 @@ package core const ( - ScopePermissionRead = string("permission:read") - ScopePermissionWrite = string("permission:write") - ScopeAdminRoleRead = string("admin_role:read") - ScopeAdminRoleWrite = string("admin_role:write") - ScopeAdminRead = string("admin:read") - ScopeAdminWrite = string("admin:write") - ScopeProductRead = string("product:read") - ScopeProductWrite = string("product:write") - ScopeProductSkuRead = string("product_sku:read") - ScopeProductSkuWrite = string("product_sku:write") + ScopePermissionRead = string("permission:read") + ScopePermissionWrite = string("permission:write") + + ScopeAdminRoleRead = string("admin_role:read") + ScopeAdminRoleWrite = string("admin_role:write") + + ScopeAdminRead = string("admin:read") + ScopeAdminWrite = string("admin:write") + + ScopeProductRead = string("product:read") + ScopeProductWrite = string("product:write") + + ScopeProductSkuRead = string("product_sku:read") + ScopeProductSkuWrite = string("product_sku:write") + ScopeProductDiscountRead = string("product_discount:read") ScopeProductDiscountWrite = string("product_discount:write") + + ScopeResourceRead = string("resource:read") + ScopeResourceWrite = string("resource:write") ) diff --git a/web/error.go b/web/error.go index b32d481..f946a6f 100644 --- a/web/error.go +++ b/web/error.go @@ -1,7 +1,9 @@ package web import ( + "encoding/json" "errors" + "fmt" "log/slog" "platform/web/auth" "platform/web/core" @@ -19,6 +21,7 @@ func ErrorHandler(c *fiber.Ctx, err error) error { var authErr auth.AuthErr var bizErr *core.BizErr var servErr *core.ServErr + var jsonErr *json.UnmarshalTypeError switch { @@ -48,6 +51,10 @@ func ErrorHandler(c *fiber.Ctx, err error) error { code = fiber.StatusInternalServerError message = err.Error() + case errors.As(err, &jsonErr): + code = fiber.StatusBadRequest + message = fmt.Sprintf("参数 %s 类型不正确,传入类型为 %s,正确类型应该为 %s", jsonErr.Field, jsonErr.Value, jsonErr.Type.Name()) + // 所有未手动声明的错误类型 default: slog.Warn("未处理的异常", slog.String("type", reflect.TypeOf(err).Name()), slog.String("error", err.Error())) diff --git a/web/events/trade.go b/web/events/trade.go index 6824693..27f8967 100644 --- a/web/events/trade.go +++ b/web/events/trade.go @@ -9,18 +9,23 @@ import ( "github.com/hibiken/asynq" ) -const CompleteTrade = "trade:update" +const CloseTrade = "trade:update" -type CompleteTradeData struct { +type CloseTradeData struct { + UserId int32 `json:"user_id" validate:"required"` TradeNo string `json:"trade_no" validate:"required"` Method m.TradeMethod `json:"method" validate:"required"` } -func NewCancelTrade(data CompleteTradeData) *asynq.Task { - bytes, err := json.Marshal(data) +func NewCloseTradeTask(uid int32, tradeNo string, method m.TradeMethod) *asynq.Task { + bytes, err := json.Marshal(CloseTradeData{ + UserId: uid, + TradeNo: tradeNo, + Method: method, + }) if err != nil { slog.Error("序列化更新交易任务失败", "error", err) return nil } - return asynq.NewTask(CompleteTrade, bytes) + return asynq.NewTask(CloseTrade, bytes) } diff --git a/web/globals/orm/timez.go b/web/globals/orm/timez.go new file mode 100644 index 0000000..618dfaf --- /dev/null +++ b/web/globals/orm/timez.go @@ -0,0 +1,30 @@ +package orm + +import ( + "fmt" + "time" +) + +type DateTime struct { + time.Time +} + +func (dt *DateTime) Scan(value any) error { + switch v := value.(type) { + case time.Time: + dt.Time = v + case string: + t, err := time.Parse(time.RFC3339, v) + if err != nil { + return err + } + dt.Time = t + default: + return fmt.Errorf("unsupported type: %T", value) + } + return nil +} + +func (dt DateTime) Value() (any, error) { + return dt.Time.Format(time.RFC3339), nil +} diff --git a/web/handlers/bill.go b/web/handlers/bill.go index 6ba0cdb..6ea0b0d 100644 --- a/web/handlers/bill.go +++ b/web/handlers/bill.go @@ -47,10 +47,24 @@ func PageBillByAdmin(c *fiber.Ctx) error { time := u.DateHead(*req.CreatedAtEnd) do = do.Where(q.Bill.CreatedAt.Lte(time)) } + if req.ProductCode != nil { + do = do.Where(q.Resource.As("Resource").Code.Eq(*req.ProductCode)) + } + if req.SkuCode != nil { + do = do.Where(q.Bill. + Where(q.ResourceShort.As("Resource__Short").Code.Eq(*req.SkuCode)). + Or(q.ResourceLong.As("Resource__Long").Code.Eq(*req.SkuCode))) + } // 查询用户列表 list, total, err := q.Bill.Debug(). - Joins(q.Bill.User, q.Bill.Resource, q.Bill.Trade). + Joins( + q.Bill.User, + q.Bill.Resource, + q.Bill.Trade, + q.Bill.Resource.Short, + q.Bill.Resource.Long, + ). Select( q.Bill.ALL, q.User.As("User").Phone.As("User__phone"), @@ -82,6 +96,8 @@ type PageBillByAdminReq struct { BillNo *string `json:"bill_no,omitempty"` CreatedAtStart *time.Time `json:"created_at_start,omitempty"` CreatedAtEnd *time.Time `json:"created_at_end,omitempty"` + ProductCode *string `json:"product_code,omitempty"` + SkuCode *string `json:"sku_code,omitempty"` } // ListBill 获取账单列表 diff --git a/web/handlers/product.go b/web/handlers/product.go index fad4e50..89dc197 100644 --- a/web/handlers/product.go +++ b/web/handlers/product.go @@ -91,6 +91,29 @@ func DeleteProduct(c *fiber.Ctx) error { return nil } +func AllProductSkuByAdmin(c *fiber.Ctx) error { + _, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeProductSkuRead) + if err != nil { + return err + } + + var req AllProductSkuByAdminReq + if err := g.Validator.ParseBody(c, &req); err != nil { + return err + } + + list, err := s.ProductSku.All(req.Code) + if err != nil { + return err + } + + return c.JSON(list) +} + +type AllProductSkuByAdminReq struct { + Code string `json:"product_code"` +} + func PageProductSkuByAdmin(c *fiber.Ctx) error { _, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeProductSkuRead) if err != nil { diff --git a/web/handlers/resource.go b/web/handlers/resource.go index 084d5bf..8b04ccd 100644 --- a/web/handlers/resource.go +++ b/web/handlers/resource.go @@ -70,7 +70,7 @@ func PageResourceShort(c *fiber.Ctx) error { } resource, err := q.Resource.Where(do). - Joins(q.Resource.Short, q.ResourceShort.Sku). + Joins(q.Resource.Short). Order(q.Resource.CreatedAt.Desc()). Offset(req.GetOffset()). Limit(req.GetLimit()). @@ -240,9 +240,28 @@ func PageResourceShortByAdmin(c *fiber.Ctx) error { time := u.DateTail(*req.CreatedAtEnd) do = do.Where(q.Resource.CreatedAt.Lte(time)) } + if req.Expired != nil { + if *req.Expired { + do = do.Where(q.Resource.Where( + q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeTime)), + q.ResourceShort.As("Short").ExpireAt.Lte(time.Now()), + ).Or( + q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeQuota)), + q.ResourceShort.As("Short").Quota.LteCol(q.ResourceShort.As("Short").Used), + )) + } else { + do = do.Where(q.Resource.Where( + q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeTime)), + q.ResourceShort.As("Short").ExpireAt.Gt(time.Now()), + ).Or( + q.ResourceShort.As("Short").Type.Eq(int(m.ResourceModeQuota)), + q.ResourceShort.As("Short").Quota.GtCol(q.ResourceShort.As("Short").Used), + )) + } + } list, total, err := q.Resource.Debug(). - Joins(q.Resource.User, q.Resource.Short). + Joins(q.Resource.User, q.Resource.Short, q.Resource.Short.Sku). Select( q.Resource.ALL, q.User.As("User").Phone.As("User__phone"), @@ -254,9 +273,14 @@ func PageResourceShortByAdmin(c *fiber.Ctx) error { q.ResourceShort.As("Short").Daily.As("Short__daily"), q.ResourceShort.As("Short").LastAt.As("Short__last_at"), q.ResourceShort.As("Short").ExpireAt.As("Short__expire_at"), + q.ProductSku.As("Short__Sku").Name.As("Short__Sku__name"), ). Where(q.Resource.Type.Eq(int(m.ResourceTypeShort)), do). + Order(q.Resource.CreatedAt.Desc()). FindByPage(req.GetOffset(), req.GetLimit()) + if err != nil { + return err + } return c.JSON(core.PageResp{ List: list, @@ -274,9 +298,10 @@ type PageResourceShortByAdminReq struct { Mode *int `json:"mode" form:"mode"` CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"` CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"` + Expired *bool `json:"expired" form:"expired"` } -// PageResourceLongByAdmin 分页查询全部短效套餐 +// PageResourceLongByAdmin 分页查询全部长效套餐 func PageResourceLongByAdmin(c *fiber.Ctx) error { _, err := auth.GetAuthCtx(c).PermitAdmin() if err != nil { @@ -307,9 +332,28 @@ func PageResourceLongByAdmin(c *fiber.Ctx) error { if req.CreatedAtEnd != nil { do = do.Where(q.Resource.CreatedAt.Lte(*req.CreatedAtEnd)) } + if req.Expired != nil { + if *req.Expired { + do = do.Where(q.Resource.Where( + q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeTime)), + q.ResourceLong.As("Long").ExpireAt.Lte(time.Now()), + ).Or( + q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeQuota)), + q.ResourceLong.As("Long").Quota.LteCol(q.ResourceLong.As("Long").Used), + )) + } else { + do = do.Where(q.Resource.Where( + q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeTime)), + q.ResourceLong.As("Long").ExpireAt.Gt(time.Now()), + ).Or( + q.ResourceLong.As("Long").Type.Eq(int(m.ResourceModeQuota)), + q.ResourceLong.As("Long").Quota.GtCol(q.ResourceLong.As("Long").Used), + )) + } + } - list, total, err := q.Resource. - Joins(q.Resource.User, q.Resource.Long). + list, total, err := q.Resource.Debug(). + Joins(q.Resource.User, q.Resource.Long, q.Resource.Long.Sku). Select( q.Resource.ALL, q.User.As("User").Phone.As("User__phone"), @@ -321,9 +365,14 @@ func PageResourceLongByAdmin(c *fiber.Ctx) error { q.ResourceLong.As("Long").Daily.As("Long__daily"), q.ResourceLong.As("Long").LastAt.As("Long__last_at"), q.ResourceLong.As("Long").ExpireAt.As("Long__expire_at"), + q.ProductSku.As("Long__Sku").Name.As("Long__Sku__name"), ). Where(q.Resource.Type.Eq(int(m.ResourceTypeLong)), do). + Order(q.Resource.CreatedAt.Desc()). FindByPage(req.GetOffset(), req.GetLimit()) + if err != nil { + return err + } return c.JSON(core.PageResp{ List: list, @@ -341,6 +390,7 @@ type PageResourceLongByAdminReq struct { Mode *int `json:"mode" form:"mode"` CreatedAtStart *time.Time `json:"created_at_start" form:"created_at_start"` CreatedAtEnd *time.Time `json:"created_at_end" form:"created_at_end"` + Expired *bool `json:"expired" form:"expired"` } // AllActiveResource 所有可用套餐 @@ -402,6 +452,24 @@ func AllActiveResource(c *fiber.Ctx) error { type AllResourceReq struct { } +func UpdateResourceByAdmin(c *fiber.Ctx) error { + _, err := auth.GetAuthCtx(c).PermitAdmin(core.ScopeResourceWrite) + if err != nil { + return err + } + + var req s.UpdateResourceData + if err := c.BodyParser(&req); err != nil { + return err + } + + if err := s.Resource.Update(&req); err != nil { + return err + } + + return c.JSON(nil) +} + // StatisticResourceFree 统计每日可用 func StatisticResourceFree(c *fiber.Ctx) error { // 检查权限 @@ -602,26 +670,28 @@ func ResourcePrice(c *fiber.Ctx) error { } // 获取套餐价格 - sku, err := s.Resource.GetSku(req.CreateResourceData) + sku, err := s.Resource.GetSku(req.CreateResourceData.Code()) if err != nil { return err } - before, after, err := s.Resource.GetPrice(sku, req.Count(), nil) + _, amount, discounted, couponApplied, err := s.Resource.GetPrice(sku, req.Count(), nil, nil) if err != nil { return err } // 计算折扣 return c.JSON(ResourcePriceResp{ - Price: before.StringFixed(2), - Discounted: float32(sku.Discount.Discount) / 100, - DiscountedPrice: after.StringFixed(2), + Discount: float32(sku.Discount.Discount) / 100, + Price: amount.StringFixed(2), + Discounted: discounted.StringFixed(2), + CouponApplied: couponApplied.StringFixed(2), }) } type ResourcePriceResp struct { - Price string `json:"price"` - Discounted float32 `json:"discounted"` - DiscountedPrice string `json:"discounted_price"` + Price string `json:"price"` + Discount float32 `json:"discounted"` + Discounted string `json:"discounted_price"` + CouponApplied string `json:"coupon_applied"` } diff --git a/web/handlers/trade.go b/web/handlers/trade.go index baedb2e..a6f474f 100644 --- a/web/handlers/trade.go +++ b/web/handlers/trade.go @@ -109,53 +109,38 @@ func TradeCreate(c *fiber.Ctx) error { if err := g.Validator.ParseBody(c, req); err != nil { return err } - - var product s.ProductInfo switch req.Type { case m.TradeTypePurchase: if req.Resource == nil { return core.NewBizErr("购买信息不能为空") } - product, err = s.NewCreateResourceByTradeData(req.Resource) - if err != nil { - return core.NewServErr("处理购买产品信息失败", err) - } case m.TradeTypeRecharge: if req.Recharge == nil { return core.NewBizErr("充值信息不能为空") } - product = req.Recharge } - // 创建交易 - result, err := s.Trade.CreateTrade(authCtx.User.ID, time.Now(), &req.CreateTradeData, product) + // 处理订单 + uid := authCtx.User.ID + result, err := s.Trade.Create(uid, req.CreateTradeData, req.Resource) if err != nil { - slog.Error("创建交易失败", "error", err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "创建交易失败"}) + return core.NewServErr("处理购买产品信息失败", err) } - return c.JSON(&TradeCreateResp{ - PayUrl: result.PaymentUrl, - TradeNo: result.TradeNo, - }) + return c.JSON(result) } type TradeCreateReq struct { - s.CreateTradeData - Type m.TradeType `json:"type" validate:"required"` - Resource *s.CreateResourceData `json:"resource,omitempty"` - Recharge *s.RechargeProductInfo `json:"recharge,omitempty"` -} - -type TradeCreateResp struct { - PayUrl string `json:"pay_url"` - TradeNo string `json:"trade_no"` + *s.CreateTradeData + Type m.TradeType `json:"type" validate:"required"` + Resource *s.CreateResourceData `json:"resource,omitempty"` + Recharge *s.UpdateBalanceData `json:"recharge,omitempty"` } // 完成订单 func TradeComplete(c *fiber.Ctx) error { // 检查权限 - _, err := auth.GetAuthCtx(c).PermitUser() + authCtx, err := auth.GetAuthCtx(c).PermitUser() if err != nil { return err } @@ -167,7 +152,7 @@ func TradeComplete(c *fiber.Ctx) error { } // 检查订单状态 - err = s.Trade.CompleteTrade(&req.ModifyTradeData) + err = s.Trade.CompleteTrade(authCtx.User, &req.TradeRef) if err != nil { return err } @@ -176,7 +161,7 @@ func TradeComplete(c *fiber.Ctx) error { } type TradeCompleteReq struct { - s.ModifyTradeData + s.TradeRef } // 取消订单 @@ -194,7 +179,7 @@ func TradeCancel(c *fiber.Ctx) error { } // 取消交易 - err = s.Trade.CancelTrade(&req.ModifyTradeData, time.Now()) + err = s.Trade.CancelTrade(&req.TradeRef) if err != nil { slog.Error("取消交易失败", "trade_no", req.TradeNo, "error", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "取消交易失败"}) @@ -204,7 +189,7 @@ func TradeCancel(c *fiber.Ctx) error { } type TradeCancelReq struct { - s.ModifyTradeData + s.TradeRef } // 检查订单 @@ -225,7 +210,7 @@ func TradeCheck(c *fiber.Ctx) error { interval := 5 for range expire / interval { // 检查订单状态 - result, err := s.Trade.CheckTrade(&req.ModifyTradeData) + result, err := s.Trade.CheckTrade(&req.TradeRef) if err != nil { slog.Error("检查订单状态失败", "trade_no", req.TradeNo, "error", err) return @@ -256,5 +241,5 @@ func TradeCheck(c *fiber.Ctx) error { } type TradeCheckReq struct { - s.ModifyTradeData + s.TradeRef } diff --git a/web/models/bill.go b/web/models/bill.go index b717d17..2444409 100644 --- a/web/models/bill.go +++ b/web/models/bill.go @@ -13,10 +13,12 @@ type Bill struct { TradeID *int32 `json:"trade_id,omitempty" gorm:"column:trade_id"` // 订单ID ResourceID *int32 `json:"resource_id,omitempty" gorm:"column:resource_id"` // 套餐ID RefundID *int32 `json:"refund_id,omitempty" gorm:"column:refund_id"` // 退款ID + CouponID *int32 `json:"coupon_id,omitempty" gorm:"column:coupon_id"` // 优惠券ID BillNo string `json:"bill_no" gorm:"column:bill_no"` // 易读账单号 Info *string `json:"info,omitempty" gorm:"column:info"` // 产品可读信息 Type BillType `json:"type" gorm:"column:type"` // 账单类型:1-消费,2-退款,3-充值 - Amount decimal.Decimal `json:"amount" gorm:"column:amount"` // 账单金额 + Amount decimal.Decimal `json:"amount" gorm:"column:amount"` // 应付金额 + Actual decimal.Decimal `json:"actual" gorm:"column:actual"` // 实付金额 User *User `json:"user,omitempty" gorm:"foreignKey:UserID"` Trade *Trade `json:"trade,omitempty" gorm:"foreignKey:TradeID"` diff --git a/web/queries/bill.gen.go b/web/queries/bill.gen.go index ab4aa9e..c17e4e0 100644 --- a/web/queries/bill.gen.go +++ b/web/queries/bill.gen.go @@ -35,10 +35,12 @@ func newBill(db *gorm.DB, opts ...gen.DOOption) bill { _bill.TradeID = field.NewInt32(tableName, "trade_id") _bill.ResourceID = field.NewInt32(tableName, "resource_id") _bill.RefundID = field.NewInt32(tableName, "refund_id") + _bill.CouponID = field.NewInt32(tableName, "coupon_id") _bill.BillNo = field.NewString(tableName, "bill_no") _bill.Info = field.NewString(tableName, "info") _bill.Type = field.NewInt(tableName, "type") _bill.Amount = field.NewField(tableName, "amount") + _bill.Actual = field.NewField(tableName, "actual") _bill.User = billBelongsToUser{ db: db.Session(&gorm.Session{}), @@ -208,10 +210,12 @@ type bill struct { TradeID field.Int32 ResourceID field.Int32 RefundID field.Int32 + CouponID field.Int32 BillNo field.String Info field.String Type field.Int Amount field.Field + Actual field.Field User billBelongsToUser Trade billBelongsToTrade @@ -243,10 +247,12 @@ func (b *bill) updateTableName(table string) *bill { b.TradeID = field.NewInt32(table, "trade_id") b.ResourceID = field.NewInt32(table, "resource_id") b.RefundID = field.NewInt32(table, "refund_id") + b.CouponID = field.NewInt32(table, "coupon_id") b.BillNo = field.NewString(table, "bill_no") b.Info = field.NewString(table, "info") b.Type = field.NewInt(table, "type") b.Amount = field.NewField(table, "amount") + b.Actual = field.NewField(table, "actual") b.fillFieldMap() @@ -263,7 +269,7 @@ func (b *bill) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (b *bill) fillFieldMap() { - b.fieldMap = make(map[string]field.Expr, 16) + b.fieldMap = make(map[string]field.Expr, 18) b.fieldMap["id"] = b.ID b.fieldMap["created_at"] = b.CreatedAt b.fieldMap["updated_at"] = b.UpdatedAt @@ -272,10 +278,12 @@ func (b *bill) fillFieldMap() { b.fieldMap["trade_id"] = b.TradeID b.fieldMap["resource_id"] = b.ResourceID b.fieldMap["refund_id"] = b.RefundID + b.fieldMap["coupon_id"] = b.CouponID b.fieldMap["bill_no"] = b.BillNo b.fieldMap["info"] = b.Info b.fieldMap["type"] = b.Type b.fieldMap["amount"] = b.Amount + b.fieldMap["actual"] = b.Actual } diff --git a/web/routes.go b/web/routes.go index 7a5b052..1f6782d 100644 --- a/web/routes.go +++ b/web/routes.go @@ -4,6 +4,9 @@ import ( "platform/pkg/env" auth2 "platform/web/auth" "platform/web/handlers" + "time" + + q "platform/web/queries" "github.com/gofiber/fiber/v2" ) @@ -23,6 +26,13 @@ func ApplyRouters(app *fiber.App) { debug.Get("/sms/:phone", handlers.DebugGetSmsCode) debug.Get("/proxy/register", handlers.DebugRegisterProxyBaiYin) debug.Get("/iden/clear/:phone", handlers.DebugIdentifyClear) + debug.Get("/session/now", func(ctx *fiber.Ctx) error { + rs, err := q.Session.Where(q.Session.AccessTokenExpires.Gt(time.Now())).Find() + if err != nil { + return err + } + return ctx.JSON(rs) + }) } } @@ -136,6 +146,7 @@ func adminRouter(api fiber.Router) { var resource = api.Group("/resource") resource.Post("/short/page", handlers.PageResourceShortByAdmin) resource.Post("/long/page", handlers.PageResourceLongByAdmin) + resource.Post("/update", handlers.UpdateResourceByAdmin) // batch 批次 var usage = api.Group("batch") @@ -159,6 +170,7 @@ func adminRouter(api fiber.Router) { product.Post("/create", handlers.CreateProduct) product.Post("/update", handlers.UpdateProduct) product.Post("/remove", handlers.DeleteProduct) + product.Post("/sku/all", handlers.AllProductSkuByAdmin) product.Post("/sku/page", handlers.PageProductSkuByAdmin) product.Post("/sku/create", handlers.CreateProductSku) product.Post("/sku/update", handlers.UpdateProductSku) diff --git a/web/services/bill.go b/web/services/bill.go index 23d6bbf..0646ec7 100644 --- a/web/services/bill.go +++ b/web/services/bill.go @@ -2,6 +2,7 @@ package services import ( m "platform/web/models" + q "platform/web/queries" "github.com/shopspring/decimal" ) @@ -10,34 +11,41 @@ var Bill = &billService{} type billService struct{} -func (s *billService) GenNo() string { - return ID.GenReadable("bil") -} - -func newForRecharge(uid int32, billNo string, info string, amount decimal.Decimal, trade *m.Trade) *m.Bill { - return &m.Bill{ +func (s *billService) CreateForBalance(q *q.Query, uid, tradeId int32, detail *TradeDetail) error { + return q.Bill.Create(&m.Bill{ UserID: uid, - BillNo: billNo, - TradeID: &trade.ID, + BillNo: ID.GenReadable("bil"), + TradeID: &tradeId, Type: m.BillTypeRecharge, - Info: &info, - Amount: amount, - } + Info: &detail.Subject, + Amount: detail.Amount, + Actual: detail.Actual, + }) } -func newForConsume(uid int32, billNo string, info string, amount decimal.Decimal, resource *m.Resource, trade ...*m.Trade) *m.Bill { - var bill = &m.Bill{ +func (s *billService) CreateForResourceByTrade(q *q.Query, uid, tradeId, resourceId int32, detail *TradeDetail) error { + return q.Bill.Create(&m.Bill{ UserID: uid, - BillNo: billNo, - ResourceID: &resource.ID, + BillNo: ID.GenReadable("bil"), + ResourceID: &resourceId, + TradeID: &tradeId, + CouponID: detail.CouponId, Type: m.BillTypeConsume, - Info: &info, - Amount: amount, - } - - if len(trade) > 0 { - bill.TradeID = &trade[0].ID - } - - return bill + Info: &detail.Subject, + Amount: detail.Amount, + Actual: detail.Actual, + }) +} + +func (s *billService) CreateForResourceByBalance(q *q.Query, uid, resourceId int32, couponId *int32, subject string, amount, actual decimal.Decimal) error { + return q.Bill.Create(&m.Bill{ + UserID: uid, + BillNo: ID.GenReadable("bil"), + ResourceID: &resourceId, + CouponID: couponId, + Type: m.BillTypeConsume, + Info: &subject, + Amount: amount, + Actual: actual, + }) } diff --git a/web/services/coupon.go b/web/services/coupon.go new file mode 100644 index 0000000..2bbd804 --- /dev/null +++ b/web/services/coupon.go @@ -0,0 +1,64 @@ +package services + +import ( + "errors" + "fmt" + "platform/web/core" + m "platform/web/models" + q "platform/web/queries" + "time" + + "github.com/shopspring/decimal" + "gorm.io/gorm" +) + +var Coupon = &couponService{} + +type couponService struct{} + +func (s *couponService) GetCouponAvailableByCode(code string, amount decimal.Decimal, uid *int32) (*m.Coupon, error) { + // 获取优惠券 + coupon, err := q.Coupon.Where( + q.Coupon.Code.Eq(code), + q.Coupon.Status.Eq(int(m.CouponStatusUnused)), + q.Coupon. + Where(q.Coupon.ExpireAt.Gt(time.Now())). + Or(q.Coupon.ExpireAt.IsNull()), + ).Take() + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, core.NewBizErr("优惠券不存在或已失效") + } + if err != nil { + return nil, core.NewBizErr("获取优惠券数据失败", err) + } + + // 检查最小使用额度 + if amount.Cmp(coupon.MinAmount) < 0 { + return nil, core.NewBizErr(fmt.Sprintf("使用此优惠券的最小额度为 %s", coupon.MinAmount)) + } + + // 检查所属 + if coupon.UserID != nil { + if uid == nil { + return nil, core.NewBizErr("检查优惠券所属用户失败") + } + if *coupon.UserID != *uid { + return nil, core.NewBizErr("优惠券不属于当前用户") + } + } + + return coupon, nil +} + +func (s *couponService) UseCoupon(q *q.Query, id int32) error { + _, err := q.Coupon. + Where( + q.Coupon.ID.Eq(id), + q.Coupon.Status.Eq(int(m.CouponStatusUnused)), + q.Coupon.ExpireAt.Gt(time.Now()), + ). + UpdateSimple( + q.Coupon.Status.Value(int(m.CouponStatusUsed)), + ) + return err +} diff --git a/web/services/product_sku.go b/web/services/product_sku.go index 7dbed0d..93372ff 100644 --- a/web/services/product_sku.go +++ b/web/services/product_sku.go @@ -15,6 +15,14 @@ var ProductSku = &productSkuService{} type productSkuService struct{} +func (s *productSkuService) All(product_code string) (result []*m.ProductSku, err error) { + return q.ProductSku. + Joins(q.ProductSku.Product). + Where(q.Product.As("Product").Code.Eq(product_code)). + Select(q.ProductSku.ALL). + Find() +} + func (s *productSkuService) Page(req *core.PageReq, productId *int32) (result []*m.ProductSku, count int64, err error) { do := make([]gen.Condition, 0) if productId != nil { diff --git a/web/services/resource.go b/web/services/resource.go index e655ae0..cd5fd21 100644 --- a/web/services/resource.go +++ b/web/services/resource.go @@ -1,7 +1,6 @@ package services import ( - "encoding/json" "errors" "fmt" "platform/pkg/u" @@ -11,6 +10,7 @@ import ( "time" "github.com/shopspring/decimal" + "gorm.io/gen/field" "gorm.io/gorm" ) @@ -18,6 +18,7 @@ var Resource = &resourceService{} type resourceService struct{} +// CreateResourceByBalance 通过余额购买套餐 func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data *CreateResourceData) error { // 找到用户 @@ -29,16 +30,21 @@ func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data } // 获取 sku - sku, err := s.GetSku(data) + sku, err := s.GetSku(data.Code()) if err != nil { return err } // 检查余额 - _, amount, err := s.GetPrice(sku, data.Count(), &uid) + coupon, _, amount, actual, err := s.GetPrice(sku, data.Count(), &uid, data.CouponCode) if err != nil { return err } + couponId := (*int32)(nil) + if coupon != nil { + couponId = &coupon.ID + } + newBalance := user.Balance.Sub(amount) if newBalance.IsNegative() { return ErrBalanceNotEnough @@ -58,49 +64,30 @@ func (s *resourceService) CreateResourceByBalance(uid int32, now time.Time, data } // 保存套餐 - resource, err := createResource(q, uid, now, data) + resource, err := s.Create(q, uid, now, data) if err != nil { return core.NewServErr("创建套餐失败", err) } // 生成账单 - err = q.Bill.Create(newForConsume(uid, Bill.GenNo(), sku.Name, amount, resource)) + err = Bill.CreateForResourceByBalance(q, uid, resource.ID, couponId, sku.Name, amount, actual) if err != nil { return core.NewServErr("生成账单失败", err) } + // 核销优惠券 + if coupon != nil { + err = Coupon.UseCoupon(q, coupon.ID) + if err != nil { + return core.NewServErr("核销优惠券失败", err) + } + } + return nil }) } -func (s *resourceService) CreateResourceByTrade(uid int32, now time.Time, data *CreateResourceByTradeData, trade *m.Trade) error { // 检查交易 - if trade == nil { - return core.NewBizErr("交易数据不能为空") - } - if trade.Status != m.TradeStatusSuccess { - return core.NewBizErr("交易状态不正确") - } - - return q.Q.Transaction(func(q *q.Query) error { - - // 保存套餐 - resource, err := createResource(q, uid, now, data.Req) - if err != nil { - return core.NewServErr("创建套餐失败", err) - } - - // 生成账单 - err = q.Bill.Create(newForConsume(uid, Bill.GenNo(), data.GetSubject(), data.GetAmount(), resource, trade)) - if err != nil { - return core.NewServErr("生成账单失败", err) - } - - return nil - }) -} - -func createResource(q *q.Query, uid int32, now time.Time, data *CreateResourceData) (*m.Resource, error) { - +func (s *resourceService) Create(q *q.Query, uid int32, now time.Time, data *CreateResourceData) (*m.Resource, error) { // 套餐基本信息 var resource = m.Resource{ UserID: uid, @@ -162,10 +149,35 @@ func createResource(q *q.Query, uid int32, now time.Time, data *CreateResourceDa return &resource, nil } -func (s *resourceService) GetSku(data *CreateResourceData) (*m.ProductSku, error) { +func (s *resourceService) Update(data *UpdateResourceData) error { + if data.Active == nil { + return core.NewBizErr("更新套餐失败,active 不能为空") + } + + do := make([]field.AssignExpr, 0) + if data.Active != nil { + do = append(do, q.Resource.Active.Value(*data.Active)) + } + + _, err := q.Resource. + Where(q.Resource.ID.Eq(data.Id)). + UpdateSimple(do...) + if err != nil { + return core.NewServErr("更新套餐失败", err) + } + + return nil +} + +type UpdateResourceData struct { + core.IdReq + Active *bool `json:"active"` +} + +func (s *resourceService) GetSku(code string) (*m.ProductSku, error) { sku, err := q.ProductSku. Joins(q.ProductSku.Discount). - Where(q.ProductSku.Code.Eq(data.Code())). + Where(q.ProductSku.Code.Eq(code)). Take() if err != nil { return nil, core.NewServErr("产品不可用", err) @@ -178,43 +190,55 @@ func (s *resourceService) GetSku(data *CreateResourceData) (*m.ProductSku, error return sku, nil } -func (s *resourceService) GetPrice(sku *m.ProductSku, count int32, uid *int32) (decimal.Decimal, decimal.Decimal, error) { +func (s *resourceService) GetPrice(sku *m.ProductSku, count int32, uid *int32, couponCode *string) (*m.Coupon, decimal.Decimal, decimal.Decimal, decimal.Decimal, error) { - // 根据用户 id 查询特殊优惠 - var uSku *m.ProductSkuUser - if uid != nil { + // 原价 + price := sku.Price + amount := price.Mul(decimal.NewFromInt32(count)) + + // 折扣价 + discount := sku.Discount.Decimal() + if uid != nil { // 用户特殊优惠 var err error - uSku, err = q.ProductSkuUser. + uSku, err := q.ProductSkuUser. Joins(q.ProductSkuUser.Discount). Where( q.ProductSkuUser.UserID.Eq(*uid), q.ProductSkuUser.ProductSkuID.Eq(sku.ID)). Take() if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return decimal.Zero, decimal.Zero, core.NewServErr("客户特殊价查询失败", err) + return nil, decimal.Zero, decimal.Zero, decimal.Zero, core.NewServErr("客户特殊价查询失败", err) + } + if uSku.Discount == nil { + return nil, decimal.Zero, decimal.Zero, decimal.Zero, core.NewServErr("价格获取失败") + } + uDiscount := uSku.Discount.Decimal() + if uDiscount.Cmp(discount) > 0 { + discount = uDiscount } } + discounted := amount.Mul(discount) - if uSku.Discount == nil { - return decimal.Decimal{}, decimal.Decimal{}, core.NewServErr("价格获取失败") + // 优惠价 + coupon := (*m.Coupon)(nil) + couponApplied := discounted.Copy() + if couponCode != nil { + var err error + coupon, err = Coupon.GetCouponAvailableByCode(*couponCode, discounted, uid) + if err != nil { + return nil, decimal.Zero, decimal.Zero, decimal.Zero, err + } + couponApplied = discounted.Sub(coupon.Amount) } - // 返回计算价格 - price := sku.Price - discount := sku.Discount.Decimal() - if uSku != nil { - discount = uSku.Discount.Decimal() - } - - before := price.Mul(decimal.NewFromInt32(count)) - after := before.Mul(discount) - return before, after, nil + return coupon, amount, discounted, couponApplied, nil } type CreateResourceData struct { - Type m.ResourceType `json:"type" validate:"required"` - Short *CreateShortResourceData `json:"short,omitempty"` - Long *CreateLongResourceData `json:"long,omitempty"` + Type m.ResourceType `json:"type" validate:"required"` + Short *CreateShortResourceData `json:"short,omitempty"` + Long *CreateLongResourceData `json:"long,omitempty"` + CouponCode *string `json:"coupon,omitempty"` } type CreateShortResourceData struct { @@ -267,71 +291,22 @@ func (c *CreateResourceData) Code() string { } } -func (c *CreateResourceData) Serialize() (string, error) { - bytes, err := json.Marshal(c) - return string(bytes), err -} - -func (c *CreateResourceData) Deserialize(str string) error { - return json.Unmarshal([]byte(str), c) -} - -// 交易后创建套餐 -type ResourceOnTradeComplete struct{} - -func (r ResourceOnTradeComplete) Check(t m.TradeType) (ProductInfo, bool) { - if t == m.TradeTypePurchase { - return &CreateResourceByTradeData{}, true - } - return nil, false -} - -func (r ResourceOnTradeComplete) OnTradeComplete(info ProductInfo, trade *m.Trade) error { - return Resource.CreateResourceByTrade(trade.UserID, time.Time(*trade.CompletedAt), info.(*CreateResourceByTradeData), trade) -} - -type CreateResourceByTradeData struct { - Subject string `json:"subject"` - Amount decimal.Decimal `json:"amount"` - Req *CreateResourceData `json:"data"` -} - -func (e CreateResourceByTradeData) GetType() m.TradeType { - return m.TradeTypePurchase -} - -func (e CreateResourceByTradeData) GetSubject() string { - return e.Subject -} - -func (e CreateResourceByTradeData) GetAmount() decimal.Decimal { - return e.Amount -} - -func (e CreateResourceByTradeData) Serialize() (string, error) { - bytes, err := json.Marshal(e) - return string(bytes), err -} - -func (e *CreateResourceByTradeData) Deserialize(str string) error { - return json.Unmarshal([]byte(str), e) -} - -func NewCreateResourceByTradeData(req *CreateResourceData) (*CreateResourceByTradeData, error) { - sku, err := Resource.GetSku(req) +func (c *CreateResourceData) TradeDetail() (*TradeDetail, error) { + sku, err := Resource.GetSku(c.Code()) if err != nil { return nil, err } - _, amount, err := Resource.GetPrice(sku, req.Count(), nil) + coupon, _, amount, actual, err := Resource.GetPrice(sku, c.Count(), nil, c.CouponCode) if err != nil { return nil, err } - return &CreateResourceByTradeData{ - Subject: sku.Name, - Amount: amount, - Req: req, + return &TradeDetail{ + m.TradeTypePurchase, + sku.Name, + amount, actual, + &coupon.ID, c, }, nil } diff --git a/web/services/trade.go b/web/services/trade.go index 39e7879..156fd26 100644 --- a/web/services/trade.go +++ b/web/services/trade.go @@ -2,6 +2,7 @@ package services import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -23,7 +24,6 @@ import ( "github.com/smartwalle/alipay/v3" "github.com/wechatpay-apiv3/wechatpay-go/services/partnerpayments/h5" "github.com/wechatpay-apiv3/wechatpay-go/services/payments/native" - "gorm.io/gorm" ) var Trade = &tradeService{} @@ -32,72 +32,17 @@ type tradeService struct { } // 创建交易 -func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTradeData, product ProductInfo) (*CreateTradeResult, error) { - platform := payment.Platform - method := payment.Method - tType := product.GetType() - expire := time.Now().Add(30 * time.Minute) - subject := product.GetSubject() - amount := product.GetAmount() - - // 实际支付金额,只在创建真实订单时使用 - amountReal := amount - if env.RunMode == env.RunModeDev { - amountReal = decimal.NewFromFloat(0.01) +func (s *tradeService) Create(uid int32, tradeData *CreateTradeData, productData *CreateResourceData) (*CreateTradeResult, error) { + detail, err := productData.TradeDetail() + if err != nil { + return nil, core.NewServErr("获取产品支付信息失败", err) } - // 附加优惠券 - if payment.CouponCode != nil { - coupon, err := q.Coupon. - Where( - q.Coupon.Code.Eq(*payment.CouponCode), - q.Coupon.Status.Eq(int(m.CouponStatusUnused)), - ). - Take() - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, errors.New("优惠券不存在或已失效") - } - return nil, err - } - - expireAt := time.Time(u.Z(coupon.ExpireAt)) - if !expireAt.IsZero() && expireAt.Before(now) { - _, err = q.Coupon. - Where(q.Coupon.ID.Eq(coupon.ID)). - Update(q.Coupon.Status, m.CouponStatusExpired) - if err != nil { - return nil, err - } - return nil, errors.New("优惠券已过期") - } - - if amount.Cmp(coupon.MinAmount) < 0 { - return nil, errors.New("订单金额未达到使用优惠券的条件") - } - - if coupon.UserID != nil { - switch *coupon.UserID { - // 指定用户的优惠券 - case uid: - amount = amount.Sub(coupon.Amount) - if expireAt.IsZero() { - _, err = q.Coupon. - Where(q.Coupon.ID.Eq(coupon.ID)). - Update(q.Coupon.Status, int(m.CouponStatusUsed)) - if err != nil { - return nil, err - } - } - // 该优惠券不属于当前用户 - default: - return nil, errors.New("优惠券不属于当前用户") - } - } else { - // 公开优惠券 - amount = amount.Sub(coupon.Amount) - } - } + now := time.Now() + platform := tradeData.Platform + method := tradeData.Method + expireIn := time.Duration(env.TradeExpire) * time.Second + expireAt := now.Add(expireIn) // 生成订单号 tradeNo, err := ID.GenSerial() @@ -105,6 +50,12 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad return nil, core.NewServErr("生成订单号失败", err) } + // 实际支付金额,只在创建真实订单时使用 + amountReal := detail.Actual + if env.RunMode == env.RunModeDev { + amountReal = decimal.NewFromFloat(0.01) + } + // 提交支付订单 var paymentUrl string switch { @@ -117,9 +68,9 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad Trade: alipay.Trade{ ProductCode: "FAST_INSTANT_TRADE_PAY", OutTradeNo: tradeNo, - Subject: subject, + Subject: detail.Subject, TotalAmount: amountReal.StringFixed(2), - TimeExpire: expire.Format("2006-01-02 15:04:05"), + TimeExpire: expireAt.Format("2006-01-02 15:04:05"), }, }) if err != nil { @@ -133,8 +84,8 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad Appid: &env.WechatPayAppId, Mchid: &env.WechatPayMchId, OutTradeNo: &tradeNo, - Description: &subject, - TimeExpire: &expire, + Description: &detail.Subject, + TimeExpire: &expireAt, NotifyUrl: &env.WechatPayCallbackUrl, Amount: &native.Amount{ Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()), @@ -151,8 +102,8 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad SpAppid: &env.WechatPayAppId, SpMchid: &env.WechatPayMchId, OutTradeNo: &tradeNo, - Description: &subject, - TimeExpire: &expire, + Description: &detail.Subject, + TimeExpire: &expireAt, NotifyUrl: &env.WechatPayCallbackUrl, Amount: &h5.Amount{ Total: u.P(amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart()), @@ -174,18 +125,17 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad payType = g.SftAlipay case m.TradeMethodSftWechat: payType = g.SftWeChat - default: - panic("unhandled default case") } + resp, err := g.SFTPay.PaymentScanPay(&g.PaymentScanPayReq{ MchOrderNo: tradeNo, - Subject: subject, - Body: subject, + Subject: detail.Subject, + Body: detail.Subject, Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(), PayType: payType, Currency: "cny", ClientIp: "123.52.74.23", - OrderTimeout: u.P(expire.Format("2006-01-02 15:04:05")), + OrderTimeout: u.P(expireAt.Format("2006-01-02 15:04:05")), }) if err != nil { return nil, err @@ -196,24 +146,24 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad case method == m.TradeMethodSftAlipay && platform == m.TradePlatformMobile, method == m.TradeMethodSftWechat && platform == m.TradePlatformMobile: + var payType g.SftPayType switch method { case m.TradeMethodSftAlipay: payType = g.SftAlipay case m.TradeMethodSftWechat: payType = g.SftWeChat - default: - panic("unhandled default case") } + resp, err := g.SFTPay.PaymentH5Pay(&g.PaymentH5PayReq{ MchOrderNo: tradeNo, - Subject: subject, - Body: subject, + Subject: detail.Subject, + Body: detail.Subject, Amount: amountReal.Mul(decimal.NewFromInt(100)).Round(0).IntPart(), PayType: payType, Currency: "cny", ClientIp: "123.52.74.23", - OrderTimeout: u.P(expire.Format("2006-01-02 15:04:05")), + OrderTimeout: u.P(expireAt.Format("2006-01-02 15:04:05")), }) if err != nil { return nil, err @@ -230,9 +180,9 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad err = q.Trade.Create(&m.Trade{ UserID: uid, InnerNo: tradeNo, - Type: tType, - Subject: subject, - Amount: amount, + Type: detail.Type, + Subject: detail.Subject, + Amount: detail.Actual, Method: method, Platform: platform, PaymentURL: &paymentUrl, @@ -242,7 +192,7 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad } // 缓存产品数据 - serialized, err := product.Serialize() + serialized, err := json.Marshal(detail) if err != nil { return nil, core.NewServErr("序列化产品信息失败", err) } @@ -251,286 +201,233 @@ func (s *tradeService) CreateTrade(uid int32, now time.Time, payment *CreateTrad context.Background(), tradeProductKey(tradeNo), serialized, - time.Duration(env.TradeExpire+10)*time.Second, + expireIn, ).Err() if err != nil { return nil, core.NewServErr("保存购买信息失败", err) } // 提交异步关闭事件 - closeAt := now.Add(time.Duration(env.TradeExpire) * time.Second) - _, err = g.Asynq.Enqueue(e.NewCancelTrade(e.CompleteTradeData{ - TradeNo: tradeNo, - Method: method, - }), asynq.ProcessAt(closeAt)) + _, err = g.Asynq.Enqueue(e.NewCloseTradeTask(uid, tradeNo, method), asynq.ProcessAt(expireAt)) if err != nil { return nil, core.NewServErr("提交异步关闭事件失败", err) } return &CreateTradeResult{ - PaymentUrl: paymentUrl, - TradeNo: tradeNo, + PayUrl: paymentUrl, + TradeNo: tradeNo, }, nil } // 完成交易 -func (s *tradeService) CompleteTrade(data *ModifyTradeData) error { - return g.Redsync.WithLock(tradeLockKey(data.TradeNo), func() error { +func (s *tradeService) CompleteTrade(user *m.User, ref *TradeRef) error { - // 检查订单状态 - result, err := s.CheckTrade(data) - if err != nil { - return core.NewServErr("检查订单状态失败", err) - } - if result.Status != m.TradeStatusSuccess { - switch result.Status { - case m.TradeStatusPending: - return core.NewBizErr("订单未支付") - case m.TradeStatusCanceled: - return core.NewBizErr("订单已过期") - } - } - - // 更新交易状态 - trade, err := completeTrade(&OnTradeCompletedData{ - data.TradeNo, - result.TransId, - result.Success, - }) - if err != nil { - return core.NewServErr("处理交易失败", err) - } - - // 处理交易完成事件 - err = afterTradeComplete(trade) - if err != nil { - return core.NewServErr("处理交易完成事件失败", err) - } - - return nil - }) -} -func (s *tradeService) OnTradeCompleted(data *OnTradeCompletedData) error { - return g.Redsync.WithLock(tradeLockKey(data.TradeNo), func() error { - - // 更新交易状态 - trade, err := completeTrade(data) - if err != nil { - return core.NewServErr("处理交易失败", err) - } - - // 处理交易完成事件 - err = afterTradeComplete(trade) - if err != nil { - return core.NewServErr("处理交易完成事件失败", err) - } - - return nil - }) -} -func completeTrade(data *OnTradeCompletedData) (*m.Trade, error) { - var trade = new(m.Trade) - var err = q.Q.Transaction(func(tx *q.Query) error { - var tradeNo = data.TradeNo - var transId = data.TransId - var payment = data.Payment - var acquirer = data.Acquirer - var paidAt = data.Time - - // 获取交易信息 - var err error - trade, err = q.Trade. - Where(q.Trade.InnerNo.Eq(tradeNo)). - Take() - if err != nil { - return core.NewBizErr("获取交易信息失败", err) - } - - // 检查交易状态 - switch trade.Status { - case m.TradeStatusCanceled: - return core.NewBizErr("交易已取消") - case m.TradeStatusSuccess: - return nil // 跳过更新交易信息 - case m.TradeStatusPending: - } - - // 更新交易信息 - trade.Status = m.TradeStatusSuccess - trade.OuterNo = &transId - trade.Payment = payment - trade.Acquirer = u.P(acquirer) - trade.CompletedAt = u.P(paidAt) - rs, err := q.Trade. - Where(q.Trade.InnerNo.Eq(tradeNo), q.Trade.Status.Eq(int(m.TradeStatusPending))). - Updates(trade) - if rs.RowsAffected == 0 { - return core.NewBizErr("交易状态已发生变化") - } - if err != nil { - return core.NewServErr("更新交易信息失败", err) - } - - return nil - }) + // 检查订单状态 + result, err := s.CheckTrade(ref) if err != nil { - return nil, err - } else { - return trade, err + return core.NewServErr("检查订单状态失败", err) } + if result.Status != m.TradeStatusSuccess { + switch result.Status { + case m.TradeStatusPending: + return core.NewBizErr("订单未支付") + case m.TradeStatusCanceled: + return core.NewBizErr("订单已过期") + } + } + + // 更新交易状态 + err = s.OnCompleteTrade(user, ref.TradeNo, result.TransId, &result.Success) + if err != nil { + return core.NewServErr("处理交易失败", err) + } + + return nil } -func afterTradeComplete(trade *m.Trade) error { +func (s *tradeService) OnCompleteTrade(user *m.User, interNo string, outerNo string, result *TradeSuccessResult) error { + + // 获取交易信息 + trade, err := q.Trade. + Where(q.Trade.InnerNo.Eq(interNo)). + Take() + if err != nil { + return core.NewBizErr("获取交易信息失败", err) + } + + // 检查交易状态 + switch trade.Status { + case m.TradeStatusCanceled: + return core.NewBizErr("交易已取消") + case m.TradeStatusSuccess: + return nil // 跳过更新交易信息 + case m.TradeStatusPending: + } // 恢复购买信息 - productData, err := g.Redis.Get(context.Background(), tradeProductKey(trade.InnerNo)).Result() + detailStr, err := g.Redis.Get(context.Background(), tradeProductKey(interNo)).Result() if err != nil { return core.NewServErr("恢复购买信息失败", err) } - // 执行资源创建 - var ComplementEvents = []CompleteEvent{ - ResourceOnTradeComplete{}, - UserOnTradeComplete{}, + var detail TradeDetail + if err := json.Unmarshal([]byte(detailStr), &detail); err != nil { + return core.NewServErr("解析购买信息失败", err) } - for _, event := range ComplementEvents { - info, ok := event.Check(trade.Type) - if !ok { - continue + err = q.Q.Transaction(func(q *q.Query) error { + // 更新交易信息 + _, err := q.Trade. + Where( + q.Trade.InnerNo.Eq(interNo), + q.Trade.Status.Eq(int(m.TradeStatusPending)), + ). + UpdateSimple( + q.Trade.Status.Value(int(m.TradeStatusSuccess)), + q.Trade.OuterNo.Value(outerNo), + q.Trade.Payment.Value(result.Actual), + q.Trade.Acquirer.Value(int(result.Acquirer)), + q.Trade.CompletedAt.Value(result.Time), + ) + if err != nil { + return core.NewServErr("更新交易信息失败", err) } - err = info.Deserialize(productData) - if err != nil { - return core.NewServErr("反序列化购买信息失败", err) + switch trade.Type { + case m.TradeTypeRecharge: + // 更新用户余额 + if err := User.UpdateBalance(q, user, detail.Actual); err != nil { + return err + } + + // 生成账单 + err = Bill.CreateForBalance(q, user.ID, trade.ID, &detail) + if err != nil { + return core.NewServErr("生成账单失败", err) + } + + case m.TradeTypePurchase: + data, ok := detail.Product.(*CreateResourceData) + if !ok { + return core.NewServErr("购买信息解析失败", nil) + } + + // 保存套餐 + resource, err := Resource.Create(q, user.ID, result.Time, data) + if err != nil { + return core.NewServErr("创建套餐失败", err) + } + + // 生成账单 + err = Bill.CreateForResourceByTrade(q, user.ID, resource.ID, trade.ID, &detail) + if err != nil { + return core.NewServErr("生成账单失败", err) + } + + // 核销优惠券 + if detail.CouponId != nil { + err = Coupon.UseCoupon(q, *detail.CouponId) + if err != nil { + return core.NewServErr("核销优惠券失败", err) + } + } } - err = event.OnTradeComplete(info, trade) - if err != nil { - return core.NewServErr("处理交易完成事件失败", err) - } + return nil + }) + if err != nil { + return err } return nil } // 取消交易 -func (s *tradeService) CancelTrade(data *ModifyTradeData, now time.Time) error { - tradeNo := data.TradeNo - method := data.Method +func (s *tradeService) CancelTrade(ref *TradeRef) error { + now := time.Now() - return g.Redsync.WithLock(tradeLockKey(tradeNo), func() error { - switch method { - - case m.TradeMethodAlipay: - resp, err := g.Alipay.TradeCancel(context.Background(), alipay.TradeCancel{ - OutTradeNo: tradeNo, - }) - if err != nil { - return core.NewServErr("上游取消交易失败", err) - } - if resp.Code != alipay.CodeSuccess { - slog.Error("支付宝交易取消失败", "code", resp.Code, "sub_code", resp.SubCode, "msg", resp.Msg) - return errors.New("上游取消交易失败") - } - - case m.TradeMethodWechat: - resp, err := g.WechatPay.Native.CloseOrder(context.Background(), native.CloseOrderRequest{ - Mchid: &env.WechatPayMchId, - OutTradeNo: &tradeNo, - }) - if err != nil { - return core.NewServErr("上游取消交易失败", err) - } - if resp.Response.StatusCode != http.StatusNoContent { - body, err := io.ReadAll(resp.Response.Body) - if err != nil { - slog.Error("读取微信交易取消响应失败", "error", err) - return core.NewServErr("上游取消交易失败", err) - } - slog.Error("微信交易取消失败", "code", resp.Response.StatusCode, "body", string(body)) - return errors.New("上游取消交易失败") - } - - case m.TradeMethodSft, m.TradeMethodSftAlipay, m.TradeMethodSftWechat: - _, err := g.SFTPay.OrderClose(&g.OrderCloseReq{ - MchOrderNo: &tradeNo, - }) - if err != nil { - slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error())) - return nil - } - - default: - return ErrTransactionNotSupported - } - - err := cancelTrade(tradeNo, now) + switch ref.Method { + case m.TradeMethodAlipay: + resp, err := g.Alipay.TradeCancel(context.Background(), alipay.TradeCancel{ + OutTradeNo: ref.TradeNo, + }) if err != nil { - return err + return core.NewServErr("上游取消交易失败", err) + } + if resp.Code != alipay.CodeSuccess { + slog.Error("支付宝交易取消失败", "code", resp.Code, "sub_code", resp.SubCode, "msg", resp.Msg) + return errors.New("上游取消交易失败") } - return nil - }) -} -func (s *tradeService) OnTradeCanceled(tradeNo string, now time.Time) error { - err := g.Redsync.WithLock(tradeLockKey(tradeNo), func() error { - return cancelTrade(tradeNo, now) - }) - if err != nil { - return core.NewServErr("处理交易取消失败", err) + case m.TradeMethodWechat: + resp, err := g.WechatPay.Native.CloseOrder(context.Background(), native.CloseOrderRequest{ + Mchid: &env.WechatPayMchId, + OutTradeNo: &ref.TradeNo, + }) + if err != nil { + return core.NewServErr("上游取消交易失败", err) + } + if resp.Response.StatusCode != http.StatusNoContent { + body, err := io.ReadAll(resp.Response.Body) + if err != nil { + slog.Error("读取微信交易取消响应失败", "error", err) + return core.NewServErr("上游取消交易失败", err) + } + slog.Error("微信交易取消失败", "code", resp.Response.StatusCode, "body", string(body)) + return errors.New("上游取消交易失败") + } + + case m.TradeMethodSft, m.TradeMethodSftAlipay, m.TradeMethodSftWechat: + _, err := g.SFTPay.OrderClose(&g.OrderCloseReq{ + MchOrderNo: &ref.TradeNo, + }) + if err != nil { + slog.Debug(fmt.Sprintf("订单无需关闭: %s", err.Error())) + return nil + } + + default: + return ErrTransactionNotSupported } + + err := s.OnCancelTrade(ref.TradeNo, now) + if err != nil { + return err + } + return nil } -func cancelTrade(tradeNo string, now time.Time) error { - return q.Q.Transaction(func(q *q.Query) error { - // 获取交易信息 - var status m.TradeStatus - err := q.Trade. - Where(q.Trade.InnerNo.Eq(tradeNo)). - Select(q.Trade.Status). - Scan(&status) - if err != nil { - return core.NewBizErr("获取交易信息失败", err) - } +func (s *tradeService) OnCancelTrade(tradeNo string, now time.Time) error { + _, err := q.Trade. + Where( + q.Trade.InnerNo.Eq(tradeNo), + q.Trade.Status.Eq(int(m.TradeStatusPending)), + ). + UpdateSimple( + q.Trade.Status.Value(int(m.TradeStatusCanceled)), + q.Trade.CanceledAt.Value(now), + ) + if err != nil { + return core.NewServErr("更新交易状态失败", err) + } - // 检查交易状态 - switch status { - case m.TradeStatusCanceled: - return core.NewBizErr("交易已取消") - case m.TradeStatusSuccess: - return core.NewBizErr("交易已完成") - case m.TradeStatusPending: - } - - // 更新交易状态 - _, err = q.Trade. - Where(q.Trade.InnerNo.Eq(tradeNo)). - UpdateSimple( - q.Trade.Status.Value(int(m.TradeStatusCanceled)), - q.Trade.CanceledAt.Value(now), - ) - if err != nil { - return core.NewServErr("更新交易状态失败", err) - } - return nil - }) + return nil } // 交易退款 -func (s *tradeService) RefundTrade(data *ModifyTradeData) error { +func (s *tradeService) RefundTrade(ref *TradeRef) error { panic("todo") } -func (s *tradeService) OnTradeRefunded(q *q.Query, tradeNo string, now time.Time) error { +func (s *tradeService) OnRefundTrade(q *q.Query, tradeNo string, now time.Time) error { panic("todo") } // 检查交易状态 -func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, error) { - var tradeNo = data.TradeNo - var method = data.Method +func (s *tradeService) CheckTrade(ref *TradeRef) (*CheckTradeResult, error) { + var tradeNo = ref.TradeNo + var method = ref.Method // 检查交易号是否存在 - var result = new(CheckTradeResult) + var result CheckTradeResult switch method { // 支付宝 @@ -560,9 +457,8 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err case alipay.TradeStatusSuccess, alipay.TradeStatusFinished: result.Status = m.TradeStatusSuccess - result.Success = &TradeSuccessResult{} result.Success.Acquirer = m.TradeAcquirerAlipay - result.Success.Payment, err = decimal.NewFromString(resp.TotalAmount) + result.Success.Actual, err = decimal.NewFromString(resp.ReceiptAmount) if err != nil { return nil, err } @@ -606,9 +502,8 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err case "SUCCESS", "REFUND": result.Status = m.TradeStatusSuccess - result.Success = &TradeSuccessResult{} result.Success.Acquirer = m.TradeAcquirerWechat - result.Success.Payment = decimal.NewFromInt(*resp.Amount.PayerTotal).Div(decimal.NewFromInt(100)) + result.Success.Actual = decimal.NewFromInt(*resp.Amount.PayerTotal).Div(decimal.NewFromInt(100)) result.Success.Time, err = time.Parse(time.RFC3339, *resp.SuccessTime) if err != nil { return nil, err @@ -626,12 +521,12 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err return nil, err } + // 填充返回值 if resp.PayOrderId == nil { return nil, errors.New("商福通交易号不存在") } - - // 填充返回值 result.TransId = *resp.PayOrderId + switch resp.State { case g.SftInit, g.SftTradeAwait, g.SftTradeFail: @@ -642,7 +537,6 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err case g.SftTradeSuccess, g.SftTradeRefund, g.SftRefundIng: result.Status = m.TradeStatusSuccess - result.Success = &TradeSuccessResult{} switch resp.PayType { case "WECHAT": result.Success.Acquirer = m.TradeAcquirerWechat @@ -651,7 +545,7 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err case "UNIONPAY": result.Success.Acquirer = m.TradeAcquirerUnionPay } - result.Success.Payment = decimal.NewFromInt(resp.Amount).Div(decimal.NewFromInt(100)) + result.Success.Actual = decimal.NewFromInt(resp.Amount).Div(decimal.NewFromInt(100)) result.Success.Time, err = time.Parse("2006-01-02 15:04:05", *resp.PayTime) if err != nil { return nil, err @@ -663,7 +557,7 @@ func (s *tradeService) CheckTrade(data *ModifyTradeData) (*CheckTradeResult, err return nil, ErrTransactionNotSupported } - return result, nil + return &result, nil } func tradeProductKey(no string) string { @@ -675,17 +569,16 @@ func tradeLockKey(no string) string { } type CreateTradeData struct { - Platform m.TradePlatform `json:"platform" validate:"required"` - Method m.TradeMethod `json:"method" validate:"required"` - CouponCode *string `json:"coupon_code"` + Platform m.TradePlatform `json:"platform" validate:"required"` + Method m.TradeMethod `json:"method" validate:"required"` } type CreateTradeResult struct { - TradeNo string - PaymentUrl string + PayUrl string `json:"pay_url"` + TradeNo string `json:"trade_no"` } -type ModifyTradeData struct { +type TradeRef struct { TradeNo string `json:"trade_no" query:"trade_no" validate:"required"` Method m.TradeMethod `json:"method" validate:"required"` } @@ -693,12 +586,12 @@ type ModifyTradeData struct { type CheckTradeResult struct { TransId string Status m.TradeStatus - Success *TradeSuccessResult + Success TradeSuccessResult } type TradeSuccessResult struct { Acquirer m.TradeAcquirer - Payment decimal.Decimal + Actual decimal.Decimal Time time.Time } @@ -709,11 +602,16 @@ type OnTradeCompletedData struct { } type ProductInfo interface { - GetType() m.TradeType - GetSubject() string - GetAmount() decimal.Decimal - Serialize() (string, error) - Deserialize(str string) error + TradeDetail() (*TradeDetail, error) +} + +type TradeDetail struct { + Type m.TradeType `json:"type"` + Subject string `json:"subject"` + Amount decimal.Decimal `json:"amount"` + Actual decimal.Decimal `json:"actual"` + CouponId *int32 `json:"coupon_id,omitempty"` + Product ProductInfo `json:"product"` } type CompleteEvent interface { diff --git a/web/services/user.go b/web/services/user.go index 6e7a2dd..111c872 100644 --- a/web/services/user.go +++ b/web/services/user.go @@ -1,10 +1,8 @@ package services import ( - "encoding/json" "fmt" "platform/web/core" - g "platform/web/globals" m "platform/web/models" q "platform/web/queries" @@ -15,48 +13,29 @@ var User = &userService{} type userService struct{} -func (s *userService) UpdateBalanceByTrade(uid int32, info *RechargeProductInfo, trade *m.Trade) (err error) { - err = g.Redsync.WithLock(userBalanceKey(uid), func() error { - return q.Q.Transaction(func(q *q.Query) error { - - err = updateBalance(q, uid, info) - if err != nil { - return err - } - - // 生成账单 - subject := info.GetSubject() - amount := info.GetAmount() - err = q.Bill.Create(newForRecharge(uid, Bill.GenNo(), subject, amount, trade)) - if err != nil { - return core.NewServErr("生成账单失败", err) - } - - return nil - }) - }) - if err != nil { - return core.NewServErr("更新用户余额失败") - } - - return nil -} -func updateBalance(q *q.Query, uid int32, info *RechargeProductInfo) error { +func (s *userService) Get(q *q.Query, uid int32) (*m.User, error) { user, err := q.User. Where(q.User.ID.Eq(uid)).Take() if err != nil { - return core.NewServErr("查询用户失败", err) + return nil, core.NewServErr("查询用户失败", err) } + return user, nil +} - amount := info.GetAmount() +func (s *userService) UpdateBalance(q *q.Query, user *m.User, amount decimal.Decimal) error { balance := user.Balance.Add(amount) if balance.IsNegative() { return core.NewServErr("用户余额不足") } - _, err = q.User. - Where(q.User.ID.Eq(user.ID)). - UpdateSimple(q.User.Balance.Value(balance)) + _, err := q.User. + Where( + q.User.ID.Eq(user.ID), + q.User.Balance.Eq(user.Balance), + ). + UpdateSimple( + q.User.Balance.Value(balance), + ) if err != nil { return core.NewServErr("更新用户余额失败", err) } @@ -68,40 +47,16 @@ func userBalanceKey(uid int32) string { return fmt.Sprintf("user:%d:balance", uid) } -type RechargeProductInfo struct { +type UpdateBalanceData struct { Amount int `json:"amount"` } -func (r *RechargeProductInfo) GetType() m.TradeType { - return m.TradeTypeRecharge -} - -func (r *RechargeProductInfo) GetSubject() string { - return fmt.Sprintf("账户充值 - %s元", r.GetAmount().StringFixed(2)) -} - -func (r *RechargeProductInfo) GetAmount() decimal.Decimal { - return decimal.NewFromInt(int64(r.Amount)).Div(decimal.NewFromInt(100)) -} - -func (r *RechargeProductInfo) Serialize() (string, error) { - bytes, err := json.Marshal(r) - return string(bytes), err -} - -func (r *RechargeProductInfo) Deserialize(str string) error { - return json.Unmarshal([]byte(str), r) -} - -type UserOnTradeComplete struct{} - -func (u UserOnTradeComplete) Check(t m.TradeType) (ProductInfo, bool) { - if t == m.TradeTypeRecharge { - return &RechargeProductInfo{}, true - } - return nil, false -} - -func (u UserOnTradeComplete) OnTradeComplete(info ProductInfo, trade *m.Trade) error { - return User.UpdateBalanceByTrade(trade.UserID, info.(*RechargeProductInfo), trade) +func (c *UpdateBalanceData) TradeDetail() (*TradeDetail, error) { + amount := decimal.NewFromInt(int64(c.Amount)).Div(decimal.NewFromInt(100)) + return &TradeDetail{ + m.TradeTypeRecharge, + fmt.Sprintf("账户充值 - %s元", amount.StringFixed(2)), + amount, amount, + nil, c, + }, nil } diff --git a/web/tasks/task.go b/web/tasks/task.go index 8424183..55ef830 100644 --- a/web/tasks/task.go +++ b/web/tasks/task.go @@ -6,29 +6,34 @@ import ( "fmt" "log/slog" "platform/web/events" + q "platform/web/queries" s "platform/web/services" - "time" "github.com/hibiken/asynq" ) -func HandleCompleteTrade(_ context.Context, task *asynq.Task) (err error) { - event := new(events.CompleteTradeData) - err = json.Unmarshal(task.Payload(), event) - if err != nil { +func HandleCompleteTrade(_ context.Context, task *asynq.Task) error { + var event events.CloseTradeData + if err := json.Unmarshal(task.Payload(), &event); err != nil { return fmt.Errorf("解析任务参数失败: %w", err) } - data := &s.ModifyTradeData{ + data := s.TradeRef{ TradeNo: event.TradeNo, Method: event.Method, } - err = s.Trade.CompleteTrade(data) + // 尝试完成交易 + user, err := s.User.Get(q.Q, event.UserId) if err != nil { + return fmt.Errorf("获取用户失败: %w", err) + } + + if err := s.Trade.CompleteTrade(user, &data); err != nil { slog.Debug("完成交易失败[异步结束订单]", "err", err) - err = s.Trade.CancelTrade(data, time.Now()) - if err != nil { + + // 交易无法完成,关闭交易 + if err := s.Trade.CancelTrade(&data); err != nil { return fmt.Errorf("取消交易失败[异步结束订单]: %w", err) } } diff --git a/web/web.go b/web/web.go index ec22b99..ae5e512 100644 --- a/web/web.go +++ b/web/web.go @@ -89,7 +89,7 @@ func RunTask(ctx context.Context) error { var mux = asynq.NewServeMux() mux.HandleFunc(events.RemoveChannel, tasks.HandleRemoveChannel) - mux.HandleFunc(events.CompleteTrade, tasks.HandleCompleteTrade) + mux.HandleFunc(events.CloseTrade, tasks.HandleCompleteTrade) // 停止服务 go func() {