From b2e4ea024663f3abdfc0704f626304625677a553 Mon Sep 17 00:00:00 2001 From: Alexander Cueva Date: Tue, 24 Jun 2025 15:06:10 -0700 Subject: [PATCH] Netlink_Netfilter process message implementation for adding and getting tables. Added an Nftables field to the stack and functionality and error-checking for creating a new NFTable (NFT_MSG_NEWTABLE) and retrieving that table (NFT_MSG_GETTABLE). Removed hostinet tests, as nftables is only supported for netstack. PiperOrigin-RevId: 775398315 --- pkg/abi/linux/netlink_netfilter.go | 2 + pkg/abi/linux/nf_tables.go | 23 ++ pkg/sentry/socket/netlink/netfilter/BUILD | 4 + .../socket/netlink/netfilter/protocol.go | 110 +++++- pkg/tcpip/nftables/nftables.go | 19 +- pkg/tcpip/nftables/nftables_test.go | 30 +- pkg/tcpip/nftables/nftables_types.go | 3 - pkg/tcpip/stack/nftables_types.go | 6 +- pkg/tcpip/stack/stack.go | 18 + runsc/boot/BUILD | 1 + runsc/boot/loader.go | 5 + runsc/cli/main.go | 3 +- test/syscalls/BUILD | 1 - test/syscalls/linux/BUILD | 10 + .../linux/socket_netlink_netfilter.cc | 353 +++++++++++++++++- .../linux/socket_netlink_netfilter_util.cc | 31 ++ .../linux/socket_netlink_netfilter_util.h | 35 ++ test/syscalls/linux/socket_netlink_util.cc | 49 +++ test/syscalls/linux/socket_netlink_util.h | 19 + test/util/BUILD | 11 +- test/util/socket_util.h | 4 + 21 files changed, 690 insertions(+), 47 deletions(-) create mode 100644 test/syscalls/linux/socket_netlink_netfilter_util.cc create mode 100644 test/syscalls/linux/socket_netlink_netfilter_util.h diff --git a/pkg/abi/linux/netlink_netfilter.go b/pkg/abi/linux/netlink_netfilter.go index 3fb95baf93..646e193d3e 100644 --- a/pkg/abi/linux/netlink_netfilter.go +++ b/pkg/abi/linux/netlink_netfilter.go @@ -35,6 +35,8 @@ const ( ) // NetFilterGenMsg describes the netlink netfilter genmsg message, from uapi/linux/netfilter/nfnetlink.h. +// +// +marshal type NetFilterGenMsg struct { Family uint8 Version uint8 diff --git a/pkg/abi/linux/nf_tables.go b/pkg/abi/linux/nf_tables.go index 17c14bdc77..a18d5fd549 100644 --- a/pkg/abi/linux/nf_tables.go +++ b/pkg/abi/linux/nf_tables.go @@ -124,6 +124,29 @@ const ( NFT_MSG_MAX ) +// NfTableFlags represents table flags that can be set for a table, namely dormant. +// These correspond to values in include/uapi/linux/netfilter/nf_tables.h. +const ( + NFT_TABLE_F_DORMANT = 0x1 +) + +// NfTableAttributes represents the netfilter table attributes. +// These correspond to values in include/uapi/linux/netfilter/nf_tables.h. +const ( + NFTA_TABLE_UNSPEC uint16 = iota + NFTA_TABLE_NAME + NFTA_TABLE_FLAGS + NFTA_TABLE_USE + NFTA_TABLE_HANDLE + NFTA_TABLE_PAD + NFTA_TABLE_USERDATA + NFTA_TABLE_OWNER + __NFTA_TABLE_MAX +) + +// NFTA_TABLE_MAX is the maximum netfilter table attribute. +const NFTA_TABLE_MAX = __NFTA_TABLE_MAX - 1 + // Nf table relational operators. // Used by the nft comparison operation to compare values in registers. // These correspond to enum values in include/uapi/linux/netfilter/nf_tables.h. diff --git a/pkg/sentry/socket/netlink/netfilter/BUILD b/pkg/sentry/socket/netlink/netfilter/BUILD index 0149f3b593..320496c7ac 100644 --- a/pkg/sentry/socket/netlink/netfilter/BUILD +++ b/pkg/sentry/socket/netlink/netfilter/BUILD @@ -12,10 +12,14 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/context", + "//pkg/log", + "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/socket/netlink", "//pkg/sentry/socket/netlink/nlmsg", + "//pkg/sentry/socket/netstack", "//pkg/syserr", "//pkg/tcpip/nftables", + "//pkg/tcpip/stack", ], ) diff --git a/pkg/sentry/socket/netlink/netfilter/protocol.go b/pkg/sentry/socket/netlink/netfilter/protocol.go index 306adbdbc8..552cad2652 100644 --- a/pkg/sentry/socket/netlink/netfilter/protocol.go +++ b/pkg/sentry/socket/netlink/netfilter/protocol.go @@ -18,11 +18,15 @@ package netfilter import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/netlink" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/nlmsg" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip/nftables" + "gvisor.dev/gvisor/pkg/tcpip/stack" ) // Protocol implements netlink.Protocol. @@ -60,22 +64,122 @@ func (p *Protocol) ProcessMessage(ctx context.Context, s *netlink.Socket, msg *n // Netlink message payloads must be of at least the size of the genmsg. Return early if it is not, // from linux/net/netfilter/nfnetlink.c. if netLinkMessagePayloadSize(&hdr) < linux.SizeOfNetfilterGenMsg { + log.Debugf("Netlink message payload is too small: %d < %d", netLinkMessagePayloadSize(&hdr), linux.SizeOfNetfilterGenMsg) return nil } msgType := hdr.NetFilterMsgType() + st := inet.StackFromContext(ctx).(*netstack.Stack).Stack + nft := (st.NFTables()).(*nftables.NFTables) + var nfGenMsg linux.NetFilterGenMsg + + // The payload of a message is its attributes. + atr, ok := msg.GetData(&nfGenMsg) + if !ok { + log.Debugf("Failed to get message data") + return syserr.ErrInvalidArgument + } + + attrs, ok := atr.Parse() + if !ok { + log.Debugf("Failed to parse message attributes") + return syserr.ErrInvalidArgument + } + + // Nftables functions error check the address family value. + family := stack.AddressFamily(nfGenMsg.Family) // TODO: b/421437663 - Match the message type and call the appropriate Nftables function. switch msgType { + case linux.NFT_MSG_NEWTABLE: + return p.newTable(nft, attrs, family, hdr.Flags) + case linux.NFT_MSG_GETTABLE: + return p.getTable(nft, attrs, family, hdr.Flags, ms) default: + log.Debugf("Unsupported message type: %d", msgType) return syserr.ErrInvalidArgument } } -// init registers the NETLINK_NETFILTER provider. -func init() { - netlink.RegisterProvider(linux.NETLINK_NETFILTER, NewProtocol) +// newTable creates a new table for the given family. +func (p *Protocol) newTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, flags uint16) *syserr.Error { + // TODO: b/421437663 - Handle the case where the table name is set to empty string. + // The table name is required. + tabNameBytes, ok := attrs[linux.NFTA_TABLE_NAME] + if !ok { + log.Debugf("Nftables: Table name attribute is malformed or not found") + return syserr.ErrInvalidArgument + } + + var dormant bool + if dbytes, ok := attrs[linux.NFTA_TABLE_FLAGS]; ok { + dflag, _ := dbytes.Uint32() + dormant = (dflag & linux.NFT_TABLE_F_DORMANT) == linux.NFT_TABLE_F_DORMANT + } + + tab, err := nft.GetTable(family, tabNameBytes.String()) + + // If a table already exists, only update its dormant flags if NLM_F_EXCL and NLM_F_REPLACE + // are not set. From net/netfilter/nf_tables_api.c:nf_tables_newtable:nf_tables_updtable + if tab != nil && err == nil { + if flags&linux.NLM_F_EXCL == linux.NLM_F_EXCL { + log.Debugf("Nftables: Table with name: %s already exists", tabNameBytes.String()) + return syserr.ErrExists + } + + if flags&linux.NLM_F_REPLACE == linux.NLM_F_REPLACE { + log.Debugf("Nftables: Table with name: %s already exists and NLM_F_REPLACE is not supported", tabNameBytes.String()) + return syserr.ErrNotSupported + } + } else { + // There does not seem to be a way to add comments to a table using the nft binary. + tab, err = nft.CreateTable(family, tabNameBytes.String()) + if err != nil { + log.Debugf("Nftables: Failed to create table with name: %s. Error: %s", tabNameBytes.String(), err.Error()) + // If there is an error, it is not a duplicate error (checked above). + return syserr.ErrInvalidArgument + } + } + + tab.SetDormant(dormant) + return nil +} + +// getTable returns a table for the given family. Returns nil on success and +// a sys.error on failure. +func (p *Protocol) getTable(nft *nftables.NFTables, attrs map[uint16]nlmsg.BytesView, family stack.AddressFamily, flags uint16, ms *nlmsg.MessageSet) *syserr.Error { + // The table name is required. + tabNameBytes, ok := attrs[linux.NFTA_TABLE_NAME] + if !ok { + log.Debugf("Nftables: Table name attribute is malformed or not found") + return syserr.ErrInvalidArgument + } + + tab, err := nft.GetTable(family, tabNameBytes.String()) + if err != nil { + log.Debugf("Nftables: ENOENT for table with name: %s", tabNameBytes.String()) + return syserr.ErrNoFileOrDir + } + + tabName := tab.GetName() + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: uint16(linux.NFNL_SUBSYS_NFTABLES)<<8 | uint16(linux.NFT_MSG_GETTABLE), + }) + + m.Put(&linux.NetFilterGenMsg{ + Family: uint8(family), + Version: uint8(linux.NFNETLINK_V0), + // Unused, set to 0. + ResourceID: uint16(0), + }) + m.PutAttrString(linux.NFTA_TABLE_NAME, tabName) + return nil } func netLinkMessagePayloadSize(h *linux.NetlinkMessageHeader) int { return int(h.Length) - linux.NetlinkMessageHeaderSize } + +// init registers the NETLINK_NETFILTER provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_NETFILTER, NewProtocol) +} diff --git a/pkg/tcpip/nftables/nftables.go b/pkg/tcpip/nftables/nftables.go index d4d8670995..7274db093b 100644 --- a/pkg/tcpip/nftables/nftables.go +++ b/pkg/tcpip/nftables/nftables.go @@ -24,6 +24,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) +// TODO: b/421437663 - Refactor functions to return a POSIX syserr + // // Interface-Related Methods // @@ -291,7 +293,7 @@ func (nf *NFTables) GetTable(family stack.AddressFamily, tableName string) (*Tab // Note: if the table already exists, the existing table is returned without any // modifications. // Note: Table initialized as not dormant. -func (nf *NFTables) AddTable(family stack.AddressFamily, name string, comment string, +func (nf *NFTables) AddTable(family stack.AddressFamily, name string, errorOnDuplicate bool) (*Table, error) { // Ensures address family is valid. if err := validateAddressFamily(family); err != nil { @@ -325,7 +327,6 @@ func (nf *NFTables) AddTable(family stack.AddressFamily, name string, comment st name: name, afFilter: nf.filters[family], chains: make(map[string]*Chain), - comment: comment, flagSet: make(map[TableFlag]struct{}), } tableMap[name] = t @@ -337,8 +338,8 @@ func (nf *NFTables) AddTable(family stack.AddressFamily, name string, comment st // but also returns an error if a table by the same name already exists. // Note: this interface mirrors the difference between the create and add // commands within the nft binary. -func (nf *NFTables) CreateTable(family stack.AddressFamily, name string, comment string) (*Table, error) { - return nf.AddTable(family, name, comment, true) +func (nf *NFTables) CreateTable(family stack.AddressFamily, name string) (*Table, error) { + return nf.AddTable(family, name, true) } // DeleteTable deletes the specified table from the NFTables object returning @@ -436,16 +437,6 @@ func (t *Table) GetAddressFamily() stack.AddressFamily { return t.afFilter.family } -// GetComment returns the comment of the table. -func (t *Table) GetComment() string { - return t.comment -} - -// SetComment sets the comment of the table. -func (t *Table) SetComment(comment string) { - t.comment = comment -} - // IsDormant returns whether the table is dormant. func (t *Table) IsDormant() bool { _, dormant := t.flagSet[TableFlagDormant] diff --git a/pkg/tcpip/nftables/nftables_test.go b/pkg/tcpip/nftables/nftables_test.go index 1749927657..8a893e7c65 100644 --- a/pkg/tcpip/nftables/nftables_test.go +++ b/pkg/tcpip/nftables/nftables_test.go @@ -528,7 +528,7 @@ func TestEvaluateImmediateVerdict(t *testing.T) { // Sets up an NFTables object with a base chain (for 2 rules) and another // target chain (for 1 rule). nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -592,7 +592,7 @@ func TestEvaluateImmediateBytesData(t *testing.T) { t.Run(tname, func(t *testing.T) { // Sets up an NFTables object with a base chain with policy accept. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -1078,7 +1078,7 @@ func TestEvaluateComparison(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -1315,7 +1315,7 @@ func TestEvaluateRanged(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -1549,7 +1549,7 @@ func TestEvaluatePayloadLoad(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2053,7 +2053,7 @@ func TestEvaluatePayloadSet(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2238,7 +2238,7 @@ func TestEvaluateBitwise(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2304,7 +2304,7 @@ func TestEvaluateCounter(t *testing.T) { t.Run("counter increment tests", func(t *testing.T) { // Sets up an NFTables object with a base chain with policy accept. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2372,7 +2372,7 @@ func TestEvaluateLast(t *testing.T) { fakeClock := faketime.NewManualClock() fixedRNG := rand.RNGFrom(&fixedReader{}) nf := NewNFTables(fakeClock, fixedRNG) - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2499,7 +2499,7 @@ func TestEvaluateRoute(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2740,7 +2740,7 @@ func TestEvaluateByteorder(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -2922,7 +2922,7 @@ func TestEvaluateMetaLoad(t *testing.T) { // Using Manual Clock sets time.Now to Unix Epoch which fixes rng seed! nf := NewNFTables(fakeClock, rand.RNGFrom(&fixedReader{})) - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -3012,7 +3012,7 @@ func TestEvaluateMetaSet(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object with a single table, chain, and rule. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -3470,7 +3470,7 @@ func TestLoopCheckOnRegisterAndUnregister(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up an NFTables object based on test struct. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } @@ -3581,7 +3581,7 @@ func TestMaxNestedJumps(t *testing.T) { t.Run(test.tname, func(t *testing.T) { // Sets up chains of nested jumps or gotos. nf := newNFTablesStd() - tab, err := nf.AddTable(arbitraryFamily, "test", "test table", false) + tab, err := nf.AddTable(arbitraryFamily, "test", false) if err != nil { t.Fatalf("unexpected error for AddTable: %v", err) } diff --git a/pkg/tcpip/nftables/nftables_types.go b/pkg/tcpip/nftables/nftables_types.go index f3d0a6d215..e13828071b 100644 --- a/pkg/tcpip/nftables/nftables_types.go +++ b/pkg/tcpip/nftables/nftables_types.go @@ -178,9 +178,6 @@ type Table struct { // flags is the set of optional flags for the table. // Note: currently nftables only has the single Dormant flag. flagSet map[TableFlag]struct{} - - // comment is the optional comment for the table. - comment string } // hookFunctionStack represents the list of base chains for a specific hook. diff --git a/pkg/tcpip/stack/nftables_types.go b/pkg/tcpip/stack/nftables_types.go index 9926b1788a..5856c587fd 100644 --- a/pkg/tcpip/stack/nftables_types.go +++ b/pkg/tcpip/stack/nftables_types.go @@ -117,8 +117,8 @@ const ( NumAFs ) -// addressFamilyStrings maps address families to their string representation. -var addressFamilyStrings = map[AddressFamily]string{ +// AddressFamilyStrings maps address families to their string representation. +var AddressFamilyStrings = map[AddressFamily]string{ IP: "IPv4", IP6: "IPv6", Inet: "Internet (Both IPv4/IPv6)", @@ -137,7 +137,7 @@ func ValidateAddressFamily(family AddressFamily) error { // String for AddressFamily returns the name of the address family. func (f AddressFamily) String() string { - if af, ok := addressFamilyStrings[f]; ok { + if af, ok := AddressFamilyStrings[f]; ok { return af } panic(fmt.Sprintf("invalid address family: %d", int(f))) diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2d24becc98..d9de59daab 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -118,6 +118,9 @@ type Stack struct { // TODO(gvisor.dev/issue/4595): S/R this field. tables *IPTables `state:"nosave"` + // nftables is the nftables interface for packet filtering and manipulation rules. + nftables NFTablesInterface `state:"nosave"` + // restoredEndpoints is a list of endpoints that need to be restored if the // stack is being restored. restoredEndpoints []RestoredEndpoint @@ -238,6 +241,9 @@ type Options struct { // all traffic. IPTables *IPTables + // NFTables is the nftables interface for packet filtering and manipulation rules. + NFTables NFTablesInterface + // DefaultIPTables is an optional iptables rules constructor that is called // if IPTables is nil. If both fields are nil, iptables will allow all // traffic. @@ -390,6 +396,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, tables: opts.IPTables, + nftables: opts.NFTables, icmpRateLimiter: NewICMPRateLimiter(clock), seed: secureRNG.Uint32(), nudConfigs: opts.NUDConfigs, @@ -1994,6 +2001,7 @@ func (s *Stack) ReplaceConfig(st *Stack) { _ = s.NextNICID() } s.tables = st.tables + s.nftables = st.nftables } // Restore restarts the stack after a restore. This must be called after the @@ -2188,6 +2196,16 @@ func (s *Stack) IPTables() *IPTables { return s.tables } +// NFTables returns the stack's nftables. +func (s *Stack) NFTables() NFTablesInterface { + return s.nftables +} + +// SetNFTables sets the stack's nftables. +func (s *Stack) SetNFTables(nft NFTablesInterface) { + s.nftables = nft +} + // ICMPLimit returns the maximum number of ICMP messages that can be sent // in one second. func (s *Stack) ICMPLimit() rate.Limit { diff --git a/runsc/boot/BUILD b/runsc/boot/BUILD index 30d6c7e5e5..c4390cbbf9 100644 --- a/runsc/boot/BUILD +++ b/runsc/boot/BUILD @@ -121,6 +121,7 @@ go_library( "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/nftables", "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/raw", diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index bc8df90bd9..ff18ba346b 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -71,6 +71,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/nftables" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/raw" @@ -1595,6 +1596,10 @@ func newEmptySandboxNetworkStack(clock tcpip.Clock, allowPacketEndpointWrite boo DefaultIPTables: netfilter.DefaultLinuxTables, })} + if nftables.IsNFTablesEnabled() { + s.Stack.SetNFTables(nftables.NewNFTables(clock, s.Stack.SecureRNG())) + } + // Enable SACK Recovery. { opt := tcpip.TCPSACKEnabled(true) diff --git a/runsc/cli/main.go b/runsc/cli/main.go index 313f6c2792..352a449f56 100644 --- a/runsc/cli/main.go +++ b/runsc/cli/main.go @@ -94,7 +94,8 @@ func Main() { util.Fatalf("%s", err.Error()) } - if conf.Nftables { + // NFtables is only supported for netstack. + if (conf.Network == config.NetworkNone || conf.Network == config.NetworkSandbox) && conf.Nftables { nftables.EnableNFTables() } diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 11a2f05691..a7565d27a2 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -938,7 +938,6 @@ syscall_test( ) syscall_test( - add_hostinet = True, nftables = True, test = "//test/syscalls/linux:socket_netlink_netfilter_test", ) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 14259293a6..3fc93f08a9 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -137,7 +137,9 @@ cc_library( deps = [ "//test/util:file_descriptor", "//test/util:posix_error", + "//test/util:save_util", "//test/util:socket_util", + "//test/util:test_util", "@com_google_absl//absl/strings", ], ) @@ -3502,7 +3504,9 @@ cc_binary( linkstatic = 1, malloc = "//test/util:errno_safe_allocator", deps = select_gtest() + [ + ":socket_netlink_netfilter_util", ":socket_netlink_util", + "//test/util:capability_util", "//test/util:file_descriptor", "//test/util:posix_error", "//test/util:socket_util", @@ -3684,6 +3688,12 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_netlink_netfilter_util", + srcs = ["socket_netlink_netfilter_util.cc"], + hdrs = ["socket_netlink_netfilter_util.h"], +) + cc_binary( name = "socket_stream_local_test", testonly = 1, diff --git a/test/syscalls/linux/socket_netlink_netfilter.cc b/test/syscalls/linux/socket_netlink_netfilter.cc index 0211e950fa..f56f3fce55 100644 --- a/test/syscalls/linux/socket_netlink_netfilter.cc +++ b/test/syscalls/linux/socket_netlink_netfilter.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include +#include +#include #include #include #include @@ -23,8 +23,10 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/strings/str_format.h" +#include "test/syscalls/linux/socket_netlink_netfilter_util.h" #include "test/syscalls/linux/socket_netlink_util.h" #include "test/util/file_descriptor.h" +#include "test/util/linux_capability_util.h" #include "test/util/posix_error.h" #include "test/util/socket_util.h" #include "test/util/test_util.h" @@ -36,6 +38,11 @@ namespace testing { namespace { +constexpr uint32_t kSeq = 12345; + +using ::testing::_; +using ::testing::Eq; + using SockOptTest = ::testing::TestWithParam< std::tuple, std::string>>; @@ -89,6 +96,348 @@ TEST(NetlinkNetfilterTest, CanCreateSocket) { ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); EXPECT_THAT(fd.get(), SyscallSucceeds()); } + +TEST(NetlinkNetfilterTest, AddAndAddTableWithDormantFlag) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + const char test_table_name[] = "test_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct flagAttribute { + struct nlattr attr; + uint32_t flags; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute nattr; + }; + + struct request_2 { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute nattr; + struct flagAttribute fattr; + }; + + struct request add_tab_req = {}; + InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq, NLM_F_REQUEST | NLM_F_ACK); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req.nattr.attr, sizeof(add_tab_req.nattr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req.nattr.name, sizeof(add_tab_req.nattr.name), + test_table_name); + + struct request_2 add_tab_req_2 = {}; + InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq + 1, NLM_F_REQUEST | NLM_F_ACK); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req_2.nattr.attr, sizeof(add_tab_req_2.nattr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req_2.nattr.name, sizeof(add_tab_req_2.nattr.name), + test_table_name); + InitNetlinkAttr(&add_tab_req_2.fattr.attr, sizeof(add_tab_req_2.fattr.flags), + NFTA_TABLE_FLAGS); + add_tab_req_2.fattr.flags = NFT_TABLE_F_DORMANT; + + ASSERT_NO_ERRNO( + NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); + ASSERT_NO_ERRNO(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, + sizeof(add_tab_req_2))); +} + +TEST(NetlinkNetfilterTest, AddAndRetrieveNewTable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + const char test_table_name[] = "test_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute attr; + }; + + struct request add_tab_req = {}; + InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq, NLM_F_REQUEST | NLM_F_ACK); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), + test_table_name); + + struct request add_tab_req_2 = {}; + bool correct_response = false; + InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), + kSeq + 1, NLM_F_REQUEST); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), + test_table_name); + + ASSERT_NO_ERRNO( + NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &add_tab_req_2, sizeof(add_tab_req_2), + [&](const struct nlmsghdr* hdr) { + ASSERT_THAT(hdr->nlmsg_type, Eq(MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, + NFT_MSG_GETTABLE))); + ASSERT_GE(hdr->nlmsg_len, NLMSG_SPACE(sizeof(struct nfgenmsg))); + const struct nfgenmsg* genmsg = + reinterpret_cast(NLMSG_DATA(hdr)); + EXPECT_EQ(genmsg->nfgen_family, AF_INET); + EXPECT_EQ(genmsg->version, NFNETLINK_V0); + + const struct nfattr* nfattr = FindNfAttr(hdr, genmsg, NFTA_TABLE_NAME); + EXPECT_NE(nullptr, nfattr) << "NFTA_TABLE_NAME not found in message."; + if (nfattr == nullptr) { + return; + } + + std::string name(reinterpret_cast(NFA_DATA(nfattr))); + EXPECT_EQ(name, test_table_name); + correct_response = true; + }, + false)); + + ASSERT_TRUE(correct_response); +} + +TEST(NetlinkNetfilterTest, ErrAddExistingTableWithExclusiveFlag) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + const char test_table_name[] = "test_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute attr; + }; + + struct request add_tab_req = {}; + InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq, NLM_F_REQUEST | NLM_F_ACK); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), + test_table_name); + + struct request add_tab_req_2 = {}; + InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq + 1, NLM_F_REQUEST | NLM_F_EXCL); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), + test_table_name); + + ASSERT_NO_ERRNO( + NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, + sizeof(add_tab_req_2)), + PosixErrorIs(EEXIST, _)); +} + +TEST(NetlinkNetfilterTest, ErrAddExistingTableWithReplaceFlag) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + const char test_table_name[] = "test_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute attr; + }; + + struct request add_tab_req = {}; + InitNetlinkHdr(&add_tab_req.hdr, sizeof(add_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq, NLM_F_REQUEST | NLM_F_ACK); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req.attr.attr, sizeof(add_tab_req.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req.attr.name, sizeof(add_tab_req.attr.name), + test_table_name); + + struct request add_tab_req_2 = {}; + InitNetlinkHdr(&add_tab_req_2.hdr, sizeof(add_tab_req_2), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq + 1, NLM_F_REQUEST | NLM_F_REPLACE); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&add_tab_req_2.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&add_tab_req_2.attr.attr, sizeof(add_tab_req_2.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(add_tab_req_2.attr.name, sizeof(add_tab_req_2.attr.name), + test_table_name); + + ASSERT_NO_ERRNO( + NetlinkRequestAckOrError(fd, kSeq, &add_tab_req, sizeof(add_tab_req))); + ASSERT_THAT(NetlinkRequestAckOrError(fd, kSeq + 1, &add_tab_req_2, + sizeof(add_tab_req_2)), + PosixErrorIs(ENOTSUP, _)); +} + +TEST(NetlinkNetfilterTest, ErrAddTableWithUnknownFamily) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint8_t unknown_family = 255; + uint16_t default_table_id = 0; + const char test_table_name[] = "unknown_family_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute attr; + }; + + struct request get_tab_req = {}; + InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_NEWTABLE), + kSeq, NLM_F_REQUEST); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&get_tab_req.msg, unknown_family, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&get_tab_req.attr.attr, sizeof(get_tab_req.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(get_tab_req.attr.name, sizeof(get_tab_req.attr.name), + test_table_name); + + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), + PosixErrorIs(EINVAL, _)); +} + +TEST(NetlinkNetfilterTest, ErrRetrieveNoSpecifiedNameTable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + }; + + struct request get_tab_req = {}; + InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), + kSeq, NLM_F_REQUEST); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&get_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), + PosixErrorIs(EINVAL, _)); +} + +TEST(NetlinkNetfilterTest, ErrRetrieveNonexistentTable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + uint16_t default_table_id = 0; + const char test_table_name[] = "undefined_table"; + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_NETFILTER)); + + struct nameAttribute { + struct nlattr attr; + char name[32]; + }; + struct request { + struct nlmsghdr hdr; + struct nfgenmsg msg; + struct nameAttribute attr; + }; + + struct request get_tab_req = {}; + InitNetlinkHdr(&get_tab_req.hdr, sizeof(get_tab_req), + MakeNetlinkMsgType(NFNL_SUBSYS_NFTABLES, NFT_MSG_GETTABLE), + kSeq, NLM_F_REQUEST); + // For both ipv4 and ipv6 tables. + InitNetfilterGenmsg(&get_tab_req.msg, AF_INET, NFNETLINK_V0, + default_table_id); + // Attribute setting + InitNetlinkAttr(&get_tab_req.attr.attr, sizeof(get_tab_req.attr.name), + NFTA_TABLE_NAME); + absl::SNPrintF(get_tab_req.attr.name, sizeof(get_tab_req.attr.name), + test_table_name); + + ASSERT_THAT( + NetlinkRequestAckOrError(fd, kSeq, &get_tab_req, sizeof(get_tab_req)), + PosixErrorIs(ENOENT, _)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/socket_netlink_netfilter_util.cc b/test/syscalls/linux/socket_netlink_netfilter_util.cc new file mode 100644 index 0000000000..54a7ebef0d --- /dev/null +++ b/test/syscalls/linux/socket_netlink_netfilter_util.cc @@ -0,0 +1,31 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_netlink_netfilter_util.h" + +#include + +namespace gvisor { +namespace testing { + +// Helper function to initialize a nfgenmsg header. +void InitNetfilterGenmsg(struct nfgenmsg* genmsg, uint8_t family, + uint8_t version, uint16_t res_id) { + genmsg->nfgen_family = family; + genmsg->version = version; + genmsg->res_id = res_id; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_netfilter_util.h b/test/syscalls/linux/socket_netlink_netfilter_util.h new file mode 100644 index 0000000000..8241ed3c12 --- /dev/null +++ b/test/syscalls/linux/socket_netlink_netfilter_util.h @@ -0,0 +1,35 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_NETFILTER_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_NETFILTER_UTIL_H_ + +#include +#include +#include +#include +#include + +#include + +namespace gvisor { +namespace testing { + +void InitNetfilterGenmsg(struct nfgenmsg* genmsg, uint8_t family, + uint8_t version, uint16_t res_id); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_NETLINK_NETFILTER_UTIL_H_ diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index c1bff3c65f..d4ad57d6a2 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -19,10 +19,19 @@ #include #include +#include +#include +#include +#include +#include #include #include "absl/strings/str_cat.h" +#include "test/util/file_descriptor.h" +#include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/socket_util.h" +#include "test/util/test_util.h" namespace gvisor { namespace testing { @@ -194,5 +203,45 @@ const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, return nullptr; } +uint16_t MakeNetlinkMsgType(uint8_t subsys_id, uint8_t msg_type) { + return (static_cast(subsys_id) << 8) | + static_cast(msg_type); +} + +// Helper function to initialize a netlink header. +void InitNetlinkHdr(struct nlmsghdr* hdr, uint32_t msg_len, uint16_t msg_type, + uint32_t seq, uint16_t flags) { + hdr->nlmsg_len = msg_len; + hdr->nlmsg_type = msg_type; + hdr->nlmsg_flags = flags; + hdr->nlmsg_seq = seq; +} + +// Helper function to initialize a netlink attribute. +void InitNetlinkAttr(struct nlattr* attr, int payload_size, + uint16_t attr_type) { + attr->nla_len = NLA_HDRLEN + payload_size; + attr->nla_type = attr_type; +} + +// Helper function to find a netlink attribute in a message. +const struct nfattr* FindNfAttr(const struct nlmsghdr* hdr, + const struct nfgenmsg* msg, int16_t attr) { + // The space dedicated to the nlmsghdr and nfgenmsg headers. + const int nf_space = NLMSG_SPACE(sizeof(nfgenmsg)); + + // The hdr->nlmsg_len = nf_space + attribute payload. + int attrlen = hdr->nlmsg_len - nf_space; + // Ensure nf_space is aligned when traversing to the attributes. + const struct nfattr* nfa = reinterpret_cast( + reinterpret_cast(hdr) + NLMSG_ALIGN(nf_space)); + for (; NFA_OK(nfa, attrlen); nfa = NFA_NEXT(nfa, attrlen)) { + if (nfa->nfa_type == attr) { + return nfa; + } + } + return nullptr; +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index 42d7e02402..c8c9c569aa 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -18,9 +18,14 @@ #include // socket.h has to be included before if_arp.h. #include +#include #include #include +#include +#include +#include + #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" @@ -64,6 +69,20 @@ PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq, const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr, const struct ifinfomsg* msg, int16_t attr); +// Helper function to make a netlink message type from a subsystem ID and a +// message type. +uint16_t MakeNetlinkMsgType(uint8_t subsys_id, uint8_t msg_type); + +// Helper function to initialize a netlink header. +void InitNetlinkHdr(struct nlmsghdr* hdr, uint32_t msg_len, uint16_t msg_type, + uint32_t seq, uint16_t flags); + +// Helper function to initialize a netlink attribute. +void InitNetlinkAttr(struct nlattr* attr, int payload_size, uint16_t attr_type); + +// Helper function to find a netlink attribute in a message. +const struct nfattr* FindNfAttr(const struct nlmsghdr* hdr, + const struct nfgenmsg* msg, int16_t attr); } // namespace testing } // namespace gvisor diff --git a/test/util/BUILD b/test/util/BUILD index 10284b9f9b..6fcc5dfe33 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -428,11 +428,12 @@ cc_library( hdrs = ["socket_util.h"], defines = select_system(), deps = default_net_util() + select_gtest() + [ - "//test/util:file_descriptor", - "//test/util:posix_error", - "//test/util:temp_path", - "//test/util:test_util", - "//test/util:thread_util", + ":file_descriptor", + ":posix_error", + ":save_util", + ":temp_path", + ":test_util", + ":thread_util", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/test/util/socket_util.h b/test/util/socket_util.h index 58c58d085b..7d138d12e0 100644 --- a/test/util/socket_util.h +++ b/test/util/socket_util.h @@ -24,6 +24,9 @@ #include #include +#include +#include +#include #include #include #include @@ -34,6 +37,7 @@ #include "absl/strings/str_format.h" #include "test/util/file_descriptor.h" #include "test/util/posix_error.h" +#include "test/util/save_util.h" #include "test/util/test_util.h" namespace gvisor {