Rewrote entire NFT Update Logic
- more flexibility by a struct based approch - allows for collection of all changes from a complex confi - pushes everything at once to NFT
This commit is contained in:
		@@ -30,7 +30,7 @@ func init() {
 | 
			
		||||
		log.Printf("Configuration in use: %v", viper.AllSettings())
 | 
			
		||||
		service.LoadConfig()
 | 
			
		||||
 | 
			
		||||
		err := service.NFTUpdateSets(*table, *set4, *ip4, *set6, *ip6)
 | 
			
		||||
		err := service.NFTUpdateSetsCmd(*table, *set4, *ip4, *set6, *ip6)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatalf("Could not update NFT: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,81 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net"
 | 
			
		||||
	"runtime"
 | 
			
		||||
 | 
			
		||||
	"github.com/google/nftables"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error {
 | 
			
		||||
type NFTUpdate struct {
 | 
			
		||||
	Tables map[string]*NFTUpdateTable
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NFTUpdateTable struct {
 | 
			
		||||
	TableName string
 | 
			
		||||
	Sets      map[string]*NFTUpdateSet
 | 
			
		||||
	done      bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type NFTUpdateSet struct {
 | 
			
		||||
	SetName string
 | 
			
		||||
	IP6Set  bool
 | 
			
		||||
	IPs     []net.IP
 | 
			
		||||
	done    bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewNFTUpdate() *NFTUpdate {
 | 
			
		||||
	f := new(NFTUpdate)
 | 
			
		||||
	f.Tables = make(map[string]*NFTUpdateTable)
 | 
			
		||||
	return f
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (nu *NFTUpdate) FindOrAddTable(TableName string) *NFTUpdateTable {
 | 
			
		||||
	if nut, ok := nu.Tables[TableName]; ok {
 | 
			
		||||
		return nut
 | 
			
		||||
	} else {
 | 
			
		||||
		nut := NFTUpdateTable{
 | 
			
		||||
			TableName: TableName,
 | 
			
		||||
			Sets:      make(map[string]*NFTUpdateSet),
 | 
			
		||||
			done:      false,
 | 
			
		||||
		}
 | 
			
		||||
		nu.Tables[TableName] = &nut
 | 
			
		||||
		return &nut
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (nut *NFTUpdateTable) FindOrAddSet(SetName string, IP6 bool) (*NFTUpdateSet, error) {
 | 
			
		||||
	if nus, ok := nut.Sets[SetName]; ok {
 | 
			
		||||
		if nus.IP6Set != IP6 {
 | 
			
		||||
			return nil, fmt.Errorf("set %s has been already declared with IP6=%t", SetName, nus.IP6Set)
 | 
			
		||||
		}
 | 
			
		||||
		return nus, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		nus := NFTUpdateSet{
 | 
			
		||||
			SetName: SetName,
 | 
			
		||||
			IP6Set:  IP6,
 | 
			
		||||
			done:    false,
 | 
			
		||||
		}
 | 
			
		||||
		nut.Sets[SetName] = &nus
 | 
			
		||||
		return &nus, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (nu *NFTUpdate) AddIP(TableName string, SetName string, IP net.IP) error {
 | 
			
		||||
	ip6 := (IP.To4() == nil)
 | 
			
		||||
	nut := nu.FindOrAddTable(TableName)
 | 
			
		||||
	nus, err := nut.FindOrAddSet(SetName, ip6)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("set %s in Table %T is of ipv6=%t, %v is not", SetName, TableName, nus.IP6Set, IP)
 | 
			
		||||
	}
 | 
			
		||||
	nus.IPs = append(nus.IPs, IP)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (upd *NFTUpdate) Process() error {
 | 
			
		||||
	runtime.LockOSThread()
 | 
			
		||||
	defer runtime.UnlockOSThread()
 | 
			
		||||
 | 
			
		||||
@@ -19,64 +86,95 @@ func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name strin
 | 
			
		||||
		return fmt.Errorf("could not list NFT tables: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var table *nftables.Table = nil
 | 
			
		||||
	for _, t := range tables {
 | 
			
		||||
		if t.Name == tableName {
 | 
			
		||||
			if table == nil {
 | 
			
		||||
				table = t
 | 
			
		||||
			} else {
 | 
			
		||||
				return fmt.Errorf("found two tables with name %s", tableName)
 | 
			
		||||
	for _, table := range tables {
 | 
			
		||||
		if updTable, ok := upd.Tables[table.Name]; ok {
 | 
			
		||||
 | 
			
		||||
			for _, updSet := range updTable.Sets {
 | 
			
		||||
				// Try to load set
 | 
			
		||||
				set, err := conn.GetSetByName(table, updSet.SetName)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return fmt.Errorf("could not find IPv4 NFT set %s: %v", updSet.SetName, err)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Validate set type
 | 
			
		||||
				if updSet.IP6Set {
 | 
			
		||||
					if set.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
 | 
			
		||||
						return fmt.Errorf("the NFT set %s is not of type ip6", updSet.SetName)
 | 
			
		||||
					}
 | 
			
		||||
				} else {
 | 
			
		||||
					if set.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
 | 
			
		||||
						return fmt.Errorf("the NFT set %s is not of type ip", updSet.SetName)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// Flush the set so that we can start adding new IPs
 | 
			
		||||
				conn.FlushSet(set)
 | 
			
		||||
 | 
			
		||||
				// Loop over all IPs to add
 | 
			
		||||
				for _, ip := range updSet.IPs {
 | 
			
		||||
					// Convert IP to binary representation and validate inbound IP type
 | 
			
		||||
					var ipbin net.IP
 | 
			
		||||
					if updSet.IP6Set {
 | 
			
		||||
						ipbin = ip.To16()
 | 
			
		||||
					} else {
 | 
			
		||||
						ipbin = ip.To4()
 | 
			
		||||
					}
 | 
			
		||||
					if ipbin == nil {
 | 
			
		||||
						return fmt.Errorf("ip %v not valid for Table %s, Set %s, IP6=%t", ip, table.Name, set.Name, updSet.IP6Set)
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					// Add IP to the set
 | 
			
		||||
					err = conn.SetAddElements(set, []nftables.SetElement{{Key: ipbin}})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return fmt.Errorf("failed to add ip %v to Table %s, Set %s, IP6=%t: %v", ip, table.Name, set.Name, updSet.IP6Set, err)
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				updSet.done = true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if table == nil {
 | 
			
		||||
		return fmt.Errorf("could not find table %s", tableName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ip4 != nil {
 | 
			
		||||
		ip4bin := ip4.To4()
 | 
			
		||||
		if ip4bin == nil {
 | 
			
		||||
			return fmt.Errorf("ipv4 must be a valid IPv4 address")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		set4, err := conn.GetSetByName(table, set4name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("could not find IPv4 NFT set %s: %v", set4name, err)
 | 
			
		||||
		}
 | 
			
		||||
		if set4.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
 | 
			
		||||
			return fmt.Errorf("the NFT set %s is not of type ip", set4name)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		conn.FlushSet(set4)
 | 
			
		||||
 | 
			
		||||
		err = conn.SetAddElements(set4, []nftables.SetElement{{Key: ip4bin}})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to add IP %v to set %s: %v", ip4, set4name, err)
 | 
			
		||||
			updTable.done = true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if ip6 != nil {
 | 
			
		||||
		ip6bin := ip6.To16()
 | 
			
		||||
		if ip6bin == nil {
 | 
			
		||||
			return fmt.Errorf("ipv6 must be a valid IPv6 address")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		set6, err := conn.GetSetByName(table, set6name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("could not find IPv6 NFT set %s: %v", set6name, err)
 | 
			
		||||
		}
 | 
			
		||||
		if set6.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
 | 
			
		||||
			return fmt.Errorf("the NFT set %s is not of type ip6", set6name)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		conn.FlushSet(set6)
 | 
			
		||||
 | 
			
		||||
		err = conn.SetAddElements(set6, []nftables.SetElement{{Key: ip6bin}})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to add IP %v to set %s: %v", ip6, set6name, err)
 | 
			
		||||
	// Check for unprocessed tables at this point. We use a NFT Table List as outer
 | 
			
		||||
	// loop, as we cannot directly lookup an NFT Table by its name. So we need to
 | 
			
		||||
	// check for inconsistencies here. Sets are looked up by their name, so we loop
 | 
			
		||||
	// over our update structure here, so there is no need to check this.
 | 
			
		||||
	for _, updTable := range upd.Tables {
 | 
			
		||||
		if !updTable.done {
 | 
			
		||||
			return fmt.Errorf("the table %s was not found while updating, aborting", updTable.TableName)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Send the changes.
 | 
			
		||||
	conn.Flush()
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (obj *NFTUpdate) PrettyPrint() string {
 | 
			
		||||
	s, err := json.MarshalIndent(obj, "", " ")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatalf("Failed to pretty print NFT Update Struct via JSON: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	return string(s)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NFTUpdateSetsCmd(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error {
 | 
			
		||||
	upd := NewNFTUpdate()
 | 
			
		||||
	if ip4 != nil {
 | 
			
		||||
		if err := upd.AddIP(tableName, set4name, ip4); err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to add ip4 %v: %v", ip4, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if ip6 != nil {
 | 
			
		||||
		if err := upd.AddIP(tableName, set6name, ip6); err != nil {
 | 
			
		||||
			return fmt.Errorf("failed to add ip6 %v: %v", ip6, err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if err := upd.Process(); err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to process NFT Updates: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user