From e5f059fbe9fbc6ac4872387517047dd5e47a16dc Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 12 Sep 2024 23:05:35 -0500 Subject: [PATCH 01/17] Add in v2 certificate --- Makefile | 2 +- cert/asn1.go | 52 ++++ cert/cert.go | 63 ++++- cert/cert_test.go | 4 +- cert/cert_v1.go | 312 +++++++++++------------- cert/cert_v2.asn1 | 38 +++ cert/cert_v2.go | 605 ++++++++++++++++++++++++++++++++++++++++++++++ cert/errors.go | 3 + cert/pem.go | 19 +- cert/sign.go | 80 +++++- handshake_ix.go | 7 +- nebula.pb.go | 127 ++++++---- nebula.proto | 1 + outside.go | 26 -- 14 files changed, 1072 insertions(+), 267 deletions(-) create mode 100644 cert/asn1.go create mode 100644 cert/cert_v2.asn1 create mode 100644 cert/cert_v2.go diff --git a/Makefile b/Makefile index 6922cc3e8..d3fbcaa57 100644 --- a/Makefile +++ b/Makefile @@ -196,7 +196,7 @@ bench-cpu-long: go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go tool pprof go-audit.test cpu.pprof -proto: nebula.pb.go cert/cert.pb.go +proto: nebula.pb.go cert/cert_v1.pb.go nebula.pb.go: nebula.proto .FORCE go build github.com/gogo/protobuf/protoc-gen-gogofaster diff --git a/cert/asn1.go b/cert/asn1.go new file mode 100644 index 000000000..6bf6a8ded --- /dev/null +++ b/cert/asn1.go @@ -0,0 +1,52 @@ +package cert + +import ( + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value +// https://github.com/golang/go/issues/64811#issuecomment-1944446920 +func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] > 0 + return true + } + + return false +} + +// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value +// Similar issue as with readOptionalASN1Boolean +func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] + return true + } + + return false +} diff --git a/cert/cert.go b/cert/cert.go index 4e41c4349..5d5abd646 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,15 +1,17 @@ package cert import ( + "fmt" "net/netip" "time" ) -type Version int +type Version uint32 const ( - Version1 Version = 1 - Version2 Version = 2 + VersionPre1 Version = 0 + Version1 Version = 1 + Version2 Version = 2 ) type Certificate interface { @@ -109,21 +111,62 @@ type CachedCertificate struct { // UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. func UnmarshalCertificate(b []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, true) - if err != nil { - return nil, err + //TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers? + var c Certificate + c, err := unmarshalCertificateV2(b, nil) + if err == nil { + return c, nil } - return c, nil + + c, err = unmarshalCertificateV1(b, nil) + if err == nil { + return c, nil + } + + return nil, fmt.Errorf("could not unmarshal certificate") } // UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, false) +func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) { + var c Certificate + var err error + + switch v { + case VersionPre1, Version1: + c, err = unmarshalCertificateV1(b, publicKey) + case Version2: + c, err = unmarshalCertificateV2(b, publicKey) + default: + //TODO: make a static var + return nil, fmt.Errorf("unknown certificate version %d", v) + } + if err != nil { return nil, err } - c.details.PublicKey = publicKey return c, nil } + +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) { + if publicKey == nil { + return nil, ErrNoPeerStaticKey + } + + if rawCertBytes == nil { + return nil, ErrNoPayload + } + + c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey) + if err != nil { + return nil, fmt.Errorf("error unmarshaling cert: %w", err) + } + + cc, err := caPool.VerifyCertificate(time.Now(), c) + if err != nil { + return nil, fmt.Errorf("certificate validation failed: %w", err) + } + + return cc, nil +} diff --git a/cert/cert_test.go b/cert/cert_test.go index c9bb3f32f..b2ea406cb 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -51,7 +51,7 @@ func TestMarshalingNebulaCertificate(t *testing.T) { assert.Nil(t, err) //t.Log("Cert size:", len(b)) - nc2, err := unmarshalCertificateV1(b, true) + nc2, err := unmarshalCertificateV1(b, nil) assert.Nil(t, err) assert.Equal(t, nc.signature, nc2.Signature()) @@ -534,7 +534,7 @@ func TestNebulaCertificate_Copy(t *testing.T) { func TestUnmarshalNebulaCertificate(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") - _, err := unmarshalCertificateV1(data, true) + _, err := unmarshalCertificateV1(data, nil) assert.EqualError(t, err, "encoded Details was nil") } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 032caec5a..83f7e3b54 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -6,19 +6,16 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" - "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/pem" "fmt" - "math/big" "net" "net/netip" "time" - "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) @@ -46,56 +43,56 @@ type detailsV1 struct { type m map[string]interface{} -func (nc *certificateV1) Version() Version { +func (c *certificateV1) Version() Version { return Version1 } -func (nc *certificateV1) Curve() Curve { - return nc.details.Curve +func (c *certificateV1) Curve() Curve { + return c.details.Curve } -func (nc *certificateV1) Groups() []string { - return nc.details.Groups +func (c *certificateV1) Groups() []string { + return c.details.Groups } -func (nc *certificateV1) IsCA() bool { - return nc.details.IsCA +func (c *certificateV1) IsCA() bool { + return c.details.IsCA } -func (nc *certificateV1) Issuer() string { - return nc.details.Issuer +func (c *certificateV1) Issuer() string { + return c.details.Issuer } -func (nc *certificateV1) Name() string { - return nc.details.Name +func (c *certificateV1) Name() string { + return c.details.Name } -func (nc *certificateV1) Networks() []netip.Prefix { - return nc.details.Ips +func (c *certificateV1) Networks() []netip.Prefix { + return c.details.Ips } -func (nc *certificateV1) NotAfter() time.Time { - return nc.details.NotAfter +func (c *certificateV1) NotAfter() time.Time { + return c.details.NotAfter } -func (nc *certificateV1) NotBefore() time.Time { - return nc.details.NotBefore +func (c *certificateV1) NotBefore() time.Time { + return c.details.NotBefore } -func (nc *certificateV1) PublicKey() []byte { - return nc.details.PublicKey +func (c *certificateV1) PublicKey() []byte { + return c.details.PublicKey } -func (nc *certificateV1) Signature() []byte { - return nc.signature +func (c *certificateV1) Signature() []byte { + return c.signature } -func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { - return nc.details.Subnets +func (c *certificateV1) UnsafeNetworks() []netip.Prefix { + return c.details.Subnets } -func (nc *certificateV1) Fingerprint() (string, error) { - b, err := nc.Marshal() +func (c *certificateV1) Fingerprint() (string, error) { + b, err := c.Marshal() if err != nil { return "", err } @@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) { return hex.EncodeToString(sum[:]), nil } -func (nc *certificateV1) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) +func (c *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(c.getRawDetails()) if err != nil { return false } - switch nc.details.Curve { + switch c.details.Curve { case Curve_CURVE25519: - return ed25519.Verify(key, b, nc.signature) + return ed25519.Verify(key, b, c.signature) case Curve_P256: x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: return false } } -func (nc *certificateV1) Expired(t time.Time) bool { - return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) +func (c *certificateV1) Expired(t time.Time) bool { + return c.details.NotBefore.After(t) || c.details.NotAfter.Before(t) } -func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.details.Curve { +func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.details.Curve { return fmt.Errorf("curve in cert and private key supplied don't match") } - if nc.details.IsCA { + if c.details.IsCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise @@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } - if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + if !ed25519.PublicKey(c.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: @@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("cannot parse private key as P256") } pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.PublicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: @@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { default: return fmt.Errorf("invalid curve: %s", curve) } - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.PublicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } @@ -181,47 +178,47 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { } // getRawDetails marshals the raw details into protobuf ready struct -func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { +func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ - Name: nc.details.Name, - Groups: nc.details.Groups, - NotBefore: nc.details.NotBefore.Unix(), - NotAfter: nc.details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Curve: nc.details.Curve, + Name: c.details.Name, + Groups: c.details.Groups, + NotBefore: c.details.NotBefore.Unix(), + NotAfter: c.details.NotAfter.Unix(), + PublicKey: make([]byte, len(c.details.PublicKey)), + IsCA: c.details.IsCA, + Curve: c.details.Curve, } - for _, ipNet := range nc.details.Ips { + for _, ipNet := range c.details.Ips { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } - for _, ipNet := range nc.details.Subnets { + for _, ipNet := range c.details.Subnets { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } - copy(rd.PublicKey, nc.details.PublicKey[:]) + copy(rd.PublicKey, c.details.PublicKey[:]) // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) + rd.Issuer, _ = hex.DecodeString(c.details.Issuer) return rd } -func (nc *certificateV1) String() string { - if nc == nil { +func (c *certificateV1) String() string { + if c == nil { return "Certificate {}\n" } s := "NebulaCertificate {\n" s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name) + s += fmt.Sprintf("\t\tName: %v\n", c.details.Name) - if len(nc.details.Ips) > 0 { + if len(c.details.Ips) > 0 { s += "\t\tIps: [\n" - for _, ip := range nc.details.Ips { + for _, ip := range c.details.Ips { s += fmt.Sprintf("\t\t\t%v\n", ip.String()) } s += "\t\t]\n" @@ -229,9 +226,9 @@ func (nc *certificateV1) String() string { s += "\t\tIps: []\n" } - if len(nc.details.Subnets) > 0 { + if len(c.details.Subnets) > 0 { s += "\t\tSubnets: [\n" - for _, ip := range nc.details.Subnets { + for _, ip := range c.details.Subnets { s += fmt.Sprintf("\t\t\t%v\n", ip.String()) } s += "\t\t]\n" @@ -239,9 +236,9 @@ func (nc *certificateV1) String() string { s += "\t\tSubnets: []\n" } - if len(nc.details.Groups) > 0 { + if len(c.details.Groups) > 0 { s += "\t\tGroups: [\n" - for _, g := range nc.details.Groups { + for _, g := range c.details.Groups { s += fmt.Sprintf("\t\t\t\"%v\"\n", g) } s += "\t\t]\n" @@ -249,105 +246,131 @@ func (nc *certificateV1) String() string { s += "\t\tGroups: []\n" } - s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) + s += fmt.Sprintf("\t\tNot before: %v\n", c.details.NotBefore) + s += fmt.Sprintf("\t\tNot After: %v\n", c.details.NotAfter) + s += fmt.Sprintf("\t\tIs CA: %v\n", c.details.IsCA) + s += fmt.Sprintf("\t\tIssuer: %s\n", c.details.Issuer) + s += fmt.Sprintf("\t\tPublic key: %x\n", c.details.PublicKey) + s += fmt.Sprintf("\t\tCurve: %s\n", c.details.Curve) s += "\t}\n" - fp, err := nc.Fingerprint() + fp, err := c.Fingerprint() if err == nil { s += fmt.Sprintf("\tFingerprint: %s\n", fp) } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature()) + s += fmt.Sprintf("\tSignature: %x\n", c.Signature()) s += "}" return s } -func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { - pubKey := nc.details.PublicKey - nc.details.PublicKey = nil - rawCertNoKey, err := nc.Marshal() +func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := c.details.PublicKey + c.details.PublicKey = nil + rawCertNoKey, err := c.Marshal() if err != nil { return nil, err } - nc.details.PublicKey = pubKey + c.details.PublicKey = pubKey return rawCertNoKey, nil } -func (nc *certificateV1) Marshal() ([]byte, error) { +func (c *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.signature, + Details: c.getRawDetails(), + Signature: c.signature, } return proto.Marshal(&rc) } -func (nc *certificateV1) MarshalPEM() ([]byte, error) { - b, err := nc.Marshal() +func (c *certificateV1) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil } -func (nc *certificateV1) MarshalJSON() ([]byte, error) { - fp, _ := nc.Fingerprint() +func (c *certificateV1) MarshalJSON() ([]byte, error) { + fp, _ := c.Fingerprint() jc := m{ "details": m{ - "name": nc.details.Name, - "ips": nc.details.Ips, - "subnets": nc.details.Subnets, - "groups": nc.details.Groups, - "notBefore": nc.details.NotBefore, - "notAfter": nc.details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.details.PublicKey), - "isCa": nc.details.IsCA, - "issuer": nc.details.Issuer, - "curve": nc.details.Curve.String(), + "name": c.details.Name, + "ips": c.details.Ips, + "subnets": c.details.Subnets, + "groups": c.details.Groups, + "notBefore": c.details.NotBefore, + "notAfter": c.details.NotAfter, + "publicKey": fmt.Sprintf("%x", c.details.PublicKey), + "isCa": c.details.IsCA, + "issuer": c.details.Issuer, + "curve": c.details.Curve.String(), }, "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature()), + "signature": fmt.Sprintf("%x", c.Signature()), } return json.Marshal(jc) } -func (nc *certificateV1) Copy() Certificate { - c := &certificateV1{ +func (c *certificateV1) Copy() Certificate { + nc := &certificateV1{ details: detailsV1{ - Name: nc.details.Name, - Groups: make([]string, len(nc.details.Groups)), - Ips: make([]netip.Prefix, len(nc.details.Ips)), - Subnets: make([]netip.Prefix, len(nc.details.Subnets)), - NotBefore: nc.details.NotBefore, - NotAfter: nc.details.NotAfter, - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Issuer: nc.details.Issuer, + Name: c.details.Name, + Groups: make([]string, len(c.details.Groups)), + Ips: make([]netip.Prefix, len(c.details.Ips)), + Subnets: make([]netip.Prefix, len(c.details.Subnets)), + NotBefore: c.details.NotBefore, + NotAfter: c.details.NotAfter, + PublicKey: make([]byte, len(c.details.PublicKey)), + IsCA: c.details.IsCA, + Issuer: c.details.Issuer, + Curve: c.details.Curve, }, - signature: make([]byte, len(nc.signature)), + signature: make([]byte, len(c.signature)), } - copy(c.signature, nc.signature) - copy(c.details.Groups, nc.details.Groups) - copy(c.details.PublicKey, nc.details.PublicKey) + copy(nc.signature, c.signature) + copy(nc.details.Groups, c.details.Groups) + copy(nc.details.PublicKey, c.details.PublicKey) + copy(nc.details.Ips, c.details.Ips) + copy(nc.details.Subnets, c.details.Subnets) + + return nc +} - for i, p := range nc.details.Ips { - c.details.Ips[i] = p +func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV1{ + Name: t.Name, + Ips: t.Networks, + Subnets: t.UnsafeNetworks, + Groups: t.Groups, + NotBefore: t.NotBefore, + NotAfter: t.NotAfter, + PublicKey: t.PublicKey, + IsCA: t.IsCA, + Curve: t.Curve, + Issuer: t.issuer, } - for i, p := range nc.details.Subnets { - c.details.Subnets[i] = p + return nil +} + +func (c *certificateV1) marshalForSigning() ([]byte, error) { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err } + return b, nil +} - return c +func (c *certificateV1) setSignature(b []byte) error { + c.signature = b + return nil } // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert -func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { +// if the publicKey is provided here then it is not required to be present in `b` +func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } @@ -388,9 +411,10 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err copy(nc.details.Groups, rc.Details.Groups) nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) - if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { - return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + if len(publicKey) > 0 { + nc.details.PublicKey = publicKey } + copy(nc.details.PublicKey, rc.Details.PublicKey) var ip netip.Addr @@ -415,62 +439,6 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err return &nc, nil } -func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) { - c := &certificateV1{ - details: detailsV1{ - Name: t.Name, - Ips: t.Networks, - Subnets: t.UnsafeNetworks, - Groups: t.Groups, - NotBefore: t.NotBefore, - NotAfter: t.NotAfter, - PublicKey: t.PublicKey, - IsCA: t.IsCA, - Curve: t.Curve, - Issuer: t.issuer, - }, - } - b, err := proto.Marshal(c.getRawDetails()) - if err != nil { - return nil, err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - if client != nil { - sig, err = client.SignASN1(b) - } else { - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return nil, err - } - } - default: - return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) - } - - c.signature = sig - return c, nil -} - func ip2int(ip []byte) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 new file mode 100644 index 000000000..32a6e735e --- /dev/null +++ b/cert/cert_v2.asn1 @@ -0,0 +1,38 @@ +Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN + +Name ::= UTF8String (SIZE (1..253)) +Time ::= INTEGER (0..18446744073709551615) -- uint64 maximum +Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length +Curve ::= ENUMERATED { + curve25519 (0), + p256 (1) +} + +-- The maximum size of a certificate must not exceed 65536 bytes +Certificate ::= SEQUENCE { + details OCTET STRING, + curve Curve DEFAULT curve25519, + publicKey OCTET STRING, + -- signature(details + curve + publicKey) using the appropriate method for curve + signature OCTET STRING +} + +Details ::= SEQUENCE { + name Name, + + -- At least 1 ipv4 or ipv6 address must be present if isCA is false + networks SEQUENCE OF Network, + unsafeNetworks SEQUENCE OF Network OPTIONAL, + groups SEQUENCE OF Name OPTIONAL, + isCA BOOLEAN DEFAULT false, + + -- ASN.1 time formats are all string representations so we use our own uint64 limited one + notBefore Time, + notAfter Time, + + issuer OCTET STRING, + ... + -- New fields can be added below here +} + +END \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go new file mode 100644 index 000000000..96a1b9b2b --- /dev/null +++ b/cert/cert_v2.go @@ -0,0 +1,605 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net/netip" + "time" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/curve25519" +) + +//TODO: should we avoid hex encoding shit on output? Just let it be base64? + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 + + TagCertDetails = 0 | classConstructed | classContextSpecific + TagCertCurve = 1 | classContextSpecific + TagCertPublicKey = 2 | classContextSpecific + TagCertSignature = 3 | classContextSpecific + + TagDetailsName = 0 | classContextSpecific + TagDetailsIps = 1 | classConstructed | classContextSpecific + TagDetailsSubnets = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific +) + +const ( + // MaxCertificateSize is the maximum length a valid certificate can be + MaxCertificateSize = 65536 + + // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems + MaxNameLength = 253 + + // MaxSubnetLength is the maximum length a subnet value can be. + // 16 bytes for an ipv6 address + 1 byte for the prefix length + MaxSubnetLength = 17 +) + +type certificateV2 struct { + details detailsV2 + + // RawDetails contains the entire asn.1 DER encoded Details struct + // This is to benefit forwards compatibility in signature checking. + // signature(RawDetails + Curve + PublicKey) == Signature + rawDetails []byte + curve Curve + publicKey []byte + signature []byte +} + +type detailsV2 struct { + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + isCA bool + notBefore time.Time + notAfter time.Time + issuer string +} + +func (c *certificateV2) Version() Version { + return Version2 +} + +func (c *certificateV2) Curve() Curve { + return c.curve +} + +func (c *certificateV2) Groups() []string { + return c.details.groups +} + +func (c *certificateV2) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV2) Issuer() string { + return c.details.issuer +} + +func (c *certificateV2) Name() string { + return c.details.name +} + +func (c *certificateV2) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV2) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV2) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV2) PublicKey() []byte { + return c.publicKey +} + +func (c *certificateV2) Signature() []byte { + return c.signature +} + +func (c *certificateV2) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV2) Fingerprint() (string, error) { + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + //TODO: double check this, panic on empty raw details + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV2) CheckSignature(key []byte) bool { + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + //TODO: double check this, panic on empty raw details + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + + switch c.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + //TODO: NewPublicKey + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV2) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.curve { + return fmt.Errorf("curve in cert and private key supplied don't match") + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") + } + + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return fmt.Errorf("cannot parse private key as P256") + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return err + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return err + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.publicKey) { + return fmt.Errorf("public key in cert and private key supplied don't match") + } + + return nil +} + +func (c *certificateV2) String() string { + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return "" + } + return string(b) +} + +func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { + panic("TODO") +} + +func (c *certificateV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Add the curve only if its not the default value + if c.curve != Curve_CURVE25519 { + b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { + b.AddBytes([]byte{byte(c.curve)}) + }) + } + + // Add the public key if it is not empty + if c.publicKey != nil { + b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { + b.AddBytes(c.publicKey) + }) + } + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil +} + +func (c *certificateV2) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) +} + +func (c *certificateV2) marshalJSON() m { + fp, _ := c.Fingerprint() + return m{ + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "isCa": c.details.isCA, + "issuer": c.details.issuer, + }, + "publicKey": fmt.Sprintf("%x", c.publicKey), + "curve": c.curve.String(), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + } +} + +func (c *certificateV2) Copy() Certificate { + nc := &certificateV2{ + details: detailsV2{ + name: c.details.name, + groups: make([]string, len(c.details.groups)), + networks: make([]netip.Prefix, len(c.details.networks)), + unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)), + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + isCA: c.details.isCA, + issuer: c.details.issuer, + }, + curve: c.curve, + publicKey: make([]byte, len(c.publicKey)), + signature: make([]byte, len(c.signature)), + } + + copy(c.signature, c.signature) + copy(c.details.groups, c.details.groups) + copy(c.publicKey, c.publicKey) + + for i, p := range c.details.networks { + c.details.networks[i] = p + } + + for i, p := range c.details.unsafeNetworks { + c.details.unsafeNetworks[i] = p + } + + return nc +} + +func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV2{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + isCA: t.IsCA, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + issuer: t.issuer, + } + c.curve = t.Curve + c.publicKey = t.PublicKey + return nil +} + +func (c *certificateV2) marshalForSigning() ([]byte, error) { + d, err := c.details.Marshal() + if err != nil { + //TODO: annotate? + return nil, err + } + c.rawDetails = d + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + //TODO: double check this + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + return b, nil +} + +func (c *certificateV2) setSignature(b []byte) error { + c.signature = b + return nil +} + +func (d *detailsV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + var err error + + // Details are a structure + b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { + + // Add the name + b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(d.name)) + }) + + // Add the ips if any exist + if len(d.networks) > 0 { + b.AddASN1(TagDetailsIps, func(b *cryptobyte.Builder) { + for _, subnet := range d.networks { + sb, innerErr := subnet.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal ip: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add the subnets if any exist + if len(d.unsafeNetworks) > 0 { + b.AddASN1(TagDetailsSubnets, func(b *cryptobyte.Builder) { + for _, subnet := range d.unsafeNetworks { + sb, innerErr := subnet.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal subnet: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add groups if any exist + if len(d.groups) > 0 { + b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { + for _, group := range d.groups { + b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(group)) + }) + } + }) + } + + // Add IsCA only if true + if d.isCA { + b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { + b.AddUint8(0xff) + }) + } + + // Add not before + b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) + + // Add not after + b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) + + // Add the issuer if present + if d.issuer != "" { + issuerBytes, innerErr := hex.DecodeString(d.issuer) + if innerErr != nil { + err = fmt.Errorf("failed to decode issuer: %w", innerErr) + return + } + b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { + b.AddBytes(issuerBytes) + }) + } + }) + + if err != nil { + return nil, err + } + + return b.Bytes() +} + +func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) { + l := len(b) + if l == 0 || l > MaxCertificateSize { + return nil, ErrBadFormat + } + + input := cryptobyte.String(b) + // Open the envelope + if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { + return nil, ErrBadFormat + } + + // Grab the cert details, we need to preserve the tag and length + var rawDetails cryptobyte.String + if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { + return nil, ErrBadFormat + } + + var rawCurve byte + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) { + return nil, ErrBadFormat + } + curve := Curve(rawCurve) + + // Maybe grab the public key + var rawPublicKey cryptobyte.String + if len(publicKey) > 0 { + rawPublicKey = publicKey + } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { + return nil, ErrBadFormat + } + + //TODO: Assert public key length + + // Grab the signature + var rawSignature cryptobyte.String + if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { + return nil, ErrBadFormat + } + + // Finally unmarshal the details + details, err := unmarshalDetails(rawDetails) + if err != nil { + return nil, err + } + + return &certificateV2{ + details: details, + rawDetails: rawDetails, + curve: curve, + publicKey: rawPublicKey, + signature: rawSignature, + }, nil +} + +func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { + // Open the envelope + if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { + return detailsV2{}, ErrBadFormat + } + + // Read the name + var name cryptobyte.String + if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { + return detailsV2{}, ErrBadFormat + } + + // Read the ip addresses + var subString cryptobyte.String + var found bool + + if !b.ReadOptionalASN1(&subString, &found, TagDetailsIps) { + return detailsV2{}, ErrBadFormat + } + + var ips []netip.Prefix + var val cryptobyte.String + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxSubnetLength { + return detailsV2{}, ErrBadFormat + } + + var ip netip.Prefix + if err := ip.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + ips = append(ips, ip) + } + } + + // Read out any subnets + if !b.ReadOptionalASN1(&subString, &found, TagDetailsSubnets) { + return detailsV2{}, ErrBadFormat + } + + var subnets []netip.Prefix + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxSubnetLength { + return detailsV2{}, ErrBadFormat + } + + var subnet netip.Prefix + if err := subnet.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + subnets = append(subnets, subnet) + } + } + + // Read out any groups + if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { + return detailsV2{}, ErrBadFormat + } + + var groups []string + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { + return detailsV2{}, ErrBadFormat + } + groups = append(groups, string(val)) + } + } + + // Read out IsCA + var isCa bool + if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { + return detailsV2{}, ErrBadFormat + } + + // Read not before and not after + var notBefore int64 + if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { + return detailsV2{}, ErrBadFormat + } + + var notAfter int64 + if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { + return detailsV2{}, ErrBadFormat + } + + // Read issuer + var issuer cryptobyte.String + if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { + return detailsV2{}, ErrBadFormat + } + + return detailsV2{ + name: string(name), + networks: ips, + unsafeNetworks: subnets, + groups: groups, + isCA: isCa, + notBefore: time.Unix(notBefore, 0), + notAfter: time.Unix(notAfter, 0), + issuer: hex.EncodeToString(issuer), + }, nil +} diff --git a/cert/errors.go b/cert/errors.go index 06ce99612..a590b1674 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -23,4 +23,7 @@ var ( ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") + + ErrNoPeerStaticKey = errors.New("no peer static key was present") + ErrNoPayload = errors.New("provided payload was empty") ) diff --git a/cert/pem.go b/cert/pem.go index 744ae2edf..8f9fe8e99 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } + var c Certificate + var err error + switch p.Type { case CertificateBanner: - c, err := unmarshalCertificateV1(p.Bytes, true) - if err != nil { - return nil, nil, err - } - return c, r, nil + c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: - //TODO - panic("TODO") + c, err = unmarshalCertificateV2(p.Bytes, nil) default: return nil, r, ErrInvalidPEMCertificateBanner } + + if err != nil { + return nil, r, err + } + + return c, r, nil + } func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { diff --git a/cert/sign.go b/cert/sign.go index e446aa131..dcb5a5d0e 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -1,7 +1,13 @@ package cert import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" "fmt" + "math/big" "net/netip" "time" @@ -24,6 +30,17 @@ type TBSCertificate struct { issuer string } +type beingSignedCertificate interface { + // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + fromTBSCertificate(*TBSCertificate) error + + // marshalForSigning returns the bytes that should be signed + marshalForSigning() ([]byte, error) + + // setSignature sets the signature for the certificate that has just been signed + setSignature([]byte) error +} + // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // details do not violate constraints of the signing certificate. // If the TBSCertificate is a CA then signer must be nil. @@ -67,10 +84,71 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien } } + var c beingSignedCertificate switch t.Version { case Version1: - return signV1(t, curve, key, client) + c = &certificateV1{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + case Version2: + c = &certificateV2{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } default: return nil, fmt.Errorf("unknown cert version %d", t.Version) } + + certBytes, err := c.marshalForSigning() + if err != nil { + return nil, err + } + + var sig []byte + switch t.Curve { + case Curve_CURVE25519: + signer := ed25519.PrivateKey(key) + sig = ed25519.Sign(signer, certBytes) + case Curve_P256: + if client != nil { + sig, err = client.SignASN1(certBytes) + } else { + signer := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) + + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) + } + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) + } + + if err != nil { + return nil, err + } + + //TODO: check if we have sig bytes? + err = c.setSignature(sig) + if err != nil { + return nil, err + } + + sc, ok := c.(Certificate) + if !ok { + return nil, fmt.Errorf("invalid certificate") + } + + return sc, nil } diff --git a/handshake_ix.go b/handshake_ix.go index 24c423d64..0448385c3 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -6,6 +6,7 @@ import ( "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -29,6 +30,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), Cert: certState.RawCertificateNoKey, + CertVersion: uint32(certState.Certificate.Version()), } hsBytes := []byte{} @@ -86,7 +88,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -166,6 +168,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hs.Details.ResponderIndex = myIndex hs.Details.Cert = certState.RawCertificateNoKey + hs.Details.CertVersion = uint32(certState.Certificate.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) @@ -386,7 +389,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) diff --git a/nebula.pb.go b/nebula.pb.go index b3c723a46..3ae0371ef 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -477,6 +477,7 @@ type NebulaHandshakeDetails struct { ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` + CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` } func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } @@ -547,6 +548,13 @@ func (m *NebulaHandshakeDetails) GetTime() uint64 { return 0 } +func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { + if m != nil { + return m.CertVersion + } + return 0 +} + type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` @@ -640,52 +648,52 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 707 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0x4d, 0x6f, 0xda, 0x4a, - 0x14, 0xc5, 0xc6, 0x7c, 0x5d, 0x02, 0xf1, 0xbb, 0x79, 0x8f, 0x07, 0x4f, 0xaf, 0x16, 0xf5, 0xa2, - 0x62, 0x45, 0x22, 0x92, 0x46, 0x5d, 0x36, 0xa5, 0xaa, 0x20, 0x4a, 0x22, 0x3a, 0x4a, 0x5b, 0xa9, - 0x9b, 0x6a, 0x62, 0xa6, 0xc1, 0x02, 0x3c, 0x8e, 0x3d, 0x54, 0xe1, 0x5f, 0xf4, 0xc7, 0xe4, 0x47, - 0x74, 0xd7, 0x2c, 0xbb, 0xac, 0x92, 0x65, 0x97, 0xfd, 0x03, 0xd5, 0x8c, 0xc1, 0x36, 0x84, 0x76, - 0x37, 0xe7, 0xde, 0x73, 0x66, 0xce, 0x9c, 0xb9, 0x36, 0x6c, 0x79, 0xec, 0x62, 0x36, 0xa1, 0x6d, - 0x3f, 0xe0, 0x82, 0x63, 0x3e, 0x42, 0xf6, 0x0f, 0x1d, 0xe0, 0x4c, 0x2d, 0x4f, 0x99, 0xa0, 0xd8, - 0x01, 0xe3, 0x7c, 0xee, 0xb3, 0xba, 0xd6, 0xd4, 0x5a, 0xd5, 0x8e, 0xd5, 0x5e, 0x68, 0x12, 0x46, - 0xfb, 0x94, 0x85, 0x21, 0xbd, 0x64, 0x92, 0x45, 0x14, 0x17, 0xf7, 0xa1, 0xf0, 0x92, 0x09, 0xea, - 0x4e, 0xc2, 0xba, 0xde, 0xd4, 0x5a, 0xe5, 0x4e, 0xe3, 0xa1, 0x6c, 0x41, 0x20, 0x4b, 0xa6, 0xfd, - 0x53, 0x83, 0x72, 0x6a, 0x2b, 0x2c, 0x82, 0x71, 0xc6, 0x3d, 0x66, 0x66, 0xb0, 0x02, 0xa5, 0x1e, - 0x0f, 0xc5, 0xeb, 0x19, 0x0b, 0xe6, 0xa6, 0x86, 0x08, 0xd5, 0x18, 0x12, 0xe6, 0x4f, 0xe6, 0xa6, - 0x8e, 0xff, 0x41, 0x4d, 0xd6, 0xde, 0xf8, 0x43, 0x2a, 0xd8, 0x19, 0x17, 0xee, 0x47, 0xd7, 0xa1, - 0xc2, 0xe5, 0x9e, 0x99, 0xc5, 0x06, 0xfc, 0x23, 0x7b, 0xa7, 0xfc, 0x13, 0x1b, 0xae, 0xb4, 0x8c, - 0x65, 0x6b, 0x30, 0xf3, 0x9c, 0xd1, 0x4a, 0x2b, 0x87, 0x55, 0x00, 0xd9, 0x7a, 0x37, 0xe2, 0x74, - 0xea, 0x9a, 0x79, 0xdc, 0x81, 0xed, 0x04, 0x47, 0xc7, 0x16, 0xa4, 0xb3, 0x01, 0x15, 0xa3, 0xee, - 0x88, 0x39, 0x63, 0xb3, 0x28, 0x9d, 0xc5, 0x30, 0xa2, 0x94, 0xf0, 0x11, 0x34, 0x36, 0x3b, 0x3b, - 0x72, 0xc6, 0x26, 0xd8, 0x5f, 0x35, 0xf8, 0xeb, 0x41, 0x28, 0xf8, 0x37, 0xe4, 0xde, 0xfa, 0x5e, - 0xdf, 0x57, 0xa9, 0x57, 0x48, 0x04, 0xf0, 0x00, 0xca, 0x7d, 0xff, 0xe0, 0xc8, 0x1b, 0x0e, 0x78, - 0x20, 0x64, 0xb4, 0xd9, 0x56, 0xb9, 0x83, 0xcb, 0x68, 0x93, 0x16, 0x49, 0xd3, 0x22, 0xd5, 0x61, - 0xac, 0x32, 0xd6, 0x55, 0x87, 0x29, 0x55, 0x4c, 0x43, 0x0b, 0x80, 0xb0, 0x09, 0x9d, 0x47, 0x36, - 0x72, 0xcd, 0x6c, 0xab, 0x42, 0x52, 0x15, 0xac, 0x43, 0xc1, 0xe1, 0x33, 0x4f, 0xb0, 0xa0, 0x9e, - 0x55, 0x1e, 0x97, 0xd0, 0xde, 0x03, 0x48, 0x8e, 0xc7, 0x2a, 0xe8, 0xf1, 0x35, 0xf4, 0xbe, 0x8f, - 0x08, 0x86, 0xac, 0xab, 0xb9, 0xa8, 0x10, 0xb5, 0xb6, 0x9f, 0x4b, 0xc5, 0x61, 0x4a, 0xd1, 0x73, - 0x95, 0xc2, 0x20, 0x7a, 0xcf, 0x95, 0xf8, 0x84, 0x2b, 0xbe, 0x41, 0xf4, 0x13, 0x1e, 0xef, 0x90, - 0x4d, 0xed, 0x70, 0xbd, 0x1c, 0xd9, 0x81, 0xeb, 0x5d, 0xfe, 0x79, 0x64, 0x25, 0x63, 0xc3, 0xc8, - 0x22, 0x18, 0xe7, 0xee, 0x94, 0x2d, 0xce, 0x51, 0x6b, 0xdb, 0x7e, 0x30, 0x90, 0x52, 0x6c, 0x66, - 0xb0, 0x04, 0xb9, 0xe8, 0x79, 0x35, 0xfb, 0x03, 0x6c, 0x47, 0xfb, 0xf6, 0xa8, 0x37, 0x0c, 0x47, - 0x74, 0xcc, 0xf0, 0x59, 0x32, 0xfd, 0x9a, 0x9a, 0xfe, 0x35, 0x07, 0x31, 0x73, 0xfd, 0x13, 0x90, - 0x26, 0x7a, 0x53, 0xea, 0x28, 0x13, 0x5b, 0x44, 0xad, 0xed, 0x1b, 0x0d, 0x6a, 0x9b, 0x75, 0x92, - 0xde, 0x65, 0x81, 0x50, 0xa7, 0x6c, 0x11, 0xb5, 0xc6, 0x27, 0x50, 0xed, 0x7b, 0xae, 0x70, 0xa9, - 0xe0, 0x41, 0xdf, 0x1b, 0xb2, 0xeb, 0x45, 0xd2, 0x6b, 0x55, 0xc9, 0x23, 0x2c, 0xf4, 0xb9, 0x37, - 0x64, 0x0b, 0x5e, 0x94, 0xe7, 0x5a, 0x15, 0x6b, 0x90, 0xef, 0x72, 0x3e, 0x76, 0x59, 0xdd, 0x50, - 0xc9, 0x2c, 0x50, 0x9c, 0x57, 0x2e, 0xc9, 0xeb, 0xd8, 0x28, 0xe6, 0xcd, 0xc2, 0xb1, 0x51, 0x2c, - 0x98, 0x45, 0xfb, 0x46, 0x87, 0x4a, 0x64, 0xbb, 0xcb, 0x3d, 0x11, 0xf0, 0x09, 0x3e, 0x5d, 0x79, - 0x95, 0xc7, 0xab, 0x99, 0x2c, 0x48, 0x1b, 0x1e, 0x66, 0x0f, 0x76, 0x62, 0xeb, 0x6a, 0xfe, 0xd2, - 0xb7, 0xda, 0xd4, 0x92, 0x8a, 0xf8, 0x12, 0x29, 0x45, 0x74, 0xbf, 0x4d, 0x2d, 0xfc, 0x1f, 0x4a, - 0x0a, 0x9d, 0xf3, 0xbe, 0xaf, 0xee, 0x59, 0x21, 0x49, 0x01, 0x9b, 0x50, 0x56, 0xe0, 0x55, 0xc0, - 0xa7, 0xea, 0x5b, 0x90, 0xfd, 0x74, 0xc9, 0xee, 0xfd, 0xee, 0xcf, 0x55, 0x03, 0xec, 0x06, 0x8c, - 0x0a, 0xa6, 0xd8, 0x84, 0x5d, 0xcd, 0x58, 0x28, 0x4c, 0x0d, 0xff, 0x85, 0x9d, 0x95, 0xba, 0xb4, - 0x14, 0x32, 0x53, 0x7f, 0xb1, 0xff, 0xe5, 0xce, 0xd2, 0x6e, 0xef, 0x2c, 0xed, 0xfb, 0x9d, 0xa5, - 0x7d, 0xbe, 0xb7, 0x32, 0xb7, 0xf7, 0x56, 0xe6, 0xdb, 0xbd, 0x95, 0x79, 0xdf, 0xb8, 0x74, 0xc5, - 0x68, 0x76, 0xd1, 0x76, 0xf8, 0x74, 0x37, 0x9c, 0x50, 0x67, 0x3c, 0xba, 0xda, 0x8d, 0x22, 0xbc, - 0xc8, 0xab, 0x1f, 0xf8, 0xfe, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x17, 0x56, 0x28, 0x74, 0xd0, - 0x05, 0x00, 0x00, + // 720 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0xcd, 0x6e, 0xd3, 0x40, + 0x10, 0x8e, 0x1d, 0xe7, 0x6f, 0xd2, 0xa4, 0x66, 0x0a, 0x21, 0x41, 0x60, 0x05, 0x1f, 0x50, 0x4e, + 0x69, 0x95, 0x96, 0x8a, 0x23, 0x25, 0x08, 0x25, 0x55, 0x5b, 0x85, 0x55, 0x29, 0x12, 0x17, 0xb4, + 0x75, 0x96, 0xc6, 0x4a, 0xe2, 0x75, 0xed, 0x0d, 0x6a, 0xde, 0x82, 0x87, 0xe1, 0x21, 0xb8, 0xd1, + 0x13, 0xe2, 0x88, 0xda, 0x23, 0x47, 0x5e, 0x00, 0xed, 0x3a, 0x71, 0x9c, 0x34, 0x70, 0xdb, 0x99, + 0xf9, 0xbe, 0xd9, 0x6f, 0xbe, 0x1d, 0x1b, 0x36, 0x3c, 0x76, 0x3e, 0x19, 0xd1, 0xa6, 0x1f, 0x70, + 0xc1, 0x31, 0x1b, 0x45, 0xf6, 0x6f, 0x1d, 0xe0, 0x44, 0x1d, 0x8f, 0x99, 0xa0, 0xd8, 0x02, 0xe3, + 0x74, 0xea, 0xb3, 0xaa, 0x56, 0xd7, 0x1a, 0xe5, 0x96, 0xd5, 0x9c, 0x71, 0x16, 0x88, 0xe6, 0x31, + 0x0b, 0x43, 0x7a, 0xc1, 0x24, 0x8a, 0x28, 0x2c, 0xee, 0x42, 0xee, 0x35, 0x13, 0xd4, 0x1d, 0x85, + 0x55, 0xbd, 0xae, 0x35, 0x8a, 0xad, 0xda, 0x5d, 0xda, 0x0c, 0x40, 0xe6, 0x48, 0xfb, 0x8f, 0x06, + 0xc5, 0x44, 0x2b, 0xcc, 0x83, 0x71, 0xc2, 0x3d, 0x66, 0xa6, 0xb0, 0x04, 0x85, 0x0e, 0x0f, 0xc5, + 0xdb, 0x09, 0x0b, 0xa6, 0xa6, 0x86, 0x08, 0xe5, 0x38, 0x24, 0xcc, 0x1f, 0x4d, 0x4d, 0x1d, 0x1f, + 0x41, 0x45, 0xe6, 0xde, 0xf9, 0x7d, 0x2a, 0xd8, 0x09, 0x17, 0xee, 0x27, 0xd7, 0xa1, 0xc2, 0xe5, + 0x9e, 0x99, 0xc6, 0x1a, 0x3c, 0x90, 0xb5, 0x63, 0xfe, 0x99, 0xf5, 0x97, 0x4a, 0xc6, 0xbc, 0xd4, + 0x9b, 0x78, 0xce, 0x60, 0xa9, 0x94, 0xc1, 0x32, 0x80, 0x2c, 0xbd, 0x1f, 0x70, 0x3a, 0x76, 0xcd, + 0x2c, 0x6e, 0xc1, 0xe6, 0x22, 0x8e, 0xae, 0xcd, 0x49, 0x65, 0x3d, 0x2a, 0x06, 0xed, 0x01, 0x73, + 0x86, 0x66, 0x5e, 0x2a, 0x8b, 0xc3, 0x08, 0x52, 0xc0, 0x27, 0x50, 0x5b, 0xaf, 0xec, 0xc0, 0x19, + 0x9a, 0x60, 0x7f, 0xd7, 0xe0, 0xde, 0x1d, 0x53, 0xf0, 0x3e, 0x64, 0xce, 0x7c, 0xaf, 0xeb, 0x2b, + 0xd7, 0x4b, 0x24, 0x0a, 0x70, 0x0f, 0x8a, 0x5d, 0x7f, 0xef, 0xc0, 0xeb, 0xf7, 0x78, 0x20, 0xa4, + 0xb5, 0xe9, 0x46, 0xb1, 0x85, 0x73, 0x6b, 0x17, 0x25, 0x92, 0x84, 0x45, 0xac, 0xfd, 0x98, 0x65, + 0xac, 0xb2, 0xf6, 0x13, 0xac, 0x18, 0x86, 0x16, 0x00, 0x61, 0x23, 0x3a, 0x8d, 0x64, 0x64, 0xea, + 0xe9, 0x46, 0x89, 0x24, 0x32, 0x58, 0x85, 0x9c, 0xc3, 0x27, 0x9e, 0x60, 0x41, 0x35, 0xad, 0x34, + 0xce, 0x43, 0x7b, 0x07, 0x60, 0x71, 0x3d, 0x96, 0x41, 0x8f, 0xc7, 0xd0, 0xbb, 0x3e, 0x22, 0x18, + 0x32, 0xaf, 0xf6, 0xa2, 0x44, 0xd4, 0xd9, 0x7e, 0x29, 0x19, 0xfb, 0x09, 0x46, 0xc7, 0x55, 0x0c, + 0x83, 0xe8, 0x1d, 0x57, 0xc6, 0x47, 0x5c, 0xe1, 0x0d, 0xa2, 0x1f, 0xf1, 0xb8, 0x43, 0x3a, 0xd1, + 0xe1, 0x6a, 0xbe, 0xb2, 0x3d, 0xd7, 0xbb, 0xf8, 0xff, 0xca, 0x4a, 0xc4, 0x9a, 0x95, 0x45, 0x30, + 0x4e, 0xdd, 0x31, 0x9b, 0xdd, 0xa3, 0xce, 0xb6, 0x7d, 0x67, 0x21, 0x25, 0xd9, 0x4c, 0x61, 0x01, + 0x32, 0xd1, 0xf3, 0x6a, 0xf6, 0x47, 0xd8, 0x8c, 0xfa, 0x76, 0xa8, 0xd7, 0x0f, 0x07, 0x74, 0xc8, + 0xf0, 0xc5, 0x62, 0xfb, 0x35, 0xb5, 0xfd, 0x2b, 0x0a, 0x62, 0xe4, 0xea, 0x27, 0x20, 0x45, 0x74, + 0xc6, 0xd4, 0x51, 0x22, 0x36, 0x88, 0x3a, 0xdb, 0x3f, 0x34, 0xa8, 0xac, 0xe7, 0x49, 0x78, 0x9b, + 0x05, 0x42, 0xdd, 0xb2, 0x41, 0xd4, 0x19, 0x9f, 0x41, 0xb9, 0xeb, 0xb9, 0xc2, 0xa5, 0x82, 0x07, + 0x5d, 0xaf, 0xcf, 0xae, 0x66, 0x4e, 0xaf, 0x64, 0x25, 0x8e, 0xb0, 0xd0, 0xe7, 0x5e, 0x9f, 0xcd, + 0x70, 0x91, 0x9f, 0x2b, 0x59, 0xac, 0x40, 0xb6, 0xcd, 0xf9, 0xd0, 0x65, 0x55, 0x43, 0x39, 0x33, + 0x8b, 0x62, 0xbf, 0x32, 0x0b, 0xbf, 0xb0, 0x0e, 0x45, 0xa9, 0xe1, 0x8c, 0x05, 0xa1, 0xcb, 0xbd, + 0x6a, 0x5e, 0x35, 0x4c, 0xa6, 0x0e, 0x8d, 0x7c, 0xd6, 0xcc, 0x1d, 0x1a, 0xf9, 0x9c, 0x99, 0xb7, + 0xbf, 0xea, 0x50, 0x8a, 0x06, 0x6b, 0x73, 0x4f, 0x04, 0x7c, 0x84, 0xcf, 0x97, 0xde, 0xed, 0xe9, + 0xb2, 0x6b, 0x33, 0xd0, 0x9a, 0xa7, 0xdb, 0x81, 0xad, 0x78, 0x38, 0xb5, 0xa1, 0xc9, 0xb9, 0xd7, + 0x95, 0x24, 0x23, 0x1e, 0x33, 0xc1, 0x88, 0x1c, 0x58, 0x57, 0xc2, 0xc7, 0x50, 0x50, 0xd1, 0x29, + 0xef, 0xfa, 0xca, 0x89, 0x12, 0x59, 0x24, 0xe4, 0xe0, 0x2a, 0x78, 0x13, 0xf0, 0xb1, 0xfa, 0x5a, + 0xd4, 0xe0, 0x89, 0x94, 0xdd, 0xf9, 0xd7, 0xbf, 0xad, 0x02, 0xd8, 0x0e, 0x18, 0x15, 0x4c, 0xa1, + 0x09, 0xbb, 0x9c, 0xb0, 0x50, 0x98, 0x1a, 0x3e, 0x84, 0xad, 0xa5, 0xbc, 0x94, 0x14, 0x32, 0x53, + 0x7f, 0xb5, 0xfb, 0xed, 0xc6, 0xd2, 0xae, 0x6f, 0x2c, 0xed, 0xd7, 0x8d, 0xa5, 0x7d, 0xb9, 0xb5, + 0x52, 0xd7, 0xb7, 0x56, 0xea, 0xe7, 0xad, 0x95, 0xfa, 0x50, 0xbb, 0x70, 0xc5, 0x60, 0x72, 0xde, + 0x74, 0xf8, 0x78, 0x3b, 0x1c, 0x51, 0x67, 0x38, 0xb8, 0xdc, 0x8e, 0x2c, 0x3c, 0xcf, 0xaa, 0x5f, + 0xfc, 0xee, 0xdf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xa5, 0xef, 0x45, 0xf2, 0x05, 0x00, 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -973,6 +981,11 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) _ = i var l int _ = l + if m.CertVersion != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) + i-- + dAtA[i] = 0x40 + } if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- @@ -1199,6 +1212,9 @@ func (m *NebulaHandshakeDetails) Size() (n int) { if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } + if m.CertVersion != 0 { + n += 1 + sovNebula(uint64(m.CertVersion)) + } return n } @@ -2111,6 +2127,25 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { break } } + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) + } + m.CertVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CertVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index 88e33b7e9..4dc15f193 100644 --- a/nebula.proto +++ b/nebula.proto @@ -62,6 +62,7 @@ message NebulaHandshakeDetails { uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; + uint32 CertVersion = 8; // reserved for WIP multiport reserved 6, 7; } diff --git a/outside.go b/outside.go index 6a71fe77f..c83d77cdb 100644 --- a/outside.go +++ b/outside.go @@ -7,9 +7,7 @@ import ( "net/netip" "time" - "github.com/flynn/noise" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -492,27 +490,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N f.outside.WriteTo(msg, endpoint) } */ - -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) { - pk := h.PeerStatic() - - if pk == nil { - return nil, errors.New("no peer static key was present") - } - - if rawCertBytes == nil { - return nil, errors.New("provided payload was empty") - } - - c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %w", err) - } - - cc, err := caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return nil, fmt.Errorf("certificate validation failed: %w", err) - } - - return cc, nil -} From 5bfb9c4c9fb82f358529653b92e1c175c3102474 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 13 Sep 2024 13:50:46 -0500 Subject: [PATCH 02/17] Ignore private key verification for pkcs11 private keys --- pki.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pki.go b/pki.go index fe64ea5ee..490b30c8b 100644 --- a/pki.go +++ b/pki.go @@ -211,8 +211,12 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("no networks encoded in certificate") } - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + if isPkcs11 { + //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } } return newCertState(nebulaCert, isPkcs11, rawKey) From 442fe47416374d24f949dec56434e84fca282b2d Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 16 Sep 2024 17:27:39 -0500 Subject: [PATCH 03/17] Fixup nebula-cert tests for v1 certs --- cmd/nebula-cert/ca.go | 66 +++++++++++++++-------- cmd/nebula-cert/ca_test.go | 34 +++++++----- cmd/nebula-cert/sign.go | 100 ++++++++++++++++++++++++----------- cmd/nebula-cert/sign_test.go | 79 +++++++++++++++------------ 4 files changed, 180 insertions(+), 99 deletions(-) diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index d098160c8..78416698a 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -27,34 +27,43 @@ type caFlags struct { outCertPath *string outQRPath *string groups *string - ips *string - subnets *string + networks *string + unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool + version *uint curve *string p11url *string + + // Deprecated options + ips *string + subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") - cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") - cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") + cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.p11url = p11Flag(cf.set) + + cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") + cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } @@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []netip.Prefix - if *cf.ips != "" { - for _, rs := range strings.Split(*cf.ips, ",") { + version := cert.Version(*cf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + + var networks []netip.Prefix + if *cf.networks == "" && *cf.ips != "" { + // Pull up deprecated -ips flag if needed + *cf.networks = *cf.ips + } + + if *cf.networks != "" { + for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid -networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } - ips = append(ips, n) + networks = append(networks, n) } } } - var subnets []netip.Prefix - if *cf.subnets != "" { - for _, rs := range strings.Split(*cf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *cf.unsafeNetworks == "" && *cf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *cf.unsafeNetworks = *cf.subnets + } + + if *cf.unsafeNetworks != "" { + for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } - subnets = append(subnets, n) + unsafeNetworks = append(unsafeNetworks, n) } } } @@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } t := &cert.TBSCertificate{ - Version: cert.Version1, + Version: version, Name: *cf.name, Groups: groups, - Networks: ips, - UnsafeNetworks: subnets, + Networks: networks, + UnsafeNetworks: unsafeNetworks, NotBefore: time.Now(), NotAfter: time.Now().Add(*cf.duration), PublicKey: pub, diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 06a24edd2..d32044340 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) { " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ + " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ + " -networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ @@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -83,25 +89,25 @@ func Test_ca(t *testing.T) { // required args assertHelpError(t, ca( - []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // failed key write ob.Reset() eb.Reset() - args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -114,7 +120,7 @@ func Test_ca(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -128,7 +134,7 @@ func Test_ca(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -161,7 +167,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -189,7 +195,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -199,7 +205,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -209,13 +215,13 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -224,7 +230,7 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 13e807f3b..345c0f3c5 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -18,36 +18,46 @@ import ( ) type signFlags struct { - set *flag.FlagSet - caKeyPath *string - caCertPath *string - name *string - ip *string - duration *time.Duration - inPubPath *string - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - subnets *string - p11url *string + set *flag.FlagSet + version *uint + caKeyPath *string + caCertPath *string + name *string + networks *string + unsafeNetworks *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + + p11url *string + + // Deprecated options + ip *string + subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} + sf.version = sf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") - sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") + sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") + sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") - sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") sf.p11url = p11Flag(sf.set) + + sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") + sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &sf } @@ -78,6 +88,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return newHelpErrorf("cannot set both -in-pub and -out-key") } + version := cert.Version(*sf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + var curve cert.Curve var caKey []byte @@ -146,12 +161,30 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - network, err := netip.ParsePrefix(*sf.ip) - if err != nil { - return newHelpErrorf("invalid ip definition: %s", *sf.ip) + var networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if *sf.networks != "" { + for _, rs := range strings.Split(*sf.networks, ",") { + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -networks definition: %s", rs) + } + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) + } + networks = append(networks, n) + } + } } - if !network.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) + + if len(networks) > 1 && version == cert.Version1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } var groups []string @@ -164,19 +197,24 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - var subnets []netip.Prefix - if *sf.subnets != "" { - for _, rs := range strings.Split(*sf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *sf.unsafeNetworks == "" && *sf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *sf.unsafeNetworks = *sf.subnets + } + + if *sf.unsafeNetworks != "" { + for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - s, err := netip.ParsePrefix(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", rs) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if !s.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: can only be ipv4, have %s", rs) } - subnets = append(subnets, s) + unsafeNetworks = append(unsafeNetworks, n) } } } @@ -219,11 +257,11 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } t := &cert.TBSCertificate{ - Version: cert.Version1, + Version: version, Name: *sf.name, - Networks: []netip.Prefix{network}, + Networks: networks, Groups: groups, - UnsafeNetworks: subnets, + UnsafeNetworks: unsafeNetworks, NotBefore: time.Now(), NotAfter: time.Now().Add(*sf.duration), PublicKey: pub, diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index b68434df7..6985f2473 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) { " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ - " \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ + " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ + " -networks string\n"+ + " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ @@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) { // required args assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-ip is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) { // failed to read key ob.Reset() eb.Reset() - args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key @@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caKeyF.Name()) - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) { caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caCrtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // failed to read pub - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(inPubF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) { // bad ip cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: can only be ipv4, have 100::100/100") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) { // failed key write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) { eb.Reset() os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // test with the proper password - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) { eb.Reset() testpw.password = []byte("invalid password") - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) @@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) From 76170ff850bb760429ac5bad744512a48a73350e Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 18 Sep 2024 16:11:07 -0500 Subject: [PATCH 04/17] Initial dual cert version nebula-cert support --- cert/cert_v1.go | 244 ++++++++++++++++------------------------ cert/cert_v2.go | 1 + cert/sign.go | 13 +++ cmd/nebula-cert/sign.go | 168 ++++++++++++++++++--------- 4 files changed, 228 insertions(+), 198 deletions(-) diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 83f7e3b54..62c1cc66f 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -28,17 +28,17 @@ type certificateV1 struct { } type detailsV1 struct { - Name string - Ips []netip.Prefix - Subnets []netip.Prefix - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + notBefore time.Time + notAfter time.Time + publicKey []byte + isCA bool + issuer string - Curve Curve + curve Curve } type m map[string]interface{} @@ -48,39 +48,39 @@ func (c *certificateV1) Version() Version { } func (c *certificateV1) Curve() Curve { - return c.details.Curve + return c.details.curve } func (c *certificateV1) Groups() []string { - return c.details.Groups + return c.details.groups } func (c *certificateV1) IsCA() bool { - return c.details.IsCA + return c.details.isCA } func (c *certificateV1) Issuer() string { - return c.details.Issuer + return c.details.issuer } func (c *certificateV1) Name() string { - return c.details.Name + return c.details.name } func (c *certificateV1) Networks() []netip.Prefix { - return c.details.Ips + return c.details.networks } func (c *certificateV1) NotAfter() time.Time { - return c.details.NotAfter + return c.details.notAfter } func (c *certificateV1) NotBefore() time.Time { - return c.details.NotBefore + return c.details.notBefore } func (c *certificateV1) PublicKey() []byte { - return c.details.PublicKey + return c.details.publicKey } func (c *certificateV1) Signature() []byte { @@ -88,7 +88,7 @@ func (c *certificateV1) Signature() []byte { } func (c *certificateV1) UnsafeNetworks() []netip.Prefix { - return c.details.Subnets + return c.details.unsafeNetworks } func (c *certificateV1) Fingerprint() (string, error) { @@ -106,7 +106,7 @@ func (c *certificateV1) CheckSignature(key []byte) bool { if err != nil { return false } - switch c.details.Curve { + switch c.details.curve { case Curve_CURVE25519: return ed25519.Verify(key, b, c.signature) case Curve_P256: @@ -120,14 +120,14 @@ func (c *certificateV1) CheckSignature(key []byte) bool { } func (c *certificateV1) Expired(t time.Time) bool { - return c.details.NotBefore.After(t) || c.details.NotAfter.Before(t) + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) } func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != c.details.Curve { + if curve != c.details.curve { return fmt.Errorf("curve in cert and private key supplied don't match") } - if c.details.IsCA { + if c.details.isCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise @@ -135,7 +135,7 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } - if !ed25519.PublicKey(c.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: @@ -144,7 +144,7 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("cannot parse private key as P256") } pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, c.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: @@ -170,7 +170,7 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { default: return fmt.Errorf("invalid curve: %s", curve) } - if !bytes.Equal(pub, c.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } @@ -180,97 +180,49 @@ func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { // getRawDetails marshals the raw details into protobuf ready struct func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ - Name: c.details.Name, - Groups: c.details.Groups, - NotBefore: c.details.NotBefore.Unix(), - NotAfter: c.details.NotAfter.Unix(), - PublicKey: make([]byte, len(c.details.PublicKey)), - IsCA: c.details.IsCA, - Curve: c.details.Curve, + Name: c.details.name, + Groups: c.details.groups, + NotBefore: c.details.notBefore.Unix(), + NotAfter: c.details.notAfter.Unix(), + PublicKey: make([]byte, len(c.details.publicKey)), + IsCA: c.details.isCA, + Curve: c.details.curve, } - for _, ipNet := range c.details.Ips { + for _, ipNet := range c.details.networks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } - for _, ipNet := range c.details.Subnets { + for _, ipNet := range c.details.unsafeNetworks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } - copy(rd.PublicKey, c.details.PublicKey[:]) + copy(rd.PublicKey, c.details.publicKey[:]) // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(c.details.Issuer) + rd.Issuer, _ = hex.DecodeString(c.details.issuer) return rd } func (c *certificateV1) String() string { - if c == nil { - return "Certificate {}\n" - } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", c.details.Name) - - if len(c.details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range c.details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(c.details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range c.details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(c.details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range c.details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", c.details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", c.details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", c.details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", c.details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", c.details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", c.details.Curve) - s += "\t}\n" - fp, err := c.Fingerprint() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return "" } - s += fmt.Sprintf("\tSignature: %x\n", c.Signature()) - s += "}" - - return s + return string(b) } func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { - pubKey := c.details.PublicKey - c.details.PublicKey = nil + pubKey := c.details.publicKey + c.details.publicKey = nil rawCertNoKey, err := c.Marshal() if err != nil { return nil, err } - c.details.PublicKey = pubKey + c.details.publicKey = pubKey return rawCertNoKey, nil } @@ -292,64 +244,68 @@ func (c *certificateV1) MarshalPEM() ([]byte, error) { } func (c *certificateV1) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) +} + +func (c *certificateV1) marshalJSON() m { fp, _ := c.Fingerprint() - jc := m{ + return m{ + "version": Version1, "details": m{ - "name": c.details.Name, - "ips": c.details.Ips, - "subnets": c.details.Subnets, - "groups": c.details.Groups, - "notBefore": c.details.NotBefore, - "notAfter": c.details.NotAfter, - "publicKey": fmt.Sprintf("%x", c.details.PublicKey), - "isCa": c.details.IsCA, - "issuer": c.details.Issuer, - "curve": c.details.Curve.String(), + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "publicKey": fmt.Sprintf("%x", c.details.publicKey), + "isCa": c.details.isCA, + "issuer": c.details.issuer, + "curve": c.details.curve.String(), }, "fingerprint": fp, "signature": fmt.Sprintf("%x", c.Signature()), } - return json.Marshal(jc) } func (c *certificateV1) Copy() Certificate { nc := &certificateV1{ details: detailsV1{ - Name: c.details.Name, - Groups: make([]string, len(c.details.Groups)), - Ips: make([]netip.Prefix, len(c.details.Ips)), - Subnets: make([]netip.Prefix, len(c.details.Subnets)), - NotBefore: c.details.NotBefore, - NotAfter: c.details.NotAfter, - PublicKey: make([]byte, len(c.details.PublicKey)), - IsCA: c.details.IsCA, - Issuer: c.details.Issuer, - Curve: c.details.Curve, + name: c.details.name, + groups: make([]string, len(c.details.groups)), + networks: make([]netip.Prefix, len(c.details.networks)), + unsafeNetworks: make([]netip.Prefix, len(c.details.unsafeNetworks)), + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + publicKey: make([]byte, len(c.details.publicKey)), + isCA: c.details.isCA, + issuer: c.details.issuer, + curve: c.details.curve, }, signature: make([]byte, len(c.signature)), } copy(nc.signature, c.signature) - copy(nc.details.Groups, c.details.Groups) - copy(nc.details.PublicKey, c.details.PublicKey) - copy(nc.details.Ips, c.details.Ips) - copy(nc.details.Subnets, c.details.Subnets) + copy(nc.details.groups, c.details.groups) + copy(nc.details.publicKey, c.details.publicKey) + copy(nc.details.networks, c.details.networks) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) return nc } func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { c.details = detailsV1{ - Name: t.Name, - Ips: t.Networks, - Subnets: t.UnsafeNetworks, - Groups: t.Groups, - NotBefore: t.NotBefore, - NotAfter: t.NotAfter, - PublicKey: t.PublicKey, - IsCA: t.IsCA, - Curve: t.Curve, - Issuer: t.issuer, + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + publicKey: t.PublicKey, + isCA: t.IsCA, + curve: t.Curve, + issuer: t.issuer, } return nil @@ -394,28 +350,28 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) nc := certificateV1{ details: detailsV1{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), - Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - Curve: rc.Details.Curve, + name: rc.Details.Name, + groups: make([]string, len(rc.Details.Groups)), + networks: make([]netip.Prefix, len(rc.Details.Ips)/2), + unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), + notBefore: time.Unix(rc.Details.NotBefore, 0), + notAfter: time.Unix(rc.Details.NotAfter, 0), + publicKey: make([]byte, len(rc.Details.PublicKey)), + isCA: rc.Details.IsCA, + curve: rc.Details.Curve, }, signature: make([]byte, len(rc.Signature)), } copy(nc.signature, rc.Signature) - copy(nc.details.Groups, rc.Details.Groups) - nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) + copy(nc.details.groups, rc.Details.Groups) + nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) if len(publicKey) > 0 { - nc.details.PublicKey = publicKey + nc.details.publicKey = publicKey } - copy(nc.details.PublicKey, rc.Details.PublicKey) + copy(nc.details.publicKey, rc.Details.PublicKey) var ip netip.Addr for i, rawIp := range rc.Details.Ips { @@ -423,7 +379,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) + nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) } } @@ -432,7 +388,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) + nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) } } diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 96a1b9b2b..04f221339 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -281,6 +281,7 @@ func (c *certificateV2) marshalJSON() m { "isCa": c.details.isCA, "issuer": c.details.issuer, }, + "version": Version2, "publicKey": fmt.Sprintf("%x", c.publicKey), "curve": c.curve.String(), "fingerprint": fp, diff --git a/cert/sign.go b/cert/sign.go index dcb5a5d0e..2f768d4ec 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -9,6 +9,7 @@ import ( "fmt" "math/big" "net/netip" + "slices" "time" "github.com/slackhq/nebula/pkclient" @@ -62,6 +63,7 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien } //TODO: make sure we have all minimum properties to sign, like a public key + //TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs if signer != nil { if t.IsCA { @@ -84,6 +86,9 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien } } + slices.SortFunc(t.Networks, comparePrefix) + slices.SortFunc(t.UnsafeNetworks, comparePrefix) + var c beingSignedCertificate switch t.Version { case Version1: @@ -152,3 +157,11 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien return sc, nil } + +func comparePrefix(a, b netip.Prefix) int { + addr := a.Addr().Compare(b.Addr()) + if addr == 0 { + return a.Bits() - b.Bits() + } + return addr +} diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 345c0f3c5..1b62cb9d3 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,6 +3,7 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" @@ -42,7 +43,7 @@ type signFlags struct { func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} - sf.version = sf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") @@ -81,15 +82,12 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("name", sf.name); err != nil { return err } - if err := mustFlagString("ip", sf.ip); err != nil { - return err - } if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } version := cert.Version(*sf.version) - if version != cert.Version1 && version != cert.Version2 { + if version != 0 && version != cert.Version1 && version != cert.Version2 { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) } @@ -106,14 +104,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { // ask for a passphrase until we get one var passphrase []byte for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() - if err == ErrNoTerminal { + if errors.Is(err, ErrNoTerminal) { return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") } else if err != nil { return fmt.Errorf("error reading password: %s", err) @@ -161,7 +159,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - var networks []netip.Prefix + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix if *sf.networks == "" && *sf.ip != "" { // Pull up deprecated -ip flag if needed *sf.networks = *sf.ip @@ -169,41 +168,32 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if *sf.networks != "" { for _, rs := range strings.Split(*sf.networks, ",") { + //TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid -networks definition: %s", rs) } - if version == cert.Version1 && !n.Addr().Is4() { - return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) - } - networks = append(networks, n) - } - } - } - - if len(networks) > 1 && version == cert.Version1 { - return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") - } - var groups []string - if *sf.groups != "" { - for _, rg := range strings.Split(*sf.groups, ",") { - g := strings.TrimSpace(rg) - if g != "" { - groups = append(groups, g) + if n.Addr().Is4() { + v4Networks = append(v4Networks, n) + } else { + v6Networks = append(v6Networks, n) + } } } } - var unsafeNetworks []netip.Prefix + var v4UnsafeNetworks []netip.Prefix + var v6UnsafeNetworks []netip.Prefix if *sf.unsafeNetworks == "" && *sf.subnets != "" { // Pull up deprecated -subnets flag if needed *sf.unsafeNetworks = *sf.subnets } if *sf.unsafeNetworks != "" { + //TODO: error on duplicates? for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { @@ -211,10 +201,22 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err != nil { return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if version == cert.Version1 && !n.Addr().Is4() { - return newHelpErrorf("invalid -unsafe-networks definition: can only be ipv4, have %s", rs) + + if n.Addr().Is4() { + v4UnsafeNetworks = append(v4UnsafeNetworks, n) + } else { + v6UnsafeNetworks = append(v6UnsafeNetworks, n) } - unsafeNetworks = append(unsafeNetworks, n) + } + } + } + + var groups []string + if *sf.groups != "" { + for _, rg := range strings.Split(*sf.groups, ",") { + g := strings.TrimSpace(rg) + if g != "" { + groups = append(groups, g) } } } @@ -256,19 +258,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - t := &cert.TBSCertificate{ - Version: version, - Name: *sf.name, - Networks: networks, - Groups: groups, - UnsafeNetworks: unsafeNetworks, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Curve: curve, - } - if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } @@ -281,18 +270,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - var c cert.Certificate + var crts []cert.Certificate - if p11Client == nil { - c, err = t.Sign(caCert, curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %w", err) + notBefore := time.Now() + notAfter := notBefore.Add(*sf.duration) + + if version == 0 || version == cert.Version1 { + // Make sure we at least have an ip + if len(v4Networks) != 1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } - } else { - c, err = t.SignPkcs11(caCert, curve, p11Client) - if err != nil { - return fmt.Errorf("error while signing with PKCS#11: %w", err) + + if version == cert.Version1 { + // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") + } + + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{v4Networks[0]}, + Groups: groups, + UnsafeNetworks: v4UnsafeNetworks, + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignPkcs11(caCert, curve, p11Client) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) + } + + if version == 0 || version == cert.Version2 { + t := &cert.TBSCertificate{ + Version: cert.Version2, + Name: *sf.name, + Networks: append(v4Networks, v6Networks...), + Groups: groups, + UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignPkcs11(caCert, curve, p11Client) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } } + + crts = append(crts, nc) } if !isP11 && *sf.inPubPath == "" { @@ -306,9 +362,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - b, err := c.MarshalPEM() - if err != nil { - return fmt.Errorf("error while marshalling certificate: %s", err) + var b []byte + for _, c := range crts { + sb, err := c.MarshalPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) From fda8f9636c234a7119de1207858d87db0999f3d6 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 19 Sep 2024 21:23:03 -0500 Subject: [PATCH 05/17] Dual cert handling --- calculated_remote.go | 8 +- calculated_remote_test.go | 2 +- cert/ca_pool_test.go | 20 +- cert/cert_test.go | 58 +-- cert/cert_v1.go | 4 + cert/cert_v2.asn1 | 4 +- cert/cert_v2.go | 76 ++-- cmd/nebula-cert/print_test.go | 65 ++- cmd/nebula-cert/sign.go | 18 +- cmd/nebula-cert/sign_test.go | 8 +- connection_manager.go | 82 ++-- connection_manager_test.go | 55 ++- connection_state.go | 36 +- control.go | 46 +- control_test.go | 30 +- control_tester.go | 16 +- dns_server.go | 22 +- dns_server_test.go | 2 +- e2e/handshakes_test.go | 240 +++++----- e2e/helpers_test.go | 54 ++- e2e/router/hostmap.go | 7 +- e2e/router/router.go | 9 +- examples/config.yml | 6 + firewall.go | 4 +- firewall_test.go | 12 +- handshake_ix.go | 170 +++---- handshake_manager.go | 218 +++++---- handshake_manager_test.go | 16 +- hostmap.go | 86 ++-- hostmap_test.go | 56 +-- hostmap_tester.go | 4 +- inside.go | 20 +- interface.go | 50 ++- lighthouse.go | 590 +++++++++++++++--------- lighthouse_test.go | 142 +++--- main.go | 24 +- nebula.pb.go | 812 ++++++++++++++++++++++++++-------- nebula.proto | 31 +- outside.go | 27 +- pki.go | 384 +++++++++++++--- relay_manager.go | 185 +++++--- remote_list.go | 48 +- remote_list_test.go | 34 +- ssh.go | 28 +- udp/temp.go | 2 +- 45 files changed, 2468 insertions(+), 1343 deletions(-) diff --git a/calculated_remote.go b/calculated_remote.go index ae2ed500c..99b2b368f 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -38,7 +38,7 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { +func (c *calculatedRemote) Apply(ip netip.Addr) *V4AddrPort { // Combine the masked bytes of the "mask" IP with the unmasked bytes // of the overlay IP if c.ipNet.Addr().Is4() { @@ -47,7 +47,7 @@ func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { return c.apply6(ip) } -func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { +func (c *calculatedRemote) apply4(ip netip.Addr) *V4AddrPort { //TODO: IPV6-WORK this can be less crappy maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) mask := binary.BigEndian.Uint32(maskb[:]) @@ -58,10 +58,10 @@ func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { b = ip.As4() intIp := binary.BigEndian.Uint32(b[:]) - return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} + return &V4AddrPort{(maskIp & mask) | (intIp & ^mask), c.port} } -func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { +func (c *calculatedRemote) apply6(ip netip.Addr) *V4AddrPort { //TODO: IPV6-WORK panic("Can not calculate ipv6 remote addresses") } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 6ff1cb0bd..0d34e2fa8 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -21,5 +21,5 @@ func TestCalculatedRemoteApply(t *testing.T) { expected, err := netip.ParseAddr("192.168.1.182") assert.NoError(t, err) - assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) + assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.Apply(input)) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 053640d98..292d2f9c7 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -63,31 +63,31 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX rootCA := certificateV1{ details: detailsV1{ - Name: "nebula root ca", + name: "nebula root ca", }, } rootCA01 := certificateV1{ details: detailsV1{ - Name: "nebula root ca 01", + name: "nebula root ca 01", }, } rootCAP256 := certificateV1{ details: detailsV1{ - Name: "nebula P256 test", + name: "nebula P256 test", }, } p, err := NewCAPoolFromPEM([]byte(noNewLines)) assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name) + assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) @@ -97,13 +97,13 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.name) assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") assert.Equal(t, len(pppp.CAs), 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) + assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.name) assert.Equal(t, len(ppppp.CAs), 1) } diff --git a/cert/cert_test.go b/cert/cert_test.go index b2ea406cb..b5f21761a 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -24,25 +24,25 @@ func TestMarshalingNebulaCertificate(t *testing.T) { nc := certificateV1{ details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ + name: "testing", + networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), //TODO: netip cant represent this netmask //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []netip.Prefix{ + unsafeNetworks: []netip.Prefix{ //TODO: netip cant represent this netmask //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } @@ -55,16 +55,16 @@ func TestMarshalingNebulaCertificate(t *testing.T) { assert.Nil(t, err) assert.Equal(t, nc.signature, nc2.Signature()) - assert.Equal(t, nc.details.Name, nc2.Name()) - assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) - assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) - assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) - assert.Equal(t, nc.details.IsCA, nc2.IsCA()) + assert.Equal(t, nc.details.name, nc2.Name()) + assert.Equal(t, nc.details.notBefore, nc2.NotBefore()) + assert.Equal(t, nc.details.notAfter, nc2.NotAfter()) + assert.Equal(t, nc.details.publicKey, nc2.PublicKey()) + assert.Equal(t, nc.details.isCA, nc2.IsCA()) - assert.Equal(t, nc.details.Ips, nc2.Networks()) - assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) + assert.Equal(t, nc.details.networks, nc2.Networks()) + assert.Equal(t, nc.details.unsafeNetworks, nc2.UnsafeNetworks()) - assert.Equal(t, nc.details.Groups, nc2.Groups()) + assert.Equal(t, nc.details.groups, nc2.Groups()) } //func TestNebulaCertificate_Sign(t *testing.T) { @@ -154,8 +154,8 @@ func TestMarshalingNebulaCertificate(t *testing.T) { func TestNebulaCertificate_Expired(t *testing.T) { nc := certificateV1{ details: detailsV1{ - NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), - NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), }, } @@ -170,25 +170,25 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { nc := certificateV1{ details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ + name: "testing", + networks: []netip.Prefix{ mustParsePrefixUnmapped("10.1.1.1/24"), mustParsePrefixUnmapped("10.1.1.2/16"), //TODO: netip bad //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []netip.Prefix{ + unsafeNetworks: []netip.Prefix{ //TODO: netip bad //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, mustParsePrefixUnmapped("9.1.1.2/24"), mustParsePrefixUnmapped("9.1.1.3/16"), }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), - NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", }, signature: []byte("1234567890abcedfghij1234567890ab"), } @@ -197,7 +197,7 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", string(b), ) } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 62c1cc66f..4dc38fcef 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -14,6 +14,7 @@ import ( "fmt" "net" "net/netip" + "slices" "time" "golang.org/x/crypto/curve25519" @@ -392,6 +393,9 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) } } + slices.SortFunc(nc.details.networks, comparePrefix) + slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + return &nc, nil } diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 index 32a6e735e..0aadef193 100644 --- a/cert/cert_v2.asn1 +++ b/cert/cert_v2.asn1 @@ -21,7 +21,7 @@ Details ::= SEQUENCE { name Name, -- At least 1 ipv4 or ipv6 address must be present if isCA is false - networks SEQUENCE OF Network, + networks SEQUENCE OF Network OPTIONAL, unsafeNetworks SEQUENCE OF Network OPTIONAL, groups SEQUENCE OF Name OPTIONAL, isCA BOOLEAN DEFAULT false, @@ -30,7 +30,7 @@ Details ::= SEQUENCE { notBefore Time, notAfter Time, - issuer OCTET STRING, + issuer OCTET STRING OPTIONAL, ... -- New fields can be added below here } diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 04f221339..117de22e1 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -12,6 +12,7 @@ import ( "encoding/pem" "fmt" "net/netip" + "slices" "time" "golang.org/x/crypto/cryptobyte" @@ -30,14 +31,14 @@ const ( TagCertPublicKey = 2 | classContextSpecific TagCertSignature = 3 | classContextSpecific - TagDetailsName = 0 | classContextSpecific - TagDetailsIps = 1 | classConstructed | classContextSpecific - TagDetailsSubnets = 2 | classConstructed | classContextSpecific - TagDetailsGroups = 3 | classConstructed | classContextSpecific - TagDetailsIsCA = 4 | classContextSpecific - TagDetailsNotBefore = 5 | classContextSpecific - TagDetailsNotAfter = 6 | classContextSpecific - TagDetailsIssuer = 7 | classContextSpecific + TagDetailsName = 0 | classContextSpecific + TagDetailsNetworks = 1 | classConstructed | classContextSpecific + TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific ) const ( @@ -47,9 +48,9 @@ const ( // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems MaxNameLength = 253 - // MaxSubnetLength is the maximum length a subnet value can be. + // MaxNetworkLength is the maximum length a network value can be. // 16 bytes for an ipv6 address + 1 byte for the prefix length - MaxSubnetLength = 17 + MaxNetworkLength = 17 ) type certificateV2 struct { @@ -370,14 +371,14 @@ func (d *detailsV2) Marshal() ([]byte, error) { b.AddBytes([]byte(d.name)) }) - // Add the ips if any exist + // Add the networks if any exist if len(d.networks) > 0 { - b.AddASN1(TagDetailsIps, func(b *cryptobyte.Builder) { - for _, subnet := range d.networks { - sb, innerErr := subnet.MarshalBinary() + b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.networks { + sb, innerErr := n.MarshalBinary() if innerErr != nil { // MarshalBinary never returns an error - err = fmt.Errorf("unable to marshal ip: %w", innerErr) + err = fmt.Errorf("unable to marshal network: %w", innerErr) return } b.AddASN1OctetString(sb) @@ -385,14 +386,14 @@ func (d *detailsV2) Marshal() ([]byte, error) { }) } - // Add the subnets if any exist + // Add the unsafe networks if any exist if len(d.unsafeNetworks) > 0 { - b.AddASN1(TagDetailsSubnets, func(b *cryptobyte.Builder) { - for _, subnet := range d.unsafeNetworks { - sb, innerErr := subnet.MarshalBinary() + b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.unsafeNetworks { + sb, innerErr := n.MarshalBinary() if innerErr != nil { // MarshalBinary never returns an error - err = fmt.Errorf("unable to marshal subnet: %w", innerErr) + err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) return } b.AddASN1OctetString(sb) @@ -511,47 +512,47 @@ func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { return detailsV2{}, ErrBadFormat } - // Read the ip addresses + // Read the network addresses var subString cryptobyte.String var found bool - if !b.ReadOptionalASN1(&subString, &found, TagDetailsIps) { + if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { return detailsV2{}, ErrBadFormat } - var ips []netip.Prefix + var networks []netip.Prefix var val cryptobyte.String if found { for !subString.Empty() { - if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxSubnetLength { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { return detailsV2{}, ErrBadFormat } - var ip netip.Prefix - if err := ip.UnmarshalBinary(val); err != nil { + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { return detailsV2{}, ErrBadFormat } - ips = append(ips, ip) + networks = append(networks, n) } } - // Read out any subnets - if !b.ReadOptionalASN1(&subString, &found, TagDetailsSubnets) { + // Read out any unsafe networks + if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { return detailsV2{}, ErrBadFormat } - var subnets []netip.Prefix + var unsafeNetworks []netip.Prefix if found { for !subString.Empty() { - if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxSubnetLength { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { return detailsV2{}, ErrBadFormat } - var subnet netip.Prefix - if err := subnet.UnmarshalBinary(val); err != nil { + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { return detailsV2{}, ErrBadFormat } - subnets = append(subnets, subnet) + unsafeNetworks = append(unsafeNetworks, n) } } @@ -593,10 +594,13 @@ func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { return detailsV2{}, ErrBadFormat } + slices.SortFunc(networks, comparePrefix) + slices.SortFunc(unsafeNetworks, comparePrefix) + return detailsV2{ name: string(name), - networks: ips, - unsafeNetworks: subnets, + networks: networks, + unsafeNetworks: unsafeNetworks, groups: groups, isCA: isCa, notBefore: time.Unix(notBefore, 0), diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 4c9a72db4..7d6a40702 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + `{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +`, ob.String(), ) assert.Equal(t, "", eb.String()) @@ -108,7 +166,10 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", + `{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1} +{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1} +{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1} +`, ob.String(), ) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 1b62cb9d3..6ac045214 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -86,6 +86,17 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return newHelpErrorf("cannot set both -in-pub and -out-key") } + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if len(*sf.networks) == 0 { + return newHelpErrorf("-networks is required") + } + version := cert.Version(*sf.version) if version != 0 && version != cert.Version1 && version != cert.Version2 { return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) @@ -159,13 +170,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - var v4Networks []netip.Prefix - var v6Networks []netip.Prefix - if *sf.networks == "" && *sf.ip != "" { - // Pull up deprecated -ip flag if needed - *sf.networks = *sf.ip - } - if *sf.networks != "" { for _, rs := range strings.Split(*sf.networks, ",") { //TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 6985f2473..b4fdc43be 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -56,7 +56,7 @@ func Test_signHelp(t *testing.T) { " -unsafe-networks string\n"+ " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ " -version uint\n"+ - " \tOptional: version of the certificate format to use (default 2)\n", + " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", ob.String(), ) } @@ -90,7 +90,7 @@ func Test_signCert(t *testing.T) { assertHelpError(t, signCert( []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, - ), "-ip is required") + ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -183,7 +183,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -205,7 +205,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: can only be ipv4, have 100::100/100") + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) diff --git a/connection_manager.go b/connection_manager.go index a0de842db..eecfd7d46 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) + n.intf.lightHouse.DeleteVpnAddr(hostinfo.vpnAddrs[0]) } case closeTunnel: @@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { - existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 var relayFrom netip.Addr @@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = existing.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr case ForwardingType: - relayFrom = existing.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) n.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { n.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = r.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = r.PeerAddr case ForwardingType: - relayFrom = r.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = r.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } } - //TODO: IPV6-WORK - relayFromB := relayFrom.As4() - relayToB := relayTo.As4() - // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), - RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } + + switch newhostinfo.GetCert().Certificate.Version() { + case cert.Version1: + if !relayFrom.Is4() { + n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !relayTo.Is4() { + n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := relayFrom.As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = relayTo.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) + req.RelayToAddr = netAddrToProtoAddr(relayTo) + default: + newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + continue + } + msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromIp, - "relayTo": req.RelayToIp, + "relayFrom": req.RelayFromAddr, + "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": newhostinfo.vpnIp}). + "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } @@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnIp] + primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { + //TODO: current.vpnIp should become an array of vpnIps + if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. // The remotes vpn ip is lower than mine. I will not flip. return false } - certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) + //TODO: we should favor v2 over v1 certificates if configured to send them + + crt := n.intf.pki.getCertificate(current.ConnectionState.myCert.Version()) + return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { n.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnIp] == primary { + if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { n.hostMap.unlockedMakePrimary(current) } n.hostMap.Unlock() @@ -473,14 +495,16 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { + crt := n.intf.pki.getCertificate(hostinfo.ConnectionState.myCert.Version()) + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), crt.Signature()) { return } - n.l.WithField("vpnIp", hostinfo.vpnIp). + //TODO: we should favor v2 over v1 certificates if configured to send them + + n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) + n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index 9f222c8b4..8e2ef15ad 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId) @@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded @@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) @@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. @@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. @@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.connectionManager = nc hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, diff --git a/connection_state.go b/connection_state.go index bcc9e5d9a..cfd86eb35 100644 --- a/connection_state.go +++ b/connection_state.go @@ -26,43 +26,45 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern noise.HandshakePattern) *ConnectionState { + crt := cs.GetDefaultCertificate() var dhFunc noise.DHFunc - switch certState.Certificate.Curve() { + switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - if certState.pkcs11Backed { + if cs.pkcs11Backed { dhFunc = noiseutil.DHP256PKCS11 } else { dhFunc = noiseutil.DHP256 } default: - l.Errorf("invalid curve: %s", certState.Certificate.Curve()) + l.Errorf("invalid curve: %s", crt.Curve()) return nil } - var cs noise.CipherSuite - if cipher == "chachapoly" { - cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + var ncs noise.CipherSuite + if cs.cipher == "chachapoly" { + ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { - cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) + ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} + static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) // Clear out bit 0, we never transmit it and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + //NOTE: These should come from CertState (pki.go) when we finally implement it + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, }) if err != nil { return nil @@ -74,7 +76,7 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - myCert: certState.Certificate, + myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) diff --git a/control.go b/control.go index 839c46f99..75bdbd771 100644 --- a/control.go +++ b/control.go @@ -19,9 +19,9 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp netip.Addr) *HostInfo + QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) - ForEachVpnIp(each controlEach) + ForEachVpnAddr(each controlEach) GetPreferredRanges() []netip.Prefix } @@ -37,7 +37,7 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` @@ -130,16 +130,17 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found -// TODO: this should copy! func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { - if c.f.myVpnNet.Addr() == vpnIp { - return c.f.pki.GetCertState().Certificate + _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) + if found { + //TODO: we might have 2 certs.... + return c.f.pki.getDefaultCertificate().Copy() } - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } - return hi.GetCert().Certificate + return hi.GetCert().Certificate.Copy() } // CreateTunnel creates a new tunnel to the given vpn ip. @@ -149,7 +150,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) { // PrintTunnel creates a new tunnel to the given vpn ip. func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -166,9 +167,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { return hi.CopyCache() } -// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found +// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. -func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { +func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -176,7 +177,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos hl = c.f.hostMap } - h := hl.QueryVpnIp(vpnIp) + h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } @@ -188,7 +189,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos // SetRemoteForTunnel forces a tunnel to use a specific remote // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } @@ -201,7 +202,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } @@ -230,14 +231,14 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { shutdown := func(h *HostInfo) { if excludeLighthouses { - if _, ok := lighthouses[h.vpnIp]; ok { + if _, ok := lighthouses[h.vpnAddrs[0]]; ok { return } } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). + c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } @@ -247,7 +248,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { - relayingHosts[relayingHost.vpnIp] = relayingHost + relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() @@ -255,7 +256,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { - if _, ok := relayingHosts[relayHost.vpnIp]; !ok { + if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } @@ -275,9 +276,8 @@ func (c *Control) Device() overlay.Device { } func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { - chi := ControlHostInfo{ - VpnIp: h.vpnIp, + VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), @@ -286,6 +286,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { CurrentRemote: h.remote, } + for i, a := range h.vpnAddrs { + chi.VpnAddrs[i] = a + } + if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() } @@ -300,7 +304,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() - hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts diff --git a/control_test.go b/control_test.go index fdfc0a57e..cd6364068 100644 --- a/control_test.go +++ b/control_test.go @@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, netip.Prefix{}) + hm := newHostMap(l) hm.preferredRanges.Store(&[]netip.Prefix{}) remote1 := netip.MustParseAddrPort("0.0.0.100:4444") @@ -36,8 +36,8 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) - remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) @@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp2, + vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(vpnIp, false) + thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: vpnIp, + VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []netip.AddrPort{remote2, remote1}, @@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.EqualValues(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(vpnIp2, false) + thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index fa87e5300..cde0919f7 100644 --- a/control_tester.go +++ b/control_tester.go @@ -6,8 +6,6 @@ package nebula import ( "net/netip" - "github.com/slackhq/nebula/cert" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" @@ -57,9 +55,9 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) c.f.lightHouse.Unlock() if toAddr.Addr().Is4() { - remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } @@ -131,8 +129,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() netip.Addr { - return c.f.myVpnNet.Addr() +func (c *Control) GetVpnAddrs() []netip.Addr { + return c.f.myVpnAddrs } func (c *Control) GetUDPAddr() netip.AddrPort { @@ -140,7 +138,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort { } func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) + hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } @@ -153,8 +151,8 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() cert.Certificate { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertState() *CertState { + return c.f.pki.getCertState() } func (c *Control) ReHandshake(vpnIp netip.Addr) { diff --git a/dns_server.go b/dns_server.go index 750123122..991f27068 100644 --- a/dns_server.go +++ b/dns_server.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -21,14 +22,16 @@ var dnsAddr string type dnsRecords struct { sync.RWMutex - dnsMap map[string]string - hostMap *HostMap + dnsMap map[string]string + hostMap *HostMap + myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(hostMap *HostMap) *dnsRecords { +func newDnsRecords(cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - dnsMap: make(map[string]string), - hostMap: hostMap, + dnsMap: make(map[string]string), + hostMap: hostMap, + myVpnAddrsTable: cs.myVpnAddrsTable, } } @@ -47,7 +50,7 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - hostinfo := d.hostMap.QueryVpnIp(ip) + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } @@ -91,7 +94,8 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { // We don't answer these queries from non nebula nodes or localhost //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { + _, found := dnsR.myVpnAddrsTable.Lookup(b) + if !found && a != "127.0.0.1" { return } l.Debugf("Query for TXT %s", q.Name) @@ -123,8 +127,8 @@ func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(hostMap) +func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { + dnsR = newDnsRecords(cs, hostMap) // attach request handler func dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae84f..ce0f419ac 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -11,7 +11,7 @@ import ( func TestParsequery(t *testing.T) { //TODO: This test is basically pointless hostMap := &HostMap{} - ds := newDnsRecords(hostMap) + ds := newDnsRecords(&CertState{}, hostMap) ds.Add("test.com.com", "1.2.3.4") m := new(dns.Msg) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 6be94adc4..a3717bcad 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -25,7 +25,7 @@ func BenchmarkHotPath(b *testing.B) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -49,14 +49,14 @@ func TestGoodHandshake(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +77,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -105,10 +105,10 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -120,7 +120,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -139,18 +139,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone @@ -169,8 +169,8 @@ func TestStage1Race(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -181,8 +181,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -194,14 +194,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -219,7 +219,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -246,8 +246,8 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -258,10 +258,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() @@ -269,17 +269,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -295,8 +295,8 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -307,10 +307,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") @@ -319,18 +319,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -347,9 +347,9 @@ func TestRelays(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -361,11 +361,11 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } @@ -378,14 +378,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -397,14 +397,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -428,14 +428,14 @@ func TestStage1RaceRelays2(t *testing.T) { l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -448,16 +448,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -470,7 +470,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -490,7 +490,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -498,7 +498,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -515,9 +515,9 @@ func TestRehandshakingRelays(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -529,17 +529,17 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -557,8 +557,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -570,8 +570,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -582,13 +582,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -596,7 +596,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -604,7 +604,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -619,9 +619,9 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -633,17 +633,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -661,8 +661,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -674,8 +674,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -686,13 +686,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -700,7 +700,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -708,7 +708,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -721,8 +721,8 @@ func TestRehandshaking(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -733,12 +733,12 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -755,8 +755,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now break @@ -783,19 +783,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides @@ -818,8 +818,8 @@ func TestRehandshakingLoser(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -830,16 +830,16 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + tt1 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) + tt2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) fmt.Println(tt1.LocalIndex, tt2.LocalIndex) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -856,8 +856,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break @@ -883,19 +883,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides @@ -918,8 +918,8 @@ func TestRaceRegression(t *testing.T) { theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -933,8 +933,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -964,7 +964,7 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myControl.Stop() theirControl.Stop() diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 77996f3da..f8a224366 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -8,6 +8,7 @@ import ( "io" "net/netip" "os" + "strings" "testing" "time" @@ -26,25 +27,35 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) - if err != nil { - panic(err) + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(sn) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") } var udpAddr netip.AddrPort - if vpnIpNet.Addr().Is4() { - budpIp := vpnIpNet.Addr().As4() + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { - budpIp := vpnIpNet.Addr().As16() - budpIp[13] -= 128 + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -109,7 +120,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe panic(err) } - return control, vpnIpNet, udpAddr, c + return control, vpnNetworks, udpAddr, c } type doneCb func() @@ -142,17 +153,18 @@ func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *n assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") + //TODO: we may want to loop over each vpnAddr and assert all the things + hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") - hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") + hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") + assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") + assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") @@ -197,6 +209,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, assert.Equal(t, expected, data.Payload(), "Data was incorrect") } +func getAddrs(ns []netip.Prefix) []netip.Addr { + var a []netip.Addr + for _, n := range ns { + a = append(a, n.Addr()) + } + return a +} + func NewTestLogger() *logrus.Logger { l := logrus.New() diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 29fa95991..f2805d0c8 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Name(), " ") - clusterVpnIp := c.GetCert().Networks()[0].Addr() + crt := c.GetCertState().GetDefaultCertificate() + clusterName := strings.Trim(crt.Name(), " ") + clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi diff --git a/e2e/router/router.go b/e2e/router/router.go index 08905705c..5fa382344 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { panic("Duplicate listen address: " + addr.String()) } - r.vpnControls[c.GetVpnIp()] = c + for _, vpnAddr := range c.GetVpnAddrs() { + r.vpnControls[vpnAddr] = c + } + r.controls[addr] = c } @@ -217,7 +220,7 @@ func (r *R) renderFlow() { participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", - sanAddr, e.packet.from.GetVpnIp(), sanAddr, + sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } @@ -303,7 +306,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 + return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) diff --git a/examples/config.yml b/examples/config.yml index c74ffc68f..f3db510f7 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,6 +13,12 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true + # default_version controls which certificate version is used in handshakes. + # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. + # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. + # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. + # default_version: 1 + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: diff --git a/firewall.go b/firewall.go index 80a828057..3a615de8a 100644 --- a/firewall.go +++ b/firewall.go @@ -8,6 +8,7 @@ import ( "hash/fnv" "net/netip" "reflect" + "slices" "strconv" "strings" "sync" @@ -433,7 +434,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } } else { // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.vpnIp { + //TODO: we can make this more performant + if !slices.Contains(h.vpnAddrs, fp.RemoteIP) { f.metrics(incoming).droppedRemoteIP.Inc(1) return ErrInvalidRemoteIP } diff --git a/firewall_test.go b/firewall_test.go index 57cd32ae5..79e90b692 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -149,7 +149,7 @@ func TestFirewall_Drop(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, - vpnIp: netip.MustParseAddr("1.2.3.4"), + vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } h.CreateRemoteCIDR(&c) @@ -329,7 +329,7 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } h.CreateRemoteCIDR(c.Certificate) @@ -391,7 +391,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } h1.CreateRemoteCIDR(c1.Certificate) @@ -406,7 +406,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } h2.CreateRemoteCIDR(c2.Certificate) @@ -421,7 +421,7 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } h3.CreateRemoteCIDR(c3.Certificate) @@ -468,7 +468,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } h.CreateRemoteCIDR(c.Certificate) diff --git a/handshake_ix.go b/handshake_ix.go index 0448385c3..4cb642ffa 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -17,31 +17,26 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) + cs := f.pki.getCertState() + ci := NewConnectionState(f.l, cs, true, noise.HandshakeIX) hh.hostinfo.ConnectionState = ci - hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: certState.RawCertificateNoKey, - CertVersion: uint32(certState.Certificate.Version()), - } - - hsBytes := []byte{} - hs := &NebulaHandshake{ - Details: hsProto, + Details: &NebulaHandshakeDetails{ + InitiatorIndex: hh.hostinfo.localIndexId, + Time: uint64(time.Now().UnixNano()), + Cert: cs.getDefaultHandshakeBytes(), + }, } - hsBytes, err = hs.Marshal() + hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -50,7 +45,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -65,8 +60,8 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) + cs := f.pki.getCertState() + ci := NewConnectionState(f.l, cs, false, noise.HandshakeIX) // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) @@ -79,9 +74,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hs := &NebulaHandshake{} err = hs.Unmarshal(msg) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") @@ -101,6 +93,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } + if remoteCert.Certificate.Version() != ci.myCert.Version() { + // We started off using the wrong certificate version, lets see if we can match the version that was sent to us + rc := cs.getCertificate(remoteCert.Certificate.Version()) + //TODO: anywhere we are logging remoteCert needs to be remoteCert.Certificate OR we make a pass through func on CachedCertificate + if rc == nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Unable to handshake with host due to missing certificate version") + return + } + + // Record the certificate we are actually using + ci.myCert = rc + } + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -113,30 +120,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + var vpnAddrs []netip.Addr certName := remoteCert.Certificate.Name() fingerprint := remoteCert.ShaSum issuer := remoteCert.Certificate.Issuer() - if vpnIp == f.myVpnNet.Addr() { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") - return - } - - if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + for _, network := range remoteCert.Certificate.Networks() { + vpnAddr := network.Addr() + _, found := f.myVpnAddrsTable.Lookup(vpnAddr) + if found { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } + + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + + vpnAddrs = append(vpnAddrs, vpnAddr) } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -148,17 +161,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - vpnIp: vpnIp, + vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -167,14 +180,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = certState.RawCertificateNoKey - hs.Details.CertVersion = uint32(certState.Certificate.Version()) + hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -185,14 +198,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -216,7 +229,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs[0]) hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert.Certificate) @@ -228,7 +241,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] @@ -236,11 +249,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -250,16 +263,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). @@ -270,23 +283,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -302,7 +315,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -310,7 +323,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -323,9 +336,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -352,8 +365,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + //TODO: this is kind of nonsense now + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) { + f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -361,7 +375,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -370,7 +384,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -382,7 +396,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again @@ -391,7 +405,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) if err != nil { - e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if f.l.Level > logrus.DebugLevel { @@ -416,14 +430,15 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() fingerprint := remoteCert.ShaSum issuer := remoteCert.Certificate.Issuer() // Ensure the right host responded - if vpnIp != hostinfo.vpnIp { - f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). + //TODO: this is a horribly broken test + if vpnNetworks[0].Addr() != hostinfo.vpnAddrs[0] { + f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -432,16 +447,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { + f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { //TODO: this doesnt know if its being added or is being used for caching a packet // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnNetworks[0].Addr()) - f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp). + f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") @@ -449,8 +464,11 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.packetStore = hh.packetStore hh.packetStore = []*cachedPacket{} - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp + // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnAddrs = nil + for _, n := range vpnNetworks { + hostinfo.vpnAddrs = append(hostinfo.vpnAddrs, n.Addr()) + } f.sendCloseTunnel(hostinfo) }) @@ -461,7 +479,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -483,7 +501,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if addr.IsValid() { hostinfo.SetRemote(addr) } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) } // Build up the radix for the firewall if we have subnets in the cert diff --git a/handshake_manager.go b/handshake_manager.go index 48348939b..258d5ae94 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -13,6 +13,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig } } -func (c *HandshakeManager) Run(ctx context.Context) { - clockSource := time.NewTicker(c.config.tryInterval) +func (hm *HandshakeManager) Run(ctx context.Context) { + clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return - case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, true) + case vpnIP := <-hm.trigger: + hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now) + hm.NextOutboundHandshakeTimerTick(now) } } } @@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { - c.OutboundHandshakeTimer.Advance(now) +func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { + hm.OutboundHandshakeTimer.Advance(now) for { - vpnIp, has := c.OutboundHandshakeTimer.Purge() + vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, false) + hm.handleOutbound(vpnIp, false) } } @@ -267,11 +268,18 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself, and don't relay through the host I'm trying to connect to - if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { + // Don't relay to myself + if relay == vpnIp { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + + // Don't relay through the host I'm trying to connect to + _, found := hm.f.myVpnAddrsTable.Lookup(relay) + if found { + continue + } + + relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hm.f.Handshake(relay) @@ -286,17 +294,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered case Requested: hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - - // Re-send the CreateRelay request, in case the previous one was lost. m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -306,7 +332,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // This must send over the hostinfo, not over hm.Hosts[ip] hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": existingRelay.LocalIndex, "relay": relay}). @@ -316,7 +342,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l). WithField("vpnIp", vpnIp). WithField("state", existingRelay.State). - WithField("relay", relayHostInfo.vpnIp). + WithField("relay", relayHostInfo.vpnAddrs[0]). Errorf("Relay unexpected state") } } else { @@ -327,16 +353,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -345,7 +390,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). @@ -381,10 +426,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hh, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) @@ -394,12 +439,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands } hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } @@ -407,9 +452,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands hostinfo: hostinfo, startTime: time.Now(), } - hm.vpnIps[vpnIp] = hh + hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) - hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) @@ -417,21 +462,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found - doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { - case hm.trigger <- vpnIp: + case hm.trigger <- vpnAddr: default: } } hm.Unlock() - hm.lightHouse.QueryServer(vpnIp) + hm.lightHouse.QueryServer(vpnAddr) return hostinfo } @@ -452,14 +497,14 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - c.Lock() - defer c.Unlock() +func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] + existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { @@ -476,31 +521,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrExistingHostInfo } - existingHostInfo.logger(c.l).Info("Taking new handshake") + existingHostInfo.logger(hm.l).Info("Taking new handshake") } - existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } - existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingPendingIndex.hostinfo, ErrLocalIndexCollision } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + hostinfo.logger(hm.l). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } @@ -518,7 +563,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } @@ -555,31 +600,32 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - c.Lock() - defer c.Unlock() - c.unlockedDeleteHostInfo(hostinfo) +func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - delete(c.vpnIps, hostinfo.vpnIp) - if len(c.vpnIps) == 0 { - c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} +func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + //TODO: need to iterate hostinfo.vpnAddrs + delete(hm.vpnIps, hostinfo.vpnAddrs[0]) + if len(hm.vpnIps) == 0 { + hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } - delete(c.indexes, hostinfo.localIndexId) - if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HandshakeHostInfo{} + delete(hm.indexes, hostinfo.localIndexId) + if len(hm.vpnIps) == 0 { + hm.indexes = map[uint32]*HandshakeHostInfo{} } - if c.l.Level >= logrus.DebugLevel { - c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { +func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -608,37 +654,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { - return c.mainHostMap.GetPreferredRanges() +func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return hm.mainHostMap.GetPreferredRanges() } -func (c *HandshakeManager) ForEachVpnIp(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.vpnIps { + for _, v := range hm.vpnIps { f(v.hostinfo) } } -func (c *HandshakeManager) ForEachIndex(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.indexes { + for _, v := range hm.indexes { f(v.hostinfo) } } -func (c *HandshakeManager) EmitStats() { - c.RLock() - hostLen := len(c.vpnIps) - indexLen := len(c.indexes) - c.RUnlock() +func (hm *HandshakeManager) EmitStats() { + hm.RLock() + hostLen := len(hm.vpnIps) + indexLen := len(hm.indexes) + hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) - c.mainHostMap.EmitStats() + hm.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index daa867564..ef6a88893 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -13,21 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") ip := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - mainHM := newHostMap(l, vpncidr) + mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -92,3 +92,7 @@ func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} + +func (mw *mockEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} diff --git a/hostmap.go b/hostmap.go index d83151eb3..fbafc06d5 100644 --- a/hostmap.go +++ b/hostmap.go @@ -48,7 +48,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp netip.Addr + PeerAddr netip.Addr } type HostMap struct { @@ -58,7 +58,6 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - vpnCIDR netip.Prefix l *logrus.Logger } @@ -68,9 +67,9 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { @@ -89,10 +88,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[ip] + r, ok := rs.relayForByAddr[addr] return r, ok } @@ -115,8 +114,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) - for relayIp := range rs.relayForByIp { + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays @@ -135,7 +134,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } @@ -143,7 +142,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return true } @@ -158,14 +157,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] return r, ok } @@ -179,7 +178,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() - rs.relayForByIp[ip] = r + rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } @@ -190,7 +189,7 @@ type HostInfo struct { ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp netip.Addr + vpnAddrs []netip.Addr recvError atomic.Uint32 remoteCidr *bart.Table[struct{}] relayState RelayState @@ -241,28 +240,26 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { - hm := newHostMap(l, vpnCIDR) +func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { + hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) - l.WithField("network", hm.vpnCIDR.String()). - WithField("preferredRanges", hm.GetPreferredRanges()). + l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } -func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { +func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{}, - vpnCIDR: vpnCIDR, l: l, } } @@ -305,17 +302,6 @@ func (hm *HostMap) EmitStats() { metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } -func (hm *HostMap) RemoveRelay(localIdx uint32) { - hm.Lock() - _, ok := hm.Relays[localIdx] - if !ok { - hm.Unlock() - return - } - delete(hm.Relays, localIdx) - hm.Unlock() -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -335,7 +321,7 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { - oldHostinfo := hm.Hosts[hostinfo.vpnIp] + oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] if oldHostinfo == hostinfo { return } @@ -348,7 +334,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { hostinfo.next.prev = hostinfo.prev } - hm.Hosts[hostinfo.vpnIp] = hostinfo + hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo if oldHostinfo == nil { return @@ -360,17 +346,17 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnIp] + primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]] if ok && primary == hostinfo { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnIp) + delete(hm.Hosts, hostinfo.vpnAddrs[0]) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnIp] = hostinfo.next + hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } @@ -406,7 +392,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } @@ -448,11 +434,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { - return hm.queryVpnIp(vpnIp, nil) +func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { + return hm.queryVpnAddr(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnAddrRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -470,7 +456,7 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -494,8 +480,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) } - existing := hm.Hosts[hostinfo.vpnIp] - hm.Hosts[hostinfo.vpnIp] = hostinfo + existing := hm.Hosts[hostinfo.vpnAddrs[0]] + hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo if existing != nil { hostinfo.next = existing @@ -506,8 +492,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). + hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } @@ -527,7 +513,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix { return *hm.preferredRanges.Load() } -func (hm *HostMap) ForEachVpnIp(f controlEach) { +func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() @@ -581,7 +567,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } @@ -596,7 +582,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { i.remote = remote - i.remotes.LearnRemote(i.vpnIp, remote) + i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } @@ -669,7 +655,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", i.vpnIp). + li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) diff --git a/hostmap_test.go b/hostmap_test.go index 7e2feb810..e974340d0 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,17 +11,14 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} - h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} - h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} + h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} + h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - hm := NewHostMapFromConfig( - l, - netip.MustParsePrefix("10.0.0.1/24"), - c, - ) + hm := NewHostMapFromConfig(l, c) toS := func(ipn []netip.Prefix) []string { var s []string diff --git a/hostmap_tester.go b/hostmap_tester.go index b2d1d1b5b..fe40c5334 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -9,8 +9,8 @@ import ( "net/netip" ) -func (i *HostInfo) GetVpnIp() netip.Addr { - return i.vpnIp +func (i *HostInfo) GetVpnAddrs() []netip.Addr { + return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { diff --git a/inside.go b/inside.go index 0ccd17909..1b75f0f46 100644 --- a/inside.go +++ b/inside.go @@ -20,11 +20,16 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { - return + if f.dropLocalBroadcast { + _, found := f.myBroadcastAddr.Lookup(fwPacket.RemoteIP) + if found { + return + } } - if fwPacket.RemoteIP == f.myVpnNet.Addr() { + //TODO: seems like a huge bummer + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteIP) + if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which // routes packets from the Nebula IP to the Nebula IP through the Nebula @@ -124,7 +129,8 @@ func (f *Interface) Handshake(vpnIp netip.Addr) { // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !f.myVpnNet.Contains(vpnIp) { + _, found := f.myVpnNetworks.Lookup(vpnIp) + if !found { vpnIp = f.inside.RouteFor(vpnIp) if !vpnIp.IsValid() { return nil, false @@ -289,10 +295,10 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp) + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrRelayFor(hostinfo.vpnAddrs[0], relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") diff --git a/interface.go b/interface.go index 9308aae84..9686d128b 100644 --- a/interface.go +++ b/interface.go @@ -11,8 +11,10 @@ import ( "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -27,7 +29,6 @@ type InterfaceConfig struct { Outside udp.Conn Inside overlay.Device pki *PKI - Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager @@ -55,15 +56,16 @@ type Interface struct { outside udp.Conn inside overlay.Device pki *PKI - cipher string firewall *Firewall connectionManager *connectionManager handshakeManager *HandshakeManager serveDns bool createTime time.Time lightHouse *LightHouse - myBroadcastAddr netip.Addr - myVpnNet netip.Prefix + myBroadcastAddr *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks *bart.Table[struct{}] // A table of networks assigned to us via our certificate dropLocalBroadcast bool dropMulticast bool routines int @@ -104,6 +106,7 @@ type EncWriter interface { SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) Handshake(vpnIp netip.Addr) + GetHostInfo(vpnIp netip.Addr) *HostInfo } type sendRecvErrorConfig uint8 @@ -154,14 +157,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } - certificate := c.pki.GetCertState().Certificate - ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, outside: c.Outside, inside: c.Inside, - cipher: c.Cipher, firewall: c.Firewall, serveDns: c.ServeDns, handshakeManager: c.HandshakeManager, @@ -173,7 +173,8 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: certificate.Networks()[0], + myVpnNetworks: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -188,11 +189,25 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if ifce.myVpnNet.Addr().Is4() { - //TODO: - //addr := myVpnNet.Masked().Addr().As4() - //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - //ifce.myBroadcastAddr = netip.AddrFrom4(addr) + var crt cert.Certificate + cs := c.pki.getCertState() + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) + } + + for _, network := range crt.Networks() { + ifce.myVpnNetworks.Insert(network, struct{}{}) + ifce.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + ifce.myVpnAddrs = append(ifce.myVpnAddrs, network.Addr()) + + if network.Addr().Is4() { + //TODO: finish calculating the broadcast ips + //addr := network.Masked().Addr().As4() + //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + //ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } } ifce.tryPromoteEvery.Store(c.tryPromoteEvery) @@ -322,7 +337,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getDefaultCertificate(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -414,11 +429,16 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.getDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second)) + //TODO: we should also report the default certificate version } } } +func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return f.hostMap.QueryVpnAddr(vpnIp) +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/lighthouse.go b/lighthouse.go index 62f406560..623ed0b9e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -15,6 +15,7 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -31,9 +32,11 @@ type LightHouse struct { sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnNet netip.Prefix - punchConn udp.Conn - punchy *Punchy + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + punchConn udp.Conn + punchy *Punchy // Local cache of answers from light houses // map of vpn Ip to answers @@ -57,10 +60,11 @@ type LightHouse struct { staticList atomic.Pointer[map[netip.Addr]struct{}] lighthouses atomic.Pointer[map[netip.Addr]struct{}] - interval atomic.Int64 - updateCancel context.CancelFunc - ifce EncWriter - nebulaPort uint32 // 32 bits because protobuf does not have a uint16 + interval atomic.Int64 + updateCancel context.CancelFunc + ifce EncWriter + nebulaPort uint32 // 32 bits because protobuf does not have a uint16 + protocolVersion atomic.Uint32 // The default protocol version to use if we can't determine which to use from the tunnel advertiseAddrs atomic.Pointer[[]netip.AddrPort] @@ -78,7 +82,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -95,15 +99,16 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } h := LightHouse{ - ctx: ctx, - amLighthouse: amLighthouse, - myVpnNet: myVpnNet, - addrMap: make(map[netip.Addr]*RemoteList), - nebulaPort: nebulaPort, - punchConn: pc, - punchy: p, - queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + ctx: ctx, + amLighthouse: amLighthouse, + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + addrMap: make(map[netip.Addr]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + l: l, } lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) @@ -199,7 +204,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used ip := ips[0].Unmap() - if lh.myVpnNet.Contains(ip) { + _, found := lh.myVpnNetworksTable.Lookup(ip) + if found { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue @@ -345,6 +351,16 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { } } + v := c.GetUint32("pki.default_version", 1) + switch v { + case 1: + lh.protocolVersion.Store(1) + case 2: + lh.protocolVersion.Store(2) + default: + return fmt.Errorf("invalid version for lighthouse: %v", v) + } + return nil } @@ -359,8 +375,10 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ if err != nil { return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) + + _, found := lh.myVpnNetworksTable.Lookup(ip) + if !found { + return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnIp": ip, "networks": lh.myVpnNetworks}, nil) } lhMap[ip] = struct{}{} } @@ -430,8 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(vpnIp) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) + _, found := lh.myVpnNetworksTable.Lookup(vpnIp) + if !found { + return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } vals, ok := v.([]interface{}) @@ -516,7 +535,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddr(vpnIp netip.Addr) { // First we check the static mapping // and do nothing if it is there if _, ok := lh.GetStaticHostList()[vpnIp]; ok { @@ -563,9 +582,9 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t } switch { case addrPort.Addr().Is4(): - am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } @@ -578,6 +597,7 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { + //TODO: this needs to support v6 addresses too tree := lh.getCalculatedRemotes() if tree == nil { return false @@ -587,7 +607,7 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { return false } - var calculated []*Ip4AndPort + var calculated []*V4AddrPort for _, cr := range calculatedRemotes { c := cr.Apply(vpnIp) if c != nil { @@ -601,7 +621,7 @@ func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } @@ -621,7 +641,12 @@ func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(to) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(to) + if found { return false } @@ -629,14 +654,19 @@ func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { - ip := AddrPortFromIp4AndPort(to) +func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *V4AddrPort) bool { + ip := protoV4AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { + return false + } + + _, found := lh.myVpnNetworksTable.Lookup(ip.Addr()) + if found { return false } @@ -644,25 +674,23 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { - ip := AddrPortFromIp6AndPort(to) +func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *V6AddrPort) bool { + ip := protoV6AddrPortToNetAddrPort(to) allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("remoteIp", protoV6AddrPortToNetAddrPort(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { return false } - return true -} + _, found := lh.myVpnNetworksTable.Lookup(ip.Addr()) + if found { + return false + } -func lhIp6ToIp(v *Ip6AndPort) net.IP { - ip := make(net.IP, 16) - binary.BigEndian.PutUint64(ip[:8], v.Hi) - binary.BigEndian.PutUint64(ip[8:], v.Lo) - return ip + return true } func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { @@ -672,52 +700,6 @@ func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { return false } -func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { - if vpnIp.Is6() { - //TODO: need to support ipv6 - panic("ipv6 is not yet supported") - } - - b := vpnIp.As4() - return &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - }, - } -} - -func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], ip.Ip) - return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) -} - -func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { - b := [16]byte{} - binary.BigEndian.PutUint64(b[:8], ip.Hi) - binary.BigEndian.PutUint64(b[8:], ip.Lo) - return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) -} - -func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { - v4Addr := ip.As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), - Port: uint32(port), - } -} - -// TODO: IPV6-WORK we can delete some more of these -func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { - ip6Addr := ip.As16() - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip6Addr[:8]), - Lo: binary.BigEndian.Uint64(ip6Addr[8:]), - Port: uint32(port), - } -} - func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { return @@ -738,15 +720,36 @@ func (lh *LightHouse) startQueryWorker() { }() } -func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { - if lh.IsLighthouseIP(ip) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { + if lh.IsLighthouseIP(addr) { return } // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() + v := lh.protocolVersion.Load() + msg := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, + } + + if v == 1 { + if !addr.Is4() { + lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol") + return + } + b := addr.As4() + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + + } else if v == 2 { + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + } else { + panic("unsupported version") + } + + query, err := msg.Marshal() if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") + lh.l.WithError(err).WithField("vpnAddr", addr).Error("Failed to marshal lighthouse query payload") return } @@ -754,6 +757,8 @@ func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) for n := range lighthouses { + //TODO: there is a slight possibility this lighthouse is using a v2 protocol even if our default is v1 + // We could facilitate the move to v2 by marshalling a v2 query lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) } } @@ -785,57 +790,72 @@ func (lh *LightHouse) StartUpdateWorker() { } func (lh *LightHouse) SendUpdate() { - var v4 []*Ip4AndPort - var v6 []*Ip6AndPort + var v4 []*V4AddrPort + var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { if e.Addr().Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) + v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) + v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() for _, e := range localIps(lh.l, lal) { - if lh.myVpnNet.Contains(e) { + _, found := lh.myVpnNetworksTable.Lookup(e) + if found { continue } // Only add IPs that aren't my VPN/tun IP if e.Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } - var relays []uint32 - for _, r := range lh.GetRelaysForMe() { - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := r.As4() - relays = append(relays, binary.BigEndian.Uint32(b[:])) - } - - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := lh.myVpnNet.Addr().As4() - - m := &NebulaMeta{ + v := lh.protocolVersion.Load() + msg := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - Ip4AndPorts: v4, - Ip6AndPorts: v6, - RelayVpnIp: relays, + V4AddrPorts: v4, + V6AddrPorts: v6, }, } + if v == 1 { + var relays []uint32 + for _, r := range lh.GetRelaysForMe() { + if !r.Is4() { + continue + } + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) + } + + //TODO: need an ipv4 vpn addr to use + msg.Details.OldRelayVpnAddrs = relays + + } else if v == 2 { + var relays []*Addr + for _, r := range lh.GetRelaysForMe() { + relays = append(relays, netAddrToProtoAddr(r)) + } + + //TODO: need a vpn addr to use + + } else { + panic("protocol version not supported") + } + lighthouses := lh.GetLighthouses() lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) nb := make([]byte, 12, 12) out := make([]byte, mtu) - mm, err := m.Marshal() + mm, err := msg.Marshal() if err != nil { lh.l.WithError(err).Error("Error while marshaling for lighthouse update") return @@ -886,32 +906,33 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { lhh.meta.Reset() // Keep the array memory around - details.Ip4AndPorts = details.Ip4AndPorts[:0] - details.Ip6AndPorts = details.Ip6AndPorts[:0] - details.RelayVpnIp = details.RelayVpnIp[:0] + details.V4AddrPorts = details.V4AddrPorts[:0] + details.V6AddrPorts = details.V6AddrPorts[:0] + details.RelayVpnAddrs = details.RelayVpnAddrs[:0] + details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] lhh.meta.Details = details return lhh.meta } func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { - lhh.HandleRequest(rAddr, vpnIp, p, f) + return func(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte) { + lhh.HandleRequest(rAddr, vpnAddrs, p, f) } } -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") //TODO: send recv_error? return @@ -921,24 +942,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Ad switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnIp, rAddr, w) + lhh.handleHostQuery(n, vpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnIp) + lhh.handleHostQueryReply(n, vpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp, w) + lhh.handleHostUpdateNotification(n, vpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnIp, w) + lhh.handleHostPunchNotification(n, vpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -947,21 +968,36 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a return } - //TODO: we can DRY this further - reqVpnIp := n.Details.VpnIp - - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + var useVersion cert.Version + var queryVpnIp netip.Addr + var reqVpnIp netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnIp = netip.AddrFrom4(b) + reqVpnIp = queryVpnIp + useVersion = 1 + } else if n.Details.VpnAddr != nil { + queryVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) + reqVpnIp = queryVpnIp + useVersion = 2 + } //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIp + if useVersion == 1 { + if !reqVpnIp.Is4() { + return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") + } + b := reqVpnIp.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else { + n.Details.VpnAddr = netAddrToProtoAddr(reqVpnIp) + } - lhh.coalesceAnswers(c, n) + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -971,21 +1007,40 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { + found, ln, err = lhh.lh.queryAndPrepMessage(vpnAddrs[0], func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - //TODO: IPV6-WORK - b = vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) - lhh.coalesceAnswers(c, n) + //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel + // and use that version then fallback to our default configuration + targetHI := lhh.lh.ifce.GetHostInfo(reqVpnIp) + useVersion = cert.Version(lhh.lh.protocolVersion.Load()) + if targetHI != nil { + useVersion = targetHI.GetCert().Certificate.Version() + } + + if useVersion == cert.Version1 { + if !vpnAddrs[0].Is4() { + return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") + } + b := vpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + lhh.coalesceAnswers(useVersion, c, n) + + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0]) + lhh.coalesceAnswers(useVersion, c, n) + + } else { + panic("unsupported version") + } return n.MarshalTo(lhh.pb) }) @@ -995,74 +1050,96 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - - //TODO: IPV6-WORK - binary.BigEndian.PutUint32(b[:], reqVpnIp) - sendTo := netip.AddrFrom4(b) - w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, reqVpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { +func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { - //TODO: IPV6-WORK - relays := make([]uint32, len(c.relay.relay)) - b := [4]byte{} - for i, _ := range relays { - b = c.relay.relay[i].As4() - relays[i] = binary.BigEndian.Uint32(b[:]) + if v == cert.Version1 { + b := [4]byte{} + for _, r := range c.relay.relay { + if !r.Is4() { + continue + } + + b = r.As4() + n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) + } + + } else if v == cert.Version2 { + for _, r := range c.relay.relay { + n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) + } + + } else { + panic("unsupported version") } - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []netip.Addr) { + //TODO: this is kind of dumb + if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) { return } lhh.lh.Lock() - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - certVpnIp := netip.AddrFrom4(b) + + var certVpnIp netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + certVpnIp = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + certVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) + } + am := lhh.lh.unlockedGetRemoteList(certVpnIp) am.Lock() lhh.lh.Unlock() - //TODO: IPV6-WORK - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) + var relays []netip.Addr + if len(n.Details.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range n.Details.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } } - am.unlockedSetRelay(vpnIp, certVpnIp, relays) + + if len(n.Details.RelayVpnAddrs) > 0 { + for _, r := range n.Details.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + + am.unlockedSetRelay(vpnAddrs[0], certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -1072,62 +1149,91 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Ad } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnAddrs) } return } //Simple check that the host sent this not someone else - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - detailsVpnIp := netip.AddrFrom4(b) - if detailsVpnIp != vpnIp { + var detailsVpnIp netip.Addr + var useVersion cert.Version + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + detailsVpnIp = netip.AddrFrom4(b) + useVersion = 1 + } else if n.Details.VpnAddr != nil { + detailsVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = 2 + } + + if detailsVpnIp != vpnAddrs[0] { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnIp) + am := lhh.lh.unlockedGetRemoteList(vpnAddrs[0]) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) + var relays []netip.Addr + if len(n.Details.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range n.Details.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } } - am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) + + if len(n.Details.RelayVpnAddrs) > 0 { + for _, r := range n.Details.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + + am.unlockedSetRelay(vpnAddrs[0], detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - //TODO: IPV6-WORK - vpnIpB := vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) - ln, err := n.MarshalTo(lhh.pb) + if useVersion == cert.Version1 { + if !vpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + return + } + vpnIpB := vpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnIpB[:]) + + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0]) + } else { + panic("unsupported version") + } + + ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) { + //TODO: this is kinda stupid + if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) { return } @@ -1144,30 +1250,39 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp n }() if lhh.l.Level >= logrus.DebugLevel { - //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - //TODO: IPV6-WORK, make this debug line not suck - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) + var logVpnIp netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + logVpnIp = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + logVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) + } + lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnIp) } } - for _, a := range n.Details.Ip4AndPorts { - punch(AddrPortFromIp4AndPort(a)) + for _, a := range n.Details.V4AddrPorts { + punch(protoV4AddrPortToNetAddrPort(a)) } - for _, a := range n.Details.Ip6AndPorts { - punch(AddrPortFromIp6AndPort(a)) + for _, a := range n.Details.V6AddrPorts { + punch(protoV6AddrPortToNetAddrPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + var queryVpnIp netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnIp = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + queryVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) + } + go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { @@ -1180,3 +1295,48 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp n }() } } + +func protoAddrToNetAddr(addr *Addr) netip.Addr { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], addr.Hi) + binary.BigEndian.PutUint64(b[8:], addr.Lo) + return netip.AddrFrom16(b).Unmap() +} + +func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ap.Addr) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) +} + +func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ap.Hi) + binary.BigEndian.PutUint64(b[8:], ap.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) +} + +func netAddrToProtoAddr(addr netip.Addr) *Addr { + b := addr.As16() + return &Addr{ + Hi: binary.BigEndian.Uint64(b[:8]), + Lo: binary.BigEndian.Uint64(b[8:]), + } +} + +func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { + v4Addr := addr.As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + +func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { + ip6Addr := addr.As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(ip6Addr[:8]), + Lo: binary.BigEndian.Uint64(ip6Addr[8:]), + Port: uint32(port), + } +} diff --git a/lighthouse_test.go b/lighthouse_test.go index 2599f5f2e..fbb86a137 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -7,6 +7,7 @@ import ( "net/netip" "testing" + "github.com/gaissmai/bart" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -19,57 +20,48 @@ import ( func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} - var m Ip4AndPort + var m V4AddrPort err := m.Unmarshal(b) assert.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() - assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) -} - -func TestNewLhQuery(t *testing.T) { - myIp, err := netip.ParseAddr("192.1.1.1") - assert.NoError(t, err) - - // Generating a new lh query should work - a := NewLhQueryByInt(myIp) - - // The result should be a nebulameta protobuf - assert.IsType(t, &NebulaMeta{}, a) - - // It should also Marshal fine - b, err := a.Marshal() - assert.Nil(t, err) - - // and then Unmarshal fine - n := &NebulaMeta{} - err = n.Unmarshal(b) - assert.Nil(t, err) - + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) @@ -79,7 +71,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,9 +91,15 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) if !assert.NoError(b, err) { b.Fatal() } @@ -114,11 +112,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), - NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), + netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rAddr := netip.MustParseAddrPort("1.2.2.3:12345") @@ -128,11 +126,11 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), - NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), + netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} @@ -142,14 +140,14 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 4, - Ip4AndPorts: nil, + OldVpnAddr: 4, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -157,15 +155,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 3, - Ip4AndPorts: nil, + OldVpnAddr: 3, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw) } }) } @@ -197,40 +195,49 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh.ifce = &mockEncWriter{} assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) @@ -255,7 +262,7 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, - r.msg.Details.Ip4AndPorts, + r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) @@ -265,7 +272,7 @@ func TestLighthouse_Memory(t *testing.T) { good := netip.MustParseAddrPort("1.128.0.99:4242") newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { @@ -273,7 +280,16 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -295,7 +311,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), + OldVpnAddr: binary.BigEndian.Uint32(bip[:]), }, } @@ -308,7 +324,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, myVpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } @@ -318,13 +334,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), + OldVpnAddr: binary.BigEndian.Uint32(bip[:]), + V4AddrPorts: make([]*V4AddrPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) + req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -333,7 +349,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, vpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } //TODO: this is a RemoteList test @@ -426,7 +442,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M tw.lastReply = testLhReply{ nebType: t, nebSubType: st, - vpnIp: hostinfo.vpnIp, + vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } @@ -453,15 +469,19 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } } +func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { +func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { //TODO: IPV6-WORK - h := AddrPortFromIp4AndPort(have[k]) + h := protoV4AddrPortToNetAddrPort(have[k]) if !(h == w) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } diff --git a/main.go b/main.go index 8f4535951..5e97b4a77 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "fmt" "net" "net/netip" @@ -61,15 +60,13 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.GetCertState().Certificate + certificate := pki.getDefaultCertificate() fw, err := NewFirewallFromConfig(l, certificate, c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - tunCidr := certificate.Networks()[0] - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) @@ -132,7 +129,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, tunCidr, routines) + //TODO: device needs all networks not just the first one + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks[0], routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -187,9 +185,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMapFromConfig(l, tunCidr, c) + hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } @@ -232,7 +230,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Inside: tun, Outside: udpConns[0], pki: pki, - Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, @@ -254,15 +251,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l: l, } - switch ifConfig.Cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) - } - var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) @@ -303,7 +291,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, c) + dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ diff --git a/nebula.pb.go b/nebula.pb.go index 3ae0371ef..2fd2ff665 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4, 0} + return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7, 0} + return fileDescriptor_2d65afa7693df5ef, []int{8, 0} } type NebulaMeta struct { @@ -180,11 +180,13 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,proto3" json:"VpnIp,omitempty"` - Ip4AndPorts []*Ip4AndPort `protobuf:"bytes,2,rep,name=Ip4AndPorts,proto3" json:"Ip4AndPorts,omitempty"` - Ip6AndPorts []*Ip6AndPort `protobuf:"bytes,4,rep,name=Ip6AndPorts,proto3" json:"Ip6AndPorts,omitempty"` - RelayVpnIp []uint32 `protobuf:"varint,5,rep,packed,name=RelayVpnIp,proto3" json:"RelayVpnIp,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -220,30 +222,46 @@ func (m *NebulaMetaDetails) XXX_DiscardUnknown() { var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo -func (m *NebulaMetaDetails) GetVpnIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { - return m.VpnIp + return m.OldVpnAddr } return 0 } -func (m *NebulaMetaDetails) GetIp4AndPorts() []*Ip4AndPort { +func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { - return m.Ip4AndPorts + return m.VpnAddr } return nil } -func (m *NebulaMetaDetails) GetIp6AndPorts() []*Ip6AndPort { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { - return m.Ip6AndPorts + return m.OldRelayVpnAddrs } return nil } -func (m *NebulaMetaDetails) GetRelayVpnIp() []uint32 { +func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { - return m.RelayVpnIp + return m.RelayVpnAddrs + } + return nil +} + +func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { + if m != nil { + return m.V4AddrPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { + if m != nil { + return m.V6AddrPorts } return nil } @@ -255,23 +273,75 @@ func (m *NebulaMetaDetails) GetCounter() uint32 { return 0 } -type Ip4AndPort struct { - Ip uint32 `protobuf:"varint,1,opt,name=Ip,proto3" json:"Ip,omitempty"` - Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +type Addr struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } -func (m *Ip4AndPort) Reset() { *m = Ip4AndPort{} } -func (m *Ip4AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip4AndPort) ProtoMessage() {} -func (*Ip4AndPort) Descriptor() ([]byte, []int) { +func (m *Addr) Reset() { *m = Addr{} } +func (m *Addr) String() string { return proto.CompactTextString(m) } +func (*Addr) ProtoMessage() {} +func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } -func (m *Ip4AndPort) XXX_Unmarshal(b []byte) error { +func (m *Addr) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Addr.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Addr) XXX_Merge(src proto.Message) { + xxx_messageInfo_Addr.Merge(m, src) +} +func (m *Addr) XXX_Size() int { + return m.Size() +} +func (m *Addr) XXX_DiscardUnknown() { + xxx_messageInfo_Addr.DiscardUnknown(m) +} + +var xxx_messageInfo_Addr proto.InternalMessageInfo + +func (m *Addr) GetHi() uint64 { + if m != nil { + return m.Hi + } + return 0 +} + +func (m *Addr) GetLo() uint64 { + if m != nil { + return m.Lo + } + return 0 +} + +type V4AddrPort struct { + Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } +func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } +func (*V4AddrPort) ProtoMessage() {} +func (*V4AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} +func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip4AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -281,50 +351,50 @@ func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip4AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip4AndPort.Merge(m, src) +func (m *V4AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V4AddrPort.Merge(m, src) } -func (m *Ip4AndPort) XXX_Size() int { +func (m *V4AddrPort) XXX_Size() int { return m.Size() } -func (m *Ip4AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip4AndPort.DiscardUnknown(m) +func (m *V4AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V4AddrPort.DiscardUnknown(m) } -var xxx_messageInfo_Ip4AndPort proto.InternalMessageInfo +var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo -func (m *Ip4AndPort) GetIp() uint32 { +func (m *V4AddrPort) GetAddr() uint32 { if m != nil { - return m.Ip + return m.Addr } return 0 } -func (m *Ip4AndPort) GetPort() uint32 { +func (m *V4AddrPort) GetPort() uint32 { if m != nil { return m.Port } return 0 } -type Ip6AndPort struct { +type V6AddrPort struct { Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` } -func (m *Ip6AndPort) Reset() { *m = Ip6AndPort{} } -func (m *Ip6AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip6AndPort) ProtoMessage() {} -func (*Ip6AndPort) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{3} +func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } +func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } +func (*V6AddrPort) ProtoMessage() {} +func (*V6AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} } -func (m *Ip6AndPort) XXX_Unmarshal(b []byte) error { +func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip6AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -334,33 +404,33 @@ func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip6AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip6AndPort.Merge(m, src) +func (m *V6AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V6AddrPort.Merge(m, src) } -func (m *Ip6AndPort) XXX_Size() int { +func (m *V6AddrPort) XXX_Size() int { return m.Size() } -func (m *Ip6AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip6AndPort.DiscardUnknown(m) +func (m *V6AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V6AddrPort.DiscardUnknown(m) } -var xxx_messageInfo_Ip6AndPort proto.InternalMessageInfo +var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo -func (m *Ip6AndPort) GetHi() uint64 { +func (m *V6AddrPort) GetHi() uint64 { if m != nil { return m.Hi } return 0 } -func (m *Ip6AndPort) GetLo() uint64 { +func (m *V6AddrPort) GetLo() uint64 { if m != nil { return m.Lo } return 0 } -func (m *Ip6AndPort) GetPort() uint32 { +func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } @@ -376,7 +446,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4} + return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,7 +498,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -484,7 +554,7 @@ func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -559,15 +629,17 @@ type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` - RelayToIp uint32 `protobuf:"varint,4,opt,name=RelayToIp,proto3" json:"RelayToIp,omitempty"` - RelayFromIp uint32 `protobuf:"varint,5,opt,name=RelayFromIp,proto3" json:"RelayFromIp,omitempty"` + OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. + OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. + RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` + RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -617,28 +689,45 @@ func (m *NebulaControl) GetResponderRelayIndex() uint32 { return 0 } -func (m *NebulaControl) GetRelayToIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { - return m.RelayToIp + return m.OldRelayToAddr } return 0 } -func (m *NebulaControl) GetRelayFromIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { - return m.RelayFromIp + return m.OldRelayFromAddr } return 0 } +func (m *NebulaControl) GetRelayToAddr() *Addr { + if m != nil { + return m.RelayToAddr + } + return nil +} + +func (m *NebulaControl) GetRelayFromAddr() *Addr { + if m != nil { + return m.RelayFromAddr + } + return nil +} + func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") - proto.RegisterType((*Ip4AndPort)(nil), "nebula.Ip4AndPort") - proto.RegisterType((*Ip6AndPort)(nil), "nebula.Ip6AndPort") + proto.RegisterType((*Addr)(nil), "nebula.Addr") + proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") + proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") @@ -648,52 +737,57 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 720 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0xcd, 0x6e, 0xd3, 0x40, - 0x10, 0x8e, 0x1d, 0xe7, 0x6f, 0xd2, 0xa4, 0x66, 0x0a, 0x21, 0x41, 0x60, 0x05, 0x1f, 0x50, 0x4e, - 0x69, 0x95, 0x96, 0x8a, 0x23, 0x25, 0x08, 0x25, 0x55, 0x5b, 0x85, 0x55, 0x29, 0x12, 0x17, 0xb4, - 0x75, 0x96, 0xc6, 0x4a, 0xe2, 0x75, 0xed, 0x0d, 0x6a, 0xde, 0x82, 0x87, 0xe1, 0x21, 0xb8, 0xd1, - 0x13, 0xe2, 0x88, 0xda, 0x23, 0x47, 0x5e, 0x00, 0xed, 0x3a, 0x71, 0x9c, 0x34, 0x70, 0xdb, 0x99, - 0xf9, 0xbe, 0xd9, 0x6f, 0xbe, 0x1d, 0x1b, 0x36, 0x3c, 0x76, 0x3e, 0x19, 0xd1, 0xa6, 0x1f, 0x70, - 0xc1, 0x31, 0x1b, 0x45, 0xf6, 0x6f, 0x1d, 0xe0, 0x44, 0x1d, 0x8f, 0x99, 0xa0, 0xd8, 0x02, 0xe3, - 0x74, 0xea, 0xb3, 0xaa, 0x56, 0xd7, 0x1a, 0xe5, 0x96, 0xd5, 0x9c, 0x71, 0x16, 0x88, 0xe6, 0x31, - 0x0b, 0x43, 0x7a, 0xc1, 0x24, 0x8a, 0x28, 0x2c, 0xee, 0x42, 0xee, 0x35, 0x13, 0xd4, 0x1d, 0x85, - 0x55, 0xbd, 0xae, 0x35, 0x8a, 0xad, 0xda, 0x5d, 0xda, 0x0c, 0x40, 0xe6, 0x48, 0xfb, 0x8f, 0x06, - 0xc5, 0x44, 0x2b, 0xcc, 0x83, 0x71, 0xc2, 0x3d, 0x66, 0xa6, 0xb0, 0x04, 0x85, 0x0e, 0x0f, 0xc5, - 0xdb, 0x09, 0x0b, 0xa6, 0xa6, 0x86, 0x08, 0xe5, 0x38, 0x24, 0xcc, 0x1f, 0x4d, 0x4d, 0x1d, 0x1f, - 0x41, 0x45, 0xe6, 0xde, 0xf9, 0x7d, 0x2a, 0xd8, 0x09, 0x17, 0xee, 0x27, 0xd7, 0xa1, 0xc2, 0xe5, - 0x9e, 0x99, 0xc6, 0x1a, 0x3c, 0x90, 0xb5, 0x63, 0xfe, 0x99, 0xf5, 0x97, 0x4a, 0xc6, 0xbc, 0xd4, - 0x9b, 0x78, 0xce, 0x60, 0xa9, 0x94, 0xc1, 0x32, 0x80, 0x2c, 0xbd, 0x1f, 0x70, 0x3a, 0x76, 0xcd, - 0x2c, 0x6e, 0xc1, 0xe6, 0x22, 0x8e, 0xae, 0xcd, 0x49, 0x65, 0x3d, 0x2a, 0x06, 0xed, 0x01, 0x73, - 0x86, 0x66, 0x5e, 0x2a, 0x8b, 0xc3, 0x08, 0x52, 0xc0, 0x27, 0x50, 0x5b, 0xaf, 0xec, 0xc0, 0x19, - 0x9a, 0x60, 0x7f, 0xd7, 0xe0, 0xde, 0x1d, 0x53, 0xf0, 0x3e, 0x64, 0xce, 0x7c, 0xaf, 0xeb, 0x2b, - 0xd7, 0x4b, 0x24, 0x0a, 0x70, 0x0f, 0x8a, 0x5d, 0x7f, 0xef, 0xc0, 0xeb, 0xf7, 0x78, 0x20, 0xa4, - 0xb5, 0xe9, 0x46, 0xb1, 0x85, 0x73, 0x6b, 0x17, 0x25, 0x92, 0x84, 0x45, 0xac, 0xfd, 0x98, 0x65, - 0xac, 0xb2, 0xf6, 0x13, 0xac, 0x18, 0x86, 0x16, 0x00, 0x61, 0x23, 0x3a, 0x8d, 0x64, 0x64, 0xea, - 0xe9, 0x46, 0x89, 0x24, 0x32, 0x58, 0x85, 0x9c, 0xc3, 0x27, 0x9e, 0x60, 0x41, 0x35, 0xad, 0x34, - 0xce, 0x43, 0x7b, 0x07, 0x60, 0x71, 0x3d, 0x96, 0x41, 0x8f, 0xc7, 0xd0, 0xbb, 0x3e, 0x22, 0x18, - 0x32, 0xaf, 0xf6, 0xa2, 0x44, 0xd4, 0xd9, 0x7e, 0x29, 0x19, 0xfb, 0x09, 0x46, 0xc7, 0x55, 0x0c, - 0x83, 0xe8, 0x1d, 0x57, 0xc6, 0x47, 0x5c, 0xe1, 0x0d, 0xa2, 0x1f, 0xf1, 0xb8, 0x43, 0x3a, 0xd1, - 0xe1, 0x6a, 0xbe, 0xb2, 0x3d, 0xd7, 0xbb, 0xf8, 0xff, 0xca, 0x4a, 0xc4, 0x9a, 0x95, 0x45, 0x30, - 0x4e, 0xdd, 0x31, 0x9b, 0xdd, 0xa3, 0xce, 0xb6, 0x7d, 0x67, 0x21, 0x25, 0xd9, 0x4c, 0x61, 0x01, - 0x32, 0xd1, 0xf3, 0x6a, 0xf6, 0x47, 0xd8, 0x8c, 0xfa, 0x76, 0xa8, 0xd7, 0x0f, 0x07, 0x74, 0xc8, - 0xf0, 0xc5, 0x62, 0xfb, 0x35, 0xb5, 0xfd, 0x2b, 0x0a, 0x62, 0xe4, 0xea, 0x27, 0x20, 0x45, 0x74, - 0xc6, 0xd4, 0x51, 0x22, 0x36, 0x88, 0x3a, 0xdb, 0x3f, 0x34, 0xa8, 0xac, 0xe7, 0x49, 0x78, 0x9b, - 0x05, 0x42, 0xdd, 0xb2, 0x41, 0xd4, 0x19, 0x9f, 0x41, 0xb9, 0xeb, 0xb9, 0xc2, 0xa5, 0x82, 0x07, - 0x5d, 0xaf, 0xcf, 0xae, 0x66, 0x4e, 0xaf, 0x64, 0x25, 0x8e, 0xb0, 0xd0, 0xe7, 0x5e, 0x9f, 0xcd, - 0x70, 0x91, 0x9f, 0x2b, 0x59, 0xac, 0x40, 0xb6, 0xcd, 0xf9, 0xd0, 0x65, 0x55, 0x43, 0x39, 0x33, - 0x8b, 0x62, 0xbf, 0x32, 0x0b, 0xbf, 0xb0, 0x0e, 0x45, 0xa9, 0xe1, 0x8c, 0x05, 0xa1, 0xcb, 0xbd, - 0x6a, 0x5e, 0x35, 0x4c, 0xa6, 0x0e, 0x8d, 0x7c, 0xd6, 0xcc, 0x1d, 0x1a, 0xf9, 0x9c, 0x99, 0xb7, - 0xbf, 0xea, 0x50, 0x8a, 0x06, 0x6b, 0x73, 0x4f, 0x04, 0x7c, 0x84, 0xcf, 0x97, 0xde, 0xed, 0xe9, - 0xb2, 0x6b, 0x33, 0xd0, 0x9a, 0xa7, 0xdb, 0x81, 0xad, 0x78, 0x38, 0xb5, 0xa1, 0xc9, 0xb9, 0xd7, - 0x95, 0x24, 0x23, 0x1e, 0x33, 0xc1, 0x88, 0x1c, 0x58, 0x57, 0xc2, 0xc7, 0x50, 0x50, 0xd1, 0x29, - 0xef, 0xfa, 0xca, 0x89, 0x12, 0x59, 0x24, 0xe4, 0xe0, 0x2a, 0x78, 0x13, 0xf0, 0xb1, 0xfa, 0x5a, - 0xd4, 0xe0, 0x89, 0x94, 0xdd, 0xf9, 0xd7, 0xbf, 0xad, 0x02, 0xd8, 0x0e, 0x18, 0x15, 0x4c, 0xa1, - 0x09, 0xbb, 0x9c, 0xb0, 0x50, 0x98, 0x1a, 0x3e, 0x84, 0xad, 0xa5, 0xbc, 0x94, 0x14, 0x32, 0x53, - 0x7f, 0xb5, 0xfb, 0xed, 0xc6, 0xd2, 0xae, 0x6f, 0x2c, 0xed, 0xd7, 0x8d, 0xa5, 0x7d, 0xb9, 0xb5, - 0x52, 0xd7, 0xb7, 0x56, 0xea, 0xe7, 0xad, 0x95, 0xfa, 0x50, 0xbb, 0x70, 0xc5, 0x60, 0x72, 0xde, - 0x74, 0xf8, 0x78, 0x3b, 0x1c, 0x51, 0x67, 0x38, 0xb8, 0xdc, 0x8e, 0x2c, 0x3c, 0xcf, 0xaa, 0x5f, - 0xfc, 0xee, 0xdf, 0x00, 0x00, 0x00, 0xff, 0xff, 0xa3, 0xa5, 0xef, 0x45, 0xf2, 0x05, 0x00, 0x00, + // 785 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, + 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, + 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, + 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, + 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, + 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, + 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, + 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, + 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, + 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, + 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, + 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, + 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, + 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, + 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, + 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, + 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, + 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, + 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, + 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, + 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, + 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, + 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, + 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, + 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, + 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, + 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, + 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, + 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, + 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, + 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, + 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, + 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, + 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, + 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, + 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, + 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, + 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, + 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, + 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, + 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, + 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, + 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, + 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, + 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, + 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, + 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, + 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, + 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, + 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -756,28 +850,54 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if len(m.RelayVpnIp) > 0 { - dAtA3 := make([]byte, len(m.RelayVpnIp)*10) - var j2 int - for _, num := range m.RelayVpnIp { + if len(m.RelayVpnAddrs) > 0 { + for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + } + if m.VpnAddr != nil { + { + size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if len(m.OldRelayVpnAddrs) > 0 { + dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j3 int + for _, num := range m.OldRelayVpnAddrs { for num >= 1<<7 { - dAtA3[j2] = uint8(uint64(num)&0x7f | 0x80) + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) num >>= 7 - j2++ + j3++ } - dAtA3[j2] = uint8(num) - j2++ + dAtA4[j3] = uint8(num) + j3++ } - i -= j2 - copy(dAtA[i:], dAtA3[:j2]) - i = encodeVarintNebula(dAtA, i, uint64(j2)) + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } - if len(m.Ip6AndPorts) > 0 { - for iNdEx := len(m.Ip6AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V6AddrPorts) > 0 { + for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip6AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -793,10 +913,10 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x18 } - if len(m.Ip4AndPorts) > 0 { - for iNdEx := len(m.Ip4AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V4AddrPorts) > 0 { + for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip4AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -807,15 +927,48 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x12 } } - if m.VpnIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *Addr) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Addr) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Lo != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) + i-- + dAtA[i] = 0x10 + } + if m.Hi != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { +func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -825,12 +978,12 @@ func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip4AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -840,15 +993,15 @@ func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x10 } - if m.Ip != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Ip)) + if m.Addr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { +func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -858,12 +1011,12 @@ func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip6AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip6AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -1036,13 +1189,37 @@ func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.RelayFromIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayFromIp)) + if m.RelayFromAddr != nil { + { + size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + if m.RelayToAddr != nil { + { + size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if m.OldRelayFromAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } - if m.RelayToIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } @@ -1097,11 +1274,11 @@ func (m *NebulaMetaDetails) Size() (n int) { } var l int _ = l - if m.VpnIp != 0 { - n += 1 + sovNebula(uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + n += 1 + sovNebula(uint64(m.OldVpnAddr)) } - if len(m.Ip4AndPorts) > 0 { - for _, e := range m.Ip4AndPorts { + if len(m.V4AddrPorts) > 0 { + for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } @@ -1109,30 +1286,55 @@ func (m *NebulaMetaDetails) Size() (n int) { if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } - if len(m.Ip6AndPorts) > 0 { - for _, e := range m.Ip6AndPorts { + if len(m.V6AddrPorts) > 0 { + for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } - if len(m.RelayVpnIp) > 0 { + if len(m.OldRelayVpnAddrs) > 0 { l = 0 - for _, e := range m.RelayVpnIp { + for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } + if m.VpnAddr != nil { + l = m.VpnAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if len(m.RelayVpnAddrs) > 0 { + for _, e := range m.RelayVpnAddrs { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } return n } -func (m *Ip4AndPort) Size() (n int) { +func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Ip != 0 { - n += 1 + sovNebula(uint64(m.Ip)) + if m.Hi != 0 { + n += 1 + sovNebula(uint64(m.Hi)) + } + if m.Lo != 0 { + n += 1 + sovNebula(uint64(m.Lo)) + } + return n +} + +func (m *V4AddrPort) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Addr != 0 { + n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) @@ -1140,7 +1342,7 @@ func (m *Ip4AndPort) Size() (n int) { return n } -func (m *Ip6AndPort) Size() (n int) { +func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } @@ -1233,11 +1435,19 @@ func (m *NebulaControl) Size() (n int) { if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } - if m.RelayToIp != 0 { - n += 1 + sovNebula(uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayToAddr)) } - if m.RelayFromIp != 0 { - n += 1 + sovNebula(uint64(m.RelayFromIp)) + if m.OldRelayFromAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) + } + if m.RelayToAddr != nil { + l = m.RelayToAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if m.RelayFromAddr != nil { + l = m.RelayFromAddr.Size() + n += 1 + l + sovNebula(uint64(l)) } return n } @@ -1384,9 +1594,9 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field VpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } - m.VpnIp = 0 + m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1396,14 +1606,14 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.VpnIp |= uint32(b&0x7F) << shift + m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip4AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1430,8 +1640,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip4AndPorts = append(m.Ip4AndPorts, &Ip4AndPort{}) - if err := m.Ip4AndPorts[len(m.Ip4AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) + if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1456,7 +1666,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } case 4: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip6AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1483,8 +1693,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip6AndPorts = append(m.Ip6AndPorts, &Ip6AndPort{}) - if err := m.Ip6AndPorts[len(m.Ip6AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) + if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1505,7 +1715,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { @@ -1540,8 +1750,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } } elementCount = count - if elementCount != 0 && len(m.RelayVpnIp) == 0 { - m.RelayVpnIp = make([]uint32, 0, elementCount) + if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { + m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 @@ -1559,10 +1769,168 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { - return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) + } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.VpnAddr == nil { + m.VpnAddr = &Addr{} + } + if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) + if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *Addr) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Addr: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) + } + m.Hi = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Hi |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) + } + m.Lo = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Lo |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } } default: iNdEx = preIndex @@ -1585,7 +1953,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { +func (m *V4AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1608,17 +1976,17 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip4AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip4AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) } - m.Ip = 0 + m.Addr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1628,7 +1996,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Ip |= uint32(b&0x7F) << shift + m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } @@ -1673,7 +2041,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { +func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1696,10 +2064,10 @@ func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip6AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip6AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -2255,9 +2623,9 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } case 4: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayToIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } - m.RelayToIp = 0 + m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2267,16 +2635,16 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayToIp |= uint32(b&0x7F) << shift + m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayFromIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) } - m.RelayFromIp = 0 + m.OldRelayFromAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2286,11 +2654,83 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayFromIp |= uint32(b&0x7F) << shift + m.OldRelayFromAddr |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift if b < 0x80 { break } } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayToAddr == nil { + m.RelayToAddr = &Addr{} + } + if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayFromAddr == nil { + m.RelayFromAddr = &Addr{} + } + if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index 4dc15f193..ea1023348 100644 --- a/nebula.proto +++ b/nebula.proto @@ -23,19 +23,28 @@ message NebulaMeta { } message NebulaMetaDetails { - uint32 VpnIp = 1; - repeated Ip4AndPort Ip4AndPorts = 2; - repeated Ip6AndPort Ip6AndPorts = 4; - repeated uint32 RelayVpnIp = 5; + uint32 OldVpnAddr = 1 [deprecated = true]; + Addr VpnAddr = 6; + + repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; + repeated Addr RelayVpnAddrs = 7; + + repeated V4AddrPort V4AddrPorts = 2; + repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } -message Ip4AndPort { - uint32 Ip = 1; +message Addr { + uint64 Hi = 1; + uint64 Lo = 2; +} + +message V4AddrPort { + uint32 Addr = 1; uint32 Port = 2; } -message Ip6AndPort { +message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; @@ -77,6 +86,10 @@ message NebulaControl { uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; - uint32 RelayToIp = 4; - uint32 RelayFromIp = 5; + + uint32 OldRelayToAddr = 4 [deprecated = true]; + uint32 OldRelayFromAddr = 5 [deprecated = true]; + + Addr RelayToAddr = 6; + Addr RelayFromAddr = 7; } diff --git a/outside.go b/outside.go index c83d77cdb..dd2ae2520 100644 --- a/outside.go +++ b/outside.go @@ -49,7 +49,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - if f.myVpnNet.Contains(ip.Addr()) { + _, found := f.myVpnNetworks.Lookup(ip.Addr()) + if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } @@ -106,7 +107,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -118,9 +119,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return case ForwardingType: // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrRelayFor(hostinfo.vpnAddrs[0], relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).Info("Failed to find target host info by ip") return } @@ -136,7 +137,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -159,7 +160,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf(ip, hostinfo.vpnIp, d) + lhf(ip, hostinfo.vpnAddrs, d) // Fallthrough to the bottom to record incoming traffic @@ -226,14 +227,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] Error("Failed to decrypt Control packet") return } - m := &NebulaControl{} - err = m.Unmarshal(d) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message") - break - } - f.relayManager.HandleControlMsg(hostinfo, m, f) + f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) @@ -251,7 +246,8 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + //TODO: we should delete all related vpnaddrs too + f.lightHouse.DeleteVpnAddr(hostInfo.vpnAddrs[0]) } } @@ -262,7 +258,8 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { if ip.IsValid() && hostinfo.remote != ip { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { + //TODO: this is weird now that we can have multiple vpn addrs + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], ip.Addr()) { hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") return } diff --git a/pki.go b/pki.go index 490b30c8b..25d4a0e14 100644 --- a/pki.go +++ b/pki.go @@ -1,13 +1,18 @@ package nebula import ( + "encoding/binary" + "encoding/json" "errors" "fmt" + "net/netip" "os" + "slices" "strings" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" @@ -21,12 +26,21 @@ type PKI struct { } type CertState struct { - Certificate cert.Certificate - RawCertificate []byte - RawCertificateNoKey []byte - PublicKey []byte - PrivateKey []byte - pkcs11Backed bool + v1Cert cert.Certificate + v1HandshakeBytes []byte + + v2Cert cert.Certificate + v2HandshakeBytes []byte + + defaultVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -46,16 +60,24 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { return pki, nil } -func (p *PKI) GetCertState() *CertState { +func (p *PKI) GetCAPool() *cert.CAPool { + return p.caPool.Load() +} + +func (p *PKI) getCertState() *CertState { return p.cs.Load() } -func (p *PKI) GetCAPool() *cert.CAPool { - return p.caPool.Load() +func (p *PKI) getDefaultCertificate() cert.Certificate { + return p.cs.Load().GetDefaultCertificate() +} + +func (p *PKI) getCertificate(v cert.Version) cert.Certificate { + return p.cs.Load().getCertificate(v) } func (p *PKI) reload(c *config.C, initial bool) error { - err := p.reloadCert(c, initial) + err := p.reloadCerts(c, initial) if err != nil { if initial { return err @@ -74,33 +96,94 @@ func (p *PKI) reload(c *config.C, initial bool) error { return nil } -func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { - cs, err := newCertStateFromConfig(c) +func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { + newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { - //TODO: include check for mask equality as well + currentState := p.cs.Load() + if newState.v1Cert != nil { + if currentState.v1Cert == nil { + return util.NewContextualError("v1 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, + nil, + ) + } + + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, + nil, + ) + } + + } else if currentState.v1Cert != nil { + //TODO: we should be able to tear this down + return util.NewContextualError("v1 certificate was removed, restart required", nil, err) + } + + if newState.v2Cert != nil { + if currentState.v2Cert == nil { + return util.NewContextualError("v2 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, + nil, + ) + } + + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, + nil, + ) + } + + } else if currentState.v2Cert != nil { + return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + } - // did IP in cert change? if so, don't set - currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Networks() - newIPs := cs.Certificate.Networks() - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + // Cipher cant be hot swapped so just leave it at what it was before + newState.cipher = currentState.cipher + + } else { + newState.cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch newState.cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_network": newIPs[0], "old_network": oldIPs[0]}, + "unknown cipher", + m{"cipher": newState.cipher}, nil, ) } } - p.cs.Store(cs) + p.cs.Store(newState) + + //TODO: newState needs a stringer that does json if initial { - p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { - p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } @@ -116,55 +199,67 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) +func (cs *CertState) GetDefaultCertificate() cert.Certificate { + c := cs.getCertificate(cs.defaultVersion) + if c == nil { + panic("No default certificate found") } + return c +} + +func (cs *CertState) getDefaultHandshakeBytes() []byte { + return cs.getHandshakeBytes(cs.defaultVersion) +} - publicKey := certificate.PublicKey() - cs := &CertState{ - RawCertificate: rawCertificate, - Certificate: certificate, - PrivateKey: privateKey, - PublicKey: publicKey, - pkcs11Backed: pkcs11backed, +func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { + switch v { + case cert.Version1: + return cs.v1Cert + case cert.Version2: + return cs.v2Cert } - rawCertNoKey, err := cs.Certificate.MarshalForHandshakes() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + return nil +} + +func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { + switch v { + case cert.Version1: + return cs.v1HandshakeBytes + case cert.Version2: + return cs.v2HandshakeBytes } - cs.RawCertificateNoKey = rawCertNoKey - return cs, nil + panic("No handshake bytes found") } -func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { - var pemPrivateKey []byte - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) - if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) - } - } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { - rawKey = []byte(privPathOrPEM) - return rawKey, cert.Curve_P256, true, nil - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) +func (cs *CertState) String() string { + b, err := cs.MarshalJSON() + if err != nil { + return fmt.Sprintf("error marshaling certificate state: %v", err) + } + return string(b) +} + +func (cs *CertState) MarshalJSON() ([]byte, error) { + msg := []json.RawMessage{} + if cs.v1Cert != nil { + b, err := cs.v1Cert.MarshalJSON() if err != nil { - return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + return nil, err } - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + msg = append(msg, b) + } + + if cs.v2Cert != nil { + b, err := cs.v2Cert.MarshalJSON() if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } + msg = append(msg, b) } - return + return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { @@ -198,28 +293,177 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) - if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + var crt, v1, v2 cert.Certificate + for len(rawCert) != 0 { + // Load the certificate + crt, rawCert, err = loadCertificate(rawCert) + if err != nil { + //TODO: check error + return nil, err + } + + switch crt.Version() { + case cert.Version1: + if v1 != nil { + return nil, fmt.Errorf("v1 certificate already found in pki.cert") + } + v1 = crt + case cert.Version2: + if v2 != nil { + return nil, fmt.Errorf("v2 certificate already found in pki.cert") + } + v2 = crt + default: + return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) + } + } + + rawDefaultVersion := c.GetUint32("pki.default_version", 1) + var defaultVersion cert.Version + switch rawDefaultVersion { + case 1: + defaultVersion = cert.Version1 + case 2: + defaultVersion = cert.Version2 + default: + return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) + } + + return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) +} + +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { + cs := CertState{ + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + } + + if v1 != nil && v2 != nil { + if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { + return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) + } + + if v1.Curve() != v2.Curve() { + return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) + } + + cs.defaultVersion = dv + } + + if v1 != nil { + if pkcs11backed { + //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v1hs, err := v1.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v1Cert = v1 + cs.v1HandshakeBytes = v1hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version1 + } } - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") + if v2 != nil { + if pkcs11backed { + //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v2hs, err := v2.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v2Cert = v2 + cs.v2HandshakeBytes = v2hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version2 + } + } + + var crt cert.Certificate + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) } - if len(nebulaCert.Networks()) == 0 { - return nil, fmt.Errorf("no networks encoded in certificate") + for _, network := range crt.Networks() { + cs.myVpnNetworks = append(cs.myVpnNetworks, network) + cs.myVpnNetworksTable.Insert(network, struct{}{}) + + cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + + if network.Addr().Is4() { + //TODO: finish calculating the broadcast ips + //addr := network.Masked().Addr().As4() + //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + //ifce.myBroadcastAddr = netip.AddrFrom4(addr) + } } - if isPkcs11 { - //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + return &cs, nil +} + +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { + var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil } else { - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) } + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func loadCertificate(b []byte) (cert.Certificate, []byte, error) { + c, b, err := cert.UnmarshalCertificateFromPEM(b) + if err != nil { + return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) + } + + if c.Expired(time.Now()) { + return nil, b, fmt.Errorf("nebula certificate for this host is expired") + } + + if len(c.Networks()) == 0 { + return nil, b, fmt.Errorf("no networks encoded in certificate") + } + + if c.IsCA() { + return nil, b, fmt.Errorf("host certificate is a CA certificate") } - return newCertState(nebulaCert, isPkcs11, rawKey) + return c, b, nil } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { diff --git a/relay_manager.go b/relay_manager.go index 1a3a4d48f..bbc151db1 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti Type: relayType, State: state, LocalIndex: index, - PeerIp: vpnIp, + PeerAddr: vpnIp, } if remoteIdx != nil { @@ -91,40 +92,60 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, + //TODO: we need to handle possibly logging deprecated fields as well + rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") + "relayFrom": m.RelayFromAddr, + "relayTo": m.RelayToAddr}).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } -func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { +func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { + msg := &NebulaControl{} + err := msg.Unmarshal(d) + if err != nil { + h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + return + } + + var v cert.Version + if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { + v = cert.Version1 + + //TODO: yeah this is junk but maybe its less junky than the other options + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) + msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) - switch m.Type { + binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) + msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + } else { + v = cert.Version2 + } + + switch msg.Type { case NebulaControl_CreateRelayRequest: - rm.handleCreateRelayRequest(h, f, m) + rm.handleCreateRelayRequest(v, h, f, msg) case NebulaControl_CreateRelayResponse: - rm.handleCreateRelayResponse(h, f, m) + rm.handleCreateRelayResponse(v, h, f, msg) } } -func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { +func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp, + "relayFrom": m.RelayFromAddr, + "relayTo": m.RelayToAddr, "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") - target := m.RelayToIp - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - targetAddr := netip.AddrFrom4(b) + + target := m.RelayToAddr + targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,68 +157,79 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") + rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } if peerRelay.State == PeerRequested { - //TODO: IPV6-WORK - b = peerHostInfo.vpnIp.As4() peerRelay.State = Established resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(b[:]), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + peer := peerHostInfo.vpnAddrs[0] + if !peer.Is4() { + //TODO: log cant do it + return + } + + b := peer.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = targetAddr.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + resp.RelayToAddr = target + } + msg, err := resp.Marshal() if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.WithError(err). + Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromIp, - "relayTo": resp.RelayToIp, + "relayFrom": resp.RelayFromAddr, + "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). + "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } -func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayFromIp) - from := netip.AddrFrom4(b) - - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - target := netip.AddrFrom4(b) +func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { + from := protoAddrToNetAddr(m.RelayFromAddr) + target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnIp": h.vpnIp}) + "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(from) + if found { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } + // Is the target of the relay me? - if target == f.myVpnNet.Addr() { + _, found = f.myVpnAddrsTable.Lookup(target) + if found { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -230,17 +262,22 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N return } - //TODO: IPV6-WORK - fromB := from.As4() - targetB := target.As4() - resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + b := from.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(from) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { logMsg. @@ -253,7 +290,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return @@ -262,7 +299,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer := rm.hostmap.QueryVpnIp(target) + peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! @@ -291,17 +328,27 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N sendCreateRequest = true } if sendCreateRequest { - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() - // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + //TODO: log it + return + } + + b := h.vpnAddrs[0].As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + req.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := req.Marshal() if err != nil { logMsg. @@ -310,11 +357,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ //TODO: IPV6-WORK another lazy used to use the req object - "relayFrom": h.vpnIp, + "relayFrom": h.vpnAddrs[0], "relayTo": target, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": target}). + "vpnAddr": target}). Info("send CreateRelayRequest") } } @@ -342,16 +389,28 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() + resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + //TODO: log it + return + } + + b := h.vpnAddrs[0].As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { rm.l. @@ -360,11 +419,11 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ //TODO: IPV6-WORK more lazy, used to use resp object - "relayFrom": h.vpnIp, + "relayFrom": h.vpnAddrs[0], "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } diff --git a/remote_list.go b/remote_list.go index fa14f4295..1c7fd4db5 100644 --- a/remote_list.go +++ b/remote_list.go @@ -17,8 +17,8 @@ import ( type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -48,14 +48,14 @@ type cacheRelay struct { // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { - learned *Ip4AndPort - reported []*Ip4AndPort + learned *V4AddrPort + reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { - learned *Ip6AndPort - reported []*Ip6AndPort + learned *V6AddrPort + reported []*V6AddrPort } type hostnamePort struct { @@ -273,9 +273,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() if remote.Addr().Is4() { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } @@ -304,21 +304,21 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) + c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) + c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) + c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) + c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } @@ -401,14 +401,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -436,12 +436,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip4AndPort{to}, c.reported...) + c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -449,14 +449,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -473,12 +473,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip6AndPort{to}, c.reported...) + c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -536,14 +536,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := AddrPortFromIp4AndPort(c.v4.learned) + u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := AddrPortFromIp4AndPort(v) + u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -552,14 +552,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := AddrPortFromIp6AndPort(c.v6.learned) + u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := AddrPortFromIp6AndPort(v) + u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } diff --git a/remote_list_test.go b/remote_list_test.go index 62a892b00..33bfbb128 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -13,7 +13,7 @@ func TestRemoteList_Rebuild(t *testing.T) { rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is duped @@ -25,20 +25,20 @@ func TestRemoteList_Rebuild(t *testing.T) { newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) rl.Rebuild([]netip.Prefix{}) @@ -102,7 +102,7 @@ func BenchmarkFullRebuild(b *testing.B) { rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -112,19 +112,19 @@ func BenchmarkFullRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -164,7 +164,7 @@ func BenchmarkSortRebuild(b *testing.B) { rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -174,19 +174,19 @@ func BenchmarkSortRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -224,19 +224,19 @@ func BenchmarkSortRebuild(b *testing.B) { }) } -func newIp4AndPortFromString(s string) *Ip4AndPort { +func newIp4AndPortFromString(s string) *V4AddrPort { a := netip.MustParseAddrPort(s) v4Addr := a.Addr().As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(a.Port()), } } -func newIp6AndPortFromString(s string) *Ip6AndPort { +func newIp6AndPortFromString(s string) *V6AddrPort { a := netip.MustParseAddrPort(s) v6Addr := a.Addr().As16() - return &Ip6AndPort{ + return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(a.Port()), diff --git a/ssh.go b/ssh.go index 881ee4696..2aba7f313 100644 --- a/ssh.go +++ b/ssh.go @@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 + return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { @@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } @@ -581,7 +581,7 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -622,12 +622,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnIp) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -677,7 +677,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -785,7 +785,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.pki.GetCertState().Certificate + cert := ifce.pki.getDefaultCertificate() if len(a) > 0 { vpnIp, err := netip.ParseAddr(a[0]) if err != nil { @@ -796,7 +796,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } @@ -880,16 +880,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnIp} + ro := RelayOutput{NebulaIp: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) - relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) continue } for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnIp) if ok { t := "" switch r.Type { @@ -913,14 +913,14 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerIp + rf.PeerIp = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnIp) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } @@ -955,7 +955,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } diff --git a/udp/temp.go b/udp/temp.go index b281906f5..416b80155 100644 --- a/udp/temp.go +++ b/udp/temp.go @@ -7,4 +7,4 @@ import ( //TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare // TODO: IPV6-WORK this can likely be removed now -type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) +type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte) From 55676971693f0959a392a718fc6575a1671597ee Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 19 Sep 2024 21:49:16 -0500 Subject: [PATCH 06/17] Fixes --- inside.go | 4 +- interface.go | 103 +++++++++++++++++++++------------------------------ outside.go | 2 +- pki.go | 33 ++++++++++------- 4 files changed, 66 insertions(+), 76 deletions(-) diff --git a/inside.go b/inside.go index 1b75f0f46..6813237ed 100644 --- a/inside.go +++ b/inside.go @@ -21,7 +21,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet // Ignore local broadcast packets if f.dropLocalBroadcast { - _, found := f.myBroadcastAddr.Lookup(fwPacket.RemoteIP) + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteIP) if found { return } @@ -129,7 +129,7 @@ func (f *Interface) Handshake(vpnIp netip.Addr) { // getOrHandshake returns nil if the vpnIp is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - _, found := f.myVpnNetworks.Lookup(vpnIp) + _, found := f.myVpnNetworksTable.Lookup(vpnIp) if !found { vpnIp = f.inside.RouteFor(vpnIp) if !vpnIp.IsValid() { diff --git a/interface.go b/interface.go index 9686d128b..a403f5d03 100644 --- a/interface.go +++ b/interface.go @@ -14,7 +14,6 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -52,26 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - myBroadcastAddr *bart.Table[struct{}] - myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate - myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate - myVpnNetworks *bart.Table[struct{}] // A table of networks assigned to us via our certificate - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -157,25 +157,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnNetworks: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -189,27 +193,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - var crt cert.Certificate - cs := c.pki.getCertState() - crt = cs.getCertificate(cert.Version2) - if crt == nil { - // v2 certificates are a superset, only look at v1 if its all we have - crt = cs.getCertificate(cert.Version1) - } - - for _, network := range crt.Networks() { - ifce.myVpnNetworks.Insert(network, struct{}{}) - ifce.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) - ifce.myVpnAddrs = append(ifce.myVpnAddrs, network.Addr()) - - if network.Addr().Is4() { - //TODO: finish calculating the broadcast ips - //addr := network.Masked().Addr().As4() - //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - //ifce.myBroadcastAddr = netip.AddrFrom4(addr) - } - } - ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) diff --git a/outside.go b/outside.go index dd2ae2520..f7dbbd32e 100644 --- a/outside.go +++ b/outside.go @@ -49,7 +49,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - _, found := f.myVpnNetworks.Lookup(ip.Addr()) + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") diff --git a/pki.go b/pki.go index 25d4a0e14..c4160d5a8 100644 --- a/pki.go +++ b/pki.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/netip" "os" "slices" @@ -37,10 +38,11 @@ type CertState struct { pkcs11Backed bool cipher string - myVpnNetworks []netip.Prefix - myVpnNetworksTable *bart.Table[struct{}] - myVpnAddrs []netip.Addr - myVpnAddrsTable *bart.Table[struct{}] + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -294,7 +296,7 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } var crt, v1, v2 cert.Certificate - for len(rawCert) != 0 { + for { // Load the certificate crt, rawCert, err = loadCertificate(rawCert) if err != nil { @@ -316,6 +318,10 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { default: return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } } rawDefaultVersion := c.GetUint32("pki.default_version", 1) @@ -334,10 +340,11 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { cs := CertState{ - privateKey: privateKey, - pkcs11Backed: pkcs11backed, - myVpnNetworksTable: new(bart.Table[struct{}]), - myVpnAddrsTable: new(bart.Table[struct{}]), + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), } if v1 != nil && v2 != nil { @@ -409,10 +416,10 @@ func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, p cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) if network.Addr().Is4() { - //TODO: finish calculating the broadcast ips - //addr := network.Masked().Addr().As4() - //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - //ifce.myBroadcastAddr = netip.AddrFrom4(addr) + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) } } From 7adbb3523d246bb284a7d45739e8af98e3793b6b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 19 Sep 2024 22:05:57 -0500 Subject: [PATCH 07/17] Fix v2 handshaking --- cert/cert_v2.go | 18 +++++++++++++++++- handshake_ix.go | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 117de22e1..329c66822 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -223,7 +223,23 @@ func (c *certificateV2) String() string { } func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { - panic("TODO") + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + //TODO: panic on nil rawDetails + b.AddBytes(c.rawDetails) + + // Skipping the curve and public key since those come across in a different part of the handshake + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() } func (c *certificateV2) Marshal() ([]byte, error) { diff --git a/handshake_ix.go b/handshake_ix.go index 4cb642ffa..e20ac2f24 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -31,6 +31,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { InitiatorIndex: hh.hostinfo.localIndexId, Time: uint64(time.Now().UnixNano()), Cert: cs.getDefaultHandshakeBytes(), + CertVersion: uint32(cs.defaultVersion), }, } From 02edb72458eccceaa634f39aea36999541e43360 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Thu, 19 Sep 2024 22:36:38 -0500 Subject: [PATCH 08/17] Fix v2 cert copying --- cert/cert_v2.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 329c66822..2c6f14a00 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -323,17 +323,11 @@ func (c *certificateV2) Copy() Certificate { signature: make([]byte, len(c.signature)), } - copy(c.signature, c.signature) - copy(c.details.groups, c.details.groups) - copy(c.publicKey, c.publicKey) - - for i, p := range c.details.networks { - c.details.networks[i] = p - } - - for i, p := range c.details.unsafeNetworks { - c.details.unsafeNetworks[i] = p - } + copy(nc.signature, c.signature) + copy(nc.details.groups, c.details.groups) + copy(nc.publicKey, c.publicKey) + copy(nc.details.networks, c.details.networks) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) return nc } From 8b52c957839f9c19fd619162c869e27cfbf5244b Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 20 Sep 2024 09:56:33 -0500 Subject: [PATCH 09/17] Copy darwin tun handling from other branch --- overlay/tun_darwin.go | 421 +++++++++++++++++++++--------------------- 1 file changed, 207 insertions(+), 214 deletions(-) diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 0b573e6b3..1cff2144c 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,56 +24,62 @@ import ( type tun struct { io.ReadWriteCloser - Device string - cidr netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type ifReq struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } -var sockaddrCtlSize uintptr = 32 - const ( - _SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ - _AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ - _PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM - _CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) - utunControlName = "com.apple.net.utun_control" + _SIOCAIFADDR_IN6 = 2155899162 + _UTUN_OPT_IFNAME = 2 + _IN6_IFF_NODAD = 0x0020 + _IN6_IFF_SECURED = 0x0400 + utunControlName = "com.apple.net.utun_control" ) -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +type addrLifetime struct { + Expire float64 + Preferred float64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } } - fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} + var ctlInfo = &unix.CtlInfo{} + copy(ctlInfo.Name[:], utunControlName) - copy(ctlInfo.ctlName[:], utunControlName) - - err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo))) + err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: _AF_SYS_CONTROL, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, - } - - _, _, errno := unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(unsafe.Pointer(&sc)), - sockaddrCtlSize, - ) - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + err = unix.Connect(fd, &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, + }) + if err != nil { + return nil, fmt.Errorf("SYS_CONNECT: %v", err) } - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(len(ifName.name)) - _, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), - 2, // SYSPROTO_CONTROL - 2, // UTUN_OPT_IFNAME - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - if errno != 0 { - return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno) + name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) + if err != nil { + return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } - name = string(ifName.name[:ifNameSize-1]) - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } - file := os.NewFile(uintptr(fd), "") - t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, - cidr: cidr, + vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -186,16 +167,6 @@ func (t *tun) Close() error { func (t *tun) Activate() error { devName := t.deviceBytes() - var addr, mask [4]byte - - if !t.cidr.Addr().Is4() { - //TODO: IPV6-WORK - panic("need ipv6") - } - - addr = t.cidr.Addr().As4() - copy(mask[:], prefixToMask(t.cidr)) - s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, @@ -208,66 +179,18 @@ func (t *tun) Activate() error { fd := uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - - // Set the device name - ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to set tun device name: %s", err) - } - // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } - /* - // Set the transmit queue length - ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { - // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") - } - */ - - // Bring up the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to bring the tun device up: %s", err) - } - - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + // Get the device flags + ifrf := ifReq{Name: devName} + if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to get tun flags: %s", err) } - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} linkAddr, err := getLinkAddr(t.Device) if err != nil { return err @@ -277,14 +200,18 @@ func (t *tun) Activate() error { } t.linkAddr = linkAddr - copy(routeAddr.IP[:], addr[:]) - copy(maskAddr.IP[:], mask[:]) - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) + for _, network := range t.vpnNetworks { + if network.Addr().Is4() { + err = t.activate4(network) + if err != nil { + return err + } + } else { + err = t.activate6(network) + if err != nil { + return err + } } - return err } // Run the interface @@ -297,8 +224,89 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) activate4(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(network).As4(), + }, + } + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun v4 address: %s", err) + } + + err = addRoute(network, t.linkAddr) + if err != nil { + return err + } + + return nil +} + +func (t *tun) activate6(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: network.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(network).As16(), + }, + Lifetime: addrLifetime{ + // never expires + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + //TODO: should we disable DAD (duplicate address detection) and mark this as a secured address? + Flags: _IN6_IFF_NODAD, + } + + if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address: %s", err) + } + + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { } func (t *tun) addRoutes(logErrors bool) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() + for _, r := range routes { if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !r.Cidr.Addr().Is4() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - //TODO: we could avoid the copy - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). @@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error { } func (t *tun) removeRoutes(routes []Route) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} - for _, r := range routes { if !r.Install { continue } - if r.Cidr.Addr().Is6() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } + _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) @@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } -func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } @@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) } func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) @@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } -func prefixToMask(prefix netip.Prefix) []byte { +func prefixToMask(prefix netip.Prefix) netip.Addr { pLen := 128 if prefix.Addr().Is4() { pLen = 32 } - return net.CIDRMask(prefix.Bits(), pLen) + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr } From f4329b1a5ccc6e8d671da6207c0d40bf7dc356fd Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 20 Sep 2024 09:59:29 -0500 Subject: [PATCH 10/17] Copy device and route files over --- overlay/device.go | 2 +- overlay/route.go | 36 ++++++++++++++++++++---------------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/overlay/device.go b/overlay/device.go index 50ad6ad5b..da8cbe92b 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -8,7 +8,7 @@ import ( type Device interface { io.ReadWriteCloser Activate() error - Cidr() netip.Prefix + Networks() []netip.Prefix Name() string RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) diff --git a/overlay/route.go b/overlay/route.go index 8ccc9943c..14b184c46 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table return routeTree, nil } -func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -117,13 +117,15 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { - return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { + return nil, fmt.Errorf( + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r @@ -132,7 +134,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -229,13 +231,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if network.Contains(r.Cidr.Addr()) { - return nil, fmt.Errorf( - "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r From 77b875dcc86126370a16ba82010e0f75894b3ab1 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 20 Sep 2024 10:16:20 -0500 Subject: [PATCH 11/17] Port over more of the ipv6 tun device changes --- control_tester.go | 43 +++++++++++++++++++++++++++++++---------- e2e/handshakes_test.go | 38 ++++++++++++++++++------------------ e2e/helpers_test.go | 4 ++-- interface.go | 2 +- main.go | 3 +-- overlay/tun.go | 18 ++++++++--------- overlay/tun_disabled.go | 16 +++++++-------- overlay/tun_tester.go | 34 ++++++++++++++++---------------- overlay/user.go | 12 ++++++------ ssh.go | 8 +++++--- 10 files changed, 101 insertions(+), 77 deletions(-) diff --git a/control_tester.go b/control_tester.go index cde0919f7..586617af7 100644 --- a/control_tester.go +++ b/control_tester.go @@ -97,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { - //TODO: IPV6-WORK - ip := layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), - DstIP: toIp.Unmap().AsSlice(), +func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - err := udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } @@ -121,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + + serialize = append(serialize, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index a3717bcad..383478af8 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -21,7 +21,7 @@ import ( func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) + myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse @@ -35,7 +35,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -56,7 +56,7 @@ func TestGoodHandshake(t *testing.T) { theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -120,7 +120,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see our cached packet get sent to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { h := &header.H{} err := h.Parse(p.Data) @@ -181,8 +181,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -258,7 +258,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -269,7 +269,7 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -307,7 +307,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) @@ -319,7 +319,7 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) @@ -361,7 +361,7 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -403,8 +403,8 @@ func TestStage1RaceRelays(t *testing.T) { assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -456,8 +456,8 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -529,7 +529,7 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -633,7 +633,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") @@ -933,8 +933,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index f8a224366..c8b42b007 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -143,12 +143,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } diff --git a/interface.go b/interface.go index a403f5d03..c7f6c3e00 100644 --- a/interface.go +++ b/interface.go @@ -213,7 +213,7 @@ func (f *Interface) activate() { f.l.WithError(err).Error("Failed to get udp listen address") } - f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). + f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") diff --git a/main.go b/main.go index 5e97b4a77..6aea39a0f 100644 --- a/main.go +++ b/main.go @@ -129,8 +129,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - //TODO: device needs all networks not just the first one - tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks[0], routines) + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } diff --git a/overlay/tun.go b/overlay/tun.go index 12460da1f..4a6377d2a 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,36 +11,36 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, tunCidr, routines > 1) + return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, tunCidr) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks) } } -func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } - routes, err := parseRoutes(c, cidr) + routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 130f8f99f..cfbf17d97 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -12,8 +12,8 @@ import ( ) type disabledTun struct { - read chan []byte - cidr netip.Prefix + read chan []byte + vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,11 +21,11 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - l: l, + vpnNetworks: vpnNetworks, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { return netip.Addr{} } -func (t *disabledTun) Cidr() netip.Prefix { - return t.cidr +func (t *disabledTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (*disabledTun) Name() string { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index ba15723a1..cc3942f41 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -16,19 +16,19 @@ import ( ) type TestTun struct { - Device string - cidr netip.Prefix - Routes []Route - routeTree *bart.Table[netip.Addr] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[netip.Addr] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { - _, routes, err := getAllRoutesFromConfig(c, cidr, true) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } @@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, } return &TestTun{ - Device: c.GetString("tun.dev", ""), - cidr: cidr, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -95,8 +95,8 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() netip.Prefix { - return t.cidr +func (t *TestTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *TestTun) Name() string { diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5f7..ae665f3a7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -8,16 +8,16 @@ import ( "github.com/slackhq/nebula/config" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return NewUserDevice(vpnNetworks) } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, + vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, @@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) { } type UserDevice struct { - tunCidr netip.Prefix + vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -38,7 +38,7 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/ssh.go b/ssh.go index 2aba7f313..a04b28b23 100644 --- a/ssh.go +++ b/ssh.go @@ -971,13 +971,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { data := struct { - Name string `json:"name"` - Cidr string `json:"cidr"` + Name string `json:"name"` + Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), - Cidr: ifce.inside.Cidr().String(), + Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } + copy(data.Cidr, ifce.inside.Networks()) + flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) From bf79947345006a8defa3c0055cfe76d8b8971128 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 20 Sep 2024 10:35:53 -0500 Subject: [PATCH 12/17] Fixup tests --- overlay/route.go | 20 +++++++----- overlay/route_test.go | 72 +++++++++++++++++++++++-------------------- service/service.go | 4 +-- test/tun.go | 4 +-- 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/overlay/route.go b/overlay/route.go index 14b184c46..687cc11b8 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -117,17 +117,23 @@ func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } + found := false for _, network := range networks { - if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { - return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { + found = true + break } } + if !found { + return nil, fmt.Errorf( + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", + i+1, + r.Cidr.String(), + networks, + ) + } + routes[i] = r } diff --git a/overlay/route_test.go b/overlay/route_test.go index d7913894b..c60e4c24b 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, n) + routes, err := parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + + // Not in multiple ranges + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 2) @@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") @@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) { 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") @@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 4) @@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) { map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) diff --git a/service/service.go b/service/service.go index 4ddd30182..4339677bd 100644 --- a/service/service.go +++ b/service/service.go @@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) { }, }) - ipNet := device.Cidr() + ipNet := device.Networks() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/test/tun.go b/test/tun.go index fbf58295a..b29d61c7b 100644 --- a/test/tun.go +++ b/test/tun.go @@ -16,8 +16,8 @@ func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() netip.Prefix { - return netip.Prefix{} +func (NoopTun) Networks() []netip.Prefix { + return []netip.Prefix{} } func (NoopTun) Name() string { From 28cd2575e81808c1b8a96d169a27b7f94d7f2652 Mon Sep 17 00:00:00 2001 From: Jack Doan Date: Fri, 20 Sep 2024 23:04:18 -0400 Subject: [PATCH 13/17] Cert v2 + tun changes for Linux (#1224) --- cert/cert.go | 15 ++- cert/cert_v1.go | 4 +- cert/cert_v2.go | 7 +- cert/pem.go | 2 +- connection_state.go | 4 + dns_server.go | 9 ++ handshake_ix.go | 4 +- lighthouse.go | 3 + outside.go | 111 +++++++++++++++++++-- outside_test.go | 18 ++-- overlay/tun_android.go | 22 ++--- overlay/tun_freebsd.go | 40 +++++--- overlay/tun_ios.go | 20 ++-- overlay/tun_linux.go | 177 +++++++++++++++++++++------------- overlay/tun_netbsd.go | 45 +++++---- overlay/tun_openbsd.go | 48 +++++---- overlay/tun_water_windows.go | 36 +++---- overlay/tun_windows.go | 8 +- overlay/tun_wintun_windows.go | 30 +++--- 19 files changed, 396 insertions(+), 207 deletions(-) diff --git a/cert/cert.go b/cert/cert.go index 5d5abd646..ed94ec043 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -113,7 +113,7 @@ type CachedCertificate struct { func UnmarshalCertificate(b []byte) (Certificate, error) { //TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers? var c Certificate - c, err := unmarshalCertificateV2(b, nil) + c, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) if err == nil { return c, nil } @@ -129,7 +129,7 @@ func UnmarshalCertificate(b []byte) (Certificate, error) { // UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) { +func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { var c Certificate var err error @@ -137,7 +137,7 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C case VersionPre1, Version1: c, err = unmarshalCertificateV1(b, publicKey) case Version2: - c, err = unmarshalCertificateV2(b, publicKey) + c, err = unmarshalCertificateV2(b, publicKey, curve) default: //TODO: make a static var return nil, fmt.Errorf("unknown certificate version %d", v) @@ -146,10 +146,15 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C if err != nil { return nil, err } + + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) + } + return c, nil } -func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) { +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { if publicKey == nil { return nil, ErrNoPeerStaticKey } @@ -158,7 +163,7 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAP return nil, ErrNoPayload } - c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey) + c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) if err != nil { return nil, fmt.Errorf("error unmarshaling cert: %w", err) } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 4dc38fcef..9f88723bd 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -14,7 +14,6 @@ import ( "fmt" "net" "net/netip" - "slices" "time" "golang.org/x/crypto/curve25519" @@ -393,8 +392,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) } } - slices.SortFunc(nc.details.networks, comparePrefix) - slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + //do not sort the subnets field for V1 certs return &nc, nil } diff --git a/cert/cert_v2.go b/cert/cert_v2.go index 2c6f14a00..11aa666bd 100644 --- a/cert/cert_v2.go +++ b/cert/cert_v2.go @@ -455,7 +455,7 @@ func (d *detailsV2) Marshal() ([]byte, error) { return b.Bytes() } -func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) { +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { l := len(b) if l == 0 || l > MaxCertificateSize { return nil, ErrBadFormat @@ -473,11 +473,12 @@ func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) return nil, ErrBadFormat } + //Maybe grab the curve var rawCurve byte - if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) { + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { return nil, ErrBadFormat } - curve := Curve(rawCurve) + curve = Curve(rawCurve) // Maybe grab the public key var rawPublicKey cryptobyte.String diff --git a/cert/pem.go b/cert/pem.go index 8f9fe8e99..249b63917 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -37,7 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { case CertificateBanner: c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: - c, err = unmarshalCertificateV2(p.Bytes, nil) + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) default: return nil, r, ErrInvalidPEMCertificateBanner } diff --git a/connection_state.go b/connection_state.go index cfd86eb35..a13164d5c 100644 --- a/connection_state.go +++ b/connection_state.go @@ -91,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/dns_server.go b/dns_server.go index 991f27068..1b97d7a4f 100644 --- a/dns_server.go +++ b/dns_server.go @@ -85,6 +85,15 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { m.Answer = append(m.Answer, rr) } } + case dns.TypeAAAA: + l.Debugf("Query for AAAA %s", q.Name) + ip := dnsR.Query(q.Name) + if ip != "" { + rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip)) + if err == nil { + m.Answer = append(m.Answer, rr) + } + } case dns.TypeTXT: a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) b, err := netip.ParseAddr(a) diff --git a/handshake_ix.go b/handshake_ix.go index e20ac2f24..55b857e46 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -81,7 +81,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -404,7 +404,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) diff --git a/lighthouse.go b/lighthouse.go index 623ed0b9e..ab81d3c66 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -1170,6 +1170,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd useVersion = 2 } + //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //todo why do we care about the vpnip in the packet? We know where it came from, right? + if detailsVpnIp != vpnAddrs[0] { if lhh.l.Level >= logrus.DebugLevel { lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") diff --git a/outside.go b/outside.go index f7dbbd32e..6eb6ea011 100644 --- a/outside.go +++ b/outside.go @@ -7,6 +7,9 @@ import ( "net/netip" "time" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" @@ -297,14 +300,104 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return errors.New("packet too short") + } + + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) } + return fmt.Errorf("packet is an unknown ip version: %v", version) +} + +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen) + } + + if incoming { + fp.RemoteIP, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalIP, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalIP, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteIP, _ = netip.AddrFromSlice(data[24:40]) + } + + //TODO: whats a reasonable number of extension headers to attempt to parse? + //https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html + protoAt := 6 + offset := 40 + for i := 0; i < 24; i++ { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6: + //TODO: we need a new protocol in config language "icmpv6" + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) + case layers.IPProtocolTCP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + fp.Fragment = false + return nil + + case layers.IPProtocolUDP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + //TODO: can we determine the protocol? + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = true + return nil + + default: + if dataLen < offset+1 { + break + } + + next := int(data[offset+1]) * 8 + if next == 0 { + // each extension is at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + } + + return fmt.Errorf("could not find payload in ipv6 packet") +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen) } // Adjust our start position based on the advertised ip header length @@ -312,7 +405,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Well formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl) } // Check if this is the second or further fragment of a fragmented packet. @@ -328,12 +421,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl) } // Firewall packets are locally oriented if incoming { - //TODO: IPV6-WORK fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { @@ -344,7 +436,6 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - //TODO: IPV6-WORK fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { diff --git a/outside_test.go b/outside_test.go index f9d4bfa48..aa5581f03 100644 --- a/outside_test.go +++ b/outside_test.go @@ -13,9 +13,15 @@ import ( func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + assert.EqualError(t, err, "packet too short") + + err = newPacket([]byte{0x40}, true, p) + assert.EqualError(t, err, "ipv4 packet is less than 20 bytes") + + err = newPacket([]byte{0x60}, true, p) + assert.EqualError(t, err, "ipv6 packet is less than 20 bytes") // length fail with ip options h := ipv4.Header{ @@ -29,15 +35,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24") // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + assert.EqualError(t, err, "packet is an unknown ip version: 0") // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8") // account for variable ip header length - incoming h = ipv4.Header{ diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 98ad9b408..72a656500 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,14 +18,14 @@ import ( type tun struct { io.ReadWriteCloser - fd int - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -66,7 +66,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bdfeb5802..69690e948 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -46,12 +46,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -78,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -195,8 +195,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 20981f08c..e99d44718 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,20 +21,20 @@ import ( type tun struct { io.ReadWriteCloser - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -59,7 +59,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0e7e20d41..08c65b4e0 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -25,7 +25,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr netip.Prefix + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -40,18 +40,16 @@ type tun struct { l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -64,10 +62,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -77,7 +75,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,7 +110,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -122,11 +120,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -148,7 +146,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -190,11 +188,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -237,10 +237,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -265,6 +265,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //todo do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -272,15 +324,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - //TODO: IPV6-WORK - addr = t.cidr.Addr().As4() - tmask := net.CIDRMask(t.cidr.Bits(), 32) - copy(mask[:], tmask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -289,31 +334,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -324,20 +357,27 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) + // Run the interface + ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING + if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to run tun device: %s", err) } - t.deviceIndex = link.Attrs().Index - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -345,11 +385,7 @@ func (t *tun) Activate() error { return err } - // Run the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP | unix.IFF_RUNNING - if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to run tun device: %s", err) - } + //todo do we want to keep the link-local address? return nil } @@ -363,12 +399,12 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { // Default route dr := &net.IPNet{ - IP: t.cidr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ @@ -377,7 +413,7 @@ func (t *tun) setDefaultRoute() error { MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: net.IP(t.cidr.Addr().AsSlice()), + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, @@ -463,10 +499,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() netip.Prefix { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -523,9 +555,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } gwAddr = gwAddr.Unmap() - if !t.cidr.Contains(gwAddr) { + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break + } + } + if !withinNetworks { // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") + t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") return } diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 24ab24f78..54984910e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -27,12 +27,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -58,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -130,8 +130,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6463ccbba..dcba05478 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -21,12 +21,12 @@ import ( ) type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser @@ -42,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -138,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -148,6 +148,16 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) RouteFor(ip netip.Addr) netip.Addr { r, _ := t.routeTree.Load().Lookup(ip) return r @@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go index d78f564cf..73252c71d 100644 --- a/overlay/tun_water_windows.go +++ b/overlay/tun_water_windows.go @@ -17,22 +17,22 @@ import ( ) type waterTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - f *net.Interface + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + f *net.Interface *water.Interface } -func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { +func newWaterTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*waterTun, error) { // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err := t.reload(c, true) @@ -52,11 +52,13 @@ func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*wat func (t *waterTun) Activate() error { var err error + //TODO multi-ip support + cidr := t.vpnNetworks[0] t.Interface, err = water.New(water.Config{ DeviceType: water.TUN, PlatformSpecificParams: water.PlatformSpecificParams{ ComponentID: "tap0901", - Network: t.cidr.String(), + Network: cidr.String(), }, }) if err != nil { @@ -70,8 +72,8 @@ func (t *waterTun) Activate() error { `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", fmt.Sprintf("name=%s", t.Device), "source=static", - fmt.Sprintf("addr=%s", t.cidr.Addr()), - fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), + fmt.Sprintf("addr=%s", cidr.Addr()), + fmt.Sprintf("mask=%s", net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen())), "gateway=none", ).Run() if err != nil { @@ -100,7 +102,7 @@ func (t *waterTun) Activate() error { } func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -187,8 +189,8 @@ func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *waterTun) Cidr() netip.Prefix { - return t.cidr +func (t *waterTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *waterTun) Name() string { diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 3d883093c..125d72bbe 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -15,11 +15,11 @@ import ( "github.com/slackhq/nebula/config" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (Device, error) { useWintun := true if err := checkWinTunExists(); err != nil { l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") @@ -27,14 +27,14 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( } if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) + device, err := newWinTun(c, l, vpnNetworks, multiqueue) if err != nil { return nil, fmt.Errorf("create Wintun interface failed, %w", err) } return device, nil } - device, err := newWaterTun(c, l, cidr, multiqueue) + device, err := newWaterTun(c, l, vpnNetworks, multiqueue) if err != nil { return nil, fmt.Errorf("create wintap driver failed, %w", err) } diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go index d0103879a..5a801c318 100644 --- a/overlay/tun_wintun_windows.go +++ b/overlay/tun_wintun_windows.go @@ -20,12 +20,12 @@ import ( const tunGUIDLabel = "Fixed Nebula Windows GUID v1" type winTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger tun *wintun.NativeTun } @@ -49,7 +49,7 @@ func generateGUIDByDeviceName(name string) (*windows.GUID, error) { return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } -func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { +func newWinTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { deviceName := c.GetString("tun.dev", "") guid, err := generateGUIDByDeviceName(deviceName) if err != nil { @@ -57,10 +57,10 @@ func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTu } t := &winTun{ - Device: deviceName, - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, } err = t.reload(c, true) @@ -92,7 +92,7 @@ func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTu } func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -131,7 +131,7 @@ func (t *winTun) reload(c *config.C, initial bool) error { func (t *winTun) Activate() error { luid := winipcfg.LUID(t.tun.LUID()) - err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) + err := luid.SetIPAddresses(t.vpnNetworks) if err != nil { return fmt.Errorf("failed to set address: %w", err) } @@ -216,8 +216,8 @@ func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *winTun) Cidr() netip.Prefix { - return t.cidr +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *winTun) Name() string { From e681be20c968f836f4a7be1449b6b122e3c74902 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Fri, 20 Sep 2024 23:21:01 -0500 Subject: [PATCH 14/17] Shed the layers of indirection on udp listeners, get the full hostinfo to the lighthouse request handler --- interface.go | 12 +++++-- lighthouse.go | 78 +++++++++++++++++++----------------------- lighthouse_test.go | 9 ++--- outside.go | 22 ++---------- udp/conn.go | 15 ++------ udp/temp.go | 10 ------ udp/udp_generic.go | 20 ++--------- udp/udp_linux.go | 26 ++------------ udp/udp_rio_windows.go | 21 ++---------- udp/udp_tester.go | 10 ++---- 10 files changed, 64 insertions(+), 159 deletions(-) delete mode 100644 udp/temp.go diff --git a/interface.go b/interface.go index c7f6c3e00..9ec7d1bf7 100644 --- a/interface.go +++ b/interface.go @@ -254,16 +254,22 @@ func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn - // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) + plaintext := make([]byte, udp.MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { diff --git a/lighthouse.go b/lighthouse.go index ab81d3c66..23a4b9f50 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -915,24 +915,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte) { - lhh.HandleRequest(rAddr, vpnAddrs, p, f) - } -} - -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *HostInfo, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") //TODO: send recv_error? return @@ -942,24 +936,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnAddrs []net switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnAddrs, rAddr, w) + lhh.handleHostQuery(n, reqHostinfo, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnAddrs) + lhh.handleHostQueryReply(n, reqHostinfo) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnAddrs, w) + lhh.handleHostUpdateNotification(n, reqHostinfo, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnAddrs, w) + lhh.handleHostPunchNotification(n, reqHostinfo, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostInfo, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -1007,15 +1001,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnAddrs[0], func(c *cache) (int, error) { + found, ln, err = lhh.lh.queryAndPrepMessage(reqHostinfo.vpnAddrs[0], func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel @@ -1027,15 +1021,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad } if useVersion == cert.Version1 { - if !vpnAddrs[0].Is4() { + if !reqHostinfo.vpnAddrs[0].Is4() { return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") } - b := vpnAddrs[0].As4() + b := reqHostinfo.vpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(useVersion, c, n) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) lhh.coalesceAnswers(useVersion, c, n) } else { @@ -1050,7 +1044,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnAddrs []netip.Ad } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } @@ -1100,9 +1094,9 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []netip.Addr) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *HostInfo) { //TODO: this is kind of dumb - if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) { + if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { return } @@ -1121,8 +1115,8 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1139,7 +1133,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net } } - am.unlockedSetRelay(vpnAddrs[0], certVpnIp, relays) + am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -1149,10 +1143,10 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnAddrs []net } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnAddrs) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", reqHostinfo.vpnAddrs) } return } @@ -1173,20 +1167,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? //todo why do we care about the vpnip in the packet? We know where it came from, right? - if detailsVpnIp != vpnAddrs[0] { + if detailsVpnIp != reqHostinfo.vpnAddrs[0] { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnAddrs[0]) + am := lhh.lh.unlockedGetRemoteList(reqHostinfo.vpnAddrs[0]) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1203,22 +1197,22 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd } } - am.unlockedSetRelay(vpnAddrs[0], detailsVpnIp, relays) + am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck if useVersion == cert.Version1 { - if !vpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + if !reqHostinfo.vpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") return } - vpnIpB := vpnAddrs[0].As4() + vpnIpB := reqHostinfo.vpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnIpB[:]) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) } else { panic("unsupported version") @@ -1226,17 +1220,17 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", vpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnAddrs []netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { //TODO: this is kinda stupid - if !lhh.lh.IsLighthouseIP(vpnAddrs[0]) { + if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index fbb86a137..0c315c09c 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -135,6 +135,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { mw := &mockEncWriter{} + hi := &HostInfo{vpnAddrs: []netip.Addr{vpnIp2}} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ @@ -147,7 +148,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -163,7 +164,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, []netip.Addr{vpnIp2}, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) } @@ -324,7 +325,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) + lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{myVpnIp}}, b, w) return w.lastReply } @@ -349,7 +350,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) + lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{vpnIp}}, b, w) } //TODO: this is a RemoteList test diff --git a/outside.go b/outside.go index 6eb6ea011..a94ac9c17 100644 --- a/outside.go +++ b/outside.go @@ -13,7 +13,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -21,24 +20,7 @@ const ( minFwPacketLen = 4 ) -// TODO: IPV6-WORK this can likely be removed now -func readOutsidePackets(f *Interface) udp.EncReader { - return func( - addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh udp.LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, - ) { - f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) - } -} - -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -163,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf(ip, hostinfo.vpnAddrs, d) + lhf.HandleRequest(ip, hostinfo, d, f) // Fallthrough to the bottom to record incoming traffic diff --git a/udp/conn.go b/udp/conn.go index fa4e44304..895b0df35 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,28 +4,19 @@ import ( "net/netip" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) const MTU = 9001 type EncReader func( addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, + payload []byte, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error @@ -39,7 +30,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { +func (NoopConn) ListenOut(_ EncReader) { return } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { diff --git a/udp/temp.go b/udp/temp.go deleted file mode 100644 index 416b80155..000000000 --- a/udp/temp.go +++ /dev/null @@ -1,10 +0,0 @@ -package udp - -import ( - "net/netip" -) - -//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare - -// TODO: IPV6-WORK this can likely be removed now -type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnAddrs []netip.Addr, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 2d8453694..99a3eca21 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -15,8 +15,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) type GenericConn struct { @@ -72,12 +70,8 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f return } - r( - netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76ee2..36ab67c50 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,8 +14,6 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) @@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} +func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - nb := make([]byte, 12, 12) - //TODO: should we track this? - //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - //metric.Update(int64(n)) for i := 0; i < n; i++ { + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) - //TODO: IPV6-WORK what is not ok? } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) - //TODO: IPV6-WORK what is not ok? } - r( - netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), - plaintext[:0], - buffers[i][:msgs[i].Len], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e002..585b642bb 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -18,9 +18,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - r( - netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f03a3535f..8d5e6c14a 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -10,7 +10,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - +func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } - r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, p.Data) } } From f8b9c804db85ddc65e4a14a47bb10feb6f85435f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Sat, 21 Sep 2024 20:25:18 -0500 Subject: [PATCH 15/17] Support multiple vpn addrs in lighthouse and hostmap --- connection_manager.go | 2 +- control_tester.go | 4 +- handshake_ix.go | 19 +++-- handshake_manager.go | 2 +- hostmap.go | 18 ++++- lighthouse.go | 174 +++++++++++++++++++++++------------------- lighthouse_test.go | 6 +- outside.go | 23 +++--- 8 files changed, 140 insertions(+), 108 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index eecfd7d46..6b7b0df91 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnAddr(hostinfo.vpnAddrs[0]) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: diff --git a/control_tester.go b/control_tester.go index 586617af7..93c4e06a8 100644 --- a/control_tester.go +++ b/control_tester.go @@ -49,7 +49,7 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() @@ -65,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() diff --git a/handshake_ix.go b/handshake_ix.go index 55b857e46..65ce5c876 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,6 +2,7 @@ package nebula import ( "net/netip" + "slices" "time" "github.com/flynn/noise" @@ -230,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs[0]) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert.Certificate) @@ -436,9 +437,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha fingerprint := remoteCert.ShaSum issuer := remoteCert.Certificate.Issuer() + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, n := range vpnNetworks { + vpnAddrs[i] = n.Addr() + } + // Ensure the right host responded - //TODO: this is a horribly broken test - if vpnNetworks[0].Addr() != hostinfo.vpnAddrs[0] { + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -455,7 +460,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.hostinfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnNetworks[0].Addr()) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). @@ -466,10 +471,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hh.packetStore = []*cachedPacket{} // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnAddrs = nil - for _, n := range vpnNetworks { - hostinfo.vpnAddrs = append(hostinfo.vpnAddrs, n.Addr()) - } + hostinfo.vpnAddrs = vpnAddrs f.sendCloseTunnel(hostinfo) }) @@ -492,6 +494,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.lastHandshakeTime = hs.Details.Time + hostinfo.vpnAddrs = vpnAddrs // Store their cert and our symmetric keys ci.peerCert = remoteCert diff --git a/handshake_manager.go b/handshake_manager.go index 258d5ae94..6b3902dfa 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -209,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) diff --git a/hostmap.go b/hostmap.go index fbafc06d5..63601ee37 100644 --- a/hostmap.go +++ b/hostmap.go @@ -308,7 +308,7 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { hm.Lock() // If we have a previous or next hostinfo then we are not the last one for this vpn ip final := (hostinfo.next == nil && hostinfo.prev == nil) - hm.unlockedDeleteHostInfo(hostinfo) + hm.unlockedDeleteHostInfo(hostinfo, false) hm.Unlock() return final @@ -345,7 +345,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { hostinfo.prev = nil } -func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo, dontRecurse bool) { primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]] if ok && primary == hostinfo { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it @@ -399,6 +399,18 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } + + if !dontRecurse { + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedDeleteHostInfo(h, true) + } + h = h.next + } + } + } } func (hm *HostMap) QueryIndex(index uint32) *HostInfo { @@ -501,7 +513,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { check := hostinfo for check != nil { if i > MaxHostInfosPerVpnIp { - hm.unlockedDeleteHostInfo(check) + hm.unlockedDeleteHostInfo(check, false) } check = check.next i++ diff --git a/lighthouse.go b/lighthouse.go index 23a4b9f50..5549e8386 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "slices" "strconv" "sync" "sync/atomic" @@ -472,12 +473,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return nil } -func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseIP(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -486,18 +487,18 @@ func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip netip.Addr) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { + if lh.amLighthouse || lh.IsLighthouseIP(vpnAddr) { return } - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -506,16 +507,16 @@ func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() @@ -523,7 +524,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, lh.RUnlock() // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -535,20 +536,25 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnAddr(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } @@ -556,9 +562,9 @@ func (lh *LightHouse) DeleteVpnAddr(vpnIp netip.Addr) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx @@ -572,12 +578,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { continue } switch { @@ -589,49 +595,52 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { //TODO: this needs to support v6 addresses too tree := lh.getCalculatedRemotes() if tree == nil { return false } - calculatedRemotes, ok := tree.Lookup(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } var calculated []*V4AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) + c := cr.Apply(vpnAddr) if c != nil { calculated = append(calculated, c) } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } @@ -693,13 +702,25 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *V6AddrPort) bool return true } -func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseIP(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } +// TODO: IsLighthouseIP should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseIP(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } + } + return false +} + func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { return @@ -915,20 +936,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *HostInfo, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -936,24 +955,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *H switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, reqHostinfo, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, reqHostinfo) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, reqHostinfo, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, reqHostinfo, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostInfo, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -1001,15 +1020,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(reqHostinfo.vpnAddrs[0], func(c *cache) (int, error) { + found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel @@ -1021,15 +1040,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if useVersion == cert.Version1 { - if !reqHostinfo.vpnAddrs[0].Is4() { + if !fromVpnAddrs[0].Is4() { return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") } - b := reqHostinfo.vpnAddrs[0].As4() + b := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(useVersion, c, n) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) lhh.coalesceAnswers(useVersion, c, n) } else { @@ -1044,7 +1063,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } @@ -1094,9 +1113,8 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *HostInfo) { - //TODO: this is kind of dumb - if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseIP(fromVpnAddrs) { return } @@ -1111,12 +1129,12 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H certVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) } - am := lhh.lh.unlockedGetRemoteList(certVpnIp) + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnIp}) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(fromVpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1133,7 +1151,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H } } - am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], certVpnIp, relays) + am.unlockedSetRelay(fromVpnAddrs[0], certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -1143,10 +1161,10 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", reqHostinfo.vpnAddrs) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } @@ -1167,20 +1185,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? //todo why do we care about the vpnip in the packet? We know where it came from, right? - if detailsVpnIp != reqHostinfo.vpnAddrs[0] { + if !slices.Contains(fromVpnAddrs, detailsVpnIp) { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(reqHostinfo.vpnAddrs[0]) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1197,22 +1215,22 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos } } - am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], detailsVpnIp, relays) + am.unlockedSetRelay(fromVpnAddrs[0], detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck if useVersion == cert.Version1 { - if !reqHostinfo.vpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") return } - vpnIpB := reqHostinfo.vpnAddrs[0].As4() + vpnIpB := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnIpB[:]) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) } else { panic("unsupported version") @@ -1220,17 +1238,17 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { //TODO: this is kinda stupid - if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { + if !lhh.lh.IsAnyLighthouseIP(fromVpnAddrs) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index 0c315c09c..2cdfce79a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -135,7 +135,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { mw := &mockEncWriter{} - hi := &HostInfo{vpnAddrs: []netip.Addr{vpnIp2}} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ @@ -325,7 +325,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{myVpnIp}}, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } @@ -350,7 +350,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{vpnIp}}, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } //TODO: this is a RemoteList test diff --git a/outside.go b/outside.go index a94ac9c17..f504bb406 100644 --- a/outside.go +++ b/outside.go @@ -145,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf.HandleRequest(ip, hostinfo, d, f) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic @@ -230,9 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - //TODO: we should delete all related vpnaddrs too - f.lightHouse.DeleteVpnAddr(hostInfo.vpnAddrs[0]) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -241,26 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { - if ip.IsValid() && hostinfo.remote != ip { +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) { + if vpnAddr.IsValid() && hostinfo.remote != vpnAddr { //TODO: this is weird now that we can have multiple vpn addrs - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], ip.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(ip) + hostinfo.SetRemote(vpnAddr) } } From 3ed4dd7e29c15486a84c9bb602558bc1c5a602c7 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 23 Sep 2024 13:06:52 -0500 Subject: [PATCH 16/17] Mostly working --- control_test.go | 2 +- e2e/handshakes_test.go | 116 ++++++++++++++++++++++++-------------- e2e/helpers.go | 4 +- e2e/helpers_test.go | 42 ++++++++++++-- e2e/router/router.go | 47 ++++++++++----- handshake_manager.go | 8 ++- handshake_manager_test.go | 2 +- hostmap.go | 60 +++++++++++--------- lighthouse.go | 35 ++++++++---- lighthouse_test.go | 4 +- remote_list.go | 10 +++- remote_list_test.go | 6 +- service/service_test.go | 2 +- 13 files changed, 222 insertions(+), 116 deletions(-) diff --git a/control_test.go b/control_test.go index cd6364068..0f21ed92a 100644 --- a/control_test.go +++ b/control_test.go @@ -35,7 +35,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Mask: net.IPMask{255, 255, 255, 0}, } - remotes := NewRemoteList(nil) + remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 383478af8..11d2f6724 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,7 +4,6 @@ package e2e import ( - "fmt" "net/netip" "slices" "testing" @@ -12,6 +11,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -21,8 +21,8 @@ import ( func BenchmarkHotPath(b *testing.B) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -45,8 +45,8 @@ func BenchmarkHotPath(b *testing.B) { func TestGoodHandshake(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -100,9 +100,9 @@ func TestWrongResponderHandshake(t *testing.T) { // The IPs here are chosen on purpose: // The current remote handling will sort by preference, public, and then lexically. // So we need them to have a higher address than evil (we could apply a preference though) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Add their real udp addr, which should be tried after evil. myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -165,8 +165,8 @@ func TestStage1Race(t *testing.T) { // But will eventually collapse down to a single tunnel ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -242,8 +242,8 @@ func TestStage1Race(t *testing.T) { func TestUncleanShutdownRaceLoser(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -291,8 +291,8 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { func TestUncleanShutdownRaceWinner(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -342,9 +342,9 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { func TestRelays(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -373,9 +373,9 @@ func TestRelays(t *testing.T) { func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -422,9 +422,9 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay @@ -510,9 +510,9 @@ func TestStage1RaceRelays2(t *testing.T) { func TestRehandshakingRelays(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -539,7 +539,7 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -614,9 +614,9 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) @@ -643,7 +643,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -717,8 +717,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { func TestRehandshaking(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -738,7 +738,7 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(cert.Version1, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -814,8 +814,8 @@ func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -832,14 +832,10 @@ func TestRehandshakingLoser(t *testing.T) { t.Log("Stand up a tunnel between me and them") assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) - tt1 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) - tt2 := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) - fmt.Println(tt1.LocalIndex, tt2.LocalIndex) - r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(cert.Version1, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -914,8 +910,8 @@ func TestRaceRegression(t *testing.T) { // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) @@ -970,6 +966,40 @@ func TestRaceRegression(t *testing.T) { theirControl.Stop() } +func TestV2NonPrimaryWithLighthouse(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[1].Addr().String()}, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + + lhControl.Stop() + myControl.Stop() + theirControl.Stop() +} + //TODO: test // Race winner renews and handshakes // Race loser renews and handshakes diff --git a/e2e/helpers.go b/e2e/helpers.go index c0893aca2..d34d15211 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -48,7 +48,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { +func NewTestCert(v cert.Version, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -59,7 +59,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim pub, rawPriv := x25519Keypair() nc := &cert.TBSCertificate{ - Version: cert.Version1, + Version: v, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index c8b42b007..72e172bcd 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -27,12 +27,12 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() var vpnNetworks []netip.Prefix for _, sn := range strings.Split(sVpnNetworks, ",") { - vpnIpNet, err := netip.ParsePrefix(sn) + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) if err != nil { panic(err) } @@ -55,7 +55,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnNetw budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(v, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -99,11 +99,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnNetw } if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + final := m{} + err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } - mc = overrides + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final } cb, err := yaml.Marshal(mc) @@ -191,6 +196,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn } func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + if toIp.Is6() { + assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) + } else { + assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) + } +} + +func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + assert.NotNil(t, v6, "No ipv6 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + assert.NotNil(t, udp, "No udp data found") + + assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") + assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") + + data := packet.ApplicationLayer() + assert.NotNil(t, data) + assert.Equal(t, expected, data.Payload(), "Data was incorrect") +} + +func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") diff --git a/e2e/router/router.go b/e2e/router/router.go index 5fa382344..492ef3864 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "reflect" + "regexp" "sort" - "strings" "sync" "testing" "time" @@ -216,7 +216,7 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr.String(), ":", "-", 1) + sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", @@ -253,9 +253,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.from.GetUDPAddr().String()), line, - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -270,6 +270,11 @@ func (r *R) renderFlow() { } } +func normalizeName(s string) string { + rx := regexp.MustCompile("[\\[\\]\\:]") + return rx.ReplaceAllLiteralString(s, "_") +} + // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered @@ -714,30 +719,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C } func (r *R) formatUdpPacket(p *packet) string { - packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - if v4 == nil { - panic("not an ipv4 packet") + var packet gopacket.Packet + var srcAddr netip.Addr + + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) + if packet.ErrorLayer() == nil { + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) + } else { + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" - srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) if c, ok := r.vpnControls[srcAddr]; ok { from = c.GetUDPAddr().String() } - udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) - if udp == nil { + udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), - udp.SrcPort, - udp.DstPort, + normalizeName(from), + normalizeName(p.to.GetUDPAddr().String()), + udpLayer.SrcPort, + udpLayer.DstPort, string(data.Payload()), ) } diff --git a/handshake_manager.go b/handshake_manager.go index 6b3902dfa..ee1545647 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -607,14 +607,16 @@ func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { } func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - //TODO: need to iterate hostinfo.vpnAddrs - delete(hm.vpnIps, hostinfo.vpnAddrs[0]) + for _, addr := range hostinfo.vpnAddrs { + delete(hm.vpnIps, addr) + } + if len(hm.vpnIps) == 0 { hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } delete(hm.indexes, hostinfo.localIndexId) - if len(hm.vpnIps) == 0 { + if len(hm.indexes) == 0 { hm.indexes = map[uint32]*HandshakeHostInfo{} } diff --git a/handshake_manager_test.go b/handshake_manager_test.go index ef6a88893..c1898384a 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList(nil) + i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) diff --git a/hostmap.go b/hostmap.go index 63601ee37..e3817e7c6 100644 --- a/hostmap.go +++ b/hostmap.go @@ -308,7 +308,7 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { hm.Lock() // If we have a previous or next hostinfo then we are not the last one for this vpn ip final := (hostinfo.next == nil && hostinfo.prev == nil) - hm.unlockedDeleteHostInfo(hostinfo, false) + hm.unlockedDeleteHostInfo(hostinfo) hm.Unlock() return final @@ -321,6 +321,8 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { + //TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one + // this really looks like an ideal spot for memory leaks oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] if oldHostinfo == hostinfo { return @@ -345,7 +347,19 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { hostinfo.prev = nil } -func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo, dontRecurse bool) { +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedInnerDeleteHostInfo(h) + } + h = h.next + } + } +} + +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo) { primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]] if ok && primary == hostinfo { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it @@ -399,18 +413,6 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo, dontRecurse bool) for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } - - if !dontRecurse { - for _, addr := range hostinfo.vpnAddrs { - h := hm.Hosts[addr] - for h != nil { - if h == hostinfo { - hm.unlockedDeleteHostInfo(h, true) - } - h = h.next - } - } - } } func (hm *HostMap) QueryIndex(index uint32) *HostInfo { @@ -487,17 +489,8 @@ func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostI // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps. // If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { - if f.serveDns { - remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) - } - - existing := hm.Hosts[hostinfo.vpnAddrs[0]] - hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo - - if existing != nil { - hostinfo.next = existing - existing.prev = hostinfo + for _, addr := range hostinfo.vpnAddrs { + hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo @@ -508,12 +501,27 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } +} + +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { + if f.serveDns { + remoteCert := hostinfo.ConnectionState.peerCert + dnsR.Add(remoteCert.Certificate.Name()+".", vpnAddr.String()) + } + + existing := hm.Hosts[vpnAddr] + hm.Hosts[vpnAddr] = hostinfo + + if existing != nil && existing != hostinfo { + hostinfo.next = existing + existing.prev = hostinfo + } i := 1 check := hostinfo for check != nil { if i > MaxHostInfosPerVpnIp { - hm.unlockedDeleteHostInfo(check, false) + hm.unlockedDeleteHostInfo(check) } check = check.next i++ diff --git a/lighthouse.go b/lighthouse.go index 5549e8386..32e280b0e 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -523,7 +523,10 @@ func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (in lh.RUnlock() - // vpnIp should also be the owner here since we are a lighthouse. + // We may be asking about a non primary address so lets get the primary address + if slices.Contains(v.vpnAddrs, vpnAddr) { + vpnAddr = v.vpnAddrs[0] + } c := v.cache[vpnAddr] // Make sure we have if c != nil { @@ -637,7 +640,7 @@ func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) for _, addr := range allAddrs { lh.addrMap[addr] = am } @@ -747,12 +750,15 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { } // Send a query to the lighthouses and hope for the best next time + //TODO: this is not sufficient since the version depends on the certs loaded into memory as well v := lh.protocolVersion.Load() msg := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{}, } + //TODO: remove this + v = 2 if v == 1 { if !addr.Is4() { lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol") @@ -846,6 +852,8 @@ func (lh *LightHouse) SendUpdate() { }, } + //TODO: remove this + v = 2 if v == 1 { var relays []uint32 for _, r := range lh.GetRelaysForMe() { @@ -856,8 +864,10 @@ func (lh *LightHouse) SendUpdate() { relays = append(relays, binary.BigEndian.Uint32(b[:])) } - //TODO: need an ipv4 vpn addr to use msg.Details.OldRelayVpnAddrs = relays + //TODO: assert ipv4 + b := lh.myVpnNetworks[0].Addr().As4() + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) } else if v == 2 { var relays []*Addr @@ -865,7 +875,8 @@ func (lh *LightHouse) SendUpdate() { relays = append(relays, netAddrToProtoAddr(r)) } - //TODO: need a vpn addr to use + // time="lh 15:57:55.871069" level=debug msg="Host sent invalid update" answer="ff::ffff:a80:3" vpnAddrs="[10.128.0.3 ff::3]" what??? + msg.Details.VpnAddr = netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()) } else { panic("protocol version not supported") @@ -931,6 +942,9 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { details.V6AddrPorts = details.V6AddrPorts[:0] details.RelayVpnAddrs = details.RelayVpnAddrs[:0] details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] + //TODO: these are unfortunate + details.OldVpnAddr = 0 + details.VpnAddr = nil lhh.meta.Details = details return lhh.meta @@ -983,16 +997,13 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti var useVersion cert.Version var queryVpnIp netip.Addr - var reqVpnIp netip.Addr if n.Details.OldVpnAddr != 0 { b := [4]byte{} binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) queryVpnIp = netip.AddrFrom4(b) - reqVpnIp = queryVpnIp useVersion = 1 } else if n.Details.VpnAddr != nil { queryVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) - reqVpnIp = queryVpnIp useVersion = 2 } @@ -1001,13 +1012,13 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply if useVersion == 1 { - if !reqVpnIp.Is4() { + if !queryVpnIp.Is4() { return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") } - b := reqVpnIp.As4() + b := queryVpnIp.As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) } else { - n.Details.VpnAddr = netAddrToProtoAddr(reqVpnIp) + n.Details.VpnAddr = netAddrToProtoAddr(queryVpnIp) } lhh.coalesceAnswers(useVersion, c, n) @@ -1033,7 +1044,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti n.Type = NebulaMeta_HostPunchNotification //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel // and use that version then fallback to our default configuration - targetHI := lhh.lh.ifce.GetHostInfo(reqVpnIp) + targetHI := lhh.lh.ifce.GetHostInfo(queryVpnIp) useVersion = cert.Version(lhh.lh.protocolVersion.Load()) if targetHI != nil { useVersion = targetHI.GetCert().Certificate.Version() @@ -1068,7 +1079,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, reqVpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, queryVpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { diff --git a/lighthouse_test.go b/lighthouse_test.go index 2cdfce79a..116099773 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -108,7 +108,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") vpnIp3 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, @@ -122,7 +122,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") vpnIp2 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, diff --git a/remote_list.go b/remote_list.go index 1c7fd4db5..cbbb22a99 100644 --- a/remote_list.go +++ b/remote_list.go @@ -189,6 +189,9 @@ type RemoteList struct { // Every interaction with internals requires a lock! sync.RWMutex + // The full list of vpn addresses assigned to this host + vpnAddrs []netip.Addr + // A deduplicated set of addresses. Any accessor should lock beforehand. addrs []netip.AddrPort @@ -212,13 +215,16 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { - return &RemoteList{ +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { + r := &RemoteList{ + vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } + copy(r.vpnAddrs, vpnAddrs) + return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { diff --git a/remote_list_test.go b/remote_list_test.go index 33bfbb128..1f548c71e 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,7 +9,7 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), @@ -98,7 +98,7 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), @@ -160,7 +160,7 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), diff --git a/service/service_test.go b/service/service_test.go index e9fceef6f..9fbf088d6 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -19,7 +19,7 @@ import ( type m map[string]interface{} func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + _, _, myPrivKey, myPEM := e2e.NewTestCert(cert.Version2, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) From fd4d68ab1a10b73235abf144bc085d06111a895a Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Mon, 23 Sep 2024 21:23:21 -0500 Subject: [PATCH 17/17] Better e2e errors --- e2e/router/router.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/e2e/router/router.go b/e2e/router/router.go index 492ef3864..d5da587da 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -427,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: r.Lock() - c := r.getControl(sender.GetUDPAddr(), p.To, p) + a := sender.GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) @@ -483,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - c := r.getControl(cm[x].GetUDPAddr(), p.To, p) + a := cm[x].GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p)