From c3abb42bcea4a3d69ee6377b13a397cadcfeb7db Mon Sep 17 00:00:00 2001 From: luorijun Date: Sat, 22 Mar 2025 16:37:24 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=A4=E8=AF=81=E6=8E=88=E6=9D=83=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BB=A3=E7=A0=81=E4=B8=8E=E4=B8=9A=E5=8A=A1=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E8=B4=A8=E9=87=8F=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 13 +- go.sum | 27 +- web/handlers/login.go | 8 +- web/handlers/oauth.go | 2 +- web/services/auth.go | 17 +- web/services/auth_test.go | 146 ++++++++++ web/services/session.go | 14 +- web/services/session_test.go | 486 ++++++++++++++++++++++++++++++++++ web/services/verifier.go | 23 +- web/services/verifier_test.go | 257 ++++++++++++++++++ 10 files changed, 960 insertions(+), 33 deletions(-) create mode 100644 web/services/auth_test.go create mode 100644 web/services/session_test.go create mode 100644 web/services/verifier_test.go diff --git a/go.mod b/go.mod index 578d9ba..a56cc3c 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,13 @@ module platform go 1.24.0 require ( + github.com/alicebob/miniredis/v2 v2.34.0 github.com/gofiber/fiber/v2 v2.52.6 github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/lmittmann/tint v1.0.7 github.com/redis/go-redis/v9 v9.3.0 + github.com/stretchr/testify v1.8.1 golang.org/x/crypto v0.17.0 gorm.io/driver/postgres v1.5.11 gorm.io/gen v0.3.26 @@ -16,8 +18,10 @@ require ( ) require ( + github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -27,18 +31,23 @@ require ( github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect + github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/mod v0.17.0 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + golang.org/x/mod v0.21.0 // indirect golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.23.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.26.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect gorm.io/driver/mysql v1.5.7 // indirect gorm.io/hints v1.1.0 // indirect diff --git a/go.sum b/go.sum index 94e8577..010c07b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= +github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= +github.com/alicebob/miniredis/v2 v2.34.0/go.mod h1:kWShP4b58T1CW0Y5dViCd5ztzrDqRWqM3nksiyXk5s8= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= @@ -6,6 +10,7 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -38,6 +43,10 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lmittmann/tint v1.0.7 h1:D/0OqWZ0YOGZ6AyC+5Y2kD8PBEzBk6rFHVSfOqCkF9Y= github.com/lmittmann/tint v1.0.7/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= @@ -58,9 +67,15 @@ github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0 github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -70,14 +85,16 @@ github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdI github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -111,10 +128,12 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= +golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/web/handlers/login.go b/web/handlers/login.go index 713aaf3..32f722b 100644 --- a/web/handlers/login.go +++ b/web/handlers/login.go @@ -42,13 +42,13 @@ func Login(c *fiber.Ctx) error { func loginByPhone(c *fiber.Ctx, req *LoginReq) error { // 验证验证码 - ok, err := services.Verifier.VerifySms(c.Context(), req.Username, req.Password) + err := services.Verifier.VerifySms(c.Context(), req.Username, req.Password) if err != nil { + if errors.Is(err, services.ErrVerifierServiceInvalid) { + return fiber.NewError(fiber.StatusBadRequest, "验证码错误") + } return err } - if !ok { - return fiber.NewError(fiber.StatusBadRequest, "验证码错误") - } // 查找用户 todo 获取权限信息 var tx = q.Q.Begin() diff --git a/web/handlers/oauth.go b/web/handlers/oauth.go index 20dcf7c..6207bca 100644 --- a/web/handlers/oauth.go +++ b/web/handlers/oauth.go @@ -104,7 +104,7 @@ func clientCredentials(c *fiber.Ctx, req *TokenReq) error { } scope := strings.Split(req.Scope, ",") - token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope) + token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...) if err != nil { return sendError(c, err.(services.AuthServiceOauthError)) } diff --git a/web/services/auth.go b/web/services/auth.go index 4f8c07f..274005c 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -8,8 +8,6 @@ import ( var Auth = &authService{} -type authService struct{} - type AuthServiceError string func (e AuthServiceError) Error() string { @@ -31,6 +29,8 @@ var ( ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type") ) +type authService struct{} + // OauthAuthorizationCode 验证授权码 func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) { // TODO: 从数据库验证授权码 @@ -38,7 +38,7 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models } // OauthClientCredentials 验证客户端凭证 -func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...[]string) (*TokenDetails, error) { +func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...string) (*TokenDetails, error) { var clientType PayloadType switch client.Spec { @@ -47,14 +47,17 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models case 1: clientType = PayloadClientPublic case 2: - clientType = PayloadClientConfidential + clientType = PayloadClientPublic + } + + var permissions = make(map[string]struct{}, len(scope)) + for _, item := range scope { + permissions[item] = struct{}{} } // 保存会话并返回令牌 auth := AuthContext{ - Permissions: map[string]struct{}{ - "client": {}, - }, + Permissions: permissions, Payload: Payload{ Type: clientType, Id: client.ID, diff --git a/web/services/auth_test.go b/web/services/auth_test.go new file mode 100644 index 0000000..2c2c2f2 --- /dev/null +++ b/web/services/auth_test.go @@ -0,0 +1,146 @@ +package services + +import ( + "context" + "platform/web/models" + "reflect" + "testing" + "time" +) + +// mockSessionService 用于模拟Session服务的行为 +type mockSessionService struct { + createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error) +} + +func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) { + panic("implement me") +} +func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) { + panic("implement me") +} +func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { + panic("implement me") +} +func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) { + return m.createFunc(ctx, auth) +} + +func Test_authService_OauthClientCredentials(t *testing.T) { + // 暂存原始Session服务 + originalSession := Session + defer func() { + // 测试结束后恢复原始Session服务 + Session = originalSession + }() + + // 预设的令牌详情 + expectedToken := &TokenDetails{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccessTokenExpires: time.Now().Add(3600 * time.Second), + } + + type args struct { + ctx context.Context + client *models.Client + scope []string + } + tests := []struct { + name string + args args + mockCreateErr error + want *TokenDetails + wantErr bool + wantPayload Payload + }{ + { + name: "成功 - 机密客户端 (Spec=0)", + args: args{ + ctx: context.Background(), + client: &models.Client{ID: 1, Spec: 0}, + scope: []string{"read", "write"}, + }, + mockCreateErr: nil, + want: expectedToken, + wantErr: false, + wantPayload: Payload{ + Type: PayloadClientConfidential, + Id: 1, + }, + }, + { + name: "成功 - 公共客户端 (Spec=1)", + args: args{ + ctx: context.Background(), + client: &models.Client{ID: 1, Spec: 1}, + scope: []string{"read", "write"}, + }, + mockCreateErr: nil, + want: expectedToken, + wantErr: false, + wantPayload: Payload{ + Type: PayloadClientPublic, + Id: 1, + }, + }, + { + name: "成功 - 公共客户端 (Spec=2)", + args: args{ + ctx: context.Background(), + client: &models.Client{ID: 1, Spec: 2}, + scope: []string{"read", "write"}, + }, + mockCreateErr: nil, + want: expectedToken, + wantErr: false, + wantPayload: Payload{ + Type: PayloadClientPublic, + Id: 1, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // 为每个测试用例设置模拟的Session服务 + mockSession := &mockSessionService{ + createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) { + // 验证权限映射 + if len(auth.Permissions) != len(tt.args.scope) { + t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope)) + for key := range auth.Permissions { + if _, ok := auth.Permissions[key]; !ok { + t.Errorf("Permissions[%s] not found", key) + } + } + } + + // 验证Payload + if auth.Payload.Type != tt.wantPayload.Type { + t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type) + } + if auth.Payload.Id != tt.wantPayload.Id { + t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id) + } + + return expectedToken, tt.mockCreateErr + }, + } + + // 替换Session服务为模拟实现 + Session = mockSession + + s := &authService{} + got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...) + if (err != nil) != tt.wantErr { + t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/web/services/session.go b/web/services/session.go index b116254..033c550 100644 --- a/web/services/session.go +++ b/web/services/session.go @@ -14,9 +14,17 @@ import ( // region SessionService -var Session = &sessionService{} +var Session SessionServiceInter = &sessionService{} -type sessionService struct { +type SessionServiceInter interface { + // Find 通过访问令牌获取会话信息 + Find(ctx context.Context, token string) (*AuthContext, error) + // Create 创建一个新的会话 + Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) + // Refresh 刷新一个会话 + Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) + // Remove 删除会话 + Remove(ctx context.Context, accessToken, refreshToken string) error } type SessionServiceError string @@ -29,6 +37,8 @@ var ( ErrInvalidToken = SessionServiceError("invalid_token") ) +type sessionService struct{} + // Find 通过访问令牌获取会话信息 func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) { diff --git a/web/services/session_test.go b/web/services/session_test.go new file mode 100644 index 0000000..25645a2 --- /dev/null +++ b/web/services/session_test.go @@ -0,0 +1,486 @@ +package services + +import ( + "context" + "errors" + "platform/init/rds" + "reflect" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +// 设置 Redis 模拟服务器 +func setupTestRedis(t *testing.T) *miniredis.Miniredis { + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("无法启动 miniredis: %v", err) + } + + // 替换 Redis 客户端为测试客户端 + origClient := rds.Client + rds.Client = redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + t.Cleanup(func() { + mr.Close() + rds.Client = origClient + }) + + return mr +} + +// 创建测试用的认证上下文 +func createTestAuthContext() AuthContext { + return AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 1001, + }, + Permissions: map[string]struct{}{ + "read": {}, + "write": {}, + }, + Metadata: map[string]interface{}{ + "username": "testuser", + "email": "test@example.com", + }, + } +} + +func Test_sessionService_Create(t *testing.T) { + mr := setupTestRedis(t) + ctx := context.Background() + auth := createTestAuthContext() + + type args struct { + ctx context.Context + auth AuthContext + config []SessionConfig + } + tests := []struct { + name string + args args + want func(*TokenDetails) bool + wantErr bool + }{ + { + name: "使用默认配置创建会话", + args: args{ + ctx: ctx, + auth: auth, + }, + want: func(td *TokenDetails) bool { + // 验证令牌存在且格式正确 + if td.AccessToken == "" || td.RefreshToken == "" { + return false + } + // 验证到期时间在未来 + now := time.Now() + if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) { + return false + } + // 验证认证信息正确 + if !reflect.DeepEqual(td.Auth, auth) { + return false + } + return true + }, + wantErr: false, + }, + { + name: "使用自定义配置创建会话", + args: args{ + ctx: ctx, + auth: auth, + config: []SessionConfig{ + { + AccessTokenDuration: 10 * time.Minute, + RefreshTokenDuration: 24 * time.Hour, + }, + }, + }, + want: func(td *TokenDetails) bool { + // 验证令牌存在且格式正确 + if td.AccessToken == "" || td.RefreshToken == "" { + return false + } + // 验证到期时间在未来且接近预期时间 + now := time.Now() + expectedAccessExpiry := now.Add(10 * time.Minute) + expectedRefreshExpiry := now.Add(24 * time.Hour) + + accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry) + refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry) + + if accessDiff < -2*time.Second || accessDiff > 2*time.Second { + return false + } + if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second { + return false + } + + // 验证认证信息正确 + if !reflect.DeepEqual(td.Auth, auth) { + return false + } + return true + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mr.FlushAll() + s := &sessionService{} + got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...) + if (err != nil) != tt.wantErr { + t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.want(got) { + t.Errorf("Create() got = %v, want to satisfy conditions", got) + } + + // 验证 Redis 中是否有相应的键 + accessKey := accessKey(got.AccessToken) + refreshKey := refreshKey(got.RefreshToken) + + if !mr.Exists(accessKey) { + t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey) + } + + if !mr.Exists(refreshKey) { + t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey) + } + }) + } +} + +func Test_sessionService_Find(t *testing.T) { + _ = setupTestRedis(t) + ctx := context.Background() + auth := createTestAuthContext() + s := &sessionService{} + + // 创建一个有效的会话 + td, err := s.Create(ctx, auth) + if err != nil { + t.Fatalf("无法创建测试会话: %v", err) + } + + validToken := td.AccessToken + invalidToken := "invalid-token" + + type args struct { + ctx context.Context + token string + } + tests := []struct { + name string + args args + want *AuthContext + wantErr error + }{ + { + name: "查找有效令牌", + args: args{ + ctx: ctx, + token: validToken, + }, + want: &auth, + wantErr: nil, + }, + { + name: "查找无效令牌", + args: args{ + ctx: ctx, + token: invalidToken, + }, + want: nil, + wantErr: ErrInvalidToken, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Find(tt.args.ctx, tt.args.token) + if !errors.Is(err, tt.wantErr) { + t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Find() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_sessionService_Refresh(t *testing.T) { + mr := setupTestRedis(t) + ctx := context.Background() + auth := createTestAuthContext() + s := &sessionService{} + + // 创建一个初始会话 + td, err := s.Create(ctx, auth) + if err != nil { + t.Fatalf("无法创建初始会话: %v", err) + } + + validRefreshToken := td.RefreshToken + invalidRefreshToken := "invalid-refresh-token" + originalAccessToken := td.AccessToken + + type args struct { + ctx context.Context + refreshToken string + config []SessionConfig + } + tests := []struct { + name string + args args + want func(*TokenDetails) bool + wantErr bool + }{ + { + name: "使用有效的刷新令牌", + args: args{ + ctx: ctx, + refreshToken: validRefreshToken, + }, + want: func(td *TokenDetails) bool { + if td.AccessToken == "" || td.RefreshToken == "" { + return false + } + // 新的令牌应该与旧的不同 + if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken { + return false + } + // 验证认证信息一致 + if !reflect.DeepEqual(td.Auth, auth) { + return false + } + return true + }, + wantErr: false, + }, + { + name: "使用无效的刷新令牌", + args: args{ + ctx: ctx, + refreshToken: invalidRefreshToken, + }, + want: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...) + if (err != nil) != tt.wantErr { + t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.want != nil && !tt.want(got) { + t.Errorf("Refresh() got = %v, want to satisfy conditions", got) + } + + if !tt.wantErr && got != nil { + // 验证旧的令牌已被删除 + if mr.Exists(accessKey(originalAccessToken)) { + t.Errorf("原始访问令牌键应被删除") + } + if mr.Exists(refreshKey(validRefreshToken)) { + t.Errorf("原始刷新令牌键应被删除") + } + + // 验证新的令牌已被添加 + if !mr.Exists(accessKey(got.AccessToken)) { + t.Errorf("新的访问令牌键应存在") + } + if !mr.Exists(refreshKey(got.RefreshToken)) { + t.Errorf("新的刷新令牌键应存在") + } + } + }) + } +} + +func Test_sessionService_Remove(t *testing.T) { + mr := setupTestRedis(t) + ctx := context.Background() + auth := createTestAuthContext() + s := &sessionService{} + + // 创建一个会话 + td, err := s.Create(ctx, auth) + if err != nil { + t.Fatalf("无法创建测试会话: %v", err) + } + + validAccessToken := td.AccessToken + validRefreshToken := td.RefreshToken + + type args struct { + ctx context.Context + accessToken string + refreshToken string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "删除有效会话", + args: args{ + ctx: ctx, + accessToken: validAccessToken, + refreshToken: validRefreshToken, + }, + wantErr: false, + }, + { + name: "删除已删除的会话", + args: args{ + ctx: ctx, + accessToken: validAccessToken, + refreshToken: validRefreshToken, + }, + wantErr: false, // 删除不存在的会话不应报错 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr { + t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr) + } + + // 验证键已被删除 + if mr.Exists(accessKey(tt.args.accessToken)) { + t.Errorf("访问令牌键应已被删除") + } + if mr.Exists(refreshKey(tt.args.refreshToken)) { + t.Errorf("刷新令牌键应已被删除") + } + }) + } +} + +func TestAuthContext_AnyPermission(t *testing.T) { + type fields struct { + Payload Payload + Permissions map[string]struct{} + Metadata map[string]interface{} + } + type args struct { + requiredPermission []string + } + tests := []struct { + name string + fields fields + args args + want bool + }{ + { + name: "用户拥有所需权限", + fields: fields{ + Payload: Payload{Type: PayloadUser, Id: 1}, + Permissions: map[string]struct{}{ + "read": {}, + "write": {}, + }, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"read"}, + }, + want: true, + }, + { + name: "用户拥有至少一个所需权限", + fields: fields{ + Payload: Payload{Type: PayloadUser, Id: 1}, + Permissions: map[string]struct{}{ + "read": {}, + }, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"read", "admin"}, + }, + want: true, + }, + { + name: "用户没有所需权限", + fields: fields{ + Payload: Payload{Type: PayloadUser, Id: 1}, + Permissions: map[string]struct{}{ + "read": {}, + }, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"admin", "delete"}, + }, + want: false, + }, + { + name: "空权限列表", + fields: fields{ + Payload: Payload{Type: PayloadUser, Id: 1}, + Permissions: map[string]struct{}{}, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"read"}, + }, + want: false, + }, + { + name: "nil权限列表", + fields: fields{ + Payload: Payload{Type: PayloadUser, Id: 1}, + Permissions: nil, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"read"}, + }, + want: false, + }, + { + name: "nil认证上下文", + fields: fields{ + Payload: Payload{}, + Permissions: nil, + Metadata: nil, + }, + args: args{ + requiredPermission: []string{"read"}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := &AuthContext{ + Payload: tt.fields.Payload, + Permissions: tt.fields.Permissions, + Metadata: tt.fields.Metadata, + } + if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want { + t.Errorf("AnyPermission() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/web/services/verifier.go b/web/services/verifier.go index 6d28c28..eb13398 100644 --- a/web/services/verifier.go +++ b/web/services/verifier.go @@ -15,9 +15,6 @@ import ( var Verifier = &verifierService{} -type verifierService struct { -} - type VerifierServiceError string func (e VerifierServiceError) Error() string { @@ -37,11 +34,10 @@ func (e VerifierServiceSendLimitErr) Error() string { type VerifierSmsPurpose int const ( - Login VerifierSmsPurpose = iota + VerifierSmsPurposeLogin VerifierSmsPurpose = iota ) -func smsKey(phone string, purpose VerifierSmsPurpose) string { - return fmt.Sprintf("verify:sms:%d:%s", purpose, phone) +type verifierService struct { } func (s *verifierService) SendSms(ctx context.Context, phone string, purpose VerifierSmsPurpose) error { @@ -83,8 +79,8 @@ func (s *verifierService) SendSms(ctx context.Context, phone string, purpose Ver return nil } -func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bool, error) { - key := smsKey(phone, Login) +func (s *verifierService) VerifySms(ctx context.Context, phone, code string) error { + key := smsKey(phone, VerifierSmsPurposeLogin) keyLock := key + ":lock" err := rds.Client.Watch(ctx, func(tx *redis.Tx) error { @@ -114,11 +110,12 @@ func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bo return nil }, key) if err != nil { - if errors.Is(err, ErrVerifierServiceInvalid) { - return false, nil - } - return false, err + return err } - return true, nil + return nil +} + +func smsKey(phone string, purpose VerifierSmsPurpose) string { + return fmt.Sprintf("verify:sms:%d:%s", purpose, phone) } diff --git a/web/services/verifier_test.go b/web/services/verifier_test.go new file mode 100644 index 0000000..2a14b61 --- /dev/null +++ b/web/services/verifier_test.go @@ -0,0 +1,257 @@ +package services + +import ( + "context" + "platform/init/rds" + "strconv" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +// 设置测试的 Redis 环境 +func setupRedisTest(t *testing.T) *miniredis.Miniredis { + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("设置 miniredis 失败: %v", err) + } + + // 替换 redis 客户端为测试客户端 + rds.Client = redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + return mr +} + +func Test_verifierService_SendSms(t *testing.T) { + type args struct { + ctx context.Context + phone string + purpose VerifierSmsPurpose + } + tests := []struct { + name string + args args + setup func(mr *miniredis.Miniredis) + wantErr bool + wantErrType error + }{ + { + name: "正常发送成功(无旧验证码)", + args: args{ + ctx: context.Background(), + phone: "13812345678", + purpose: VerifierSmsPurposeLogin, + }, + setup: func(mr *miniredis.Miniredis) {}, + wantErr: false, + }, + { + name: "正常发送成功(有旧验证码)", + args: args{ + ctx: context.Background(), + phone: "13812345679", + purpose: VerifierSmsPurposeLogin, + }, + setup: func(mr *miniredis.Miniredis) { + key := smsKey("13812345679", VerifierSmsPurposeLogin) + mr.Set(key, "123456") + mr.SetTTL(key, 10*time.Minute) + }, + wantErr: false, + }, + { + name: "发送频率过快", + args: args{ + ctx: context.Background(), + phone: "13812345680", + purpose: VerifierSmsPurposeLogin, + }, + setup: func(mr *miniredis.Miniredis) { + key := smsKey("13812345680", VerifierSmsPurposeLogin) + ":lock" + mr.Set(key, "") + mr.SetTTL(key, 1*time.Minute) + }, + wantErr: true, + wantErrType: VerifierServiceSendLimitErr(0), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 设置 Redis 测试环境 + mr := setupRedisTest(t) + defer mr.Close() + + // 执行测试前的设置 + if tt.setup != nil { + tt.setup(mr) + } + + s := &verifierService{} + err := s.SendSms(tt.args.ctx, tt.args.phone, tt.args.purpose) + + // 验证错误 + if (err != nil) != tt.wantErr { + t.Errorf("SendSms() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 验证错误类型 + if tt.wantErr && tt.wantErrType != nil { + if _, isSendLimitErr := err.(VerifierServiceSendLimitErr); isSendLimitErr { + if _, wantSendLimitErr := tt.wantErrType.(VerifierServiceSendLimitErr); !wantSendLimitErr { + t.Errorf("SendSms() error type = %T, wantErrType %T", err, tt.wantErrType) + } + } + } + + // 验证 Redis 中的记录 + if !tt.wantErr { + key := smsKey(tt.args.phone, tt.args.purpose) + keyLock := key + ":lock" + + // 验证码应存在 + val, err := mr.Get(key) + if err != nil { + t.Errorf("验证码应存在但不存在: %v", err) + } + + // 限速锁应存在 + _, err = mr.Get(keyLock) + if err != nil { + t.Errorf("限速锁应存在但不存在: %v", err) + } + + // 验证码应为6位数字 + code, err := strconv.Atoi(val) + if err != nil || code < 100000 || code > 999999 { + t.Errorf("验证码应为6位数字: %v", val) + } + } + }) + } +} + +func Test_verifierService_VerifySms(t *testing.T) { + type args struct { + ctx context.Context + phone string + code string + } + tests := []struct { + name string + args args + setup func(mr *miniredis.Miniredis) + wantErr bool + wantErrType error + }{ + { + name: "验证码正确", + args: args{ + ctx: context.Background(), + phone: "13812345678", + code: "123456", + }, + setup: func(mr *miniredis.Miniredis) { + key := smsKey("13812345678", VerifierSmsPurposeLogin) + keyLock := key + ":lock" + mr.Set(key, "123456") + mr.SetTTL(key, 10*time.Minute) + mr.Set(keyLock, "") + mr.SetTTL(keyLock, 1*time.Minute) + }, + wantErr: false, + }, + { + name: "验证码错误", + args: args{ + ctx: context.Background(), + phone: "13812345679", + code: "654321", + }, + setup: func(mr *miniredis.Miniredis) { + key := smsKey("13812345679", VerifierSmsPurposeLogin) + keyLock := key + ":lock" + mr.Set(key, "123456") + mr.SetTTL(key, 10*time.Minute) + mr.Set(keyLock, "") + mr.SetTTL(keyLock, 1*time.Minute) + }, + wantErr: true, + wantErrType: ErrVerifierServiceInvalid, + }, + { + name: "验证码过期", + args: args{ + ctx: context.Background(), + phone: "13812345680", + code: "123456", + }, + setup: func(mr *miniredis.Miniredis) { + // 不设置验证码,模拟过期情况 + }, + wantErr: true, + wantErrType: ErrVerifierServiceInvalid, + }, + { + name: "手机号错误", + args: args{ + ctx: context.Background(), + phone: "13812345681", + code: "123456", + }, + setup: func(mr *miniredis.Miniredis) { + // 设置一个不同手机号的验证码 + key := smsKey("13800000000", VerifierSmsPurposeLogin) + mr.Set(key, "123456") + mr.SetTTL(key, 10*time.Minute) + }, + wantErr: true, + wantErrType: ErrVerifierServiceInvalid, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 设置 Redis 测试环境 + mr := setupRedisTest(t) + defer mr.Close() + + // 执行测试前的设置 + if tt.setup != nil { + tt.setup(mr) + } + + s := &verifierService{} + err := s.VerifySms(tt.args.ctx, tt.args.phone, tt.args.code) + if (err != nil) != tt.wantErr { + t.Errorf("VerifySms() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 检查错误类型 + if tt.wantErr && tt.wantErrType != nil && err != tt.wantErrType { + t.Errorf("VerifySms() error = %v, wantErrType %v", err, tt.wantErrType) + return + } + + // 验证成功后 Redis 中应该没有该记录 + if err == nil { + key := smsKey(tt.args.phone, VerifierSmsPurposeLogin) + keyLock := key + ":lock" + + _, redisErr := mr.Get(key) + if redisErr == nil { + t.Errorf("验证码验证成功后应删除,但仍存在") + } + + _, redisErr = mr.Get(keyLock) + if redisErr == nil { + t.Errorf("限速锁验证成功后应删除,但仍存在") + } + } + }) + } +}