Torben Nehmer
fbfd75bf99
- more flexibility by a struct based approch - allows for collection of all changes from a complex confi - pushes everything at once to NFT
181 lines
4.5 KiB
Go
181 lines
4.5 KiB
Go
package service
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"runtime"
|
|
|
|
"github.com/google/nftables"
|
|
)
|
|
|
|
type NFTUpdate struct {
|
|
Tables map[string]*NFTUpdateTable
|
|
}
|
|
|
|
type NFTUpdateTable struct {
|
|
TableName string
|
|
Sets map[string]*NFTUpdateSet
|
|
done bool
|
|
}
|
|
|
|
type NFTUpdateSet struct {
|
|
SetName string
|
|
IP6Set bool
|
|
IPs []net.IP
|
|
done bool
|
|
}
|
|
|
|
func NewNFTUpdate() *NFTUpdate {
|
|
f := new(NFTUpdate)
|
|
f.Tables = make(map[string]*NFTUpdateTable)
|
|
return f
|
|
}
|
|
|
|
func (nu *NFTUpdate) FindOrAddTable(TableName string) *NFTUpdateTable {
|
|
if nut, ok := nu.Tables[TableName]; ok {
|
|
return nut
|
|
} else {
|
|
nut := NFTUpdateTable{
|
|
TableName: TableName,
|
|
Sets: make(map[string]*NFTUpdateSet),
|
|
done: false,
|
|
}
|
|
nu.Tables[TableName] = &nut
|
|
return &nut
|
|
}
|
|
}
|
|
|
|
func (nut *NFTUpdateTable) FindOrAddSet(SetName string, IP6 bool) (*NFTUpdateSet, error) {
|
|
if nus, ok := nut.Sets[SetName]; ok {
|
|
if nus.IP6Set != IP6 {
|
|
return nil, fmt.Errorf("set %s has been already declared with IP6=%t", SetName, nus.IP6Set)
|
|
}
|
|
return nus, nil
|
|
} else {
|
|
nus := NFTUpdateSet{
|
|
SetName: SetName,
|
|
IP6Set: IP6,
|
|
done: false,
|
|
}
|
|
nut.Sets[SetName] = &nus
|
|
return &nus, nil
|
|
}
|
|
}
|
|
|
|
func (nu *NFTUpdate) AddIP(TableName string, SetName string, IP net.IP) error {
|
|
ip6 := (IP.To4() == nil)
|
|
nut := nu.FindOrAddTable(TableName)
|
|
nus, err := nut.FindOrAddSet(SetName, ip6)
|
|
if err != nil {
|
|
return fmt.Errorf("set %s in Table %T is of ipv6=%t, %v is not", SetName, TableName, nus.IP6Set, IP)
|
|
}
|
|
nus.IPs = append(nus.IPs, IP)
|
|
return nil
|
|
}
|
|
|
|
func (upd *NFTUpdate) Process() error {
|
|
runtime.LockOSThread()
|
|
defer runtime.UnlockOSThread()
|
|
|
|
conn := &nftables.Conn{}
|
|
|
|
tables, err := conn.ListTables()
|
|
if err != nil {
|
|
return fmt.Errorf("could not list NFT tables: %v", err)
|
|
}
|
|
|
|
for _, table := range tables {
|
|
if updTable, ok := upd.Tables[table.Name]; ok {
|
|
|
|
for _, updSet := range updTable.Sets {
|
|
// Try to load set
|
|
set, err := conn.GetSetByName(table, updSet.SetName)
|
|
if err != nil {
|
|
return fmt.Errorf("could not find IPv4 NFT set %s: %v", updSet.SetName, err)
|
|
}
|
|
|
|
// Validate set type
|
|
if updSet.IP6Set {
|
|
if set.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
|
|
return fmt.Errorf("the NFT set %s is not of type ip6", updSet.SetName)
|
|
}
|
|
} else {
|
|
if set.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
|
|
return fmt.Errorf("the NFT set %s is not of type ip", updSet.SetName)
|
|
}
|
|
}
|
|
|
|
// Flush the set so that we can start adding new IPs
|
|
conn.FlushSet(set)
|
|
|
|
// Loop over all IPs to add
|
|
for _, ip := range updSet.IPs {
|
|
// Convert IP to binary representation and validate inbound IP type
|
|
var ipbin net.IP
|
|
if updSet.IP6Set {
|
|
ipbin = ip.To16()
|
|
} else {
|
|
ipbin = ip.To4()
|
|
}
|
|
if ipbin == nil {
|
|
return fmt.Errorf("ip %v not valid for Table %s, Set %s, IP6=%t", ip, table.Name, set.Name, updSet.IP6Set)
|
|
}
|
|
|
|
// Add IP to the set
|
|
err = conn.SetAddElements(set, []nftables.SetElement{{Key: ipbin}})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to add ip %v to Table %s, Set %s, IP6=%t: %v", ip, table.Name, set.Name, updSet.IP6Set, err)
|
|
}
|
|
}
|
|
|
|
updSet.done = true
|
|
}
|
|
|
|
updTable.done = true
|
|
}
|
|
}
|
|
|
|
// Check for unprocessed tables at this point. We use a NFT Table List as outer
|
|
// loop, as we cannot directly lookup an NFT Table by its name. So we need to
|
|
// check for inconsistencies here. Sets are looked up by their name, so we loop
|
|
// over our update structure here, so there is no need to check this.
|
|
for _, updTable := range upd.Tables {
|
|
if !updTable.done {
|
|
return fmt.Errorf("the table %s was not found while updating, aborting", updTable.TableName)
|
|
}
|
|
}
|
|
|
|
// Send the changes.
|
|
conn.Flush()
|
|
return nil
|
|
}
|
|
|
|
func (obj *NFTUpdate) PrettyPrint() string {
|
|
s, err := json.MarshalIndent(obj, "", " ")
|
|
if err != nil {
|
|
log.Fatalf("Failed to pretty print NFT Update Struct via JSON: %v", err)
|
|
}
|
|
return string(s)
|
|
}
|
|
|
|
func NFTUpdateSetsCmd(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error {
|
|
upd := NewNFTUpdate()
|
|
if ip4 != nil {
|
|
if err := upd.AddIP(tableName, set4name, ip4); err != nil {
|
|
return fmt.Errorf("failed to add ip4 %v: %v", ip4, err)
|
|
}
|
|
}
|
|
if ip6 != nil {
|
|
if err := upd.AddIP(tableName, set6name, ip6); err != nil {
|
|
return fmt.Errorf("failed to add ip6 %v: %v", ip6, err)
|
|
}
|
|
}
|
|
if err := upd.Process(); err != nil {
|
|
return fmt.Errorf("failed to process NFT Updates: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|