=== added file 'tests/test_context.py' --- tests/test_context.py 1970-01-01 00:00:00 +0000 +++ tests/test_context.py 2012-05-22 22:59:33 +0000 @@ -0,0 +1,177 @@ +# tests.test_context - test ssl context creation +# +# Copyright 2012 Canonical Ltd. +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License version 3, +# as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranties of +# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR +# PURPOSE. See the GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# +# In addition, as a special exception, the copyright holders give +# permission to link the code of portions of this program with the +# OpenSSL library under certain conditions as described in each +# individual source file, and distribute linked combinations +# including the two. +# You must obey the GNU General Public License in all respects +# for all of the code used other than OpenSSL. If you modify +# file(s) with this exception, you may extend this exception to your +# version of the file(s), but you are not obligated to do so. If you +# do not wish to do so, delete this exception statement from your +# version. If you delete this exception statement from all source +# files in the program, then also delete it here. + +import os + +from OpenSSL import crypto, SSL +from twisted.internet import defer, reactor, ssl +from twisted.trial import unittest +from twisted.web import client, resource, server + +from ubuntuone.storageprotocol import context + + +class FakeCerts(object): + """CA and Server certificate.""" + + def __init__(self, testcase, common_name="fake.domain"): + """Initialize this fake instance.""" + self.cert_dir = os.path.join(testcase.mktemp(), 'certs') + if not os.path.exists(self.cert_dir): + os.makedirs(self.cert_dir) + + ca_key = self._build_key() + ca_req = self._build_request(ca_key, "Fake Cert Authority") + self.ca_cert = self._build_cert(ca_req, ca_req, ca_key) + + server_key = self._build_key() + server_req = self._build_request(server_key, common_name) + server_cert = self._build_cert(server_req, self.ca_cert, ca_key) + + self.server_key_path = self._save_key(server_key, "server_key.pem") + self.server_cert_path = self._save_cert(server_cert, "server_cert.pem") + + def _save_key(self, key, filename): + """Save a certificate.""" + data = crypto.dump_privatekey(crypto.FILETYPE_PEM, key) + return self._save(filename, data) + + def _save_cert(self, cert, filename): + """Save a certificate.""" + data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + return self._save(filename, data) + + def _save(self, filename, data): + """Save a key or certificate, and return the full path.""" + fullpath = os.path.join(self.cert_dir, filename) + if os.path.exists(fullpath): + os.unlink(fullpath) + with open(fullpath, 'wt') as fd: + fd.write(data) + return fullpath + + def _build_key(self): + """Create a private/public key, save it in a temp dir.""" + key = crypto.PKey() + key.generate_key(crypto.TYPE_RSA, 1024) + return key + + def _build_request(self, key, common_name): + """Create a new certificate request.""" + request = crypto.X509Req() + request.get_subject().CN = common_name + request.set_pubkey(key) + request.sign(key, "md5") + return request + + def _build_cert(self, request, ca_cert, ca_key): + """Create a new certificate.""" + certificate = crypto.X509() + certificate.set_serial_number(1) + certificate.set_issuer(ca_cert.get_subject()) + certificate.set_subject(request.get_subject()) + certificate.set_pubkey(request.get_pubkey()) + certificate.gmtime_adj_notBefore(0) + certificate.gmtime_adj_notAfter(3600) # valid for one hour + certificate.sign(ca_key, "md5") + return certificate + + +class FakeResource(resource.Resource): + """A fake resource.""" + + isLeaf = True + + def render(self, request): + """Render this resource.""" + return "ok" + + +class SSLContextTestCase(unittest.TestCase): + """Tests for the context.get_ssl_context function.""" + + @defer.inlineCallbacks + def verify_context(self, server_context, client_context): + """Verify a client context with a given server context.""" + site = server.Site(FakeResource()) + port = reactor.listenSSL(0, site, server_context) + self.addCleanup(port.stopListening) + url = "https://localhost:%d" % port.getHost().port + result = yield client.getPage(url, contextFactory=client_context) + self.assertEqual(result, "ok") + + @defer.inlineCallbacks + def test_no_verify(self): + """Test the no_verify option.""" + certs = FakeCerts(self, "localhost") + server_context = ssl.DefaultOpenSSLContextFactory( + certs.server_key_path, certs.server_cert_path) + client_context = context.get_ssl_context(no_verify=True, + hostname="localhost") + + yield self.verify_context(server_context, client_context) + + @defer.inlineCallbacks + def test_fails_certificate(self): + """A wrong certificate is rejected.""" + certs = FakeCerts(self, "localhost") + server_context = ssl.DefaultOpenSSLContextFactory( + certs.server_key_path, certs.server_cert_path) + client_context = context.get_ssl_context(no_verify=False, + hostname="localhost") + + d = self.verify_context(server_context, client_context) + e = yield self.assertFailure(d, SSL.Error) + self.assertEqual(e[0][0][1], "SSL3_GET_SERVER_CERTIFICATE") + + @defer.inlineCallbacks + def test_fails_hostname(self): + """A wrong hostname is rejected.""" + certs = FakeCerts(self, "thisiswronghost.net") + server_context = ssl.DefaultOpenSSLContextFactory( + certs.server_key_path, certs.server_cert_path) + self.patch(context, "certificates", [certs.ca_cert]) + client_context = context.get_ssl_context(no_verify=False, + hostname="localhost") + + d = self.verify_context(server_context, client_context) + e = yield self.assertFailure(d, SSL.Error) + self.assertEqual(e[0][0][1], "SSL3_GET_SERVER_CERTIFICATE") + + @defer.inlineCallbacks + def test_matches_all(self): + """A valid certificate passes checks.""" + certs = FakeCerts(self, "localhost") + server_context = ssl.DefaultOpenSSLContextFactory( + certs.server_key_path, certs.server_cert_path) + self.patch(context, "certificates", [certs.ca_cert]) + client_context = context.get_ssl_context(no_verify=False, + hostname="localhost") + + yield self.verify_context(server_context, client_context) === modified file 'ubuntuone/storageprotocol/context.py' --- ubuntuone/storageprotocol/context.py 2012-03-29 20:28:09 +0000 +++ ubuntuone/storageprotocol/context.py 2012-05-22 22:59:33 +0000 @@ -34,6 +34,7 @@ from OpenSSL import SSL from twisted.internet import ssl +from twisted.python import log if sys.platform == "win32": # diable pylint warning, as it may be the wrong platform @@ -58,18 +59,45 @@ ssl_cert_location = '/etc/ssl/certs' -def get_ssl_context(no_verify): - """ Get the ssl context """ +ca_file = ssl.Certificate.loadPEM(file(os.path.join(ssl_cert_location, + 'UbuntuOne-Go_Daddy_Class_2_CA.pem'), 'r').read()) +ca_file_2 = ssl.Certificate.loadPEM(file(os.path.join(ssl_cert_location, + 'UbuntuOne-Go_Daddy_CA.pem'), 'r').read()) +certificates = [ca_file.original, ca_file_2.original] + + +class HostnameVerifyContextFactory(ssl.CertificateOptions): + """Does hostname checks in addition to certificate checks.""" + + def __init__(self, hostname, *args, **kwargs): + """Initialize this instance.""" + super(HostnameVerifyContextFactory, self).__init__(*args, **kwargs) + self.expected_hostname = hostname + + def verify_server_hostname(self, conn, cert, errno, depth, preverifyOK): + """Verify the server hostname.""" + if depth == 0: + # No extra checks because U1 certs have the right commonName + if self.expected_hostname != cert.get_subject().commonName: + log.err("Host name does not match certificate. " + "Expected %s but got %s." % (self.expected_hostname, + cert.get_subject().commonName)) + return False + return preverifyOK + + def getContext(self): + """The context returned will verify the hostname too.""" + ctx = super(HostnameVerifyContextFactory, self).getContext() + flags = SSL.VERIFY_PEER | SSL.VERIFY_FAIL_IF_NO_PEER_CERT + ctx.set_verify(flags, self.verify_server_hostname) + return ctx + + +def get_ssl_context(no_verify, hostname): + """Get the ssl context.""" if no_verify: ctx = ssl.ClientContextFactory() else: - ca_file = ssl.Certificate.loadPEM(file( - os.path.join(ssl_cert_location, - 'UbuntuOne-Go_Daddy_Class_2_CA.pem'), 'r').read()) - ca_file_2 = ssl.Certificate.loadPEM(file( - os.path.join(ssl_cert_location, - 'UbuntuOne-Go_Daddy_CA.pem'), 'r').read()) - ctx = ssl.CertificateOptions(verify=True, - caCerts=[ca_file.original, ca_file_2.original], - method=SSL.SSLv23_METHOD) + ctx = HostnameVerifyContextFactory(hostname, verify=True, + caCerts=certificates, method=SSL.SSLv23_METHOD) return ctx