Rewrote entire NFT Update Logic

- more flexibility by a struct based approch
- allows for collection of all changes from a complex confi
- pushes everything at once to NFT
This commit is contained in:
Torben Nehmer 2021-08-22 20:45:36 +02:00
parent 5242cc98f7
commit fbfd75bf99
2 changed files with 150 additions and 52 deletions

View File

@ -30,7 +30,7 @@ func init() {
log.Printf("Configuration in use: %v", viper.AllSettings()) log.Printf("Configuration in use: %v", viper.AllSettings())
service.LoadConfig() service.LoadConfig()
err := service.NFTUpdateSets(*table, *set4, *ip4, *set6, *ip6) err := service.NFTUpdateSetsCmd(*table, *set4, *ip4, *set6, *ip6)
if err != nil { if err != nil {
log.Fatalf("Could not update NFT: %s", err) log.Fatalf("Could not update NFT: %s", err)
} }

View File

@ -1,14 +1,81 @@
package service package service
import ( import (
"encoding/json"
"fmt" "fmt"
"log"
"net" "net"
"runtime" "runtime"
"github.com/google/nftables" "github.com/google/nftables"
) )
func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error { type NFTUpdate struct {
Tables map[string]*NFTUpdateTable
}
type NFTUpdateTable struct {
TableName string
Sets map[string]*NFTUpdateSet
done bool
}
type NFTUpdateSet struct {
SetName string
IP6Set bool
IPs []net.IP
done bool
}
func NewNFTUpdate() *NFTUpdate {
f := new(NFTUpdate)
f.Tables = make(map[string]*NFTUpdateTable)
return f
}
func (nu *NFTUpdate) FindOrAddTable(TableName string) *NFTUpdateTable {
if nut, ok := nu.Tables[TableName]; ok {
return nut
} else {
nut := NFTUpdateTable{
TableName: TableName,
Sets: make(map[string]*NFTUpdateSet),
done: false,
}
nu.Tables[TableName] = &nut
return &nut
}
}
func (nut *NFTUpdateTable) FindOrAddSet(SetName string, IP6 bool) (*NFTUpdateSet, error) {
if nus, ok := nut.Sets[SetName]; ok {
if nus.IP6Set != IP6 {
return nil, fmt.Errorf("set %s has been already declared with IP6=%t", SetName, nus.IP6Set)
}
return nus, nil
} else {
nus := NFTUpdateSet{
SetName: SetName,
IP6Set: IP6,
done: false,
}
nut.Sets[SetName] = &nus
return &nus, nil
}
}
func (nu *NFTUpdate) AddIP(TableName string, SetName string, IP net.IP) error {
ip6 := (IP.To4() == nil)
nut := nu.FindOrAddTable(TableName)
nus, err := nut.FindOrAddSet(SetName, ip6)
if err != nil {
return fmt.Errorf("set %s in Table %T is of ipv6=%t, %v is not", SetName, TableName, nus.IP6Set, IP)
}
nus.IPs = append(nus.IPs, IP)
return nil
}
func (upd *NFTUpdate) Process() error {
runtime.LockOSThread() runtime.LockOSThread()
defer runtime.UnlockOSThread() defer runtime.UnlockOSThread()
@ -19,64 +86,95 @@ func NFTUpdateSets(tableName string, set4name string, ip4 net.IP, set6name strin
return fmt.Errorf("could not list NFT tables: %v", err) return fmt.Errorf("could not list NFT tables: %v", err)
} }
var table *nftables.Table = nil for _, table := range tables {
for _, t := range tables { if updTable, ok := upd.Tables[table.Name]; ok {
if t.Name == tableName {
if table == nil { for _, updSet := range updTable.Sets {
table = t // Try to load set
set, err := conn.GetSetByName(table, updSet.SetName)
if err != nil {
return fmt.Errorf("could not find IPv4 NFT set %s: %v", updSet.SetName, err)
}
// Validate set type
if updSet.IP6Set {
if set.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
return fmt.Errorf("the NFT set %s is not of type ip6", updSet.SetName)
}
} else { } else {
return fmt.Errorf("found two tables with name %s", tableName) if set.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
return fmt.Errorf("the NFT set %s is not of type ip", updSet.SetName)
} }
} }
}
if table == nil {
return fmt.Errorf("could not find table %s", tableName)
}
if ip4 != nil { // Flush the set so that we can start adding new IPs
ip4bin := ip4.To4() conn.FlushSet(set)
if ip4bin == nil {
return fmt.Errorf("ipv4 must be a valid IPv4 address") // Loop over all IPs to add
for _, ip := range updSet.IPs {
// Convert IP to binary representation and validate inbound IP type
var ipbin net.IP
if updSet.IP6Set {
ipbin = ip.To16()
} else {
ipbin = ip.To4()
}
if ipbin == nil {
return fmt.Errorf("ip %v not valid for Table %s, Set %s, IP6=%t", ip, table.Name, set.Name, updSet.IP6Set)
} }
set4, err := conn.GetSetByName(table, set4name) // Add IP to the set
err = conn.SetAddElements(set, []nftables.SetElement{{Key: ipbin}})
if err != nil { if err != nil {
return fmt.Errorf("could not find IPv4 NFT set %s: %v", set4name, err) return fmt.Errorf("failed to add ip %v to Table %s, Set %s, IP6=%t: %v", ip, table.Name, set.Name, updSet.IP6Set, err)
}
if set4.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
return fmt.Errorf("the NFT set %s is not of type ip", set4name)
}
conn.FlushSet(set4)
err = conn.SetAddElements(set4, []nftables.SetElement{{Key: ip4bin}})
if err != nil {
return fmt.Errorf("failed to add IP %v to set %s: %v", ip4, set4name, err)
} }
} }
if ip6 != nil { updSet.done = true
ip6bin := ip6.To16()
if ip6bin == nil {
return fmt.Errorf("ipv6 must be a valid IPv6 address")
} }
set6, err := conn.GetSetByName(table, set6name) updTable.done = true
if err != nil {
return fmt.Errorf("could not find IPv6 NFT set %s: %v", set6name, err)
}
if set6.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
return fmt.Errorf("the NFT set %s is not of type ip6", set6name)
}
conn.FlushSet(set6)
err = conn.SetAddElements(set6, []nftables.SetElement{{Key: ip6bin}})
if err != nil {
return fmt.Errorf("failed to add IP %v to set %s: %v", ip6, set6name, err)
} }
} }
// Check for unprocessed tables at this point. We use a NFT Table List as outer
// loop, as we cannot directly lookup an NFT Table by its name. So we need to
// check for inconsistencies here. Sets are looked up by their name, so we loop
// over our update structure here, so there is no need to check this.
for _, updTable := range upd.Tables {
if !updTable.done {
return fmt.Errorf("the table %s was not found while updating, aborting", updTable.TableName)
}
}
// Send the changes.
conn.Flush() conn.Flush()
return nil return nil
} }
func (obj *NFTUpdate) PrettyPrint() string {
s, err := json.MarshalIndent(obj, "", " ")
if err != nil {
log.Fatalf("Failed to pretty print NFT Update Struct via JSON: %v", err)
}
return string(s)
}
func NFTUpdateSetsCmd(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error {
upd := NewNFTUpdate()
if ip4 != nil {
if err := upd.AddIP(tableName, set4name, ip4); err != nil {
return fmt.Errorf("failed to add ip4 %v: %v", ip4, err)
}
}
if ip6 != nil {
if err := upd.AddIP(tableName, set6name, ip6); err != nil {
return fmt.Errorf("failed to add ip6 %v: %v", ip6, err)
}
}
if err := upd.Process(); err != nil {
return fmt.Errorf("failed to process NFT Updates: %v", err)
}
return nil
}