=== modified file 'bzrlib/transport/sftp.py' --- bzrlib/transport/sftp.py +++ bzrlib/transport/sftp.py @@ -22,6 +22,7 @@ import os import random import re +import select import stat import subprocess import sys @@ -300,6 +301,7 @@ except (NoSuchFile,): # What specific errors should we catch here? pass + class SFTPTransport (Transport): """ @@ -905,9 +907,9 @@ nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7 -----END RSA PRIVATE KEY----- """ - - -class SingleListener(threading.Thread): + + +class SocketListener(threading.Thread): def __init__(self, callback): threading.Thread.__init__(self) @@ -917,25 +919,33 @@ self._socket.bind(('localhost', 0)) self._socket.listen(1) self.port = self._socket.getsockname()[1] - self.stop_event = threading.Event() - - def run(self): - s, _ = self._socket.accept() - # now close the listen socket - self._socket.close() - try: - self._callback(s, self.stop_event) - except socket.error: - pass #Ignore socket errors - except Exception, x: - # probably a failed test - warning('Exception from within unit test server thread: %r' % x) + self._stop_event = threading.Event() def stop(self): - self.stop_event.set() + self._stop_event.set() # use a timeout here, because if the test fails, the server thread may # never notice the stop_event. self.join(5.0) + self._socket.close() + + def run(self): + while True: + readable, _, _ = select.select([self._socket], [], [], 0.1) + if self._stop_event.isSet(): + return + if len(readable) == 0: + continue + try: + s, _ = self._socket.accept() + # because the loopback socket is inline, and transports are + # never explicitly closed, best to launch a new thread. + threading.Thread(target=self._callback, args=(s,)).start() + except socket.error, x: + pass #Ignore socket errors + except Exception, x: + # probably a failed test + sys.excepthook(*sys.exc_info()) + warning('Exception from within unit test server thread: %r' % x) class SFTPServer(Server): @@ -959,10 +969,12 @@ """StubServer uses this to log when a new server is created.""" self.logs.append(message) - def _run_server(self, s, stop_event): + def _run_server(self, s): ssh_server = paramiko.Transport(s) key_file = os.path.join(self._homedir, 'test_rsa.key') - file(key_file, 'w').write(STUB_SERVER_KEY) + f = open(key_file, 'w') + f.write(STUB_SERVER_KEY) + f.close() host_key = paramiko.RSAKey.from_private_key_file(key_file) ssh_server.add_server_key(host_key) server = StubServer(self) @@ -972,7 +984,6 @@ event = threading.Event() ssh_server.start_server(event, server) event.wait(5.0) - stop_event.wait(30.0) def setUp(self): global _ssh_vendor @@ -983,7 +994,7 @@ self._server_homedir = self._homedir self._root = '/' # FIXME WINDOWS: _root should be _server_homedir[0]:/ - self._listener = SingleListener(self._run_server) + self._listener = SocketListener(self._run_server) self._listener.setDaemon(True) self._listener.start() @@ -1009,7 +1020,7 @@ super(SFTPServerWithoutSSH, self).__init__() self._vendor = 'loopback' - def _run_server(self, sock, stop_event): + def _run_server(self, sock): class FakeChannel(object): def get_transport(self): return self @@ -1019,6 +1030,8 @@ return '1' def get_hexdump(self): return False + def close(self): + pass server = paramiko.SFTPServer(FakeChannel(), 'sftp', StubServer(self), StubSFTPServer, root=self._root, home=self._server_homedir)