83 lines
1.9 KiB
Go
83 lines
1.9 KiB
Go
|
package service
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"runtime"
|
||
|
|
||
|
"github.com/google/nftables"
|
||
|
)
|
||
|
|
||
|
func UpdateNFTSets(tableName string, set4name string, ip4 net.IP, set6name string, ip6 net.IP) error {
|
||
|
runtime.LockOSThread()
|
||
|
defer runtime.UnlockOSThread()
|
||
|
|
||
|
conn := &nftables.Conn{}
|
||
|
|
||
|
tables, err := conn.ListTables()
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not list NFT tables: %v", err)
|
||
|
}
|
||
|
|
||
|
var table *nftables.Table = nil
|
||
|
for _, t := range tables {
|
||
|
if t.Name == tableName {
|
||
|
if table == nil {
|
||
|
table = t
|
||
|
} else {
|
||
|
return fmt.Errorf("found two tables with name %s", tableName)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
if table == nil {
|
||
|
return fmt.Errorf("could not find table %s", tableName)
|
||
|
}
|
||
|
|
||
|
if ip4 != nil {
|
||
|
ip4bin := ip4.To4()
|
||
|
if ip4bin == nil {
|
||
|
return fmt.Errorf("ipv4 must be a valid IPv4 address")
|
||
|
}
|
||
|
|
||
|
set4, err := conn.GetSetByName(table, set4name)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not find IPv4 NFT set %s: %v", set4name, err)
|
||
|
}
|
||
|
if set4.KeyType.GetNFTMagic() != nftables.TypeIPAddr.GetNFTMagic() {
|
||
|
return fmt.Errorf("the NFT set %s is not of type ip", set4name)
|
||
|
}
|
||
|
|
||
|
conn.FlushSet(set4)
|
||
|
|
||
|
err = conn.SetAddElements(set4, []nftables.SetElement{{Key: ip4bin}})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to add IP %v to set %s: %v", ip4, set4name, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if ip6 != nil {
|
||
|
ip6bin := ip6.To16()
|
||
|
if ip6bin == nil {
|
||
|
return fmt.Errorf("ipv6 must be a valid IPv6 address")
|
||
|
}
|
||
|
|
||
|
set6, err := conn.GetSetByName(table, set6name)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not find IPv6 NFT set %s: %v", set6name, err)
|
||
|
}
|
||
|
if set6.KeyType.GetNFTMagic() != nftables.TypeIP6Addr.GetNFTMagic() {
|
||
|
return fmt.Errorf("the NFT set %s is not of type ip6", set6name)
|
||
|
}
|
||
|
|
||
|
conn.FlushSet(set6)
|
||
|
|
||
|
err = conn.SetAddElements(set6, []nftables.SetElement{{Key: ip6bin}})
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to add IP %v to set %s: %v", ip6, set6name, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
conn.Flush()
|
||
|
return nil
|
||
|
}
|