# Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. import os import ssl import socket import threading from OpenSSL import crypto from etcd3gw.client import Etcd3Client from etcd3gw.tests import base from future.backports.http.server import (HTTPServer as _HTTPServer, SimpleHTTPRequestHandler, BaseHTTPRequestHandler) class ETCDMock(_HTTPServer): def __init__(self, server_address, handler_class, context): _HTTPServer.__init__(self, server_address, handler_class) self.context = context def __str__(self): return ('<%s %s:%s>' % (self.__class__.__name__, self.server_name, self.server_port)) def get_request(self): try: sock, addr = self.socket.accept() sslconn = self.context.wrap_socket(sock, server_side=True) self.sock = sock except socket.error as e: print("failure in etcdservermock: %s" % e) exit(1) return sslconn, addr class ETCDMockRequestHandler(SimpleHTTPRequestHandler): protocol_version = "HTTP/1.0" def do_GET(self): if self.path == "/health": example_response = b"{health:true}" self.send_response(200) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", len(example_response)) self.end_headers() self.wfile.write(example_response) else: super().do_GET() def do_POST(self): if self.path == "/maintenance/status": example_response = b"{health:true}" self.send_response(200) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", len(example_response)) self.end_headers() self.wfile.write(example_response) else: super().do_POST() class ETCDServerThread(threading.Thread): def __init__(self, context): self.flag = None self.server = ETCDMock(('127.0.0.1', 2379), ETCDMockRequestHandler, context) self.port = self.server.server_port threading.Thread.__init__(self) self.daemon = True def __str__(self): return "<%s %s>" % (self.__class__.__name__, self.server) def start(self, flag=None): self.flag = flag threading.Thread.start(self) def run(self): if self.flag: self.flag.set() try: self.server.serve_forever(0.05) finally: self.server.server_close() def stop(self): self.server.shutdown() def create_self_signed_cert(): # create a key pair pub_key = crypto.PKey() pub_key.generate_key(crypto.TYPE_RSA, 2048) # create a csr csr = crypto.X509Req() csr.get_subject().C = "US" csr.get_subject().ST = "Boston" csr.get_subject().L = "Boston" csr.get_subject().O = "Test Company Ltd" csr.get_subject().OU = "Test Company Ltd" csr.get_subject().CN = "127.0.0.1" csr.set_pubkey(pub_key) csr.sign(pub_key, "sha256") # create a self-signed cert cert = crypto.X509() cert.get_subject().C = "US" cert.get_subject().ST = "Boston" cert.get_subject().L = "Boston" cert.get_subject().O = "Test Company Ltd" cert.get_subject().OU = "Test Company Ltd" cert.get_subject().CN = "127.0.0.1" cert.set_serial_number(1000) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60) cert.set_issuer(cert.get_subject()) cert.set_pubkey(pub_key) cert.sign(pub_key, "sha256") cert_file = 'test.crt' key_file = 'test.key' ca_file = 'test.ca' with open(cert_file, 'w') as crt: if crt is not None: crt.write(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode("utf-8")) with open(key_file, 'w') as key: if key is not None: key.write(crypto.dump_privatekey(crypto.FILETYPE_PEM, pub_key).decode("utf-8")) with open(ca_file, 'w') as ca: if ca is not None: ca.write(crypto.dump_certificate_request(crypto.FILETYPE_PEM, csr).decode("utf-8")) return cert_file, key_file, cert_file class TestEtcd3Gateway(base.TestCase): def test_client_default(self): client = Etcd3Client() self.assertEqual("http://localhost:2379/v3alpha/lease/grant", client.get_url("/lease/grant")) def test_client_ipv4(self): client = Etcd3Client(host="127.0.0.1") self.assertEqual("http://127.0.0.1:2379/v3alpha/lease/grant", client.get_url("/lease/grant")) def test_client_ipv6(self): client = Etcd3Client(host="::1") self.assertEqual("http://[::1]:2379/v3alpha/lease/grant", client.get_url("/lease/grant")) def test_client_tls(self): cert_file, key_file, ca_file = create_self_signed_cert() ctx = ssl.SSLContext() ctx.load_cert_chain(certfile=cert_file, keyfile=key_file) ctx.load_verify_locations(cafile=ca_file) server = ETCDServerThread(ctx) flag = threading.Event() server.start(flag) try: client = Etcd3Client(host="127.0.0.1", protocol="https", ca_cert=ca_file, cert_key=key_file, cert_cert=cert_file, timeout=10) response = client.session.get("https://127.0.0.1:2379/health") try: self.assertEqual(200, response.status_code) self.assertEqual("{health:true}", response.text) finally: response.close() except ValueError as e: print(e, "Connection failure to TLS etcd") finally: os.remove(cert_file) os.remove(key_file) if ca_file != cert_file: os.remove(ca_file) client.session.close() server.stop()