package sync import ( "crypto/sha256" "encoding/hex" "fmt" "io" "log" "net/http" "os" "path/filepath" "runtime" "time" ) // Updater checks for and downloads client updates. type Updater struct { fetcher *Fetcher currentVer string dataDir string } // NewUpdater creates an updater. func NewUpdater(fetcher *Fetcher, currentVersion, dataDir string) *Updater { return &Updater{ fetcher: fetcher, currentVer: currentVersion, dataDir: dataDir, } } // UpdateInfo describes an available update. type UpdateInfo struct { Available bool `json:"available"` Version string `json:"version"` Changelog string `json:"changelog"` CurrentVer string `json:"current_version"` } // Check returns info about available updates. func (u *Updater) Check() (*UpdateInfo, error) { ver, err := u.fetcher.FetchVersion() if err != nil { return nil, fmt.Errorf("check update: %w", err) } return &UpdateInfo{ Available: ver.Version != u.currentVer, Version: ver.Version, Changelog: ver.Changelog, CurrentVer: u.currentVer, }, nil } // Download fetches the new binary, verifies checksum, and prepares for restart. // On success it returns "restart_pending" and the caller should exit gracefully. func (u *Updater) Download() (string, error) { ver, err := u.fetcher.FetchVersion() if err != nil { return "", fmt.Errorf("fetch version: %w", err) } if ver.URL == "" { suffix := "linux-amd64" if runtime.GOOS == "windows" { suffix = "windows-amd64.exe" } ver.URL = u.fetcher.baseURL + "/releases/vpnem-" + suffix } client := &http.Client{Timeout: 5 * time.Minute} resp, err := client.Get(ver.URL) if err != nil { return "", fmt.Errorf("download: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "", fmt.Errorf("download: HTTP %d", resp.StatusCode) } ext := "" if runtime.GOOS == "windows" { ext = ".exe" } // Download to temp file first downloadPath := filepath.Join(u.dataDir, "vpnem-new"+ext) f, err := os.Create(downloadPath) if err != nil { return "", fmt.Errorf("create file: %w", err) } // Track SHA256 while downloading hasher := sha256.New() written, err := io.Copy(io.MultiWriter(f, hasher), resp.Body) f.Close() if err != nil { os.Remove(downloadPath) return "", fmt.Errorf("write update: %w", err) } // Verify SHA256 checksum if provided if ver.SHA256 != "" { gotHash := hex.EncodeToString(hasher.Sum(nil)) if gotHash != ver.SHA256 { os.Remove(downloadPath) return "", fmt.Errorf("checksum mismatch: expected %s, got %s (%.1f MB)", ver.SHA256, gotHash, float64(written)/1024/1024) } log.Printf("update: sha256 verified (%.1f MB)", float64(written)/1024/1024) } // Clean stale configs so new version starts fresh os.Remove(filepath.Join(u.dataDir, "state.json")) os.Remove(filepath.Join(u.dataDir, "config.json")) os.Remove(filepath.Join(u.dataDir, "cache.db")) currentBin, _ := os.Executable() if currentBin == "" { return "", fmt.Errorf("could not determine current binary") } if runtime.GOOS == "windows" { // Windows: can't overwrite running exe // Strategy: rename old to .old, copy new in place // If .old already exists from a previous failed update, remove it oldBin := currentBin + ".old" os.Remove(oldBin) if err := os.Rename(currentBin, oldBin); err != nil { return "", fmt.Errorf("rename old binary: %w", err) } if err := copyFile(downloadPath, currentBin); err != nil { // Restore old binary os.Remove(currentBin) os.Rename(oldBin, currentBin) os.Remove(downloadPath) return "", fmt.Errorf("copy new binary: %w", err) } os.Remove(downloadPath) log.Printf("update: binary replaced, version %s", ver.Version) return "restart_pending", nil } // Linux: overwrite in place if err := copyFile(downloadPath, currentBin); err != nil { os.Remove(downloadPath) return "", fmt.Errorf("copy new binary: %w", err) } os.Remove(downloadPath) log.Printf("update: binary replaced, version %s", ver.Version) return "restart_pending", nil } func copyFile(src, dst string) error { in, err := os.Open(src) if err != nil { return err } defer in.Close() out, err := os.Create(dst) if err != nil { return err } defer out.Close() _, err = io.Copy(out, in) if err != nil { return err } // Preserve executable permission on Linux if runtime.GOOS != "windows" { os.Chmod(dst, 0o755) } return nil }