diff --git a/README.md b/README.md index e3ef105..8f9e308 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ - [ ] 对接短信接口 - [ ] 人机风险分级验证 - [ ] jwt 签发 -- [ ] 鉴权 +- [x] 鉴权 - [ ] 实名认证 - [ ] 充值余额 - [ ] 选择套餐 diff --git a/web/handlers/channel.go b/web/handlers/channel.go new file mode 100644 index 0000000..c5df6f9 --- /dev/null +++ b/web/handlers/channel.go @@ -0,0 +1,92 @@ +package handlers + +import ( + "errors" + "fmt" + "platform/web/services" + "strings" + + "github.com/gofiber/fiber/v2" +) + +// region CreateChannel + +type CreateChannelReq struct { + Region string `json:"region" validate:"required"` + Provider string `json:"provider" validate:"required"` + Protocol services.Protocol `json:"protocol" validate:"required,oneof=socks5 http https"` + ResourceId int `json:"resource_id" validate:"required"` + Count int `json:"count" validate:"required"` + ResultType CreateChannelResultType `json:"return_type" validate:"required,oneof=json text"` + ResultSeparator CreateChannelResultSeparator `json:"return_separator" validate:"required,oneof=enter line both tab"` +} + +func CreateChannel(c *fiber.Ctx) error { + req := new(CreateChannelReq) + if err := c.BodyParser(req); err != nil { + return err + } + + // 建立连接通道 + auth, ok := c.Locals("user").(*services.AuthContext) + if !ok { + return errors.New("user not found") + } + + channels, err := services.Channel.CreateChannel( + auth.Payload.Id, + req.Region, + req.Provider, + req.Protocol, + req.ResourceId, + req.Count, + ) + if err != nil { + return err + } + + // 返回连接通道列表 + var result []string + for _, channel := range channels { + url := fmt.Sprintf("%s://%s:%d", channel.Protocol, channel.UserHost, channel.NodePort) + result = append(result, url) + } + + switch req.ResultType { + case CreateChannelResultTypeJson: + return c.JSON(fiber.Map{ + "result": result, + }) + case CreateChannelResultTypeText: + switch req.ResultSeparator { + case CreateChannelResultSeparatorEnter: + return c.SendString(strings.Join(result, "\r")) + case CreateChannelResultSeparatorLine: + return c.SendString(strings.Join(result, "\n")) + case CreateChannelResultSeparatorBoth: + return c.SendString(strings.Join(result, "\r\n")) + case CreateChannelResultSeparatorTab: + return c.SendString(strings.Join(result, "\t")) + } + } + + return errors.New("无效的返回类型") +} + +type CreateChannelResultType string + +const ( + CreateChannelResultTypeJson CreateChannelResultType = "json" + CreateChannelResultTypeText CreateChannelResultType = "text" +) + +type CreateChannelResultSeparator string + +const ( + CreateChannelResultSeparatorEnter CreateChannelResultSeparator = "enter" + CreateChannelResultSeparatorLine CreateChannelResultSeparator = "line" + CreateChannelResultSeparatorBoth CreateChannelResultSeparator = "both" + CreateChannelResultSeparatorTab CreateChannelResultSeparator = "tab" +) + +// endregion diff --git a/web/services/channel.go b/web/services/channel.go new file mode 100644 index 0000000..8586ad3 --- /dev/null +++ b/web/services/channel.go @@ -0,0 +1,58 @@ +package services + +import ( + "errors" + "log/slog" + "platform/web/models" + q "platform/web/queries" +) + +var Channel = &channelService{} + +type channelService struct { +} + +func (s *channelService) CreateChannel( + userID int32, + region string, + provider string, + protocol Protocol, + resourceId int, + count int, +) ([]*models.Channel, error) { + + // 检查并扣减套餐余额 + var resourceInfo = struct { + models.Resource + models.ResourcePss + }{} + err := q.Resource. + Where(q.Resource.UserID.Eq(userID)). + LeftJoin(q.ResourcePss, q.ResourcePss.ResourceID.EqCol(q.Resource.ID)). + Scan(&resourceInfo) + if err != nil { + return nil, err + } + + slog.Debug("查询资源", slog.Any("info", resourceInfo)) + + // 创建连接通道 + + // 保存到数据库与缓存,以及计时关闭 + + // 组织请求数据 + + // 发送请求到远端配置服务 + + // 返回连接通道列表 + + return nil, errors.New("not implemented") +} + +type Protocol string + +const ( + ProtocolSocks5 = Protocol("socks5") + ProtocolHTTP = Protocol("http") + ProtocolHttps = Protocol("https") +) diff --git a/web/services/channel_test.go b/web/services/channel_test.go new file mode 100644 index 0000000..33064f3 --- /dev/null +++ b/web/services/channel_test.go @@ -0,0 +1,37 @@ +package services + +import ( + "testing" +) + +func Test_channelService_CreateChannel(t *testing.T) { + // type args struct { + // userID int32 + // region string + // provider string + // protocol Protocol + // resourceId int + // count int + // } + // tests := []struct { + // name string + // args args + // want []*models.Channel + // wantErr bool + // }{ + // // TODO: Add test cases. + // } + // for _, tt := range tests { + // t.Run(tt.name, func(t *testing.T) { + // s := &channelService{} + // got, err := s.CreateChannel(tt.args.userID, tt.args.region, tt.args.provider, tt.args.protocol, tt.args.resourceId, tt.args.count) + // if (err != nil) != tt.wantErr { + // t.Errorf("CreateChannel() error = %v, wantErr %v", err, tt.wantErr) + // return + // } + // if !reflect.DeepEqual(got, tt.want) { + // t.Errorf("CreateChannel() got = %v, want %v", got, tt.want) + // } + // }) + // } +} diff --git a/web/services/session.go b/web/services/session.go index 033c550..b93dccf 100644 --- a/web/services/session.go +++ b/web/services/session.go @@ -185,6 +185,11 @@ func (s *sessionService) Remove(ctx context.Context, accessToken, refreshToken s return nil } +// 生成一个新的令牌 +func genToken() string { + return uuid.NewString() +} + // 令牌键的格式为 "session:" func accessKey(token string) string { return fmt.Sprintf("session:%s", token) @@ -195,11 +200,6 @@ func refreshKey(token string) string { return fmt.Sprintf("session:refresh:%s", token) } -// 生成一个新的令牌 -func genToken() string { - return uuid.NewString() -} - // endregion // region SessionConfig