=== modified file 'bzrlib/smart/medium.py' --- bzrlib/smart/medium.py 2009-08-07 05:56:29 +0000 +++ bzrlib/smart/medium.py 2009-08-22 21:20:40 +0000 @@ -271,6 +271,10 @@ sock.setblocking(True) self.socket = sock + def _close(self): + self.socket.close() + self.finished = True + def _serve_one_request_unguarded(self, protocol): while protocol.next_read_size(): # We can safely try to read large chunks. If there is less data @@ -291,8 +295,7 @@ def terminate_due_to_error(self): # TODO: This should log to a server log file, but no such thing # exists yet. Andrew Bennetts 2006-09-29. - self.socket.close() - self.finished = True + self._close() def _write_out(self, bytes): osutils.send_all(self.socket, bytes, self._report_activity) === modified file 'bzrlib/smart/server.py' --- bzrlib/smart/server.py 2009-07-20 11:27:05 +0000 +++ bzrlib/smart/server.py 2009-08-24 10:33:00 +0000 @@ -84,6 +84,8 @@ self._started = threading.Event() self._stopped = threading.Event() self.root_client_path = root_client_path + self._keep_track_of_client_connections = False + self._client_connections = None def serve(self, thread_name_suffix=''): self._should_terminate = False @@ -159,12 +161,22 @@ conn.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) handler = medium.SmartServerSocketStreamMedium( conn, self.backing_transport, self.root_client_path) + self._register_client_connection(handler) thread_name = 'smart-server-child' + thread_name_suffix connection_thread = threading.Thread( None, handler.serve, name=thread_name) + connection_thread._my_handler = handler connection_thread.setDaemon(True) connection_thread.start() + def _register_client_connection(self, handler): + if not self._keep_track_of_client_connections: + return + if self._client_connections is None: + import weakref + self._client_connections = weakref.WeakKeyDictionary() + self._client_connections[handler] = True + def start_background_thread(self, thread_name_suffix=''): self._started.clear() self._server_thread = threading.Thread(None, @@ -198,6 +210,9 @@ temp_socket.close() self._stopped.wait() self._server_thread.join() + if self._client_connections: + for x in self._client_connections.keys(): + x._close() class SmartServerHooks(Hooks): @@ -257,6 +272,7 @@ `bzr://127.0.0.1:nnnn/`. Default value is `extra`, so that tests by default will fail unless they do the necessary path translation. """ + self._keep_track_of_client_connections = True if not client_path_extra.startswith('/'): raise ValueError(client_path_extra) from bzrlib.transport.chroot import ChrootServer === modified file 'bzrlib/tests/__init__.py' --- bzrlib/tests/__init__.py 2009-08-20 05:05:59 +0000 +++ bzrlib/tests/__init__.py 2009-08-22 17:13:38 +0000 @@ -774,6 +774,38 @@ return NullProgressView() +COUNTING = 0 +MAIN_THREAD = threading.current_thread() + +def wait_for_threads(): + curr = threading.current_thread() + for t in threading.enumerate(): + if t != curr and t != MAIN_THREAD: + try: + t.join(0.1) + except RuntimeError: + pass + +def not_too_many_threads(): + cnt_beg = threading.active_count() + while cnt_beg > 1: + import gc, time + for i in range(1,5): + wait_for_threads() + gc.collect() + cnt_end = threading.active_count() + if cnt_end >= cnt_beg: + #for t in threading.enumerate(): + # print t + from pydbgr.api import debug + debug() + raise errors.BzrError( + "thread leak: %d threads at test %d" + % (cnt_end, COUNTING)) + else: + cnt_beg = cnt_end + + class TestCase(unittest.TestCase): """Base class for bzr unit tests. @@ -1426,6 +1458,9 @@ addSkip(self, reason) def run(self, result=None): + global COUNTING + COUNTING += 1 + not_too_many_threads() if result is None: result = self.defaultTestResult() for feature in getattr(self, '_test_needs_features', []): if not feature.available(): === modified file 'bzrlib/transport/sftp.py' --- bzrlib/transport/sftp.py 2009-06-10 03:56:49 +0000 +++ bzrlib/transport/sftp.py 2009-08-22 17:29:00 +0000 @@ -1063,6 +1063,7 @@ def _run_server(self, s): ssh_server = paramiko.Transport(s) + self._ssh_server = ssh_server key_file = pathjoin(self._homedir, 'test_rsa.key') f = open(key_file, 'w') f.write(STUB_SERVER_KEY) @@ -1105,6 +1106,9 @@ def tearDown(self): """See bzrlib.transport.Server.tearDown.""" self._listener.stop() + ssh_server = getattr(self, '_ssh_server', None) + if ssh_server: + ssh_server.close() ssh._ssh_vendor_manager._cached_ssh_vendor = self._original_vendor def get_bogus_url(self):