Introduce get_elements_by_class and get_elements_by_attribute utility functions
authorThomas Christlieb <thomaschristlieb@hotmail.com>
Sat, 11 Feb 2017 09:16:54 +0000 (10:16 +0100)
committerSergey M <dstftw@gmail.com>
Sat, 11 Feb 2017 09:16:54 +0000 (17:16 +0800)
test/test_utils.py
youtube_dl/utils.py

index edc712f0741576c852be2b528f95dcf81f309bfc..3cdb21d409b97c4dec4a1c205960ae133642c620 100644 (file)
@@ -34,6 +34,9 @@ from youtube_dl.utils import (
     find_xpath_attr,
     fix_xml_ampersands,
     get_element_by_class,
+    get_element_by_attribute,
+    get_elements_by_class,
+    get_elements_by_attribute,
     InAdvancePagedList,
     intlist_to_bytes,
     is_html,
@@ -1124,6 +1127,32 @@ The first line
         self.assertEqual(get_element_by_class('foo', html), 'nice')
         self.assertEqual(get_element_by_class('no-such-class', html), None)
 
+    def test_get_element_by_attribute(self):
+        html = '''
+            <span class="foo bar">nice</span>
+        '''
+
+        self.assertEqual(get_element_by_attribute('class', 'foo bar', html), 'nice')
+        self.assertEqual(get_element_by_attribute('class', 'foo', html), None)
+        self.assertEqual(get_element_by_attribute('class', 'no-such-foo', html), None)
+
+    def test_get_elements_by_class(self):
+        html = '''
+            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
+        '''
+
+        self.assertEqual(get_elements_by_class('foo', html), ['nice', 'also nice'])
+        self.assertEqual(get_elements_by_class('no-such-class', html), [])
+
+    def test_get_elements_by_attribute(self):
+        html = '''
+            <span class="foo bar">nice</span><span class="foo bar">also nice</span>
+        '''
+
+        self.assertEqual(get_elements_by_attribute('class', 'foo bar', html), ['nice', 'also nice'])
+        self.assertEqual(get_elements_by_attribute('class', 'foo', html), [])
+        self.assertEqual(get_elements_by_attribute('class', 'no-such-foo', html), [])
+
 
 if __name__ == '__main__':
     unittest.main()
index 67a847ebad8238fc4f368f46b336b80e6caa3673..a81fe7d30c4b26e9a4ec1861dd6e9c2a3f1f590d 100644 (file)
@@ -337,17 +337,30 @@ def get_element_by_id(id, html):
 
 
 def get_element_by_class(class_name, html):
-    return get_element_by_attribute(
+    """Return the content of the first tag with the specified class in the passed HTML document"""
+    retval = get_elements_by_class(class_name, html)
+    return retval[0] if retval else None
+
+
+def get_element_by_attribute(attribute, value, html, escape_value=True):
+    retval = get_elements_by_attribute(attribute, value, html, escape_value)
+    return retval[0] if retval else None
+
+
+def get_elements_by_class(class_name, html):
+    """Return the content of all tags with the specified class in the passed HTML document as a list"""
+    return get_elements_by_attribute(
         'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
         html, escape_value=False)
 
 
-def get_element_by_attribute(attribute, value, html, escape_value=True):
+def get_elements_by_attribute(attribute, value, html, escape_value=True):
     """Return the content of the tag with the specified attribute in the passed HTML document"""
 
     value = re.escape(value) if escape_value else value
 
-    m = re.search(r'''(?xs)
+    retlist = []
+    for m in re.finditer(r'''(?xs)
         <([a-zA-Z0-9:._-]+)
          (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
          \s+%s=['"]?%s['"]?
@@ -355,16 +368,15 @@ def get_element_by_attribute(attribute, value, html, escape_value=True):
         \s*>
         (?P<content>.*?)
         </\1>
-    ''' % (re.escape(attribute), value), html)
+    ''' % (re.escape(attribute), value), html):
+        res = m.group('content')
 
-    if not m:
-        return None
-    res = m.group('content')
+        if res.startswith('"') or res.startswith("'"):
+            res = res[1:-1]
 
-    if res.startswith('"') or res.startswith("'"):
-        res = res[1:-1]
+        retlist.append(unescapeHTML(res))
 
-    return unescapeHTML(res)
+    return retlist
 
 
 class HTMLAttributeParser(compat_HTMLParser):