diff --git a/README.md b/README.md index 315fec0..e3ef105 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,39 @@ ## todo -购买使用流程 +核心流程: -- [ ] 注册 -- [ ] 登录 +- [x] 注册与登录 + - [ ] 对接短信接口 + - [ ] 人机风险分级验证 + - [ ] jwt 签发 +- [ ] 鉴权 - [ ] 实名认证 - [ ] 充值余额 - [ ] 选择套餐 - [ ] 提取 IP - [ ] 连接 -确认登录方式,决定表结构 +中间件: + +- [ ] CORS +- [ ] Limiter +- [ ] Compress + +环境变量配置默认会话配置 + +oauth token 验证授权范围 + +保存 session 到数据库 + +账单数据表结构修改 + +captcha 自定义生成流程,弃用 store + +短信发送日志 + +captcha_id 关联用户本机信息,实现验证码设备绑定(或者其他方式) + +退出时主动断开数据库缓存等连接 固有字段统一放在最开始 diff --git a/cmd/gen/main.go b/cmd/gen/main.go index 5d4fa35..2adc6e6 100644 --- a/cmd/gen/main.go +++ b/cmd/gen/main.go @@ -8,12 +8,6 @@ import ( ) func main() { - g := gen.NewGenerator(gen.Config{ - OutPath: "web/queries", - ModelPkgPath: "models", - Mode: gen.WithDefaultQuery | gen.WithoutContext, - }) - db, _ := gorm.Open( postgres.Open("host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai"), &gorm.Config{ @@ -22,6 +16,12 @@ func main() { }, }, ) + + g := gen.NewGenerator(gen.Config{ + OutPath: "web/queries", + ModelPkgPath: "models", + Mode: gen.WithDefaultQuery | gen.WithoutContext, + }) g.UseDB(db) models := g.GenerateAllTable() diff --git a/cmd/main/main.go b/cmd/main/main.go index a7598cd..b3687fd 100644 --- a/cmd/main/main.go +++ b/cmd/main/main.go @@ -7,12 +7,12 @@ import ( "platform/init/env" "platform/init/logs" "platform/init/orm" + "platform/init/rds" "platform/web" "syscall" ) func main() { - logger := slog.Default() // 退出信号 shutdown := make(chan os.Signal, 1) @@ -22,14 +22,14 @@ func main() { env.Init() logs.Init() orm.Init() + rds.Init() // web 服务 app, err := web.New(&web.Config{ - Logger: logger, Listen: ":8080", }) if err != nil { - logger.Error("Failed to create server", slog.Any("error", err)) + slog.Error("Failed to create server", slog.Any("err", err)) return } @@ -38,7 +38,7 @@ func main() { go func() { err = app.Run() if err != nil { - logger.Error("Failed to run server", slog.Any("error", err)) + slog.Error("Failed to run server", slog.Any("err", err)) errCh <- err } errCh <- nil @@ -48,19 +48,19 @@ func main() { exit := false select { case <-shutdown: - logger.Info("Received shutdown signal") + slog.Info("Received shutdown signal") app.Stop() exit = true case err := <-errCh: if err != nil { - logger.Error("Server error", slog.Any("error", err)) + slog.Error("Server error", slog.Any("err", err)) } } if exit { err := <-errCh if err != nil { - logger.Error("Server error", slog.Any("error", err)) + slog.Error("Server error", slog.Any("err", err)) } } } diff --git a/go.mod b/go.mod index bba4715..578d9ba 100644 --- a/go.mod +++ b/go.mod @@ -4,17 +4,22 @@ go 1.24.0 require ( github.com/gofiber/fiber/v2 v2.52.6 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/lmittmann/tint v1.0.7 + github.com/redis/go-redis/v9 v9.3.0 + golang.org/x/crypto v0.17.0 gorm.io/driver/postgres v1.5.11 gorm.io/gen v0.3.26 gorm.io/gorm v1.25.12 + gorm.io/plugin/dbresolver v1.5.3 ) require ( github.com/andybalholm/brotli v1.1.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/go-sql-driver/mysql v1.7.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect @@ -29,14 +34,12 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/crypto v0.17.0 // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/sync v0.6.0 // indirect + golang.org/x/mod v0.17.0 // indirect + golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.28.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.17.0 // indirect + golang.org/x/text v0.23.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect gorm.io/datatypes v1.1.1-0.20230130040222-c43177d3cf8c // indirect gorm.io/driver/mysql v1.5.7 // indirect gorm.io/hints v1.1.0 // indirect - gorm.io/plugin/dbresolver v1.5.3 // indirect ) diff --git a/go.sum b/go.sum index a23c275..94e8577 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,16 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/gofiber/fiber/v2 v2.52.6 h1:Rfp+ILPiYSvvVuIPvxrBns+HJp8qGLDnLJawAu27XVI= @@ -46,6 +54,8 @@ github.com/microsoft/go-mssqldb v0.17.0 h1:Fto83dMZPnYv1Zwx5vHHxpNraeEaUlQ/hhHLg github.com/microsoft/go-mssqldb v0.17.0/go.mod h1:OkoNGhGEs8EZqchVTtochlXruEhEOaO4S0d2sB5aeGQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.3.0 h1:RiVDjmig62jIWp7Kk4XVLs0hzV6pI3PyTnnL0cnn0u0= +github.com/redis/go-redis/v9 v9.3.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -66,8 +76,8 @@ golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -75,8 +85,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= -golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw= +golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -94,14 +104,15 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= +golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= -golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/init/env/env.go b/init/env/env.go index 1c5b6ca..05cbf80 100644 --- a/init/env/env.go +++ b/init/env/env.go @@ -3,40 +3,20 @@ package env import ( "log/slog" "os" + "strconv" "github.com/gofiber/fiber/v2/log" "github.com/joho/godotenv" ) +// region app + var ( AppName = "platform" AppPort = "8080" ) -var ( - DbHost = "localhost" - DbPort = "3306" - DbName string - DbUserName string - DbPassword string -) - -var ( - LogLevel = slog.LevelDebug -) - -func Init() { - err := godotenv.Load() - if err != nil { - log.Debug("❓ 没有本地环境变量") - } else { - log.Debug("✔ 加载本地环境变量") - } - - check() -} - -func check() { +func loadApp() { _AppName := os.Getenv("APP_NAME") if _AppName != "" { AppName = _AppName @@ -46,7 +26,21 @@ func check() { if _AppPort != "" { AppPort = _AppPort } +} +// endregion + +// region db + +var ( + DbHost = "localhost" + DbPort = "5432" + DbName string + DbUserName string + DbPassword string +) + +func loadDb() { _DbHost := os.Getenv("DB_HOST") if _DbHost != "" { DbHost = _DbHost @@ -77,7 +71,54 @@ func check() { } else { panic("环境变量 DB_PASSWORD 的值为空") } +} +// endregion + +// region redis + +var ( + RedisHost = "localhost" + RedisPort = "6379" + RedisDb = 0 + RedisPass = "" +) + +func loadRedis() { + _RedisHost := os.Getenv("REDIS_HOST") + if _RedisHost != "" { + RedisHost = _RedisHost + } + + _RedisPort := os.Getenv("REDIS_PORT") + if _RedisPort != "" { + RedisPort = _RedisPort + } + + _RedisDb := os.Getenv("REDIS_DB") + if _RedisDb != "" { + atoi, err := strconv.Atoi(_RedisDb) + if err != nil { + panic("环境变量 REDIS_DB 的值不是数字") + } + RedisDb = atoi + } + + _RedisPass := os.Getenv("REDIS_PASS") + if _RedisPass != "" { + RedisPass = _RedisPass + } +} + +// endregion + +// region log + +var ( + LogLevel = slog.LevelDebug +) + +func loadLog() { _LogLevel := os.Getenv("LOG_LEVEL") switch _LogLevel { case "debug": @@ -90,3 +131,19 @@ func check() { LogLevel = slog.LevelError } } + +// endregion + +func Init() { + err := godotenv.Load() + if err != nil { + log.Debug("❓ 没有本地环境变量") + } else { + log.Debug("✔ 加载本地环境变量") + } + + loadApp() + loadDb() + loadRedis() + loadLog() +} diff --git a/init/logs/logs.go b/init/logs/logs.go index 911a2e9..74ae61b 100644 --- a/init/logs/logs.go +++ b/init/logs/logs.go @@ -9,14 +9,18 @@ import ( "github.com/lmittmann/tint" ) -var Default *slog.Logger - func Init() { - Default = slog.New( + slog.SetDefault(slog.New( tint.NewHandler(os.Stdout, &tint.Options{ Level: env.LogLevel, TimeFormat: time.Kitchen, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + err, ok := attr.Value.Any().(error) + if ok { + return tint.Err(err) + } + return attr + }, }), - ) - slog.SetDefault(Default) + )) } diff --git a/init/orm/orm.go b/init/orm/orm.go index aaf8df3..a62016e 100644 --- a/init/orm/orm.go +++ b/init/orm/orm.go @@ -4,40 +4,39 @@ import ( "fmt" "log/slog" "platform/init/env" - "platform/init/logs" + "platform/web/queries" "gorm.io/gorm" "gorm.io/gorm/schema" ) import "gorm.io/driver/postgres" -var DB *gorm.DB - func Init() { - logger := logs.Default + // 连接数据库 dsn := fmt.Sprintf( "host=%s user=%s password=%s dbname=%s port=%s sslmode=disable TimeZone=Asia/Shanghai", - env.DbName, env.DbUserName, env.DbPassword, env.DbName, env.DbPort, + env.DbHost, env.DbUserName, env.DbPassword, env.DbName, env.DbPort, ) - - open, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ NamingStrategy: schema.NamingStrategy{ SingularTable: true, }, }) if err != nil { - logger.Error("gorm 打开数据库失败", slog.Any("err", err)) + slog.Error("gorm 初始化数据库失败:", slog.Any("err", err)) panic(err) } - sql, err := open.DB() + // 连接池 + conn, err := db.DB() if err != nil { - logger.Error("gorm open db error: ", slog.Any("err", err)) + slog.Error("gorm 初始化数据库失败:", slog.Any("err", err)) panic(err) } - sql.SetMaxIdleConns(10) - sql.SetMaxOpenConns(100) + conn.SetMaxIdleConns(10) + conn.SetMaxOpenConns(100) - DB = open + // 初始化查询工具 + queries.SetDefault(db) } diff --git a/init/rds/rds.go b/init/rds/rds.go new file mode 100644 index 0000000..4b92615 --- /dev/null +++ b/init/rds/rds.go @@ -0,0 +1,18 @@ +package rds + +import ( + "net" + "platform/init/env" + + "github.com/redis/go-redis/v9" +) + +var Client *redis.Client + +func Init() { + Client = redis.NewClient(&redis.Options{ + Addr: net.JoinHostPort(env.RedisHost, env.RedisPort), + DB: env.RedisDb, + Password: env.RedisPass, + }) +} diff --git a/scripts/dev/docker-compose.yaml b/scripts/dev/docker-compose.yaml index 0454a5f..6eef6b0 100644 --- a/scripts/dev/docker-compose.yaml +++ b/scripts/dev/docker-compose.yaml @@ -1,6 +1,7 @@ name: server-dev services: + postgres: image: postgres:17 restart: always @@ -12,6 +13,11 @@ services: - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data + redis: + image: redis:7.4 + restart: always + ports: + - "6379:6379" volumes: postgres_data: diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index af5cb32..b88b89b 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -1,3 +1,22 @@ +-- 清空数据表 +do +$$ + declare + r record; + begin + for r in ( + select + tablename + from + pg_tables + where + schemaname = 'public' + ) loop + execute 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; + end loop; + end +$$; + -- ==================== -- region 管理员信息 -- ==================== @@ -14,7 +33,7 @@ create table admin ( email varchar(255), status int not null default 1, last_login timestamp, - last_login_addr varchar(45), + last_login_host varchar(45), last_login_agent varchar(255), created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, @@ -33,7 +52,7 @@ comment on column admin.phone is '手机号码'; comment on column admin.email is '邮箱'; comment on column admin.status is '状态:1-正常,0-禁用'; comment on column admin.last_login is '最后登录时间'; -comment on column admin.last_login_addr is '最后登录地址'; +comment on column admin.last_login_host is '最后登录地址'; comment on column admin.last_login_agent is '最后登录代理'; comment on column admin.created_at is '创建时间'; comment on column admin.updated_at is '更新时间'; @@ -90,7 +109,7 @@ create table "user" ( contact_qq varchar(255), contact_wechat varchar(255), last_login timestamp, - last_login_addr varchar(45), + last_login_host varchar(45), last_login_agent varchar(255), created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, @@ -103,26 +122,26 @@ create index user_status_index on "user" (status); -- user表字段注释 comment on table "user" is '用户表'; -comment on column user.id is '用户ID'; -comment on column user.admin_id is '管理员ID'; -comment on column user.password is '用户密码'; -comment on column user.username is '用户名'; -comment on column user.phone is '手机号码'; -comment on column user.name is '真实姓名'; -comment on column user.avatar is '头像URL'; -comment on column user.status is '用户状态:1-正常,0-禁用'; -comment on column user.balance is '账户余额'; -comment on column user.id_type is '认证类型:0-未认证,1-个人认证,2-企业认证'; -comment on column user.id_no is '身份证号或营业执照号'; -comment on column user.id_token is '身份验证标识'; -comment on column user.contact_qq is 'QQ联系方式'; -comment on column user.contact_wechat is '微信联系方式'; -comment on column user.last_login is '最后登录时间'; -comment on column user.last_login_addr is '最后登录地址'; -comment on column user.last_login_agent is '最后登录代理'; -comment on column user.created_at is '创建时间'; -comment on column user.updated_at is '更新时间'; -comment on column user.deleted_at is '删除时间'; +comment on column "user".id is '用户ID'; +comment on column "user".admin_id is '管理员ID'; +comment on column "user".password is '用户密码'; +comment on column "user".username is '用户名'; +comment on column "user".phone is '手机号码'; +comment on column "user".name is '真实姓名'; +comment on column "user".avatar is '头像URL'; +comment on column "user".status is '用户状态:1-正常,0-禁用'; +comment on column "user".balance is '账户余额'; +comment on column "user".id_type is '认证类型:0-未认证,1-个人认证,2-企业认证'; +comment on column "user".id_no is '身份证号或营业执照号'; +comment on column "user".id_token is '身份验证标识'; +comment on column "user".contact_qq is 'QQ联系方式'; +comment on column "user".contact_wechat is '微信联系方式'; +comment on column "user".last_login is '最后登录时间'; +comment on column "user".last_login_host is '最后登录地址'; +comment on column "user".last_login_agent is '最后登录代理'; +comment on column "user".created_at is '创建时间'; +comment on column "user".updated_at is '更新时间'; +comment on column "user".deleted_at is '删除时间'; -- user_role drop table if exists user_role cascade; @@ -150,6 +169,51 @@ comment on column user_role.deleted_at is '删除时间'; -- endregion +-- ==================== +-- region 客户端信息 +-- ==================== + +drop table if exists client cascade; +create table client ( + id serial primary key, + client_id varchar(255) not null unique, + client_secret varchar(255) not null, + redirect_uri varchar(255), + grant_code bool not null default false, + grant_client bool not null default false, + grant_refresh bool not null default false, + spec int not null, + name varchar(255) not null, + version int not null, + status int not null default 1, + created_at timestamp default current_timestamp, + updated_at timestamp default current_timestamp, + deleted_at timestamp +); + +create index client_client_id_index on client (client_id); +create index client_name_index on client (name); +create index client_status_index on client (status); + +-- client表字段注释 +comment on table client is '客户端表'; +comment on column client.id is '客户端ID'; +comment on column client.client_id is 'OAuth2客户端标识符'; +comment on column client.client_secret is 'OAuth2客户端密钥'; +comment on column client.redirect_uri is 'OAuth2 重定向URI'; +comment on column client.grant_code is '允许授权码授予'; +comment on column client.grant_client is '允许客户端凭证授予'; +comment on column client.grant_refresh is '允许刷新令牌授予'; +comment on column client.spec is '安全规范:0-web,1-native,2-browser'; +comment on column client.name is '名称'; +comment on column client.version is '版本'; +comment on column client.status is '状态:1-正常,0-禁用'; +comment on column client.created_at is '创建时间'; +comment on column client.updated_at is '更新时间'; +comment on column client.deleted_at is '删除时间'; + +-- endregion + -- ==================== -- region 权限信息 -- ==================== @@ -168,6 +232,7 @@ create table permission ( deleted_at timestamp ); create index permission_parent_id_index on permission (parent_id); +create index permission_name_index on permission (name); -- permission表字段注释 comment on table permission is '权限表'; @@ -283,6 +348,32 @@ comment on column admin_role_permission_link.created_at is '创建时间'; comment on column admin_role_permission_link.updated_at is '更新时间'; comment on column admin_role_permission_link.deleted_at is '删除时间'; +-- client_permission_link +drop table if exists client_permission_link cascade; +create table client_permission_link ( + id serial primary key, + client_id int not null references client (id) + on update cascade + on delete cascade, + permission_id int not null references permission (id) + on update cascade + on delete cascade, + created_at timestamp default current_timestamp, + updated_at timestamp default current_timestamp, + deleted_at timestamp +); +create index client_permission_link_client_id_index on client_permission_link (client_id); +create index client_permission_link_permission_id_index on client_permission_link (permission_id); + +-- client_permission_link表字段注释 +comment on table client_permission_link is '客户端权限关联表'; +comment on column client_permission_link.id is '关联ID'; +comment on column client_permission_link.client_id is '客户端ID'; +comment on column client_permission_link.permission_id is '权限ID'; +comment on column client_permission_link.created_at is '创建时间'; +comment on column client_permission_link.updated_at is '更新时间'; +comment on column client_permission_link.deleted_at is '删除时间'; + -- endregion -- ==================== @@ -324,19 +415,19 @@ create table whitelist ( user_id int not null references "user" (id) on update cascade on delete cascade, - address varchar(45) not null, + host varchar(45) not null, created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, deleted_at timestamp ); create index whitelist_user_id_index on whitelist (user_id); -create index whitelist_address_index on whitelist (address); +create index whitelist_host_index on whitelist (host); -- whitelist表字段注释 comment on table whitelist is '白名单表'; comment on column whitelist.id is '白名单ID'; comment on column whitelist.user_id is '用户ID'; -comment on column whitelist.address is 'IP地址'; +comment on column whitelist.host is 'IP地址'; comment on column whitelist.created_at is '创建时间'; comment on column whitelist.updated_at is '更新时间'; comment on column whitelist.deleted_at is '删除时间'; @@ -351,7 +442,7 @@ create table channel ( node_id int references node (id) -- on update cascade -- on delete set null, - user_addr varchar(255) not null, + user_host varchar(255) not null, node_port int, auth_ip bool not null default false, auth_pass bool not null default false, @@ -365,7 +456,7 @@ create table channel ( ); create index channel_user_id_index on channel (user_id); create index channel_node_id_index on channel (node_id); -create index channel_user_addr_index on channel (user_addr); +create index channel_user_host_index on channel (user_host); create index channel_node_port_index on channel (node_port); create index channel_expiration_index on channel (expiration); @@ -374,7 +465,7 @@ comment on table channel is '通道表'; comment on column channel.id is '通道ID'; comment on column channel.user_id is '用户ID'; comment on column channel.node_id is '节点ID'; -comment on column channel.user_addr is '用户地址'; +comment on column channel.user_host is '用户地址'; comment on column channel.node_port is '节点端口'; comment on column channel.auth_ip is 'IP认证'; comment on column channel.auth_pass is '密码认证'; diff --git a/web/auth.go b/web/auth.go index efb3895..9c50676 100644 --- a/web/auth.go +++ b/web/auth.go @@ -1 +1,48 @@ package web + +import ( + "platform/web/common" + "strings" + + "platform/web/services" + + "github.com/gofiber/fiber/v2" +) + +// Protect 创建针对单个路由的鉴权中间件 +func Protect(permissions ...string) fiber.Handler { + return func(c *fiber.Ctx) error { + // 获取令牌 + var header = c.Get("Authorization") + var token = strings.TrimPrefix(header, "Bearer ") + if token == "" { + return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ + Error: true, + Message: "没有权限", + }) + } + + // 验证令牌 + auth, err := services.Session.Find(c.Context(), token) + if err != nil { + return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ + Error: true, + Message: "没有权限", + }) + } + + // 检查权限 + if len(permissions) > 0 && !auth.AnyPermission(permissions...) { + return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ + Error: true, + Message: "拒绝访问", + }) + } + + // 将认证信息存储在上下文中 + c.Locals("auth", auth) + c.Locals("access_token", token) // 存储原始令牌,便于后续操作 + + return c.Next() + } +} diff --git a/web/common/types.go b/web/common/types.go new file mode 100644 index 0000000..3e6348c --- /dev/null +++ b/web/common/types.go @@ -0,0 +1,7 @@ +package common + +// ErrResp 定义通用错误响应格式 +type ErrResp struct { + Message string `json:"message"` + Error bool `json:"error"` +} diff --git a/web/error.go b/web/error.go new file mode 100644 index 0000000..2c13f6d --- /dev/null +++ b/web/error.go @@ -0,0 +1,19 @@ +package web + +import ( + "errors" + + "github.com/gofiber/fiber/v2" +) + +func ErrorHandler(c *fiber.Ctx, err error) error { + code := fiber.StatusInternalServerError + message := "服务器异常" + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + message = e.Message + } + c.Set(fiber.HeaderContentType, fiber.MIMETextPlainCharsetUTF8) + return c.Status(code).SendString(message) +} diff --git a/web/handlers/client.go b/web/handlers/client.go new file mode 100644 index 0000000..d527f24 --- /dev/null +++ b/web/handlers/client.go @@ -0,0 +1,52 @@ +package handlers + +import ( + "platform/web/models" + q "platform/web/queries" + "time" + + "github.com/gofiber/fiber/v2" + "golang.org/x/crypto/bcrypt" +) + +type CreateClientReq struct { + ClientID string `query:"client_id"` + ClientSecret string `query:"client_secret"` +} + +func CreateClient(c *fiber.Ctx) error { + // 验证请求参数 + req := new(CreateClientReq) + if err := c.QueryParser(req); err != nil { + return err + } + if req.ClientID == "" { + return fiber.NewError(fiber.StatusBadRequest, "client_id不能为空") + } + if req.ClientSecret == "" { + return fiber.NewError(fiber.StatusBadRequest, "client_secret不能为空") + } + + // 创建客户端 + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(req.ClientSecret), bcrypt.DefaultCost) + if err != nil { + return err + } + client := &models.Client{ + ClientID: req.ClientID, + ClientSecret: string(hashedSecret), + Name: "默认客户端 - " + time.Now().String(), + Spec: 0, + GrantCode: true, + GrantClient: true, + GrantRefresh: true, + Version: 0, + } + + err = q.Client.Create(client) + if err != nil { + return err + } + + return c.JSON(client) +} diff --git a/web/handlers/login.go b/web/handlers/login.go new file mode 100644 index 0000000..713aaf3 --- /dev/null +++ b/web/handlers/login.go @@ -0,0 +1,107 @@ +package handlers + +import ( + "errors" + "platform/web/models" + q "platform/web/queries" + "platform/web/services" + "time" + + "github.com/gofiber/fiber/v2" + "gorm.io/gorm" +) + +type LoginReq struct { + Username string `json:"username"` + Password string `json:"password"` + Remember bool `json:"remember"` +} + +type LoginResp struct { + Token string `json:"token"` + Expires int64 `json:"expires"` +} + +func Login(c *fiber.Ctx) error { + + // 验证请求参数 + req := new(LoginReq) + if err := c.BodyParser(req); err != nil { + return err + } + if req.Username == "" { + return fiber.NewError(fiber.StatusBadRequest, "手机号不能为空") + } + if req.Password == "" { + return fiber.NewError(fiber.StatusBadRequest, "验证码不能为空") + } + + return loginByPhone(c, req) +} + +func loginByPhone(c *fiber.Ctx, req *LoginReq) error { + + // 验证验证码 + ok, err := services.Verifier.VerifySms(c.Context(), req.Username, req.Password) + if err != nil { + return err + } + if !ok { + return fiber.NewError(fiber.StatusBadRequest, "验证码错误") + } + + // 查找用户 todo 获取权限信息 + var tx = q.Q.Begin() + + var user *models.User + user, err = tx.User. + Where(tx.User.Phone.Eq(req.Username)). + Take() + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // 如果用户不存在,初始化用户 todo 保存默认权限信息 + if user == nil { + user = &models.User{ + Phone: req.Username, + } + } + + // 更新用户的登录时间 + user.LastLogin = time.Now() + user.LastLoginHost = c.IP() + user.LastLoginAgent = c.Get("User-Agent") + if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil { + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + // 保存到会话 + auth := services.AuthContext{ + Permissions: map[string]struct{}{ + "user": {}, + }, + Payload: services.Payload{ + Type: services.PayloadUser, + Id: user.ID, + }, + } + duration := time.Hour * 24 + if req.Remember { + duration *= 7 + } + token, err := services.Session.Create(c.Context(), auth) + if err != nil { + return err + } + + return c.JSON(LoginResp{ + Token: token.AccessToken, + Expires: token.AccessTokenExpires.Unix(), + }) +} diff --git a/web/handlers/oauth.go b/web/handlers/oauth.go new file mode 100644 index 0000000..20dcf7c --- /dev/null +++ b/web/handlers/oauth.go @@ -0,0 +1,243 @@ +package handlers + +import ( + "encoding/base64" + "errors" + "platform/web/models" + q "platform/web/queries" + "platform/web/services" + "strings" + "time" + + "github.com/gofiber/fiber/v2" + "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" +) + +// region Token + +type TokenReq struct { + ClientID string `json:"client_id" form:"client_id"` + ClientSecret string `json:"client_secret" form:"client_secret"` + GrantType TokenGrantType `json:"grant_type" form:"grant_type"` + Code string `json:"code" form:"code"` + RedirectURI string `json:"redirect_uri" form:"redirect_uri"` + CodeVerifier string `json:"code_verifier" form:"code_verifier"` + RefreshToken string `json:"refresh_token" form:"refresh_token"` + Scope string `json:"scope" form:"scope"` +} + +type TokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + Scope string `json:"scope,omitempty"` + ExpiresIn int `json:"expires_in"` +} + +type TokenErrResp struct { + Error string `json:"error"` + Description string `json:"error_description,omitempty"` +} + +type TokenGrantType string + +const ( + AuthorizationCode = TokenGrantType("authorization_code") + ClientCredentials = TokenGrantType("client_credentials") + RefreshToken = TokenGrantType("refresh_token") +) + +// Token 处理 OAuth2.0 授权请求 +func Token(c *fiber.Ctx) error { + + // 验证请求参数 + req := new(TokenReq) + if err := c.BodyParser(req); err != nil { + return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数") + } + if req.GrantType == "" { + return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:grant_type") + } + + // 基于授权类型处理请求 + switch req.GrantType { + + case AuthorizationCode: + return authorizationCode(c, req) + + case ClientCredentials: + return clientCredentials(c, req) + + case RefreshToken: + return refreshToken(c, req) + + default: + return sendError(c, services.ErrOauthUnsupportedGrantType) + } +} + +// 授权码 +func authorizationCode(c *fiber.Ctx, req *TokenReq) error { + if req.Code == "" { + return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:code") + } + + client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier) + if err != nil { + return sendError(c, err.(services.AuthServiceOauthError)) + } + + return sendSuccess(c, token) +} + +// 客户端凭证 +func clientCredentials(c *fiber.Ctx, req *TokenReq) error { + client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + scope := strings.Split(req.Scope, ",") + token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope) + if err != nil { + return sendError(c, err.(services.AuthServiceOauthError)) + } + + return sendSuccess(c, token) +} + +// 刷新令牌 +func refreshToken(c *fiber.Ctx, req *TokenReq) error { + if req.RefreshToken == "" { + return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") + } + + client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret) + if err != nil { + return sendError(c, err) + } + + scope := strings.Split(req.Scope, ",") + token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) + if err != nil { + return sendError(c, err.(services.AuthServiceOauthError)) + } + + return sendSuccess(c, token) +} + +// 检查客户端凭证 +func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) { + header := c.Get("Authorization") + if header != "" { + basic := strings.TrimPrefix(header, "Basic ") + if basic != "" { + base, err := base64.URLEncoding.DecodeString(basic) + if err != nil { + return nil, err + } + parts := strings.SplitN(string(base), ":", 2) + if len(parts) == 2 { + clientId = parts[0] + clientSecret = parts[1] + } + } + } + + // 查找客户端 + if clientId == "" { + return nil, services.ErrOauthInvalidRequest + } + client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take() + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, services.ErrOauthInvalidClient + } + return nil, err + } + + // 验证客户端状态 + if client.Status != 1 { + return nil, services.ErrOauthUnauthorizedClient + } + + // 验证授权类型 + switch grant { + case services.GrantTypeAuthorizationCode: + if !client.GrantCode { + return nil, services.ErrOauthUnauthorizedClient + } + case services.GrantTypeClientCredentials: + if !client.GrantClient || client.Spec != 0 { + return nil, services.ErrOauthUnauthorizedClient + } + case services.GrantTypeRefreshToken: + if !client.GrantRefresh { + return nil, services.ErrOauthUnauthorizedClient + } + } + + // 如果客户端是 confidential,验证 client_secret,失败返回错误 + if client.Spec == 0 { + if clientSecret == "" { + return nil, services.ErrOauthInvalidRequest + } + if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { + return nil, services.ErrOauthInvalidClient + } + } + + return client, nil +} + +// 发送成功响应 +func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error { + return c.JSON(TokenResp{ + AccessToken: details.AccessToken, + TokenType: "Bearer", + ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()), + RefreshToken: details.RefreshToken, + }) +} + +// 发送错误响应 +func sendError(c *fiber.Ctx, err error, description ...string) error { + var sErr services.AuthServiceOauthError + if errors.As(err, &sErr) { + status := fiber.StatusBadRequest + var desc string + switch { + case errors.Is(sErr, services.ErrOauthInvalidRequest): + desc = "无效的请求" + case errors.Is(sErr, services.ErrOauthInvalidClient): + status = fiber.StatusUnauthorized + desc = "无效的客户端凭证" + case errors.Is(sErr, services.ErrOauthInvalidGrant): + desc = "无效的授权凭证" + case errors.Is(sErr, services.ErrOauthInvalidScope): + desc = "无效的授权范围" + case errors.Is(sErr, services.ErrOauthUnauthorizedClient): + desc = "未授权的客户端" + case errors.Is(sErr, services.ErrOauthUnsupportedGrantType): + desc = "不支持的授权类型" + } + if len(description) > 0 { + desc = description[0] + } + + return c.Status(status).JSON(TokenErrResp{ + Error: string(sErr), + Description: desc, + }) + } + + return err +} + +// endregion diff --git a/web/handlers/verifier.go b/web/handlers/verifier.go new file mode 100644 index 0000000..c64beb1 --- /dev/null +++ b/web/handlers/verifier.go @@ -0,0 +1,44 @@ +package handlers + +import ( + "errors" + "platform/web/services" + "regexp" + "strconv" + + "github.com/gofiber/fiber/v2" +) + +type VerifierReq struct { + Purpose services.VerifierSmsPurpose `json:"purpose"` + Phone string `json:"phone"` +} + +func SmsCode(c *fiber.Ctx) error { + + // 解析请求参数 + req := new(VerifierReq) + if err := c.BodyParser(req); err != nil { + return err + } + match, err := regexp.MatchString(`^1[3-9]\d{9}$`, req.Phone) + if err != nil { + return err + } + if !match { + return fiber.NewError(fiber.StatusBadRequest, "手机号格式错误") + } + + // 发送身份验证码 + err = services.Verifier.SendSms(c.Context(), req.Phone, req.Purpose) + if err != nil { + var sErr services.VerifierServiceSendLimitErr + if errors.As(err, &sErr) { + return fiber.NewError(fiber.StatusTooManyRequests, strconv.Itoa(int(sErr))) + } + return err + } + + // 发送成功 + return nil +} diff --git a/web/models/admin.gen.go b/web/models/admin.gen.go index a66a9cc..aee901a 100644 --- a/web/models/admin.gen.go +++ b/web/models/admin.gen.go @@ -23,7 +23,7 @@ type Admin struct { Email string `gorm:"column:email;comment:邮箱" json:"email"` // 邮箱 Status int32 `gorm:"column:status;not null;default:1;comment:状态:1-正常,0-禁用" json:"status"` // 状态:1-正常,0-禁用 LastLogin time.Time `gorm:"column:last_login;comment:最后登录时间" json:"last_login"` // 最后登录时间 - LastLoginAddr string `gorm:"column:last_login_addr;comment:最后登录地址" json:"last_login_addr"` // 最后登录地址 + LastLoginHost string `gorm:"column:last_login_host;comment:最后登录地址" json:"last_login_host"` // 最后登录地址 LastLoginAgent string `gorm:"column:last_login_agent;comment:最后登录代理" json:"last_login_agent"` // 最后登录代理 CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 diff --git a/web/models/channel.gen.go b/web/models/channel.gen.go index 2c28896..f8fc279 100644 --- a/web/models/channel.gen.go +++ b/web/models/channel.gen.go @@ -17,7 +17,7 @@ type Channel struct { ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:通道ID" json:"id"` // 通道ID UserID int32 `gorm:"column:user_id;not null;comment:用户ID" json:"user_id"` // 用户ID NodeID int32 `gorm:"column:node_id;comment:节点ID" json:"node_id"` // 节点ID - UserAddr string `gorm:"column:user_addr;not null;comment:用户地址" json:"user_addr"` // 用户地址 + UserHost string `gorm:"column:user_host;not null;comment:用户地址" json:"user_host"` // 用户地址 NodePort int32 `gorm:"column:node_port;comment:节点端口" json:"node_port"` // 节点端口 AuthIP bool `gorm:"column:auth_ip;not null;comment:IP认证" json:"auth_ip"` // IP认证 AuthPass bool `gorm:"column:auth_pass;not null;comment:密码认证" json:"auth_pass"` // 密码认证 diff --git a/web/models/client.gen.go b/web/models/client.gen.go new file mode 100644 index 0000000..8f729cb --- /dev/null +++ b/web/models/client.gen.go @@ -0,0 +1,36 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package models + +import ( + "time" + + "gorm.io/gorm" +) + +const TableNameClient = "client" + +// Client mapped from table +type Client struct { + ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:客户端ID" json:"id"` // 客户端ID + ClientID string `gorm:"column:client_id;not null;comment:OAuth2客户端标识符" json:"client_id"` // OAuth2客户端标识符 + ClientSecret string `gorm:"column:client_secret;not null;comment:OAuth2客户端密钥" json:"client_secret"` // OAuth2客户端密钥 + RedirectURI string `gorm:"column:redirect_uri;comment:OAuth2 重定向URI" json:"redirect_uri"` // OAuth2 重定向URI + GrantCode bool `gorm:"column:grant_code;not null;comment:允许授权码授予" json:"grant_code"` // 允许授权码授予 + GrantClient bool `gorm:"column:grant_client;not null;comment:允许客户端凭证授予" json:"grant_client"` // 允许客户端凭证授予 + GrantRefresh bool `gorm:"column:grant_refresh;not null;comment:允许刷新令牌授予" json:"grant_refresh"` // 允许刷新令牌授予 + Spec int32 `gorm:"column:spec;not null;comment:安全规范:0-web,1-native,2-browser" json:"spec"` // 安全规范:0-web,1-native,2-browser + Name string `gorm:"column:name;not null;comment:名称" json:"name"` // 名称 + Version int32 `gorm:"column:version;not null;comment:版本" json:"version"` // 版本 + Status int32 `gorm:"column:status;not null;default:1;comment:状态:1-正常,0-禁用" json:"status"` // 状态:1-正常,0-禁用 + CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 +} + +// TableName Client's table name +func (*Client) TableName() string { + return TableNameClient +} diff --git a/web/models/client_permission_link.gen.go b/web/models/client_permission_link.gen.go new file mode 100644 index 0000000..9fd5f28 --- /dev/null +++ b/web/models/client_permission_link.gen.go @@ -0,0 +1,28 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package models + +import ( + "time" + + "gorm.io/gorm" +) + +const TableNameClientPermissionLink = "client_permission_link" + +// ClientPermissionLink mapped from table +type ClientPermissionLink struct { + ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:关联ID" json:"id"` // 关联ID + ClientID int32 `gorm:"column:client_id;not null;comment:客户端ID" json:"client_id"` // 客户端ID + PermissionID int32 `gorm:"column:permission_id;not null;comment:权限ID" json:"permission_id"` // 权限ID + CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 +} + +// TableName ClientPermissionLink's table name +func (*ClientPermissionLink) TableName() string { + return TableNameClientPermissionLink +} diff --git a/web/models/user.gen.go b/web/models/user.gen.go index 0ba9442..8fb5752 100644 --- a/web/models/user.gen.go +++ b/web/models/user.gen.go @@ -14,27 +14,27 @@ const TableNameUser = "user" // User mapped from table type User struct { - ID int32 `gorm:"column:id;primaryKey;autoIncrement:true" json:"id"` - AdminID int32 `gorm:"column:admin_id" json:"admin_id"` - Phone string `gorm:"column:phone;not null" json:"phone"` - Username string `gorm:"column:username" json:"username"` + ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:用户ID" json:"id"` // 用户ID + AdminID int32 `gorm:"column:admin_id;comment:管理员ID" json:"admin_id"` // 管理员ID + Phone string `gorm:"column:phone;not null;comment:手机号码" json:"phone"` // 手机号码 + Username string `gorm:"column:username;comment:用户名" json:"username"` // 用户名 Email string `gorm:"column:email" json:"email"` - Password string `gorm:"column:password" json:"password"` - Name string `gorm:"column:name" json:"name"` - Avatar string `gorm:"column:avatar" json:"avatar"` - Status int32 `gorm:"column:status;not null;default:1" json:"status"` - Balance float64 `gorm:"column:balance;not null" json:"balance"` - IDType int32 `gorm:"column:id_type;not null" json:"id_type"` - IDNo string `gorm:"column:id_no" json:"id_no"` - IDToken string `gorm:"column:id_token" json:"id_token"` - ContactQq string `gorm:"column:contact_qq" json:"contact_qq"` - ContactWechat string `gorm:"column:contact_wechat" json:"contact_wechat"` - LastLogin time.Time `gorm:"column:last_login" json:"last_login"` - LastLoginAddr string `gorm:"column:last_login_addr" json:"last_login_addr"` - LastLoginAgent string `gorm:"column:last_login_agent" json:"last_login_agent"` - CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP" json:"created_at"` - UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP" json:"updated_at"` - DeletedAt gorm.DeletedAt `gorm:"column:deleted_at" json:"deleted_at"` + Password string `gorm:"column:password;comment:用户密码" json:"password"` // 用户密码 + Name string `gorm:"column:name;comment:真实姓名" json:"name"` // 真实姓名 + Avatar string `gorm:"column:avatar;comment:头像URL" json:"avatar"` // 头像URL + Status int32 `gorm:"column:status;not null;default:1;comment:用户状态:1-正常,0-禁用" json:"status"` // 用户状态:1-正常,0-禁用 + Balance float64 `gorm:"column:balance;not null;comment:账户余额" json:"balance"` // 账户余额 + IDType int32 `gorm:"column:id_type;not null;comment:认证类型:0-未认证,1-个人认证,2-企业认证" json:"id_type"` // 认证类型:0-未认证,1-个人认证,2-企业认证 + IDNo string `gorm:"column:id_no;comment:身份证号或营业执照号" json:"id_no"` // 身份证号或营业执照号 + IDToken string `gorm:"column:id_token;comment:身份验证标识" json:"id_token"` // 身份验证标识 + ContactQq string `gorm:"column:contact_qq;comment:QQ联系方式" json:"contact_qq"` // QQ联系方式 + ContactWechat string `gorm:"column:contact_wechat;comment:微信联系方式" json:"contact_wechat"` // 微信联系方式 + LastLogin time.Time `gorm:"column:last_login;comment:最后登录时间" json:"last_login"` // 最后登录时间 + LastLoginHost string `gorm:"column:last_login_host;comment:最后登录地址" json:"last_login_host"` // 最后登录地址 + LastLoginAgent string `gorm:"column:last_login_agent;comment:最后登录代理" json:"last_login_agent"` // 最后登录代理 + CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 + UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 + DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 } // TableName User's table name diff --git a/web/models/whitelist.gen.go b/web/models/whitelist.gen.go index 1c8da3d..3e56ac2 100644 --- a/web/models/whitelist.gen.go +++ b/web/models/whitelist.gen.go @@ -16,7 +16,7 @@ const TableNameWhitelist = "whitelist" type Whitelist struct { ID int32 `gorm:"column:id;primaryKey;autoIncrement:true;comment:白名单ID" json:"id"` // 白名单ID UserID int32 `gorm:"column:user_id;not null;comment:用户ID" json:"user_id"` // 用户ID - Address string `gorm:"column:address;not null;comment:IP地址" json:"address"` // IP地址 + Host string `gorm:"column:host;not null;comment:IP地址" json:"host"` // IP地址 CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 diff --git a/web/queries/admin.gen.go b/web/queries/admin.gen.go index 29dae1d..417719f 100644 --- a/web/queries/admin.gen.go +++ b/web/queries/admin.gen.go @@ -36,7 +36,7 @@ func newAdmin(db *gorm.DB, opts ...gen.DOOption) admin { _admin.Email = field.NewString(tableName, "email") _admin.Status = field.NewInt32(tableName, "status") _admin.LastLogin = field.NewTime(tableName, "last_login") - _admin.LastLoginAddr = field.NewString(tableName, "last_login_addr") + _admin.LastLoginHost = field.NewString(tableName, "last_login_host") _admin.LastLoginAgent = field.NewString(tableName, "last_login_agent") _admin.CreatedAt = field.NewTime(tableName, "created_at") _admin.UpdatedAt = field.NewTime(tableName, "updated_at") @@ -60,7 +60,7 @@ type admin struct { Email field.String // 邮箱 Status field.Int32 // 状态:1-正常,0-禁用 LastLogin field.Time // 最后登录时间 - LastLoginAddr field.String // 最后登录地址 + LastLoginHost field.String // 最后登录地址 LastLoginAgent field.String // 最后登录代理 CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 @@ -90,7 +90,7 @@ func (a *admin) updateTableName(table string) *admin { a.Email = field.NewString(table, "email") a.Status = field.NewInt32(table, "status") a.LastLogin = field.NewTime(table, "last_login") - a.LastLoginAddr = field.NewString(table, "last_login_addr") + a.LastLoginHost = field.NewString(table, "last_login_host") a.LastLoginAgent = field.NewString(table, "last_login_agent") a.CreatedAt = field.NewTime(table, "created_at") a.UpdatedAt = field.NewTime(table, "updated_at") @@ -121,7 +121,7 @@ func (a *admin) fillFieldMap() { a.fieldMap["email"] = a.Email a.fieldMap["status"] = a.Status a.fieldMap["last_login"] = a.LastLogin - a.fieldMap["last_login_addr"] = a.LastLoginAddr + a.fieldMap["last_login_host"] = a.LastLoginHost a.fieldMap["last_login_agent"] = a.LastLoginAgent a.fieldMap["created_at"] = a.CreatedAt a.fieldMap["updated_at"] = a.UpdatedAt diff --git a/web/queries/channel.gen.go b/web/queries/channel.gen.go index d27bb17..537edf3 100644 --- a/web/queries/channel.gen.go +++ b/web/queries/channel.gen.go @@ -30,7 +30,7 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel { _channel.ID = field.NewInt32(tableName, "id") _channel.UserID = field.NewInt32(tableName, "user_id") _channel.NodeID = field.NewInt32(tableName, "node_id") - _channel.UserAddr = field.NewString(tableName, "user_addr") + _channel.UserHost = field.NewString(tableName, "user_host") _channel.NodePort = field.NewInt32(tableName, "node_port") _channel.AuthIP = field.NewBool(tableName, "auth_ip") _channel.AuthPass = field.NewBool(tableName, "auth_pass") @@ -54,7 +54,7 @@ type channel struct { ID field.Int32 // 通道ID UserID field.Int32 // 用户ID NodeID field.Int32 // 节点ID - UserAddr field.String // 用户地址 + UserHost field.String // 用户地址 NodePort field.Int32 // 节点端口 AuthIP field.Bool // IP认证 AuthPass field.Bool // 密码认证 @@ -84,7 +84,7 @@ func (c *channel) updateTableName(table string) *channel { c.ID = field.NewInt32(table, "id") c.UserID = field.NewInt32(table, "user_id") c.NodeID = field.NewInt32(table, "node_id") - c.UserAddr = field.NewString(table, "user_addr") + c.UserHost = field.NewString(table, "user_host") c.NodePort = field.NewInt32(table, "node_port") c.AuthIP = field.NewBool(table, "auth_ip") c.AuthPass = field.NewBool(table, "auth_pass") @@ -115,7 +115,7 @@ func (c *channel) fillFieldMap() { c.fieldMap["id"] = c.ID c.fieldMap["user_id"] = c.UserID c.fieldMap["node_id"] = c.NodeID - c.fieldMap["user_addr"] = c.UserAddr + c.fieldMap["user_host"] = c.UserHost c.fieldMap["node_port"] = c.NodePort c.fieldMap["auth_ip"] = c.AuthIP c.fieldMap["auth_pass"] = c.AuthPass diff --git a/web/queries/client.gen.go b/web/queries/client.gen.go new file mode 100644 index 0000000..10dc2a3 --- /dev/null +++ b/web/queries/client.gen.go @@ -0,0 +1,371 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package queries + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "platform/web/models" +) + +func newClient(db *gorm.DB, opts ...gen.DOOption) client { + _client := client{} + + _client.clientDo.UseDB(db, opts...) + _client.clientDo.UseModel(&models.Client{}) + + tableName := _client.clientDo.TableName() + _client.ALL = field.NewAsterisk(tableName) + _client.ID = field.NewInt32(tableName, "id") + _client.ClientID = field.NewString(tableName, "client_id") + _client.ClientSecret = field.NewString(tableName, "client_secret") + _client.RedirectURI = field.NewString(tableName, "redirect_uri") + _client.GrantCode = field.NewBool(tableName, "grant_code") + _client.GrantClient = field.NewBool(tableName, "grant_client") + _client.GrantRefresh = field.NewBool(tableName, "grant_refresh") + _client.Spec = field.NewInt32(tableName, "spec") + _client.Name = field.NewString(tableName, "name") + _client.Version = field.NewInt32(tableName, "version") + _client.Status = field.NewInt32(tableName, "status") + _client.CreatedAt = field.NewTime(tableName, "created_at") + _client.UpdatedAt = field.NewTime(tableName, "updated_at") + _client.DeletedAt = field.NewField(tableName, "deleted_at") + + _client.fillFieldMap() + + return _client +} + +type client struct { + clientDo + + ALL field.Asterisk + ID field.Int32 // 客户端ID + ClientID field.String // OAuth2客户端标识符 + ClientSecret field.String // OAuth2客户端密钥 + RedirectURI field.String // OAuth2 重定向URI + GrantCode field.Bool // 允许授权码授予 + GrantClient field.Bool // 允许客户端凭证授予 + GrantRefresh field.Bool // 允许刷新令牌授予 + Spec field.Int32 // 安全规范:0-web,1-native,2-browser + Name field.String // 名称 + Version field.Int32 // 版本 + Status field.Int32 // 状态:1-正常,0-禁用 + CreatedAt field.Time // 创建时间 + UpdatedAt field.Time // 更新时间 + DeletedAt field.Field // 删除时间 + + fieldMap map[string]field.Expr +} + +func (c client) Table(newTableName string) *client { + c.clientDo.UseTable(newTableName) + return c.updateTableName(newTableName) +} + +func (c client) As(alias string) *client { + c.clientDo.DO = *(c.clientDo.As(alias).(*gen.DO)) + return c.updateTableName(alias) +} + +func (c *client) updateTableName(table string) *client { + c.ALL = field.NewAsterisk(table) + c.ID = field.NewInt32(table, "id") + c.ClientID = field.NewString(table, "client_id") + c.ClientSecret = field.NewString(table, "client_secret") + c.RedirectURI = field.NewString(table, "redirect_uri") + c.GrantCode = field.NewBool(table, "grant_code") + c.GrantClient = field.NewBool(table, "grant_client") + c.GrantRefresh = field.NewBool(table, "grant_refresh") + c.Spec = field.NewInt32(table, "spec") + c.Name = field.NewString(table, "name") + c.Version = field.NewInt32(table, "version") + c.Status = field.NewInt32(table, "status") + c.CreatedAt = field.NewTime(table, "created_at") + c.UpdatedAt = field.NewTime(table, "updated_at") + c.DeletedAt = field.NewField(table, "deleted_at") + + c.fillFieldMap() + + return c +} + +func (c *client) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := c.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (c *client) fillFieldMap() { + c.fieldMap = make(map[string]field.Expr, 14) + c.fieldMap["id"] = c.ID + c.fieldMap["client_id"] = c.ClientID + c.fieldMap["client_secret"] = c.ClientSecret + c.fieldMap["redirect_uri"] = c.RedirectURI + c.fieldMap["grant_code"] = c.GrantCode + c.fieldMap["grant_client"] = c.GrantClient + c.fieldMap["grant_refresh"] = c.GrantRefresh + c.fieldMap["spec"] = c.Spec + c.fieldMap["name"] = c.Name + c.fieldMap["version"] = c.Version + c.fieldMap["status"] = c.Status + c.fieldMap["created_at"] = c.CreatedAt + c.fieldMap["updated_at"] = c.UpdatedAt + c.fieldMap["deleted_at"] = c.DeletedAt +} + +func (c client) clone(db *gorm.DB) client { + c.clientDo.ReplaceConnPool(db.Statement.ConnPool) + return c +} + +func (c client) replaceDB(db *gorm.DB) client { + c.clientDo.ReplaceDB(db) + return c +} + +type clientDo struct{ gen.DO } + +func (c clientDo) Debug() *clientDo { + return c.withDO(c.DO.Debug()) +} + +func (c clientDo) WithContext(ctx context.Context) *clientDo { + return c.withDO(c.DO.WithContext(ctx)) +} + +func (c clientDo) ReadDB() *clientDo { + return c.Clauses(dbresolver.Read) +} + +func (c clientDo) WriteDB() *clientDo { + return c.Clauses(dbresolver.Write) +} + +func (c clientDo) Session(config *gorm.Session) *clientDo { + return c.withDO(c.DO.Session(config)) +} + +func (c clientDo) Clauses(conds ...clause.Expression) *clientDo { + return c.withDO(c.DO.Clauses(conds...)) +} + +func (c clientDo) Returning(value interface{}, columns ...string) *clientDo { + return c.withDO(c.DO.Returning(value, columns...)) +} + +func (c clientDo) Not(conds ...gen.Condition) *clientDo { + return c.withDO(c.DO.Not(conds...)) +} + +func (c clientDo) Or(conds ...gen.Condition) *clientDo { + return c.withDO(c.DO.Or(conds...)) +} + +func (c clientDo) Select(conds ...field.Expr) *clientDo { + return c.withDO(c.DO.Select(conds...)) +} + +func (c clientDo) Where(conds ...gen.Condition) *clientDo { + return c.withDO(c.DO.Where(conds...)) +} + +func (c clientDo) Order(conds ...field.Expr) *clientDo { + return c.withDO(c.DO.Order(conds...)) +} + +func (c clientDo) Distinct(cols ...field.Expr) *clientDo { + return c.withDO(c.DO.Distinct(cols...)) +} + +func (c clientDo) Omit(cols ...field.Expr) *clientDo { + return c.withDO(c.DO.Omit(cols...)) +} + +func (c clientDo) Join(table schema.Tabler, on ...field.Expr) *clientDo { + return c.withDO(c.DO.Join(table, on...)) +} + +func (c clientDo) LeftJoin(table schema.Tabler, on ...field.Expr) *clientDo { + return c.withDO(c.DO.LeftJoin(table, on...)) +} + +func (c clientDo) RightJoin(table schema.Tabler, on ...field.Expr) *clientDo { + return c.withDO(c.DO.RightJoin(table, on...)) +} + +func (c clientDo) Group(cols ...field.Expr) *clientDo { + return c.withDO(c.DO.Group(cols...)) +} + +func (c clientDo) Having(conds ...gen.Condition) *clientDo { + return c.withDO(c.DO.Having(conds...)) +} + +func (c clientDo) Limit(limit int) *clientDo { + return c.withDO(c.DO.Limit(limit)) +} + +func (c clientDo) Offset(offset int) *clientDo { + return c.withDO(c.DO.Offset(offset)) +} + +func (c clientDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *clientDo { + return c.withDO(c.DO.Scopes(funcs...)) +} + +func (c clientDo) Unscoped() *clientDo { + return c.withDO(c.DO.Unscoped()) +} + +func (c clientDo) Create(values ...*models.Client) error { + if len(values) == 0 { + return nil + } + return c.DO.Create(values) +} + +func (c clientDo) CreateInBatches(values []*models.Client, batchSize int) error { + return c.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (c clientDo) Save(values ...*models.Client) error { + if len(values) == 0 { + return nil + } + return c.DO.Save(values) +} + +func (c clientDo) First() (*models.Client, error) { + if result, err := c.DO.First(); err != nil { + return nil, err + } else { + return result.(*models.Client), nil + } +} + +func (c clientDo) Take() (*models.Client, error) { + if result, err := c.DO.Take(); err != nil { + return nil, err + } else { + return result.(*models.Client), nil + } +} + +func (c clientDo) Last() (*models.Client, error) { + if result, err := c.DO.Last(); err != nil { + return nil, err + } else { + return result.(*models.Client), nil + } +} + +func (c clientDo) Find() ([]*models.Client, error) { + result, err := c.DO.Find() + return result.([]*models.Client), err +} + +func (c clientDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.Client, err error) { + buf := make([]*models.Client, 0, batchSize) + err = c.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (c clientDo) FindInBatches(result *[]*models.Client, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return c.DO.FindInBatches(result, batchSize, fc) +} + +func (c clientDo) Attrs(attrs ...field.AssignExpr) *clientDo { + return c.withDO(c.DO.Attrs(attrs...)) +} + +func (c clientDo) Assign(attrs ...field.AssignExpr) *clientDo { + return c.withDO(c.DO.Assign(attrs...)) +} + +func (c clientDo) Joins(fields ...field.RelationField) *clientDo { + for _, _f := range fields { + c = *c.withDO(c.DO.Joins(_f)) + } + return &c +} + +func (c clientDo) Preload(fields ...field.RelationField) *clientDo { + for _, _f := range fields { + c = *c.withDO(c.DO.Preload(_f)) + } + return &c +} + +func (c clientDo) FirstOrInit() (*models.Client, error) { + if result, err := c.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*models.Client), nil + } +} + +func (c clientDo) FirstOrCreate() (*models.Client, error) { + if result, err := c.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*models.Client), nil + } +} + +func (c clientDo) FindByPage(offset int, limit int) (result []*models.Client, count int64, err error) { + result, err = c.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = c.Offset(-1).Limit(-1).Count() + return +} + +func (c clientDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = c.Count() + if err != nil { + return + } + + err = c.Offset(offset).Limit(limit).Scan(result) + return +} + +func (c clientDo) Scan(result interface{}) (err error) { + return c.DO.Scan(result) +} + +func (c clientDo) Delete(models ...*models.Client) (result gen.ResultInfo, err error) { + return c.DO.Delete(models) +} + +func (c *clientDo) withDO(do gen.Dao) *clientDo { + c.DO = *do.(*gen.DO) + return c +} diff --git a/web/queries/client_permission_link.gen.go b/web/queries/client_permission_link.gen.go new file mode 100644 index 0000000..3cea478 --- /dev/null +++ b/web/queries/client_permission_link.gen.go @@ -0,0 +1,339 @@ +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. +// Code generated by gorm.io/gen. DO NOT EDIT. + +package queries + +import ( + "context" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + + "gorm.io/gen" + "gorm.io/gen/field" + + "gorm.io/plugin/dbresolver" + + "platform/web/models" +) + +func newClientPermissionLink(db *gorm.DB, opts ...gen.DOOption) clientPermissionLink { + _clientPermissionLink := clientPermissionLink{} + + _clientPermissionLink.clientPermissionLinkDo.UseDB(db, opts...) + _clientPermissionLink.clientPermissionLinkDo.UseModel(&models.ClientPermissionLink{}) + + tableName := _clientPermissionLink.clientPermissionLinkDo.TableName() + _clientPermissionLink.ALL = field.NewAsterisk(tableName) + _clientPermissionLink.ID = field.NewInt32(tableName, "id") + _clientPermissionLink.ClientID = field.NewInt32(tableName, "client_id") + _clientPermissionLink.PermissionID = field.NewInt32(tableName, "permission_id") + _clientPermissionLink.CreatedAt = field.NewTime(tableName, "created_at") + _clientPermissionLink.UpdatedAt = field.NewTime(tableName, "updated_at") + _clientPermissionLink.DeletedAt = field.NewField(tableName, "deleted_at") + + _clientPermissionLink.fillFieldMap() + + return _clientPermissionLink +} + +type clientPermissionLink struct { + clientPermissionLinkDo + + ALL field.Asterisk + ID field.Int32 // 关联ID + ClientID field.Int32 // 客户端ID + PermissionID field.Int32 // 权限ID + CreatedAt field.Time // 创建时间 + UpdatedAt field.Time // 更新时间 + DeletedAt field.Field // 删除时间 + + fieldMap map[string]field.Expr +} + +func (c clientPermissionLink) Table(newTableName string) *clientPermissionLink { + c.clientPermissionLinkDo.UseTable(newTableName) + return c.updateTableName(newTableName) +} + +func (c clientPermissionLink) As(alias string) *clientPermissionLink { + c.clientPermissionLinkDo.DO = *(c.clientPermissionLinkDo.As(alias).(*gen.DO)) + return c.updateTableName(alias) +} + +func (c *clientPermissionLink) updateTableName(table string) *clientPermissionLink { + c.ALL = field.NewAsterisk(table) + c.ID = field.NewInt32(table, "id") + c.ClientID = field.NewInt32(table, "client_id") + c.PermissionID = field.NewInt32(table, "permission_id") + c.CreatedAt = field.NewTime(table, "created_at") + c.UpdatedAt = field.NewTime(table, "updated_at") + c.DeletedAt = field.NewField(table, "deleted_at") + + c.fillFieldMap() + + return c +} + +func (c *clientPermissionLink) GetFieldByName(fieldName string) (field.OrderExpr, bool) { + _f, ok := c.fieldMap[fieldName] + if !ok || _f == nil { + return nil, false + } + _oe, ok := _f.(field.OrderExpr) + return _oe, ok +} + +func (c *clientPermissionLink) fillFieldMap() { + c.fieldMap = make(map[string]field.Expr, 6) + c.fieldMap["id"] = c.ID + c.fieldMap["client_id"] = c.ClientID + c.fieldMap["permission_id"] = c.PermissionID + c.fieldMap["created_at"] = c.CreatedAt + c.fieldMap["updated_at"] = c.UpdatedAt + c.fieldMap["deleted_at"] = c.DeletedAt +} + +func (c clientPermissionLink) clone(db *gorm.DB) clientPermissionLink { + c.clientPermissionLinkDo.ReplaceConnPool(db.Statement.ConnPool) + return c +} + +func (c clientPermissionLink) replaceDB(db *gorm.DB) clientPermissionLink { + c.clientPermissionLinkDo.ReplaceDB(db) + return c +} + +type clientPermissionLinkDo struct{ gen.DO } + +func (c clientPermissionLinkDo) Debug() *clientPermissionLinkDo { + return c.withDO(c.DO.Debug()) +} + +func (c clientPermissionLinkDo) WithContext(ctx context.Context) *clientPermissionLinkDo { + return c.withDO(c.DO.WithContext(ctx)) +} + +func (c clientPermissionLinkDo) ReadDB() *clientPermissionLinkDo { + return c.Clauses(dbresolver.Read) +} + +func (c clientPermissionLinkDo) WriteDB() *clientPermissionLinkDo { + return c.Clauses(dbresolver.Write) +} + +func (c clientPermissionLinkDo) Session(config *gorm.Session) *clientPermissionLinkDo { + return c.withDO(c.DO.Session(config)) +} + +func (c clientPermissionLinkDo) Clauses(conds ...clause.Expression) *clientPermissionLinkDo { + return c.withDO(c.DO.Clauses(conds...)) +} + +func (c clientPermissionLinkDo) Returning(value interface{}, columns ...string) *clientPermissionLinkDo { + return c.withDO(c.DO.Returning(value, columns...)) +} + +func (c clientPermissionLinkDo) Not(conds ...gen.Condition) *clientPermissionLinkDo { + return c.withDO(c.DO.Not(conds...)) +} + +func (c clientPermissionLinkDo) Or(conds ...gen.Condition) *clientPermissionLinkDo { + return c.withDO(c.DO.Or(conds...)) +} + +func (c clientPermissionLinkDo) Select(conds ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Select(conds...)) +} + +func (c clientPermissionLinkDo) Where(conds ...gen.Condition) *clientPermissionLinkDo { + return c.withDO(c.DO.Where(conds...)) +} + +func (c clientPermissionLinkDo) Order(conds ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Order(conds...)) +} + +func (c clientPermissionLinkDo) Distinct(cols ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Distinct(cols...)) +} + +func (c clientPermissionLinkDo) Omit(cols ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Omit(cols...)) +} + +func (c clientPermissionLinkDo) Join(table schema.Tabler, on ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Join(table, on...)) +} + +func (c clientPermissionLinkDo) LeftJoin(table schema.Tabler, on ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.LeftJoin(table, on...)) +} + +func (c clientPermissionLinkDo) RightJoin(table schema.Tabler, on ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.RightJoin(table, on...)) +} + +func (c clientPermissionLinkDo) Group(cols ...field.Expr) *clientPermissionLinkDo { + return c.withDO(c.DO.Group(cols...)) +} + +func (c clientPermissionLinkDo) Having(conds ...gen.Condition) *clientPermissionLinkDo { + return c.withDO(c.DO.Having(conds...)) +} + +func (c clientPermissionLinkDo) Limit(limit int) *clientPermissionLinkDo { + return c.withDO(c.DO.Limit(limit)) +} + +func (c clientPermissionLinkDo) Offset(offset int) *clientPermissionLinkDo { + return c.withDO(c.DO.Offset(offset)) +} + +func (c clientPermissionLinkDo) Scopes(funcs ...func(gen.Dao) gen.Dao) *clientPermissionLinkDo { + return c.withDO(c.DO.Scopes(funcs...)) +} + +func (c clientPermissionLinkDo) Unscoped() *clientPermissionLinkDo { + return c.withDO(c.DO.Unscoped()) +} + +func (c clientPermissionLinkDo) Create(values ...*models.ClientPermissionLink) error { + if len(values) == 0 { + return nil + } + return c.DO.Create(values) +} + +func (c clientPermissionLinkDo) CreateInBatches(values []*models.ClientPermissionLink, batchSize int) error { + return c.DO.CreateInBatches(values, batchSize) +} + +// Save : !!! underlying implementation is different with GORM +// The method is equivalent to executing the statement: db.Clauses(clause.OnConflict{UpdateAll: true}).Create(values) +func (c clientPermissionLinkDo) Save(values ...*models.ClientPermissionLink) error { + if len(values) == 0 { + return nil + } + return c.DO.Save(values) +} + +func (c clientPermissionLinkDo) First() (*models.ClientPermissionLink, error) { + if result, err := c.DO.First(); err != nil { + return nil, err + } else { + return result.(*models.ClientPermissionLink), nil + } +} + +func (c clientPermissionLinkDo) Take() (*models.ClientPermissionLink, error) { + if result, err := c.DO.Take(); err != nil { + return nil, err + } else { + return result.(*models.ClientPermissionLink), nil + } +} + +func (c clientPermissionLinkDo) Last() (*models.ClientPermissionLink, error) { + if result, err := c.DO.Last(); err != nil { + return nil, err + } else { + return result.(*models.ClientPermissionLink), nil + } +} + +func (c clientPermissionLinkDo) Find() ([]*models.ClientPermissionLink, error) { + result, err := c.DO.Find() + return result.([]*models.ClientPermissionLink), err +} + +func (c clientPermissionLinkDo) FindInBatch(batchSize int, fc func(tx gen.Dao, batch int) error) (results []*models.ClientPermissionLink, err error) { + buf := make([]*models.ClientPermissionLink, 0, batchSize) + err = c.DO.FindInBatches(&buf, batchSize, func(tx gen.Dao, batch int) error { + defer func() { results = append(results, buf...) }() + return fc(tx, batch) + }) + return results, err +} + +func (c clientPermissionLinkDo) FindInBatches(result *[]*models.ClientPermissionLink, batchSize int, fc func(tx gen.Dao, batch int) error) error { + return c.DO.FindInBatches(result, batchSize, fc) +} + +func (c clientPermissionLinkDo) Attrs(attrs ...field.AssignExpr) *clientPermissionLinkDo { + return c.withDO(c.DO.Attrs(attrs...)) +} + +func (c clientPermissionLinkDo) Assign(attrs ...field.AssignExpr) *clientPermissionLinkDo { + return c.withDO(c.DO.Assign(attrs...)) +} + +func (c clientPermissionLinkDo) Joins(fields ...field.RelationField) *clientPermissionLinkDo { + for _, _f := range fields { + c = *c.withDO(c.DO.Joins(_f)) + } + return &c +} + +func (c clientPermissionLinkDo) Preload(fields ...field.RelationField) *clientPermissionLinkDo { + for _, _f := range fields { + c = *c.withDO(c.DO.Preload(_f)) + } + return &c +} + +func (c clientPermissionLinkDo) FirstOrInit() (*models.ClientPermissionLink, error) { + if result, err := c.DO.FirstOrInit(); err != nil { + return nil, err + } else { + return result.(*models.ClientPermissionLink), nil + } +} + +func (c clientPermissionLinkDo) FirstOrCreate() (*models.ClientPermissionLink, error) { + if result, err := c.DO.FirstOrCreate(); err != nil { + return nil, err + } else { + return result.(*models.ClientPermissionLink), nil + } +} + +func (c clientPermissionLinkDo) FindByPage(offset int, limit int) (result []*models.ClientPermissionLink, count int64, err error) { + result, err = c.Offset(offset).Limit(limit).Find() + if err != nil { + return + } + + if size := len(result); 0 < limit && 0 < size && size < limit { + count = int64(size + offset) + return + } + + count, err = c.Offset(-1).Limit(-1).Count() + return +} + +func (c clientPermissionLinkDo) ScanByPage(result interface{}, offset int, limit int) (count int64, err error) { + count, err = c.Count() + if err != nil { + return + } + + err = c.Offset(offset).Limit(limit).Scan(result) + return +} + +func (c clientPermissionLinkDo) Scan(result interface{}) (err error) { + return c.DO.Scan(result) +} + +func (c clientPermissionLinkDo) Delete(models ...*models.ClientPermissionLink) (result gen.ResultInfo, err error) { + return c.DO.Delete(models) +} + +func (c *clientPermissionLinkDo) withDO(do gen.Dao) *clientPermissionLinkDo { + c.DO = *do.(*gen.DO) + return c +} diff --git a/web/queries/gen.go b/web/queries/gen.go index df83828..eace030 100644 --- a/web/queries/gen.go +++ b/web/queries/gen.go @@ -23,6 +23,8 @@ var ( AdminRolePermissionLink *adminRolePermissionLink Bill *bill Channel *channel + Client *client + ClientPermissionLink *clientPermissionLink Node *node Permission *permission Product *product @@ -47,6 +49,8 @@ func SetDefault(db *gorm.DB, opts ...gen.DOOption) { AdminRolePermissionLink = &Q.AdminRolePermissionLink Bill = &Q.Bill Channel = &Q.Channel + Client = &Q.Client + ClientPermissionLink = &Q.ClientPermissionLink Node = &Q.Node Permission = &Q.Permission Product = &Q.Product @@ -72,6 +76,8 @@ func Use(db *gorm.DB, opts ...gen.DOOption) *Query { AdminRolePermissionLink: newAdminRolePermissionLink(db, opts...), Bill: newBill(db, opts...), Channel: newChannel(db, opts...), + Client: newClient(db, opts...), + ClientPermissionLink: newClientPermissionLink(db, opts...), Node: newNode(db, opts...), Permission: newPermission(db, opts...), Product: newProduct(db, opts...), @@ -98,6 +104,8 @@ type Query struct { AdminRolePermissionLink adminRolePermissionLink Bill bill Channel channel + Client client + ClientPermissionLink clientPermissionLink Node node Permission permission Product product @@ -125,6 +133,8 @@ func (q *Query) clone(db *gorm.DB) *Query { AdminRolePermissionLink: q.AdminRolePermissionLink.clone(db), Bill: q.Bill.clone(db), Channel: q.Channel.clone(db), + Client: q.Client.clone(db), + ClientPermissionLink: q.ClientPermissionLink.clone(db), Node: q.Node.clone(db), Permission: q.Permission.clone(db), Product: q.Product.clone(db), @@ -159,6 +169,8 @@ func (q *Query) ReplaceDB(db *gorm.DB) *Query { AdminRolePermissionLink: q.AdminRolePermissionLink.replaceDB(db), Bill: q.Bill.replaceDB(db), Channel: q.Channel.replaceDB(db), + Client: q.Client.replaceDB(db), + ClientPermissionLink: q.ClientPermissionLink.replaceDB(db), Node: q.Node.replaceDB(db), Permission: q.Permission.replaceDB(db), Product: q.Product.replaceDB(db), @@ -183,6 +195,8 @@ type queryCtx struct { AdminRolePermissionLink *adminRolePermissionLinkDo Bill *billDo Channel *channelDo + Client *clientDo + ClientPermissionLink *clientPermissionLinkDo Node *nodeDo Permission *permissionDo Product *productDo @@ -207,6 +221,8 @@ func (q *Query) WithContext(ctx context.Context) *queryCtx { AdminRolePermissionLink: q.AdminRolePermissionLink.WithContext(ctx), Bill: q.Bill.WithContext(ctx), Channel: q.Channel.WithContext(ctx), + Client: q.Client.WithContext(ctx), + ClientPermissionLink: q.ClientPermissionLink.WithContext(ctx), Node: q.Node.WithContext(ctx), Permission: q.Permission.WithContext(ctx), Product: q.Product.WithContext(ctx), diff --git a/web/queries/user.gen.go b/web/queries/user.gen.go index d9fa0f6..9bd79bf 100644 --- a/web/queries/user.gen.go +++ b/web/queries/user.gen.go @@ -43,7 +43,7 @@ func newUser(db *gorm.DB, opts ...gen.DOOption) user { _user.ContactQq = field.NewString(tableName, "contact_qq") _user.ContactWechat = field.NewString(tableName, "contact_wechat") _user.LastLogin = field.NewTime(tableName, "last_login") - _user.LastLoginAddr = field.NewString(tableName, "last_login_addr") + _user.LastLoginHost = field.NewString(tableName, "last_login_host") _user.LastLoginAgent = field.NewString(tableName, "last_login_agent") _user.CreatedAt = field.NewTime(tableName, "created_at") _user.UpdatedAt = field.NewTime(tableName, "updated_at") @@ -58,27 +58,27 @@ type user struct { userDo ALL field.Asterisk - ID field.Int32 - AdminID field.Int32 - Phone field.String - Username field.String + ID field.Int32 // 用户ID + AdminID field.Int32 // 管理员ID + Phone field.String // 手机号码 + Username field.String // 用户名 Email field.String - Password field.String - Name field.String - Avatar field.String - Status field.Int32 - Balance field.Float64 - IDType field.Int32 - IDNo field.String - IDToken field.String - ContactQq field.String - ContactWechat field.String - LastLogin field.Time - LastLoginAddr field.String - LastLoginAgent field.String - CreatedAt field.Time - UpdatedAt field.Time - DeletedAt field.Field + Password field.String // 用户密码 + Name field.String // 真实姓名 + Avatar field.String // 头像URL + Status field.Int32 // 用户状态:1-正常,0-禁用 + Balance field.Float64 // 账户余额 + IDType field.Int32 // 认证类型:0-未认证,1-个人认证,2-企业认证 + IDNo field.String // 身份证号或营业执照号 + IDToken field.String // 身份验证标识 + ContactQq field.String // QQ联系方式 + ContactWechat field.String // 微信联系方式 + LastLogin field.Time // 最后登录时间 + LastLoginHost field.String // 最后登录地址 + LastLoginAgent field.String // 最后登录代理 + CreatedAt field.Time // 创建时间 + UpdatedAt field.Time // 更新时间 + DeletedAt field.Field // 删除时间 fieldMap map[string]field.Expr } @@ -111,7 +111,7 @@ func (u *user) updateTableName(table string) *user { u.ContactQq = field.NewString(table, "contact_qq") u.ContactWechat = field.NewString(table, "contact_wechat") u.LastLogin = field.NewTime(table, "last_login") - u.LastLoginAddr = field.NewString(table, "last_login_addr") + u.LastLoginHost = field.NewString(table, "last_login_host") u.LastLoginAgent = field.NewString(table, "last_login_agent") u.CreatedAt = field.NewTime(table, "created_at") u.UpdatedAt = field.NewTime(table, "updated_at") @@ -149,7 +149,7 @@ func (u *user) fillFieldMap() { u.fieldMap["contact_qq"] = u.ContactQq u.fieldMap["contact_wechat"] = u.ContactWechat u.fieldMap["last_login"] = u.LastLogin - u.fieldMap["last_login_addr"] = u.LastLoginAddr + u.fieldMap["last_login_host"] = u.LastLoginHost u.fieldMap["last_login_agent"] = u.LastLoginAgent u.fieldMap["created_at"] = u.CreatedAt u.fieldMap["updated_at"] = u.UpdatedAt diff --git a/web/queries/whitelist.gen.go b/web/queries/whitelist.gen.go index 622364e..7f9e612 100644 --- a/web/queries/whitelist.gen.go +++ b/web/queries/whitelist.gen.go @@ -29,7 +29,7 @@ func newWhitelist(db *gorm.DB, opts ...gen.DOOption) whitelist { _whitelist.ALL = field.NewAsterisk(tableName) _whitelist.ID = field.NewInt32(tableName, "id") _whitelist.UserID = field.NewInt32(tableName, "user_id") - _whitelist.Address = field.NewString(tableName, "address") + _whitelist.Host = field.NewString(tableName, "host") _whitelist.CreatedAt = field.NewTime(tableName, "created_at") _whitelist.UpdatedAt = field.NewTime(tableName, "updated_at") _whitelist.DeletedAt = field.NewField(tableName, "deleted_at") @@ -45,7 +45,7 @@ type whitelist struct { ALL field.Asterisk ID field.Int32 // 白名单ID UserID field.Int32 // 用户ID - Address field.String // IP地址 + Host field.String // IP地址 CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 @@ -67,7 +67,7 @@ func (w *whitelist) updateTableName(table string) *whitelist { w.ALL = field.NewAsterisk(table) w.ID = field.NewInt32(table, "id") w.UserID = field.NewInt32(table, "user_id") - w.Address = field.NewString(table, "address") + w.Host = field.NewString(table, "host") w.CreatedAt = field.NewTime(table, "created_at") w.UpdatedAt = field.NewTime(table, "updated_at") w.DeletedAt = field.NewField(table, "deleted_at") @@ -90,7 +90,7 @@ func (w *whitelist) fillFieldMap() { w.fieldMap = make(map[string]field.Expr, 6) w.fieldMap["id"] = w.ID w.fieldMap["user_id"] = w.UserID - w.fieldMap["address"] = w.Address + w.fieldMap["host"] = w.Host w.fieldMap["created_at"] = w.CreatedAt w.fieldMap["updated_at"] = w.UpdatedAt w.fieldMap["deleted_at"] = w.DeletedAt diff --git a/web/router.go b/web/router.go index 36aced1..9dcee64 100644 --- a/web/router.go +++ b/web/router.go @@ -1,9 +1,21 @@ package web import ( + "platform/web/handlers" + "github.com/gofiber/fiber/v2" ) -func UseRoute(app *fiber.App) { +func ApplyRouters(app *fiber.App) { + api := app.Group("/api") + // 认证路由 + auth := api.Group("/auth") + auth.Post("/verify/sms", Protect(), handlers.SmsCode) + auth.Post("/login/sms", Protect(), handlers.Login) + auth.Post("/token", handlers.Token) + + // 客户端路由 + client := api.Group("/client") + client.Get("/test/create", handlers.CreateClient) } diff --git a/web/services/auth.go b/web/services/auth.go new file mode 100644 index 0000000..4f8c07f --- /dev/null +++ b/web/services/auth.go @@ -0,0 +1,85 @@ +package services + +import ( + "context" + "errors" + "platform/web/models" +) + +var Auth = &authService{} + +type authService struct{} + +type AuthServiceError string + +func (e AuthServiceError) Error() string { + return string(e) +} + +type AuthServiceOauthError string + +func (e AuthServiceOauthError) Error() string { + return string(e) +} + +var ( + ErrOauthInvalidRequest = AuthServiceOauthError("invalid_request") + ErrOauthInvalidClient = AuthServiceOauthError("invalid_client") + ErrOauthInvalidGrant = AuthServiceOauthError("invalid_grant") + ErrOauthInvalidScope = AuthServiceOauthError("invalid_scope") + ErrOauthUnauthorizedClient = AuthServiceOauthError("unauthorized_client") + ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type") +) + +// OauthAuthorizationCode 验证授权码 +func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) { + // TODO: 从数据库验证授权码 + return nil, errors.New("TODO") +} + +// OauthClientCredentials 验证客户端凭证 +func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...[]string) (*TokenDetails, error) { + + var clientType PayloadType + switch client.Spec { + case 0: + clientType = PayloadClientConfidential + case 1: + clientType = PayloadClientPublic + case 2: + clientType = PayloadClientConfidential + } + + // 保存会话并返回令牌 + auth := AuthContext{ + Permissions: map[string]struct{}{ + "client": {}, + }, + Payload: Payload{ + Type: clientType, + Id: client.ID, + }, + } + + // todo 数据库定义会话持续时间 + token, err := Session.Create(ctx, auth) + if err != nil { + return nil, err + } + + return token, nil +} + +// OauthRefreshToken 验证刷新令牌 +func (s *authService) OauthRefreshToken(ctx context.Context, client *models.Client, refreshToken string, scope ...[]string) (*TokenDetails, error) { + // TODO: 从数据库验证刷新令牌 + return nil, errors.New("TODO") +} + +type GrantType int + +const ( + GrantTypeAuthorizationCode GrantType = iota + GrantTypeClientCredentials + GrantTypeRefreshToken +) diff --git a/web/services/session.go b/web/services/session.go new file mode 100644 index 0000000..b116254 --- /dev/null +++ b/web/services/session.go @@ -0,0 +1,288 @@ +package services + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "platform/init/rds" + "time" + + "github.com/google/uuid" + "github.com/redis/go-redis/v9" +) + +// region SessionService + +var Session = &sessionService{} + +type sessionService struct { +} + +type SessionServiceError string + +func (e SessionServiceError) Error() string { + return string(e) +} + +var ( + ErrInvalidToken = SessionServiceError("invalid_token") +) + +// Find 通过访问令牌获取会话信息 +func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) { + + // 读取认证数据 + authJSON, err := rds.Client.Get(ctx, accessKey(token)).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrInvalidToken + } + return nil, err + } + + // 反序列化 + auth := new(AuthContext) + if err := json.Unmarshal([]byte(authJSON), auth); err != nil { + return nil, err + } + + return auth, nil +} + +// Create 创建一个新的会话 +func (s *sessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) { + // 解析可选配置 + cfg := DefaultSessionConfig + if len(config) > 0 { + cfg = mergeConfig(DefaultSessionConfig, config[0]) + } + + // 生成令牌组 + accessToken := genToken() + refreshToken := genToken() + + // 序列化认证数据 + authData, err := json.Marshal(auth) + if err != nil { + return nil, err + } + + // 序列化刷新令牌数据 + refreshData, err := json.Marshal(RefreshData{ + AuthContext: auth, + AccessToken: accessToken, + }) + if err != nil { + return nil, err + } + + // 事务保存数据到 Redis + pipe := rds.Client.TxPipeline() + pipe.Set(ctx, accessKey(accessToken), authData, cfg.AccessTokenDuration) + pipe.Set(ctx, refreshKey(refreshToken), refreshData, cfg.RefreshTokenDuration) + _, err = pipe.Exec(ctx) + if err != nil { + return nil, err + } + + return &TokenDetails{ + AccessToken: accessToken, + AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration), + RefreshToken: refreshToken, + RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration), + Auth: auth, + }, nil +} + +// Refresh 刷新一个会话 +func (s *sessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) { + // 解析可选配置 + cfg := DefaultSessionConfig + if len(config) > 0 { + cfg = mergeConfig(DefaultSessionConfig, config[0]) + } + + rKey := refreshKey(refreshToken) + var tokenDetails *TokenDetails + + // 刷新令牌 + err := rds.Client.Watch(ctx, func(tx *redis.Tx) error { + + // 先获取刷新令牌数据 + refreshJson, err := tx.Get(ctx, rKey).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return ErrInvalidToken + } + return err + } + + // 解析刷新令牌数据 + refreshData := new(RefreshData) + if err := json.Unmarshal([]byte(refreshJson), refreshData); err != nil { + return err + } + + // 删除旧的令牌 + pipeline := tx.Pipeline() + pipeline.Del(ctx, accessKey(refreshData.AccessToken)) + pipeline.Del(ctx, refreshKey(refreshToken)) + + // 生成新的令牌 + newAccessToken := genToken() + newRefreshToken := genToken() + + authData, err := json.Marshal(refreshData.AuthContext) + if err != nil { + return err + } + newRefreshData, err := json.Marshal(RefreshData{ + AuthContext: refreshData.AuthContext, + AccessToken: newAccessToken, + }) + if err != nil { + return err + } + + pipeline.Set(ctx, accessKey(newAccessToken), authData, cfg.AccessTokenDuration) + pipeline.Set(ctx, refreshKey(newRefreshToken), newRefreshData, cfg.RefreshTokenDuration) + + _, err = pipeline.Exec(ctx) + if err != nil { + return err + } + + tokenDetails = &TokenDetails{ + AccessToken: newAccessToken, + RefreshToken: newRefreshToken, + AccessTokenExpires: time.Now().Add(cfg.AccessTokenDuration), + RefreshTokenExpires: time.Now().Add(cfg.RefreshTokenDuration), + Auth: refreshData.AuthContext, + } + return nil + }, rKey) + if err != nil { + return nil, fmt.Errorf("刷新令牌失败: %w", err) + } + + return tokenDetails, nil +} + +// Remove 删除会话 +func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { + rds.Client.Del(ctx, accessKey(accessToken), refreshKey(refreshToken)) + return nil +} + +// 令牌键的格式为 "session:" +func accessKey(token string) string { + return fmt.Sprintf("session:%s", token) +} + +// 刷新令牌键的格式为 "session:refreshKey:" +func refreshKey(token string) string { + return fmt.Sprintf("session:refresh:%s", token) +} + +// 生成一个新的令牌 +func genToken() string { + return uuid.NewString() +} + +// endregion + +// region SessionConfig + +// SessionConfig 定义会话管理的配置选项 +type SessionConfig struct { + // 令牌配置 + AccessTokenDuration time.Duration + RefreshTokenDuration time.Duration +} + +// DefaultSessionConfig 默认会话配置 +var DefaultSessionConfig = SessionConfig{ + AccessTokenDuration: 2 * time.Hour, + RefreshTokenDuration: 7 * 24 * time.Hour, +} + +// 合并配置,保留非零值 +func mergeConfig(defaultCfg SessionConfig, customCfg SessionConfig) SessionConfig { + result := defaultCfg + + if customCfg.AccessTokenDuration != 0 { + result.AccessTokenDuration = customCfg.AccessTokenDuration + } + + if customCfg.RefreshTokenDuration != 0 { + result.RefreshTokenDuration = customCfg.RefreshTokenDuration + } + + return result +} + +// endregion + +// region AuthContext + +// AuthContext 定义认证信息 +type AuthContext struct { + Payload Payload + Permissions map[string]struct{} + Metadata map[string]interface{} +} + +// Payload 定义负载信息 +type Payload struct { + Type PayloadType + Id int32 +} + +// PayloadType 定义负载类型 +type PayloadType int + +const ( + // PayloadUser 用户类型 + PayloadUser PayloadType = iota + // PayloadAdmin 管理员类型 + PayloadAdmin + // PayloadClientPublic 公共客户端类型 + PayloadClientPublic + // PayloadClientConfidential 机密客户端类型 + PayloadClientConfidential +) + +// AnyPermission 检查认证是否包含指定权限 +func (a *AuthContext) AnyPermission(requiredPermission ...string) bool { + if a == nil || a.Permissions == nil { + return false + } + for _, permission := range requiredPermission { + if _, ok := a.Permissions[permission]; ok { + return true + } + } + return false +} + +// endregion + +type RefreshData struct { + AuthContext AuthContext + AccessToken string +} + +// TokenDetails 存储令牌详细信息 +type TokenDetails struct { + // 访问令牌 + AccessToken string + // 刷新令牌 + RefreshToken string + // 访问令牌过期时间 + AccessTokenExpires time.Time + // 刷新令牌过期时间 + RefreshTokenExpires time.Time + // 认证信息 + Auth AuthContext +} diff --git a/web/services/verifier.go b/web/services/verifier.go new file mode 100644 index 0000000..6d28c28 --- /dev/null +++ b/web/services/verifier.go @@ -0,0 +1,124 @@ +package services + +import ( + "context" + "errors" + "fmt" + "log/slog" + "math/rand" + "platform/init/rds" + "strconv" + "time" + + "github.com/redis/go-redis/v9" +) + +var Verifier = &verifierService{} + +type verifierService struct { +} + +type VerifierServiceError string + +func (e VerifierServiceError) Error() string { + return string(e) +} + +var ( + ErrVerifierServiceInvalid = VerifierServiceError("验证码错误") +) + +type VerifierServiceSendLimitErr int + +func (e VerifierServiceSendLimitErr) Error() string { + return "发送频率过快" +} + +type VerifierSmsPurpose int + +const ( + Login VerifierSmsPurpose = iota +) + +func smsKey(phone string, purpose VerifierSmsPurpose) string { + return fmt.Sprintf("verify:sms:%d:%s", purpose, phone) +} + +func (s *verifierService) SendSms(ctx context.Context, phone string, purpose VerifierSmsPurpose) error { + key := smsKey(phone, purpose) + keyLock := key + ":lock" + + // 生成验证码 + code := rand.Intn(900000) + 100000 // 6-digit code between 100000-999999 + + // 检查发送频率,1 分钟内只能发送一次 + err := rds.Client.Watch(ctx, func(tx *redis.Tx) error { + result, err := tx.TTL(ctx, keyLock).Result() + if err != nil { + return err + } + if result > 0 { + return VerifierServiceSendLimitErr(result.Seconds()) + } + if result != -2 { + return VerifierServiceError("验证码检查异常") + } + + pipe := rds.Client.Pipeline() + pipe.Set(ctx, key, code, 10*time.Minute) + pipe.Set(ctx, keyLock, "", 1*time.Minute) + _, err = pipe.Exec(ctx) + if err != nil { + return err + } + return nil + }, keyLock) + if err != nil { + return err + } + + // TODO: 发送短信验证码 + slog.Debug("发送验证码", slog.String("phone", phone), slog.String("code", strconv.Itoa(code))) + + return nil +} + +func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bool, error) { + key := smsKey(phone, Login) + keyLock := key + ":lock" + + err := rds.Client.Watch(ctx, func(tx *redis.Tx) error { + + // 检查验证码 + val, err := rds.Client.Get(ctx, key).Result() + if err != nil && !errors.Is(err, redis.Nil) { + slog.Error("验证码获取失败", slog.Any("err", err)) + return err + } + + if val != code { + return ErrVerifierServiceInvalid + } + + // 删除验证码 + _, err = tx.Pipelined(ctx, func(pipe redis.Pipeliner) error { + pipe.Del(ctx, key) + pipe.Del(ctx, keyLock) + return nil + }) + if err != nil { + slog.Error("验证码删除失败", slog.Any("err", err)) + return err + } + + return nil + }, key) + if err != nil { + if errors.Is(err, ErrVerifierServiceInvalid) { + return false, nil + } + return false, err + } + + return true, nil +} diff --git a/web/utils/resp.go b/web/utils/resp.go deleted file mode 100644 index df6d356..0000000 --- a/web/utils/resp.go +++ /dev/null @@ -1,5 +0,0 @@ -package utils - -type ErrResp struct { - Cause string `json:"cause"` -} diff --git a/web/web.go b/web/web.go index a2893f1..410ba9b 100644 --- a/web/web.go +++ b/web/web.go @@ -1,18 +1,20 @@ package web import ( + "platform/init/env" + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "github.com/gofiber/fiber/v2/middleware/requestid" ) import "log/slog" type Config struct { - Logger *slog.Logger Listen string } type Server struct { config *Config - log *slog.Logger fiber *fiber.App } @@ -22,33 +24,35 @@ func New(config *Config) (*Server, error) { _config = &Config{} } - if _config.Logger == nil { - _config.Logger = slog.Default() - } - return &Server{ config: _config, - log: _config.Logger, }, nil } func (s *Server) Run() error { - s.fiber = fiber.New(fiber.Config{}) - UseRoute(s.fiber) + s.fiber = fiber.New(fiber.Config{ + ErrorHandler: ErrorHandler, + }) - s.log.Info("Server started on :8080") - err := s.fiber.Listen(":8080") + s.fiber.Use(requestid.New()) + s.fiber.Use(logger.New()) + + ApplyRouters(s.fiber) + + port := env.AppPort + slog.Info("Server started on :" + port) + err := s.fiber.Listen(":" + port) if err != nil { - s.log.Error("Failed to start server", slog.Any("error", err)) + slog.Error("Failed to start server", slog.Any("err", err)) } - s.log.Info("Server stopped") + slog.Info("Server stopped") return nil } func (s *Server) Stop() { err := s.fiber.Shutdown() if err != nil { - s.log.Error("Failed to shutdown server", slog.Any("error", err)) + slog.Error("Failed to shutdown server", slog.Any("err", err)) } }