diff options
Diffstat (limited to 'internal/control')
| -rw-r--r-- | internal/control/bootstrap.go | 369 | ||||
| -rw-r--r-- | internal/control/bootstrap_test.go | 58 | ||||
| -rw-r--r-- | internal/control/catalog.go | 229 | ||||
| -rw-r--r-- | internal/control/catalog_test.go | 332 | ||||
| -rw-r--r-- | internal/control/dns.go | 163 | ||||
| -rw-r--r-- | internal/control/dns_test.go | 58 | ||||
| -rw-r--r-- | internal/control/health_test.go | 49 | ||||
| -rw-r--r-- | internal/control/hysteria2.go | 179 | ||||
| -rw-r--r-- | internal/control/inventory.go | 225 | ||||
| -rw-r--r-- | internal/control/inventory_test.go | 50 | ||||
| -rw-r--r-- | internal/control/lifecycle.go | 205 | ||||
| -rw-r--r-- | internal/control/lifecycle_test.go | 149 | ||||
| -rw-r--r-- | internal/control/models.go | 66 | ||||
| -rw-r--r-- | internal/control/preflight.go | 68 | ||||
| -rw-r--r-- | internal/control/publish.go | 321 | ||||
| -rw-r--r-- | internal/control/reality.go | 64 | ||||
| -rw-r--r-- | internal/control/runtime.go | 586 | ||||
| -rw-r--r-- | internal/control/runtime_test.go | 307 | ||||
| -rw-r--r-- | internal/control/ssh.go | 182 | ||||
| -rw-r--r-- | internal/control/ssh_test.go | 80 | ||||
| -rw-r--r-- | internal/control/state.go | 71 | ||||
| -rw-r--r-- | internal/control/upgrade_test.go | 55 |
22 files changed, 3866 insertions, 0 deletions
diff --git a/internal/control/bootstrap.go b/internal/control/bootstrap.go new file mode 100644 index 0000000..5eb6f4f --- /dev/null +++ b/internal/control/bootstrap.go @@ -0,0 +1,369 @@ +package control + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +type BootstrapOptions struct { + StateDir string + DryRun bool +} + +func BootstrapNode(ctx context.Context, runner SSHExecutor, node Node, opts BootstrapOptions) (*NodeState, error) { + for idx := range node.Protocols { + if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { + return nil, err + } + if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { + return nil, err + } + } + if err := ValidateNode(node); err != nil { + return nil, err + } + + now := time.Now().UTC() + state := &NodeState{ + NodeID: node.ID, + BootstrapStatus: "pending", + PublicHost: publicHost(node), + Services: serviceStatuses(node.Protocols, "configured"), + Metadata: map[string]any{ + "provider": node.Provider, + "region": node.Region, + "dry_run": opts.DryRun, + }, + } + + if opts.DryRun { + state.BootstrapStatus = "planned" + state.LastBootstrapAt = &now + state.Metadata["release_id"] = buildReleaseID(now) + if err := SaveNodeState(opts.StateDir, *state); err != nil { + return nil, err + } + return state, nil + } + + relID := buildReleaseID(now) + bundleDir, tarballPath, err := buildRuntimeBundle(node, relID) + if err != nil { + return nil, err + } + defer os.RemoveAll(bundleDir) + defer os.Remove(tarballPath) + + result, err := runner.Run(ctx, node, RenderBootstrapPrepareScript()) + if err != nil { + state.BootstrapStatus = "failed" + state.LastBootstrapAt = &now + state.Metadata["stderr"] = strings.TrimSpace(result.Stderr) + state.Metadata["stdout"] = strings.TrimSpace(result.Stdout) + if saveErr := SaveNodeState(opts.StateDir, *state); saveErr != nil { + return nil, fmt.Errorf("%w; save state: %v", err, saveErr) + } + return nil, err + } + + remoteTarballPath := "/tmp/vpnem-node-" + node.ID + ".tar.gz" + if err := runner.CopyFile(ctx, node, tarballPath, remoteTarballPath); err != nil { + state.BootstrapStatus = "failed" + state.LastBootstrapAt = &now + state.Metadata["release_id"] = relID + state.Metadata["copy_error"] = err.Error() + if saveErr := SaveNodeState(opts.StateDir, *state); saveErr != nil { + return nil, fmt.Errorf("%w; save state: %v", err, saveErr) + } + return nil, err + } + + result, err = runner.Run(ctx, node, RenderBootstrapFinalizeScript(node, relID, remoteTarballPath)) + if err != nil { + state.BootstrapStatus = "failed" + state.LastBootstrapAt = &now + state.Metadata["stderr"] = strings.TrimSpace(result.Stderr) + state.Metadata["stdout"] = strings.TrimSpace(result.Stdout) + if saveErr := SaveNodeState(opts.StateDir, *state); saveErr != nil { + return nil, fmt.Errorf("%w; save state: %v", err, saveErr) + } + return nil, err + } + + state.BootstrapStatus = "ready" + state.LastBootstrapAt = &now + state.Metadata["release_id"] = relID + state.Metadata["stdout"] = strings.TrimSpace(result.Stdout) + if err := SaveNodeState(opts.StateDir, *state); err != nil { + return nil, err + } + return state, nil +} + +func CheckNode(ctx context.Context, runner SSHExecutor, node Node, stateDir string) (*NodeState, error) { + now := time.Now().UTC() + result, err := runner.Check(ctx, node) + state := &NodeState{ + NodeID: node.ID, + PublicHost: publicHost(node), + LastHealthCheckAt: &now, + Services: serviceStatuses(node.Protocols, "unknown"), + Metadata: map[string]any{}, + } + + if err != nil { + state.BootstrapStatus = "unreachable" + state.Metadata["stderr"] = strings.TrimSpace(result.Stderr) + if saveErr := SaveNodeState(stateDir, *state); saveErr != nil { + return nil, fmt.Errorf("%w; save state: %v", err, saveErr) + } + return nil, err + } + + state.BootstrapStatus = "reachable" + state.Metadata["stdout"] = strings.TrimSpace(result.Stdout) + + runtimeResult, runtimeErr := runner.Run(ctx, node, RenderHealthCheckScript(node)) + if runtimeErr != nil { + state.Metadata["runtime_stderr"] = strings.TrimSpace(runtimeResult.Stderr) + state.Metadata["runtime_stdout"] = strings.TrimSpace(runtimeResult.Stdout) + } else { + services, metadata := parseHealthCheckOutput(runtimeResult.Stdout, node.Protocols) + if len(services) > 0 { + state.Services = services + } + for k, v := range metadata { + state.Metadata[k] = v + } + if healthy, ok := metadata["healthz_http_code"].(int); ok && healthy == 200 { + state.BootstrapStatus = "healthy" + } else if allServicesRunning(state.Services) { + state.BootstrapStatus = "ready" + } + } + if err := SaveNodeState(stateDir, *state); err != nil { + return nil, err + } + return state, nil +} + +func RenderBootstrapPrepareScript() string { + var b strings.Builder + b.WriteString("set -eu\n") + b.WriteString("export DEBIAN_FRONTEND=noninteractive\n") + b.WriteString("mkdir -p /opt/vpnem-node/releases\n") + b.WriteString("if command -v apt-get >/dev/null 2>&1; then\n") + b.WriteString(" apt-get update\n") + b.WriteString(" apt-get install -y ca-certificates curl tar gzip openssl docker.io docker-compose || true\n") + b.WriteString("elif command -v dnf >/dev/null 2>&1; then\n") + b.WriteString(" dnf install -y ca-certificates curl tar gzip openssl docker docker-compose-plugin docker-compose || true\n") + b.WriteString("elif command -v pacman >/dev/null 2>&1; then\n") + b.WriteString(" pacman -Sy --noconfirm ca-certificates curl tar gzip openssl docker docker-compose || true\n") + b.WriteString("elif command -v apk >/dev/null 2>&1; then\n") + b.WriteString(" apk add --no-cache ca-certificates curl tar gzip openssl docker-cli-compose || true\n") + b.WriteString("fi\n") + b.WriteString("if command -v systemctl >/dev/null 2>&1; then systemctl enable --now docker || true; fi\n") + b.WriteString("if ! command -v docker >/dev/null 2>&1; then\n") + b.WriteString(" echo 'docker is not installed after bootstrap prepare' >&2\n") + b.WriteString(" exit 1\n") + b.WriteString("fi\n") + b.WriteString("printf 'vpnem-node bootstrap prepared\\n'\n") + return b.String() +} + +func RenderBootstrapFinalizeScript(node Node, releaseID, remoteTarballPath string) string { + var b strings.Builder + releaseDir := "/opt/vpnem-node/releases/" + releaseID + b.WriteString("set -eu\n") + b.WriteString("mkdir -p " + releaseDir + "\n") + b.WriteString("tar -xzf " + remoteTarballPath + " -C " + releaseDir + "\n") + b.WriteString("ln -sfn " + releaseDir + " /opt/vpnem-node/current\n") + b.WriteString("rm -f " + remoteTarballPath + "\n") + b.WriteString("if ! command -v docker >/dev/null 2>&1; then\n") + b.WriteString(" echo 'docker is not installed on target node' >&2\n") + b.WriteString(" exit 1\n") + b.WriteString("fi\n") + b.WriteString("if docker compose version >/dev/null 2>&1; then\n") + b.WriteString(" docker compose -f /opt/vpnem-node/current/docker-compose.yml up -d --force-recreate\n") + b.WriteString("elif command -v docker-compose >/dev/null 2>&1; then\n") + b.WriteString(" docker-compose -f /opt/vpnem-node/current/docker-compose.yml up -d --force-recreate\n") + b.WriteString("else\n") + b.WriteString(" echo 'docker compose is not available on target node' >&2\n") + b.WriteString(" exit 1\n") + b.WriteString("fi\n") + b.WriteString("printf 'vpnem-node release ") + b.WriteString(shellQuoteValue(releaseID)) + b.WriteString(" ready for ") + b.WriteString(shellQuoteValue(node.ID)) + b.WriteString("\\n'\n") + return b.String() +} + +func RenderHealthCheckScript(node Node) string { + var b strings.Builder + b.WriteString("set -eu\n") + b.WriteString("if [ -f /opt/vpnem-node/current/docker-compose.yml ]; then\n") + b.WriteString(" if command -v docker >/dev/null 2>&1 && docker compose version >/dev/null 2>&1; then\n") + b.WriteString(" docker compose -f /opt/vpnem-node/current/docker-compose.yml ps --format json 2>/dev/null || true\n") + b.WriteString(" elif command -v docker-compose >/dev/null 2>&1; then\n") + b.WriteString(" docker-compose -f /opt/vpnem-node/current/docker-compose.yml ps --format json 2>/dev/null || true\n") + b.WriteString(" fi\n") + b.WriteString(" if command -v docker >/dev/null 2>&1; then\n") + b.WriteString(" docker ps --format '{{json .}}' 2>/dev/null || true\n") + b.WriteString(" fi\n") + b.WriteString("fi\n") + if needsEdgeProxy(node) { + b.WriteString("printf 'HEALTHZ_HTTP_CODE='; ") + b.WriteString("curl -ks --resolve ") + b.WriteString(shellQuoteValue(node.Domain)) + b.WriteString(":443:127.0.0.1 -o /dev/null -w '%{http_code}' https://") + b.WriteString(shellQuoteValue(node.Domain)) + b.WriteString("/healthz || true\n") + } + if needsHysteria2HealthInbound(node) { + b.WriteString("printf 'HY2_MIXED_PORT='; ") + b.WriteString("curl -sS --max-time 5 --proxy socks5h://127.0.0.1:1080 https://ifconfig.me/ip || true\n") + } + return b.String() +} + +func serviceStatuses(protocols []ProtocolProfile, status string) []ServiceStatus { + services := make([]ServiceStatus, 0, len(protocols)) + for _, protocol := range protocols { + if !protocol.Enabled { + continue + } + services = append(services, ServiceStatus{ + Type: protocol.Type, + Status: status, + Port: protocol.Port, + }) + } + return services +} + +func parseHealthCheckOutput(stdout string, protocols []ProtocolProfile) ([]ServiceStatus, map[string]any) { + services := serviceStatuses(protocols, "unknown") + metadata := map[string]any{} + lines := strings.Split(stdout, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if strings.HasPrefix(line, "HEALTHZ_HTTP_CODE=") { + codeStr := strings.TrimPrefix(line, "HEALTHZ_HTTP_CODE=") + if code, err := strconv.Atoi(codeStr); err == nil { + metadata["healthz_http_code"] = code + } + continue + } + if strings.HasPrefix(line, "HY2_MIXED_PORT=") { + value := strings.TrimSpace(strings.TrimPrefix(line, "HY2_MIXED_PORT=")) + metadata["hy2_mixed_port"] = value + if value != "" { + markServicesByTypes(services, []string{"hysteria2"}, "running") + } + continue + } + + var entry map[string]any + if err := jsonUnmarshalLine(line, &entry); err != nil { + continue + } + serviceName, _ := entry["Service"].(string) + state, _ := entry["State"].(string) + if serviceName == "" { + if labels, _ := entry["Labels"].(string); strings.Contains(labels, "com.docker.compose.service=sing-box") { + serviceName = "sing-box" + } else if names, _ := entry["Names"].(string); strings.Contains(names, "sing-box") { + serviceName = "sing-box" + } + } + if state == "" { + if status, _ := entry["Status"].(string); strings.HasPrefix(strings.ToLower(status), "up") { + state = "running" + } + } + if serviceName == "" || state == "" { + continue + } + metadata["docker_"+serviceName] = state + switch serviceName { + case "sing-box": + markServicesByTypes(services, []string{"vless", "vless-reality", "shadowsocks", "socks", "socks5", "vmess", "hysteria2"}, state) + case "caddy": + markServicesByTypes(services, []string{"vless", "vmess"}, state) + } + } + return services, metadata +} + +func allServicesRunning(services []ServiceStatus) bool { + if len(services) == 0 { + return false + } + for _, service := range services { + if service.Status != "running" { + return false + } + } + return true +} + +func markServicesByTypes(services []ServiceStatus, kinds []string, state string) { + set := make(map[string]struct{}, len(kinds)) + for _, kind := range kinds { + set[kind] = struct{}{} + } + for idx := range services { + if _, ok := set[services[idx].Type]; ok { + services[idx].Status = state + } + } +} + +func jsonUnmarshalLine(line string, out *map[string]any) error { + decoder := strings.NewReader(line) + return json.NewDecoder(decoder).Decode(out) +} + +func publicHost(node Node) string { + if strings.TrimSpace(node.Domain) != "" { + return node.Domain + } + return node.Host +} + +func shellQuoteValue(value string) string { + value = strings.ReplaceAll(value, "\n", "") + return value +} + +func buildRuntimeBundle(node Node, releaseID string) (string, string, error) { + rootDir, err := os.MkdirTemp("", "vpnem-node-bundle-*") + if err != nil { + return "", "", err + } + bundleDir := filepath.Join(rootDir, "bundle") + if err := RenderRuntimeBundle(bundleDir, node, releaseID); err != nil { + os.RemoveAll(rootDir) + return "", "", err + } + tarballPath := filepath.Join(rootDir, "bundle.tar.gz") + if err := CreateTarGzFromDir(bundleDir, tarballPath); err != nil { + os.RemoveAll(rootDir) + return "", "", err + } + return rootDir, tarballPath, nil +} + +func buildReleaseID(now time.Time) string { + return now.UTC().Format("20060102-150405") +} diff --git a/internal/control/bootstrap_test.go b/internal/control/bootstrap_test.go new file mode 100644 index 0000000..70e5ccb --- /dev/null +++ b/internal/control/bootstrap_test.go @@ -0,0 +1,58 @@ +package control + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRenderBootstrapScript(t *testing.T) { + t.Parallel() + + script := RenderBootstrapPrepareScript() + script += RenderBootstrapFinalizeScript(Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + }, "20260401-123000", "/tmp/vpnem-node-nl-01.tar.gz") + + if !strings.Contains(script, "mkdir -p /opt/vpnem-node/releases") { + t.Fatal("expected remote workdir creation") + } + if !strings.Contains(script, "vpnem-node release 20260401-123000 ready for nl-01") { + t.Fatal("expected release finalize message") + } +} + +func TestSaveNodeState(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + err := SaveNodeState(dir, NodeState{ + NodeID: "nl-01", + BootstrapStatus: "ready", + Services: []ServiceStatus{ + {Type: "vless", Status: "configured", Port: 443}, + }, + }) + if err != nil { + t.Fatalf("SaveNodeState error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "nl-01.json")) + if err != nil { + t.Fatalf("ReadFile error = %v", err) + } + if !strings.Contains(string(data), `"bootstrap_status": "ready"`) { + t.Fatal("expected bootstrap_status in state file") + } +} diff --git a/internal/control/catalog.go b/internal/control/catalog.go new file mode 100644 index 0000000..9ef3c35 --- /dev/null +++ b/internal/control/catalog.go @@ -0,0 +1,229 @@ +package control + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "vpnem/internal/models" +) + +func BuildLegacyCatalog(nodes []Node) (*models.ServersResponse, error) { + servers := make([]models.Server, 0) + + for _, node := range nodes { + if !node.Enabled { + continue + } + + publicHost := node.Host + if strings.TrimSpace(node.Domain) != "" { + publicHost = node.Domain + } + + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + if err := ensureRealityProfile(&protocol); err != nil { + return nil, err + } + + server, err := legacyServerFromNode(node, publicHost, protocol) + if err != nil { + return nil, err + } + servers = append(servers, server) + } + } + + sort.Slice(servers, func(i, j int) bool { + return servers[i].Tag < servers[j].Tag + }) + + return &models.ServersResponse{Servers: servers}, nil +} + +func WriteLegacyCatalog(path string, nodes []Node) error { + resp, err := BuildLegacyCatalog(nodes) + if err != nil { + return err + } + staticResp, err := LoadStaticLegacyCatalog(filepath.Join(filepath.Dir(path), "static-servers.json")) + if err != nil { + return err + } + resp.Servers = MergeLegacyServers(staticResp.Servers, resp.Servers) + + data, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return err + } + data = append(data, '\n') + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o644); err != nil { + return err + } + return os.Rename(tmpPath, path) +} + +func LoadStaticLegacyCatalog(path string) (*models.ServersResponse, error) { + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &models.ServersResponse{Servers: nil}, nil + } + return nil, err + } + + var resp models.ServersResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +func MergeLegacyServers(primary, secondary []models.Server) []models.Server { + merged := make([]models.Server, 0, len(primary)+len(secondary)) + seen := make(map[string]struct{}, len(primary)+len(secondary)) + for _, item := range primary { + if _, ok := seen[item.Tag]; ok { + continue + } + seen[item.Tag] = struct{}{} + merged = append(merged, item) + } + for _, item := range secondary { + if _, ok := seen[item.Tag]; ok { + continue + } + seen[item.Tag] = struct{}{} + merged = append(merged, item) + } + sort.Slice(merged, func(i, j int) bool { + return merged[i].Tag < merged[j].Tag + }) + return merged +} + +func legacyServerFromNode(node Node, publicHost string, protocol ProtocolProfile) (models.Server, error) { + switch protocol.Type { + case "socks", "socks5": + return models.Server{ + Tag: node.ID + "-socks5", + Region: node.Region, + Type: "socks", + Server: publicHost, + ServerPort: protocol.Port, + }, nil + case "vless": + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.UUID) == "" { + return models.Server{}, fmt.Errorf("node %s protocol vless requires auth.uuid", node.ID) + } + server := models.Server{ + Tag: node.ID + "-vless", + Region: node.Region, + Type: "vless", + Server: publicHost, + ServerPort: protocol.Port, + UUID: protocol.Auth.UUID, + } + if protocol.TLS != nil { + server.TLS = &models.TLS{ + Enabled: protocol.TLS.Enabled, + ServerName: protocol.TLS.ServerName, + Insecure: false, + } + } + if transportType, _ := protocol.Extra["transport_type"].(string); transportType != "" { + server.Transport = &models.Transport{ + Type: transportType, + Path: stringFromExtra(protocol.Extra, "path"), + } + } + return server, nil + case "vless-reality": + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.UUID) == "" { + return models.Server{}, fmt.Errorf("node %s protocol vless-reality requires auth.uuid", node.ID) + } + if protocol.Reality == nil { + return models.Server{}, fmt.Errorf("node %s protocol vless-reality requires reality settings", node.ID) + } + server := models.Server{ + Tag: node.ID + "-vless-reality", + Region: node.Region, + Type: "vless-reality", + Server: publicHost, + ServerPort: protocol.Port, + UUID: protocol.Auth.UUID, + TLS: &models.TLS{ + Enabled: true, + ServerName: protocol.Reality.ServerName, + Reality: &models.Reality{ + Enabled: true, + PublicKey: protocol.Reality.PublicKey, + ShortID: protocol.Reality.ShortID, + Fingerprint: protocol.Reality.Fingerprint, + }, + }, + } + return server, nil + case "shadowsocks": + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.Method) == "" || strings.TrimSpace(protocol.Auth.Password) == "" { + return models.Server{}, fmt.Errorf("node %s protocol shadowsocks requires auth.method and auth.password", node.ID) + } + return models.Server{ + Tag: node.ID + "-shadowsocks", + Region: node.Region, + Type: "shadowsocks", + Server: publicHost, + ServerPort: protocol.Port, + Method: protocol.Auth.Method, + Password: protocol.Auth.Password, + }, nil + case "hysteria2": + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.Password) == "" { + return models.Server{}, fmt.Errorf("node %s protocol hysteria2 requires auth.password", node.ID) + } + server := models.Server{ + Tag: node.ID + "-hysteria2", + Region: node.Region, + Type: "hysteria2", + Server: publicHost, + ServerPort: protocol.Port, + Password: protocol.Auth.Password, + ObfsPassword: stringFromExtra(protocol.Extra, "obfs_password"), + UpMbps: intFromExtra(protocol.Extra, "up_mbps", 0), + DownMbps: intFromExtra(protocol.Extra, "down_mbps", 0), + TLS: &models.TLS{ + Enabled: true, + Insecure: true, + ServerName: "", + ALPN: []string{defaultHysteria2ALPN}, + MinVersion: "1.3", + MaxVersion: "1.3", + }, + } + if protocol.TLS != nil && protocol.TLS.ServerName != "" { + server.TLS.ServerName = protocol.TLS.ServerName + } + return server, nil + default: + return models.Server{}, fmt.Errorf("node %s uses unsupported legacy protocol %q", node.ID, protocol.Type) + } +} + +func stringFromExtra(extra map[string]any, key string) string { + if extra == nil { + return "" + } + value, _ := extra[key].(string) + return value +} diff --git a/internal/control/catalog_test.go b/internal/control/catalog_test.go new file mode 100644 index 0000000..facaaf7 --- /dev/null +++ b/internal/control/catalog_test.go @@ -0,0 +1,332 @@ +package control + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "vpnem/internal/models" +) + +func TestBuildLegacyCatalog(t *testing.T) { + t.Parallel() + + nodes := []Node{ + { + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + { + Type: "vless", + Enabled: true, + Port: 443, + TLS: &TLSProfile{ + Enabled: true, + ServerName: "nl-01.example.com", + }, + Auth: &AuthProfile{ + UUID: "11111111-1111-1111-1111-111111111111", + }, + Extra: map[string]any{ + "transport_type": "ws", + "path": "/ws", + }, + }, + { + Type: "shadowsocks", + Enabled: true, + Port: 8443, + Auth: &AuthProfile{ + Method: "2022-blake3-aes-128-gcm", + Password: "secret", + }, + }, + }, + }, + } + + resp, err := BuildLegacyCatalog(nodes) + if err != nil { + t.Fatalf("BuildLegacyCatalog error = %v", err) + } + if len(resp.Servers) != 2 { + t.Fatalf("len(resp.Servers) = %d, want 2", len(resp.Servers)) + } + if resp.Servers[0].Tag != "nl-01-shadowsocks" { + t.Fatalf("unexpected first tag %q", resp.Servers[0].Tag) + } + if resp.Servers[1].Tag != "nl-01-vless" { + t.Fatalf("unexpected second tag %q", resp.Servers[1].Tag) + } + if resp.Servers[1].Transport == nil || resp.Servers[1].Transport.Type != "ws" { + t.Fatalf("expected ws transport, got %+v", resp.Servers[1].Transport) + } +} + +func TestBuildLegacyCatalogRejectsUnsupportedProtocol(t *testing.T) { + t.Parallel() + + _, err := BuildLegacyCatalog([]Node{ + { + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + {Type: "hysteria2", Enabled: true, Port: 443}, + }, + }, + }) + if err == nil { + t.Fatal("expected unsupported protocol error") + } +} + +func TestPublishableNodes(t *testing.T) { + t.Parallel() + + nodes := []Node{ + {ID: "healthy", Name: "healthy", Region: "nl", Host: "1.1.1.1", Enabled: true, SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, Protocols: []ProtocolProfile{{Type: "socks5", Enabled: true, Port: 1080}}}, + {ID: "failed", Name: "failed", Region: "nl", Host: "1.1.1.2", Enabled: true, SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, Protocols: []ProtocolProfile{{Type: "socks5", Enabled: true, Port: 1080}}}, + {ID: "nostate", Name: "nostate", Region: "nl", Host: "1.1.1.3", Enabled: true, SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, Protocols: []ProtocolProfile{{Type: "socks5", Enabled: true, Port: 1080}}}, + } + states := map[string]*NodeState{ + "healthy": {NodeID: "healthy", BootstrapStatus: "healthy"}, + "failed": {NodeID: "failed", BootstrapStatus: "failed"}, + } + + got := PublishableNodes(nodes, states) + if len(got) != 1 { + t.Fatalf("len(PublishableNodes) = %d, want 1", len(got)) + } + if got[0].ID != "healthy" { + t.Fatalf("expected healthy node, got %s", got[0].ID) + } +} + +func TestPublishableNodesRequiresRunningServicesWhenKnown(t *testing.T) { + t.Parallel() + + nodes := []Node{ + {ID: "healthy", Name: "healthy", Region: "nl", Host: "1.1.1.1", Enabled: true, SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, Protocols: []ProtocolProfile{{Type: "socks5", Enabled: true, Port: 1080}}}, + {ID: "degraded", Name: "degraded", Region: "nl", Host: "1.1.1.2", Enabled: true, SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, Protocols: []ProtocolProfile{{Type: "socks5", Enabled: true, Port: 1080}}}, + } + states := map[string]*NodeState{ + "healthy": { + NodeID: "healthy", + BootstrapStatus: "healthy", + Services: []ServiceStatus{{Type: "socks5", Status: "running", Port: 1080}}, + Metadata: map[string]any{"healthz_http_code": 200}, + }, + "degraded": { + NodeID: "degraded", + BootstrapStatus: "healthy", + Services: []ServiceStatus{{Type: "socks5", Status: "unknown", Port: 1080}}, + Metadata: map[string]any{"healthz_http_code": 503}, + }, + } + + got := PublishableNodes(nodes, states) + if len(got) != 1 { + t.Fatalf("len(PublishableNodes) = %d, want 1", len(got)) + } + if got[0].ID != "healthy" { + t.Fatalf("expected healthy node, got %s", got[0].ID) + } +} + +func TestPublishDecisionForNode(t *testing.T) { + t.Parallel() + + node := Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443}, + }, + } + + blocked := PublishDecisionForNode(node, &NodeState{ + NodeID: "nl-01", + BootstrapStatus: "healthy", + Services: []ServiceStatus{{Type: "vless", Status: "configured", Port: 443}}, + Metadata: map[string]any{"healthz_http_code": 503}, + }) + if blocked.Eligible { + t.Fatal("expected blocked publish decision") + } + if len(blocked.Reasons) == 0 { + t.Fatal("expected reasons for blocked decision") + } + + ready := PublishDecisionForNode(node, &NodeState{ + NodeID: "nl-01", + BootstrapStatus: "healthy", + PublicHost: "nl-01.example.com", + Services: []ServiceStatus{{Type: "vless", Status: "running", Port: 443}}, + Metadata: map[string]any{"healthz_http_code": 200}, + }) + if !ready.Eligible { + t.Fatalf("expected ready decision, got reasons: %v", ready.Reasons) + } + if ready.PublicHost != "nl-01.example.com" { + t.Fatalf("unexpected public host %q", ready.PublicHost) + } +} + +func TestBuildCatalogV2(t *testing.T) { + t.Parallel() + + nodes := []Node{ + { + ID: "nl-01", + Name: "NL 01", + Provider: "custom-vps", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, TLS: &TLSProfile{Enabled: true, ServerName: "nl-01.example.com"}, Auth: &AuthProfile{UUID: "11111111-1111-1111-1111-111111111111"}}, + {Type: "hysteria2", Enabled: true, Port: 9443, Auth: &AuthProfile{Password: "hidden"}, Extra: map[string]any{"obfs_password": "masked"}}, + }, + }, + } + states := map[string]*NodeState{ + "nl-01": {NodeID: "nl-01", BootstrapStatus: "healthy", PublicHost: "nl-01.example.com", Metadata: map[string]any{"healthz_http_code": 200}}, + } + + catalog := BuildCatalogV2(nodes, states) + if catalog.Version != "2" { + t.Fatalf("catalog.Version = %q, want 2", catalog.Version) + } + if len(catalog.Nodes) != 1 { + t.Fatalf("len(catalog.Nodes) = %d, want 1", len(catalog.Nodes)) + } + if catalog.Nodes[0].PublicHost != "nl-01.example.com" { + t.Fatalf("unexpected public host %q", catalog.Nodes[0].PublicHost) + } + if len(catalog.Nodes[0].Protocols) != 2 { + t.Fatalf("expected 2 protocols, got %d", len(catalog.Nodes[0].Protocols)) + } + if catalog.Nodes[0].Protocols[0].Type != "vless" { + t.Fatalf("unexpected first protocol %q", catalog.Nodes[0].Protocols[0].Type) + } +} + +func TestMergeLegacyServersPreservesStaticEntries(t *testing.T) { + t.Parallel() + + static := []models.Server{ + {Tag: "nl-1", Type: "socks", Server: "1.1.1.1", ServerPort: 1080}, + {Tag: "nl-ss-1", Type: "shadowsocks", Server: "ss.example.com", ServerPort: 443}, + } + dynamic := []models.Server{ + {Tag: "node-1-vless", Type: "vless", Server: "2.2.2.2", ServerPort: 443}, + } + + merged := MergeLegacyServers(static, dynamic) + if len(merged) != 3 { + t.Fatalf("len(merged) = %d, want 3", len(merged)) + } +} + +func TestWriteLegacyCatalogMergesStaticServers(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + staticPath := filepath.Join(dir, "static-servers.json") + if err := os.WriteFile(staticPath, []byte(`{"servers":[{"tag":"nl-1","region":"NL","type":"socks","server":"1.1.1.1","server_port":1080}]}`), 0o644); err != nil { + t.Fatalf("write static servers: %v", err) + } + + err := WriteLegacyCatalog(filepath.Join(dir, "servers.json"), []Node{ + { + ID: "node-1", + Name: "Node 1", + Region: "nl", + Host: "2.2.2.2", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, + Protocols: []ProtocolProfile{ + {Type: "socks5", Enabled: true, Port: 1081}, + }, + }, + }) + if err != nil { + t.Fatalf("WriteLegacyCatalog error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "servers.json")) + if err != nil { + t.Fatalf("read merged servers: %v", err) + } + text := string(data) + if !strings.Contains(text, `"tag": "nl-1"`) { + t.Fatalf("expected static server in merged catalog: %s", text) + } + if !strings.Contains(text, `"tag": "node-1-socks5"`) { + t.Fatalf("expected dynamic server in merged catalog: %s", text) + } +} + +func TestWriteCatalogV2DoesNotMergeStaticLegacyServers(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + staticPath := filepath.Join(dir, "static-servers.json") + if err := os.WriteFile(staticPath, []byte(`{"servers":[{"tag":"nl-ss-1","region":"NL","type":"shadowsocks","server":"ss.example.com","server_port":443,"method":"chacha20-ietf-poly1305","password":"secret"}]}`), 0o644); err != nil { + t.Fatalf("write static servers: %v", err) + } + + err := WriteCatalogV2(filepath.Join(dir, "catalog-v2.json"), []Node{ + { + ID: "node-1", + Name: "Node 1", + Region: "nl", + Host: "2.2.2.2", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, Auth: &AuthProfile{UUID: "11111111-1111-1111-1111-111111111111"}}, + }, + }, + }, map[string]*NodeState{}) + if err != nil { + t.Fatalf("WriteCatalogV2 error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "catalog-v2.json")) + if err != nil { + t.Fatalf("read catalog v2: %v", err) + } + text := string(data) + if strings.Contains(text, `"id": "nl-ss-1"`) { + t.Fatalf("did not expect static legacy node in catalog v2: %s", text) + } + if !strings.Contains(text, `"id": "node-1"`) { + t.Fatalf("expected dynamic node in catalog v2: %s", text) + } +} diff --git a/internal/control/dns.go b/internal/control/dns.go new file mode 100644 index 0000000..f841d91 --- /dev/null +++ b/internal/control/dns.go @@ -0,0 +1,163 @@ +package control + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +const porkbunAPIHost = "https://api.porkbun.com/api/json/v3" + +var porkbunAPIHostOverride string + +type DNSProvider interface { + EnsureRandomARecord(ctx context.Context, zone, prefix, ip string, ttl int) (string, error) + DeleteARecord(ctx context.Context, zone, name string) error +} + +type PorkbunClient struct { + APIKey string + SecretAPIKey string + HTTPClient *http.Client +} + +type porkbunResponse struct { + Status string `json:"status"` + Message string `json:"message"` + Records []map[string]any `json:"records"` + ID string `json:"id"` +} + +func (c PorkbunClient) EnsureRandomARecord(ctx context.Context, zone, prefix, ip string, ttl int) (string, error) { + if err := c.validate(); err != nil { + return "", err + } + if ttl == 0 { + ttl = 600 + } + + for range 10 { + name := randomSubdomain(prefix) + records, err := c.retrieveRecordsByNameType(ctx, zone, "A", name) + if err != nil { + return "", err + } + if len(records) > 0 { + continue + } + if err := c.createRecord(ctx, zone, name, "A", ip, ttl); err != nil { + return "", err + } + return name + "." + zone, nil + } + + return "", errors.New("failed to allocate unique subdomain") +} + +func (c PorkbunClient) DeleteARecord(ctx context.Context, zone, name string) error { + if err := c.validate(); err != nil { + return err + } + return c.deleteByNameType(ctx, zone, "A", name) +} + +func (c PorkbunClient) validate() error { + if strings.TrimSpace(c.APIKey) == "" || strings.TrimSpace(c.SecretAPIKey) == "" { + return errors.New("porkbun api keys are not configured") + } + return nil +} + +func (c PorkbunClient) createRecord(ctx context.Context, zone, name, recordType, content string, ttl int) error { + payload := map[string]string{ + "secretapikey": c.SecretAPIKey, + "apikey": c.APIKey, + "name": name, + "type": recordType, + "content": content, + "ttl": fmt.Sprintf("%d", ttl), + } + _, err := c.post(ctx, "/dns/create/"+zone, payload) + return err +} + +func (c PorkbunClient) deleteByNameType(ctx context.Context, zone, recordType, name string) error { + payload := map[string]string{ + "secretapikey": c.SecretAPIKey, + "apikey": c.APIKey, + } + _, err := c.post(ctx, "/dns/deleteByNameType/"+zone+"/"+recordType+"/"+name, payload) + return err +} + +func (c PorkbunClient) retrieveRecordsByNameType(ctx context.Context, zone, recordType, name string) ([]map[string]any, error) { + payload := map[string]string{ + "secretapikey": c.SecretAPIKey, + "apikey": c.APIKey, + } + resp, err := c.post(ctx, "/dns/retrieveByNameType/"+zone+"/"+recordType+"/"+name, payload) + if err != nil { + return nil, err + } + return resp.Records, nil +} + +func (c PorkbunClient) post(ctx context.Context, path string, payload map[string]string) (*porkbunResponse, error) { + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + client := c.HTTPClient + if client == nil { + client = &http.Client{Timeout: 15 * time.Second} + } + + baseURL := porkbunAPIHost + if porkbunAPIHostOverride != "" { + baseURL = porkbunAPIHostOverride + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+path, bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var out porkbunResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("porkbun http %d: %s", resp.StatusCode, out.Message) + } + if strings.ToUpper(out.Status) != "SUCCESS" { + return nil, fmt.Errorf("porkbun api error: %s", out.Message) + } + return &out, nil +} + +func randomSubdomain(prefix string) string { + if prefix == "" { + prefix = "vpn" + } + var buf [4]byte + if _, err := rand.Read(buf[:]); err != nil { + now := time.Now().UTC().UnixNano() + return fmt.Sprintf("%s-%x", prefix, now) + } + return prefix + "-" + hex.EncodeToString(buf[:]) +} diff --git a/internal/control/dns_test.go b/internal/control/dns_test.go new file mode 100644 index 0000000..cf44639 --- /dev/null +++ b/internal/control/dns_test.go @@ -0,0 +1,58 @@ +package control + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" +) + +func TestPorkbunEnsureRandomARecord(t *testing.T) { + t.Parallel() + + retrieveCalls := 0 + client := &http.Client{Transport: roundTripFunc(func(r *http.Request) (*http.Response, error) { + recorder := map[string]any{} + switch { + case strings.Contains(r.URL.Path, "/dns/retrieveByNameType/"): + retrieveCalls++ + recorder = map[string]any{ + "status": "SUCCESS", + "records": []map[string]any{}, + } + case strings.Contains(r.URL.Path, "/dns/create/"): + recorder = map[string]any{ + "status": "SUCCESS", + "id": "123", + } + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + body, _ := json.Marshal(recorder) + return &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(string(body))), + }, nil + })} + + clientAPI := PorkbunClient{APIKey: "a", SecretAPIKey: "b", HTTPClient: client} + name, err := clientAPI.EnsureRandomARecord(context.Background(), "em-sysadmin.xyz", "vpn", "203.0.113.10", 600) + if err != nil { + t.Fatalf("EnsureRandomARecord error = %v", err) + } + if !strings.HasSuffix(name, ".em-sysadmin.xyz") { + t.Fatalf("expected fqdn suffix, got %q", name) + } + if retrieveCalls == 0 { + t.Fatal("expected retrieve call") + } +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} diff --git a/internal/control/health_test.go b/internal/control/health_test.go new file mode 100644 index 0000000..a0f488f --- /dev/null +++ b/internal/control/health_test.go @@ -0,0 +1,49 @@ +package control + +import "testing" + +func TestParseHealthCheckOutput(t *testing.T) { + t.Parallel() + + stdout := `{"Service":"sing-box","State":"running"} +{"Service":"caddy","State":"running"} +HEALTHZ_HTTP_CODE=200 +` + services, metadata := parseHealthCheckOutput(stdout, []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443}, + {Type: "vmess", Enabled: true, Port: 443}, + {Type: "shadowsocks", Enabled: true, Port: 8443}, + }) + + if len(services) != 3 { + t.Fatalf("len(services) = %d, want 3", len(services)) + } + if metadata["healthz_http_code"] != 200 { + t.Fatalf("healthz_http_code = %v, want 200", metadata["healthz_http_code"]) + } + if services[0].Status != "running" && services[1].Status != "running" && services[2].Status != "running" { + t.Fatal("expected at least one service marked running") + } +} + +func TestParseHealthCheckOutputDockerPSFallback(t *testing.T) { + t.Parallel() + + stdout := `{"Names":"current_sing-box_1","Labels":"com.docker.compose.service=sing-box,com.docker.compose.project=current","State":"running","Status":"Up 52 seconds"} +HY2_MIXED_PORT=5.180.97.199 +` + services, _ := parseHealthCheckOutput(stdout, []ProtocolProfile{ + {Type: "vless-reality", Enabled: true, Port: 443}, + {Type: "hysteria2", Enabled: true, Port: 443}, + }) + + if len(services) != 2 { + t.Fatalf("len(services) = %d, want 2", len(services)) + } + if services[0].Status != "running" { + t.Fatalf("vless-reality status = %q, want running", services[0].Status) + } + if services[1].Status != "running" { + t.Fatalf("hysteria2 status = %q, want running", services[1].Status) + } +} diff --git a/internal/control/hysteria2.go b/internal/control/hysteria2.go new file mode 100644 index 0000000..78c80e7 --- /dev/null +++ b/internal/control/hysteria2.go @@ -0,0 +1,179 @@ +package control + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "math/big" + "strings" + "time" +) + +const ( + defaultHysteria2Port = 443 + defaultHysteria2UpMbps = 100 + defaultHysteria2DownMbps = 100 + defaultHysteria2CertPath = "/etc/sing-box/cert.pem" + defaultHysteria2KeyPath = "/etc/sing-box/key.pem" + defaultHysteria2ALPN = "h3" +) + +func ensureHysteria2Profile(protocol *ProtocolProfile) error { + if protocol == nil || protocol.Type != "hysteria2" { + return nil + } + if protocol.Hysteria2 == nil { + protocol.Hysteria2 = &Hysteria2Profile{} + } + + if protocol.Port > 0 && protocol.Hysteria2.Port == 0 { + protocol.Hysteria2.Port = protocol.Port + } + if protocol.Hysteria2.Port == 0 { + protocol.Hysteria2.Port = defaultHysteria2Port + } + protocol.Port = protocol.Hysteria2.Port + + if protocol.Auth == nil { + protocol.Auth = &AuthProfile{} + } + if protocol.Hysteria2.UserPassword == "" && strings.TrimSpace(protocol.Auth.Password) != "" { + protocol.Hysteria2.UserPassword = strings.TrimSpace(protocol.Auth.Password) + } + if protocol.Hysteria2.UserPassword == "" { + password, err := randomBase64(16) + if err != nil { + return err + } + protocol.Hysteria2.UserPassword = password + } + protocol.Auth.Password = protocol.Hysteria2.UserPassword + + if protocol.Extra == nil { + protocol.Extra = map[string]any{} + } + if protocol.Hysteria2.ObfsPassword == "" { + if extra := stringFromExtra(protocol.Extra, "obfs_password"); extra != "" { + protocol.Hysteria2.ObfsPassword = extra + } + } + if protocol.Hysteria2.ObfsPassword == "" { + obfsPassword, err := randomHex(32) + if err != nil { + return err + } + protocol.Hysteria2.ObfsPassword = obfsPassword + } + protocol.Extra["obfs_password"] = protocol.Hysteria2.ObfsPassword + + if protocol.Hysteria2.UpMbps == 0 { + protocol.Hysteria2.UpMbps = intFromExtra(protocol.Extra, "up_mbps", defaultHysteria2UpMbps) + } + if protocol.Hysteria2.DownMbps == 0 { + protocol.Hysteria2.DownMbps = intFromExtra(protocol.Extra, "down_mbps", defaultHysteria2DownMbps) + } + protocol.Extra["up_mbps"] = protocol.Hysteria2.UpMbps + protocol.Extra["down_mbps"] = protocol.Hysteria2.DownMbps + + if protocol.Hysteria2.CertPath == "" { + protocol.Hysteria2.CertPath = firstNonEmptyString(stringFromExtra(protocol.Extra, "tls_cert_path"), defaultHysteria2CertPath) + } + if protocol.Hysteria2.KeyPath == "" { + protocol.Hysteria2.KeyPath = firstNonEmptyString(stringFromExtra(protocol.Extra, "tls_key_path"), defaultHysteria2KeyPath) + } + protocol.Extra["tls_cert_path"] = protocol.Hysteria2.CertPath + protocol.Extra["tls_key_path"] = protocol.Hysteria2.KeyPath + + if protocol.TLS == nil { + protocol.TLS = &TLSProfile{} + } + protocol.TLS.Enabled = true + if strings.TrimSpace(protocol.TLS.ServerName) == "" { + protocol.TLS.ServerName = "" + } + return nil +} + +func GenerateSelfSignedCert() (certPEM, keyPEM []byte, err error) { + commonName := "node-" + randomHostnameSuffix() + ".local" + return generateSelfSignedCertForHost(commonName) +} + +func generateSelfSignedCertForHost(commonName string) (certPEM, keyPEM []byte, err error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, fmt.Errorf("generate ecdsa key: %w", err) + } + serialLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialLimit) + if err != nil { + return nil, nil, fmt.Errorf("generate serial: %w", err) + } + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now().UTC().Add(-time.Hour), + NotAfter: time.Now().UTC().AddDate(10, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{commonName}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, nil, fmt.Errorf("create certificate: %w", err) + } + certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + keyDER, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return nil, nil, fmt.Errorf("marshal private key: %w", err) + } + keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + return certPEM, keyPEM, nil +} + +func randomBase64(size int) (string, error) { + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return base64.RawStdEncoding.EncodeToString(buf), nil +} + +func randomHostnameSuffix() string { + buf := make([]byte, 4) + if _, err := rand.Read(buf); err != nil { + return "local" + } + return hex.EncodeToString(buf) +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func EnsureProtocolForUI(protocol *ProtocolProfile) error { + if err := ensureRealityProfile(protocol); err != nil { + return err + } + if err := ensureHysteria2Profile(protocol); err != nil { + return err + } + return nil +} diff --git a/internal/control/inventory.go b/internal/control/inventory.go new file mode 100644 index 0000000..22d9179 --- /dev/null +++ b/internal/control/inventory.go @@ -0,0 +1,225 @@ +package control + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + + "gopkg.in/yaml.v3" +) + +type Inventory struct { + Nodes []Node +} + +func LoadInventoryDir(dir string) (*Inventory, error) { + entries, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + nodes := make([]Node, 0, len(entries)) + for _, entry := range entries { + if entry.IsDir() { + continue + } + ext := strings.ToLower(filepath.Ext(entry.Name())) + if ext != ".yaml" && ext != ".yml" { + continue + } + + node, err := LoadNodeFile(filepath.Join(dir, entry.Name())) + if err != nil { + return nil, err + } + nodes = append(nodes, *node) + } + + sort.Slice(nodes, func(i, j int) bool { + return nodes[i].ID < nodes[j].ID + }) + + return &Inventory{Nodes: nodes}, nil +} + +func LoadNodeFile(path string) (*Node, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var node Node + if err := yaml.Unmarshal(data, &node); err != nil { + return nil, fmt.Errorf("parse %s: %w", path, err) + } + for idx := range node.Protocols { + if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { + return nil, fmt.Errorf("prepare %s: %w", path, err) + } + if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { + return nil, fmt.Errorf("prepare %s: %w", path, err) + } + } + if err := ValidateNode(node); err != nil { + return nil, fmt.Errorf("validate %s: %w", path, err) + } + + return &node, nil +} + +func SaveNodeFile(dir string, node Node) (string, error) { + for idx := range node.Protocols { + if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { + return "", err + } + if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { + return "", err + } + } + if err := ValidateNode(node); err != nil { + return "", err + } + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", err + } + + path := filepath.Join(dir, node.ID+".yaml") + data, err := yaml.Marshal(node) + if err != nil { + return "", err + } + if err := os.WriteFile(path, data, 0o600); err != nil { + return "", err + } + + return path, nil +} + +func (i *Inventory) NodeByID(id string) (*Node, bool) { + for idx := range i.Nodes { + if i.Nodes[idx].ID == id { + return &i.Nodes[idx], true + } + } + return nil, false +} + +func ValidateNode(node Node) error { + if strings.TrimSpace(node.ID) == "" { + return errors.New("id is required") + } + if strings.TrimSpace(node.Name) == "" { + return errors.New("name is required") + } + if strings.TrimSpace(node.Region) == "" { + return errors.New("region is required") + } + if strings.TrimSpace(node.Host) == "" { + return errors.New("host is required") + } + if node.SSH.Port < 0 || node.SSH.Port > 65535 { + return errors.New("ssh.port must be between 0 and 65535") + } + if node.SSH.Port == 0 { + node.SSH.Port = 22 + } + if strings.TrimSpace(node.SSH.User) == "" { + return errors.New("ssh.user is required") + } + if strings.TrimSpace(node.SSH.Auth) == "" { + return errors.New("ssh.auth is required") + } + switch strings.TrimSpace(node.SSH.Auth) { + case "key": + if strings.TrimSpace(node.SSH.IdentityFile) == "" { + return errors.New("ssh.identity_file is required when ssh.auth=key") + } + case "password": + if strings.TrimSpace(node.SSH.PasswordEnv) == "" { + return errors.New("ssh.password_env is required when ssh.auth=password") + } + default: + return errors.New("ssh.auth must be either key or password") + } + if len(node.Protocols) == 0 { + return errors.New("at least one protocol is required") + } + + seen := make(map[string]struct{}, len(node.Protocols)) + for _, protocol := range node.Protocols { + if strings.TrimSpace(protocol.Type) == "" { + return errors.New("protocol.type is required") + } + if protocol.Port <= 0 || protocol.Port > 65535 { + return fmt.Errorf("protocol %s has invalid port", protocol.Type) + } + if protocol.Type == "vless" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" { + return errors.New("vless with tls.enabled requires node.domain") + } + if protocol.Type == "vless-reality" { + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.UUID) == "" { + return errors.New("vless-reality requires auth.uuid") + } + if protocol.Reality == nil { + return errors.New("vless-reality requires reality settings") + } + if strings.TrimSpace(protocol.Reality.ServerName) == "" { + return errors.New("vless-reality requires reality.server_name") + } + if strings.TrimSpace(protocol.Reality.PrivateKey) == "" { + return errors.New("vless-reality requires reality.private_key") + } + if strings.TrimSpace(protocol.Reality.PublicKey) == "" { + return errors.New("vless-reality requires reality.public_key") + } + if strings.TrimSpace(protocol.Reality.ShortID) == "" { + return errors.New("vless-reality requires reality.short_id") + } + } + if protocol.Type == "vmess" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" { + return errors.New("vmess with tls.enabled requires node.domain") + } + if protocol.Type == "hysteria2" { + if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.Password) == "" { + return errors.New("hysteria2 requires auth.password") + } + if protocol.Hysteria2 == nil { + return errors.New("hysteria2 requires hysteria2 settings") + } + if protocol.Hysteria2.CertPath == "" || protocol.Hysteria2.KeyPath == "" { + return errors.New("hysteria2 requires cert_path and key_path") + } + } + key := protocol.Type + if _, ok := seen[key]; ok { + return fmt.Errorf("duplicate protocol %s", protocol.Type) + } + seen[key] = struct{}{} + } + + return nil +} + +func CopyNodeFile(srcPath, inventoryDir string) (string, error) { + node, err := LoadNodeFile(srcPath) + if err != nil { + return "", err + } + return SaveNodeFile(inventoryDir, *node) +} + +func DeleteNodeFile(dir, nodeID string) error { + err := os.Remove(filepath.Join(dir, nodeID+".yaml")) + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return err +} + +func IsNotExist(err error) bool { + return errors.Is(err, fs.ErrNotExist) +} diff --git a/internal/control/inventory_test.go b/internal/control/inventory_test.go new file mode 100644 index 0000000..eb03979 --- /dev/null +++ b/internal/control/inventory_test.go @@ -0,0 +1,50 @@ +package control + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadInventoryDir(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + input := `id: nl-01 +name: NL 01 +provider: custom-vps +region: nl +host: 203.0.113.10 +domain: nl-01.example.com +acme_email: admin@example.com +enabled: true +ssh: + user: root + port: 22 + auth: key + identity_file: ~/.ssh/id_ed25519 +protocols: + - type: vless + enabled: true + port: 443 + tls: + enabled: true + server_name: nl-01.example.com + auth: + uuid: 11111111-1111-1111-1111-111111111111 +` + if err := os.WriteFile(filepath.Join(dir, "nl-01.yaml"), []byte(input), 0o600); err != nil { + t.Fatal(err) + } + + inventory, err := LoadInventoryDir(dir) + if err != nil { + t.Fatalf("LoadInventoryDir error = %v", err) + } + if len(inventory.Nodes) != 1 { + t.Fatalf("len(inventory.Nodes) = %d, want 1", len(inventory.Nodes)) + } + if inventory.Nodes[0].ID != "nl-01" { + t.Fatalf("inventory.Nodes[0].ID = %q, want nl-01", inventory.Nodes[0].ID) + } +} diff --git a/internal/control/lifecycle.go b/internal/control/lifecycle.go new file mode 100644 index 0000000..a45339f --- /dev/null +++ b/internal/control/lifecycle.go @@ -0,0 +1,205 @@ +package control + +import ( + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "strings" +) + +func SetNodeEnabled(node Node, enabled bool) Node { + node.Enabled = enabled + return node +} + +func RotateNodeSecrets(node Node) (Node, error) { + for idx := range node.Protocols { + protocol := &node.Protocols[idx] + switch protocol.Type { + case "vless", "vmess": + if protocol.Auth == nil { + protocol.Auth = &AuthProfile{} + } + uuid, err := randomUUID() + if err != nil { + return node, err + } + protocol.Auth.UUID = uuid + case "shadowsocks": + if protocol.Auth == nil { + protocol.Auth = &AuthProfile{} + } + password, err := randomHex(16) + if err != nil { + return node, err + } + protocol.Auth.Password = password + case "hysteria2": + if err := ensureHysteria2Profile(protocol); err != nil { + return node, err + } + password, err := randomBase64(16) + if err != nil { + return node, err + } + protocol.Auth.Password = password + protocol.Hysteria2.UserPassword = password + obfsPassword, err := randomHex(32) + if err != nil { + return node, err + } + protocol.Hysteria2.ObfsPassword = obfsPassword + if protocol.Extra == nil { + protocol.Extra = map[string]any{} + } + protocol.Extra["obfs_password"] = obfsPassword + } + } + return node, nil +} + +func AddSocks5Protocol(node Node, port int) (Node, error) { + if port <= 0 { + port = 54101 + } + for _, protocol := range node.Protocols { + if protocol.Type == "socks5" || protocol.Type == "socks" { + return node, fmt.Errorf("node %s already has SOCKS5 enabled", node.ID) + } + } + node.Protocols = append(node.Protocols, ProtocolProfile{ + Type: "socks5", + Enabled: true, + Port: port, + }) + return node, nil +} + +func DestroyNode(ctx context.Context, runner SSHExecutor, dnsClient DNSProvider, zone string, node Node, inventoryDir, stateDir string) []string { + var warnings []string + + if dnsClient != nil && strings.TrimSpace(node.Domain) != "" && strings.HasSuffix(node.Domain, "."+zone) { + name := strings.TrimSuffix(node.Domain, "."+zone) + name = strings.TrimSuffix(name, ".") + if err := dnsClient.DeleteARecord(ctx, zone, name); err != nil { + warnings = append(warnings, "dns cleanup failed: "+err.Error()) + } + } + + if strings.TrimSpace(node.Host) != "" { + if _, err := runner.Run(ctx, node, RenderDestroyScript()); err != nil { + warnings = append(warnings, "remote cleanup failed: "+err.Error()) + } + } + + if err := DeleteNodeState(stateDir, node.ID); err != nil { + warnings = append(warnings, "state cleanup failed: "+err.Error()) + } + if err := DeleteNodeFile(inventoryDir, node.ID); err != nil { + warnings = append(warnings, "inventory cleanup failed: "+err.Error()) + } + + return warnings +} + +func UpgradeNode(ctx context.Context, runner SSHExecutor, node Node, stateDir string) (*NodeState, error) { + if _, err := BootstrapNode(ctx, runner, node, BootstrapOptions{ + StateDir: stateDir, + DryRun: false, + }); err != nil { + return nil, err + } + + state, err := CheckNode(ctx, runner, node, stateDir) + if state != nil { + if state.Metadata == nil { + state.Metadata = map[string]any{} + } + state.Metadata["lifecycle_action"] = "upgrade" + _ = SaveNodeState(stateDir, *state) + } + return state, err +} + +func RepairReinstallNode(ctx context.Context, runner SSHExecutor, node Node, stateDir string) (*NodeState, error) { + return reinstallNode(ctx, runner, node, stateDir, "repair_reinstall") +} + +func CleanReinstallNode(ctx context.Context, runner SSHExecutor, node Node, stateDir string) (*NodeState, error) { + return reinstallNode(ctx, runner, node, stateDir, "clean_reinstall") +} + +func reinstallNode(ctx context.Context, runner SSHExecutor, node Node, stateDir, action string) (*NodeState, error) { + cleanupWarning := "" + if strings.TrimSpace(node.Host) != "" { + if _, err := runner.Run(ctx, node, RenderDestroyScript()); err != nil { + cleanupWarning = err.Error() + } + } + + if _, err := BootstrapNode(ctx, runner, node, BootstrapOptions{ + StateDir: stateDir, + DryRun: false, + }); err != nil { + return nil, err + } + + state, err := CheckNode(ctx, runner, node, stateDir) + if state != nil { + if state.Metadata == nil { + state.Metadata = map[string]any{} + } + state.Metadata["lifecycle_action"] = action + if cleanupWarning != "" { + state.Metadata["cleanup_warning"] = cleanupWarning + } + _ = SaveNodeState(stateDir, *state) + } + return state, err +} + +func RenderDestroyScript() string { + return `set -eu +if [ -f /opt/vpnem-node/current/docker-compose.yml ]; then + if command -v docker >/dev/null 2>&1 && docker compose version >/dev/null 2>&1; then + docker compose -f /opt/vpnem-node/current/docker-compose.yml down -v || true + elif command -v docker-compose >/dev/null 2>&1; then + docker-compose -f /opt/vpnem-node/current/docker-compose.yml down -v || true + fi +fi +rm -rf /opt/vpnem-node +printf 'vpnem-node removed\n' +` +} + +func randomHex(size int) (string, error) { + buf := make([]byte, size) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func randomUUID() (string, error) { + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return "", err + } + buf[6] = (buf[6] & 0x0f) | 0x40 + buf[8] = (buf[8] & 0x3f) | 0x80 + hexID := hex.EncodeToString(buf) + return fmt.Sprintf("%s-%s-%s-%s-%s", + hexID[0:8], + hexID[8:12], + hexID[12:16], + hexID[16:20], + hexID[20:32], + ), nil +} + +func RandomHexForAPI(size int) (string, error) { return randomHex(size) } + +func RandomBase64ForAPI(size int) (string, error) { return randomBase64(size) } + +func RandomUUIDForAPI() (string, error) { return randomUUID() } diff --git a/internal/control/lifecycle_test.go b/internal/control/lifecycle_test.go new file mode 100644 index 0000000..2d9958c --- /dev/null +++ b/internal/control/lifecycle_test.go @@ -0,0 +1,149 @@ +package control + +import ( + "context" + "testing" +) + +func TestSetNodeEnabled(t *testing.T) { + t.Parallel() + + node := Node{ID: "nl-01", Enabled: true} + disabled := SetNodeEnabled(node, false) + if disabled.Enabled { + t.Fatal("expected node to be disabled") + } + if node.Enabled != true { + t.Fatal("expected original node to stay unchanged") + } +} + +func TestRotateNodeSecrets(t *testing.T) { + t.Parallel() + + node := Node{ + ID: "nl-01", + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, Auth: &AuthProfile{UUID: "old-vless"}}, + {Type: "vmess", Enabled: true, Port: 8444, Auth: &AuthProfile{UUID: "old-vmess"}}, + {Type: "shadowsocks", Enabled: true, Port: 8443, Auth: &AuthProfile{Method: "2022-blake3-aes-128-gcm", Password: "old-ss"}}, + {Type: "hysteria2", Enabled: true, Port: 9443, Auth: &AuthProfile{Password: "old-hy2"}, Extra: map[string]any{"obfs_password": "old-obfs"}}, + }, + } + + rotated, err := RotateNodeSecrets(node) + if err != nil { + t.Fatalf("RotateNodeSecrets() error = %v", err) + } + + if rotated.Protocols[0].Auth.UUID == "old-vless" || rotated.Protocols[0].Auth.UUID == "" { + t.Fatal("expected rotated vless uuid") + } + if rotated.Protocols[1].Auth.UUID == "old-vmess" || rotated.Protocols[1].Auth.UUID == "" { + t.Fatal("expected rotated vmess uuid") + } + if rotated.Protocols[2].Auth.Password == "old-ss" || rotated.Protocols[2].Auth.Password == "" { + t.Fatal("expected rotated shadowsocks password") + } + if rotated.Protocols[3].Auth.Password == "old-hy2" || rotated.Protocols[3].Auth.Password == "" { + t.Fatal("expected rotated hysteria2 password") + } + if rotated.Protocols[3].Extra["obfs_password"] == "old-obfs" || rotated.Protocols[3].Extra["obfs_password"] == "" { + t.Fatal("expected rotated hysteria2 obfs password") + } +} + +func TestAddSocks5Protocol(t *testing.T) { + t.Parallel() + + node, err := AddSocks5Protocol(Node{ + ID: "nl-01", + Protocols: []ProtocolProfile{ + {Type: "vless-reality", Enabled: true, Port: 443}, + {Type: "hysteria2", Enabled: true, Port: 443}, + }, + }, 54101) + if err != nil { + t.Fatalf("AddSocks5Protocol() error = %v", err) + } + if len(node.Protocols) != 3 { + t.Fatalf("expected 3 protocols, got %d", len(node.Protocols)) + } + last := node.Protocols[len(node.Protocols)-1] + if last.Type != "socks5" || last.Port != 54101 || !last.Enabled { + t.Fatalf("unexpected socks5 protocol: %+v", last) + } +} + +func TestRepairReinstallNode(t *testing.T) { + t.Parallel() + + state, err := RepairReinstallNode(context.Background(), fakeRunner{}, Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key", IdentityFile: "~/.ssh/id_ed25519"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, TLS: &TLSProfile{Enabled: true, ServerName: "nl-01.example.com"}, Auth: &AuthProfile{UUID: "11111111-1111-1111-1111-111111111111"}, Extra: map[string]any{"path": "/ws"}}, + }, + }, t.TempDir()) + if err != nil { + t.Fatalf("RepairReinstallNode() error = %v", err) + } + if state == nil { + t.Fatal("expected state") + } + if state.BootstrapStatus != "healthy" { + t.Fatalf("BootstrapStatus = %q, want healthy", state.BootstrapStatus) + } + if got := state.Metadata["lifecycle_action"]; got != "repair_reinstall" { + t.Fatalf("lifecycle_action = %v, want repair_reinstall", got) + } +} + +func TestCleanReinstallNode(t *testing.T) { + t.Parallel() + + state, err := CleanReinstallNode(context.Background(), fakeRunner{}, Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key", IdentityFile: "~/.ssh/id_ed25519"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, TLS: &TLSProfile{Enabled: true, ServerName: "nl-01.example.com"}, Auth: &AuthProfile{UUID: "11111111-1111-1111-1111-111111111111"}, Extra: map[string]any{"path": "/ws"}}, + }, + }, t.TempDir()) + if err != nil { + t.Fatalf("CleanReinstallNode() error = %v", err) + } + if state == nil { + t.Fatal("expected state") + } + if state.BootstrapStatus != "healthy" { + t.Fatalf("BootstrapStatus = %q, want healthy", state.BootstrapStatus) + } + if got := state.Metadata["lifecycle_action"]; got != "clean_reinstall" { + t.Fatalf("lifecycle_action = %v, want clean_reinstall", got) + } +} + +func TestParsePreflightInspectOutput(t *testing.T) { + t.Parallel() + + data := ParsePreflightInspectOutput("OS_ID=ubuntu\nMANAGED=1\nTCP_443=0\n") + if data["OS_ID"] != "ubuntu" { + t.Fatalf("OS_ID = %q, want ubuntu", data["OS_ID"]) + } + if data["MANAGED"] != "1" { + t.Fatalf("MANAGED = %q, want 1", data["MANAGED"]) + } + if data["TCP_443"] != "0" { + t.Fatalf("TCP_443 = %q, want 0", data["TCP_443"]) + } +} diff --git a/internal/control/models.go b/internal/control/models.go new file mode 100644 index 0000000..bec8e89 --- /dev/null +++ b/internal/control/models.go @@ -0,0 +1,66 @@ +package control + +type Node struct { + ID string `yaml:"id" json:"id"` + Name string `yaml:"name" json:"name"` + Provider string `yaml:"provider" json:"provider"` + Region string `yaml:"region" json:"region"` + Host string `yaml:"host" json:"host"` + Domain string `yaml:"domain,omitempty" json:"domain,omitempty"` + ACMEEmail string `yaml:"acme_email,omitempty" json:"acme_email,omitempty"` + Enabled bool `yaml:"enabled" json:"enabled"` + SSH SSHConfig `yaml:"ssh" json:"ssh"` + Protocols []ProtocolProfile `yaml:"protocols" json:"protocols"` + Tags []string `yaml:"tags,omitempty" json:"tags,omitempty"` + Metadata map[string]string `yaml:"metadata,omitempty" json:"metadata,omitempty"` +} + +type SSHConfig struct { + User string `yaml:"user" json:"user"` + Port int `yaml:"port" json:"port"` + Auth string `yaml:"auth" json:"auth"` + IdentityFile string `yaml:"identity_file,omitempty" json:"identity_file,omitempty"` + PasswordEnv string `yaml:"password_env,omitempty" json:"password_env,omitempty"` + Password string `yaml:"-" json:"-"` +} + +type ProtocolProfile struct { + Type string `yaml:"type" json:"type"` + Enabled bool `yaml:"enabled" json:"enabled"` + Port int `yaml:"port" json:"port"` + TLS *TLSProfile `yaml:"tls,omitempty" json:"tls,omitempty"` + Auth *AuthProfile `yaml:"auth,omitempty" json:"auth,omitempty"` + Reality *VLESSRealityProfile `yaml:"reality,omitempty" json:"reality,omitempty"` + Hysteria2 *Hysteria2Profile `yaml:"hysteria2,omitempty" json:"hysteria2,omitempty"` + Extra map[string]any `yaml:"extra,omitempty" json:"extra,omitempty"` +} + +type TLSProfile struct { + Enabled bool `yaml:"enabled" json:"enabled"` + ServerName string `yaml:"server_name,omitempty" json:"server_name,omitempty"` +} + +type AuthProfile struct { + UUID string `yaml:"uuid,omitempty" json:"uuid,omitempty"` + Method string `yaml:"method,omitempty" json:"method,omitempty"` + Password string `yaml:"password,omitempty" json:"password,omitempty"` +} + +type VLESSRealityProfile struct { + ServerName string `yaml:"server_name" json:"server_name"` + ServerPort int `yaml:"server_port,omitempty" json:"server_port,omitempty"` + PrivateKey string `yaml:"private_key,omitempty" json:"private_key,omitempty"` + PublicKey string `yaml:"public_key,omitempty" json:"public_key,omitempty"` + ShortID string `yaml:"short_id,omitempty" json:"short_id,omitempty"` + Fingerprint string `yaml:"fingerprint,omitempty" json:"fingerprint,omitempty"` +} + +type Hysteria2Profile struct { + Port int `yaml:"port,omitempty" json:"port,omitempty"` + UpMbps int `yaml:"up_mbps,omitempty" json:"up_mbps,omitempty"` + DownMbps int `yaml:"down_mbps,omitempty" json:"down_mbps,omitempty"` + ObfsPassword string `yaml:"obfs_password,omitempty" json:"obfs_password,omitempty"` + UserPassword string `yaml:"user_password,omitempty" json:"user_password,omitempty"` + CertPath string `yaml:"cert_path,omitempty" json:"cert_path,omitempty"` + KeyPath string `yaml:"key_path,omitempty" json:"key_path,omitempty"` +} diff --git a/internal/control/preflight.go b/internal/control/preflight.go new file mode 100644 index 0000000..44db7d0 --- /dev/null +++ b/internal/control/preflight.go @@ -0,0 +1,68 @@ +package control + +import "strings" + +func RenderPreflightInspectScript() string { + return `set -eu +if [ -r /etc/os-release ]; then + . /etc/os-release +fi +printf 'OS_ID=%s\n' "${ID:-}" +printf 'OS_PRETTY=%s\n' "${PRETTY_NAME:-}" +printf 'OS_LIKE=%s\n' "${ID_LIKE:-}" +if [ -d /opt/vpnem-node/current ]; then + printf 'MANAGED=1\n' +else + printf 'MANAGED=0\n' +fi +if command -v docker >/dev/null 2>&1; then + printf 'DOCKER=1\n' +else + printf 'DOCKER=0\n' +fi +if command -v docker >/dev/null 2>&1 && docker compose version >/dev/null 2>&1; then + printf 'COMPOSE=1\n' +elif command -v docker-compose >/dev/null 2>&1; then + printf 'COMPOSE=1\n' +else + printf 'COMPOSE=0\n' +fi +if command -v ss >/dev/null 2>&1; then + if ss -lnt 2>/dev/null | awk 'NR>1 {print $4}' | grep -Eq '(^|[:.])443$'; then + printf 'TCP_443=1\n' + else + printf 'TCP_443=0\n' + fi + if ss -lnu 2>/dev/null | awk 'NR>1 {print $4}' | grep -Eq '(^|[:.])443$'; then + printf 'UDP_443=1\n' + else + printf 'UDP_443=0\n' + fi + if ss -lnt 2>/dev/null | awk 'NR>1 {print $4}' | grep -Eq '(^|[:.])54101$'; then + printf 'TCP_54101=1\n' + else + printf 'TCP_54101=0\n' + fi +else + printf 'TCP_443=unknown\n' + printf 'UDP_443=unknown\n' + printf 'TCP_54101=unknown\n' +fi +` +} + +func ParsePreflightInspectOutput(stdout string) map[string]string { + values := map[string]string{} + for _, line := range strings.Split(stdout, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + values[strings.TrimSpace(key)] = strings.TrimSpace(value) + } + return values +} diff --git a/internal/control/publish.go b/internal/control/publish.go new file mode 100644 index 0000000..d05e98b --- /dev/null +++ b/internal/control/publish.go @@ -0,0 +1,321 @@ +package control + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strconv" + + "vpnem/internal/models" +) + +type PublishDecision struct { + NodeID string `json:"node_id"` + Eligible bool `json:"eligible"` + Reasons []string `json:"reasons,omitempty"` + PublicHost string `json:"public_host,omitempty"` + Status string `json:"status,omitempty"` +} + +func PublishableNodes(nodes []Node, states map[string]*NodeState) []Node { + filtered := make([]Node, 0, len(nodes)) + for _, node := range nodes { + if PublishDecisionForNode(node, states[node.ID]).Eligible { + filtered = append(filtered, node) + } + } + return filtered +} + +func NodeStateReadyForPublish(state NodeState) bool { + if state.BootstrapStatus != "healthy" && state.BootstrapStatus != "ready" { + return false + } + + if code, ok := state.Metadata["healthz_http_code"]; ok { + switch v := code.(type) { + case int: + if v != 200 { + return false + } + case float64: + if int(v) != 200 { + return false + } + } + } + + if len(state.Services) == 0 { + return true + } + for _, service := range state.Services { + if service.Status != "running" { + return false + } + } + return true +} + +func PublishDecisionForNode(node Node, state *NodeState) PublishDecision { + decision := PublishDecision{ + NodeID: node.ID, + Eligible: false, + PublicHost: publicHost(node), + } + + if !node.Enabled { + decision.Reasons = append(decision.Reasons, "узел выключен") + return decision + } + if state == nil { + decision.Reasons = append(decision.Reasons, "нет сохранённого состояния узла") + return decision + } + + decision.Status = state.BootstrapStatus + if state.PublicHost != "" { + decision.PublicHost = state.PublicHost + } + + if state.BootstrapStatus != "healthy" && state.BootstrapStatus != "ready" { + decision.Reasons = append(decision.Reasons, "статус bootstrap: "+state.BootstrapStatus) + return decision + } + + if code, ok := state.Metadata["healthz_http_code"]; ok { + switch v := code.(type) { + case int: + if v != 200 { + decision.Reasons = append(decision.Reasons, "healthz_http_code: "+itoa(v)) + } + case float64: + if int(v) != 200 { + decision.Reasons = append(decision.Reasons, "healthz_http_code: "+itoa(int(v))) + } + } + } + + for _, service := range state.Services { + if service.Status != "running" { + decision.Reasons = append(decision.Reasons, "сервис "+service.Type+" имеет статус "+service.Status) + } + } + + decision.Eligible = len(decision.Reasons) == 0 + return decision +} + +func PublishDecisions(nodes []Node, states map[string]*NodeState) map[string]PublishDecision { + decisions := make(map[string]PublishDecision, len(nodes)) + for _, node := range nodes { + decisions[node.ID] = PublishDecisionForNode(node, states[node.ID]) + } + return decisions +} + +func itoa(v int) string { return strconv.Itoa(v) } + +func BuildCatalogV2(nodes []Node, states map[string]*NodeState) *models.CatalogV2 { + result := &models.CatalogV2{ + Version: "2", + Nodes: make([]models.CatalogNode, 0, len(nodes)), + } + + for _, node := range nodes { + publicHost := node.Host + if state := states[node.ID]; state != nil && state.PublicHost != "" { + publicHost = state.PublicHost + } else if node.Domain != "" { + publicHost = node.Domain + } + + catalogNode := models.CatalogNode{ + ID: node.ID, + Name: node.Name, + Provider: node.Provider, + Region: node.Region, + Host: node.Host, + Domain: node.Domain, + PublicHost: publicHost, + Tags: node.Tags, + Metadata: map[string]any{}, + Protocols: make([]models.CatalogProtocol, 0, len(node.Protocols)), + } + if state := states[node.ID]; state != nil { + catalogNode.Status = state.BootstrapStatus + for k, v := range state.Metadata { + catalogNode.Metadata[k] = v + } + } + + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + if err := ensureRealityProfile(&protocol); err != nil { + continue + } + item := models.CatalogProtocol{ + Type: protocol.Type, + Enabled: protocol.Enabled, + Port: protocol.Port, + Extra: protocol.Extra, + } + if protocol.TLS != nil { + item.TLS = &models.TLS{ + Enabled: protocol.TLS.Enabled, + ServerName: protocol.TLS.ServerName, + Insecure: false, + } + } + if protocol.Type == "vless-reality" && protocol.Reality != nil { + item.TLS = &models.TLS{ + Enabled: true, + ServerName: protocol.Reality.ServerName, + Reality: &models.Reality{ + Enabled: true, + PublicKey: protocol.Reality.PublicKey, + ShortID: protocol.Reality.ShortID, + Fingerprint: protocol.Reality.Fingerprint, + }, + } + } + if protocol.Type == "hysteria2" { + if item.TLS == nil { + item.TLS = &models.TLS{} + } + item.TLS.Enabled = true + item.TLS.Insecure = true + if len(item.TLS.ALPN) == 0 { + item.TLS.ALPN = []string{defaultHysteria2ALPN} + } + if item.TLS.MinVersion == "" { + item.TLS.MinVersion = "1.3" + } + if item.TLS.MaxVersion == "" { + item.TLS.MaxVersion = "1.3" + } + } + if protocol.Auth != nil { + item.Auth = &models.CatalogAuth{ + UUID: protocol.Auth.UUID, + Method: protocol.Auth.Method, + Password: protocol.Auth.Password, + } + } + catalogNode.Protocols = append(catalogNode.Protocols, item) + } + result.Nodes = append(result.Nodes, catalogNode) + } + + return result +} + +func WriteCatalogV2(path string, nodes []Node, states map[string]*NodeState) error { + catalog := BuildCatalogV2(nodes, states) + + data, err := json.MarshalIndent(catalog, "", " ") + if err != nil { + return err + } + data = append(data, '\n') + + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + tmpPath := path + ".tmp" + if err := os.WriteFile(tmpPath, data, 0o644); err != nil { + return err + } + return os.Rename(tmpPath, path) +} + +func StaticCatalogNodesFromLegacy(servers []models.Server) []models.CatalogNode { + nodes := make([]models.CatalogNode, 0, len(servers)) + for _, server := range servers { + node := models.CatalogNode{ + ID: server.Tag, + Name: server.Tag, + Region: server.Region, + Host: server.Server, + PublicHost: server.Server, + Status: "static", + Metadata: map[string]any{ + "static_legacy": true, + }, + Protocols: []models.CatalogProtocol{ + { + Type: server.Type, + Enabled: true, + Port: server.ServerPort, + Auth: &models.CatalogAuth{ + UUID: server.UUID, + Method: server.Method, + Password: server.Password, + }, + TLS: server.TLS, + Extra: map[string]any{}, + }, + }, + } + if server.UDPOverTCP { + node.Protocols[0].Extra["udp_over_tcp"] = true + } + if server.ObfsPassword != "" { + node.Protocols[0].Extra["obfs_password"] = server.ObfsPassword + } + if server.UpMbps > 0 { + node.Protocols[0].Extra["up_mbps"] = server.UpMbps + } + if server.DownMbps > 0 { + node.Protocols[0].Extra["down_mbps"] = server.DownMbps + } + if server.Transport != nil { + node.Protocols[0].Extra["transport_type"] = server.Transport.Type + if server.Transport.Path != "" { + node.Protocols[0].Extra["path"] = server.Transport.Path + } + } + nodes = append(nodes, node) + } + return nodes +} + +func MergeCatalogNodes(primary, secondary []models.CatalogNode) []models.CatalogNode { + merged := make([]models.CatalogNode, 0, len(primary)+len(secondary)) + seen := make(map[string]struct{}, len(primary)+len(secondary)) + for _, item := range primary { + if _, ok := seen[item.ID]; ok { + continue + } + seen[item.ID] = struct{}{} + merged = append(merged, item) + } + for _, item := range secondary { + if _, ok := seen[item.ID]; ok { + continue + } + seen[item.ID] = struct{}{} + merged = append(merged, item) + } + return merged +} + +func PublishLegacyCatalog(ctx context.Context, nodes []Node, targetPath string, remoteNode *Node) error { + if remoteNode == nil { + return WriteLegacyCatalog(targetPath, nodes) + } + + tmpDir, err := os.MkdirTemp("", "vpnem-publish-*") + if err != nil { + return err + } + defer os.RemoveAll(tmpDir) + + localPath := filepath.Join(tmpDir, "servers.json") + if err := WriteLegacyCatalog(localPath, nodes); err != nil { + return err + } + return CopyFileOverSCP(ctx, *remoteNode, localPath, targetPath) +} diff --git a/internal/control/reality.go b/internal/control/reality.go new file mode 100644 index 0000000..301a674 --- /dev/null +++ b/internal/control/reality.go @@ -0,0 +1,64 @@ +package control + +import ( + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "strings" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const defaultRealityServerName = "www.nokia.com" + +func ensureRealityProfile(protocol *ProtocolProfile) error { + if protocol == nil || protocol.Type != "vless-reality" { + return nil + } + if protocol.Reality == nil { + protocol.Reality = &VLESSRealityProfile{} + } + if strings.TrimSpace(protocol.Reality.ServerName) == "" { + protocol.Reality.ServerName = defaultRealityServerName + } + if protocol.Reality.ServerPort == 0 { + protocol.Reality.ServerPort = 443 + } + if strings.TrimSpace(protocol.Reality.Fingerprint) == "" { + protocol.Reality.Fingerprint = "chrome" + } + if strings.TrimSpace(protocol.Reality.PrivateKey) == "" || strings.TrimSpace(protocol.Reality.PublicKey) == "" { + privateKey, publicKey, err := generateRealityKeyPair() + if err != nil { + return err + } + protocol.Reality.PrivateKey = privateKey + protocol.Reality.PublicKey = publicKey + } + if strings.TrimSpace(protocol.Reality.ShortID) == "" { + shortID, err := generateRealityShortID() + if err != nil { + return err + } + protocol.Reality.ShortID = shortID + } + return nil +} + +func generateRealityKeyPair() (privateKey string, publicKey string, err error) { + privateKeyPair, err := wgtypes.GeneratePrivateKey() + if err != nil { + return "", "", err + } + publicKeyPair := privateKeyPair.PublicKey() + return base64.RawURLEncoding.EncodeToString(privateKeyPair[:]), base64.RawURLEncoding.EncodeToString(publicKeyPair[:]), nil +} + +func generateRealityShortID() (string, error) { + var raw [8]byte + if _, err := rand.Read(raw[:]); err != nil { + return "", fmt.Errorf("generate reality short id: %w", err) + } + return hex.EncodeToString(raw[:]), nil +} diff --git a/internal/control/runtime.go b/internal/control/runtime.go new file mode 100644 index 0000000..93138b6 --- /dev/null +++ b/internal/control/runtime.go @@ -0,0 +1,586 @@ +package control + +import ( + "archive/tar" + "compress/gzip" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "vpnem/internal/config" +) + +type RuntimeBundleMeta struct { + ReleaseID string `json:"release_id"` + CreatedAt string `json:"created_at"` + NodeID string `json:"node_id"` +} + +func RenderRuntimeBundle(dir string, node Node, releaseID string) error { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + for idx := range node.Protocols { + if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { + return err + } + if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { + return err + } + } + + meta := RuntimeBundleMeta{ + ReleaseID: releaseID, + CreatedAt: time.Now().UTC().Format(time.RFC3339), + NodeID: node.ID, + } + + files := map[string][]byte{} + + nodeJSON, err := json.MarshalIndent(node, "", " ") + if err != nil { + return err + } + files["node.json"] = append(nodeJSON, '\n') + + metaJSON, err := json.MarshalIndent(meta, "", " ") + if err != nil { + return err + } + files["bundle-meta.json"] = append(metaJSON, '\n') + + files["node.env"] = []byte(renderNodeEnv(node)) + files["docker-compose.yml"] = []byte(renderRuntimeCompose(node)) + files["README.md"] = []byte(renderRuntimeReadme(node)) + if hasHysteria2(node) { + certHost := hysteria2CertificateHost(node) + certPEM, keyPEM, err := generateSelfSignedCertForHost(certHost) + if err != nil { + return err + } + files["cert.pem"] = certPEM + files["key.pem"] = keyPEM + } + if config, ok, err := renderSingBoxServerConfig(node); err != nil { + return err + } else if ok { + files["sing-box.server.json"] = []byte(config) + if needsEdgeProxy(node) { + files["Caddyfile"] = []byte(renderCaddyfile(node)) + } + } + + for name, data := range files { + path := filepath.Join(dir, name) + if err := os.WriteFile(path, data, 0o644); err != nil { + return err + } + } + + return nil +} + +func CreateTarGzFromDir(srcDir, outPath string) error { + outFile, err := os.Create(outPath) + if err != nil { + return err + } + defer outFile.Close() + + gzw := gzip.NewWriter(outFile) + defer gzw.Close() + + tw := tar.NewWriter(gzw) + defer tw.Close() + + return filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + header.Name = filepath.ToSlash(relPath) + + if err := tw.WriteHeader(header); err != nil { + return err + } + + data, err := os.ReadFile(path) + if err != nil { + return err + } + _, err = tw.Write(data) + return err + }) +} + +func renderNodeEnv(node Node) string { + var b strings.Builder + writeEnv := func(key, value string) { + b.WriteString(key) + b.WriteString("=") + b.WriteString(sanitizeEnv(value)) + b.WriteString("\n") + } + + writeEnv("NODE_ID", node.ID) + writeEnv("NODE_NAME", node.Name) + writeEnv("NODE_PROVIDER", node.Provider) + writeEnv("NODE_REGION", node.Region) + writeEnv("NODE_HOST", node.Host) + writeEnv("NODE_DOMAIN", node.Domain) + writeEnv("NODE_ACME_EMAIL", node.ACMEEmail) + return b.String() +} + +func renderRuntimeCompose(node Node) string { + if needsSingBoxRuntime(node) { + return renderSingBoxCompose(node) + } + + var b strings.Builder + b.WriteString("services:\n") + b.WriteString(" node-info:\n") + b.WriteString(" image: nginx:alpine\n") + b.WriteString(" restart: unless-stopped\n") + b.WriteString(" ports:\n") + b.WriteString(" - \"127.0.0.1:18080:80\"\n") + b.WriteString(" volumes:\n") + b.WriteString(" - ./node.json:/usr/share/nginx/html/index.json:ro\n") + b.WriteString(" - ./README.md:/usr/share/nginx/html/README.md:ro\n") + return b.String() +} + +func renderRuntimeReadme(node Node) string { + if hasHysteria2(node) { + profile := firstHysteria2Profile(node) + return fmt.Sprintf( + "# vpnem node bundle\n\nThis bundle was generated for node `%s` in region `%s`.\n\nIncluded runtime:\n- sing-box server with a Hysteria2 inbound on UDP `%d`\n- embedded self-signed TLS certificate\n- Salamander obfuscation enabled\n- local mixed health inbound on `127.0.0.1:1080`\n", + node.ID, + node.Region, + defaultInt(profile.Port, defaultHysteria2Port), + ) + } + if usesVLESSReality(node) { + reality := firstRealityProfile(node) + return fmt.Sprintf( + "# vpnem node bundle\n\nThis bundle was generated for node `%s` in region `%s`.\n\nIncluded runtime:\n- sing-box server with a VLESS REALITY inbound on `%d`\n- no ACME or Caddy layer is required\n- REALITY handshake destination `%s:%d`\n", + node.ID, + node.Region, + realityPort(node), + reality.ServerName, + reality.ServerPort, + ) + } + if usesVLESSTLS(node) { + return fmt.Sprintf( + "# vpnem node bundle\n\nThis bundle was generated for node `%s` in region `%s`.\n\nIncluded runtime:\n- sing-box server with a VLESS inbound on loopback\n- Caddy terminating HTTPS with ACME certificates for `%s`\n\nRequirements:\n- the domain must resolve to this VPS\n- ports 80 and 443 must be reachable from the internet\n- acme_email should be set for certificate issuance\n", + node.ID, + node.Region, + node.Domain, + ) + } + + return fmt.Sprintf( + "# vpnem node bundle\n\nThis bundle was generated for node `%s` in region `%s`.\n\nIt contains inventory metadata and a minimal runtime placeholder. Replace or extend the runtime services as protocol-specific deployers are added.\n", + node.ID, + node.Region, + ) +} + +func sanitizeEnv(value string) string { + value = strings.ReplaceAll(value, "\n", "") + return value +} + +func usesVLESSTLS(node Node) bool { + for _, protocol := range node.Protocols { + if protocol.Type == "vless" && protocol.Enabled && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) != "" { + return true + } + } + return false +} + +func usesVLESSReality(node Node) bool { + for _, protocol := range node.Protocols { + if protocol.Type == "vless-reality" && protocol.Enabled { + return true + } + } + return false +} + +func usesVMessTLS(node Node) bool { + for _, protocol := range node.Protocols { + if protocol.Type == "vmess" && protocol.Enabled && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) != "" { + return true + } + } + return false +} + +func needsEdgeProxy(node Node) bool { + return usesVLESSTLS(node) || usesVMessTLS(node) +} + +func needsSingBoxRuntime(node Node) bool { + for _, protocol := range node.Protocols { + if protocol.Enabled { + return true + } + } + return false +} + +func renderSingBoxCompose(node Node) string { + var b strings.Builder + b.WriteString("services:\n") + b.WriteString(" sing-box:\n") + b.WriteString(" image: ghcr.io/sagernet/sing-box:v1.12.20\n") + b.WriteString(" restart: unless-stopped\n") + b.WriteString(" command: [\"run\", \"-c\", \"/etc/sing-box/config.json\"]\n") + if isHysteria2Only(node) { + hy2Port := defaultInt(firstHysteria2Profile(node).Port, defaultHysteria2Port) + b.WriteString(" ports:\n") + b.WriteString(fmt.Sprintf(" - \"%d:%d/udp\"\n", hy2Port, hy2Port)) + b.WriteString(" - \"127.0.0.1:1080:1080/tcp\"\n") + } else { + b.WriteString(" network_mode: host\n") + } + b.WriteString(" volumes:\n") + b.WriteString(" - ./sing-box.server.json:/etc/sing-box/config.json:ro\n") + if hasHysteria2(node) { + b.WriteString(" - ./cert.pem:/etc/sing-box/cert.pem:ro\n") + b.WriteString(" - ./key.pem:/etc/sing-box/key.pem:ro\n") + } + b.WriteString("\n") + if needsEdgeProxy(node) { + b.WriteString(" caddy:\n") + b.WriteString(" image: caddy:2\n") + b.WriteString(" restart: unless-stopped\n") + b.WriteString(" network_mode: host\n") + b.WriteString(" depends_on:\n") + b.WriteString(" - sing-box\n") + b.WriteString(" environment:\n") + if strings.TrimSpace(node.ACMEEmail) != "" { + b.WriteString(" ACME_EMAIL: " + node.ACMEEmail + "\n") + } + b.WriteString(" volumes:\n") + b.WriteString(" - ./Caddyfile:/etc/caddy/Caddyfile:ro\n") + b.WriteString(" - caddy_data:/data\n") + b.WriteString(" - caddy_config:/config\n") + b.WriteString("\n") + b.WriteString("volumes:\n") + b.WriteString(" caddy_data:\n") + b.WriteString(" caddy_config:\n") + } + return b.String() +} + +func renderSingBoxServerConfig(node Node) (string, bool, error) { + inbounds := make([]map[string]any, 0) + if !needsSingBoxRuntime(node) { + return "", false, nil + } + + if vless, ok := findProtocol(node, "vless"); ok && vless.Enabled { + if vless.Auth == nil || strings.TrimSpace(vless.Auth.UUID) == "" { + return "", false, fmt.Errorf("vless runtime requires auth.uuid") + } + inbound := map[string]any{ + "type": "vless", + "tag": "vless-in", + "users": []map[string]any{ + {"uuid": vless.Auth.UUID}, + }, + } + path := stringFromExtra(vless.Extra, "path") + if path == "" { + path = "/ws" + } + if vless.TLS != nil && vless.TLS.Enabled && strings.TrimSpace(node.Domain) != "" { + inbound["listen"] = "127.0.0.1" + inbound["listen_port"] = 10443 + inbound["transport"] = map[string]any{ + "type": "ws", + "path": path, + } + } else { + inbound["listen"] = "0.0.0.0" + inbound["listen_port"] = vless.Port + } + inbounds = append(inbounds, inbound) + } + + if reality, ok := findProtocol(node, "vless-reality"); ok && reality.Enabled { + if reality.Auth == nil || strings.TrimSpace(reality.Auth.UUID) == "" { + return "", false, fmt.Errorf("vless-reality runtime requires auth.uuid") + } + if err := ensureRealityProfile(&reality); err != nil { + return "", false, err + } + inbound := map[string]any{ + "type": "vless", + "tag": "vless-reality-in", + "listen": "::", + "listen_port": reality.Port, + "users": []map[string]any{ + {"uuid": reality.Auth.UUID}, + }, + "tls": map[string]any{ + "enabled": true, + "server_name": reality.Reality.ServerName, + "reality": map[string]any{ + "enabled": true, + "handshake": map[string]any{ + "server": reality.Reality.ServerName, + "server_port": defaultInt(reality.Reality.ServerPort, 443), + }, + "private_key": reality.Reality.PrivateKey, + "short_id": []string{reality.Reality.ShortID}, + }, + }, + } + inbounds = append(inbounds, inbound) + } + + if ss, ok := findProtocol(node, "shadowsocks"); ok && ss.Enabled { + if ss.Auth == nil || strings.TrimSpace(ss.Auth.Method) == "" || strings.TrimSpace(ss.Auth.Password) == "" { + return "", false, fmt.Errorf("shadowsocks runtime requires auth.method and auth.password") + } + inbounds = append(inbounds, map[string]any{ + "type": "shadowsocks", + "tag": "ss-in", + "listen": "0.0.0.0", + "listen_port": ss.Port, + "method": ss.Auth.Method, + "password": ss.Auth.Password, + }) + } + + if socks, ok := findProtocol(node, "socks"); ok && socks.Enabled { + inbounds = append(inbounds, map[string]any{ + "type": "socks", + "tag": "socks-in", + "listen": "0.0.0.0", + "listen_port": socks.Port, + }) + } + if socks, ok := findProtocol(node, "socks5"); ok && socks.Enabled { + inbounds = append(inbounds, map[string]any{ + "type": "socks", + "tag": "socks5-in", + "listen": "0.0.0.0", + "listen_port": socks.Port, + }) + } + + if vmess, ok := findProtocol(node, "vmess"); ok && vmess.Enabled { + if vmess.Auth == nil || strings.TrimSpace(vmess.Auth.UUID) == "" { + return "", false, fmt.Errorf("vmess runtime requires auth.uuid") + } + inbound := map[string]any{ + "type": "vmess", + "tag": "vmess-in", + "users": []map[string]any{ + {"uuid": vmess.Auth.UUID, "alterId": 0}, + }, + } + path := stringFromExtra(vmess.Extra, "path") + if path == "" { + path = "/vmess" + } + if vmess.TLS != nil && vmess.TLS.Enabled && strings.TrimSpace(node.Domain) != "" { + inbound["listen"] = "127.0.0.1" + inbound["listen_port"] = 10444 + inbound["transport"] = map[string]any{ + "type": "ws", + "path": path, + } + } else { + inbound["listen"] = "0.0.0.0" + inbound["listen_port"] = vmess.Port + } + inbounds = append(inbounds, inbound) + } + + if hy2, ok := findProtocol(node, "hysteria2"); ok && hy2.Enabled { + profile := hy2.Hysteria2 + if profile == nil { + return "", false, fmt.Errorf("hysteria2 runtime requires hysteria2 settings") + } + inboundConfig, err := config.BuildHysteria2Inbound(node, hy2.Port, profile.UserPassword, profile.ObfsPassword, profile.UpMbps, profile.DownMbps, profile.CertPath, profile.KeyPath) + if err != nil { + return "", false, err + } + inbound := map[string]any(*inboundConfig) + inbound["users"] = []map[string]any{ + {"name": node.ID, "password": profile.UserPassword}, + } + inbounds = append(inbounds, inbound) + if needsHysteria2HealthInbound(node) { + inbounds = append(inbounds, map[string]any{ + "type": "mixed", + "tag": "hy2-health-in", + "listen": "127.0.0.1", + "listen_port": 1080, + }) + } + } + + config := map[string]any{ + "log": map[string]any{"level": "info"}, + "inbounds": inbounds, + "outbounds": []map[string]any{ + {"type": "direct", "tag": "direct"}, + }, + } + + data, err := json.MarshalIndent(config, "", " ") + if err != nil { + return "", false, err + } + return string(data) + "\n", true, nil +} + +func renderCaddyfile(node Node) string { + var b strings.Builder + b.WriteString("{\n") + if strings.TrimSpace(node.ACMEEmail) != "" { + b.WriteString(" email ") + b.WriteString(node.ACMEEmail) + b.WriteString("\n") + } + b.WriteString("}\n\n") + b.WriteString(node.Domain) + b.WriteString(" {\n") + b.WriteString(" encode zstd gzip\n") + if vless, ok := findProtocol(node, "vless"); ok && vless.Enabled && vless.TLS != nil && vless.TLS.Enabled { + path := stringFromExtra(vless.Extra, "path") + if path == "" { + path = "/ws" + } + b.WriteString(" @vless path ") + b.WriteString(path) + b.WriteString("\n") + b.WriteString(" reverse_proxy @vless 127.0.0.1:10443\n") + } + if vmess, ok := findProtocol(node, "vmess"); ok && vmess.Enabled && vmess.TLS != nil && vmess.TLS.Enabled { + path := stringFromExtra(vmess.Extra, "path") + if path == "" { + path = "/vmess" + } + b.WriteString(" @vmess path ") + b.WriteString(path) + b.WriteString("\n") + b.WriteString(" reverse_proxy @vmess 127.0.0.1:10444\n") + } + b.WriteString(" respond /healthz 200\n") + b.WriteString("}\n") + return b.String() +} + +func firstRealityProfile(node Node) VLESSRealityProfile { + for _, protocol := range node.Protocols { + if protocol.Type == "vless-reality" && protocol.Enabled && protocol.Reality != nil { + return *protocol.Reality + } + } + return VLESSRealityProfile{} +} + +func firstHysteria2Profile(node Node) Hysteria2Profile { + for _, protocol := range node.Protocols { + if protocol.Type == "hysteria2" && protocol.Enabled && protocol.Hysteria2 != nil { + return *protocol.Hysteria2 + } + } + return Hysteria2Profile{} +} + +func realityPort(node Node) int { + for _, protocol := range node.Protocols { + if protocol.Type == "vless-reality" && protocol.Enabled { + return protocol.Port + } + } + return 443 +} + +func defaultInt(value, fallback int) int { + if value > 0 { + return value + } + return fallback +} + +func findProtocol(node Node, kind string) (ProtocolProfile, bool) { + for _, protocol := range node.Protocols { + if protocol.Type == kind { + return protocol, true + } + } + return ProtocolProfile{}, false +} + +func hasHysteria2(node Node) bool { + hy2, ok := findProtocol(node, "hysteria2") + return ok && hy2.Enabled +} + +func isHysteria2Only(node Node) bool { + enabled := 0 + hy2Enabled := false + for _, protocol := range node.Protocols { + if !protocol.Enabled { + continue + } + enabled++ + if protocol.Type == "hysteria2" { + hy2Enabled = true + } + } + return enabled == 1 && hy2Enabled +} + +func needsHysteria2HealthInbound(node Node) bool { + return hasHysteria2(node) +} + +func hysteria2CertificateHost(node Node) string { + if tls, ok := findProtocol(node, "hysteria2"); ok && tls.TLS != nil && strings.TrimSpace(tls.TLS.ServerName) != "" { + return strings.TrimSpace(tls.TLS.ServerName) + } + suffix := strings.ReplaceAll(strings.ToLower(node.ID), "_", "-") + suffix = strings.ReplaceAll(suffix, " ", "-") + return "node-" + suffix + ".local" +} + +func intFromExtra(extra map[string]any, key string, fallback int) int { + if extra == nil { + return fallback + } + switch value := extra[key].(type) { + case int: + return value + case float64: + return int(value) + default: + return fallback + } +} diff --git a/internal/control/runtime_test.go b/internal/control/runtime_test.go new file mode 100644 index 0000000..1f38dd7 --- /dev/null +++ b/internal/control/runtime_test.go @@ -0,0 +1,307 @@ +package control + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestRenderRuntimeBundle(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + node := Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + ACMEEmail: "admin@example.com", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + { + Type: "vless", + Enabled: true, + Port: 443, + TLS: &TLSProfile{ + Enabled: true, + ServerName: "nl-01.example.com", + }, + Auth: &AuthProfile{ + UUID: "11111111-1111-1111-1111-111111111111", + }, + Extra: map[string]any{ + "path": "/ws", + }, + }, + }, + } + + if err := RenderRuntimeBundle(dir, node, "20260401-123000"); err != nil { + t.Fatalf("RenderRuntimeBundle error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "docker-compose.yml")) + if err != nil { + t.Fatalf("ReadFile docker-compose.yml error = %v", err) + } + if !strings.Contains(string(data), "sing-box:") { + t.Fatal("expected sing-box service in runtime compose") + } + if !strings.Contains(string(data), "caddy:") { + t.Fatal("expected caddy service in runtime compose") + } + + caddyfile, err := os.ReadFile(filepath.Join(dir, "Caddyfile")) + if err != nil { + t.Fatalf("ReadFile Caddyfile error = %v", err) + } + if !strings.Contains(string(caddyfile), "nl-01.example.com") { + t.Fatal("expected domain in Caddyfile") + } + + serverConfig, err := os.ReadFile(filepath.Join(dir, "sing-box.server.json")) + if err != nil { + t.Fatalf("ReadFile sing-box.server.json error = %v", err) + } + if !strings.Contains(string(serverConfig), "\"type\": \"vless\"") { + t.Fatal("expected vless inbound in sing-box config") + } +} + +func TestRenderRuntimeBundleReality(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + node := Node{ + ID: "nl-reality", + Name: "NL Reality", + Region: "nl", + Host: "203.0.113.20", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + { + Type: "vless-reality", + Enabled: true, + Port: 443, + Auth: &AuthProfile{ + UUID: "33333333-3333-3333-3333-333333333333", + }, + Reality: &VLESSRealityProfile{ + ServerName: "login.microsoftonline.com", + ServerPort: 443, + PrivateKey: "UuMBgl7MXTPx9inmQp2UC7Jcnwc6XYbwDNebonM-FCc", + PublicKey: "jNXHt1yRo0vDuchQlIP6Z0ZvjT3KtzVI-T4E7RoLJS0", + ShortID: "0123456789abcdef", + Fingerprint: "chrome", + }, + }, + }, + } + + if err := RenderRuntimeBundle(dir, node, "20260408-180000"); err != nil { + t.Fatalf("RenderRuntimeBundle error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "docker-compose.yml")) + if err != nil { + t.Fatalf("ReadFile docker-compose.yml error = %v", err) + } + if !strings.Contains(string(data), "sing-box:") { + t.Fatal("expected sing-box service in runtime compose") + } + if strings.Contains(string(data), "caddy:") { + t.Fatal("did not expect caddy service for reality runtime") + } + + if _, err := os.Stat(filepath.Join(dir, "Caddyfile")); !os.IsNotExist(err) { + t.Fatal("did not expect Caddyfile for reality runtime") + } + + serverConfig, err := os.ReadFile(filepath.Join(dir, "sing-box.server.json")) + if err != nil { + t.Fatalf("ReadFile sing-box.server.json error = %v", err) + } + s := string(serverConfig) + if !strings.Contains(s, "\"private_key\": \"UuMBgl7MXTPx9inmQp2UC7Jcnwc6XYbwDNebonM-FCc\"") { + t.Fatal("expected reality private key in sing-box config") + } + if !strings.Contains(s, "\"short_id\": [") || !strings.Contains(s, "0123456789abcdef") { + t.Fatal("expected reality short id in sing-box config") + } + if !strings.Contains(s, "login.microsoftonline.com") { + t.Fatal("expected reality handshake destination in sing-box config") + } +} + +func TestHysteria2Bundle(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + node := Node{ + ID: "nl-hy2", + Name: "NL Hysteria2", + Region: "nl", + Host: "203.0.113.30", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + { + Type: "hysteria2", + Enabled: true, + Port: 443, + Auth: &AuthProfile{ + Password: "user-password", + }, + Hysteria2: &Hysteria2Profile{ + Port: 443, + UpMbps: 100, + DownMbps: 100, + ObfsPassword: "obfs-password", + UserPassword: "user-password", + CertPath: "/etc/sing-box/cert.pem", + KeyPath: "/etc/sing-box/key.pem", + }, + }, + }, + } + + if err := RenderRuntimeBundle(dir, node, "20260408-220000"); err != nil { + t.Fatalf("RenderRuntimeBundle error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "docker-compose.yml")) + if err != nil { + t.Fatalf("ReadFile docker-compose.yml error = %v", err) + } + compose := string(data) + if !strings.Contains(compose, "443:443/udp") { + t.Fatal("expected udp port mapping for hysteria2 runtime") + } + if !strings.Contains(compose, "127.0.0.1:1080:1080/tcp") { + t.Fatal("expected local tcp health port mapping for hysteria2 runtime") + } + if strings.Contains(compose, "caddy:") { + t.Fatal("did not expect caddy service for hysteria2 runtime") + } + + serverConfig, err := os.ReadFile(filepath.Join(dir, "sing-box.server.json")) + if err != nil { + t.Fatalf("ReadFile sing-box.server.json error = %v", err) + } + config := string(serverConfig) + if !strings.Contains(config, "\"type\": \"hysteria2\"") { + t.Fatal("expected hysteria2 inbound in sing-box config") + } + if !strings.Contains(config, "\"salamander\"") { + t.Fatal("expected salamander obfuscation in sing-box config") + } + if !strings.Contains(config, "\"listen_port\": 1080") { + t.Fatal("expected mixed health inbound in sing-box config") + } + if !strings.Contains(config, "\"certificate_path\": \"/etc/sing-box/cert.pem\"") { + t.Fatal("expected embedded certificate path in sing-box config") + } + if _, err := os.Stat(filepath.Join(dir, "cert.pem")); err != nil { + t.Fatalf("expected generated cert.pem: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "key.pem")); err != nil { + t.Fatalf("expected generated key.pem: %v", err) + } +} + +func TestRenderRuntimeBundleMultiProtocol(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + node := Node{ + ID: "nl-multi", + Name: "NL Multi", + Region: "nl", + Host: "203.0.113.40", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "key", + }, + Protocols: []ProtocolProfile{ + { + Type: "vless-reality", + Enabled: true, + Port: 443, + Auth: &AuthProfile{ + UUID: "33333333-3333-3333-3333-333333333333", + }, + Reality: &VLESSRealityProfile{ + ServerName: "www.microsoft.com", + ServerPort: 443, + PrivateKey: "UuMBgl7MXTPx9inmQp2UC7Jcnwc6XYbwDNebonM-FCc", + PublicKey: "jNXHt1yRo0vDuchQlIP6Z0ZvjT3KtzVI-T4E7RoLJS0", + ShortID: "0123456789abcdef", + Fingerprint: "chrome", + }, + }, + { + Type: "hysteria2", + Enabled: true, + Port: 443, + Auth: &AuthProfile{ + Password: "user-password", + }, + Hysteria2: &Hysteria2Profile{ + Port: 443, + UpMbps: 100, + DownMbps: 100, + ObfsPassword: "obfs-password", + UserPassword: "user-password", + CertPath: "/etc/sing-box/cert.pem", + KeyPath: "/etc/sing-box/key.pem", + }, + }, + }, + } + + if err := RenderRuntimeBundle(dir, node, "20260409-120000"); err != nil { + t.Fatalf("RenderRuntimeBundle error = %v", err) + } + + data, err := os.ReadFile(filepath.Join(dir, "docker-compose.yml")) + if err != nil { + t.Fatalf("ReadFile docker-compose.yml error = %v", err) + } + compose := string(data) + if !strings.Contains(compose, "network_mode: host") { + t.Fatal("expected host networking for multi protocol runtime") + } + + serverConfig, err := os.ReadFile(filepath.Join(dir, "sing-box.server.json")) + if err != nil { + t.Fatalf("ReadFile sing-box.server.json error = %v", err) + } + config := string(serverConfig) + if !strings.Contains(config, "\"tag\": \"vless-reality-in\"") { + t.Fatal("expected reality inbound in sing-box config") + } + if !strings.Contains(config, "\"tag\": \"hysteria2-in\"") { + t.Fatal("expected hysteria2 inbound in sing-box config") + } + if !strings.Contains(config, "\"tag\": \"hy2-health-in\"") { + t.Fatal("expected hysteria2 health inbound for multi runtime") + } +} diff --git a/internal/control/ssh.go b/internal/control/ssh.go new file mode 100644 index 0000000..b7d7dd5 --- /dev/null +++ b/internal/control/ssh.go @@ -0,0 +1,182 @@ +package control + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" +) + +type SSHRunner struct{} + +type SSHExecutor interface { + Run(ctx context.Context, node Node, script string) (*CommandResult, error) + Check(ctx context.Context, node Node) (*CommandResult, error) + CopyFile(ctx context.Context, node Node, localPath, remotePath string) error +} + +type CommandResult struct { + Stdout string + Stderr string +} + +func (r SSHRunner) Run(ctx context.Context, node Node, script string) (*CommandResult, error) { + target := sshTarget(node) + cmd, err := sshCommand(ctx, node, target, "sh -s") + if err != nil { + return &CommandResult{}, err + } + cmd.Stdin = strings.NewReader(script) + + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, fmt.Errorf("ssh %s: %w", target, err) + } + + return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, nil +} + +func (r SSHRunner) Check(ctx context.Context, node Node) (*CommandResult, error) { + target := sshTarget(node) + cmd, err := sshCommand(ctx, node, target, "printf ok") + if err != nil { + return &CommandResult{}, err + } + + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, fmt.Errorf("ssh %s: %w", target, err) + } + + return &CommandResult{Stdout: stdout.String(), Stderr: stderr.String()}, nil +} + +func CopyFileOverSCP(ctx context.Context, node Node, localPath, remotePath string) error { + target := fmt.Sprintf("%s:%s", sshTarget(node), remotePath) + cmd, err := scpCommand(ctx, node, localPath, target) + if err != nil { + return err + } + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("scp %s -> %s: %w: %s", localPath, target, err, string(output)) + } + return nil +} + +func (r SSHRunner) CopyFile(ctx context.Context, node Node, localPath, remotePath string) error { + return CopyFileOverSCP(ctx, node, localPath, remotePath) +} + +func CopyDirContentsOverSCP(ctx context.Context, node Node, localDir, remoteDir string) error { + target := fmt.Sprintf("%s:%s", sshTarget(node), remoteDir) + cmd, err := scpCommand(ctx, node, "-r", filepath.Clean(localDir)+"/.", target) + if err != nil { + return err + } + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("scp %s -> %s: %w: %s", localDir, target, err, string(output)) + } + return nil +} + +func sshBaseArgs(node Node) []string { + args := []string{ + "-o", "StrictHostKeyChecking=accept-new", + "-p", strconv.Itoa(defaultSSHPort(node.SSH.Port)), + } + if strings.TrimSpace(node.SSH.Auth) == "password" { + args = append(args, "-o", "BatchMode=no") + } else { + args = append(args, "-o", "BatchMode=yes") + } + if strings.TrimSpace(node.SSH.IdentityFile) != "" { + args = append(args, "-i", expandHome(node.SSH.IdentityFile)) + } + return args +} + +func scpBaseArgs(node Node) []string { + args := []string{ + "-o", "StrictHostKeyChecking=accept-new", + "-P", strconv.Itoa(defaultSSHPort(node.SSH.Port)), + } + if strings.TrimSpace(node.SSH.Auth) == "password" { + args = append(args, "-o", "BatchMode=no") + } else { + args = append(args, "-o", "BatchMode=yes") + } + if strings.TrimSpace(node.SSH.IdentityFile) != "" { + args = append(args, "-i", expandHome(node.SSH.IdentityFile)) + } + return args +} + +func sshTarget(node Node) string { + return fmt.Sprintf("%s@%s", node.SSH.User, node.Host) +} + +func defaultSSHPort(port int) int { + if port == 0 { + return 22 + } + return port +} + +func expandHome(path string) string { + if path == "" || path[0] != '~' { + return path + } + home, err := exec.Command("sh", "-lc", "printf %s \"$HOME\"").Output() + if err != nil { + return path + } + return filepath.Join(strings.TrimSpace(string(home)), strings.TrimPrefix(path, "~/")) +} + +func sshCommand(ctx context.Context, node Node, extraArgs ...string) (*exec.Cmd, error) { + args := sshBaseArgs(node) + args = append(args, extraArgs...) + return wrapWithPassword(ctx, node, "ssh", args...) +} + +func scpCommand(ctx context.Context, node Node, extraArgs ...string) (*exec.Cmd, error) { + args := scpBaseArgs(node) + args = append(args, extraArgs...) + return wrapWithPassword(ctx, node, "scp", args...) +} + +func wrapWithPassword(ctx context.Context, node Node, command string, args ...string) (*exec.Cmd, error) { + if strings.TrimSpace(node.SSH.Auth) != "password" { + return exec.CommandContext(ctx, command, args...), nil + } + + password := node.SSH.Password + if password == "" { + envName := strings.TrimSpace(node.SSH.PasswordEnv) + if envName == "" { + return nil, fmt.Errorf("ssh password auth for %s requires ssh.password_env", sshTarget(node)) + } + password = os.Getenv(envName) + if password == "" { + return nil, fmt.Errorf("ssh password env %s is empty", envName) + } + } + + wrappedArgs := append([]string{"-p", password, command}, args...) + cmd := exec.CommandContext(ctx, "sshpass", wrappedArgs...) + return cmd, nil +} diff --git a/internal/control/ssh_test.go b/internal/control/ssh_test.go new file mode 100644 index 0000000..8b7afd0 --- /dev/null +++ b/internal/control/ssh_test.go @@ -0,0 +1,80 @@ +package control + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestValidateNodeSSHPasswordAuth(t *testing.T) { + t.Parallel() + + node := Node{ + ID: "pw-01", + Name: "Password Node", + Provider: "custom-vps", + Region: "nl", + Host: "203.0.113.20", + Enabled: true, + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "password", + PasswordEnv: "VPNEM_TEST_PASSWORD", + }, + Protocols: []ProtocolProfile{ + {Type: "socks5", Enabled: true, Port: 1080}, + }, + } + + if err := ValidateNode(node); err != nil { + t.Fatalf("ValidateNode() error = %v", err) + } +} + +func TestWrapWithPasswordUsesSSHPass(t *testing.T) { + t.Setenv("VPNEM_TEST_PASSWORD", "secret") + node := Node{ + ID: "pw-01", + Name: "Password Node", + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "password", + PasswordEnv: "VPNEM_TEST_PASSWORD", + }, + } + + cmd, err := wrapWithPassword(context.Background(), node, "ssh", "-V") + if err != nil { + t.Fatalf("wrapWithPassword() error = %v", err) + } + if got := filepath.Base(cmd.Path); got != "sshpass" { + t.Fatalf("filepath.Base(cmd.Path) = %q, want sshpass", got) + } + if len(cmd.Args) < 4 { + t.Fatalf("cmd.Args too short: %#v", cmd.Args) + } + if cmd.Args[1] != "-p" || cmd.Args[2] != "secret" || cmd.Args[3] != "ssh" { + t.Fatalf("unexpected cmd.Args: %#v", cmd.Args) + } +} + +func TestWrapWithPasswordRequiresEnv(t *testing.T) { + _ = os.Unsetenv("VPNEM_TEST_PASSWORD_MISSING") + node := Node{ + ID: "pw-01", + Name: "Password Node", + SSH: SSHConfig{ + User: "root", + Port: 22, + Auth: "password", + PasswordEnv: "VPNEM_TEST_PASSWORD_MISSING", + }, + } + + if _, err := wrapWithPassword(context.Background(), node, "ssh", "-V"); err == nil { + t.Fatal("expected error for missing password env") + } +} diff --git a/internal/control/state.go b/internal/control/state.go new file mode 100644 index 0000000..7fc7827 --- /dev/null +++ b/internal/control/state.go @@ -0,0 +1,71 @@ +package control + +import ( + "encoding/json" + "errors" + "os" + "path/filepath" + "sort" + "time" +) + +type NodeState struct { + NodeID string `json:"node_id"` + BootstrapStatus string `json:"bootstrap_status"` + LastBootstrapAt *time.Time `json:"last_bootstrap_at,omitempty"` + LastHealthCheckAt *time.Time `json:"last_health_check_at,omitempty"` + LastDNSSyncAt *time.Time `json:"last_dns_sync_at,omitempty"` + PublicHost string `json:"public_host,omitempty"` + Services []ServiceStatus `json:"services,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +type ServiceStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Port int `json:"port"` +} + +func LoadNodeState(dir, nodeID string) (*NodeState, error) { + data, err := os.ReadFile(filepath.Join(dir, nodeID+".json")) + if err != nil { + return nil, err + } + + var state NodeState + if err := json.Unmarshal(data, &state); err != nil { + return nil, err + } + return &state, nil +} + +func SaveNodeState(dir string, state NodeState) error { + if err := os.MkdirAll(dir, 0o755); err != nil { + return err + } + + sort.Slice(state.Services, func(i, j int) bool { + return state.Services[i].Type < state.Services[j].Type + }) + + data, err := json.MarshalIndent(state, "", " ") + if err != nil { + return err + } + data = append(data, '\n') + + tmpPath := filepath.Join(dir, state.NodeID+".json.tmp") + finalPath := filepath.Join(dir, state.NodeID+".json") + if err := os.WriteFile(tmpPath, data, 0o600); err != nil { + return err + } + return os.Rename(tmpPath, finalPath) +} + +func DeleteNodeState(dir, nodeID string) error { + err := os.Remove(filepath.Join(dir, nodeID+".json")) + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err +} diff --git a/internal/control/upgrade_test.go b/internal/control/upgrade_test.go new file mode 100644 index 0000000..c3e5404 --- /dev/null +++ b/internal/control/upgrade_test.go @@ -0,0 +1,55 @@ +package control + +import ( + "context" + "strings" + "testing" +) + +type fakeRunner struct{} + +func (fakeRunner) Run(ctx context.Context, node Node, script string) (*CommandResult, error) { + if strings.Contains(script, "HEALTHZ_HTTP_CODE=") { + return &CommandResult{ + Stdout: "{\"Service\":\"sing-box\",\"Status\":\"running\"}\nHEALTHZ_HTTP_CODE=200\n", + }, nil + } + return &CommandResult{Stdout: "ok\n"}, nil +} + +func (fakeRunner) Check(ctx context.Context, node Node) (*CommandResult, error) { + return &CommandResult{Stdout: "ok"}, nil +} + +func (fakeRunner) CopyFile(ctx context.Context, node Node, localPath, remotePath string) error { + return nil +} + +func TestUpgradeNode(t *testing.T) { + t.Parallel() + + state, err := UpgradeNode(context.Background(), fakeRunner{}, Node{ + ID: "nl-01", + Name: "NL 01", + Region: "nl", + Host: "203.0.113.10", + Domain: "nl-01.example.com", + Enabled: true, + SSH: SSHConfig{User: "root", Port: 22, Auth: "key", IdentityFile: "~/.ssh/id_ed25519"}, + Protocols: []ProtocolProfile{ + {Type: "vless", Enabled: true, Port: 443, TLS: &TLSProfile{Enabled: true, ServerName: "nl-01.example.com"}, Auth: &AuthProfile{UUID: "11111111-1111-1111-1111-111111111111"}, Extra: map[string]any{"path": "/ws"}}, + }, + }, t.TempDir()) + if err != nil { + t.Fatalf("UpgradeNode() error = %v", err) + } + if state == nil { + t.Fatal("expected state") + } + if state.BootstrapStatus != "healthy" { + t.Fatalf("BootstrapStatus = %q, want healthy", state.BootstrapStatus) + } + if got := state.Metadata["lifecycle_action"]; got != "upgrade" { + t.Fatalf("lifecycle_action = %v, want upgrade", got) + } +} |
