summaryrefslogtreecommitdiff
path: root/internal/control/hysteria2.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/control/hysteria2.go')
-rw-r--r--internal/control/hysteria2.go179
1 files changed, 179 insertions, 0 deletions
diff --git a/internal/control/hysteria2.go b/internal/control/hysteria2.go
new file mode 100644
index 0000000..78c80e7
--- /dev/null
+++ b/internal/control/hysteria2.go
@@ -0,0 +1,179 @@
+package control
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/pem"
+ "fmt"
+ "math/big"
+ "strings"
+ "time"
+)
+
+const (
+ defaultHysteria2Port = 443
+ defaultHysteria2UpMbps = 100
+ defaultHysteria2DownMbps = 100
+ defaultHysteria2CertPath = "/etc/sing-box/cert.pem"
+ defaultHysteria2KeyPath = "/etc/sing-box/key.pem"
+ defaultHysteria2ALPN = "h3"
+)
+
+func ensureHysteria2Profile(protocol *ProtocolProfile) error {
+ if protocol == nil || protocol.Type != "hysteria2" {
+ return nil
+ }
+ if protocol.Hysteria2 == nil {
+ protocol.Hysteria2 = &Hysteria2Profile{}
+ }
+
+ if protocol.Port > 0 && protocol.Hysteria2.Port == 0 {
+ protocol.Hysteria2.Port = protocol.Port
+ }
+ if protocol.Hysteria2.Port == 0 {
+ protocol.Hysteria2.Port = defaultHysteria2Port
+ }
+ protocol.Port = protocol.Hysteria2.Port
+
+ if protocol.Auth == nil {
+ protocol.Auth = &AuthProfile{}
+ }
+ if protocol.Hysteria2.UserPassword == "" && strings.TrimSpace(protocol.Auth.Password) != "" {
+ protocol.Hysteria2.UserPassword = strings.TrimSpace(protocol.Auth.Password)
+ }
+ if protocol.Hysteria2.UserPassword == "" {
+ password, err := randomBase64(16)
+ if err != nil {
+ return err
+ }
+ protocol.Hysteria2.UserPassword = password
+ }
+ protocol.Auth.Password = protocol.Hysteria2.UserPassword
+
+ if protocol.Extra == nil {
+ protocol.Extra = map[string]any{}
+ }
+ if protocol.Hysteria2.ObfsPassword == "" {
+ if extra := stringFromExtra(protocol.Extra, "obfs_password"); extra != "" {
+ protocol.Hysteria2.ObfsPassword = extra
+ }
+ }
+ if protocol.Hysteria2.ObfsPassword == "" {
+ obfsPassword, err := randomHex(32)
+ if err != nil {
+ return err
+ }
+ protocol.Hysteria2.ObfsPassword = obfsPassword
+ }
+ protocol.Extra["obfs_password"] = protocol.Hysteria2.ObfsPassword
+
+ if protocol.Hysteria2.UpMbps == 0 {
+ protocol.Hysteria2.UpMbps = intFromExtra(protocol.Extra, "up_mbps", defaultHysteria2UpMbps)
+ }
+ if protocol.Hysteria2.DownMbps == 0 {
+ protocol.Hysteria2.DownMbps = intFromExtra(protocol.Extra, "down_mbps", defaultHysteria2DownMbps)
+ }
+ protocol.Extra["up_mbps"] = protocol.Hysteria2.UpMbps
+ protocol.Extra["down_mbps"] = protocol.Hysteria2.DownMbps
+
+ if protocol.Hysteria2.CertPath == "" {
+ protocol.Hysteria2.CertPath = firstNonEmptyString(stringFromExtra(protocol.Extra, "tls_cert_path"), defaultHysteria2CertPath)
+ }
+ if protocol.Hysteria2.KeyPath == "" {
+ protocol.Hysteria2.KeyPath = firstNonEmptyString(stringFromExtra(protocol.Extra, "tls_key_path"), defaultHysteria2KeyPath)
+ }
+ protocol.Extra["tls_cert_path"] = protocol.Hysteria2.CertPath
+ protocol.Extra["tls_key_path"] = protocol.Hysteria2.KeyPath
+
+ if protocol.TLS == nil {
+ protocol.TLS = &TLSProfile{}
+ }
+ protocol.TLS.Enabled = true
+ if strings.TrimSpace(protocol.TLS.ServerName) == "" {
+ protocol.TLS.ServerName = ""
+ }
+ return nil
+}
+
+func GenerateSelfSignedCert() (certPEM, keyPEM []byte, err error) {
+ commonName := "node-" + randomHostnameSuffix() + ".local"
+ return generateSelfSignedCertForHost(commonName)
+}
+
+func generateSelfSignedCertForHost(commonName string) (certPEM, keyPEM []byte, err error) {
+ privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate ecdsa key: %w", err)
+ }
+ serialLimit := new(big.Int).Lsh(big.NewInt(1), 128)
+ serialNumber, err := rand.Int(rand.Reader, serialLimit)
+ if err != nil {
+ return nil, nil, fmt.Errorf("generate serial: %w", err)
+ }
+
+ template := &x509.Certificate{
+ SerialNumber: serialNumber,
+ Subject: pkix.Name{
+ CommonName: commonName,
+ },
+ NotBefore: time.Now().UTC().Add(-time.Hour),
+ NotAfter: time.Now().UTC().AddDate(10, 0, 0),
+ KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+ BasicConstraintsValid: true,
+ DNSNames: []string{commonName},
+ }
+
+ certDER, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("create certificate: %w", err)
+ }
+ certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
+
+ keyDER, err := x509.MarshalECPrivateKey(privateKey)
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal private key: %w", err)
+ }
+ keyPEM = pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
+ return certPEM, keyPEM, nil
+}
+
+func randomBase64(size int) (string, error) {
+ buf := make([]byte, size)
+ if _, err := rand.Read(buf); err != nil {
+ return "", err
+ }
+ return base64.RawStdEncoding.EncodeToString(buf), nil
+}
+
+func randomHostnameSuffix() string {
+ buf := make([]byte, 4)
+ if _, err := rand.Read(buf); err != nil {
+ return "local"
+ }
+ return hex.EncodeToString(buf)
+}
+
+func firstNonEmptyString(values ...string) string {
+ for _, value := range values {
+ if strings.TrimSpace(value) != "" {
+ return strings.TrimSpace(value)
+ }
+ }
+ return ""
+}
+
+func EnsureProtocolForUI(protocol *ProtocolProfile) error {
+ if err := ensureRealityProfile(protocol); err != nil {
+ return err
+ }
+ if err := ensureHysteria2Profile(protocol); err != nil {
+ return err
+ }
+ return nil
+}