diff --git a/swift/common/middleware/list_endpoints.py b/swift/common/middleware/list_endpoints.py index fbe3fad..8be54e2 100644 --- a/swift/common/middleware/list_endpoints.py +++ b/swift/common/middleware/list_endpoints.py @@ -64,6 +64,8 @@ from swift.common.swob import HTTPBadRequest, HTTPMethodNotAllowed from swift.common.storage_policy import POLICIES from swift.proxy.controllers.base import get_container_info +RESPONSE_VERSIONS = (1.0, 2.0) + class ListEndpointsMiddleware(object): """ @@ -87,6 +89,22 @@ class ListEndpointsMiddleware(object): self.endpoints_path = conf.get('list_endpoints_path', '/endpoints/') if not self.endpoints_path.endswith('/'): self.endpoints_path += '/' + response_version_option = conf.get( + 'default_response_version', 'v1') + try: + self.default_response_version = self._parse_version( + response_version_option) + except ValueError: + self.logger.warning( + 'Unknown response version %r specified for ' + '"default_response_version" option in list_endpoints ' + 'middleware configuration section, using v1 response ' + 'as default', response_version_option) + self.default_response_version = 1.0 + self.response_map = { + 1.0: self.v1_format_response, + 2.0: self.v2_format_response, + } def get_object_ring(self, policy_idx): """ @@ -97,6 +115,69 @@ class ListEndpointsMiddleware(object): """ return POLICIES.get_object_ring(policy_idx, self.swift_dir) + def _parse_version(self, raw_version): + err_msg = 'Unsupport version %r' % raw_version + try: + version = float(raw_version.lstrip('v')) + except ValueError: + raise ValueError(err_msg) + if not any(version == v for v in RESPONSE_VERSIONS): + raise ValueError(err_msg) + return version + + def _parse_path(self, request): + """ + Parse path parts of request into a tuple of version, account, + container, obj. Unspecified path parts are filled in as None, + excpet version which is always returned as a float using the + configured default response version if not specified in the + request. + + :param request: the swob request + + :returns: parsed path parts as a tuple with version filled in as + configured default response version if not specified. + :raises: ValueError if path is invalid, message will say why. + """ + clean_path = request.path[len(self.endpoints_path) - 1:] + # try to peal off version + try: + raw_version, rest = split_path(clean_path, 1, 2, True) + except ValueError: + raise ValueError('No account specified') + try: + version = self._parse_version(raw_version) + except ValueError: + if raw_version.startswith('v') and '_' not in raw_version: + # looks more like a invalid version than an account + raise + # probably no version specified, but if the client really + # said /endpoints/v_3/account they'll probably be sorta + # confused by the useless response and lack of error. + version = self.default_response_version + rest = clean_path + else: + rest = '/' + rest + try: + account, container, obj = split_path(rest, 1, 3, True) + except ValueError: + raise ValueError('No account specified') + return version, account, container, obj + + def v1_format_response(self, req, endpoints, **kwargs): + return Response(json.dumps(endpoints), + content_type='application/json') + + def v2_format_response(self, req, endpoints, storage_policy_index, + **kwargs): + resp = { + 'endpoints': endpoints + } + if storage_policy_index is not None: + resp['storage_policy_index'] = int(storage_policy_index) + return Response(json.dumps(resp), + content_type='application/json') + def __call__(self, env, start_response): request = Request(env) if not request.path.startswith(self.endpoints_path): @@ -107,11 +188,9 @@ class ListEndpointsMiddleware(object): req=request, headers={"Allow": "GET"})(env, start_response) try: - clean_path = request.path[len(self.endpoints_path) - 1:] - account, container, obj = \ - split_path(clean_path, 1, 3, True) - except ValueError: - return HTTPBadRequest('No account specified')(env, start_response) + version, account, container, obj = self._parse_path(request) + except ValueError as err: + return HTTPBadRequest(str(err))(env, start_response) if account is not None: account = unquote(account) @@ -120,16 +199,13 @@ class ListEndpointsMiddleware(object): if obj is not None: obj = unquote(obj) + storage_policy_index = None if obj is not None: - # remove 'endpoints' from call to get_container_info - stripped = request.environ - if stripped['PATH_INFO'][:len(self.endpoints_path)] == \ - self.endpoints_path: - stripped['PATH_INFO'] = "/v1/" + \ - stripped['PATH_INFO'][len(self.endpoints_path):] container_info = get_container_info( - stripped, self.app, swift_source='LE') - obj_ring = self.get_object_ring(container_info['storage_policy']) + {'PATH_INFO': '/v1/%s/%s' % (account, container)}, + self.app, swift_source='LE') + storage_policy_index = container_info['storage_policy'] + obj_ring = self.get_object_ring(storage_policy_index) partition, nodes = obj_ring.get_nodes( account, container, obj) endpoint_template = 'http://{ip}:{port}/{device}/{partition}/' + \ @@ -157,8 +233,13 @@ class ListEndpointsMiddleware(object): obj=quote(obj or '')) endpoints.append(endpoint) - return Response(json.dumps(endpoints), - content_type='application/json')(env, start_response) + resp_info = { + 'endpoints': endpoints, + 'storage_policy_index': storage_policy_index, + } + + resp = self.response_map[version](request, **resp_info) + return resp(env, start_response) def filter_factory(global_conf, **local_conf): diff --git a/test/unit/common/middleware/test_list_endpoints.py b/test/unit/common/middleware/test_list_endpoints.py index bb491ad..bd69c6a 100644 --- a/test/unit/common/middleware/test_list_endpoints.py +++ b/test/unit/common/middleware/test_list_endpoints.py @@ -25,7 +25,7 @@ from swift.common.utils import json, split_path from swift.common.swob import Request, Response from swift.common.middleware import list_endpoints from swift.common.storage_policy import StoragePolicy, POLICIES -from test.unit import patch_policies +from test.unit import patch_policies, debug_logger class FakeApp(object): @@ -110,10 +110,53 @@ class TestListEndpoints(unittest.TestCase): info['storage_policy'] = self.policy_to_test (version, account, container, unused) = \ split_path(env['PATH_INFO'], 3, 4, True) - self.assertEquals((version, account, container, unused), - self.expected_path) + self.assertEquals((version, account, container), + self.expected_path[:3]) return info + def test_parse_default_response_version(self): + expectations = { + '1': 1.0, + 'v1': 1.0, + '1.0': 1.0, + 'v1.0': 1.0, + '2': 2.0, + 'v2': 2.0, + '2.0': 2.0, + 'v2.0': 2.0, + } + for option, expected in expectations.items(): + config = { + 'swift_dir': self.testdir, + 'default_response_version': option, + } + filtered_app = list_endpoints.filter_factory(config)(self.app) + self.assertEqual(expected, filtered_app.default_response_version) + + def test_invalid_default_response_version(self): + bad_options = ( + '3', + 'v3', + '3.0', + 'v3.0', + 'a', + '', + ) + for option in bad_options: + config = { + 'swift_dir': self.testdir, + 'default_response_version': option, + } + logger = debug_logger() + patch_path = 'swift.common.middleware.list_endpoints.get_logger' + with mock.patch(patch_path, lambda *args, **kwargs: logger): + filtered_app = list_endpoints.filter_factory(config)(self.app) + warnings = logger.get_lines_for_level('warning') + for warning in warnings: + self.assert_('Unknown response version' in warning) + self.assert_('using v1 response' in warning) + self.assertEqual(1.0, filtered_app.default_response_version) + def test_get_object_ring(self): self.assertEquals(isinstance(self.list_endpoints.get_object_ring(0), ring.Ring), True) @@ -121,6 +164,30 @@ class TestListEndpoints(unittest.TestCase): ring.Ring), True) self.assertRaises(ValueError, self.list_endpoints.get_object_ring, 99) + def test_parse_path_no_version_specified(self): + req = Request.blank('/endpoints/a/c/o1') + version, account, container, obj = \ + self.list_endpoints._parse_path(req) + self.assertEqual(version, + self.list_endpoints.default_response_version) + self.assertEqual(account, 'a') + self.assertEqual(container, 'c') + self.assertEqual(obj, 'o1') + + def test_parse_path_with_valid_version(self): + req = Request.blank('/endpoints/v2/a/c/o1') + version, account, container, obj = \ + self.list_endpoints._parse_path(req) + self.assertEqual(version, 2.0) + self.assertEqual(account, 'a') + self.assertEqual(container, 'c') + self.assertEqual(obj, 'o1') + + def test_parse_path_with_invalid_version(self): + req = Request.blank('/endpoints/v3/a/c/o1') + self.assertRaises(ValueError, self.list_endpoints._parse_path, + req) + def test_get_endpoint(self): # Expected results for objects taken from test_ring # Expected results for others computed by manually invoking @@ -245,6 +312,77 @@ class TestListEndpoints(unittest.TestCase): self.assertEquals(resp.content_type, 'application/json') self.assertEquals(json.loads(resp.body), expected[pol.idx]) + def test_v1_response(self): + req = Request.blank('/endpoints/v1/a/c/o1') + resp = req.get_response(self.list_endpoints) + expected = ["http://10.1.1.1:6000/sdb1/1/a/c/o1", + "http://10.1.2.2:6000/sdd1/1/a/c/o1"] + self.assertEqual(resp.body, json.dumps(expected)) + + def test_v2_obj_response(self): + req = Request.blank('/endpoints/v2/a/c/o1') + resp = req.get_response(self.list_endpoints) + expected = { + 'endpoints': ["http://10.1.1.1:6000/sdb1/1/a/c/o1", + "http://10.1.2.2:6000/sdd1/1/a/c/o1"], + 'storage_policy_index': 0, # FWR, not found is legacy + } + self.assertEqual(resp.body, json.dumps(expected)) + for policy in POLICIES: + patch_path = 'swift.common.middleware.list_endpoints' \ + '.get_container_info' + mock_get_container_info = lambda *args, **kwargs: \ + {'storage_policy': int(policy)} + with mock.patch(patch_path, mock_get_container_info): + resp = req.get_response(self.list_endpoints) + part, nodes = policy.object_ring.get_nodes('a', 'c', 'o1') + [node.update({'part': part}) for node in nodes] + path = 'http://%(ip)s:%(port)s/%(device)s/%(part)s/a/c/o1' + expected = { + 'storage_policy_index': int(policy), + 'endpoints': [path % node for node in nodes], + } + self.assertEqual(resp.body, json.dumps(expected)) + + def test_v2_non_obj_response(self): + # account + req = Request.blank('/endpoints/v2/a') + resp = req.get_response(self.list_endpoints) + expected = { + 'endpoints': ["http://10.1.2.1:6000/sdc1/0/a", + "http://10.1.1.1:6000/sda1/0/a", + "http://10.1.1.1:6000/sdb1/0/a"], + } + # container + self.assertEqual(resp.body, json.dumps(expected)) + req = Request.blank('/endpoints/v2/a/c') + resp = req.get_response(self.list_endpoints) + expected = { + 'endpoints': ["http://10.1.2.2:6000/sdd1/0/a/c", + "http://10.1.1.1:6000/sda1/0/a/c", + "http://10.1.2.1:6000/sdc1/0/a/c"], + } + self.assertEqual(resp.body, json.dumps(expected)) + + def test_default_response(self): + req = Request.blank('/endpoints/a') + + self.list_endpoints.default_response_version = 1.0 + resp = req.get_response(self.list_endpoints) + expected = ["http://10.1.2.1:6000/sdc1/0/a", + "http://10.1.1.1:6000/sda1/0/a", + "http://10.1.1.1:6000/sdb1/0/a"] + self.assertEqual(resp.body, json.dumps(expected)) + + self.list_endpoints.default_response_version = 2.0 + expected = { + 'endpoints': ["http://10.1.2.1:6000/sdc1/0/a", + "http://10.1.1.1:6000/sda1/0/a", + "http://10.1.1.1:6000/sdb1/0/a"], + } + resp = req.get_response(self.list_endpoints) + self.assertEqual(resp.body, json.dumps(expected)) + if __name__ == '__main__': unittest.main()