package service import ( "encoding/json" "fmt" "log" "net" "runtime" "github.com/google/nftables" ) type NFTUpdate struct { Tables map[string]*NFTUpdateTable } type NFTUpdateTable struct { TableName string Sets map[string]*NFTUpdateSet done bool } type NFTUpdateSet struct { SetName string IP6Set bool IPs []net.IP done bool } func NewNFTUpdate() *NFTUpdate { f := new(NFTUpdate) f.Tables = make(map[string]*NFTUpdateTable) return f } func (nu *NFTUpdate) FindOrAddTable(TableName string) *NFTUpdateTable { if nut, ok := nu.Tables[TableName]; ok { return nut } else { nut := NFTUpdateTable{ TableName: TableName, Sets: make(map[string]*NFTUpdateSet), done: false, } nu.Tables[TableName] = &nut return &nut } } func (nut *NFTUpdateTable) FindOrAddSet(SetName string, IP6 bool) (*NFTUpdateSet, error) { if nus, ok := nut.Sets[SetName]; ok { if nus.IP6Set != IP6 { return nil, fmt.Errorf("set %s has been already declared with IP6=%t", SetName, nus.IP6Set) } return nus, nil } else { nus := NFTUpdateSet{ SetName: SetName, IP6Set: IP6, done: false, } nut.Sets[SetName] = &nus return &nus, nil } } func (nu *NFTUpdate) AddIP(TableName string, SetName string, IP net.IP) error { if IP == nil { return nil } ip6 := (IP.To4() == nil) nut := nu.FindOrAddTable(TableName) nus, err := nut.FindOrAddSet(SetName, ip6) if err != nil { return fmt.Errorf("set %s in Table %T is of ipv6=%t, %v is not", SetName, TableName, nus.IP6Set, IP) } nus.IPs = append(nus.IPs, IP) return nil } func (upd *NFTUpdate) Process() error { runtime.LockOSThread() defer runtime.UnlockOSThread() conn := &nftables.Conn{} tables, err := conn.ListTables() if err != nil { return fmt.Errorf("could not list NFT tables: %v", err) } for _, table := range tables { if updTable, ok := upd.Tables[table.Name]; ok { for _, updSet := range updTable.Sets { // Try to load set set, err := conn.GetSetByName(table, updSet.SetName) if err != nil { return fmt.Errorf("could not find IPv4 NFT set %s: %v", updSet.SetName, err) } // Validate set type if updSet.IP6Set { if set.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() { return fmt.Errorf("the NFT set %s is not of type ip6", updSet.SetName) } } else { if set.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() { return fmt.Errorf("the NFT set %s is not of type ip", updSet.SetName) } } // Flush the set so that we can start adding new IPs conn.FlushSet(set) // Loop over all IPs to add for _, ip := range updSet.IPs { // Convert IP to binary representation and validate inbound IP type var ipbin net.IP if updSet.IP6Set { ipbin = ip.To16() } else { ipbin = ip.To4() } if ipbin == nil { return fmt.Errorf("ip %v not valid for Table %s, Set %s, IP6=%t", ip, table.Name, set.Name, updSet.IP6Set) } // Add IP to the set err = conn.SetAddElements(set, []nftables.SetElement{{Key: ipbin}}) if err != nil { return fmt.Errorf("failed to add ip %v to Table %s, Set %s, IP6=%t: %v", ip, table.Name, set.Name, updSet.IP6Set, err) } } updSet.done = true } updTable.done = true } } // Check for unprocessed tables at this point. We use a NFT Table List as outer // loop, as we cannot directly lookup an NFT Table by its name. So we need to // check for inconsistencies here. Sets are looked up by their name, so we loop // over our update structure here, so there is no need to check this. for _, updTable := range upd.Tables { if !updTable.done { return fmt.Errorf("the table %s was not found while updating, aborting", updTable.TableName) } } // Send the changes. conn.Flush() return nil } func (obj *NFTUpdate) PrettyPrint() string { s, err := json.MarshalIndent(obj, "", " ") if err != nil { log.Fatalf("Failed to pretty print NFT Update Struct via JSON: %v", err) } return string(s) } func NFTUpdateSetsCmd(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error { upd := NewNFTUpdate() if ip4 != nil { if err := upd.AddIP(tableName, set4name, ip4); err != nil { return fmt.Errorf("failed to add ip4 %v: %v", ip4, err) } } if ip6 != nil { if err := upd.AddIP(tableName, set6name, ip6); err != nil { return fmt.Errorf("failed to add ip6 %v: %v", ip6, err) } } if err := upd.Process(); err != nil { return fmt.Errorf("failed to process NFT Updates: %v", err) } return nil }