diff --git a/cmd/nft-update.go b/cmd/nft-update.go index 4631769..2a994d4 100644 --- a/cmd/nft-update.go +++ b/cmd/nft-update.go @@ -30,7 +30,7 @@ func init() { log.Printf("Configuration in use: %v", viper.AllSettings()) service.LoadConfig() - err := service.NFTUpdateSets(*table, *set4, *ip4, *set6, *ip6) + err := service.NFTUpdateSetsCmd(*table, *set4, *ip4, *set6, *ip6) if err != nil { log.Fatalf("Could not update NFT: %s", err) } diff --git a/service/nftables.go b/service/nftables.go index f67c824..986fc10 100644 --- a/service/nftables.go +++ b/service/nftables.go @@ -1,14 +1,81 @@ package service import ( + "encoding/json" "fmt" + "log" "net" "runtime" "github.com/google/nftables" ) -func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error { +type NFTUpdate struct { + Tables map[string]*NFTUpdateTable +} + +type NFTUpdateTable struct { + TableName string + Sets map[string]*NFTUpdateSet + done bool +} + +type NFTUpdateSet struct { + SetName string + IP6Set bool + IPs []net.IP + done bool +} + +func NewNFTUpdate() *NFTUpdate { + f := new(NFTUpdate) + f.Tables = make(map[string]*NFTUpdateTable) + return f +} + +func (nu *NFTUpdate) FindOrAddTable(TableName string) *NFTUpdateTable { + if nut, ok := nu.Tables[TableName]; ok { + return nut + } else { + nut := NFTUpdateTable{ + TableName: TableName, + Sets: make(map[string]*NFTUpdateSet), + done: false, + } + nu.Tables[TableName] = &nut + return &nut + } +} + +func (nut *NFTUpdateTable) FindOrAddSet(SetName string, IP6 bool) (*NFTUpdateSet, error) { + if nus, ok := nut.Sets[SetName]; ok { + if nus.IP6Set != IP6 { + return nil, fmt.Errorf("set %s has been already declared with IP6=%t", SetName, nus.IP6Set) + } + return nus, nil + } else { + nus := NFTUpdateSet{ + SetName: SetName, + IP6Set: IP6, + done: false, + } + nut.Sets[SetName] = &nus + return &nus, nil + } +} + +func (nu *NFTUpdate) AddIP(TableName string, SetName string, IP net.IP) error { + ip6 := (IP.To4() == nil) + nut := nu.FindOrAddTable(TableName) + nus, err := nut.FindOrAddSet(SetName, ip6) + if err != nil { + return fmt.Errorf("set %s in Table %T is of ipv6=%t, %v is not", SetName, TableName, nus.IP6Set, IP) + } + nus.IPs = append(nus.IPs, IP) + return nil +} + +func (upd *NFTUpdate) Process() error { runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -19,64 +86,95 @@ func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name strin return fmt.Errorf("could not list NFT tables: %v", err) } - var table *nftables.Table = nil - for _, t := range tables { - if t.Name == tableName { - if table == nil { - table = t - } else { - return fmt.Errorf("found two tables with name %s", tableName) + for _, table := range tables { + if updTable, ok := upd.Tables[table.Name]; ok { + + for _, updSet := range updTable.Sets { + // Try to load set + set, err := conn.GetSetByName(table, updSet.SetName) + if err != nil { + return fmt.Errorf("could not find IPv4 NFT set %s: %v", updSet.SetName, err) + } + + // Validate set type + if updSet.IP6Set { + if set.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() { + return fmt.Errorf("the NFT set %s is not of type ip6", updSet.SetName) + } + } else { + if set.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() { + return fmt.Errorf("the NFT set %s is not of type ip", updSet.SetName) + } + } + + // Flush the set so that we can start adding new IPs + conn.FlushSet(set) + + // Loop over all IPs to add + for _, ip := range updSet.IPs { + // Convert IP to binary representation and validate inbound IP type + var ipbin net.IP + if updSet.IP6Set { + ipbin = ip.To16() + } else { + ipbin = ip.To4() + } + if ipbin == nil { + return fmt.Errorf("ip %v not valid for Table %s, Set %s, IP6=%t", ip, table.Name, set.Name, updSet.IP6Set) + } + + // Add IP to the set + err = conn.SetAddElements(set, []nftables.SetElement{{Key: ipbin}}) + if err != nil { + return fmt.Errorf("failed to add ip %v to Table %s, Set %s, IP6=%t: %v", ip, table.Name, set.Name, updSet.IP6Set, err) + } + } + + updSet.done = true } - } - } - if table == nil { - return fmt.Errorf("could not find table %s", tableName) - } - if ip4 != nil { - ip4bin := ip4.To4() - if ip4bin == nil { - return fmt.Errorf("ipv4 must be a valid IPv4 address") - } - - set4, err := conn.GetSetByName(table, set4name) - if err != nil { - return fmt.Errorf("could not find IPv4 NFT set %s: %v", set4name, err) - } - if set4.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() { - return fmt.Errorf("the NFT set %s is not of type ip", set4name) - } - - conn.FlushSet(set4) - - err = conn.SetAddElements(set4, []nftables.SetElement{{Key: ip4bin}}) - if err != nil { - return fmt.Errorf("failed to add IP %v to set %s: %v", ip4, set4name, err) + updTable.done = true } } - if ip6 != nil { - ip6bin := ip6.To16() - if ip6bin == nil { - return fmt.Errorf("ipv6 must be a valid IPv6 address") - } - - set6, err := conn.GetSetByName(table, set6name) - if err != nil { - return fmt.Errorf("could not find IPv6 NFT set %s: %v", set6name, err) - } - if set6.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() { - return fmt.Errorf("the NFT set %s is not of type ip6", set6name) - } - - conn.FlushSet(set6) - - err = conn.SetAddElements(set6, []nftables.SetElement{{Key: ip6bin}}) - if err != nil { - return fmt.Errorf("failed to add IP %v to set %s: %v", ip6, set6name, err) + // Check for unprocessed tables at this point. We use a NFT Table List as outer + // loop, as we cannot directly lookup an NFT Table by its name. So we need to + // check for inconsistencies here. Sets are looked up by their name, so we loop + // over our update structure here, so there is no need to check this. + for _, updTable := range upd.Tables { + if !updTable.done { + return fmt.Errorf("the table %s was not found while updating, aborting", updTable.TableName) } } + // Send the changes. conn.Flush() return nil } + +func (obj *NFTUpdate) PrettyPrint() string { + s, err := json.MarshalIndent(obj, "", " ") + if err != nil { + log.Fatalf("Failed to pretty print NFT Update Struct via JSON: %v", err) + } + return string(s) +} + +func NFTUpdateSetsCmd(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error { + upd := NewNFTUpdate() + if ip4 != nil { + if err := upd.AddIP(tableName, set4name, ip4); err != nil { + return fmt.Errorf("failed to add ip4 %v: %v", ip4, err) + } + } + if ip6 != nil { + if err := upd.AddIP(tableName, set6name, ip6); err != nil { + return fmt.Errorf("failed to add ip6 %v: %v", ip6, err) + } + } + if err := upd.Process(); err != nil { + return fmt.Errorf("failed to process NFT Updates: %v", err) + } + + return nil +}