diff --git a/go.mod b/go.mod index 81cc11a..7ed15e8 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,6 @@ go 1.24.0 require ( github.com/alibabacloud-go/darabonba-openapi/v2 v2.1.7 github.com/alibabacloud-go/dysmsapi-20170525/v4 v4.1.3 - github.com/alicebob/miniredis/v2 v2.34.0 - github.com/glebarez/sqlite v1.11.0 github.com/go-playground/locales v0.14.1 github.com/go-playground/universal-translator v0.18.1 github.com/go-playground/validator/v10 v10.26.0 @@ -35,15 +33,12 @@ require ( github.com/alibabacloud-go/openapi-util v0.1.1 // indirect github.com/alibabacloud-go/tea v1.3.8 // indirect github.com/alibabacloud-go/tea-utils/v2 v2.0.7 // indirect - github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/aliyun/credentials-go v1.4.5 // indirect github.com/andybalholm/brotli v1.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/clbanning/mxj/v2 v2.7.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dustin/go-humanize v1.0.1 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect - github.com/glebarez/go-sqlite v1.21.2 // indirect github.com/go-sql-driver/mysql v1.9.1 // indirect github.com/gofrs/uuid v4.4.0+incompatible // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -61,14 +56,12 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/smartwalle/ncrypto v1.0.4 // indirect github.com/smartwalle/ngx v1.0.9 // indirect github.com/smartwalle/nsign v1.0.9 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/yuin/gopher-lua v1.1.1 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.37.0 // indirect golang.org/x/sync v0.12.0 // indirect @@ -80,8 +73,4 @@ require ( gorm.io/driver/mysql v1.5.7 // indirect gorm.io/driver/sqlite v1.5.7 // indirect gorm.io/hints v1.1.2 // indirect - modernc.org/libc v1.22.5 // indirect - modernc.org/mathutil v1.5.0 // indirect - modernc.org/memory v1.5.0 // indirect - modernc.org/sqlite v1.23.1 // indirect ) diff --git a/go.sum b/go.sum index 5015d9a..8c2551c 100644 --- a/go.sum +++ b/go.sum @@ -48,10 +48,6 @@ github.com/alibabacloud-go/tea-utils/v2 v2.0.6/go.mod h1:qxn986l+q33J5VkialKMqT/ github.com/alibabacloud-go/tea-utils/v2 v2.0.7 h1:WDx5qW3Xa5ZgJ1c8NfqJkF6w+AU5wB8835UdhPr6Ax0= github.com/alibabacloud-go/tea-utils/v2 v2.0.7/go.mod h1:qxn986l+q33J5VkialKMqT/TTs3E+U9MJpd001iWQ9I= github.com/alibabacloud-go/tea-xml v1.1.3/go.mod h1:Rq08vgCcCAjHyRi/M7xlHKUykZCEtyBy9+DPF6GgEu8= -github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= -github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= -github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= -github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8= github.com/aliyun/credentials-go v1.1.2/go.mod h1:ozcZaMR5kLM7pwtCMEpVmQ242suV6qTJya2bDq4X1Tw= github.com/aliyun/credentials-go v1.3.1/go.mod h1:8jKYhQuDawt8x2+fusqa1Y6mPxemTsBEN04dgcAcYz0= github.com/aliyun/credentials-go v1.3.6/go.mod h1:1LxUuX7L5YrZUWzBrRyk0SwSdH4OmPrib8NVePL3fxM= @@ -76,17 +72,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= -github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= -github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= -github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= -github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -124,8 +114,6 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= @@ -185,9 +173,6 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= -github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= @@ -228,8 +213,6 @@ github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3i github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.30/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= -github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191219195013-becbf705a915/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -405,11 +388,3 @@ gorm.io/plugin/dbresolver v1.5.3 h1:wFwINGZZmttuu9h7XpvbDHd8Lf9bb8GNzp/NpAMV2wU= gorm.io/plugin/dbresolver v1.5.3/go.mod h1:TSrVhaUg2DZAWP3PrHlDlITEJmNOkL0tFTjvTEsQ4XE= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE= -modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY= -modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= -modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= -modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= -modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= -modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM= -modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk= diff --git a/web/handlers/whitelist_test.go b/web/handlers/whitelist_test.go deleted file mode 100644 index 01a45cd..0000000 --- a/web/handlers/whitelist_test.go +++ /dev/null @@ -1,127 +0,0 @@ -package handlers - -import "testing" - -func Test_secureAddr(t *testing.T) { - type args struct { - str string - } - tests := []struct { - name string - args args - wantErr bool - }{ - // 有效的公网 IP 地址 - { - name: "有效公网IPv4地址", - args: args{str: "203.0.113.1"}, - wantErr: false, - }, - { - name: "有效公网IPv6地址", - args: args{str: "2001:db8::1"}, - wantErr: false, - }, - - // 私有地址 - { - name: "IPv4私有地址(10.x.x.x)", - args: args{str: "10.0.0.1"}, - wantErr: false, // 取决于需求,通常私有地址是被允许的全局单播地址 - }, - { - name: "IPv4私有地址(172.16.x.x)", - args: args{str: "172.16.0.1"}, - wantErr: false, - }, - { - name: "IPv4私有地址(192.168.x.x)", - args: args{str: "192.168.0.1"}, - wantErr: false, - }, - { - name: "IPv6私有地址(ULA)", - args: args{str: "fd00::1"}, - wantErr: false, - }, - - // 广播地址 - { - name: "IPv4本地广播地址", - args: args{str: "255.255.255.255"}, - wantErr: true, - }, - - // 未指定地址 - { - name: "IPv4未指定地址", - args: args{str: "0.0.0.0"}, - wantErr: true, - }, - { - name: "IPv6未指定地址", - args: args{str: "::"}, - wantErr: true, - }, - - // 回环地址 - { - name: "IPv4回环地址", - args: args{str: "127.0.0.1"}, - wantErr: true, - }, - { - name: "IPv6回环地址", - args: args{str: "::1"}, - wantErr: true, - }, - - // 组播地址 - { - name: "IPv4组播地址", - args: args{str: "224.0.0.1"}, - wantErr: true, - }, - { - name: "IPv6组播地址", - args: args{str: "ff00::1"}, - wantErr: true, - }, - - // 链路本地地址 - { - name: "IPv4链路本地地址", - args: args{str: "169.254.0.1"}, - wantErr: true, - }, - { - name: "IPv6链路本地地址", - args: args{str: "fe80::1"}, - wantErr: true, - }, - - // 格式错误的地址 - { - name: "格式错误的IP地址", - args: args{str: "not-an-ip"}, - wantErr: true, - }, - { - name: "不完整的IP地址", - args: args{str: "192.168.0"}, - wantErr: true, - }, - { - name: "超出范围的IP地址", - args: args{str: "256.256.256.256"}, - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := secureAddr(tt.args.str); (err != nil) != tt.wantErr { - t.Errorf("secureAddr() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} diff --git a/web/services/auth_test.go b/web/services/auth_test.go deleted file mode 100644 index fb0e1bb..0000000 --- a/web/services/auth_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package services - -import ( - "context" - "platform/web/auth" - "platform/web/models" - "reflect" - "testing" - "time" -) - -// mockSessionService 用于模拟Session服务的行为 -type mockSessionService struct { - createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) -} - -func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) { - panic("implement me") -} -func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) { - panic("implement me") -} -func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { - panic("implement me") -} -func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) { - return m.createFunc(ctx, authCtx) -} - -func Test_authService_OauthClientCredentials(t *testing.T) { - // 暂存原始Session服务 - originalSession := Session - defer func() { - // 测试结束后恢复原始Session服务 - Session = originalSession - }() - - // 预设的令牌详情 - expectedToken := &TokenDetails{ - AccessToken: "test-access-token", - RefreshToken: "test-refresh-token", - AccessTokenExpires: time.Now().Add(3600 * time.Second), - } - - type args struct { - ctx context.Context - client *models.Client - scope []string - } - tests := []struct { - name string - args args - mockCreateErr error - want *TokenDetails - wantErr bool - wantPayload auth.Payload - }{ - { - name: "成功 - 机密客户端 (Spec=0)", - args: args{ - ctx: context.Background(), - client: &models.Client{ID: 1, Spec: 3}, - scope: []string{"read", "write"}, - }, - mockCreateErr: nil, - want: expectedToken, - wantErr: false, - wantPayload: auth.Payload{ - Type: auth.PayloadSecuredServer, - Id: 1, - }, - }, - { - name: "成功 - 公共客户端 (Spec=1)", - args: args{ - ctx: context.Background(), - client: &models.Client{ID: 1, Spec: 1}, - scope: []string{"read", "write"}, - }, - mockCreateErr: nil, - want: expectedToken, - wantErr: false, - wantPayload: auth.Payload{ - Type: auth.PayloadPublicServer, - Id: 1, - }, - }, - { - name: "成功 - 公共客户端 (Spec=2)", - args: args{ - ctx: context.Background(), - client: &models.Client{ID: 1, Spec: 2}, - scope: []string{"read", "write"}, - }, - mockCreateErr: nil, - want: expectedToken, - wantErr: false, - wantPayload: auth.Payload{ - Type: auth.PayloadPublicServer, - Id: 1, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - // 为每个测试用例设置模拟的Session服务 - mockSession := &mockSessionService{ - createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) { - // 验证权限映射 - if len(authCtx.Permissions) != len(tt.args.scope) { - t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope)) - for key := range authCtx.Permissions { - if _, ok := authCtx.Permissions[key]; !ok { - t.Errorf("Permissions[%s] not found", key) - } - } - } - - // 验证Payload - if authCtx.Payload.Type != tt.wantPayload.Type { - t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type) - } - if authCtx.Payload.Id != tt.wantPayload.Id { - t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id) - } - - return expectedToken, tt.mockCreateErr - }, - } - - // 替换Session服务为模拟实现 - Session = mockSession - - s := &authService{} - got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...) - if (err != nil) != tt.wantErr { - t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/web/services/channel_test.go b/web/services/channel_test.go deleted file mode 100644 index cb0a895..0000000 --- a/web/services/channel_test.go +++ /dev/null @@ -1,1323 +0,0 @@ -package services - -import ( - "context" - "encoding/json" - "fmt" - "platform/web/auth" - g "platform/web/globals" - "platform/web/globals/orm" - "platform/web/models" - testutil2 "platform/web/testutil" - "reflect" - "strings" - "testing" - "time" - - "github.com/gofiber/fiber/v2/middleware/requestid" -) - -func Test_genPassPair(t *testing.T) { - tests := []struct { - name string - }{ - { - name: "正常生成随机用户名和密码", - }, - { - name: "多次调用生成不同的值", - }, - } - - // 第一个测试:检查生成的用户名和密码是否有效 - t.Run(tests[0].name, func(t *testing.T) { - username, password := genPassPair() - if username == "" { - t.Errorf("genPassPair() username is empty") - } - if password == "" { - t.Errorf("genPassPair() password is empty") - } - }) - - // 第二个测试:确保多次调用生成不同的值 - t.Run(tests[1].name, func(t *testing.T) { - username1, password1 := genPassPair() - username2, password2 := genPassPair() - - if username1 == username2 { - t.Errorf("genPassPair() generated the same username twice: %v", username1) - } - if password1 == password2 { - t.Errorf("genPassPair() generated the same password twice: %v", password1) - } - }) -} - -func Test_chKey(t *testing.T) { - type args struct { - channel *models.Channel - } - tests := []struct { - name string - args args - want string - }{ - { - name: "ID为1的通道", - args: args{ - channel: &models.Channel{ - ID: 1, - }, - }, - want: "channel:1", - }, - { - name: "ID为100的通道", - args: args{ - channel: &models.Channel{ - ID: 100, - }, - }, - want: "channel:100", - }, - { - name: "ID为0的通道", - args: args{ - channel: &models.Channel{ - ID: 0, - }, - }, - want: "channel:0", - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := chKey(tt.args.channel); got != tt.want { - t.Errorf("chKey() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_cache(t *testing.T) { - mr := testutil2.SetupRedisTest(t) - - type args struct { - ctx context.Context - channels []*models.Channel - } - - // 准备测试数据 - now := time.Now() - expiration := orm.LocalDateTime(now.Add(24 * time.Hour)) - - testChannels := []*models.Channel{ - { - ID: 1, - UserID: 100, - ProxyID: 10, - ProxyPort: 8080, - Protocol: 1, - Expiration: expiration, - }, - { - ID: 2, - UserID: 101, - ProxyID: 11, - ProxyPort: 8081, - Protocol: 3, - Expiration: expiration, - }, - } - - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "正常缓存多个通道", - args: args{ - ctx: context.Background(), - channels: testChannels, - }, - wantErr: false, - }, - { - name: "空通道列表", - args: args{ - ctx: context.Background(), - channels: []*models.Channel{}, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mr.FlushAll() // 清空 Redis 数据 - - if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { - t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // 验证缓存结果 - if len(tt.args.channels) > 0 { - for _, channel := range tt.args.channels { - key := fmt.Sprintf("channel:%d", channel.ID) - if !mr.Exists(key) { - t.Errorf("缓存未包含通道键 %s", key) - } else { - // 验证缓存的数据是否正确 - data, _ := mr.Get(key) - var cachedChannel models.Channel - err := json.Unmarshal([]byte(data), &cachedChannel) - if err != nil { - t.Errorf("无法解析缓存数据: %v", err) - } - if cachedChannel.ID != channel.ID { - t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID) - } - } - } - - // 验证是否设置了过期时间 - for _, channel := range tt.args.channels { - key := fmt.Sprintf("channel:%d", channel.ID) - ttl := mr.TTL(key) - if ttl <= 0 { - t.Errorf("键 %s 没有设置过期时间", key) - } - } - - // 验证是否添加了有序集合 - if !mr.Exists("tasks:channel") { - t.Errorf("ZAdd未创建有序集合 tasks:channel") - } - } - }) - } -} - -func Test_deleteCache(t *testing.T) { - mr := testutil2.SetupRedisTest(t) - - type args struct { - ctx context.Context - channels []*models.Channel - } - - // 准备测试数据 - testChannels := []*models.Channel{ - {ID: 1}, - {ID: 2}, - {ID: 3}, - } - - ctx := context.Background() - - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "正常删除多个通道缓存", - args: args{ - ctx: ctx, - channels: testChannels, - }, - wantErr: false, - }, - { - name: "空通道列表", - args: args{ - ctx: ctx, - channels: []*models.Channel{}, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mr.FlushAll() // 清空 Redis 数据 - - // 预先设置缓存数据 - for _, channel := range testChannels { - key := fmt.Sprintf("channel:%d", channel.ID) - data, _ := json.Marshal(channel) - mr.Set(key, string(data)) - mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间 - } - - if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { - t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // 验证删除结果 - for _, channel := range tt.args.channels { - key := fmt.Sprintf("channel:%d", channel.ID) - if mr.Exists(key) { - t.Errorf("通道键 %s 未被删除", key) - } - } - }) - } -} - -func Test_channelService_CreateChannel(t *testing.T) { - mr := testutil2.SetupRedisTest(t) - db := testutil2.SetupDBTest(t) - mc := testutil2.SetupCloudClientMock(t) - mg := testutil2.SetupGatewayClientMock(t) - - type args struct { - ctx context.Context - auth *auth.Context - resourceId int32 - protocol ChannelProtocol - authType ChannelAuthType - count int - nodeFilter []NodeFilterConfig - } - - // 准备测试数据 - ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") - var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} - var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} - mc.AutoQueryMock = func() (g.CloudConnectResp, error) { - return g.CloudConnectResp{ - "test-proxy": []g.AutoConfig{ - {Province: "河南省", Count: 10}, - }, - }, nil - } - - var user *models.User - var whitelists []*models.Whitelist - var proxy *models.Proxy - var resource *models.Resource - var resourcePss *models.ResourcePss - var resetDb = func() { - user = &models.User{ - ID: 101, - Phone: "12312341234", - } - db.Exec("delete from user where true") - db.Create(user) - - whitelists = []*models.Whitelist{ - {ID: 1, UserID: 101, Host: "123.123.123.123"}, - {ID: 2, UserID: 101, Host: "456.456.456.456"}, - {ID: 3, UserID: 101, Host: "789.789.789.789"}, - } - db.Exec("delete from whitelist where true") - db.Create(whitelists) - - proxy = &models.Proxy{ - ID: 1, - Version: 1, - Name: "test-proxy", - Host: "111.111.111.111", - Type: 1, - Secret: "test:secret", - } - db.Exec("delete from proxy where true") - db.Create(proxy) - - resource = &models.Resource{ - ID: 1, - UserID: 101, - Active: true, - } - db.Exec("delete from resource where true") - db.Create(resource) - - resourcePss = &models.ResourcePss{ - ID: 1, - ResourceID: 1, - Type: 1, - Live: 180, - Expire: orm.LocalDateTime(time.Now().AddDate(1, 0, 0)), - DailyLimit: 10000, - } - db.Exec("delete from resource_pss where true") - db.Create(resourcePss) - - db.Exec("delete from channel where true") - } - - tests := []struct { - name string - args args - setup func() - wantErr bool - wantErrContains string - want func(t *testing.T, got []*PortInfo) error - }{ - { - name: "用户创建HTTP密码通道", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 1, - protocol: ProtocolHTTP, - authType: ChannelAuthTypePass, - count: 3, - nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, - }, - setup: func() { - mr.FlushAll() - resetDb() - - mc.ConnectMock = func(param g.CloudConnectReq) error { - if param.Uuid != proxy.Name { - return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) - } - if len(param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) - } - if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ - {Province: "河南省", Count: 10}, - {Province: "北京市", Count: 6}, - }) { - return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) - } - return nil - } - - mg.PortConfigsMock = func(c *testutil2.MockGatewayClient, params []g.PortConfigsReq) error { - if c.Host != proxy.Host { - return fmt.Errorf("代理主机不符合预期: %s", c.Host) - } - if len(params) != 3 { - return fmt.Errorf("端口数量不符合预期: %d", len(params)) - } - for _, param := range params { - if param.Status != true { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.AutoEdgeConfig == nil { - return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) - } - if param.Userpass == nil || *param.Userpass == "" { - return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass) - } - if param.Whitelist == nil || len(*param.Whitelist) != 0 { - return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) - } - config := param.AutoEdgeConfig - if config.Province != "北京市" { - return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) - } - if *config.Count != 1 { - return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) - } - if config.PacketLoss != 30 { - return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) - } - } - return nil - } - }, - want: func(t *testing.T, got []*PortInfo) error { - // 验证返回结果 - if len(got) != 3 { - return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) - } - - // 验证结果 - var gotMap = make(map[int]PortInfo) - for _, port := range got { - if port.Proto != 1 { - return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) - } - if port.Host != proxy.Host { - return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) - } - gotMap[port.Port] = *port - } - - // 验证数据库字段 - var channels []*models.Channel - db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) - for _, ch := range channels { - if ch.Protocol != 1 { - return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) - } - if ch.UserID != userAuth.Payload.Id { - return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) - } - // todo 多代理分配策略,验证 proxy_host - if ch.ProxyID != proxy.ID { - return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) - } - var info, ok = gotMap[int(ch.ProxyPort)] - if !ok { - return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) - } - if ch.AuthPass != true && ch.AuthIP != false { - return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) - } - if ch.Protocol != int32(info.Proto) { - return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) - } - if ch.Username != *info.Username { - return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username) - } - if ch.Password != *info.Password { - return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) - } - if time.Time(ch.Expiration).IsZero() { - return fmt.Errorf("通道过期时间不应为空") - } - - // 检查Redis缓存中的字段 - key := fmt.Sprintf("channel:%d", ch.ID) - if !mr.Exists(key) { - return fmt.Errorf("redis缓存中应有键 %s", key) - } - - data, _ := mr.Get(key) - var cache models.Channel - err := json.Unmarshal([]byte(data), &cache) - if err != nil { - return fmt.Errorf("无法解析缓存数据: %v", err) - } - if reflect.DeepEqual(cache, *ch) { - return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) - } - } - - // 检查跨天用量更新 - var pss models.ResourcePss - db.Where("resource_id = ?", 1).First(&pss) - if pss.DailyUsed != 3 { - return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) - } - if time.Time(pss.DailyLast).IsZero() { - return fmt.Errorf("套餐每日最后更新时间不应为空") - } - if pss.Used != 3 { - return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) - } - - return nil - }, - }, - { - name: "用户创建HTTP白名单通道", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 1, - protocol: ProtocolHTTP, - authType: ChannelAuthTypeIp, - count: 3, - nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, - }, - setup: func() { - mr.FlushAll() - resetDb() - - mc.ConnectMock = func(param g.CloudConnectReq) error { - if param.Uuid != proxy.Name { - return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) - } - if len(param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) - } - if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ - {Province: "河南省", Count: 10}, - {Province: "北京市", Count: 6}, - }) { - return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) - } - return nil - } - - mg.PortConfigsMock = func(c *testutil2.MockGatewayClient, params []g.PortConfigsReq) error { - if c.Host != proxy.Host { - return fmt.Errorf("代理主机不符合预期: %s", c.Host) - } - if len(params) != 3 { - return fmt.Errorf("端口数量不符合预期: %d", len(params)) - } - for _, param := range params { - if param.Status != true { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.AutoEdgeConfig == nil { - return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) - } - if param.Userpass == nil || *param.Userpass != "" { - return fmt.Errorf("用户名密码不符合预期: %v", *param.Userpass) - } - if param.Whitelist == nil || len(*param.Whitelist) == 0 { - return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) - } - config := param.AutoEdgeConfig - if config.Province != "北京市" { - return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) - } - if *config.Count != 1 { - return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) - } - if config.PacketLoss != 30 { - return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) - } - } - return nil - } - }, - want: func(t *testing.T, got []*PortInfo) error { - // 验证返回结果 - if len(got) != 3 { - return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) - } - - // 验证结果 - var gotMap = make(map[int]PortInfo) - for _, port := range got { - if port.Proto != 1 { - return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) - } - if port.Host != proxy.Host { - return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) - } - gotMap[port.Port] = *port - } - - // 验证数据库字段 - var channels []*models.Channel - db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) - for _, ch := range channels { - if ch.Protocol != 1 { - return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) - } - if ch.UserID != userAuth.Payload.Id { - return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) - } - // todo 多代理分配策略,验证 proxy_host - if ch.ProxyID != proxy.ID { - return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) - } - var info, ok = gotMap[int(ch.ProxyPort)] - if !ok { - return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) - } - if ch.AuthPass != false && ch.AuthIP != true { - return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) - } - if ch.Protocol != int32(info.Proto) { - return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) - } - if time.Time(ch.Expiration).IsZero() { - return fmt.Errorf("通道过期时间不应为空") - } - - // 检查Redis缓存中的字段 - key := fmt.Sprintf("channel:%d", ch.ID) - if !mr.Exists(key) { - return fmt.Errorf("redis缓存中应有键 %s", key) - } - - data, _ := mr.Get(key) - var cache models.Channel - err := json.Unmarshal([]byte(data), &cache) - if err != nil { - return fmt.Errorf("无法解析缓存数据: %v", err) - } - if reflect.DeepEqual(cache, *ch) { - return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) - } - } - - // 检查跨天用量更新 - var pss models.ResourcePss - db.Where("resource_id = ?", 1).First(&pss) - if pss.DailyUsed != 3 { - return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) - } - if time.Time(pss.DailyLast).IsZero() { - return fmt.Errorf("套餐每日最后更新时间不应为空") - } - if pss.Used != 3 { - return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) - } - - return nil - }, - }, - { - name: "管理员替用户创建HTTP密码通道", - args: args{ - ctx: ctx, - auth: adminAuth, - resourceId: 1, - protocol: ProtocolSocks5, - authType: ChannelAuthTypePass, - count: 3, - nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, - }, - setup: func() { - mr.FlushAll() - resetDb() - - mc.ConnectMock = func(param g.CloudConnectReq) error { - if param.Uuid != proxy.Name { - return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) - } - if len(param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) - } - if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ - {Province: "河南省", Count: 10}, - {Province: "北京市", Count: 6}, - }) { - return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) - } - return nil - } - - mg.PortConfigsMock = func(c *testutil2.MockGatewayClient, params []g.PortConfigsReq) error { - if c.Host != proxy.Host { - return fmt.Errorf("代理主机不符合预期: %s", c.Host) - } - if len(params) != 3 { - return fmt.Errorf("端口数量不符合预期: %d", len(params)) - } - for _, param := range params { - if param.Status != true { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.AutoEdgeConfig == nil { - return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) - } - if param.Userpass == nil || *param.Userpass == "" { - return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass) - } - if param.Whitelist == nil || len(*param.Whitelist) != 0 { - return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) - } - config := param.AutoEdgeConfig - if config.Province != "北京市" { - return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) - } - if *config.Count != 1 { - return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) - } - if config.PacketLoss != 30 { - return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) - } - } - return nil - } - }, - want: func(t *testing.T, got []*PortInfo) error { - // 验证返回结果 - if len(got) != 3 { - return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) - } - - // 验证结果 - var gotMap = make(map[int]PortInfo) - for _, port := range got { - if port.Proto != 3 { - return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) - } - if port.Host != proxy.Host { - return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) - } - gotMap[port.Port] = *port - } - - // 验证数据库字段 - var channels []*models.Channel - db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) - for _, ch := range channels { - if ch.Protocol != 1 { - return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) - } - if ch.UserID != userAuth.Payload.Id { - return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) - } - // todo 多代理分配策略,验证 proxy_host - if ch.ProxyID != proxy.ID { - return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) - } - var info, ok = gotMap[int(ch.ProxyPort)] - if !ok { - return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) - } - if ch.AuthPass != true && ch.AuthIP != false { - return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) - } - if ch.Protocol != int32(info.Proto) { - return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) - } - if ch.Username != *info.Username { - return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username) - } - if ch.Password != *info.Password { - return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) - } - if time.Time(ch.Expiration).IsZero() { - return fmt.Errorf("通道过期时间不应为空") - } - - // 检查Redis缓存中的字段 - key := fmt.Sprintf("channel:%d", ch.ID) - if !mr.Exists(key) { - return fmt.Errorf("redis缓存中应有键 %s", key) - } - - data, _ := mr.Get(key) - var cache models.Channel - err := json.Unmarshal([]byte(data), &cache) - if err != nil { - return fmt.Errorf("无法解析缓存数据: %v", err) - } - if reflect.DeepEqual(cache, *ch) { - return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) - } - } - - // 检查跨天用量更新 - var pss models.ResourcePss - db.Where("resource_id = ?", 1).First(&pss) - if pss.DailyUsed != 3 { - return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) - } - if time.Time(pss.DailyLast).IsZero() { - return fmt.Errorf("套餐每日最后更新时间不应为空") - } - if pss.Used != 3 { - return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) - } - - return nil - }, - }, - { - name: "套餐不存在", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 999, - protocol: ProtocolHTTP, - authType: ChannelAuthTypeIp, - count: 1, - }, - setup: func() { - mr.FlushAll() - resetDb() - }, - wantErr: true, - wantErrContains: "无权限访问", - }, - { - name: "套餐没有权限", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 2, - protocol: ProtocolHTTP, - authType: ChannelAuthTypeIp, - count: 1, - }, - setup: func() { - mr.FlushAll() - resetDb() - - resource2 := &models.Resource{ - ID: 2, - UserID: 102, - Active: true, - } - db.Create(resource2) - var resourcePss2 = &models.ResourcePss{ - ID: 2, - ResourceID: 2, - Type: 1, - Live: 180, - Expire: orm.LocalDateTime(time.Now().AddDate(1, 0, 0)), - DailyLimit: 10000, - } - db.Create(resourcePss2) - }, - wantErr: true, - wantErrContains: "无权限访问", - }, - { - name: "套餐配额不足", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 2, - protocol: ProtocolHTTP, - authType: ChannelAuthTypeIp, - count: 10, - }, - setup: func() { - mr.FlushAll() - resetDb() - - // 创建一个配额几乎用完的资源包 - resource2 := models.Resource{ - ID: 2, - UserID: 101, - Active: true, - } - resourcePss2 := models.ResourcePss{ - ID: 2, - ResourceID: 2, - Type: 2, - Quota: 100, - Used: 91, - Live: 180, - DailyLimit: 10000, - } - db.Create(&resource2).Create(&resourcePss2) - }, - wantErr: true, - wantErrContains: "套餐配额不足", - }, - { - name: "端口数量达到上限", - args: args{ - ctx: ctx, - auth: userAuth, - resourceId: 1, - protocol: ProtocolHTTP, - authType: ChannelAuthTypeIp, - count: 1, - }, - setup: func() { - mr.FlushAll() - resetDb() - mc.AutoQueryMock = func() (g.CloudConnectResp, error) { - return g.CloudConnectResp{ - "test-proxy": []g.AutoConfig{ - {Count: 20000}, - }, - }, nil - } - // 创建大量占用端口的通道 - var channels = make([]models.Channel, 10000) - var expr = time.Now().Add(time.Hour) - for i := range channels { - channels[i] = models.Channel{ - ProxyID: 1, - ProxyPort: int32(i + 10000), - UserID: 101, - Expiration: orm.LocalDateTime(expr), - } - } - db.CreateInBatches(channels, 1000) - }, - wantErr: true, - wantErrContains: "端口数量到达上限", - }, - // todo 多地区混杂条件提取 - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setup != nil { - tt.setup() - } - - s := &channelService{} - got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...) - - // 检查错误或结果 - if tt.wantErr { - if err == nil { - t.Errorf("CreateChannel() 应当返回错误") - return - } - if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains) - } - return - } - - if err != nil { - t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr) - return - } - - // 使用检查函数验证结果 - if err := tt.want(t, got); err != nil { - t.Errorf("结果验证失败: %v", err) - } - }) - } -} - -func Test_channelService_RemoveChannels(t *testing.T) { - mr := testutil2.SetupRedisTest(t) - md := testutil2.SetupDBTest(t) - mg := testutil2.SetupGatewayClientMock(t) - mc := testutil2.SetupCloudClientMock(t) - - type args struct { - ctx context.Context - auth *auth.Context - id []int32 - } - - // 准备测试数据 - ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") - - // 创建用户 - var user = &models.User{ - ID: 101, - Phone: "12312341234", - } - md.Create(user) - - // 创建管理员 - var adminUser = &models.User{ - ID: 100, - Phone: "99999999999", - } - md.Create(adminUser) - - // 认证上下文 - var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} - var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} - - // 创建代理 - var proxy = &models.Proxy{ - ID: 1, - Version: 1, - Name: "test-proxy", - Host: "111.111.111.111", - Type: 1, - Secret: "test:secret", - } - md.Create(proxy) - - var proxy2 = &models.Proxy{ - ID: 2, - Version: 1, - Name: "test-proxy-2", - Host: "222.222.222.222", - Type: 1, - Secret: "test:secret2", - } - md.Create(proxy2) - - // 清空数据库函数 - var clearDb = func() { - md.Exec("delete from channel where true") - mr.FlushAll() - } - - tests := []struct { - name string - args args - setup func() - wantErr bool - wantErrContains string - want func(t *testing.T) error - }{ - { - name: "管理员删除多个通道", - args: args{ - ctx: ctx, - auth: adminAuth, - id: []int32{1, 2, 3}, - }, - setup: func() { - mr.FlushAll() - clearDb() - - // 创建通道 - channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - } - - // 保存预设数据 - md.Create(channels) - for _, channel := range channels { - key := fmt.Sprintf("channel:%d", channel.ID) - data, _ := json.Marshal(channel) - _ = mr.Set(key, string(data)) - } - - // 模拟网关客户端的响应 - mg.PortActiveMock = func(m *testutil2.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) { - switch { - case m.Host == proxy.Host: - return map[string]g.PortData{ - "10001": {Edge: []string{"edge1", "edge4"}}, - "10002": {Edge: []string{"edge2"}}, - }, nil - case m.Host == proxy2.Host: - return map[string]g.PortData{ - "10001": {Edge: []string{"edge3"}}, - }, nil - } - return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) - } - mg.PortConfigsMock = func(m *testutil2.MockGatewayClient, params []g.PortConfigsReq) error { - switch { - case m.Host == proxy.Host: - for _, param := range params { - if param.Port != 10001 && param.Port != 10002 { - return fmt.Errorf("端口配置不符合预期: %d", param.Port) - } - if param.Status != false { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.Edge == nil || len(*param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期1: %v", param.Edge) - } - } - return nil - case m.Host == proxy2.Host: - for _, param := range params { - if param.Port != 10001 { - return fmt.Errorf("端口配置不符合预期: %d", param.Port) - } - if param.Status != false { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.Edge == nil || len(*param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期2: %v", param.Edge) - } - } - return nil - } - return fmt.Errorf("代理主机不符合预期: %s", m.Host) - } - mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) { - switch { - case param.Uuid == proxy.Name: - var edges = []string{"edge1", "edge2", "edge4"} - if !testutil2.SliceEqual(edges, param.Edge) { - return 0, fmt.Errorf("边缘节点不符合预期3: %v", param.Edge) - } - if len(param.Config) != 0 { - return 0, fmt.Errorf("配置不符合预期: %v", param.Config) - } - return len(param.Edge), nil - case param.Uuid == proxy2.Name: - var edges = []string{"edge3"} - if !testutil2.SliceEqual(edges, param.Edge) { - return 0, fmt.Errorf("边缘节点不符合预期4: %v", param.Edge) - } - if len(param.Config) != 0 { - return 0, fmt.Errorf("配置不符合预期: %v", param.Config) - } - return len(param.Edge), nil - } - return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) - } - }, - want: func(t *testing.T) error { - // 检查通道是否被软删除 - var count int64 - md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) - if count > 0 { - return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) - } - - // 检查Redis缓存是否被删除 - for _, id := range []int32{1, 2, 3} { - key := fmt.Sprintf("channel:%d", id) - if mr.Exists(key) { - return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) - } - } - return nil - }, - }, - { - name: "用户删除自己的通道", - args: args{ - ctx: ctx, - auth: userAuth, - id: []int32{1, 2, 3}, - }, - setup: func() { - mr.FlushAll() - clearDb() - - // 创建通道 - channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - } - - // 保存预设数据 - md.Create(channels) - for _, channel := range channels { - key := fmt.Sprintf("channel:%d", channel.ID) - data, _ := json.Marshal(channel) - _ = mr.Set(key, string(data)) - } - - // 模拟网关客户端的响应 - mg.PortActiveMock = func(m *testutil2.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) { - switch { - case m.Host == proxy.Host: - return map[string]g.PortData{ - "10001": {Edge: []string{"edge1", "edge4"}}, - "10002": {Edge: []string{"edge2"}}, - }, nil - case m.Host == proxy2.Host: - return map[string]g.PortData{ - "10001": {Edge: []string{"edge3"}}, - }, nil - } - return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) - } - mg.PortConfigsMock = func(m *testutil2.MockGatewayClient, params []g.PortConfigsReq) error { - switch { - case m.Host == proxy.Host: - for _, param := range params { - if param.Port != 10001 && param.Port != 10002 { - return fmt.Errorf("端口配置不符合预期: %d", param.Port) - } - if param.Status != false { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.Edge == nil || len(*param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期5: %v", param.Edge) - } - } - return nil - case m.Host == proxy2.Host: - for _, param := range params { - if param.Port != 10001 { - return fmt.Errorf("端口配置不符合预期: %d", param.Port) - } - if param.Status != false { - return fmt.Errorf("端口状态不符合预期: %v", param.Status) - } - if param.Edge == nil || len(*param.Edge) != 0 { - return fmt.Errorf("边缘节点不符合预期6: %v", param.Edge) - } - } - return nil - } - return fmt.Errorf("代理主机不符合预期: %s", m.Host) - } - mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) { - switch { - case param.Uuid == proxy.Name: - var edges = []string{"edge1", "edge2", "edge4"} - if !testutil2.SliceEqual(edges, param.Edge) { - return 0, fmt.Errorf("边缘节点不符合预期7: %v", param.Edge) - } - if len(param.Config) != 0 { - return 0, fmt.Errorf("配置不符合预期: %v", param.Config) - } - return len(param.Edge), nil - case param.Uuid == proxy2.Name: - var edges = []string{"edge3"} - if !testutil2.SliceEqual(edges, param.Edge) { - return 0, fmt.Errorf("边缘节点不符合预期8: %v", param.Edge) - } - if len(param.Config) != 0 { - return 0, fmt.Errorf("配置不符合预期: %v", param.Config) - } - return len(param.Edge), nil - } - return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) - } - }, - want: func(t *testing.T) error { - // 检查通道是否被软删除 - var count int64 - md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) - if count > 0 { - return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) - } - - // 检查Redis缓存是否被删除 - for _, id := range []int32{1, 2, 3} { - key := fmt.Sprintf("channel:%d", id) - if mr.Exists(key) { - return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) - } - } - return nil - }, - }, - { - name: "用户删除不属于自己的通道", - args: args{ - ctx: ctx, - auth: userAuth, - id: []int32{1, 2, 3}, - }, - setup: func() { - mr.FlushAll() - clearDb() - - // 创建通道 - channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - {ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: orm.LocalDateTime(time.Now().Add(24 * time.Hour))}, - } - - // 保存预设数据 - md.Create(channels) - for _, channel := range channels { - key := fmt.Sprintf("channel:%d", channel.ID) - data, _ := json.Marshal(channel) - _ = mr.Set(key, string(data)) - } - }, - wantErr: true, - wantErrContains: "无权限访问", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if tt.setup != nil { - tt.setup() - } - - s := &channelService{} - err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...) - - // 检查错误 - if tt.wantErr { - if err == nil { - t.Errorf("RemoveChannels() 应当返回错误") - return - } - if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { - t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains) - } - return - } - - if err != nil { - t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr) - return - } - - // 检查数据库和缓存是否正确设置 - - if err := tt.want(t); err != nil { - t.Errorf("RemoveChannels() 结果验证失败: %v", err) - } - }) - } -} diff --git a/web/services/session_test.go b/web/services/session_test.go deleted file mode 100644 index 3750604..0000000 --- a/web/services/session_test.go +++ /dev/null @@ -1,422 +0,0 @@ -package services - -import ( - "context" - "errors" - "platform/web/auth" - "platform/web/testutil" - "reflect" - "testing" - "time" -) - -// 创建测试用的认证上下文 -func createTestAuthContext() auth.Context { - //goland:noinspection ALL - return auth.Context{ - Payload: auth.Payload{ - Type: auth.PayloadUser, - Id: 1001, - }, - Permissions: map[string]struct{}{ - "read": {}, - "write": {}, - }, - Metadata: map[string]interface{}{ - "username": "testuser", - "email": "test@example.com", - }, - } -} - -func Test_sessionService_Create(t *testing.T) { - mr := testutil.SetupRedisTest(t) - ctx := context.Background() - authCtx := createTestAuthContext() - - type args struct { - ctx context.Context - auth auth.Context - } - tests := []struct { - name string - args args - want func(*TokenDetails) bool - wantErr bool - }{ - { - name: "创建会话", - args: args{ - ctx: ctx, - auth: authCtx, - }, - want: func(td *TokenDetails) bool { - // 验证令牌存在且格式正确 - if td.AccessToken == "" || td.RefreshToken == "" { - return false - } - // 验证到期时间在未来 - now := time.Now() - if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) { - return false - } - // 验证认证信息正确 - if !reflect.DeepEqual(td.Auth, authCtx) { - return false - } - return true - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mr.FlushAll() - s := &sessionService{} - got, err := s.Create(tt.args.ctx, tt.args.auth, true) - if (err != nil) != tt.wantErr { - t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !tt.want(got) { - t.Errorf("Create() got = %v, want to satisfy conditions", got) - } - - // 验证 Redis 中是否有相应的键 - accessKey := accessKey(got.AccessToken) - refreshKey := refreshKey(got.RefreshToken) - - if !mr.Exists(accessKey) { - t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey) - } - - if !mr.Exists(refreshKey) { - t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey) - } - }) - } -} - -func Test_sessionService_Find(t *testing.T) { - testutil.SetupRedisTest(t) - ctx := context.Background() - authCtx := createTestAuthContext() - s := &sessionService{} - - // 创建一个有效的会话 - td, err := s.Create(ctx, authCtx, true) - if err != nil { - t.Fatalf("无法创建测试会话: %v", err) - } - - validToken := td.AccessToken - invalidToken := "invalid-token" - - type args struct { - ctx context.Context - token string - } - tests := []struct { - name string - args args - want *auth.Context - wantErr error - }{ - { - name: "查找有效令牌", - args: args{ - ctx: ctx, - token: validToken, - }, - want: &authCtx, - wantErr: nil, - }, - { - name: "查找无效令牌", - args: args{ - ctx: ctx, - token: invalidToken, - }, - want: nil, - wantErr: ErrInvalidToken, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := s.Find(tt.args.ctx, tt.args.token) - if !errors.Is(err, tt.wantErr) { - t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) - return - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("Find() got = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_sessionService_Refresh(t *testing.T) { - mr := testutil.SetupRedisTest(t) - ctx := context.Background() - authCtx := createTestAuthContext() - s := &sessionService{} - - // 创建一个初始会话 - td, err := s.Create(ctx, authCtx, true) - if err != nil { - t.Fatalf("无法创建初始会话: %v", err) - } - - validRefreshToken := td.RefreshToken - invalidRefreshToken := "invalid-refresh-token" - originalAccessToken := td.AccessToken - - type args struct { - ctx context.Context - refreshToken string - } - tests := []struct { - name string - args args - want func(*TokenDetails) bool - wantErr bool - }{ - { - name: "使用有效的刷新令牌", - args: args{ - ctx: ctx, - refreshToken: validRefreshToken, - }, - want: func(td *TokenDetails) bool { - if td.AccessToken == "" || td.RefreshToken == "" { - return false - } - // 新的令牌应该与旧的不同 - if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken { - return false - } - // 验证认证信息一致 - if !reflect.DeepEqual(td.Auth, authCtx) { - return false - } - return true - }, - wantErr: false, - }, - { - name: "使用无效的刷新令牌", - args: args{ - ctx: ctx, - refreshToken: invalidRefreshToken, - }, - want: nil, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken) - if (err != nil) != tt.wantErr { - t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr) - return - } - - if tt.want != nil && !tt.want(got) { - t.Errorf("Refresh() got = %v, want to satisfy conditions", got) - } - - if !tt.wantErr && got != nil { - // 验证旧的令牌已被删除 - if mr.Exists(accessKey(originalAccessToken)) { - t.Errorf("原始访问令牌键应被删除") - } - if mr.Exists(refreshKey(validRefreshToken)) { - t.Errorf("原始刷新令牌键应被删除") - } - - // 验证新的令牌已被添加 - if !mr.Exists(accessKey(got.AccessToken)) { - t.Errorf("新的访问令牌键应存在") - } - if !mr.Exists(refreshKey(got.RefreshToken)) { - t.Errorf("新的刷新令牌键应存在") - } - } - }) - } -} - -func Test_sessionService_Remove(t *testing.T) { - mr := testutil.SetupRedisTest(t) - ctx := context.Background() - authCtx := createTestAuthContext() - s := &sessionService{} - - // 创建一个会话 - td, err := s.Create(ctx, authCtx, true) - if err != nil { - t.Fatalf("无法创建测试会话: %v", err) - } - - validAccessToken := td.AccessToken - validRefreshToken := td.RefreshToken - - type args struct { - ctx context.Context - accessToken string - refreshToken string - } - tests := []struct { - name string - args args - wantErr bool - }{ - { - name: "删除有效会话", - args: args{ - ctx: ctx, - accessToken: validAccessToken, - refreshToken: validRefreshToken, - }, - wantErr: false, - }, - { - name: "删除已删除的会话", - args: args{ - ctx: ctx, - accessToken: validAccessToken, - refreshToken: validRefreshToken, - }, - wantErr: false, // 删除不存在的会话不应报错 - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr { - t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr) - } - - // 验证键已被删除 - if mr.Exists(accessKey(tt.args.accessToken)) { - t.Errorf("访问令牌键应已被删除") - } - if mr.Exists(refreshKey(tt.args.refreshToken)) { - t.Errorf("刷新令牌键应已被删除") - } - }) - } -} - -func TestAuthContext_AnyPermission(t *testing.T) { - type fields struct { - Payload auth.Payload - Permissions map[string]struct{} - Metadata map[string]interface{} - } - type args struct { - requiredPermission []string - } - tests := []struct { - name string - fields fields - args args - want bool - }{ - { - name: "用户拥有所需权限", - fields: fields{ - Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, - Permissions: map[string]struct{}{ - "read": {}, - "write": {}, - }, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"read"}, - }, - want: true, - }, - { - name: "用户拥有至少一个所需权限", - fields: fields{ - Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, - Permissions: map[string]struct{}{ - "read": {}, - }, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"read", "admin"}, - }, - want: true, - }, - { - name: "用户没有所需权限", - fields: fields{ - Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, - Permissions: map[string]struct{}{ - "read": {}, - }, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"admin", "delete"}, - }, - want: false, - }, - { - name: "空权限列表", - fields: fields{ - Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, - Permissions: map[string]struct{}{}, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"read"}, - }, - want: false, - }, - { - name: "nil权限列表", - fields: fields{ - Payload: auth.Payload{Type: auth.PayloadUser, Id: 1}, - Permissions: nil, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"read"}, - }, - want: false, - }, - { - name: "nil认证上下文", - fields: fields{ - Payload: auth.Payload{}, - Permissions: nil, - Metadata: nil, - }, - args: args{ - requiredPermission: []string{"read"}, - }, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - a := &auth.Context{ - Payload: tt.fields.Payload, - Permissions: tt.fields.Permissions, - Metadata: tt.fields.Metadata, - } - if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want { - t.Errorf("AnyPermission() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/web/services/verifier_test.go b/web/services/verifier_test.go deleted file mode 100644 index 123622c..0000000 --- a/web/services/verifier_test.go +++ /dev/null @@ -1,244 +0,0 @@ -package services - -import ( - "context" - "errors" - "platform/web/testutil" - "strconv" - "testing" - "time" - - "github.com/alicebob/miniredis/v2" -) - -func Test_verifierService_SendSms(t *testing.T) { - type args struct { - ctx context.Context - phone string - purpose VerifierSmsPurpose - } - tests := []struct { - name string - args args - setup func(mr *miniredis.Miniredis) - wantErr bool - wantErrType error - }{ - { - name: "正常发送成功(无旧验证码)", - args: args{ - ctx: context.Background(), - phone: "13812345678", - purpose: VerifierSmsPurposeLogin, - }, - setup: func(mr *miniredis.Miniredis) {}, - wantErr: false, - }, - { - name: "正常发送成功(有旧验证码)", - args: args{ - ctx: context.Background(), - phone: "13812345679", - purpose: VerifierSmsPurposeLogin, - }, - setup: func(mr *miniredis.Miniredis) { - key := smsKey("13812345679", VerifierSmsPurposeLogin) - mr.Set(key, "123456") - mr.SetTTL(key, 10*time.Minute) - }, - wantErr: false, - }, - { - name: "发送频率过快", - args: args{ - ctx: context.Background(), - phone: "13812345680", - purpose: VerifierSmsPurposeLogin, - }, - setup: func(mr *miniredis.Miniredis) { - key := smsKey("13812345680", VerifierSmsPurposeLogin) + ":lock" - mr.Set(key, "") - mr.SetTTL(key, 1*time.Minute) - }, - wantErr: true, - wantErrType: VerifierServiceSendLimitErr(0), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 设置 Redis 测试环境 - mr := testutil.SetupRedisTest(t) - defer mr.Close() - - // 执行测试前的设置 - if tt.setup != nil { - tt.setup(mr) - } - - s := &verifierService{} - err := s.SendSms(tt.args.ctx, tt.args.phone, tt.args.purpose) - - // 验证错误 - if (err != nil) != tt.wantErr { - t.Errorf("SendSms() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // 验证错误类型 - if tt.wantErr && tt.wantErrType != nil { - var verifierServiceSendLimitErr VerifierServiceSendLimitErr - if errors.As(err, &verifierServiceSendLimitErr) { - var verifierServiceSendLimitErr VerifierServiceSendLimitErr - if !errors.As(tt.wantErrType, &verifierServiceSendLimitErr) { - t.Errorf("SendSms() error type = %T, wantErrType %T", err, tt.wantErrType) - } - } - } - - // 验证 Redis 中的记录 - if !tt.wantErr { - key := smsKey(tt.args.phone, tt.args.purpose) - keyLock := key + ":lock" - - // 验证码应存在 - val, err := mr.Get(key) - if err != nil { - t.Errorf("验证码应存在但不存在: %v", err) - } - - // 限速锁应存在 - _, err = mr.Get(keyLock) - if err != nil { - t.Errorf("限速锁应存在但不存在: %v", err) - } - - // 验证码应为6位数字 - code, err := strconv.Atoi(val) - if err != nil || code < 100000 || code > 999999 { - t.Errorf("验证码应为6位数字: %v", val) - } - } - }) - } -} - -func Test_verifierService_VerifySms(t *testing.T) { - type args struct { - ctx context.Context - phone string - code string - } - tests := []struct { - name string - args args - setup func(mr *miniredis.Miniredis) - wantErr bool - wantErrType error - }{ - { - name: "验证码正确", - args: args{ - ctx: context.Background(), - phone: "13812345678", - code: "123456", - }, - setup: func(mr *miniredis.Miniredis) { - key := smsKey("13812345678", VerifierSmsPurposeLogin) - keyLock := key + ":lock" - mr.Set(key, "123456") - mr.SetTTL(key, 10*time.Minute) - mr.Set(keyLock, "") - mr.SetTTL(keyLock, 1*time.Minute) - }, - wantErr: false, - }, - { - name: "验证码错误", - args: args{ - ctx: context.Background(), - phone: "13812345679", - code: "654321", - }, - setup: func(mr *miniredis.Miniredis) { - key := smsKey("13812345679", VerifierSmsPurposeLogin) - keyLock := key + ":lock" - mr.Set(key, "123456") - mr.SetTTL(key, 10*time.Minute) - mr.Set(keyLock, "") - mr.SetTTL(keyLock, 1*time.Minute) - }, - wantErr: true, - wantErrType: ErrVerifierServiceInvalid, - }, - { - name: "验证码过期", - args: args{ - ctx: context.Background(), - phone: "13812345680", - code: "123456", - }, - setup: func(mr *miniredis.Miniredis) { - // 不设置验证码,模拟过期情况 - }, - wantErr: true, - wantErrType: ErrVerifierServiceInvalid, - }, - { - name: "手机号错误", - args: args{ - ctx: context.Background(), - phone: "13812345681", - code: "123456", - }, - setup: func(mr *miniredis.Miniredis) { - // 设置一个不同手机号的验证码 - key := smsKey("13800000000", VerifierSmsPurposeLogin) - mr.Set(key, "123456") - mr.SetTTL(key, 10*time.Minute) - }, - wantErr: true, - wantErrType: ErrVerifierServiceInvalid, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // 设置 Redis 测试环境 - mr := testutil.SetupRedisTest(t) - defer mr.Close() - - // 执行测试前的设置 - if tt.setup != nil { - tt.setup(mr) - } - - s := &verifierService{} - err := s.VerifySms(tt.args.ctx, tt.args.phone, tt.args.code) - if (err != nil) != tt.wantErr { - t.Errorf("VerifySms() error = %v, wantErr %v", err, tt.wantErr) - return - } - - // 检查错误类型 - if tt.wantErr && tt.wantErrType != nil && !errors.Is(err, tt.wantErrType) { - t.Errorf("VerifySms() error = %v, wantErrType %v", err, tt.wantErrType) - return - } - - // 验证成功后 Redis 中应该没有该记录 - if err == nil { - key := smsKey(tt.args.phone, VerifierSmsPurposeLogin) - keyLock := key + ":lock" - - _, redisErr := mr.Get(key) - if redisErr == nil { - t.Errorf("验证码验证成功后应删除,但仍存在") - } - - _, redisErr = mr.Get(keyLock) - if redisErr == nil { - t.Errorf("限速锁验证成功后应删除,但仍存在") - } - } - }) - } -} diff --git a/web/testutil/db.go b/web/testutil/db.go deleted file mode 100644 index 874b44f..0000000 --- a/web/testutil/db.go +++ /dev/null @@ -1,39 +0,0 @@ -package testutil - -import ( - g"platform/web/globals" - m"platform/web/models" - q "platform/web/queries" - "testing" - - "github.com/glebarez/sqlite" - "gorm.io/gorm" -) - -// SetupDBTest 创建一个基于 SQLite 内存数据库的 GORM 连接 -func SetupDBTest(t *testing.T) *gorm.DB { - // 使用 SQLite 内存数据库 - gormDB, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) - if err != nil { - t.Fatalf("gorm 打开 SQLite 内存数据库失败: %v", err) - } - - // 自动迁移数据表结构 - err = gormDB.AutoMigrate( - &m.User{}, - &m.Whitelist{}, - &m.Resource{}, - &m.ResourcePss{}, - &m.Proxy{}, - &m.Channel{}, - ) - if err != nil { - t.Fatalf("自动迁移表结构失败: %v", err) - } - - // 设置全局数据库连接 - q.SetDefault(gormDB) - g.DB = gormDB - - return gormDB -} diff --git a/web/testutil/redis.go b/web/testutil/redis.go deleted file mode 100644 index 9e9fd08..0000000 --- a/web/testutil/redis.go +++ /dev/null @@ -1,30 +0,0 @@ -package testutil - -import ( - g "platform/web/globals" - "testing" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" -) - -// SetupRedisTest 创建一个测试用的Redis实例 -// 返回miniredis实例,使用t.Cleanup自动清理资源 -func SetupRedisTest(t *testing.T) *miniredis.Miniredis { - mr, err := miniredis.Run() - if err != nil { - t.Fatalf("设置 miniredis 失败: %v", err) - } - - // 替换 Redis 客户端为测试客户端 - g.Redis = redis.NewClient(&redis.Options{ - Addr: mr.Addr(), - }) - - // 使用t.Cleanup确保测试结束后恢复原始客户端并关闭miniredis - t.Cleanup(func() { - mr.Close() - }) - - return mr -} diff --git a/web/testutil/remote.go b/web/testutil/remote.go deleted file mode 100644 index eb0735b..0000000 --- a/web/testutil/remote.go +++ /dev/null @@ -1,150 +0,0 @@ -package testutil - -import ( - g "platform/web/globals" - "sync" - "testing" -) - -// MockCloudClient 是CloudClient接口的测试实现 -type MockCloudClient struct { - // 存储预期结果的字段 - EdgesMock func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) - ConnectMock func(param g.CloudConnectReq) error - DisconnectMock func(param g.CloudDisconnectReq) (int, error) - AutoQueryMock func() (g.CloudConnectResp, error) - - // 记录调用历史 - EdgesCalls []g.CloudEdgesReq - ConnectCalls []g.CloudConnectReq - DisconnectCalls []g.CloudDisconnectReq - AutoQueryCalls int - - // 用于并发安全 - mu sync.Mutex -} - -// 确保MockCloudClient实现了CloudClient接口 -var _ g.CloudClient = (*MockCloudClient)(nil) - -func (m *MockCloudClient) CloudEdges(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.EdgesCalls = append(m.EdgesCalls, param) - if m.EdgesMock != nil { - return m.EdgesMock(param) - } - return &g.CloudEdgesResp{}, nil -} - -func (m *MockCloudClient) CloudConnect(param g.CloudConnectReq) error { - m.mu.Lock() - defer m.mu.Unlock() - m.ConnectCalls = append(m.ConnectCalls, param) - if m.ConnectMock != nil { - return m.ConnectMock(param) - } - return nil -} - -func (m *MockCloudClient) CloudDisconnect(param g.CloudDisconnectReq) (int, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.DisconnectCalls = append(m.DisconnectCalls, param) - if m.DisconnectMock != nil { - return m.DisconnectMock(param) - } - return 0, nil -} - -func (m *MockCloudClient) CloudAutoQuery() (g.CloudConnectResp, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.AutoQueryCalls++ - if m.AutoQueryMock != nil { - return m.AutoQueryMock() - } - return g.CloudConnectResp{}, nil -} - -// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复 -func SetupCloudClientMock(t *testing.T) *MockCloudClient { - mock := &MockCloudClient{ - EdgesMock: func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) { - panic("not implemented") - }, - ConnectMock: func(param g.CloudConnectReq) error { - panic("not implemented") - }, - DisconnectMock: func(param g.CloudDisconnectReq) (int, error) { - panic("not implemented") - }, - AutoQueryMock: func() (g.CloudConnectResp, error) { - panic("not implemented") - }, - } - g.Cloud = mock - - return mock -} - -// MockGatewayClient 是GatewayClient接口的测试实现 -type MockGatewayClient struct { - Host string -} - -// 确保MockGatewayClient实现了GatewayClient接口 -var _ g.GatewayClient = (*MockGatewayClient)(nil) - -func (m *MockGatewayClient) GatewayPortConfigs(params []g.PortConfigsReq) error { - testGatewayBase.mu.Lock() - defer testGatewayBase.mu.Unlock() - testGatewayBase.PortConfigsCalls = append(testGatewayBase.PortConfigsCalls, params) - if testGatewayBase.PortConfigsMock != nil { - return testGatewayBase.PortConfigsMock(m, params) - } - return nil -} - -func (m *MockGatewayClient) GatewayPortActive(param ...g.PortActiveReq) (map[string]g.PortData, error) { - testGatewayBase.mu.Lock() - defer testGatewayBase.mu.Unlock() - testGatewayBase.PortActiveCalls = append(testGatewayBase.PortActiveCalls, param) - if testGatewayBase.PortActiveMock != nil { - return testGatewayBase.PortActiveMock(m, param...) - } - return map[string]g.PortData{}, nil -} - -type GatewayClientIns struct { - - // 存储预期结果的字段 - PortConfigsMock func(c *MockGatewayClient, params []g.PortConfigsReq) error - PortActiveMock func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) - - // 记录调用历史 - PortConfigsCalls [][]g.PortConfigsReq - PortActiveCalls [][]g.PortActiveReq - - // 用于并发安全 - mu sync.Mutex -} - -var testGatewayBase = &GatewayClientIns{ - PortConfigsMock: func(c *MockGatewayClient, params []g.PortConfigsReq) error { - panic("not implemented") - }, - PortActiveMock: func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) { - panic("not implemented") - }, -} - -// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数 -func SetupGatewayClientMock(t *testing.T) *GatewayClientIns { - g.GatewayInitializer = func(url, username, password string) g.GatewayClient { - return &MockGatewayClient{ - Host: url, - } - } - return testGatewayBase -} diff --git a/web/testutil/tools.go b/web/testutil/tools.go deleted file mode 100644 index 64fb3b6..0000000 --- a/web/testutil/tools.go +++ /dev/null @@ -1,26 +0,0 @@ -package testutil - -import ( - "reflect" - "sort" -) - -// SliceEqual 检查两个字符串切片是否完全相等(忽略顺序) -func SliceEqual(a, b []string) bool { - if len(a) != len(b) { - return false - } - - // 复制切片以避免修改原始数据 - aCopy := make([]string, len(a)) - bCopy := make([]string, len(b)) - copy(aCopy, a) - copy(bCopy, b) - - // 排序两个切片 - sort.Strings(aCopy) - sort.Strings(bCopy) - - // 比较排序后的切片 - return reflect.DeepEqual(aCopy, bCopy) -}