diff --git a/README.md b/README.md index 314b345..f630721 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,23 @@ ## todo +可配置 logger,不直接使用 slog + +授权测试,两种协议,三种认证方式 + +连接计时数据清理,避免堆泄露 + 找一个其他方式即时关闭未成功建立数据通道的连接 排查下套接字重复的问题 -鉴权时判断授权的协议 - 建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string,提高效率 建立数据通道失败后,根据用户所选协议返回对应失败响应 -测试跳过认证时的最大 qps(需要注意单机连接数上限,会导致连接失败) - -数据通道池化 可配配置环境变量 +- 输出级别 - 退出等待时间 - 数据通道连接超时等待时间 - 目标地址连接超时等待时间 @@ -24,13 +26,13 @@ 协程池化 -需要测试,考虑是否切换到 gnet +数据通道池化 数据通道支持 tcp 多路复用(分离逻辑流) 👆 进阶黑魔法 multipath tcp + 多路复用 -考虑一下连接安全性 +切换到 gnet ## 开发相关 diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index 98cb48d..bb4a80f 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -47,29 +47,83 @@ create index user_ips_ip_address_index on user_ips (ip_address); drop table if exists channels cascade; create table channels ( id serial primary key, - user_id int not null references users (id) + user_id int not null references users (id) on update cascade on delete cascade, - node_id int not null references nodes (id) -- - on update cascade -- - on delete set null, -- 节点删除后,用户侧需要保留提取记录 - node_port int, - protocol varchar(255), + node_id int not null references nodes (id) -- + on update cascade -- + on delete set null, -- 节点删除后,用户侧需要保留提取记录 + user_addr varchar(255) not null, -- 快照数据 + node_port int, -- 快照数据 auth_ip bool, auth_pass bool, + protocol varchar(255), username varchar(255) unique, password varchar(255), - expiration timestamp not null, + expiration timestamp not null, created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, deleted_at timestamp ); create index channel_user_id_index on channels (user_id); create index channel_node_id_index on channels (node_id); -create index channel_username_index on channels (username); +create index channel_user_addr_index on channels (user_addr); +create index channel_node_port_index on channels (node_port); -- ==================== -- 填充数据 -- ==================== +do +$$ + declare + r_user_id int; + r_node_id int; + begin + -- 用户信息 + insert into users ( + username, password, email, phone, name + ) + values ( + 'test', 'test', 'test@user.email', '12345678901', 'test_user' + ) + returning id into r_user_id; + insert into user_ips ( + user_id, ip_address + ) + values ( + r_user_id, '::1' + ), ( + r_user_id, '127.0.0.1' + ); + + -- 节点信息 + insert into nodes ( + name, version, fwd_port, provider, location + ) + values ( + 'qwer', 1, 20001, 'test_provider', 'test_location' + ) + returning id into r_node_id; + + -- 权限信息 + insert into channels ( + user_id, node_id, user_addr, node_port, auth_ip, auth_pass, + protocol, username, password, expiration + ) + values ( + r_user_id, r_node_id, '::1', 20001, true, false, + 'http', 'ip6http', 'asdf', now() + interval '1 year' + ), ( + r_user_id, r_node_id, '::1', 20001, true, false, + 'socks5', 'ip6socks', 'asdf', now() + interval '1 year' + ), ( + r_user_id, r_node_id, '127.0.0.1', 20001, true, false, + 'http', 'ip4http', 'asdf', now() + interval '1 year' + ), ( + r_user_id, r_node_id, '127.0.0.1', 20001, true, false, + 'socks5', 'ip4socks', 'asdf', now() + interval '1 year' + ); + end +$$; diff --git a/server/fwd/core/auth.go b/server/fwd/auth/auth.go similarity index 51% rename from server/fwd/core/auth.go rename to server/fwd/auth/auth.go index 27ee4a4..48f6d67 100644 --- a/server/fwd/core/auth.go +++ b/server/fwd/auth/auth.go @@ -1,32 +1,25 @@ -package core +package auth import ( "log/slog" "net" - models2 "proxy-server/server/pkg/models" + "proxy-server/server/fwd/core" + "proxy-server/server/pkg/models" "proxy-server/server/pkg/orm" + "strconv" "time" "github.com/pkg/errors" ) -type Payload struct { - ID uint -} +type Protocol string -type AuthContext struct { - Timeout float64 - Payload Payload - Meta map[string]any -} +const ( + Socks5 = Protocol("socks5") + Http = Protocol("http") +) -func CheckIp(conn net.Conn) (*AuthContext, error) { - return &AuthContext{ - Timeout: 0, - Payload: Payload{ - 1, - }, - }, nil +func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) { // 获取用户地址 remoteAddr := conn.RemoteAddr().String() @@ -37,23 +30,24 @@ func CheckIp(conn net.Conn) (*AuthContext, error) { // 获取服务端口 localAddr := conn.LocalAddr().String() - _, localPort, err := net.SplitHostPort(localAddr) + _, _localPort, err := net.SplitHostPort(localAddr) + localPort, err := strconv.Atoi(_localPort) + if err != nil { + return nil, errors.Wrap(err, "noAuth 认证失败") + } // 查询权限记录 - slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort) - var channels []models2.Channel - err = orm.DB. - Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). - Joins("INNER JOIN public.users u ON channels.user_id = u.id"). - Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost). - Where(&models2.Channel{ - AuthIp: true, - }). - Find(&channels).Error + slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort) + var channels []models.Channel + err = orm.DB.Find(&channels, &models.Channel{ + AuthIp: true, + UserAddr: remoteHost, + NodePort: localPort, + Protocol: string(proto), + }).Error if err != nil { return nil, errors.New("查询用户权限失败") } - // 记录应该只有一条 channel, err := orm.MaySingle(channels) if err != nil { @@ -71,30 +65,32 @@ func CheckIp(conn net.Conn) (*AuthContext, error) { return nil, errors.New("权限已过期") } - return &AuthContext{ + return &core.AuthContext{ Timeout: timeout, - Payload: Payload{ - channel.UserId, + Payload: core.Payload{ + ID: channel.UserId, }, }, nil } -func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { - return &AuthContext{ - Timeout: 0, - Payload: Payload{ - 1, - }, - }, nil +func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) { - // 查询通道配置 - var channel models2.Channel - err := orm.DB. - Where(&models2.Channel{ - Username: username, - AuthPass: true, - }). - First(&channel).Error + // 获取服务端口 + localAddr := conn.LocalAddr().String() + _, _localPort, err := net.SplitHostPort(localAddr) + localPort, err := strconv.Atoi(_localPort) + if err != nil { + return nil, errors.Wrap(err, "noAuth 认证失败") + } + + // 查询权限记录 + var channel models.Channel + err = orm.DB.Take(&channel, &models.Channel{ + AuthPass: true, + Username: username, + NodePort: localPort, + Protocol: string(proto), + }).Error if err != nil { return nil, errors.Wrap(err, "用户不存在") } @@ -104,15 +100,8 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { return nil, errors.New("密码错误") } - // 检查权限是否过期 - timeout := channel.Expiration.Sub(time.Now()).Seconds() - if timeout <= 0 { - return nil, errors.New("权限已过期") - } - // 如果用户设置了双验证则检查 ip 是否在白名单中 if channel.AuthIp { - slog.Debug("验证用户 ip") // 获取用户地址 remoteAddr := conn.RemoteAddr().String() @@ -121,28 +110,22 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { return nil, errors.Wrap(err, "无法获取连接信息") } - // 查询通道配置 - - var ips int64 - err = orm.DB. - Where(&models2.UserIp{ - UserId: channel.UserId, - IpAddress: remoteHost, - }). - Count(&ips).Error - if err != nil { - return nil, errors.Wrap(err, "查询白名单失败") - } - - if ips == 0 { + // 查询权限记录 + if channel.UserAddr != remoteHost { return nil, errors.New("不在白名单内") } } - return &AuthContext{ + // 检查权限是否过期 + timeout := channel.Expiration.Sub(time.Now()).Seconds() + if timeout <= 0 { + return nil, errors.New("权限已过期") + } + + return &core.AuthContext{ Timeout: timeout, - Payload: Payload{ - channel.UserId, + Payload: core.Payload{ + ID: channel.UserId, }, }, nil } diff --git a/server/fwd/auth/auth_test.go b/server/fwd/auth/auth_test.go new file mode 100644 index 0000000..09607be --- /dev/null +++ b/server/fwd/auth/auth_test.go @@ -0,0 +1,151 @@ +package auth + +import ( + "net" + "proxy-server/server/fwd/core" + "proxy-server/server/pkg/orm" + "reflect" + "testing" + "time" +) + +type MockConn struct { + local net.Addr + remote net.Addr +} + +func (m MockConn) LocalAddr() net.Addr { + return m.local +} + +func (m MockConn) RemoteAddr() net.Addr { + return m.remote +} + +func (m MockConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +func (m MockConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func (m MockConn) Close() error { + return nil +} + +func (m MockConn) SetDeadline(t time.Time) error { + return nil +} + +func (m MockConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (m MockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +func TestCheckIp(t *testing.T) { + + orm.InitForTest( + "host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai", + ) + + type args struct { + conn net.Conn + proto Protocol + } + tests := []struct { + name string + args args + want *core.AuthContext + wantErr bool + }{ + { + name: "test-ok", + args: args{ + conn: MockConn{ + remote: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 12345, + }, + local: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 20001, + }, + }, + proto: Http, + }, + want: &core.AuthContext{ + Payload: core.Payload{ID: 1}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CheckIp(tt.args.conn, tt.args.proto) + if (err != nil) != tt.wantErr { + t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) { + t.Errorf("CheckIp() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckPass(t *testing.T) { + + orm.InitForTest( + "host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai", + ) + + type args struct { + conn net.Conn + username string + password string + proto Protocol + } + tests := []struct { + name string + args args + want *core.AuthContext + wantErr bool + }{ + { + name: "test-ok", + args: args{ + conn: MockConn{ + remote: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 12345, + }, + local: &net.TCPAddr{ + IP: []byte{127, 0, 0, 1}, + Port: 20001, + }, + }, + proto: Http, + username: "ip4http", + password: "asdf", + }, + want: &core.AuthContext{ + Payload: core.Payload{ID: 1}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password) + if (err != nil) != tt.wantErr { + t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) { + t.Errorf("CheckPass() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/server/fwd/core/conn.go b/server/fwd/core/conn.go index 3cc9dd5..614c082 100644 --- a/server/fwd/core/conn.go +++ b/server/fwd/core/conn.go @@ -92,3 +92,13 @@ func (a FwdAddr) Network() string { func (a FwdAddr) String() string { return fmt.Sprintf("%s:%d", a.IP, a.Port) } + +type AuthContext struct { + Timeout float64 + Payload Payload + Meta map[string]any +} + +type Payload struct { + ID uint +} diff --git a/server/fwd/ctrl.go b/server/fwd/ctrl.go index 7f07adc..e9e93b3 100644 --- a/server/fwd/ctrl.go +++ b/server/fwd/ctrl.go @@ -94,7 +94,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error { // 检查客户端 var node models.Node - err = orm.DB.First(&node, &models.Node{ + err = orm.DB.Take(&node, &models.Node{ Name: name, }).Error if err != nil { diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go index 3c05b22..4000e6f 100644 --- a/server/fwd/dispatcher/dispatch.go +++ b/server/fwd/dispatcher/dispatch.go @@ -118,7 +118,7 @@ func (s *Server) acceptHttp(ls net.Listener) error { go func() { user, err := http.Process(s.ctx, conn) if err != nil { - slog.Error("dispatcher http process error", "err", err) + slog.Error("处理 http 连接失败", "err", err) utils.Close(conn) return } diff --git a/server/fwd/http/http.go b/server/fwd/http/http.go index 8bf4df3..9e159b8 100644 --- a/server/fwd/http/http.go +++ b/server/fwd/http/http.go @@ -8,6 +8,7 @@ import ( "net" "net/textproto" "net/url" + "proxy-server/server/fwd/auth" "proxy-server/server/fwd/core" "strings" @@ -47,11 +48,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { // 验证账号 authInfo := headers.Get("Proxy-Authorization") - var auth *core.AuthContext + var authCtx *core.AuthContext + var authErr error if authInfo == "" { - auth, err = core.CheckIp(conn) - if err != nil { - return nil, errors.Wrap(err, "验证账号失败") + authCtx, authErr = auth.CheckIp(conn, auth.Http) + if authErr != nil { + _, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n")) + if err != nil { + return nil, errors.Wrap(err, "响应 407 失败") + } + return nil, errors.Wrap(authErr, "验证账号失败") } } else { authParts := strings.Split(authInfo, " ") @@ -66,7 +72,14 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { return nil, errors.Wrap(err, "解码认证信息失败") } authPair := strings.Split(string(authBytes), ":") - auth, err = core.CheckPass(conn, authPair[0], authPair[1]) + authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1]) + if authErr != nil { + _, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n")) + if err != nil { + return nil, errors.Wrap(err, "响应 407 失败") + } + return nil, errors.Wrap(authErr, "验证账号失败") + } } // 获取 Host @@ -94,7 +107,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { Port: addr.Port, Domain: host, }, - auth: auth, + auth: authCtx, } var user *core.Conn diff --git a/server/fwd/socks/socks.go b/server/fwd/socks/socks.go index 9615bfa..7e81315 100644 --- a/server/fwd/socks/socks.go +++ b/server/fwd/socks/socks.go @@ -9,6 +9,7 @@ import ( "log/slog" "net" "proxy-server/pkg/utils" + "proxy-server/server/fwd/auth" "proxy-server/server/fwd/core" "slices" @@ -60,7 +61,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { reader := bufio.NewReader(conn) // 认证 - auth, err := authenticate(ctx, reader, conn) + authCtx, err := authenticate(ctx, reader, conn) if err != nil { return nil, errors.Wrap(err, "认证失败") } @@ -85,7 +86,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { Protocol: "socks5", Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), Dest: request.DestAddr, - Auth: auth, + Auth: authCtx, }, nil } @@ -167,7 +168,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co password := string(passwordBuf) // 检查权限 - authContext, err := core.CheckPass(conn, username, password) + authContext, err := auth.CheckPass(conn, auth.Socks5, username, password) if err != nil { return nil, errors.Wrap(err, "权限检查失败") } @@ -188,12 +189,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co return nil, errors.Wrap(err, "响应认证方式失败") } - authContext, err := core.CheckIp(conn) + authCtx, err := auth.CheckIp(conn, auth.Socks5) if err != nil { return nil, errors.Wrap(err, "权限检查失败") } - return authContext, nil + return authCtx, nil } // 无适用的认证方式 diff --git a/server/pkg/models/channel.go b/server/pkg/models/channel.go index 5c6941b..f0870bd 100644 --- a/server/pkg/models/channel.go +++ b/server/pkg/models/channel.go @@ -11,10 +11,12 @@ type Channel struct { gorm.Model UserId uint NodeId uint + UserAddr string + NodePort int + AuthIp bool + AuthPass bool Protocol string Username string Password string - AuthIp bool - AuthPass bool Expiration time.Time } diff --git a/server/pkg/orm/orm.go b/server/pkg/orm/orm.go index 60887e5..141a1da 100644 --- a/server/pkg/orm/orm.go +++ b/server/pkg/orm/orm.go @@ -29,6 +29,17 @@ func Init() { DB = db } +func InitForTest(dsn string) { + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: logger.Default, + }) + if err != nil { + panic(err) + } + + DB = db +} + func MaySingle[T any](results []T) (*T, error) { rsLen := len(results) if rsLen == 0 {