diff --git a/bs4/element.py b/bs4/element.py index c70ad5a..8042fd4 100644 --- a/bs4/element.py +++ b/bs4/element.py @@ -1286,9 +1286,23 @@ class Tag(PageElement): def select(self, selector, _candidate_generator=None, limit=None): """Perform a CSS selection operation on the current element.""" - # Remove whitespace directly after the grouping operator ',' - # then split into tokens. - tokens = re.sub(',[\s]*',',', selector).split() + # Handle grouping selectors if ',' exists, ie: p,a + if ',' in selector: + context = [] + for partial_selector in selector.split(','): + partial_selector = partial_selector.strip() + if partial_selector == '': + raise ValueError('Invalid group selection syntax: %s' % selector) + candidates = self.select(partial_selector, limit=limit) + for candidate in candidates: + if candidate not in context: + context.append(candidate) + + if limit and len(context) >= limit: + break + return context + + tokens = selector.split() current_context = [self] if tokens[-1] in self._selector_combinators: @@ -1298,198 +1312,192 @@ class Tag(PageElement): if self._select_debug: print 'Running CSS selector "%s"' % selector - for index, token_group in enumerate(tokens): + for index, token in enumerate(tokens): new_context = [] new_context_ids = set([]) - # Grouping selectors, ie: p,a - grouped_tokens = token_group.split(',') - if '' in grouped_tokens: - raise ValueError('Invalid group selection syntax: %s' % token_group) - if tokens[index-1] in self._selector_combinators: # This token was consumed by the previous combinator. Skip it. if self._select_debug: print ' Token was consumed by the previous combinator.' continue - for token in grouped_tokens: - if self._select_debug: - print ' Considering token "%s"' % token - recursive_candidate_generator = None - tag_name = None - - # Each operation corresponds to a checker function, a rule - # for determining whether a candidate matches the - # selector. Candidates are generated by the active - # iterator. - checker = None - - m = self.attribselect_re.match(token) - if m is not None: - # Attribute selector - tag_name, attribute, operator, value = m.groups() - checker = self._attribute_checker(operator, attribute, value) - - elif '#' in token: - # ID selector - tag_name, tag_id = token.split('#', 1) - def id_matches(tag): - return tag.get('id', None) == tag_id - checker = id_matches - - elif '.' in token: - # Class selector - tag_name, klass = token.split('.', 1) - classes = set(klass.split('.')) - def classes_match(candidate): - return classes.issubset(candidate.get('class', [])) - checker = classes_match - - elif ':' in token: - # Pseudo-class - tag_name, pseudo = token.split(':', 1) - if tag_name == '': - raise ValueError( - "A pseudo-class must be prefixed with a tag name.") - pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo) - found = [] - if pseudo_attributes is None: - pseudo_type = pseudo - pseudo_value = None - else: - pseudo_type, pseudo_value = pseudo_attributes.groups() - if pseudo_type == 'nth-of-type': - try: - pseudo_value = int(pseudo_value) - except: - raise NotImplementedError( - 'Only numeric values are currently supported for the nth-of-type pseudo-class.') - if pseudo_value < 1: - raise ValueError( - 'nth-of-type pseudo-class value must be at least 1.') - class Counter(object): - def __init__(self, destination): - self.count = 0 - self.destination = destination - - def nth_child_of_type(self, tag): - self.count += 1 - if self.count == self.destination: - return True - if self.count > self.destination: - # Stop the generator that's sending us - # these things. - raise StopIteration() - return False - checker = Counter(pseudo_value).nth_child_of_type - else: + if self._select_debug: + print ' Considering token "%s"' % token + recursive_candidate_generator = None + tag_name = None + + # Each operation corresponds to a checker function, a rule + # for determining whether a candidate matches the + # selector. Candidates are generated by the active + # iterator. + checker = None + + m = self.attribselect_re.match(token) + if m is not None: + # Attribute selector + tag_name, attribute, operator, value = m.groups() + checker = self._attribute_checker(operator, attribute, value) + + elif '#' in token: + # ID selector + tag_name, tag_id = token.split('#', 1) + def id_matches(tag): + return tag.get('id', None) == tag_id + checker = id_matches + + elif '.' in token: + # Class selector + tag_name, klass = token.split('.', 1) + classes = set(klass.split('.')) + def classes_match(candidate): + return classes.issubset(candidate.get('class', [])) + checker = classes_match + + elif ':' in token: + # Pseudo-class + tag_name, pseudo = token.split(':', 1) + if tag_name == '': + raise ValueError( + "A pseudo-class must be prefixed with a tag name.") + pseudo_attributes = re.match('([a-zA-Z\d-]+)\(([a-zA-Z\d]+)\)', pseudo) + found = [] + if pseudo_attributes is None: + pseudo_type = pseudo + pseudo_value = None + else: + pseudo_type, pseudo_value = pseudo_attributes.groups() + if pseudo_type == 'nth-of-type': + try: + pseudo_value = int(pseudo_value) + except: raise NotImplementedError( - 'Only the following pseudo-classes are implemented: nth-of-type.') - - elif token == '*': - # Star selector -- matches everything - pass - elif token == '>': - # Run the next token as a CSS selector against the - # direct children of each tag in the current context. - recursive_candidate_generator = lambda tag: tag.children - elif token == '~': - # Run the next token as a CSS selector against the - # siblings of each tag in the current context. - recursive_candidate_generator = lambda tag: tag.next_siblings - elif token == '+': - # For each tag in the current context, run the next - # token as a CSS selector against the tag's next - # sibling that's a tag. - def next_tag_sibling(tag): - yield tag.find_next_sibling(True) - recursive_candidate_generator = next_tag_sibling - - elif self.tag_name_re.match(token): - # Just a tag name. - tag_name = token + 'Only numeric values are currently supported for the nth-of-type pseudo-class.') + if pseudo_value < 1: + raise ValueError( + 'nth-of-type pseudo-class value must be at least 1.') + class Counter(object): + def __init__(self, destination): + self.count = 0 + self.destination = destination + + def nth_child_of_type(self, tag): + self.count += 1 + if self.count == self.destination: + return True + if self.count > self.destination: + # Stop the generator that's sending us + # these things. + raise StopIteration() + return False + checker = Counter(pseudo_value).nth_child_of_type else: - raise ValueError( - 'Unsupported or invalid CSS selector: "%s"' % token) - if recursive_candidate_generator: - # This happens when the selector looks like "> foo". - # - # The generator calls select() recursively on every - # member of the current context, passing in a different - # candidate generator and a different selector. - # - # In the case of "> foo", the candidate generator is - # one that yields a tag's direct children (">"), and - # the selector is "foo". - next_token = tokens[index+1] - def recursive_select(tag): - if self._select_debug: - print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs) - print '-' * 40 - for i in tag.select(next_token, recursive_candidate_generator): - if self._select_debug: - print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs) - yield i - if self._select_debug: - print '-' * 40 - _use_candidate_generator = recursive_select - elif _candidate_generator is None: - # By default, a tag's candidates are all of its - # children. If tag_name is defined, only yield tags - # with that name. + raise NotImplementedError( + 'Only the following pseudo-classes are implemented: nth-of-type.') + + elif token == '*': + # Star selector -- matches everything + pass + elif token == '>': + # Run the next token as a CSS selector against the + # direct children of each tag in the current context. + recursive_candidate_generator = lambda tag: tag.children + elif token == '~': + # Run the next token as a CSS selector against the + # siblings of each tag in the current context. + recursive_candidate_generator = lambda tag: tag.next_siblings + elif token == '+': + # For each tag in the current context, run the next + # token as a CSS selector against the tag's next + # sibling that's a tag. + def next_tag_sibling(tag): + yield tag.find_next_sibling(True) + recursive_candidate_generator = next_tag_sibling + + elif self.tag_name_re.match(token): + # Just a tag name. + tag_name = token + else: + raise ValueError( + 'Unsupported or invalid CSS selector: "%s"' % token) + if recursive_candidate_generator: + # This happens when the selector looks like "> foo". + # + # The generator calls select() recursively on every + # member of the current context, passing in a different + # candidate generator and a different selector. + # + # In the case of "> foo", the candidate generator is + # one that yields a tag's direct children (">"), and + # the selector is "foo". + next_token = tokens[index+1] + def recursive_select(tag): if self._select_debug: - if tag_name: - check = "[any]" - else: - check = tag_name - print ' Default candidate generator, tag name="%s"' % check + print ' Calling select("%s") recursively on %s %s' % (next_token, tag.name, tag.attrs) + print '-' * 40 + for i in tag.select(next_token, recursive_candidate_generator): + if self._select_debug: + print '(Recursive select picked up candidate %s %s)' % (i.name, i.attrs) + yield i if self._select_debug: - # This is redundant with later code, but it stops - # a bunch of bogus tags from cluttering up the - # debug log. - def default_candidate_generator(tag): - for child in tag.descendants: - if not isinstance(child, Tag): - continue - if tag_name and not child.name == tag_name: - continue - yield child - _use_candidate_generator = default_candidate_generator + print '-' * 40 + _use_candidate_generator = recursive_select + elif _candidate_generator is None: + # By default, a tag's candidates are all of its + # children. If tag_name is defined, only yield tags + # with that name. + if self._select_debug: + if tag_name: + check = "[any]" else: - _use_candidate_generator = lambda tag: tag.descendants + check = tag_name + print ' Default candidate generator, tag name="%s"' % check + if self._select_debug: + # This is redundant with later code, but it stops + # a bunch of bogus tags from cluttering up the + # debug log. + def default_candidate_generator(tag): + for child in tag.descendants: + if not isinstance(child, Tag): + continue + if tag_name and not child.name == tag_name: + continue + yield child + _use_candidate_generator = default_candidate_generator else: - _use_candidate_generator = _candidate_generator + _use_candidate_generator = lambda tag: tag.descendants + else: + _use_candidate_generator = _candidate_generator - count = 0 - for tag in current_context: - if self._select_debug: - print " Running candidate generator on %s %s" % ( - tag.name, repr(tag.attrs)) - for candidate in _use_candidate_generator(tag): - if not isinstance(candidate, Tag): - continue - if tag_name and candidate.name != tag_name: - continue - if checker is not None: - try: - result = checker(candidate) - except StopIteration: - # The checker has decided we should no longer - # run the generator. + count = 0 + for tag in current_context: + if self._select_debug: + print " Running candidate generator on %s %s" % ( + tag.name, repr(tag.attrs)) + for candidate in _use_candidate_generator(tag): + if not isinstance(candidate, Tag): + continue + if tag_name and candidate.name != tag_name: + continue + if checker is not None: + try: + result = checker(candidate) + except StopIteration: + # The checker has decided we should no longer + # run the generator. + break + if checker is None or result: + if self._select_debug: + print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs)) + if id(candidate) not in new_context_ids: + # If a tag matches a selector more than once, + # don't include it in the context more than once. + new_context.append(candidate) + new_context_ids.add(id(candidate)) + if limit and len(new_context) >= limit: break - if checker is None or result: - if self._select_debug: - print " SUCCESS %s %s" % (candidate.name, repr(candidate.attrs)) - if id(candidate) not in new_context_ids: - # If a tag matches a selector more than once, - # don't include it in the context more than once. - new_context.append(candidate) - new_context_ids.add(id(candidate)) - if limit and len(new_context) >= limit: - break - elif self._select_debug: - print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs)) + elif self._select_debug: + print " FAILURE %s %s" % (candidate.name, repr(candidate.attrs)) current_context = new_context diff --git a/bs4/tests/test_tree.py b/bs4/tests/test_tree.py index 2371591..8aac22e 100644 --- a/bs4/tests/test_tree.py +++ b/bs4/tests/test_tree.py @@ -1942,22 +1942,25 @@ class TestSoupSelector(TreeTest): # Test the selector grouping operator (the comma) def test_multiple_select(self): - self.assertSelects('x, y',['xid','yid']) + self.assertSelects('x, y', ['xid', 'yid']) def test_multiple_select_with_no_space(self): - self.assertSelects('x,y',['xid','yid']) + self.assertSelects('x,y', ['xid', 'yid']) def test_multiple_select_with_more_space(self): - self.assertSelects('x, y',['xid', 'yid']) + self.assertSelects('x, y', ['xid', 'yid']) + + def test_multiple_select_duplicated(self): + self.assertSelects('x, x', ['xid']) def test_multiple_select_sibling(self): - self.assertSelects('x, y ~ p[lang=fr]',['lang-fr']) + self.assertSelects('x, y ~ p[lang=fr]', ['xid', 'lang-fr']) - def test_multiple_select(self): - self.assertSelects('x, y > z', ['zida', 'zidb', 'zidab', 'zidac']) + def test_multiple_select_tag_and_direct_descendant(self): + self.assertSelects('x, y > z', ['xid', 'zidb']) - def test_multiple_select_direct_descendant(self): - self.assertSelects('div > x, y, z', ['xid', 'yid']) + def test_multiple_select_direct_descendant_and_tags(self): + self.assertSelects('div > x, y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac']) def test_multiple_select_indirect_descendant(self): self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac']) @@ -1966,14 +1969,14 @@ class TestSoupSelector(TreeTest): self.assertRaises(ValueError, self.soup.select, ',x, y') self.assertRaises(ValueError, self.soup.select, 'x,,y') - def test_multiple_select(self): - self.assertSelects('p[lang=en], p[lang=en-gb]',['lang-en','lang-en-gb']) + def test_multiple_select_attrs(self): + self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb']) def test_multiple_select_ids(self): - self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['zida', 'zidb','zidab']) + self.assertSelects('x, y > z[id=zida], z[id=zidab], z[id=zidb]', ['xid', 'zidb', 'zidab']) def test_multiple_select_nested(self): - self.assertSelects('body > div > x, y > z', ['zida', 'zidb', 'zidab', 'zidac']) + self.assertSelects('body > div > x, y > z', ['xid', 'zidb'])