summaryrefslogtreecommitdiff
path: root/internal/sync/fetcher.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/sync/fetcher.go')
-rw-r--r--internal/sync/fetcher.go643
1 files changed, 643 insertions, 0 deletions
diff --git a/internal/sync/fetcher.go b/internal/sync/fetcher.go
new file mode 100644
index 0000000..3ea6df4
--- /dev/null
+++ b/internal/sync/fetcher.go
@@ -0,0 +1,643 @@
+package sync
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "runtime"
+ "time"
+
+ "vpnem/internal/config"
+ "vpnem/internal/models"
+)
+
+// Fetcher pulls configuration from the vpnem server API.
+type Fetcher struct {
+ baseURL string
+ client *http.Client
+}
+
+// NewFetcher creates a new Fetcher.
+func NewFetcher(baseURL string) *Fetcher {
+ return &Fetcher{
+ baseURL: baseURL,
+ client: &http.Client{
+ Timeout: 15 * time.Second,
+ },
+ }
+}
+
+// FetchServers retrieves the server list from the API.
+func (f *Fetcher) FetchServers() (*models.ServersResponse, error) {
+ catalog, err := f.FetchCatalog()
+ if err == nil {
+ return &models.ServersResponse{Servers: CatalogToServers(catalog)}, nil
+ }
+ return nil, fmt.Errorf("fetch catalog: %w", err)
+}
+
+func (f *Fetcher) FetchCatalogV2() (*models.CatalogV2, error) {
+ var resp models.CatalogV2
+ if err := f.getJSON("/api/v2/catalog", &resp); err != nil {
+ return nil, err
+ }
+ return &resp, nil
+}
+
+func (f *Fetcher) FetchCatalog() (*models.CatalogV2, error) {
+ catalog, err := f.FetchCatalogV2()
+ if err == nil {
+ return catalog, nil
+ }
+ var statusErr *HTTPStatusError
+ if !errors.As(err, &statusErr) || statusErr.StatusCode != http.StatusNotFound {
+ return nil, fmt.Errorf("fetch catalog v2: %w", err)
+ }
+
+ var resp models.ServersResponse
+ if err := f.getJSON("/api/v1/servers", &resp); err != nil {
+ return nil, fmt.Errorf("fetch servers: %w", err)
+ }
+ return ServersToCatalog(resp.Servers), nil
+}
+
+func (f *Fetcher) FetchRoutingPolicy() (*models.RoutingPolicy, error) {
+ var resp models.RoutingPolicy
+ if err := f.getJSON("/api/v1/routing-policy", &resp); err != nil {
+ var statusErr *HTTPStatusError
+ if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
+ return config.DefaultRoutingPolicy(), nil
+ }
+ return nil, fmt.Errorf("fetch routing policy: %w", err)
+ }
+ return config.EffectiveRoutingPolicy(&resp), nil
+}
+
+// FetchRuleSets retrieves the rule-set manifest from the API.
+func (f *Fetcher) FetchRuleSets() (*models.RuleSetManifest, error) {
+ var resp models.RuleSetManifest
+ if err := f.getJSON("/api/v1/ruleset/manifest", &resp); err != nil {
+ return nil, fmt.Errorf("fetch rulesets: %w", err)
+ }
+ return &resp, nil
+}
+
+// DownloadRuleSets downloads all non-optional .srs files to dataDir/rules/.
+// Returns the updated RuleSet list with LocalPath populated.
+func (f *Fetcher) DownloadRuleSets(ruleSets []models.RuleSet, dataDir string) ([]models.RuleSet, error) {
+ rulesDir := filepath.Join(dataDir, "rules")
+ if err := os.MkdirAll(rulesDir, 0755); err != nil {
+ return nil, fmt.Errorf("create rules dir: %w", err)
+ }
+
+ var downloaded []models.RuleSet
+ for _, rs := range ruleSets {
+ if rs.URL == "" {
+ if rs.Optional {
+ continue
+ }
+ return nil, fmt.Errorf("rule-set %s has no URL", rs.Tag)
+ }
+ localPath := filepath.Join(rulesDir, rs.Tag+".srs")
+ resp, err := f.client.Get(rs.URL)
+ if err != nil {
+ if rs.Optional {
+ continue
+ }
+ return nil, fmt.Errorf("download %s: %w", rs.URL, err)
+ }
+ body, err := io.ReadAll(resp.Body)
+ resp.Body.Close()
+ if err != nil {
+ if rs.Optional {
+ continue
+ }
+ return nil, fmt.Errorf("read %s: %w", rs.URL, err)
+ }
+ if resp.StatusCode != http.StatusOK {
+ if rs.Optional {
+ continue
+ }
+ return nil, fmt.Errorf("download %s: HTTP %d", rs.URL, resp.StatusCode)
+ }
+ if err := os.WriteFile(localPath, body, 0644); err != nil {
+ return nil, fmt.Errorf("write %s: %w", localPath, err)
+ }
+ rs.LocalPath = localPath
+ downloaded = append(downloaded, rs)
+ }
+ return downloaded, nil
+}
+
+// FetchVersion retrieves the latest client version info.
+func (f *Fetcher) FetchVersion() (*models.VersionResponse, error) {
+ var resp models.VersionResponse
+ if err := f.getJSON("/api/v1/version", &resp); err != nil {
+ return nil, fmt.Errorf("fetch version: %w", err)
+ }
+ return &resp, nil
+}
+
+// ServerIPs extracts all unique server IPs from the server list.
+func ServerIPs(servers []models.Server) []string {
+ seen := make(map[string]bool)
+ var ips []string
+ for _, s := range servers {
+ if !seen[s.Server] {
+ seen[s.Server] = true
+ ips = append(ips, s.Server)
+ }
+ }
+ return ips
+}
+
+// ReportError sends error logs to the server (best-effort, non-blocking).
+func (f *Fetcher) ReportError(version, osName string, lines []string) {
+ payload := map[string]any{
+ "version": version,
+ "os": osName,
+ "lines": lines,
+ }
+ data, err := json.Marshal(payload)
+ if err != nil {
+ return
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ req, err := http.NewRequestWithContext(ctx, http.MethodPost, f.baseURL+"/logs2026vpnem/errors", bytes.NewReader(data))
+ if err != nil {
+ return
+ }
+ req.Header.Set("Content-Type", "application/json")
+ resp, err := f.client.Do(req)
+ if err != nil {
+ return
+ }
+ defer resp.Body.Close()
+ io.Copy(io.Discard, resp.Body)
+}
+
+func (f *Fetcher) getJSON(path string, v any) error {
+ resp, err := f.client.Get(f.baseURL + path)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)}
+ }
+
+ return json.NewDecoder(resp.Body).Decode(v)
+}
+
+type HTTPStatusError struct {
+ StatusCode int
+ Body string
+}
+
+func (e *HTTPStatusError) Error() string {
+ return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Body)
+}
+
+// ReportConnect sends a connection report to the server.
+// Server auto-detects client real IP from X-Forwarded-For.
+func (f *Fetcher) ReportConnect(serverIP, nodeID string) (*models.RecommendationResponse, error) {
+ req := models.ConnectRequest{
+ ServerIP: serverIP,
+ NodeID: nodeID,
+ OS: func() string {
+ if runtime.GOOS == "windows" {
+ return "windows"
+ }
+ return "linux"
+ }(),
+ Version: "", // server doesn't need version for rebalancing
+ }
+
+ payload, err := json.Marshal(req)
+ if err != nil {
+ return nil, err
+ }
+
+ resp, err := f.client.Post(f.baseURL+"/api/v1/connect", "application/json", bytes.NewReader(payload))
+ if err != nil {
+ return nil, fmt.Errorf("report connect: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)}
+ }
+
+ var recommendation models.RecommendationResponse
+ if err := json.NewDecoder(resp.Body).Decode(&recommendation); err != nil {
+ return nil, fmt.Errorf("decode recommendation: %w", err)
+ }
+ return &recommendation, nil
+}
+
+// ReportDisconnect notifies the server that a client disconnected.
+func (f *Fetcher) ReportDisconnect(serverIP, nodeID string) error {
+ req := models.DisconnectRequest{
+ ServerIP: serverIP,
+ NodeID: nodeID,
+ }
+
+ payload, err := json.Marshal(req)
+ if err != nil {
+ return err
+ }
+
+ resp, err := f.client.Post(f.baseURL+"/api/v1/disconnect", "application/json", bytes.NewReader(payload))
+ if err != nil {
+ return fmt.Errorf("report disconnect: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)}
+ }
+ return nil
+}
+
+// GetRecommendation fetches a recommendation for the client.
+// Server auto-detects client real IP from X-Forwarded-For.
+func (f *Fetcher) GetRecommendation() (*models.RecommendationResponse, error) {
+ url := f.baseURL + "/api/v1/recommend"
+ resp, err := f.client.Get(url)
+ if err != nil {
+ return nil, fmt.Errorf("get recommendation: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ body, _ := io.ReadAll(resp.Body)
+ return nil, &HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)}
+ }
+
+ var recommendation models.RecommendationResponse
+ if err := json.NewDecoder(resp.Body).Decode(&recommendation); err != nil {
+ return nil, fmt.Errorf("decode recommendation: %w", err)
+ }
+ return &recommendation, nil
+}
+
+func CatalogToServers(catalog *models.CatalogV2) []models.Server {
+ if catalog == nil {
+ return nil
+ }
+ servers := make([]models.Server, 0)
+ for _, node := range catalog.Nodes {
+ if multi, ok := nodeToSplitServer(node); ok {
+ servers = append(servers, multi)
+ }
+ host := node.PublicHost
+ if host == "" {
+ if node.Domain != "" {
+ host = node.Domain
+ } else {
+ host = node.Host
+ }
+ }
+ for _, protocol := range node.Protocols {
+ if !protocol.Enabled {
+ continue
+ }
+ if isSplitProtocol(protocol.Type) && hasSplitPair(node) {
+ continue
+ }
+ server := models.Server{
+ Tag: legacyTag(node, protocol),
+ Region: node.Region,
+ Type: protocol.Type,
+ Server: host,
+ ServerPort: protocol.Port,
+ }
+ if protocol.TLS != nil {
+ server.TLS = &models.TLS{
+ Enabled: protocol.TLS.Enabled,
+ ServerName: protocol.TLS.ServerName,
+ Insecure: protocol.TLS.Insecure,
+ ALPN: protocol.TLS.ALPN,
+ MinVersion: protocol.TLS.MinVersion,
+ MaxVersion: protocol.TLS.MaxVersion,
+ }
+ if protocol.TLS.Reality != nil && protocol.TLS.Reality.Enabled {
+ server.TLS.Reality = &models.Reality{
+ Enabled: true,
+ PublicKey: protocol.TLS.Reality.PublicKey,
+ ShortID: protocol.TLS.Reality.ShortID,
+ Fingerprint: protocol.TLS.Reality.Fingerprint,
+ }
+ }
+ }
+ switch protocol.Type {
+ case "vless", "vless-reality", "vmess":
+ if protocol.Auth != nil {
+ server.UUID = protocol.Auth.UUID
+ }
+ if transportType, _ := protocol.Extra["transport_type"].(string); transportType != "" {
+ server.Transport = &models.Transport{
+ Type: transportType,
+ Path: extraString(protocol.Extra, "path"),
+ }
+ } else if path := extraString(protocol.Extra, "path"); path != "" {
+ server.Transport = &models.Transport{
+ Type: "ws",
+ Path: path,
+ }
+ }
+ case "shadowsocks":
+ if protocol.Auth != nil {
+ server.Method = protocol.Auth.Method
+ server.Password = protocol.Auth.Password
+ }
+ case "hysteria2":
+ if protocol.Auth != nil {
+ server.Password = protocol.Auth.Password
+ }
+ server.ObfsPassword = extraString(protocol.Extra, "obfs_password")
+ server.UpMbps = extraInt(protocol.Extra, "up_mbps", 0)
+ server.DownMbps = extraInt(protocol.Extra, "down_mbps", 0)
+ if server.TLS == nil {
+ server.TLS = &models.TLS{}
+ }
+ server.TLS.Enabled = true
+ server.TLS.Insecure = true
+ if len(server.TLS.ALPN) == 0 {
+ server.TLS.ALPN = []string{"h3"}
+ }
+ if server.TLS.MinVersion == "" {
+ server.TLS.MinVersion = "1.3"
+ }
+ if server.TLS.MaxVersion == "" {
+ server.TLS.MaxVersion = "1.3"
+ }
+ case "socks5":
+ server.Type = "socks"
+ }
+ servers = append(servers, server)
+ }
+ }
+ return servers
+}
+
+func nodeToSplitServer(node models.CatalogNode) (models.Server, bool) {
+ if !hasSplitPair(node) {
+ return models.Server{}, false
+ }
+ host := node.PublicHost
+ if host == "" {
+ if node.Domain != "" {
+ host = node.Domain
+ } else {
+ host = node.Host
+ }
+ }
+ var main models.Server
+ var hy2 models.Server
+ for _, protocol := range node.Protocols {
+ if !protocol.Enabled {
+ continue
+ }
+ switch protocol.Type {
+ case "vless-reality":
+ main = catalogProtocolToServer(node, protocol, host)
+ case "hysteria2":
+ hy2 = catalogProtocolToServer(node, protocol, host)
+ }
+ }
+ if main.Type == "" || hy2.Type == "" {
+ return models.Server{}, false
+ }
+ main.Tag = node.ID + "-multi"
+ main.Companions = []models.Server{hy2}
+ return main, true
+}
+
+func hasSplitPair(node models.CatalogNode) bool {
+ hasReality := false
+ hasHy2 := false
+ for _, protocol := range node.Protocols {
+ if !protocol.Enabled {
+ continue
+ }
+ switch protocol.Type {
+ case "vless-reality":
+ hasReality = true
+ case "hysteria2":
+ hasHy2 = true
+ }
+ }
+ return hasReality && hasHy2
+}
+
+func isSplitProtocol(protocolType string) bool {
+ return protocolType == "vless-reality" || protocolType == "hysteria2"
+}
+
+func catalogProtocolToServer(node models.CatalogNode, protocol models.CatalogProtocol, host string) models.Server {
+ server := models.Server{
+ Tag: legacyTag(node, protocol),
+ Region: node.Region,
+ Type: protocol.Type,
+ Server: host,
+ ServerPort: protocol.Port,
+ }
+ if protocol.TLS != nil {
+ server.TLS = &models.TLS{
+ Enabled: protocol.TLS.Enabled,
+ ServerName: protocol.TLS.ServerName,
+ Insecure: protocol.TLS.Insecure,
+ ALPN: protocol.TLS.ALPN,
+ MinVersion: protocol.TLS.MinVersion,
+ MaxVersion: protocol.TLS.MaxVersion,
+ }
+ if protocol.TLS.Reality != nil && protocol.TLS.Reality.Enabled {
+ server.TLS.Reality = &models.Reality{
+ Enabled: true,
+ PublicKey: protocol.TLS.Reality.PublicKey,
+ ShortID: protocol.TLS.Reality.ShortID,
+ Fingerprint: protocol.TLS.Reality.Fingerprint,
+ }
+ }
+ }
+ switch protocol.Type {
+ case "vless", "vless-reality", "vmess":
+ if protocol.Auth != nil {
+ server.UUID = protocol.Auth.UUID
+ }
+ if transportType, _ := protocol.Extra["transport_type"].(string); transportType != "" {
+ server.Transport = &models.Transport{
+ Type: transportType,
+ Path: extraString(protocol.Extra, "path"),
+ }
+ } else if path := extraString(protocol.Extra, "path"); path != "" {
+ server.Transport = &models.Transport{
+ Type: "ws",
+ Path: path,
+ }
+ }
+ case "shadowsocks":
+ if protocol.Auth != nil {
+ server.Method = protocol.Auth.Method
+ server.Password = protocol.Auth.Password
+ }
+ case "hysteria2":
+ if protocol.Auth != nil {
+ server.Password = protocol.Auth.Password
+ }
+ server.ObfsPassword = extraString(protocol.Extra, "obfs_password")
+ server.UpMbps = extraInt(protocol.Extra, "up_mbps", 0)
+ server.DownMbps = extraInt(protocol.Extra, "down_mbps", 0)
+ if server.TLS == nil {
+ server.TLS = &models.TLS{}
+ }
+ server.TLS.Enabled = true
+ server.TLS.Insecure = true
+ if len(server.TLS.ALPN) == 0 {
+ server.TLS.ALPN = []string{"h3"}
+ }
+ if server.TLS.MinVersion == "" {
+ server.TLS.MinVersion = "1.3"
+ }
+ if server.TLS.MaxVersion == "" {
+ server.TLS.MaxVersion = "1.3"
+ }
+ case "socks5":
+ server.Type = "socks"
+ }
+ return server
+}
+
+func ServersToCatalog(servers []models.Server) *models.CatalogV2 {
+ nodesByID := make(map[string]*models.CatalogNode, len(servers))
+ order := make([]string, 0, len(servers))
+ for _, server := range servers {
+ nodeID := server.Tag
+ if existingID, protocolType, ok := splitLegacyTag(server.Tag); ok && existingID != "" {
+ nodeID = existingID
+ server.Type = protocolType
+ }
+
+ node := nodesByID[nodeID]
+ if node == nil {
+ node = &models.CatalogNode{
+ ID: nodeID,
+ Name: nodeID,
+ Region: server.Region,
+ Host: server.Server,
+ PublicHost: server.Server,
+ Status: "published",
+ }
+ nodesByID[nodeID] = node
+ order = append(order, nodeID)
+ }
+ node.Protocols = append(node.Protocols, serverToCatalogProtocol(server))
+ }
+
+ nodes := make([]models.CatalogNode, 0, len(order))
+ for _, id := range order {
+ nodes = append(nodes, *nodesByID[id])
+ }
+ return &models.CatalogV2{
+ Version: "legacy-adapter",
+ Nodes: nodes,
+ }
+}
+
+func serverToCatalogProtocol(server models.Server) models.CatalogProtocol {
+ protocolType := server.Type
+ if protocolType == "socks" {
+ protocolType = "socks5"
+ }
+ protocol := models.CatalogProtocol{
+ Type: protocolType,
+ Enabled: true,
+ Port: server.ServerPort,
+ TLS: server.TLS,
+ Extra: make(map[string]any),
+ }
+ switch server.Type {
+ case "vless", "vless-reality", "vmess":
+ protocol.Auth = &models.CatalogAuth{UUID: server.UUID}
+ protocol.Extra["legacy_tag"] = server.Tag
+ if server.Transport != nil {
+ protocol.Extra["transport_type"] = server.Transport.Type
+ if server.Transport.Path != "" {
+ protocol.Extra["path"] = server.Transport.Path
+ }
+ }
+ case "shadowsocks":
+ protocol.Auth = &models.CatalogAuth{Method: server.Method, Password: server.Password}
+ protocol.Extra["legacy_tag"] = server.Tag
+ case "hysteria2":
+ protocol.Auth = &models.CatalogAuth{Password: server.Password}
+ protocol.Extra["legacy_tag"] = server.Tag
+ if server.ObfsPassword != "" {
+ protocol.Extra["obfs_password"] = server.ObfsPassword
+ }
+ if server.UpMbps > 0 {
+ protocol.Extra["up_mbps"] = server.UpMbps
+ }
+ if server.DownMbps > 0 {
+ protocol.Extra["down_mbps"] = server.DownMbps
+ }
+ case "socks":
+ protocol.Extra["legacy_tag"] = server.Tag
+ protocol.Extra["udp_over_tcp"] = server.UDPOverTCP
+ }
+ if len(protocol.Extra) == 0 {
+ protocol.Extra = nil
+ }
+ return protocol
+}
+
+func splitLegacyTag(tag string) (nodeID, protocolType string, ok bool) {
+ for _, candidate := range []string{"vless-reality", "vless", "vmess", "shadowsocks", "hysteria2", "socks", "socks5"} {
+ suffix := "-" + candidate
+ if len(tag) > len(suffix) && tag[len(tag)-len(suffix):] == suffix {
+ return tag[:len(tag)-len(suffix)], candidate, true
+ }
+ }
+ return "", "", false
+}
+
+func legacyTag(node models.CatalogNode, protocol models.CatalogProtocol) string {
+ if tag := extraString(protocol.Extra, "legacy_tag"); tag != "" {
+ return tag
+ }
+ return node.ID + "-" + protocol.Type
+}
+
+func extraString(extra map[string]any, key string) string {
+ if extra == nil {
+ return ""
+ }
+ value, _ := extra[key].(string)
+ return value
+}
+
+func extraInt(extra map[string]any, key string, fallback int) int {
+ if extra == nil {
+ return fallback
+ }
+ switch value := extra[key].(type) {
+ case int:
+ return value
+ case float64:
+ return int(value)
+ default:
+ return fallback
+ }
+}