重构项目结构,将 orm 和 rds 包迁移到 web/globals
This commit is contained in:
10
pkg/env/env.go
vendored
10
pkg/env/env.go
vendored
@@ -268,7 +268,6 @@ var (
|
||||
WechatPayPublicKey string
|
||||
WechatPayApiCert string
|
||||
WechatPayCallbackUrl string
|
||||
WechatPayProduction = false
|
||||
)
|
||||
|
||||
func loadWechatPay() {
|
||||
@@ -312,15 +311,6 @@ func loadWechatPay() {
|
||||
if WechatPayCallbackUrl == "" {
|
||||
panic("环境变量 WECHATPAY_CALLBACK_URL 的值不能为空")
|
||||
}
|
||||
|
||||
_WechatPayProduction := os.Getenv("WECHATPAY_PRODUCTION")
|
||||
if _WechatPayProduction != "" {
|
||||
value, err := strconv.ParseBool(_WechatPayProduction)
|
||||
if err != nil {
|
||||
panic("环境变量 WECHATPAY_PRODUCTION 的值不是布尔值")
|
||||
}
|
||||
WechatPayProduction = value
|
||||
}
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
)
|
||||
|
||||
type LocalDateTime time.Time
|
||||
|
||||
var formats = []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
|
||||
var t time.Time
|
||||
|
||||
if strValue, ok := value.(string); ok {
|
||||
var timeValue time.Time
|
||||
for _, format := range formats {
|
||||
timeValue, err = time.Parse(format, strValue)
|
||||
if err == nil {
|
||||
t = timeValue
|
||||
break
|
||||
}
|
||||
}
|
||||
t = timeValue
|
||||
} else {
|
||||
nullTime := &sql.NullTime{}
|
||||
err = nullTime.Scan(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nullTime == nil {
|
||||
return nil
|
||||
}
|
||||
t = nullTime.Time
|
||||
}
|
||||
*ldt = LocalDateTime(time.Date(
|
||||
t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local,
|
||||
))
|
||||
return
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt LocalDateTime) Value() (driver.Value, error) {
|
||||
return time.Time(ldt).Local(), nil
|
||||
}
|
||||
|
||||
// GormDataType gorm common data type
|
||||
//
|
||||
//goland:noinspection GoMixedReceiverTy
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt LocalDateTime) GormDataType() string {
|
||||
return "ldt"
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt LocalDateTime) GobEncode() ([]byte, error) {
|
||||
return time.Time(ldt).GobEncode()
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt *LocalDateTime) GobDecode(b []byte) error {
|
||||
return (*time.Time)(ldt).GobDecode(b)
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt LocalDateTime) MarshalJSON() ([]byte, error) {
|
||||
return time.Time(ldt).MarshalJSON()
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt *LocalDateTime) UnmarshalJSON(b []byte) error {
|
||||
return (*time.Time)(ldt).UnmarshalJSON(b)
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package orm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
"log/slog"
|
||||
"platform/pkg/env"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func Init() {
|
||||
|
||||
// 连接数据库
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai",
|
||||
env.DbHost, env.DbUserName, env.DbPassword, env.DbName, env.DbPort,
|
||||
)
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
TranslateError: true,
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// 连接池
|
||||
conn, err := db.DB()
|
||||
if err != nil {
|
||||
slog.Error("gorm 初始化数据库失败:", slog.Any("err", err))
|
||||
panic(err)
|
||||
}
|
||||
conn.SetMaxIdleConns(10)
|
||||
conn.SetMaxOpenConns(100)
|
||||
|
||||
DB = db
|
||||
}
|
||||
|
||||
func Exit() error {
|
||||
if DB != nil {
|
||||
conn, err := DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return conn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type WithAlias interface {
|
||||
Alias() string
|
||||
}
|
||||
|
||||
func Alias(model WithAlias) func(db gen.Dao) gen.Dao {
|
||||
return func(db gen.Dao) gen.Dao {
|
||||
return db.Unscoped().Where(field.NewBool(model.Alias(), "deleted_at").IsNull())
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package rds
|
||||
|
||||
import (
|
||||
"net"
|
||||
"platform/pkg/env"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
var Client *redis.Client
|
||||
|
||||
func Init() {
|
||||
Client = redis.NewClient(&redis.Options{
|
||||
Addr: net.JoinHostPort(env.RedisHost, env.RedisPort),
|
||||
DB: env.RedisDb,
|
||||
Password: env.RedisPass,
|
||||
})
|
||||
}
|
||||
|
||||
func Exit() error {
|
||||
if Client != nil {
|
||||
return Client.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"platform/pkg/orm"
|
||||
"platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupDBTest 创建一个基于 SQLite 内存数据库的 GORM 连接
|
||||
func SetupDBTest(t *testing.T) *gorm.DB {
|
||||
// 使用 SQLite 内存数据库
|
||||
gormDB, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("gorm 打开 SQLite 内存数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移数据表结构
|
||||
err = gormDB.AutoMigrate(
|
||||
&models.User{},
|
||||
&models.Whitelist{},
|
||||
&models.Resource{},
|
||||
&models.ResourcePss{},
|
||||
&models.Proxy{},
|
||||
&models.Channel{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("自动迁移表结构失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置全局数据库连接
|
||||
q.SetDefault(gormDB)
|
||||
orm.DB = gormDB
|
||||
|
||||
return gormDB
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"platform/pkg/rds"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRedisTest 创建一个测试用的Redis实例
|
||||
// 返回miniredis实例,使用t.Cleanup自动清理资源
|
||||
func SetupRedisTest(t *testing.T) *miniredis.Miniredis {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("设置 miniredis 失败: %v", err)
|
||||
}
|
||||
|
||||
// 替换 Redis 客户端为测试客户端
|
||||
rds.Client = redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// 使用t.Cleanup确保测试结束后恢复原始客户端并关闭miniredis
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return mr
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
g "platform/web/globals"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockCloudClient 是CloudClient接口的测试实现
|
||||
type MockCloudClient struct {
|
||||
// 存储预期结果的字段
|
||||
EdgesMock func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error)
|
||||
ConnectMock func(param g.CloudConnectReq) error
|
||||
DisconnectMock func(param g.CloudDisconnectReq) (int, error)
|
||||
AutoQueryMock func() (g.CloudConnectResp, error)
|
||||
|
||||
// 记录调用历史
|
||||
EdgesCalls []g.CloudEdgesReq
|
||||
ConnectCalls []g.CloudConnectReq
|
||||
DisconnectCalls []g.CloudDisconnectReq
|
||||
AutoQueryCalls int
|
||||
|
||||
// 用于并发安全
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// 确保MockCloudClient实现了CloudClient接口
|
||||
var _ g.CloudClient = (*MockCloudClient)(nil)
|
||||
|
||||
func (m *MockCloudClient) CloudEdges(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.EdgesCalls = append(m.EdgesCalls, param)
|
||||
if m.EdgesMock != nil {
|
||||
return m.EdgesMock(param)
|
||||
}
|
||||
return &g.CloudEdgesResp{}, nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudConnect(param g.CloudConnectReq) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ConnectCalls = append(m.ConnectCalls, param)
|
||||
if m.ConnectMock != nil {
|
||||
return m.ConnectMock(param)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudDisconnect(param g.CloudDisconnectReq) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.DisconnectCalls = append(m.DisconnectCalls, param)
|
||||
if m.DisconnectMock != nil {
|
||||
return m.DisconnectMock(param)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudAutoQuery() (g.CloudConnectResp, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.AutoQueryCalls++
|
||||
if m.AutoQueryMock != nil {
|
||||
return m.AutoQueryMock()
|
||||
}
|
||||
return g.CloudConnectResp{}, nil
|
||||
}
|
||||
|
||||
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||
mock := &MockCloudClient{
|
||||
EdgesMock: func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
ConnectMock: func(param g.CloudConnectReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
DisconnectMock: func(param g.CloudDisconnectReq) (int, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
AutoQueryMock: func() (g.CloudConnectResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
g.Cloud = mock
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockGatewayClient 是GatewayClient接口的测试实现
|
||||
type MockGatewayClient struct {
|
||||
Host string
|
||||
}
|
||||
|
||||
// 确保MockGatewayClient实现了GatewayClient接口
|
||||
var _ g.GatewayClient = (*MockGatewayClient)(nil)
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortConfigs(params []g.PortConfigsReq) error {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortConfigsCalls = append(testGatewayBase.PortConfigsCalls, params)
|
||||
if testGatewayBase.PortConfigsMock != nil {
|
||||
return testGatewayBase.PortConfigsMock(m, params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortActive(param ...g.PortActiveReq) (map[string]g.PortData, error) {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortActiveCalls = append(testGatewayBase.PortActiveCalls, param)
|
||||
if testGatewayBase.PortActiveMock != nil {
|
||||
return testGatewayBase.PortActiveMock(m, param...)
|
||||
}
|
||||
return map[string]g.PortData{}, nil
|
||||
}
|
||||
|
||||
type GatewayClientIns struct {
|
||||
|
||||
// 存储预期结果的字段
|
||||
PortConfigsMock func(c *MockGatewayClient, params []g.PortConfigsReq) error
|
||||
PortActiveMock func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error)
|
||||
|
||||
// 记录调用历史
|
||||
PortConfigsCalls [][]g.PortConfigsReq
|
||||
PortActiveCalls [][]g.PortActiveReq
|
||||
|
||||
// 用于并发安全
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var testGatewayBase = &GatewayClientIns{
|
||||
PortConfigsMock: func(c *MockGatewayClient, params []g.PortConfigsReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
PortActiveMock: func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数
|
||||
func SetupGatewayClientMock(t *testing.T) *GatewayClientIns {
|
||||
g.GatewayInitializer = func(url, username, password string) g.GatewayClient {
|
||||
return &MockGatewayClient{
|
||||
Host: url,
|
||||
}
|
||||
}
|
||||
return testGatewayBase
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// SliceEqual 检查两个字符串切片是否完全相等(忽略顺序)
|
||||
func SliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 复制切片以避免修改原始数据
|
||||
aCopy := make([]string, len(a))
|
||||
bCopy := make([]string, len(b))
|
||||
copy(aCopy, a)
|
||||
copy(bCopy, b)
|
||||
|
||||
// 排序两个切片
|
||||
sort.Strings(aCopy)
|
||||
sort.Strings(bCopy)
|
||||
|
||||
// 比较排序后的切片
|
||||
return reflect.DeepEqual(aCopy, bCopy)
|
||||
}
|
||||
Reference in New Issue
Block a user