| /* |
| Copyright 2023 The Kubernetes Authors. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| package iptree |
| |
| import ( |
| "fmt" |
| "math/bits" |
| "net/netip" |
| ) |
| |
| // iptree implement a radix tree that uses IP prefixes as nodes and allows to store values in each node. |
| // Example: |
| // |
| // r := New[int]() |
| // |
| // prefixes := []string{ |
| // "0.0.0.0/0", |
| // "10.0.0.0/8", |
| // "10.0.0.0/16", |
| // "10.1.0.0/16", |
| // "10.1.1.0/24", |
| // "10.1.244.0/24", |
| // "10.0.0.0/24", |
| // "10.0.0.3/32", |
| // "192.168.0.0/24", |
| // "192.168.0.0/28", |
| // "192.168.129.0/28", |
| // } |
| // for _, k := range prefixes { |
| // r.InsertPrefix(netip.MustParsePrefix(k), 0) |
| // } |
| // |
| // (*) means the node is not public, is not storing any value |
| // |
| // 0.0.0.0/0 --- 10.0.0.0/8 --- *10.0.0.0/15 --- 10.0.0.0/16 --- 10.0.0.0/24 --- 10.0.0.3/32 |
| // | | |
| // | \ -------- 10.1.0.0/16 --- 10.1.1.0/24 |
| // | | |
| // | \ ------- 10.1.244.0/24 |
| // | |
| // \------ *192.168.0.0/16 --- 192.168.0.0/24 --- 192.168.0.0/28 |
| // | |
| // \ -------- 192.168.129.0/28 |
| |
| // node is an element of radix tree with a netip.Prefix optimized to store IP prefixes. |
| type node[T any] struct { |
| // prefix network CIDR |
| prefix netip.Prefix |
| // public nodes are used to store values |
| public bool |
| val T |
| |
| child [2]*node[T] // binary tree |
| } |
| |
| // mergeChild allow to compress the tree |
| // when n has exactly one child and no value |
| // p -> n -> b -> c ==> p -> b -> c |
| func (n *node[T]) mergeChild() { |
| // public nodes can not be merged |
| if n.public { |
| return |
| } |
| // can not merge if there are two children |
| if n.child[0] != nil && |
| n.child[1] != nil { |
| return |
| } |
| // can not merge if there are no children |
| if n.child[0] == nil && |
| n.child[1] == nil { |
| return |
| } |
| // find the child and merge it |
| var child *node[T] |
| if n.child[0] != nil { |
| child = n.child[0] |
| } else if n.child[1] != nil { |
| child = n.child[1] |
| } |
| n.prefix = child.prefix |
| n.public = child.public |
| n.val = child.val |
| n.child = child.child |
| // remove any references from the deleted node |
| // to avoid memory leak |
| child.child[0] = nil |
| child.child[1] = nil |
| } |
| |
| // Tree is a radix tree for IPv4 and IPv6 networks. |
| type Tree[T any] struct { |
| rootV4 *node[T] |
| rootV6 *node[T] |
| } |
| |
| // New creates a new Radix Tree for IP addresses. |
| func New[T any]() *Tree[T] { |
| return &Tree[T]{ |
| rootV4: &node[T]{ |
| prefix: netip.PrefixFrom(netip.IPv4Unspecified(), 0), |
| }, |
| rootV6: &node[T]{ |
| prefix: netip.PrefixFrom(netip.IPv6Unspecified(), 0), |
| }, |
| } |
| } |
| |
| // GetPrefix returns the stored value and true if the exact prefix exists in the tree. |
| func (t *Tree[T]) GetPrefix(prefix netip.Prefix) (T, bool) { |
| var zeroT T |
| |
| n := t.rootV4 |
| if prefix.Addr().Is6() { |
| n = t.rootV6 |
| } |
| bitPosition := 0 |
| // mask the address for sanity |
| address := prefix.Masked().Addr() |
| // we can't check longer than the request mask |
| mask := prefix.Bits() |
| // walk the network bits of the prefix |
| for bitPosition < mask { |
| // Look for a child checking the bit position after the mask |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| return zeroT, false |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| return zeroT, false |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| // check if this node is a public node and contains a prefix |
| if n != nil && n.public && n.prefix == prefix { |
| return n.val, true |
| } |
| |
| return zeroT, false |
| } |
| |
| // LongestPrefixMatch returns the longest prefix match, the stored value and true if exist. |
| // For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, |
| // when the address 192.168.20.19/32 is looked up it will return 192.168.20.16/28. |
| func (t *Tree[T]) LongestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { |
| n := t.rootV4 |
| if prefix.Addr().Is6() { |
| n = t.rootV6 |
| } |
| |
| var last *node[T] |
| // bit position is given by the mask bits |
| bitPosition := 0 |
| // mask the address |
| address := prefix.Masked().Addr() |
| mask := prefix.Bits() |
| // walk the network bits of the prefix |
| for bitPosition < mask { |
| if n.public { |
| last = n |
| } |
| // Look for a child checking the bit position after the mask |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| break |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| break |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| |
| if n != nil && n.public && n.prefix == prefix { |
| last = n |
| } |
| |
| if last != nil { |
| return last.prefix, last.val, true |
| } |
| var zeroT T |
| return netip.Prefix{}, zeroT, false |
| } |
| |
| // ShortestPrefixMatch returns the shortest prefix match, the stored value and true if exist. |
| // For example, considering the following prefixes 192.168.20.16/28 and 192.168.0.0/16, |
| // when the address 192.168.20.19/32 is looked up it will return 192.168.0.0/16. |
| func (t *Tree[T]) ShortestPrefixMatch(prefix netip.Prefix) (netip.Prefix, T, bool) { |
| var zeroT T |
| |
| n := t.rootV4 |
| if prefix.Addr().Is6() { |
| n = t.rootV6 |
| } |
| // bit position is given by the mask bits |
| bitPosition := 0 |
| // mask the address |
| address := prefix.Masked().Addr() |
| mask := prefix.Bits() |
| for bitPosition < mask { |
| if n.public { |
| return n.prefix, n.val, true |
| } |
| // Look for a child checking the bit position after the mask |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| return netip.Prefix{}, zeroT, false |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| return netip.Prefix{}, zeroT, false |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| |
| if n != nil && n.public && n.prefix == prefix { |
| return n.prefix, n.val, true |
| } |
| return netip.Prefix{}, zeroT, false |
| } |
| |
| // InsertPrefix is used to add a new entry or update |
| // an existing entry. Returns true if updated. |
| func (t *Tree[T]) InsertPrefix(prefix netip.Prefix, v T) bool { |
| n := t.rootV4 |
| if prefix.Addr().Is6() { |
| n = t.rootV6 |
| } |
| var parent *node[T] |
| // bit position is given by the mask bits |
| bitPosition := 0 |
| // mask the address |
| address := prefix.Masked().Addr() |
| mask := prefix.Bits() |
| for bitPosition < mask { |
| // Look for a child checking the bit position after the mask |
| childIndex := getBitFromAddr(address, bitPosition+1) |
| parent = n |
| n = n.child[childIndex] |
| // if no child create a new one with |
| if n == nil { |
| parent.child[childIndex] = &node[T]{ |
| public: true, |
| val: v, |
| prefix: prefix, |
| } |
| return false |
| } |
| |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| |
| // continue if we are in the right branch and current |
| // node is our parent |
| if n.prefix.Contains(address) && bitPosition <= mask { |
| continue |
| } |
| |
| // Split the node and add a new child: |
| // - Case 1: parent -> child -> n |
| // - Case 2: parent -> newnode |--> child |
| // |--> n |
| child := &node[T]{ |
| prefix: prefix, |
| public: true, |
| val: v, |
| } |
| // Case 1: existing node is a sibling |
| if prefix.Contains(n.prefix.Addr()) && bitPosition > mask { |
| // parent to child |
| parent.child[childIndex] = child |
| pos := prefix.Bits() + 1 |
| // calculate if the sibling is at the left or right |
| child.child[getBitFromAddr(n.prefix.Addr(), pos)] = n |
| return false |
| } |
| |
| // Case 2: existing node has the same mask but different base address |
| // add common ancestor and branch on it |
| ancestor := findAncestor(prefix, n.prefix) |
| link := &node[T]{ |
| prefix: ancestor, |
| } |
| pos := parent.prefix.Bits() + 1 |
| parent.child[getBitFromAddr(ancestor.Addr(), pos)] = link |
| // ancestor -> children |
| pos = ancestor.Bits() + 1 |
| idxChild := getBitFromAddr(prefix.Addr(), pos) |
| idxN := getBitFromAddr(n.prefix.Addr(), pos) |
| if idxChild == idxN { |
| panic(fmt.Sprintf("wrong ancestor %s: child %s N %s", ancestor.String(), prefix.String(), n.prefix.String())) |
| } |
| link.child[idxChild] = child |
| link.child[idxN] = n |
| return false |
| } |
| |
| // if already exist update it and make it public |
| if n != nil && n.prefix == prefix { |
| if n.public { |
| n.val = v |
| n.public = true |
| return true |
| } |
| n.val = v |
| n.public = true |
| return false |
| } |
| |
| return false |
| } |
| |
| // DeletePrefix delete the exact prefix and return true if it existed. |
| func (t *Tree[T]) DeletePrefix(prefix netip.Prefix) bool { |
| root := t.rootV4 |
| if prefix.Addr().Is6() { |
| root = t.rootV6 |
| } |
| var parent *node[T] |
| n := root |
| // bit position is given by the mask bits |
| bitPosition := 0 |
| // mask the address |
| address := prefix.Masked().Addr() |
| mask := prefix.Bits() |
| for bitPosition < mask { |
| // Look for a child checking the bit position after the mask |
| parent = n |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| return false |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| return false |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| // check if the node contains the prefix we want to delete |
| if n.prefix != prefix { |
| return false |
| } |
| // Delete the value |
| n.public = false |
| var zeroT T |
| n.val = zeroT |
| |
| nodeChildren := 0 |
| if n.child[0] != nil { |
| nodeChildren++ |
| } |
| if n.child[1] != nil { |
| nodeChildren++ |
| } |
| // If there is a parent and this node does not have any children |
| // this is a leaf so we can delete this node. |
| // - parent -> child(to be deleted) |
| if parent != nil && nodeChildren == 0 { |
| if parent.child[0] != nil && parent.child[0] == n { |
| parent.child[0] = nil |
| } else if parent.child[1] != nil && parent.child[1] == n { |
| parent.child[1] = nil |
| } else { |
| panic("wrong parent") |
| } |
| n = nil |
| } |
| // Check if we should merge this node |
| // The root node can not be merged |
| if n != root && nodeChildren == 1 { |
| n.mergeChild() |
| } |
| // Check if we should merge the parent's other child |
| // parent -> deletedNode |
| // |--> child |
| parentChildren := 0 |
| if parent != nil { |
| if parent.child[0] != nil { |
| parentChildren++ |
| } |
| if parent.child[1] != nil { |
| parentChildren++ |
| } |
| if parent != root && parentChildren == 1 && !parent.public { |
| parent.mergeChild() |
| } |
| } |
| return true |
| } |
| |
| // for testing, returns the number of public nodes in the tree. |
| func (t *Tree[T]) Len(isV6 bool) int { |
| count := 0 |
| t.DepthFirstWalk(isV6, func(k netip.Prefix, v T) bool { |
| count++ |
| return false |
| }) |
| return count |
| } |
| |
| // WalkFn is used when walking the tree. Takes a |
| // key and value, returning if iteration should |
| // be terminated. |
| type WalkFn[T any] func(s netip.Prefix, v T) bool |
| |
| // DepthFirstWalk is used to walk the tree of the corresponding IP family |
| func (t *Tree[T]) DepthFirstWalk(isIPv6 bool, fn WalkFn[T]) { |
| if isIPv6 { |
| recursiveWalk(t.rootV6, fn) |
| } |
| recursiveWalk(t.rootV4, fn) |
| } |
| |
| // recursiveWalk is used to do a pre-order walk of a node |
| // recursively. Returns true if the walk should be aborted |
| func recursiveWalk[T any](n *node[T], fn WalkFn[T]) bool { |
| if n == nil { |
| return true |
| } |
| // Visit the public values if any |
| if n.public && fn(n.prefix, n.val) { |
| return true |
| } |
| |
| // Recurse on the children |
| if n.child[0] != nil { |
| if recursiveWalk(n.child[0], fn) { |
| return true |
| } |
| } |
| if n.child[1] != nil { |
| if recursiveWalk(n.child[1], fn) { |
| return true |
| } |
| } |
| return false |
| } |
| |
| // WalkPrefix is used to walk the tree under a prefix |
| func (t *Tree[T]) WalkPrefix(prefix netip.Prefix, fn WalkFn[T]) { |
| n := t.rootV4 |
| if prefix.Addr().Is6() { |
| n = t.rootV6 |
| } |
| bitPosition := 0 |
| // mask the address for sanity |
| address := prefix.Masked().Addr() |
| // we can't check longer than the request mask |
| mask := prefix.Bits() |
| // walk the network bits of the prefix |
| for bitPosition < mask { |
| // Look for a child checking the bit position after the mask |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| return |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| break |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| recursiveWalk[T](n, fn) |
| |
| } |
| |
| // WalkPath is used to walk the tree, but only visiting nodes |
| // from the root down to a given IP prefix. Where WalkPrefix walks |
| // all the entries *under* the given prefix, this walks the |
| // entries *above* the given prefix. |
| func (t *Tree[T]) WalkPath(path netip.Prefix, fn WalkFn[T]) { |
| n := t.rootV4 |
| if path.Addr().Is6() { |
| n = t.rootV6 |
| } |
| bitPosition := 0 |
| // mask the address for sanity |
| address := path.Masked().Addr() |
| // we can't check longer than the request mask |
| mask := path.Bits() |
| // walk the network bits of the prefix |
| for bitPosition < mask { |
| // Visit the public values if any |
| if n.public && fn(n.prefix, n.val) { |
| return |
| } |
| // Look for a child checking the bit position after the mask |
| n = n.child[getBitFromAddr(address, bitPosition+1)] |
| if n == nil { |
| return |
| } |
| // check we are in the right branch comparing the suffixes |
| if !n.prefix.Contains(address) { |
| return |
| } |
| // update the new bit position with the new node mask |
| bitPosition = n.prefix.Bits() |
| } |
| // check if this node is a public node and contains a prefix |
| if n != nil && n.public && n.prefix == path { |
| fn(n.prefix, n.val) |
| } |
| } |
| |
| // TopLevelPrefixes is used to return a map with all the Top Level prefixes |
| // from the corresponding IP family and its values. |
| // For example, if the tree contains entries for 10.0.0.0/8, 10.1.0.0/16, and 192.168.0.0/16, |
| // this will return 10.0.0.0/8 and 192.168.0.0/16. |
| func (t *Tree[T]) TopLevelPrefixes(isIPv6 bool) map[string]T { |
| if isIPv6 { |
| return t.topLevelPrefixes(t.rootV6) |
| } |
| return t.topLevelPrefixes(t.rootV4) |
| } |
| |
| // topLevelPrefixes is used to return a map with all the Top Level prefixes and its values |
| func (t *Tree[T]) topLevelPrefixes(root *node[T]) map[string]T { |
| result := map[string]T{} |
| queue := []*node[T]{root} |
| |
| for len(queue) > 0 { |
| n := queue[0] |
| queue = queue[1:] |
| // store and continue, only interested on the top level prefixes |
| if n.public { |
| result[n.prefix.String()] = n.val |
| continue |
| } |
| if n.child[0] != nil { |
| queue = append(queue, n.child[0]) |
| } |
| if n.child[1] != nil { |
| queue = append(queue, n.child[1]) |
| } |
| } |
| return result |
| } |
| |
| // GetHostIPPrefixMatches returns the list of prefixes that contain the specified Host IP. |
| // An IP is considered a Host IP if is within the subnet range and is not the network address |
| // or, if IPv4, the broadcast address (RFC 1878). |
| func (t *Tree[T]) GetHostIPPrefixMatches(ip netip.Addr) map[netip.Prefix]T { |
| // walk the tree to find all the prefixes containing this IP |
| ipPrefix := netip.PrefixFrom(ip, ip.BitLen()) |
| prefixes := map[netip.Prefix]T{} |
| t.WalkPath(ipPrefix, func(k netip.Prefix, v T) bool { |
| if prefixContainIP(k, ipPrefix.Addr()) { |
| prefixes[k] = v |
| } |
| return false |
| }) |
| return prefixes |
| } |
| |
| // assume starts at 0 from the MSB: 0.1.2......31 |
| // return 0 or 1 |
| func getBitFromAddr(ip netip.Addr, pos int) int { |
| bytes := ip.AsSlice() |
| // get the byte in the slice |
| index := (pos - 1) / 8 |
| if index >= len(bytes) { |
| panic(fmt.Sprintf("ip %s pos %d index %d bytes %v", ip, pos, index, bytes)) |
| } |
| // get the offset inside the byte |
| offset := (pos - 1) % 8 |
| // check if the bit is set |
| if bytes[index]&(uint8(0x80)>>offset) > 0 { |
| return 1 |
| } |
| return 0 |
| } |
| |
| // find the common subnet, aka the one with the common prefix |
| func findAncestor(a, b netip.Prefix) netip.Prefix { |
| bytesA := a.Addr().AsSlice() |
| bytesB := b.Addr().AsSlice() |
| bytes := make([]byte, len(bytesA)) |
| |
| max := a.Bits() |
| if l := b.Bits(); l < max { |
| max = l |
| } |
| |
| mask := 0 |
| for i := range bytesA { |
| xor := bytesA[i] ^ bytesB[i] |
| if xor == 0 { |
| bytes[i] = bytesA[i] |
| mask += 8 |
| |
| } else { |
| pos := bits.LeadingZeros8(xor) |
| mask += pos |
| // mask off the non leading zeros |
| bytes[i] = bytesA[i] & (^uint8(0) << (8 - pos)) |
| break |
| } |
| } |
| if mask > max { |
| mask = max |
| } |
| |
| addr, ok := netip.AddrFromSlice(bytes) |
| if !ok { |
| panic(bytes) |
| } |
| ancestor := netip.PrefixFrom(addr, mask) |
| return ancestor.Masked() |
| } |
| |
| // prefixContainIP returns true if the given IP is contained with the prefix, |
| // is not the network address and also, if IPv4, is not the broadcast address. |
| // This is required because the Kubernetes allocators reserve these addresses |
| // so IPAddresses can not block deletion of this ranges. |
| func prefixContainIP(prefix netip.Prefix, ip netip.Addr) bool { |
| // if the IP is the network address is not contained |
| if prefix.Masked().Addr() == ip { |
| return false |
| } |
| // the broadcast address is not considered contained for IPv4 |
| if !ip.Is6() { |
| ipLast, err := broadcastAddress(prefix) |
| if err != nil || ipLast == ip { |
| return false |
| } |
| } |
| return prefix.Contains(ip) |
| } |
| |
| // TODO(aojea) consolidate all these IPs utils |
| // pkg/registry/core/service/ipallocator/ipallocator.go |
| // broadcastAddress returns the broadcast address of the subnet |
| // The broadcast address is obtained by setting all the host bits |
| // in a subnet to 1. |
| // network 192.168.0.0/24 : subnet bits 24 host bits 32 - 24 = 8 |
| // broadcast address 192.168.0.255 |
| func broadcastAddress(subnet netip.Prefix) (netip.Addr, error) { |
| base := subnet.Masked().Addr() |
| bytes := base.AsSlice() |
| // get all the host bits from the subnet |
| n := 8*len(bytes) - subnet.Bits() |
| // set all the host bits to 1 |
| for i := len(bytes) - 1; i >= 0 && n > 0; i-- { |
| if n >= 8 { |
| bytes[i] = 0xff |
| n -= 8 |
| } else { |
| mask := ^uint8(0) >> (8 - n) |
| bytes[i] |= mask |
| break |
| } |
| } |
| |
| addr, ok := netip.AddrFromSlice(bytes) |
| if !ok { |
| return netip.Addr{}, fmt.Errorf("invalid address %v", bytes) |
| } |
| return addr, nil |
| } |