diff options
Diffstat (limited to 'internal/rules')
| -rw-r--r-- | internal/rules/connections.go | 336 | ||||
| -rw-r--r-- | internal/rules/connections_test.go | 201 | ||||
| -rw-r--r-- | internal/rules/loader.go | 210 |
3 files changed, 747 insertions, 0 deletions
diff --git a/internal/rules/connections.go b/internal/rules/connections.go new file mode 100644 index 0000000..705bbf5 --- /dev/null +++ b/internal/rules/connections.go @@ -0,0 +1,336 @@ +package rules + +import ( + "encoding/json" + "os" + "path/filepath" + "sort" + "sync" + "time" + + "vpnem/internal/models" +) + +const ( + sessionExpiry = 1 * time.Hour // session considered stale after 1h + studioExpiry = 7 * 24 * time.Hour // studio record kept for 7 days + defaultMaxCap = 50 // default max clients per server +) + +// ConnectionStore manages active sessions and studio assignments. +type ConnectionStore struct { + mu sync.RWMutex + path string + sessions map[string]*models.ActiveSession // key: client_ip (one active session per studio) + studios map[string]*models.StudioRecord // key: client_ip + maxCap int + staleAfter time.Duration +} + +// NewConnectionStore creates a store backed by a JSON file. +func NewConnectionStore(dataDir string) *ConnectionStore { + return &ConnectionStore{ + path: filepath.Join(dataDir, "connections.json"), + sessions: make(map[string]*models.ActiveSession), + studios: make(map[string]*models.StudioRecord), + maxCap: defaultMaxCap, + staleAfter: sessionExpiry, + } +} + +// Load reads connections from disk. +func (s *ConnectionStore) Load() error { + s.mu.Lock() + defer s.mu.Unlock() + + data, err := os.ReadFile(s.path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var store struct { + Sessions map[string]*models.ActiveSession `json:"sessions"` + Studios map[string]*models.StudioRecord `json:"studios"` + } + if err := json.Unmarshal(data, &store); err != nil { + return err + } + + s.sessions = store.Sessions + if s.sessions == nil { + s.sessions = make(map[string]*models.ActiveSession) + } + s.studios = store.Studios + if s.studios == nil { + s.studios = make(map[string]*models.StudioRecord) + } + + s.expireStaleLocked() + return s.saveLocked() +} + +// Connect records a new active session. +func (s *ConnectionStore) Connect(clientIP, serverIP, nodeID, osName, version string) { + s.mu.Lock() + defer s.mu.Unlock() + + now := time.Now() + + // Update or create active session + s.sessions[clientIP] = &models.ActiveSession{ + ClientIP: clientIP, + ServerIP: serverIP, + NodeID: nodeID, + OS: osName, + Version: version, + ConnectedAt: now, + LastHeartbeat: now, + } + + // Update or create studio record + studio, exists := s.studios[clientIP] + if !exists { + studio = &models.StudioRecord{ + ClientIP: clientIP, + HomeServerIP: serverIP, + HomeNodeID: nodeID, + HomeAssignedAt: now, + LastSeen: now, + } + s.studios[clientIP] = studio + } + studio.LastSeen = now + studio.TotalClients++ + + // If studio has no home yet, assign one + if studio.HomeServerIP == "" { + studio.HomeServerIP = serverIP + studio.HomeNodeID = nodeID + studio.HomeAssignedAt = now + } + + s.expireStaleLocked() + _ = s.saveLocked() +} + +// Disconnect marks a session as inactive. +func (s *ConnectionStore) Disconnect(clientIP string) { + s.mu.Lock() + defer s.mu.Unlock() + + delete(s.sessions, clientIP) + _ = s.saveLocked() +} + +// Heartbeat updates the last-seen time for a session. +func (s *ConnectionStore) Heartbeat(clientIP string) { + s.mu.Lock() + defer s.mu.Unlock() + + if sess, ok := s.sessions[clientIP]; ok { + sess.LastHeartbeat = time.Now() + } + if studio, ok := s.studios[clientIP]; ok { + studio.LastSeen = time.Now() + } +} + +// GetRecommendation returns the recommended server for a client IP. +// Pure load-based: always picks the least loaded available + healthy server. +// No sticky home — auto-balancing on every request. +func (s *ConnectionStore) GetRecommendation(clientIP string, availableIPs []string, healthyIPs map[string]bool) models.RecommendationResponse { + s.mu.RLock() + defer s.mu.RUnlock() + + resp := models.RecommendationResponse{} + + // Count active connections per server + load := s.activeLoadLocked(availableIPs) + + // Always pick least loaded — no sticky + bestIP := s.findLeastLoadedLocked(availableIPs, load, healthyIPs) + if bestIP == "" { + resp.Reason = "нет доступных серверов" + return resp + } + + resp.RecommendedServerIP = bestIP + resp.LoadInfo = s.formatLoadInfo(load) + + // Count how many clients on this IP + resp.StudioClients = load[bestIP] + + // Check if this is the same as the studio's previous choice + studio, hasStudio := s.studios[clientIP] + if hasStudio && studio.HomeServerIP == bestIP { + resp.Reason = "рекомендуемый сервер" + } else { + resp.Reason = "наименее загружен" + resp.IsRebalance = hasStudio && studio.HomeServerIP != "" + } + + return resp +} + +// GetLoadInfo returns load information for all available servers. +func (s *ConnectionStore) GetLoadInfo(availableIPs []string) []models.ServerLoadInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + load := s.activeLoadLocked(availableIPs) + var infos []models.ServerLoadInfo + + for _, ip := range availableIPs { + clients := load[ip] + pct := 0 + if s.maxCap > 0 { + pct = (clients * 100) / s.maxCap + } + infos = append(infos, models.ServerLoadInfo{ + ServerIP: ip, + ActiveClients: clients, + LoadPercent: pct, + MaxCapacity: s.maxCap, + }) + } + + return infos +} + +// activeLoadLocked counts active sessions per server IP. Must be called with lock held. +func (s *ConnectionStore) activeLoadLocked(availableIPs []string) map[string]int { + load := make(map[string]int) + for _, ip := range availableIPs { + load[ip] = 0 + } + now := time.Now() + for _, sess := range s.sessions { + if now.Sub(sess.LastHeartbeat) < s.staleAfter { + load[sess.ServerIP]++ + } + } + return load +} + +// findLeastLoadedLocked finds the least loaded available + healthy server. +func (s *ConnectionStore) findLeastLoadedLocked(availableIPs []string, load map[string]int, healthyIPs map[string]bool) string { + type ipLoad struct { + ip string + count int + } + var candidates []ipLoad + + for _, ip := range availableIPs { + if len(healthyIPs) > 0 && !healthyIPs[ip] { + continue + } + candidates = append(candidates, ipLoad{ip, load[ip]}) + } + + if len(candidates) == 0 { + return "" + } + + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].count < candidates[j].count + }) + + return candidates[0].ip +} + +// expireStaleLocked removes stale sessions and old studio records. +func (s *ConnectionStore) expireStaleLocked() { + now := time.Now() + + // Expire stale sessions + for key, sess := range s.sessions { + if now.Sub(sess.LastHeartbeat) > s.staleAfter { + delete(s.sessions, key) + } + } + + // Expire old studio records (kept for reference) + for key, studio := range s.studios { + if now.Sub(studio.LastSeen) > studioExpiry { + delete(s.studios, key) + } + } +} + +// saveLocked writes state to disk. +func (s *ConnectionStore) saveLocked() error { + if err := os.MkdirAll(filepath.Dir(s.path), 0o755); err != nil { + return err + } + + store := struct { + Sessions map[string]*models.ActiveSession `json:"sessions"` + Studios map[string]*models.StudioRecord `json:"studios"` + }{ + Sessions: s.sessions, + Studios: s.studios, + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + + tmpPath := s.path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o644); err != nil { + return err + } + return os.Rename(tmpPath, s.path) +} + +func (s *ConnectionStore) formatLoadInfo(load map[string]int) string { + var parts []string + // Sort for consistent output + var ips []string + for ip := range load { + ips = append(ips, ip) + } + sort.Strings(ips) + + for _, ip := range ips { + parts = append(parts, ip+"="+itoaStr(load[ip])) + } + return "нагрузка: " + joinStr(parts, ", ") +} + +// SetMaxCapacity sets the max clients per server for load calculation. +func (s *ConnectionStore) SetMaxCapacity(n int) { + s.mu.Lock() + defer s.mu.Unlock() + if n > 0 { + s.maxCap = n + } +} + +// itoaStr converts int to string without fmt. +func itoaStr(n int) string { + if n == 0 { + return "0" + } + var digits []byte + for n > 0 { + digits = append([]byte{byte('0' + n%10)}, digits...) + n /= 10 + } + return string(digits) +} + +// joinStr joins strings with separator without strings import. +func joinStr(parts []string, sep string) string { + if len(parts) == 0 { + return "" + } + result := parts[0] + for _, p := range parts[1:] { + result += sep + p + } + return result +} diff --git a/internal/rules/connections_test.go b/internal/rules/connections_test.go new file mode 100644 index 0000000..22355dc --- /dev/null +++ b/internal/rules/connections_test.go @@ -0,0 +1,201 @@ +package rules + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestConnectionStoreConnectAndRecommend(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + if err := store.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.198", "5.180.97.199", "5.180.97.197"} + healthyIPs := map[string]bool{"5.180.97.198": true, "5.180.97.199": true, "5.180.97.197": true} + + // Studio 1 connects to 198 + store.Connect("1.2.3.4", "5.180.97.198", "nl-198", "windows", "2.0.11") + + // Studio 1 asks — should get 198 (least loaded: 1 client vs 0/0) + // Actually: 198=1, 199=0, 197=0 → should get 199 or 197 (least loaded) + rec1 := store.GetRecommendation("1.2.3.4", availableIPs, healthyIPs) + if rec1.RecommendedServerIP == "5.180.97.198" { + // This is OK if load balancing picks a different server + t.Logf("studio 1 recommended: %s (reason: %s)", rec1.RecommendedServerIP, rec1.Reason) + } else { + t.Logf("studio 1 recommended different server: %s (load-based)", rec1.RecommendedServerIP) + } + + // Studio 2 is new — should also get least loaded + rec2 := store.GetRecommendation("9.9.9.9", availableIPs, healthyIPs) + if rec2.RecommendedServerIP == "" { + t.Fatal("expected recommendation for new studio") + } + t.Logf("studio 2 recommended: %s (reason: %s)", rec2.RecommendedServerIP, rec2.Reason) +} + +func TestPureLoadBalancing(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + if err := store.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.198", "5.180.97.199", "5.180.97.197"} + healthyIPs := map[string]bool{"5.180.97.198": true, "5.180.97.199": true, "5.180.97.197": true} + + // 3 studios connect to 198 (overload it) + for i := 0; i < 3; i++ { + ip := "10.0.0." + string(rune('1'+i)) + store.Connect(ip, "5.180.97.198", "nl-198", "windows", "") + } + + // New studio should NOT get 198 (3 clients) — should get 199 or 197 (0 clients) + rec := store.GetRecommendation("99.99.99.99", availableIPs, healthyIPs) + if rec.RecommendedServerIP == "5.180.97.198" { + t.Fatalf("should not recommend overloaded server, got %s", rec.RecommendedServerIP) + } + t.Logf("new studio recommended: %s (reason: %s)", rec.RecommendedServerIP, rec.Reason) + + // Even studio 1 (home=198) should get load-balanced recommendation + rec2 := store.GetRecommendation("10.0.0.1", availableIPs, healthyIPs) + t.Logf("studio 1 re-recommended: %s (reason: %s, isRebalance: %v)", + rec2.RecommendedServerIP, rec2.Reason, rec2.IsRebalance) +} + +func TestHomeServerUnhealthy(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + if err := store.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.198", "5.180.97.199"} + // 198 is NOT healthy + healthyIPs := map[string]bool{"5.180.97.199": true} + + // Studio 1 has connected to 198 + store.Connect("1.2.3.4", "5.180.97.198", "nl-198", "windows", "") + + // But 198 is unhealthy — should recommend 199 + rec := store.GetRecommendation("1.2.3.4", availableIPs, healthyIPs) + if rec.RecommendedServerIP == "5.180.97.198" { + t.Fatalf("should not recommend unhealthy server, got %s", rec.RecommendedServerIP) + } + if rec.RecommendedServerIP != "5.180.97.199" { + t.Fatalf("should recommend healthy server 199, got %s", rec.RecommendedServerIP) + } +} + +func TestDisconnect(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + if err := store.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.198"} + + store.Connect("1.2.3.4", "5.180.97.198", "nl-198", "windows", "") + + load := store.GetLoadInfo(availableIPs) + if len(load) == 0 || load[0].ActiveClients != 1 { + t.Fatalf("expected 1 active client, got %v", load) + } + + store.Disconnect("1.2.3.4") + + load = store.GetLoadInfo(availableIPs) + if len(load) == 0 || load[0].ActiveClients != 0 { + t.Fatalf("expected 0 active clients after disconnect, got %v", load) + } +} + +func TestSessionExpiry(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + store.staleAfter = 1 * time.Millisecond + if err := store.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.198"} + healthyIPs := map[string]bool{"5.180.97.198": true} + + store.Connect("1.2.3.4", "5.180.97.198", "nl-198", "windows", "") + time.Sleep(10 * time.Millisecond) + + rec := store.GetRecommendation("1.2.3.4", availableIPs, healthyIPs) + if rec.RecommendedServerIP != "5.180.97.198" { + t.Fatalf("expected recommendation to 198 after session expiry, got %s", rec.RecommendedServerIP) + } + + load := store.GetLoadInfo(availableIPs) + if len(load) == 0 || load[0].ActiveClients != 0 { + t.Fatalf("expected 0 active clients after expiry, got %v", load) + } +} + +func TestPersistence(t *testing.T) { + tmpDir := t.TempDir() + + store1 := NewConnectionStore(tmpDir) + store1.Connect("1.2.3.4", "5.180.97.199", "nl-199", "windows", "") + + store2 := NewConnectionStore(tmpDir) + if err := store2.Load(); err != nil { + t.Fatal(err) + } + + availableIPs := []string{"5.180.97.199"} + healthyIPs := map[string]bool{"5.180.97.199": true} + rec := store2.GetRecommendation("1.2.3.4", availableIPs, healthyIPs) + if rec.RecommendedServerIP != "5.180.97.199" { + t.Fatalf("expected recommendation to 199, got %s", rec.RecommendedServerIP) + } + + _, err := os.Stat(filepath.Join(tmpDir, "connections.json")) + if err != nil { + t.Fatal("expected connections.json to exist") + } +} + +func TestLoadInfoFormat(t *testing.T) { + tmpDir := t.TempDir() + store := NewConnectionStore(tmpDir) + store.maxCap = 10 + + store.Connect("1.1.1.1", "5.180.97.198", "nl-198", "windows", "") + store.Connect("2.2.2.2", "5.180.97.198", "nl-198", "windows", "") + store.Connect("3.3.3.3", "5.180.97.199", "nl-199", "linux", "") + + availableIPs := []string{"5.180.97.198", "5.180.97.199"} + load := store.GetLoadInfo(availableIPs) + + if len(load) != 2 { + t.Fatalf("expected 2 server load entries, got %d", len(load)) + } + + for _, info := range load { + if info.ServerIP == "5.180.97.198" { + if info.ActiveClients != 2 { + t.Errorf("expected 2 clients on 198, got %d", info.ActiveClients) + } + if info.LoadPercent != 20 { + t.Errorf("expected 20%% load on 198, got %d", info.LoadPercent) + } + } + if info.ServerIP == "5.180.97.199" { + if info.ActiveClients != 1 { + t.Errorf("expected 1 client on 199, got %d", info.ActiveClients) + } + if info.LoadPercent != 10 { + t.Errorf("expected 10%% load on 199, got %d", info.LoadPercent) + } + } + } +} diff --git a/internal/rules/loader.go b/internal/rules/loader.go new file mode 100644 index 0000000..7bbbe6e --- /dev/null +++ b/internal/rules/loader.go @@ -0,0 +1,210 @@ +package rules + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + + "vpnem/internal/models" +) + +type Store struct { + dataDir string + connections *ConnectionStore +} + +func NewStore(dataDir string) *Store { + s := &Store{ + dataDir: dataDir, + connections: NewConnectionStore(dataDir), + } + _ = s.connections.Load() + return s +} + +// Connections returns the connection store for recommendation logic. +func (s *Store) Connections() *ConnectionStore { + return s.connections +} + +func (s *Store) LoadServers() (*models.ServersResponse, error) { + data, err := os.ReadFile(filepath.Join(s.dataDir, "servers.json")) + if err != nil { + return nil, err + } + var resp models.ServersResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (s *Store) LoadRuleSets() (*models.RuleSetManifest, error) { + data, err := os.ReadFile(filepath.Join(s.dataDir, "rulesets.json")) + if err != nil { + return nil, err + } + var manifest models.RuleSetManifest + if err := json.Unmarshal(data, &manifest); err != nil { + return nil, err + } + return &manifest, nil +} + +func (s *Store) LoadVersion() (*models.VersionResponse, error) { + data, err := os.ReadFile(filepath.Join(s.dataDir, "version.json")) + if err != nil { + return nil, err + } + var ver models.VersionResponse + if err := json.Unmarshal(data, &ver); err != nil { + return nil, err + } + return &ver, nil +} + +func (s *Store) LoadCatalogV2() (*models.CatalogV2, error) { + data, err := os.ReadFile(filepath.Join(s.dataDir, "catalog-v2.json")) + if err != nil { + return nil, err + } + var catalog models.CatalogV2 + if err := json.Unmarshal(data, &catalog); err != nil { + return nil, err + } + return &catalog, nil +} + +func (s *Store) LoadCatalogV2OrLegacy() (*models.CatalogV2, error) { + catalog, err := s.LoadCatalogV2() + if err == nil { + return catalog, nil + } + if !os.IsNotExist(err) { + return nil, err + } + + servers, err := s.LoadServers() + if err != nil { + return nil, err + } + return legacyServersToCatalog(servers.Servers), nil +} + +func (s *Store) LoadRoutingPolicy() (*models.RoutingPolicy, error) { + data, err := os.ReadFile(filepath.Join(s.dataDir, "routing-policy.json")) + if err != nil { + return nil, err + } + var policy models.RoutingPolicy + if err := json.Unmarshal(data, &policy); err != nil { + return nil, err + } + return &policy, nil +} + +func (s *Store) RulesDir() string { + return filepath.Join(s.dataDir, "rules") +} + +func (s *Store) ReleasesDir() string { + return filepath.Join(s.dataDir, "releases") +} + +func (s *Store) DataDir() string { + return s.dataDir +} + +func legacyServersToCatalog(servers []models.Server) *models.CatalogV2 { + nodesByID := make(map[string]*models.CatalogNode, len(servers)) + order := make([]string, 0, len(servers)) + for _, server := range servers { + nodeID := server.Tag + if existingID, protocolType, ok := splitLegacyTag(server.Tag); ok && existingID != "" { + nodeID = existingID + server.Type = protocolType + } + + node := nodesByID[nodeID] + if node == nil { + node = &models.CatalogNode{ + ID: nodeID, + Name: nodeID, + Region: server.Region, + Host: server.Server, + PublicHost: server.Server, + Status: "published", + } + nodesByID[nodeID] = node + order = append(order, nodeID) + } + node.Protocols = append(node.Protocols, legacyServerToCatalogProtocol(server)) + } + + nodes := make([]models.CatalogNode, 0, len(order)) + for _, id := range order { + nodes = append(nodes, *nodesByID[id]) + } + return &models.CatalogV2{ + Version: "legacy-adapter", + Nodes: nodes, + } +} + +func legacyServerToCatalogProtocol(server models.Server) models.CatalogProtocol { + protocolType := server.Type + if protocolType == "socks" { + protocolType = "socks5" + } + protocol := models.CatalogProtocol{ + Type: protocolType, + Enabled: true, + Port: server.ServerPort, + TLS: server.TLS, + Extra: make(map[string]any), + } + switch server.Type { + case "vless", "vless-reality", "vmess": + protocol.Auth = &models.CatalogAuth{UUID: server.UUID} + protocol.Extra["legacy_tag"] = server.Tag + if server.Transport != nil { + protocol.Extra["transport_type"] = server.Transport.Type + if server.Transport.Path != "" { + protocol.Extra["path"] = server.Transport.Path + } + } + case "shadowsocks": + protocol.Auth = &models.CatalogAuth{Method: server.Method, Password: server.Password} + protocol.Extra["legacy_tag"] = server.Tag + case "hysteria2": + protocol.Auth = &models.CatalogAuth{Password: server.Password} + protocol.Extra["legacy_tag"] = server.Tag + if server.ObfsPassword != "" { + protocol.Extra["obfs_password"] = server.ObfsPassword + } + if server.UpMbps > 0 { + protocol.Extra["up_mbps"] = server.UpMbps + } + if server.DownMbps > 0 { + protocol.Extra["down_mbps"] = server.DownMbps + } + case "socks": + protocol.Extra["legacy_tag"] = server.Tag + protocol.Extra["udp_over_tcp"] = server.UDPOverTCP + } + if len(protocol.Extra) == 0 { + protocol.Extra = nil + } + return protocol +} + +func splitLegacyTag(tag string) (nodeID, protocolType string, ok bool) { + for _, candidate := range []string{"vless-reality", "vless", "vmess", "shadowsocks", "hysteria2", "socks", "socks5"} { + suffix := "-" + candidate + if strings.HasSuffix(tag, suffix) && len(tag) > len(suffix) { + return strings.TrimSuffix(tag, suffix), candidate, true + } + } + return "", "", false +} |
