diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 17fda417bc..bf688072d2 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -828,3 +829,8 @@ func init() { registered[fam] = struct{}{} } } + +// GetPeerCreds implements socket.Socket.GetPeerCreds +func (s *Socket) GetPeerCreds(t *kernel.Task) (marshal.Marshallable, *syserr.Error) { + return nil, syserr.ErrNotSupported +} diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 60403b04db..0167c24060 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -568,6 +568,11 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er return sa, uint32(sa.SizeBytes()), nil } +// GetPeerCreds implements socket.Socket.GetPeerCreds. +func (s *Socket) GetPeerCreds(t *kernel.Task) (marshal.Marshallable, *syserr.Error) { + return nil, syserr.ErrNotSupported +} + // RecvMsg implements socket.Socket.RecvMsg. func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { from := &linux.SockAddrNetlink{ diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 2a5a23d93e..79de589d0e 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -980,14 +980,7 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family if family != linux.AF_UNIX || outLen < unix.SizeofUcred { return nil, syserr.ErrInvalidArgument } - - tcred := t.Credentials() - creds := linux.ControlMessageCredentials{ - PID: int32(t.ThreadGroup().ID()), - UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - } - return &creds, nil + return s.GetPeerCreds(t) case linux.SO_PASSCRED: if outLen < sizeOfInt32 { @@ -3051,6 +3044,10 @@ func (s *sock) GetPeerName(*kernel.Task) (linux.SockAddr, uint32, *syserr.Error) return a, l, nil } +func (s *sock) GetPeerCreds(*kernel.Task) (marshal.Marshallable, *syserr.Error) { + return nil, syserr.ErrNotSupported +} + func (s *sock) fillCmsgInq(cmsg *socket.ControlMessages) { if !s.sockOptInq { return diff --git a/pkg/sentry/socket/plugin/stack/socket.go b/pkg/sentry/socket/plugin/stack/socket.go index 665966958c..058bf6b9cf 100644 --- a/pkg/sentry/socket/plugin/stack/socket.go +++ b/pkg/sentry/socket/plugin/stack/socket.go @@ -569,6 +569,11 @@ func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil } +// GetPeerCreds implements socket.Socket.GetPeerCreds. +func (s *socketOperations) GetPeerCreds(t *kernel.Task) (marshal.Marshallable, *syserr.Error) { + return nil, syserr.ErrInvalidEndpointState +} + // recv is a helper function for doing non-blocking read once. // It returns: // 1. number of bytes received; diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index ac7c2e1317..35d16b98c5 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -257,6 +257,9 @@ type Socket interface { // necessarily the actual length of the address. GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error) + // GetPeerCreds returns the peer credentials of the socket. + GetPeerCreds(t *kernel.Task) (marshal.Marshallable, *syserr.Error) + // RecvMsg implements the recvmsg(2) linux unix. // // senderAddrLen is the address length to be returned to the application, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index ad0fef609c..aed19cb692 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -42,6 +42,7 @@ go_library( "//pkg/sentry/fsutil", "//pkg/sentry/inet", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/ktime", "//pkg/sentry/socket", "//pkg/sentry/socket/control", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 4bb2ea8a1d..38e2b40642 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -97,6 +97,7 @@ go_library( "//pkg/log", "//pkg/refs", "//pkg/sentry/hostfd", + "//pkg/sentry/kernel/auth", "//pkg/sentry/uniqueid", "//pkg/sync", "//pkg/sync/locking", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index deda6ac71f..643dc2d5ff 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -106,6 +106,14 @@ type connectionedEndpoint struct { // tcpip.SockStream. stype linux.SockType + // peerCreds is used to store the peer credentials. + // This will store the socket's own credentials until the connection is + // established with connect(2). Once the connection is established, this + // will store the peer's credentials. The use of this option is possible + // only for connected `AF_UNIX` stream sockets and for `AF_UNIX` stream and + // datagram socket pairs created using socketpair(2) + peerCreds CredentialsControlMessage + // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and // have another associated connectionedEndpoint. @@ -274,6 +282,16 @@ func (e *connectionedEndpoint) Close(ctx context.Context) { } } +func (e *connectionedEndpoint) swapPeerCredsLocked(ctx context.Context, cend ConnectingEndpoint, ne *connectionedEndpoint) *syserr.Error { + ce, ok := cend.(*connectionedEndpoint) + if !ok { + return syserr.ErrInvalidEndpointState + } + // Swap peer credentials between the two endpoints. + ne.peerCreds, ce.peerCreds = ce.peerCreds, ne.peerCreds + return nil +} + // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint), opts UnixSocketOpts) *syserr.Error { if ce.Type() != e.stype { @@ -327,6 +345,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn ne.ops.SetSendBufferSize(defaultBufferSize, false /* notify */) ne.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */) ne.SocketOptions().SetPassCred(e.SocketOptions().GetPassCred()) + ne.peerCreds = e.peerCreds readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: defaultBufferSize} readQueue.InitRefs() @@ -354,6 +373,9 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn } readQueue.IncRef() if e.stype == linux.SOCK_STREAM { + if err := e.swapPeerCredsLocked(ctx, ce, ne); err != nil { + return err + } returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) @@ -655,3 +677,15 @@ func (e *connectionedEndpoint) EventUnregister(we *waiter.Entry) { func (e *connectionedEndpoint) GetAcceptConn() bool { return e.Listening() } + +func (e *connectionedEndpoint) PeerCreds() CredentialsControlMessage { + e.Lock() + defer e.Unlock() + return e.peerCreds +} + +func (e *connectionedEndpoint) SetPeerCreds(creds CredentialsControlMessage) { + e.Lock() + defer e.Unlock() + e.peerCreds = creds +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 7890c73dd5..86d6dfc2ac 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -158,6 +158,14 @@ func (*connectionlessEndpoint) Accept(context.Context, *Address, UnixSocketOpts) return nil, syserr.ErrNotSupported } +func (e *connectionlessEndpoint) PeerCreds() CredentialsControlMessage { + return nil +} + +func (e *connectionlessEndpoint) SetPeerCreds(creds CredentialsControlMessage) { + // no-op +} + // Bind binds the connection. // // For Unix endpoints, this _only sets the address associated with the socket_. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index f83090bcf8..a88b838181 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -52,6 +52,16 @@ type CredentialsControlMessage interface { Equals(CredentialsControlMessage) bool } +// A PeerCredentialer is a socket or endpoint that supports the SO_PEERCREDS socket +// option. +type PeerCredentialer interface { + // PeerCreds returns the peer credentials. + PeerCreds() CredentialsControlMessage + + // SetPeerCreds sets the peer credentials. + SetPeerCreds(creds CredentialsControlMessage) +} + // A ControlMessages represents a collection of socket control messages. // // +stateify savable @@ -151,6 +161,7 @@ type UnixSocketOpts struct { // etc. to Unix socket implementations. type Endpoint interface { Credentialer + PeerCredentialer waiter.Waitable // Close puts the endpoint in a closed state and frees all resources diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index e0c6686c97..0a1e6d7040 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/ktime" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" @@ -422,6 +423,10 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs. ep2.Close(t) return nil, nil, err } + // Initialize the peer credentials. + creds := control.MakeCreds(t) + ep1.SetPeerCreds(creds) + ep2.SetPeerCreds(creds) return s1, s2, nil } @@ -507,6 +512,30 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er return a, l, nil } +// GetPeerCreds returns the peer credentials of the socket backed by a +// transport.Endpoint. +func (s *Socket) GetPeerCreds(t *kernel.Task) (marshal.Marshallable, *syserr.Error) { + pCreds := s.ep.PeerCreds() + if pCreds == nil { + // https://elixir.bootlin.com/linux/v6.16-rc7/source/net/core/sock.c#L1905 + return &linux.ControlMessageCredentials{ + PID: 0, + UID: auth.NoID, + GID: auth.NoID, + }, nil + } + scmCreds, ok := pCreds.(control.SCMCredentials) + if !ok { + return nil, syserr.ErrInvalidEndpointState + } + pid, uid, gid := scmCreds.Credentials(t) + return &linux.ControlMessageCredentials{ + PID: int32(pid), + UID: uint32(uid), + GID: uint32(gid), + }, nil +} + // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -522,7 +551,11 @@ func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error { - return s.ep.Listen(t, backlog) + if err := s.ep.Listen(t, backlog); err != nil { + return err + } + s.ep.SetPeerCreds(control.MakeCreds(t)) + return nil } // extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix @@ -581,6 +614,9 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } defer ep.Release(t) + // Initialize the peer credentials of client endpoint. + s.ep.SetPeerCreds(control.MakeCreds(t)) + // Connect the server endpoint. err = s.ep.Connect(t, ep, t.Kernel().UnixSocketOpts) diff --git a/test/image/BUILD b/test/image/BUILD index 4ca950f803..a53798cdea 100644 --- a/test/image/BUILD +++ b/test/image/BUILD @@ -29,7 +29,6 @@ go_test( deps = [ "//pkg/test/dockerutil", "//pkg/test/testutil", - "@com_github_cenkalti_backoff//:go_default_library", "@com_github_docker_docker//api/types/mount:go_default_library", "@in_gopkg_yaml_v3//:go_default_library", ], diff --git a/test/image/image_test.go b/test/image/image_test.go index 2b09f825d6..d38d6612cd 100644 --- a/test/image/image_test.go +++ b/test/image/image_test.go @@ -33,7 +33,6 @@ import ( "testing" "time" - "github.com/cenkalti/backoff" "github.com/docker/docker/api/types/mount" yaml "gopkg.in/yaml.v3" "gvisor.dev/gvisor/pkg/test/dockerutil" @@ -500,18 +499,6 @@ func testDockerMatrix(ctx context.Context, t *testing.T, d *dockerutil.Container } name := strings.Join(nameParts, "_") t.Run(name, func(t *testing.T) { - if err := backoff.Retry(func() error { - output, err := dockerInGvisorExecOutput(ctx, d, []string{"docker", "info"}) - if err != nil { - return fmt.Errorf("docker exec failed: %v", err) - } - if !strings.Contains(output, "Cannot connect to the Docker daemon") { - return nil - } - return fmt.Errorf("docker daemon not ready") - }, backoff.WithMaxRetries(backoff.NewConstantBackOff(100*time.Millisecond), 10)); err != nil { - t.Fatalf("failed to run docker test %q: %v", name, err) - } def.testFunc(ctx, t, d, opts) }) } @@ -612,18 +599,6 @@ func removeDockerImage(ctx context.Context, imageName string, d *dockerutil.Cont return nil } -func dockerInGvisorExecOutput(ctx context.Context, d *dockerutil.Container, cmd []string) (string, error) { - execProc, err := d.ExecProcess(ctx, dockerutil.ExecOpts{}, cmd...) - if err != nil { - return "", fmt.Errorf("docker exec failed: %v", err) - } - output, err := execProc.Logs() - if err != nil { - return "", fmt.Errorf("docker logs failed: %v", err) - } - return output, nil -} - func testDockerRun(ctx context.Context, t *testing.T, d *dockerutil.Container, opts dockerCommandOptions) { cmd := []string{"docker", "run", "--rm"} if opts.hostNetwork { @@ -633,12 +608,15 @@ func testDockerRun(ctx context.Context, t *testing.T, d *dockerutil.Container, o cmd = append(cmd, "--privileged") } cmd = append(cmd, testAlpineImage, "sh", "-c", "apk add curl && apk info -d curl") - - expectedOutput := "URL retrival utility and library" - output, err := dockerInGvisorExecOutput(ctx, d, cmd) + execProc, err := d.ExecProcess(ctx, dockerutil.ExecOpts{}, cmd...) if err != nil { t.Fatalf("docker exec failed: %v", err) } + output, err := execProc.Logs() + if err != nil { + t.Fatalf("docker logs failed: %v", err) + } + expectedOutput := "URL retrival utility and library" if !strings.Contains(output, expectedOutput) { t.Fatalf("docker didn't get output expected: %q, got: %q", expectedOutput, output) } @@ -652,10 +630,14 @@ func testDockerBuild(ctx context.Context, t *testing.T, d *dockerutil.Container, imageName := strings.ToLower(strings.ReplaceAll(testutil.RandomID("test_docker_build"), "/", "-")) parts = append(parts, "-t", imageName, "-f", "-", ".") cmd := strings.Join(parts, " ") - _, err := dockerInGvisorExecOutput(ctx, d, []string{"/bin/sh", "-c", cmd}) + dockerBuildProc, err := d.ExecProcess(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", cmd) if err != nil { t.Fatalf("docker exec failed: %v", err) } + _, err = dockerBuildProc.Logs() + if err != nil { + t.Fatalf("docker logs failed: %v", err) + } defer removeDockerImage(ctx, imageName, d) if err := checkDockerImage(ctx, imageName, d); err != nil { t.Fatalf("failed to find docker image: %v", err) @@ -680,9 +662,13 @@ func testDockerExec(ctx context.Context, t *testing.T, d *dockerutil.Container, }() for i := 0; i < 10; i++ { - inspectOutput, err := dockerInGvisorExecOutput(ctx, d, []string{"docker", "container", "inspect", containerName}) + inspectProc, err := d.ExecProcess(ctx, dockerutil.ExecOpts{}, []string{"docker", "container", "inspect", containerName}...) + if err != nil { + t.Fatalf("docker container inspect failed: %v", err) + } + inspectOutput, err := inspectProc.Logs() if err != nil { - t.Fatalf("docker exec failed: %v", err) + t.Fatalf("docker logs failed: %v", err) } if strings.Contains(inspectOutput, "\"Status\": \"running\"") { break @@ -696,11 +682,15 @@ func testDockerExec(ctx context.Context, t *testing.T, d *dockerutil.Container, } // Execute echo command in the container. execCmd = append(execCmd, containerName, "echo", "exec in "+containerName) - - output, err := dockerInGvisorExecOutput(ctx, d, execCmd) + execProc, err := d.ExecProcess(ctx, dockerutil.ExecOpts{}, execCmd...) if err != nil { t.Fatalf("docker exec failed: %v", err) } + + output, err := execProc.Logs() + if err != nil { + t.Fatalf("docker logs failed: %v", err) + } expectedOutput := "exec in " + containerName if !strings.Contains(output, expectedOutput) { t.Fatalf("docker didn't get output expected: %q, got: %q", expectedOutput, output) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 086569c812..dd53cc79b5 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -394,6 +394,47 @@ cc_binary( ], ) +cc_library( + name = "socket_unix_peercred_test_cases", + testonly = 1, + srcs = [ + "socket_unix_peercred.cc", + ], + hdrs = [ + "socket_unix_peercred.h", + ], + deps = select_gtest() + [ + "//test/util:capability_util", + "//test/util:cleanup", + "//test/util:eventfd_util", + "//test/util:file_descriptor", + "//test/util:fs_util", + "//test/util:logging", + "//test/util:memory_util", + "//test/util:mount_util", + "//test/util:multiprocess_util", + "//test/util:posix_error", + "//test/util:proc_util", + "//test/util:save_util", + "//test/util:socket_util", + "//test/util:temp_path", + "//test/util:test_util", + "//test/util:thread_util", + "//test/util:time_util", + "//test/util:timer_util", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:btree", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/synchronization", + "@com_google_absl//absl/time", + ], + alwayslink = 1, +) + cc_binary( name = "chdir_test", testonly = 1, @@ -3930,6 +3971,7 @@ cc_binary( malloc = "//test/util:errno_safe_allocator", deps = select_gtest() + [ ":socket_unix_cmsg_test_cases", + ":socket_unix_peercred_test_cases", ":socket_unix_test_cases", ":unix_domain_socket_test_util", "//test/util:socket_util", diff --git a/test/syscalls/linux/socket_unix_pair.cc b/test/syscalls/linux/socket_unix_pair.cc index 28a4373391..9080ecf930 100644 --- a/test/syscalls/linux/socket_unix_pair.cc +++ b/test/syscalls/linux/socket_unix_pair.cc @@ -16,6 +16,7 @@ #include "test/syscalls/linux/socket_unix.h" #include "test/syscalls/linux/socket_unix_cmsg.h" +#include "test/syscalls/linux/socket_unix_peercred.h" #include "test/syscalls/linux/unix_domain_socket_test_util.h" #include "test/util/socket_util.h" #include "test/util/test_util.h" @@ -39,6 +40,10 @@ INSTANTIATE_TEST_SUITE_P( AllUnixDomainSockets, UnixSocketPairCmsgTest, ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); +INSTANTIATE_TEST_SUITE_P( + AllUnixDomainSockets, UnixSocketPeerCredTest, + ::testing::ValuesIn(IncludeReversals(GetSocketPairs()))); + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_peercred.cc b/test/syscalls/linux/socket_unix_peercred.cc new file mode 100644 index 0000000000..c3ac7d6996 --- /dev/null +++ b/test/syscalls/linux/socket_unix_peercred.cc @@ -0,0 +1,195 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_unix_peercred.h" + +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "test/util/logging.h" +#include "test/util/multiprocess_util.h" +#include "test/util/posix_error.h" +#include "test/util/save_util.h" +#include "test/util/socket_util.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void AssertCredSetTo(struct ucred peerCreds, pid_t pid, uid_t uid, gid_t gid) { + ASSERT_EQ(peerCreds.pid, pid); + ASSERT_EQ(peerCreds.uid, uid); + ASSERT_EQ(peerCreds.gid, gid); +} + +void TestCredSetTo(struct ucred peerCreds, pid_t pid, uid_t uid, gid_t gid) { + TEST_PCHECK_MSG(peerCreds.pid == pid, "peer pid does not match expected pid"); + TEST_PCHECK_MSG(peerCreds.uid == uid, "peer uid does not match expected uid"); + TEST_PCHECK_MSG(peerCreds.gid == gid, "peer gid does not match expected gid"); +} + +TEST_P(UnixSocketPeerCredTest, GetPeerCred) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + + struct ucred cred; + socklen_t len = sizeof(cred); + ASSERT_THAT( + getsockopt(sockets->first_fd(), SOL_SOCKET, SO_PEERCRED, &cred, &len), + SyscallSucceeds()); + + AssertCredSetTo(cred, getpid(), getuid(), getgid()); +} + +TEST_P(UnixSocketPeerCredTest, PeerCredBeforeListen) { + if (GetParam().type != SOCK_STREAM) { + GTEST_SKIP() << "Test requires SOCK_STREAM"; + } + + auto server_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + // Before listen, the peer credentials should not be set. + struct ucred peerCreds; + socklen_t peerCredsLen = sizeof(peerCreds); + ASSERT_THAT(getsockopt(server_socket.get(), SOL_SOCKET, SO_PEERCRED, + &peerCreds, &peerCredsLen), + SyscallSucceeds()); + AssertCredSetTo(peerCreds, 0, -1, -1); +} + +TEST_P(UnixSocketPeerCredTest, PeerCredAfterListen) { + if (GetParam().type != SOCK_STREAM) { + GTEST_SKIP() << "Test requires SOCK_STREAM"; + } + auto addr = ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true, AF_UNIX)); + + auto server_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(bind(server_socket.get(), AsSockAddr(&addr), sizeof(addr)), + SyscallSucceeds()); + ASSERT_THAT(listen(server_socket.get(), 5), SyscallSucceeds()); + // After listen, the peer credentials should be set to process credentials. + struct ucred peerCreds; + socklen_t peerCredsLen = sizeof(peerCreds); + ASSERT_THAT(getsockopt(server_socket.get(), SOL_SOCKET, SO_PEERCRED, + &peerCreds, &peerCredsLen), + SyscallSucceeds()); + AssertCredSetTo(peerCreds, getpid(), getuid(), getgid()); +} + +TEST_P(UnixSocketPeerCredTest, AfterConnectClientPeerCred) { + if (GetParam().type != SOCK_STREAM) { + GTEST_SKIP() << "Test requires SOCK_STREAM"; + } + auto addr = ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true, AF_UNIX)); + auto server_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(bind(server_socket.get(), AsSockAddr(&addr), sizeof(addr)), + SyscallSucceeds()); + ASSERT_THAT(listen(server_socket.get(), 5), SyscallSucceeds()); + + pid_t server_pid = getpid(); + gid_t server_gid = getgid(); + uid_t server_uid = getuid(); + const auto client = [&] { + auto client_socket = Socket(AF_UNIX, SOCK_STREAM, 0); + TEST_PCHECK_MSG(client_socket.ok(), "client socket failed"); + TEST_PCHECK_MSG(connect(client_socket.ValueOrDie().get(), AsSockAddr(&addr), + sizeof(addr)) == 0, + "connect failed"); + struct ucred clientPeerCred; + socklen_t len = sizeof(clientPeerCred); + // After connect, the peer credentials should be set to the server's + // credentials. + TEST_PCHECK_MSG(getsockopt(client_socket.ValueOrDie().get(), SOL_SOCKET, + SO_PEERCRED, &clientPeerCred, &len) == 0, + "client getsockopt failed"); + TestCredSetTo(clientPeerCred, server_pid, server_uid, server_gid); + }; + EXPECT_THAT(InForkedProcess(client), IsPosixErrorOkAndHolds(0)); +} + +TEST_P(UnixSocketPeerCredTest, AfterConnectServerPeerCred) { + if (GetParam().type != SOCK_STREAM) { + GTEST_SKIP() << "Test requires SOCK_STREAM"; + } + // Create a pipe. + int pipe_fd[2]; + ASSERT_THAT(pipe(pipe_fd), SyscallSucceeds()); + // Create a server socket. + auto addr = ASSERT_NO_ERRNO_AND_VALUE(UniqueUnixAddr(true, AF_UNIX)); + auto server_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_UNIX, SOCK_STREAM, 0)); + ASSERT_THAT(bind(server_socket.get(), AsSockAddr(&addr), sizeof(addr)), + SyscallSucceeds()); + ASSERT_THAT(listen(server_socket.get(), 5), SyscallSucceeds()); + + const auto client = [&] { + TEST_PCHECK_MSG(close(pipe_fd[1]) == 0, "close failed"); + auto client_socket = Socket(AF_UNIX, SOCK_STREAM, 0); + auto client_soc = client_socket.ValueOrDie().get(); + TEST_PCHECK_MSG(client_socket.ok(), "client socket failed"); + TEST_PCHECK_MSG(connect(client_soc, AsSockAddr(&addr), sizeof(addr)) == 0, + "connect failed"); + // Wait for the server to close the connection. + char ok = 0; + TEST_PCHECK_MSG(read(pipe_fd[0], &ok, sizeof(ok)) == sizeof(ok), + "read failed"); + TEST_PCHECK_MSG(close(pipe_fd[0]) == 0, "closing pipe failed"); + TEST_PCHECK_MSG(close(client_soc) == 0, "closing client socket failed"); + }; + pid_t pid = fork(); + if (pid == 0) { + client(); + _exit(0); + } + MaybeSave(); + ASSERT_GT(pid, 0); + ASSERT_THAT(close(pipe_fd[0]), SyscallSucceeds()); + + char ok = 1; + ASSERT_THAT(write(pipe_fd[1], &ok, sizeof(ok)), + SyscallSucceedsWithValue(sizeof(ok))); + ASSERT_THAT(close(pipe_fd[1]), SyscallSucceeds()); + auto accepted_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(server_socket.get(), nullptr, nullptr)); + + struct ucred serverPeerCred; + socklen_t len = sizeof(serverPeerCred); + ASSERT_THAT(getsockopt(accepted_socket.get(), SOL_SOCKET, SO_PEERCRED, + &serverPeerCred, &len), + SyscallSucceeds()); + // After connection is established, the peer credentials should be set to the + // client's credentials. + AssertCredSetTo(serverPeerCred, pid, getuid(), getgid()); + struct ucred serverSocPeerCred; + ASSERT_THAT(getsockopt(server_socket.get(), SOL_SOCKET, SO_PEERCRED, + &serverSocPeerCred, &len), + SyscallSucceeds()); + // The listening socket's credentials should remain its own. + AssertCredSetTo(serverSocPeerCred, getpid(), getuid(), getgid()); + // Wait for the client to exit. + int status; + ASSERT_GE(waitpid(pid, &status, 0), 0); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_unix_peercred.h b/test/syscalls/linux/socket_unix_peercred.h new file mode 100644 index 0000000000..2162b1cbd7 --- /dev/null +++ b/test/syscalls/linux/socket_unix_peercred.h @@ -0,0 +1,30 @@ +// Copyright 2025 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_PEERCRED_H_ +#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_PEERCRED_H_ + +#include "test/util/socket_util.h" + +namespace gvisor { +namespace testing { + +// Test fixture for tests that apply to pairs of connected unix sockets about +// SO_PEERCRED. +using UnixSocketPeerCredTest = SocketPairTest; + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_LINUX_SOCKET_UNIX_PEERCRED_H_