package control import ( "errors" "fmt" "io/fs" "os" "path/filepath" "sort" "strings" "gopkg.in/yaml.v3" ) type Inventory struct { Nodes []Node } func LoadInventoryDir(dir string) (*Inventory, error) { entries, err := os.ReadDir(dir) if err != nil { return nil, err } nodes := make([]Node, 0, len(entries)) for _, entry := range entries { if entry.IsDir() { continue } ext := strings.ToLower(filepath.Ext(entry.Name())) if ext != ".yaml" && ext != ".yml" { continue } node, err := LoadNodeFile(filepath.Join(dir, entry.Name())) if err != nil { return nil, err } nodes = append(nodes, *node) } sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) return &Inventory{Nodes: nodes}, nil } func LoadNodeFile(path string) (*Node, error) { data, err := os.ReadFile(path) if err != nil { return nil, err } var node Node if err := yaml.Unmarshal(data, &node); err != nil { return nil, fmt.Errorf("parse %s: %w", path, err) } for idx := range node.Protocols { if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { return nil, fmt.Errorf("prepare %s: %w", path, err) } if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { return nil, fmt.Errorf("prepare %s: %w", path, err) } } if err := ValidateNode(node); err != nil { return nil, fmt.Errorf("validate %s: %w", path, err) } return &node, nil } func SaveNodeFile(dir string, node Node) (string, error) { for idx := range node.Protocols { if err := ensureRealityProfile(&node.Protocols[idx]); err != nil { return "", err } if err := ensureHysteria2Profile(&node.Protocols[idx]); err != nil { return "", err } } if err := ValidateNode(node); err != nil { return "", err } if err := os.MkdirAll(dir, 0o755); err != nil { return "", err } path := filepath.Join(dir, node.ID+".yaml") data, err := yaml.Marshal(node) if err != nil { return "", err } if err := os.WriteFile(path, data, 0o600); err != nil { return "", err } return path, nil } func (i *Inventory) NodeByID(id string) (*Node, bool) { for idx := range i.Nodes { if i.Nodes[idx].ID == id { return &i.Nodes[idx], true } } return nil, false } func ValidateNode(node Node) error { if strings.TrimSpace(node.ID) == "" { return errors.New("id is required") } if strings.TrimSpace(node.Name) == "" { return errors.New("name is required") } if strings.TrimSpace(node.Region) == "" { return errors.New("region is required") } if strings.TrimSpace(node.Host) == "" { return errors.New("host is required") } if node.SSH.Port < 0 || node.SSH.Port > 65535 { return errors.New("ssh.port must be between 0 and 65535") } if node.SSH.Port == 0 { node.SSH.Port = 22 } if strings.TrimSpace(node.SSH.User) == "" { return errors.New("ssh.user is required") } if strings.TrimSpace(node.SSH.Auth) == "" { return errors.New("ssh.auth is required") } switch strings.TrimSpace(node.SSH.Auth) { case "key": if strings.TrimSpace(node.SSH.IdentityFile) == "" { return errors.New("ssh.identity_file is required when ssh.auth=key") } case "password": if strings.TrimSpace(node.SSH.PasswordEnv) == "" { return errors.New("ssh.password_env is required when ssh.auth=password") } default: return errors.New("ssh.auth must be either key or password") } if len(node.Protocols) == 0 { return errors.New("at least one protocol is required") } seen := make(map[string]struct{}, len(node.Protocols)) for _, protocol := range node.Protocols { if strings.TrimSpace(protocol.Type) == "" { return errors.New("protocol.type is required") } if protocol.Port <= 0 || protocol.Port > 65535 { return fmt.Errorf("protocol %s has invalid port", protocol.Type) } if protocol.Type == "vless" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" { return errors.New("vless with tls.enabled requires node.domain") } if protocol.Type == "vless-reality" { if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.UUID) == "" { return errors.New("vless-reality requires auth.uuid") } if protocol.Reality == nil { return errors.New("vless-reality requires reality settings") } if strings.TrimSpace(protocol.Reality.ServerName) == "" { return errors.New("vless-reality requires reality.server_name") } if strings.TrimSpace(protocol.Reality.PrivateKey) == "" { return errors.New("vless-reality requires reality.private_key") } if strings.TrimSpace(protocol.Reality.PublicKey) == "" { return errors.New("vless-reality requires reality.public_key") } if strings.TrimSpace(protocol.Reality.ShortID) == "" { return errors.New("vless-reality requires reality.short_id") } } if protocol.Type == "vmess" && protocol.TLS != nil && protocol.TLS.Enabled && strings.TrimSpace(node.Domain) == "" { return errors.New("vmess with tls.enabled requires node.domain") } if protocol.Type == "hysteria2" { if protocol.Auth == nil || strings.TrimSpace(protocol.Auth.Password) == "" { return errors.New("hysteria2 requires auth.password") } if protocol.Hysteria2 == nil { return errors.New("hysteria2 requires hysteria2 settings") } if protocol.Hysteria2.CertPath == "" || protocol.Hysteria2.KeyPath == "" { return errors.New("hysteria2 requires cert_path and key_path") } } key := protocol.Type if _, ok := seen[key]; ok { return fmt.Errorf("duplicate protocol %s", protocol.Type) } seen[key] = struct{}{} } return nil } func CopyNodeFile(srcPath, inventoryDir string) (string, error) { node, err := LoadNodeFile(srcPath) if err != nil { return "", err } return SaveNodeFile(inventoryDir, *node) } func DeleteNodeFile(dir, nodeID string) error { err := os.Remove(filepath.Join(dir, nodeID+".yaml")) if errors.Is(err, fs.ErrNotExist) { return nil } return err } func IsNotExist(err error) bool { return errors.Is(err, fs.ErrNotExist) }