package service import ( "fmt" "net" "runtime" "github.com/google/nftables" ) func UpdateNFTSets(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error { runtime.LockOSThread() defer runtime.UnlockOSThread() conn := &nftables.Conn{} tables, err := conn.ListTables() if err != nil { return fmt.Errorf("could not list NFT tables: %v", err) } var table *nftables.Table = nil for _, t := range tables { if t.Name == tableName { if table == nil { table = t } else { return fmt.Errorf("found two tables with name %s", tableName) } } } if table == nil { return fmt.Errorf("could not find table %s", tableName) } if ip4 != nil { ip4bin := ip4.To4() if ip4bin == nil { return fmt.Errorf("ipv4 must be a valid IPv4 address") } set4, err := conn.GetSetByName(table, set4name) if err != nil { return fmt.Errorf("could not find IPv4 NFT set %s: %v", set4name, err) } if set4.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() { return fmt.Errorf("the NFT set %s is not of type ip", set4name) } conn.FlushSet(set4) err = conn.SetAddElements(set4, []nftables.SetElement{{Key: ip4bin}}) if err != nil { return fmt.Errorf("failed to add IP %v to set %s: %v", ip4, set4name, err) } } if ip6 != nil { ip6bin := ip6.To16() if ip6bin == nil { return fmt.Errorf("ipv6 must be a valid IPv6 address") } set6, err := conn.GetSetByName(table, set6name) if err != nil { return fmt.Errorf("could not find IPv6 NFT set %s: %v", set6name, err) } if set6.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() { return fmt.Errorf("the NFT set %s is not of type ip6", set6name) } conn.FlushSet(set6) err = conn.SetAddElements(set6, []nftables.SetElement{{Key: ip6bin}}) if err != nil { return fmt.Errorf("failed to add IP %v to set %s: %v", ip6, set6name, err) } } conn.Flush() return nil }