From d678a9dc35194c51eb3e9f450d43d2e48f65cb44 Mon Sep 17 00:00:00 2001 From: "John R. Lenton" Date: Sat, 26 Jan 2019 17:55:19 +0000 Subject: [PATCH] We were alerted that the socket peer credential code was susceptible to an attack allowing creating of root account from an otherwise unprivileged user. The attack is well explained in the bug report referenced below. The quick summary is that the code can be fooled into thinking user is root thanks to specifically crafted client-side socket name. This change fixes that by making us a lot stricter in parsing the remote address we set internally, and by not setting the remote address to contain the user-provided remote address. Fixes: https://bugs.launchpad.net/snapd/+bug/1813365 --- daemon/daemon_test.go | 28 ++++++++++++++-------------- daemon/ucrednet.go | 30 +++++++++++------------------- daemon/ucrednet_test.go | 8 ++++---- 3 files changed, 29 insertions(+), 37 deletions(-) diff --git a/daemon/daemon_test.go b/daemon/daemon_test.go index e3494d569..d2e4de8ee 100644 --- a/daemon/daemon_test.go +++ b/daemon/daemon_test.go @@ -146,7 +146,7 @@ func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) { c.Check(rec.Code, check.Equals, 401, check.Commentf(method)) rec = httptest.NewRecorder() - req.RemoteAddr = "pid=100;uid=0;" + req.RemoteAddr + req.RemoteAddr = "pid=100;uid=0;socket=;" cmd.ServeHTTP(rec, req) c.Check(mck.lastMethod, check.Equals, method) @@ -155,7 +155,7 @@ func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) { req, err := http.NewRequest("POTATO", "", nil) c.Assert(err, check.IsNil) - req.RemoteAddr = "pid=100;uid=0;" + req.RemoteAddr + req.RemoteAddr = "pid=100;uid=0;socket=;" rec := httptest.NewRecorder() cmd.ServeHTTP(rec, req) @@ -171,7 +171,7 @@ func (s *daemonSuite) TestCommandRestartingState(c *check.C) { } req, err := http.NewRequest("GET", "", nil) c.Assert(err, check.IsNil) - req.RemoteAddr = "pid=100;uid=0;" + req.RemoteAddr + req.RemoteAddr = "pid=100;uid=0;socket=;" rec := httptest.NewRecorder() cmd.ServeHTTP(rec, req) @@ -215,7 +215,7 @@ func (s *daemonSuite) TestFillsWarnings(c *check.C) { } req, err := http.NewRequest("GET", "", nil) c.Assert(err, check.IsNil) - req.RemoteAddr = "pid=100;uid=0;" + req.RemoteAddr + req.RemoteAddr = "pid=100;uid=0;socket=;" rec := httptest.NewRecorder() cmd.ServeHTTP(rec, req) @@ -269,7 +269,7 @@ func (s *daemonSuite) TestGuestAccess(c *check.C) { } func (s *daemonSuite) TestSnapctlAccessSnapOKWithUser(c *check.C) { - remoteAddr := "pid=100;uid=1000;socket=" + dirs.SnapSocket + remoteAddr := "pid=100;uid=1000;socket=" + dirs.SnapSocket + ";" get := &http.Request{Method: "GET", RemoteAddr: remoteAddr} put := &http.Request{Method: "PUT", RemoteAddr: remoteAddr} pst := &http.Request{Method: "POST", RemoteAddr: remoteAddr} @@ -283,7 +283,7 @@ func (s *daemonSuite) TestSnapctlAccessSnapOKWithUser(c *check.C) { } func (s *daemonSuite) TestSnapctlAccessSnapOKWithRoot(c *check.C) { - remoteAddr := "pid=100;uid=0;socket=" + dirs.SnapSocket + remoteAddr := "pid=100;uid=0;socket=" + dirs.SnapSocket + ";" get := &http.Request{Method: "GET", RemoteAddr: remoteAddr} put := &http.Request{Method: "PUT", RemoteAddr: remoteAddr} pst := &http.Request{Method: "POST", RemoteAddr: remoteAddr} @@ -297,8 +297,8 @@ func (s *daemonSuite) TestSnapctlAccessSnapOKWithRoot(c *check.C) { } func (s *daemonSuite) TestUserAccess(c *check.C) { - get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;"} - put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;"} + get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;socket=;"} + put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;"} cmd := &Command{d: newTestDaemon(c)} c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized) @@ -321,8 +321,8 @@ func (s *daemonSuite) TestUserAccess(c *check.C) { } func (s *daemonSuite) TestSuperAccess(c *check.C) { - get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=0;"} - put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=0;"} + get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=0;socket=;"} + put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=0;socket=;"} cmd := &Command{d: newTestDaemon(c)} c.Check(cmd.canAccess(get, nil), check.Equals, accessOK) @@ -342,7 +342,7 @@ func (s *daemonSuite) TestSuperAccess(c *check.C) { } func (s *daemonSuite) TestPolkitAccess(c *check.C) { - put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;"} + put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;"} cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"} // polkit says user is not authorised @@ -363,7 +363,7 @@ func (s *daemonSuite) TestPolkitAccess(c *check.C) { } func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) { - get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;"} + get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;socket=;"} cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"} // polkit can grant authorisation for GET requests @@ -379,7 +379,7 @@ func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) { } func (s *daemonSuite) TestPolkitInteractivity(c *check.C) { - put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;", Header: make(http.Header)} + put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;", Header: make(http.Header)} cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"} s.authorized = true @@ -950,7 +950,7 @@ func (s *daemonSuite) TestShutdownServerCanShutdown(c *check.C) { func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder { req, err := http.NewRequest(mth, "", nil) c.Assert(err, check.IsNil) - req.RemoteAddr = "pid=100;uid=0;" + req.RemoteAddr + req.RemoteAddr = "pid=100;uid=0;socket=;" rec := httptest.NewRecorder() cmd.ServeHTTP(rec, req) return rec diff --git a/daemon/ucrednet.go b/daemon/ucrednet.go index 5e981e115..c12db91a3 100644 --- a/daemon/ucrednet.go +++ b/daemon/ucrednet.go @@ -23,8 +23,8 @@ import ( "errors" "fmt" "net" + "regexp" "strconv" - "strings" sys "syscall" ) @@ -35,28 +35,20 @@ const ( ucrednetNobody = uint32((1 << 32) - 1) ) +var raddrRegexp = regexp.MustCompile(`^pid=(\d+);uid=(\d+);socket=([^;]*);$`) + func ucrednetGet(remoteAddr string) (pid uint32, uid uint32, socket string, err error) { pid = ucrednetNoProcess uid = ucrednetNobody - for _, token := range strings.Split(remoteAddr, ";") { - var v uint64 - if strings.HasPrefix(token, "pid=") { - if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil { - pid = uint32(v) - } else { - break - } - } else if strings.HasPrefix(token, "uid=") { - if v, err = strconv.ParseUint(token[4:], 10, 32); err == nil { - uid = uint32(v) - } else { - break - } + subs := raddrRegexp.FindStringSubmatch(remoteAddr) + if subs != nil { + if v, err := strconv.ParseUint(subs[1], 10, 32); err == nil { + pid = uint32(v) } - if strings.HasPrefix(token, "socket=") { - socket = token[7:] + if v, err := strconv.ParseUint(subs[2], 10, 32); err == nil { + uid = uint32(v) } - + socket = subs[3] } if pid == ucrednetNoProcess || uid == ucrednetNobody { err = errNoID @@ -73,7 +65,7 @@ type ucrednetAddr struct { } func (wa *ucrednetAddr) String() string { - return fmt.Sprintf("pid=%s;uid=%s;socket=%s;%s", wa.pid, wa.uid, wa.socket, wa.Addr) + return fmt.Sprintf("pid=%s;uid=%s;socket=%s;", wa.pid, wa.uid, wa.socket) } type ucrednetConn struct { diff --git a/daemon/ucrednet_test.go b/daemon/ucrednet_test.go index 83511a845..ea51db659 100644 --- a/daemon/ucrednet_test.go +++ b/daemon/ucrednet_test.go @@ -144,14 +144,14 @@ func (s *ucrednetSuite) TestUcredErrors(c *check.C) { } func (s *ucrednetSuite) TestGetNoUid(c *check.C) { - pid, uid, _, err := ucrednetGet("pid=100;uid=;") + pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;") c.Check(err, check.Equals, errNoID) - c.Check(pid, check.Equals, uint32(100)) + c.Check(pid, check.Equals, ucrednetNoProcess) c.Check(uid, check.Equals, ucrednetNobody) } func (s *ucrednetSuite) TestGetBadUid(c *check.C) { - pid, uid, _, err := ucrednetGet("pid=100;uid=hello;") + pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;") c.Check(err, check.NotNil) c.Check(pid, check.Equals, uint32(100)) c.Check(uid, check.Equals, ucrednetNobody) @@ -172,7 +172,7 @@ func (s *ucrednetSuite) TestGetNothing(c *check.C) { } func (s *ucrednetSuite) TestGet(c *check.C) { - pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket") + pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;") c.Check(err, check.IsNil) c.Check(pid, check.Equals, uint32(100)) c.Check(uid, check.Equals, uint32(42)) -- 2.20.1