summaryrefslogtreecommitdiff
path: root/internal/control/inventory.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/control/inventory.go')
-rw-r--r--internal/control/inventory.go225
1 files changed, 225 insertions, 0 deletions
diff --git a/internal/control/inventory.go b/internal/control/inventory.go
new file mode 100644
index 0000000..22d9179
--- /dev/null
+++ b/internal/control/inventory.go
@@ -0,0 +1,225 @@
+package control
+
+import (
+ "errors"
+ "fmt"
+ "io/fs"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+
+ "gopkg.in/yaml.v3"
+)
+
+type Inventory struct {
+ Nodes []Node
+}
+
+func LoadInventoryDir(dir string) (*Inventory, error) {
+ entries, err := os.ReadDir(dir)
+ if err != nil {
+ return nil, err
+ }
+
+ nodes := make([]Node, 0, len(entries))
+ for _, entry := range entries {
+ if entry.IsDir() {
+ continue
+ }
+ ext := strings.ToLower(filepath.Ext(entry.Name()))
+ if ext != ".yaml" && ext != ".yml" {
+ continue
+ }
+
+ node, err := LoadNodeFile(filepath.Join(dir, entry.Name()))
+ if err != nil {
+ return nil, err
+ }
+ nodes = append(nodes, *node)
+ }
+
+ sort.Slice(nodes, func(i, j int) bool {
+ return nodes[i].ID < nodes[j].ID
+ })
+
+ return &Inventory{Nodes: nodes}, nil
+}
+
+func LoadNodeFile(path string) (*Node, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var node Node
+ if err := yaml.Unmarshal(data, &node); err != nil {
+ return nil, fmt.Errorf("parse %s: %w", path, err)
+ }
+ for idx := range node.Protocols {
+ if err := ensureRealityProfile(&node.Protocols[idx]); err != nil {
+ return nil, fmt.Errorf("prepare %s: %w", path, err)
+ }
+ if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil {
+ return nil, fmt.Errorf("prepare %s: %w", path, err)
+ }
+ }
+ if err := ValidateNode(node); err != nil {
+ return nil, fmt.Errorf("validate %s: %w", path, err)
+ }
+
+ return &node, nil
+}
+
+func SaveNodeFile(dir string, node Node) (string, error) {
+ for idx := range node.Protocols {
+ if err := ensureRealityProfile(&node.Protocols[idx]); err != nil {
+ return "", err
+ }
+ if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil {
+ return "", err
+ }
+ }
+ if err := ValidateNode(node); err != nil {
+ return "", err
+ }
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return "", err
+ }
+
+ path := filepath.Join(dir, node.ID+".yaml")
+ data, err := yaml.Marshal(node)
+ if err != nil {
+ return "", err
+ }
+ if err := os.WriteFile(path, data, 0o600); err != nil {
+ return "", err
+ }
+
+ return path, nil
+}
+
+func (i *Inventory) NodeByID(id string) (*Node, bool) {
+ for idx := range i.Nodes {
+ if i.Nodes[idx].ID == id {
+ return &i.Nodes[idx], true
+ }
+ }
+ return nil, false
+}
+
+func ValidateNode(node Node) error {
+ if strings.TrimSpace(node.ID) == "" {
+ return errors.New("id is required")
+ }
+ if strings.TrimSpace(node.Name) == "" {
+ return errors.New("name is required")
+ }
+ if strings.TrimSpace(node.Region) == "" {
+ return errors.New("region is required")
+ }
+ if strings.TrimSpace(node.Host) == "" {
+ return errors.New("host is required")
+ }
+ if node.SSH.Port < 0 || node.SSH.Port > 65535 {
+ return errors.New("ssh.port must be between 0 and 65535")
+ }
+ if node.SSH.Port == 0 {
+ node.SSH.Port = 22
+ }
+ if strings.TrimSpace(node.SSH.User) == "" {
+ return errors.New("ssh.user is required")
+ }
+ if strings.TrimSpace(node.SSH.Auth) == "" {
+ return errors.New("ssh.auth is required")
+ }
+ switch strings.TrimSpace(node.SSH.Auth) {
+ case "key":
+ if strings.TrimSpace(node.SSH.IdentityFile) == "" {
+ return errors.New("ssh.identity_file is required when ssh.auth=key")
+ }
+ case "password":
+ if strings.TrimSpace(node.SSH.PasswordEnv) == "" {
+ return errors.New("ssh.password_env is required when ssh.auth=password")
+ }
+ default:
+ return errors.New("ssh.auth must be either key or password")
+ }
+ if len(node.Protocols) == 0 {
+ return errors.New("at least one protocol is required")
+ }
+
+ seen := make(map[string]struct{}, len(node.Protocols))
+ for _, protocol := range node.Protocols {
+ if strings.TrimSpace(protocol.Type) == "" {
+ return errors.New("protocol.type is required")
+ }
+ if protocol.Port <= 0 || protocol.Port > 65535 {
+ return fmt.Errorf("protocol %s has invalid port", protocol.Type)
+ }
+ if protocol.Type == "vless" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" {
+ return errors.New("vless with tls.enabled requires node.domain")
+ }
+ if protocol.Type == "vless-reality" {
+ if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.UUID) == "" {
+ return errors.New("vless-reality requires auth.uuid")
+ }
+ if protocol.Reality == nil {
+ return errors.New("vless-reality requires reality settings")
+ }
+ if strings.TrimSpace(protocol.Reality.ServerName) == "" {
+ return errors.New("vless-reality requires reality.server_name")
+ }
+ if strings.TrimSpace(protocol.Reality.PrivateKey) == "" {
+ return errors.New("vless-reality requires reality.private_key")
+ }
+ if strings.TrimSpace(protocol.Reality.PublicKey) == "" {
+ return errors.New("vless-reality requires reality.public_key")
+ }
+ if strings.TrimSpace(protocol.Reality.ShortID) == "" {
+ return errors.New("vless-reality requires reality.short_id")
+ }
+ }
+ if protocol.Type == "vmess" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" {
+ return errors.New("vmess with tls.enabled requires node.domain")
+ }
+ if protocol.Type == "hysteria2" {
+ if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.Password) == "" {
+ return errors.New("hysteria2 requires auth.password")
+ }
+ if protocol.Hysteria2 == nil {
+ return errors.New("hysteria2 requires hysteria2 settings")
+ }
+ if protocol.Hysteria2.CertPath == "" || protocol.Hysteria2.KeyPath == "" {
+ return errors.New("hysteria2 requires cert_path and key_path")
+ }
+ }
+ key := protocol.Type
+ if _, ok := seen[key]; ok {
+ return fmt.Errorf("duplicate protocol %s", protocol.Type)
+ }
+ seen[key] = struct{}{}
+ }
+
+ return nil
+}
+
+func CopyNodeFile(srcPath, inventoryDir string) (string, error) {
+ node, err := LoadNodeFile(srcPath)
+ if err != nil {
+ return "", err
+ }
+ return SaveNodeFile(inventoryDir, *node)
+}
+
+func DeleteNodeFile(dir, nodeID string) error {
+ err := os.Remove(filepath.Join(dir, nodeID+".yaml"))
+ if errors.Is(err, fs.ErrNotExist) {
+ return nil
+ }
+ return err
+}
+
+func IsNotExist(err error) bool {
+ return errors.Is(err, fs.ErrNotExist)
+}