From 3d51aa455006903345f554a2dd90034993796114 Mon Sep 17 00:00:00 2001 From: sergei Date: Tue, 14 Apr 2026 06:23:55 +0400 Subject: vpnem: VPN infrastructure with load-balanced multi-protocol nodes - Multi-protocol VPS nodes (VLESS-REALITY + Hysteria2 + SOCKS5) - Smart load balancing via recommendation API - Windows/Linux client (Go + Wails + sing-box) - Server API with RealIP detection and connection tracking - Auto-deployment via vpnui control plane - Silent Windows installer with UAC elevation - Load-based server recommendation (no sticky sessions) - Best Server one-click connection workflow --- internal/sync/fetcher.go | 643 ++++++++++++++++++++++++++++++++++++++++++ internal/sync/fetcher_test.go | 300 ++++++++++++++++++++ internal/sync/health.go | 33 +++ internal/sync/latency.go | 62 ++++ internal/sync/updater.go | 180 ++++++++++++ 5 files changed, 1218 insertions(+) create mode 100644 internal/sync/fetcher.go create mode 100644 internal/sync/fetcher_test.go create mode 100644 internal/sync/health.go create mode 100644 internal/sync/latency.go create mode 100644 internal/sync/updater.go (limited to 'internal/sync') diff --git a/internal/sync/fetcher.go b/internal/sync/fetcher.go new file mode 100644 index 0000000..3ea6df4 --- /dev/null +++ b/internal/sync/fetcher.go @@ -0,0 +1,643 @@ +package sync + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "time" + + "vpnem/internal/config" + "vpnem/internal/models" +) + +// Fetcher pulls configuration from the vpnem server API. +type Fetcher struct { + baseURL string + client *http.Client +} + +// NewFetcher creates a new Fetcher. +func NewFetcher(baseURL string) *Fetcher { + return &Fetcher{ + baseURL: baseURL, + client: &http.Client{ + Timeout: 15 * time.Second, + }, + } +} + +// FetchServers retrieves the server list from the API. +func (f *Fetcher) FetchServers() (*models.ServersResponse, error) { + catalog, err := f.FetchCatalog() + if err == nil { + return &models.ServersResponse{Servers: CatalogToServers(catalog)}, nil + } + return nil, fmt.Errorf("fetch catalog: %w", err) +} + +func (f *Fetcher) FetchCatalogV2() (*models.CatalogV2, error) { + var resp models.CatalogV2 + if err := f.getJSON("/api/v2/catalog", &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func (f *Fetcher) FetchCatalog() (*models.CatalogV2, error) { + catalog, err := f.FetchCatalogV2() + if err == nil { + return catalog, nil + } + var statusErr *HTTPStatusError + if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound { + return nil, fmt.Errorf("fetch catalog v2: %w", err) + } + + var resp models.ServersResponse + if err := f.getJSON("/api/v1/servers", &resp); err != nil { + return nil, fmt.Errorf("fetch servers: %w", err) + } + return ServersToCatalog(resp.Servers), nil +} + +func (f *Fetcher) FetchRoutingPolicy() (*models.RoutingPolicy, error) { + var resp models.RoutingPolicy + if err := f.getJSON("/api/v1/routing-policy", &resp); err != nil { + var statusErr *HTTPStatusError + if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound { + return config.DefaultRoutingPolicy(), nil + } + return nil, fmt.Errorf("fetch routing policy: %w", err) + } + return config.EffectiveRoutingPolicy(&resp), nil +} + +// FetchRuleSets retrieves the rule-set manifest from the API. +func (f *Fetcher) FetchRuleSets() (*models.RuleSetManifest, error) { + var resp models.RuleSetManifest + if err := f.getJSON("/api/v1/ruleset/manifest", &resp); err != nil { + return nil, fmt.Errorf("fetch rulesets: %w", err) + } + return &resp, nil +} + +// DownloadRuleSets downloads all non-optional .srs files to dataDir/rules/. +// Returns the updated RuleSet list with LocalPath populated. +func (f *Fetcher) DownloadRuleSets(ruleSets []models.RuleSet, dataDir string) ([]models.RuleSet, error) { + rulesDir := filepath.Join(dataDir, "rules") + if err := os.MkdirAll(rulesDir, 0755); err != nil { + return nil, fmt.Errorf("create rules dir: %w", err) + } + + var downloaded []models.RuleSet + for _, rs := range ruleSets { + if rs.URL == "" { + if rs.Optional { + continue + } + return nil, fmt.Errorf("rule-set %s has no URL", rs.Tag) + } + localPath := filepath.Join(rulesDir, rs.Tag+".srs") + resp, err := f.client.Get(rs.URL) + if err != nil { + if rs.Optional { + continue + } + return nil, fmt.Errorf("download %s: %w", rs.URL, err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + if rs.Optional { + continue + } + return nil, fmt.Errorf("read %s: %w", rs.URL, err) + } + if resp.StatusCode != http.StatusOK { + if rs.Optional { + continue + } + return nil, fmt.Errorf("download %s: HTTP %d", rs.URL, resp.StatusCode) + } + if err := os.WriteFile(localPath, body, 0644); err != nil { + return nil, fmt.Errorf("write %s: %w", localPath, err) + } + rs.LocalPath = localPath + downloaded = append(downloaded, rs) + } + return downloaded, nil +} + +// FetchVersion retrieves the latest client version info. +func (f *Fetcher) FetchVersion() (*models.VersionResponse, error) { + var resp models.VersionResponse + if err := f.getJSON("/api/v1/version", &resp); err != nil { + return nil, fmt.Errorf("fetch version: %w", err) + } + return &resp, nil +} + +// ServerIPs extracts all unique server IPs from the server list. +func ServerIPs(servers []models.Server) []string { + seen := make(map[string]bool) + var ips []string + for _, s := range servers { + if !seen[s.Server] { + seen[s.Server] = true + ips = append(ips, s.Server) + } + } + return ips +} + +// ReportError sends error logs to the server (best-effort, non-blocking). +func (f *Fetcher) ReportError(version, osName string, lines []string) { + payload := map[string]any{ + "version": version, + "os": osName, + "lines": lines, + } + data, err := json.Marshal(payload) + if err != nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, f.baseURL+"/logs2026vpnem/errors", bytes.NewReader(data)) + if err != nil { + return + } + req.Header.Set("Content-Type", "application/json") + resp, err := f.client.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + io.Copy(io.Discard, resp.Body) +} + +func (f *Fetcher) getJSON(path string, v any) error { + resp, err := f.client.Get(f.baseURL + path) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)} + } + + return json.NewDecoder(resp.Body).Decode(v) +} + +type HTTPStatusError struct { + StatusCode int + Body string +} + +func (e *HTTPStatusError) Error() string { + return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Body) +} + +// ReportConnect sends a connection report to the server. +// Server auto-detects client real IP from X-Forwarded-For. +func (f *Fetcher) ReportConnect(serverIP, nodeID string) (*models.RecommendationResponse, error) { + req := models.ConnectRequest{ + ServerIP: serverIP, + NodeID: nodeID, + OS: func() string { + if runtime.GOOS == "windows" { + return "windows" + } + return "linux" + }(), + Version: "", // server doesn't need version for rebalancing + } + + payload, err := json.Marshal(req) + if err != nil { + return nil, err + } + + resp, err := f.client.Post(f.baseURL+"/api/v1/connect", "application/json", bytes.NewReader(payload)) + if err != nil { + return nil, fmt.Errorf("report connect: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)} + } + + var recommendation models.RecommendationResponse + if err := json.NewDecoder(resp.Body).Decode(&recommendation); err != nil { + return nil, fmt.Errorf("decode recommendation: %w", err) + } + return &recommendation, nil +} + +// ReportDisconnect notifies the server that a client disconnected. +func (f *Fetcher) ReportDisconnect(serverIP, nodeID string) error { + req := models.DisconnectRequest{ + ServerIP: serverIP, + NodeID: nodeID, + } + + payload, err := json.Marshal(req) + if err != nil { + return err + } + + resp, err := f.client.Post(f.baseURL+"/api/v1/disconnect", "application/json", bytes.NewReader(payload)) + if err != nil { + return fmt.Errorf("report disconnect: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)} + } + return nil +} + +// GetRecommendation fetches a recommendation for the client. +// Server auto-detects client real IP from X-Forwarded-For. +func (f *Fetcher) GetRecommendation() (*models.RecommendationResponse, error) { + url := f.baseURL + "/api/v1/recommend" + resp, err := f.client.Get(url) + if err != nil { + return nil, fmt.Errorf("get recommendation: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)} + } + + var recommendation models.RecommendationResponse + if err := json.NewDecoder(resp.Body).Decode(&recommendation); err != nil { + return nil, fmt.Errorf("decode recommendation: %w", err) + } + return &recommendation, nil +} + +func CatalogToServers(catalog *models.CatalogV2) []models.Server { + if catalog == nil { + return nil + } + servers := make([]models.Server, 0) + for _, node := range catalog.Nodes { + if multi, ok := nodeToSplitServer(node); ok { + servers = append(servers, multi) + } + host := node.PublicHost + if host == "" { + if node.Domain != "" { + host = node.Domain + } else { + host = node.Host + } + } + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + if isSplitProtocol(protocol.Type) && hasSplitPair(node) { + continue + } + server := models.Server{ + Tag: legacyTag(node, protocol), + Region: node.Region, + Type: protocol.Type, + Server: host, + ServerPort: protocol.Port, + } + if protocol.TLS != nil { + server.TLS = &models.TLS{ + Enabled: protocol.TLS.Enabled, + ServerName: protocol.TLS.ServerName, + Insecure: protocol.TLS.Insecure, + ALPN: protocol.TLS.ALPN, + MinVersion: protocol.TLS.MinVersion, + MaxVersion: protocol.TLS.MaxVersion, + } + if protocol.TLS.Reality != nil && protocol.TLS.Reality.Enabled { + server.TLS.Reality = &models.Reality{ + Enabled: true, + PublicKey: protocol.TLS.Reality.PublicKey, + ShortID: protocol.TLS.Reality.ShortID, + Fingerprint: protocol.TLS.Reality.Fingerprint, + } + } + } + switch protocol.Type { + case "vless", "vless-reality", "vmess": + if protocol.Auth != nil { + server.UUID = protocol.Auth.UUID + } + if transportType, _ := protocol.Extra["transport_type"].(string); transportType != "" { + server.Transport = &models.Transport{ + Type: transportType, + Path: extraString(protocol.Extra, "path"), + } + } else if path := extraString(protocol.Extra, "path"); path != "" { + server.Transport = &models.Transport{ + Type: "ws", + Path: path, + } + } + case "shadowsocks": + if protocol.Auth != nil { + server.Method = protocol.Auth.Method + server.Password = protocol.Auth.Password + } + case "hysteria2": + if protocol.Auth != nil { + server.Password = protocol.Auth.Password + } + server.ObfsPassword = extraString(protocol.Extra, "obfs_password") + server.UpMbps = extraInt(protocol.Extra, "up_mbps", 0) + server.DownMbps = extraInt(protocol.Extra, "down_mbps", 0) + if server.TLS == nil { + server.TLS = &models.TLS{} + } + server.TLS.Enabled = true + server.TLS.Insecure = true + if len(server.TLS.ALPN) == 0 { + server.TLS.ALPN = []string{"h3"} + } + if server.TLS.MinVersion == "" { + server.TLS.MinVersion = "1.3" + } + if server.TLS.MaxVersion == "" { + server.TLS.MaxVersion = "1.3" + } + case "socks5": + server.Type = "socks" + } + servers = append(servers, server) + } + } + return servers +} + +func nodeToSplitServer(node models.CatalogNode) (models.Server, bool) { + if !hasSplitPair(node) { + return models.Server{}, false + } + host := node.PublicHost + if host == "" { + if node.Domain != "" { + host = node.Domain + } else { + host = node.Host + } + } + var main models.Server + var hy2 models.Server + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + switch protocol.Type { + case "vless-reality": + main = catalogProtocolToServer(node, protocol, host) + case "hysteria2": + hy2 = catalogProtocolToServer(node, protocol, host) + } + } + if main.Type == "" || hy2.Type == "" { + return models.Server{}, false + } + main.Tag = node.ID + "-multi" + main.Companions = []models.Server{hy2} + return main, true +} + +func hasSplitPair(node models.CatalogNode) bool { + hasReality := false + hasHy2 := false + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + switch protocol.Type { + case "vless-reality": + hasReality = true + case "hysteria2": + hasHy2 = true + } + } + return hasReality && hasHy2 +} + +func isSplitProtocol(protocolType string) bool { + return protocolType == "vless-reality" || protocolType == "hysteria2" +} + +func catalogProtocolToServer(node models.CatalogNode, protocol models.CatalogProtocol, host string) models.Server { + server := models.Server{ + Tag: legacyTag(node, protocol), + Region: node.Region, + Type: protocol.Type, + Server: host, + ServerPort: protocol.Port, + } + if protocol.TLS != nil { + server.TLS = &models.TLS{ + Enabled: protocol.TLS.Enabled, + ServerName: protocol.TLS.ServerName, + Insecure: protocol.TLS.Insecure, + ALPN: protocol.TLS.ALPN, + MinVersion: protocol.TLS.MinVersion, + MaxVersion: protocol.TLS.MaxVersion, + } + if protocol.TLS.Reality != nil && protocol.TLS.Reality.Enabled { + server.TLS.Reality = &models.Reality{ + Enabled: true, + PublicKey: protocol.TLS.Reality.PublicKey, + ShortID: protocol.TLS.Reality.ShortID, + Fingerprint: protocol.TLS.Reality.Fingerprint, + } + } + } + switch protocol.Type { + case "vless", "vless-reality", "vmess": + if protocol.Auth != nil { + server.UUID = protocol.Auth.UUID + } + if transportType, _ := protocol.Extra["transport_type"].(string); transportType != "" { + server.Transport = &models.Transport{ + Type: transportType, + Path: extraString(protocol.Extra, "path"), + } + } else if path := extraString(protocol.Extra, "path"); path != "" { + server.Transport = &models.Transport{ + Type: "ws", + Path: path, + } + } + case "shadowsocks": + if protocol.Auth != nil { + server.Method = protocol.Auth.Method + server.Password = protocol.Auth.Password + } + case "hysteria2": + if protocol.Auth != nil { + server.Password = protocol.Auth.Password + } + server.ObfsPassword = extraString(protocol.Extra, "obfs_password") + server.UpMbps = extraInt(protocol.Extra, "up_mbps", 0) + server.DownMbps = extraInt(protocol.Extra, "down_mbps", 0) + if server.TLS == nil { + server.TLS = &models.TLS{} + } + server.TLS.Enabled = true + server.TLS.Insecure = true + if len(server.TLS.ALPN) == 0 { + server.TLS.ALPN = []string{"h3"} + } + if server.TLS.MinVersion == "" { + server.TLS.MinVersion = "1.3" + } + if server.TLS.MaxVersion == "" { + server.TLS.MaxVersion = "1.3" + } + case "socks5": + server.Type = "socks" + } + return server +} + +func ServersToCatalog(servers []models.Server) *models.CatalogV2 { + nodesByID := make(map[string]*models.CatalogNode, len(servers)) + order := make([]string, 0, len(servers)) + for _, server := range servers { + nodeID := server.Tag + if existingID, protocolType, ok := splitLegacyTag(server.Tag); ok && existingID != "" { + nodeID = existingID + server.Type = protocolType + } + + node := nodesByID[nodeID] + if node == nil { + node = &models.CatalogNode{ + ID: nodeID, + Name: nodeID, + Region: server.Region, + Host: server.Server, + PublicHost: server.Server, + Status: "published", + } + nodesByID[nodeID] = node + order = append(order, nodeID) + } + node.Protocols = append(node.Protocols, serverToCatalogProtocol(server)) + } + + nodes := make([]models.CatalogNode, 0, len(order)) + for _, id := range order { + nodes = append(nodes, *nodesByID[id]) + } + return &models.CatalogV2{ + Version: "legacy-adapter", + Nodes: nodes, + } +} + +func serverToCatalogProtocol(server models.Server) models.CatalogProtocol { + protocolType := server.Type + if protocolType == "socks" { + protocolType = "socks5" + } + protocol := models.CatalogProtocol{ + Type: protocolType, + Enabled: true, + Port: server.ServerPort, + TLS: server.TLS, + Extra: make(map[string]any), + } + switch server.Type { + case "vless", "vless-reality", "vmess": + protocol.Auth = &models.CatalogAuth{UUID: server.UUID} + protocol.Extra["legacy_tag"] = server.Tag + if server.Transport != nil { + protocol.Extra["transport_type"] = server.Transport.Type + if server.Transport.Path != "" { + protocol.Extra["path"] = server.Transport.Path + } + } + case "shadowsocks": + protocol.Auth = &models.CatalogAuth{Method: server.Method, Password: server.Password} + protocol.Extra["legacy_tag"] = server.Tag + case "hysteria2": + protocol.Auth = &models.CatalogAuth{Password: server.Password} + protocol.Extra["legacy_tag"] = server.Tag + if server.ObfsPassword != "" { + protocol.Extra["obfs_password"] = server.ObfsPassword + } + if server.UpMbps > 0 { + protocol.Extra["up_mbps"] = server.UpMbps + } + if server.DownMbps > 0 { + protocol.Extra["down_mbps"] = server.DownMbps + } + case "socks": + protocol.Extra["legacy_tag"] = server.Tag + protocol.Extra["udp_over_tcp"] = server.UDPOverTCP + } + if len(protocol.Extra) == 0 { + protocol.Extra = nil + } + return protocol +} + +func splitLegacyTag(tag string) (nodeID, protocolType string, ok bool) { + for _, candidate := range []string{"vless-reality", "vless", "vmess", "shadowsocks", "hysteria2", "socks", "socks5"} { + suffix := "-" + candidate + if len(tag) > len(suffix) && tag[len(tag)-len(suffix):] == suffix { + return tag[:len(tag)-len(suffix)], candidate, true + } + } + return "", "", false +} + +func legacyTag(node models.CatalogNode, protocol models.CatalogProtocol) string { + if tag := extraString(protocol.Extra, "legacy_tag"); tag != "" { + return tag + } + return node.ID + "-" + protocol.Type +} + +func extraString(extra map[string]any, key string) string { + if extra == nil { + return "" + } + value, _ := extra[key].(string) + return value +} + +func extraInt(extra map[string]any, key string, fallback int) int { + if extra == nil { + return fallback + } + switch value := extra[key].(type) { + case int: + return value + case float64: + return int(value) + default: + return fallback + } +} diff --git a/internal/sync/fetcher_test.go b/internal/sync/fetcher_test.go new file mode 100644 index 0000000..cdf3e73 --- /dev/null +++ b/internal/sync/fetcher_test.go @@ -0,0 +1,300 @@ +package sync + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "vpnem/internal/models" +) + +func TestCatalogToServers(t *testing.T) { + catalog := &models.CatalogV2{ + Version: "2", + Nodes: []models.CatalogNode{ + { + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + PublicHost: "nl-01.example.com", + Protocols: []models.CatalogProtocol{ + { + Type: "vless", + Enabled: true, + Port: 443, + TLS: &models.TLS{Enabled: true, ServerName: "nl-01.example.com"}, + Auth: &models.CatalogAuth{UUID: "11111111-1111-1111-1111-111111111111"}, + Extra: map[string]any{"transport_type": "ws", "path": "/ws"}, + }, + { + Type: "vmess", + Enabled: true, + Port: 8444, + TLS: &models.TLS{Enabled: true, ServerName: "nl-01.example.com"}, + Auth: &models.CatalogAuth{UUID: "22222222-2222-2222-2222-222222222222"}, + Extra: map[string]any{"path": "/vmess"}, + }, + { + Type: "hysteria2", + Enabled: true, + Port: 9443, + TLS: &models.TLS{Enabled: true, ServerName: "nl-01.example.com"}, + Auth: &models.CatalogAuth{Password: "hy2-secret"}, + Extra: map[string]any{"obfs_password": "obfs-secret", "up_mbps": 80, "down_mbps": 90}, + }, + }, + }, + }, + } + + servers := CatalogToServers(catalog) + if len(servers) != 3 { + t.Fatalf("len(servers) = %d, want 3", len(servers)) + } + if servers[1].Type != "vmess" { + t.Fatalf("expected vmess, got %q", servers[1].Type) + } + if servers[2].Type != "hysteria2" { + t.Fatalf("expected hysteria2, got %q", servers[2].Type) + } + if servers[2].ObfsPassword != "obfs-secret" { + t.Fatalf("unexpected hysteria2 obfs password") + } +} + +func TestFetchServersPrefersCatalogV2(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v2/catalog": + _ = json.NewEncoder(w).Encode(models.CatalogV2{ + Version: "2", + Nodes: []models.CatalogNode{ + { + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + PublicHost: "nl-01.example.com", + Protocols: []models.CatalogProtocol{ + {Type: "vmess", Enabled: true, Port: 8444, Auth: &models.CatalogAuth{UUID: "22222222-2222-2222-2222-222222222222"}}, + }, + }, + }, + }) + case "/api/v1/servers": + t.Fatal("legacy servers endpoint should not be used when catalog-v2 is available") + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + fetcher := NewFetcher(server.URL) + resp, err := fetcher.FetchServers() + if err != nil { + t.Fatalf("FetchServers error = %v", err) + } + if len(resp.Servers) != 1 { + t.Fatalf("expected 1 server, got %d", len(resp.Servers)) + } + if resp.Servers[0].Type != "vmess" { + t.Fatalf("expected vmess, got %q", resp.Servers[0].Type) + } +} + +func TestFetchServersFallsBackToLegacy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v2/catalog": + http.NotFound(w, r) + case "/api/v1/servers": + _ = json.NewEncoder(w).Encode(models.ServersResponse{ + Servers: []models.Server{{Tag: "legacy", Type: "socks", Server: "1.2.3.4", ServerPort: 1080}}, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + fetcher := NewFetcher(server.URL) + resp, err := fetcher.FetchServers() + if err != nil { + t.Fatalf("FetchServers error = %v", err) + } + if len(resp.Servers) != 1 || !strings.EqualFold(resp.Servers[0].Tag, "legacy") { + t.Fatalf("unexpected legacy response: %+v", resp.Servers) + } +} + +func TestFetchCatalogFallsBackToLegacy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v2/catalog": + http.NotFound(w, r) + case "/api/v1/servers": + _ = json.NewEncoder(w).Encode(models.ServersResponse{ + Servers: []models.Server{ + {Tag: "legacy-vless", Region: "nl", Type: "vless", Server: "legacy.example.com", ServerPort: 443, UUID: "11111111-1111-1111-1111-111111111111"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + fetcher := NewFetcher(server.URL) + catalog, err := fetcher.FetchCatalog() + if err != nil { + t.Fatalf("FetchCatalog error = %v", err) + } + if catalog.Version != "legacy-adapter" { + t.Fatalf("expected legacy-adapter version, got %q", catalog.Version) + } + if len(catalog.Nodes) != 1 || len(catalog.Nodes[0].Protocols) != 1 { + t.Fatalf("unexpected catalog shape: %+v", catalog) + } + if catalog.Nodes[0].Protocols[0].Type != "vless" { + t.Fatalf("expected vless protocol, got %q", catalog.Nodes[0].Protocols[0].Type) + } +} + +func TestFetchRoutingPolicyFallsBackToDefault(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer server.Close() + + fetcher := NewFetcher(server.URL) + policy, err := fetcher.FetchRoutingPolicy() + if err != nil { + t.Fatalf("FetchRoutingPolicy error = %v", err) + } + if policy.Version == "" { + t.Fatalf("expected default policy version") + } + if len(policy.AlwaysDirectProcesses) == 0 { + t.Fatalf("expected default direct processes") + } +} + +func TestFetchRoutingPolicy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/routing-policy": + _ = json.NewEncoder(w).Encode(models.RoutingPolicy{ + Version: "remote-policy", + AlwaysDirectProcesses: []string{"chromium.exe"}, + BlockedDomains: []string{"example.com"}, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + fetcher := NewFetcher(server.URL) + policy, err := fetcher.FetchRoutingPolicy() + if err != nil { + t.Fatalf("FetchRoutingPolicy error = %v", err) + } + if policy.Version != "remote-policy" { + t.Fatalf("expected remote-policy, got %q", policy.Version) + } + if len(policy.AlwaysDirectProcesses) != 1 || policy.AlwaysDirectProcesses[0] != "chromium.exe" { + t.Fatalf("unexpected routing policy: %+v", policy) + } +} + +func TestServersToCatalog(t *testing.T) { + catalog := ServersToCatalog([]models.Server{ + { + Tag: "nl-01-vless", + Region: "nl", + Type: "vless", + Server: "nl-01.example.com", + ServerPort: 443, + UUID: "11111111-1111-1111-1111-111111111111", + TLS: &models.TLS{Enabled: true, ServerName: "nl-01.example.com"}, + Transport: &models.Transport{Type: "ws", Path: "/ws"}, + }, + { + Tag: "nl-01-hysteria2", + Region: "nl", + Type: "hysteria2", + Server: "nl-01.example.com", + ServerPort: 9443, + Password: "hy2-secret", + ObfsPassword: "obfs-secret", + }, + }) + + if catalog.Version != "legacy-adapter" { + t.Fatalf("unexpected version %q", catalog.Version) + } + if len(catalog.Nodes) != 1 { + t.Fatalf("expected one node, got %d", len(catalog.Nodes)) + } + if len(catalog.Nodes[0].Protocols) != 2 { + t.Fatalf("expected two protocols, got %d", len(catalog.Nodes[0].Protocols)) + } + if catalog.Nodes[0].Protocols[1].Extra["obfs_password"] != "obfs-secret" { + t.Fatalf("expected obfs password in extra") + } +} + +func TestCatalogToServersMultiProtocolNode(t *testing.T) { + catalog := &models.CatalogV2{ + Version: "2", + Nodes: []models.CatalogNode{ + { + ID: "nl-multi-01", + Name: "NL Multi", + Region: "nl", + Host: "203.0.113.55", + PublicHost: "203.0.113.55", + Protocols: []models.CatalogProtocol{ + { + Type: "vless-reality", + Enabled: true, + Port: 443, + Auth: &models.CatalogAuth{UUID: "11111111-1111-1111-1111-111111111111"}, + TLS: &models.TLS{ + Enabled: true, + ServerName: "www.microsoft.com", + Reality: &models.Reality{ + Enabled: true, + PublicKey: "pubkey", + ShortID: "shortid", + Fingerprint: "chrome", + }, + }, + }, + { + Type: "hysteria2", + Enabled: true, + Port: 443, + Auth: &models.CatalogAuth{Password: "hy2-secret"}, + TLS: &models.TLS{Enabled: true, Insecure: true, ALPN: []string{"h3"}}, + Extra: map[string]any{"obfs_password": "obfs-secret", "up_mbps": 100, "down_mbps": 100}, + }, + }, + }, + }, + } + + servers := CatalogToServers(catalog) + if len(servers) != 1 { + t.Fatalf("expected 1 synthetic multi server, got %d", len(servers)) + } + if servers[0].Tag != "nl-multi-01-multi" { + t.Fatalf("unexpected synthetic tag %q", servers[0].Tag) + } + if len(servers[0].Companions) != 1 || servers[0].Companions[0].Type != "hysteria2" { + t.Fatalf("expected hysteria2 companion, got %+v", servers[0].Companions) + } +} diff --git a/internal/sync/health.go b/internal/sync/health.go new file mode 100644 index 0000000..4d6ceca --- /dev/null +++ b/internal/sync/health.go @@ -0,0 +1,33 @@ +package sync + +import ( + "fmt" + "net" + "time" + + "vpnem/internal/models" +) + +// HealthCheck tests if a server's proxy port is reachable. +func HealthCheck(server models.Server, timeout time.Duration) error { + addr := fmt.Sprintf("%s:%d", server.Server, server.ServerPort) + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + return fmt.Errorf("server %s unreachable: %w", server.Tag, err) + } + conn.Close() + return nil +} + +// FindHealthyServer returns the first healthy non-RU server from the list. +func FindHealthyServer(servers []models.Server, timeout time.Duration) *models.Server { + for _, s := range servers { + if s.Region == "RU" { + continue + } + if err := HealthCheck(s, timeout); err == nil { + return &s + } + } + return nil +} diff --git a/internal/sync/latency.go b/internal/sync/latency.go new file mode 100644 index 0000000..dd3268b --- /dev/null +++ b/internal/sync/latency.go @@ -0,0 +1,62 @@ +package sync + +import ( + "fmt" + "net" + "sort" + "sync" + "time" + + "vpnem/internal/models" +) + +// LatencyResult holds a server's latency measurement. +type LatencyResult struct { + Tag string `json:"tag"` + Region string `json:"region"` + Latency int `json:"latency_ms"` // -1 means unreachable +} + +// MeasureLatency pings all servers concurrently and returns results sorted by latency. +func MeasureLatency(servers []models.Server, timeout time.Duration) []LatencyResult { + var wg sync.WaitGroup + results := make([]LatencyResult, len(servers)) + + for i, s := range servers { + wg.Add(1) + go func(idx int, srv models.Server) { + defer wg.Done() + ms := tcpPing(srv.Server, srv.ServerPort, timeout) + results[idx] = LatencyResult{ + Tag: srv.Tag, + Region: srv.Region, + Latency: ms, + } + }(i, s) + } + + wg.Wait() + + sort.Slice(results, func(i, j int) bool { + if results[i].Latency == -1 { + return false + } + if results[j].Latency == -1 { + return true + } + return results[i].Latency < results[j].Latency + }) + + return results +} + +func tcpPing(host string, port int, timeout time.Duration) int { + addr := fmt.Sprintf("%s:%d", host, port) + start := time.Now() + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + return -1 + } + conn.Close() + return int(time.Since(start).Milliseconds()) +} diff --git a/internal/sync/updater.go b/internal/sync/updater.go new file mode 100644 index 0000000..983ce43 --- /dev/null +++ b/internal/sync/updater.go @@ -0,0 +1,180 @@ +package sync + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "log" + "net/http" + "os" + "path/filepath" + "runtime" + "time" +) + +// Updater checks for and downloads client updates. +type Updater struct { + fetcher *Fetcher + currentVer string + dataDir string +} + +// NewUpdater creates an updater. +func NewUpdater(fetcher *Fetcher, currentVersion, dataDir string) *Updater { + return &Updater{ + fetcher: fetcher, + currentVer: currentVersion, + dataDir: dataDir, + } +} + +// UpdateInfo describes an available update. +type UpdateInfo struct { + Available bool `json:"available"` + Version string `json:"version"` + Changelog string `json:"changelog"` + CurrentVer string `json:"current_version"` +} + +// Check returns info about available updates. +func (u *Updater) Check() (*UpdateInfo, error) { + ver, err := u.fetcher.FetchVersion() + if err != nil { + return nil, fmt.Errorf("check update: %w", err) + } + + return &UpdateInfo{ + Available: ver.Version != u.currentVer, + Version: ver.Version, + Changelog: ver.Changelog, + CurrentVer: u.currentVer, + }, nil +} + +// Download fetches the new binary, verifies checksum, and prepares for restart. +// On success it returns "restart_pending" and the caller should exit gracefully. +func (u *Updater) Download() (string, error) { + ver, err := u.fetcher.FetchVersion() + if err != nil { + return "", fmt.Errorf("fetch version: %w", err) + } + + if ver.URL == "" { + suffix := "linux-amd64" + if runtime.GOOS == "windows" { + suffix = "windows-amd64.exe" + } + ver.URL = u.fetcher.baseURL + "/releases/vpnem-" + suffix + } + + client := &http.Client{Timeout: 5 * time.Minute} + resp, err := client.Get(ver.URL) + if err != nil { + return "", fmt.Errorf("download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download: HTTP %d", resp.StatusCode) + } + + ext := "" + if runtime.GOOS == "windows" { + ext = ".exe" + } + + // Download to temp file first + downloadPath := filepath.Join(u.dataDir, "vpnem-new"+ext) + f, err := os.Create(downloadPath) + if err != nil { + return "", fmt.Errorf("create file: %w", err) + } + + // Track SHA256 while downloading + hasher := sha256.New() + written, err := io.Copy(io.MultiWriter(f, hasher), resp.Body) + f.Close() + if err != nil { + os.Remove(downloadPath) + return "", fmt.Errorf("write update: %w", err) + } + + // Verify SHA256 checksum if provided + if ver.SHA256 != "" { + gotHash := hex.EncodeToString(hasher.Sum(nil)) + if gotHash != ver.SHA256 { + os.Remove(downloadPath) + return "", fmt.Errorf("checksum mismatch: expected %s, got %s (%.1f MB)", ver.SHA256, gotHash, float64(written)/1024/1024) + } + log.Printf("update: sha256 verified (%.1f MB)", float64(written)/1024/1024) + } + + // Clean stale configs so new version starts fresh + os.Remove(filepath.Join(u.dataDir, "state.json")) + os.Remove(filepath.Join(u.dataDir, "config.json")) + os.Remove(filepath.Join(u.dataDir, "cache.db")) + + currentBin, _ := os.Executable() + if currentBin == "" { + return "", fmt.Errorf("could not determine current binary") + } + + if runtime.GOOS == "windows" { + // Windows: can't overwrite running exe + // Strategy: rename old to .old, copy new in place + // If .old already exists from a previous failed update, remove it + oldBin := currentBin + ".old" + os.Remove(oldBin) + + if err := os.Rename(currentBin, oldBin); err != nil { + return "", fmt.Errorf("rename old binary: %w", err) + } + + if err := copyFile(downloadPath, currentBin); err != nil { + // Restore old binary + os.Remove(currentBin) + os.Rename(oldBin, currentBin) + os.Remove(downloadPath) + return "", fmt.Errorf("copy new binary: %w", err) + } + + os.Remove(downloadPath) + log.Printf("update: binary replaced, version %s", ver.Version) + return "restart_pending", nil + } + + // Linux: overwrite in place + if err := copyFile(downloadPath, currentBin); err != nil { + os.Remove(downloadPath) + return "", fmt.Errorf("copy new binary: %w", err) + } + os.Remove(downloadPath) + log.Printf("update: binary replaced, version %s", ver.Version) + return "restart_pending", nil +} + +func copyFile(src, dst string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, in) + if err != nil { + return err + } + + // Preserve executable permission on Linux + if runtime.GOOS != "windows" { + os.Chmod(dst, 0o755) + } + return nil +} -- cgit v1.2.3