Comment 12 for bug 1820083

Revision history for this message
Heather Lemon (hypothetical-lemon) wrote :

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import ssl
import socket

import threading

from OpenSSL import crypto
from etcd3gw.client import Etcd3Client
from etcd3gw.tests import base
from future.backports.http.server import (HTTPServer as _HTTPServer,
                                          SimpleHTTPRequestHandler, BaseHTTPRequestHandler)

class ETCDMock(_HTTPServer):

    def __init__(self, server_address, handler_class, context):
        _HTTPServer.__init__(self, server_address, handler_class)
        self.context = context

    def __str__(self):
        return ('<%s %s:%s>' %
                (self.__class__.__name__,
                 self.server_name,
                 self.server_port))

    def get_request(self):
        try:
            sock, addr = self.socket.accept()
            sslconn = self.context.wrap_socket(sock, server_side=True)
            self.sock = sock
        except socket.error as e:
            print("failure in etcdservermock: %s" % e)
            exit(1)
        return sslconn, addr

class ETCDMockRequestHandler(SimpleHTTPRequestHandler):
    protocol_version = "HTTP/1.0"

    def do_GET(self):
        if self.path == "/health":
            example_response = b"{health:true}"
            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Content-Length", len(example_response))
            self.end_headers()
            self.wfile.write(example_response)
        else:
            super().do_GET()

    def do_POST(self):
        if self.path == "/maintenance/status":
            example_response = b"{health:true}"
            self.send_response(200)
            self.send_header("Content-Type", "application/json")
            self.send_header("Content-Length", len(example_response))
            self.end_headers()
            self.wfile.write(example_response)
        else:
            super().do_POST()

class ETCDServerThread(threading.Thread):

    def __init__(self, context):

        self.flag = None
        self.server = ETCDMock(('127.0.0.1', 2379),
                               ETCDMockRequestHandler,
                               context)
        self.port = self.server.server_port
        threading.Thread.__init__(self)
        self.daemon = True

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.server)

    def start(self, flag=None):
        self.flag = flag
        threading.Thread.start(self)

    def run(self):
        if self.flag:
            self.flag.set()
        try:
            self.server.serve_forever(0.05)
        finally:
            self.server.server_close()

    def stop(self):
        self.server.shutdown()

def create_self_signed_cert():
    # create a key pair
    pub_key = crypto.PKey()
    pub_key.generate_key(crypto.TYPE_RSA, 2048)

    # create a csr
    csr = crypto.X509Req()
    csr.get_subject().C = "US"
    csr.get_subject().ST = "Boston"
    csr.get_subject().L = "Boston"
    csr.get_subject().O = "Test Company Ltd"
    csr.get_subject().OU = "Test Company Ltd"
    csr.get_subject().CN = "127.0.0.1"
    csr.set_pubkey(pub_key)
    csr.sign(pub_key, "sha256")

    # create a self-signed cert
    cert = crypto.X509()
    cert.get_subject().C = "US"
    cert.get_subject().ST = "Boston"
    cert.get_subject().L = "Boston"
    cert.get_subject().O = "Test Company Ltd"
    cert.get_subject().OU = "Test Company Ltd"
    cert.get_subject().CN = "127.0.0.1"
    cert.set_serial_number(1000)
    cert.gmtime_adj_notBefore(0)
    cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)
    cert.set_issuer(cert.get_subject())
    cert.set_pubkey(pub_key)
    cert.sign(pub_key, "sha256")

    cert_file = 'test.crt'
    key_file = 'test.key'
    ca_file = 'test.ca'

    with open(cert_file, 'w') as crt:
        if crt is not None:
            crt.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8"))
    with open(key_file, 'w') as key:
        if key is not None:
            key.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pub_key).decode("utf-8"))
    with open(ca_file, 'w') as ca:
        if ca is not None:
            ca.write(crypto.dump_certificate_request(crypto.FILETYPE_PEM, csr).decode("utf-8"))

    return cert_file, key_file, cert_file

class TestEtcd3Gateway(base.TestCase):

    def test_client_default(self):
        client = Etcd3Client()
        self.assertEqual("http://localhost:2379/v3alpha/lease/grant",
                         client.get_url("/lease/grant"))

    def test_client_ipv4(self):
        client = Etcd3Client(host="127.0.0.1")
        self.assertEqual("http://127.0.0.1:2379/v3alpha/lease/grant",
                         client.get_url("/lease/grant"))

    def test_client_ipv6(self):
        client = Etcd3Client(host="::1")
        self.assertEqual("http://[::1]:2379/v3alpha/lease/grant",
                         client.get_url("/lease/grant"))

    def test_client_tls(self):
        cert_file, key_file, ca_file = create_self_signed_cert()

        ctx = ssl.SSLContext()
        ctx.load_cert_chain(certfile=cert_file, keyfile=key_file)
        ctx.load_verify_locations(cafile=ca_file)

        server = ETCDServerThread(ctx)
        flag = threading.Event()
        server.start(flag)
        try:
            client = Etcd3Client(host="127.0.0.1", protocol="https", ca_cert=ca_file,
                                 cert_key=key_file,
                                 cert_cert=cert_file, timeout=10)

            response = client.session.get("https://127.0.0.1:2379/health")
            try:
                self.assertEqual(200, response.status_code)
                self.assertEqual("{health:true}", response.text)
            finally:
                response.close()

        except ValueError as e:
            print(e, "Connection failure to TLS etcd")
        finally:
            os.remove(cert_file)
            os.remove(key_file)
            if ca_file != cert_file:
                os.remove(ca_file)
            client.session.close()
            server.stop()