[swfinterp] Intepret more multinames
[youtube-dl] / youtube_dl / swfinterp.py
index 1cd2921386f3a86cc12a95545478311e571a268a..7c0ee1e61a89d13d56bf0f2c54111bd738054890 100644 (file)
@@ -2,23 +2,42 @@ from __future__ import unicode_literals
 
 import collections
 import io
-import struct
 import zlib
 
-from .utils import ExtractorError
-
-
-def _extract_tags(content):
-    pos = 0
+from .utils import (
+    compat_str,
+    ExtractorError,
+    struct_unpack,
+)
+
+
+def _extract_tags(file_contents):
+    if file_contents[1:3] != b'WS':
+        raise ExtractorError(
+            'Not an SWF file; header is %r' % file_contents[:3])
+    if file_contents[:1] == b'C':
+        content = zlib.decompress(file_contents[8:])
+    else:
+        raise NotImplementedError(
+            'Unsupported compression format %r' %
+            file_contents[:1])
+
+    # Determine number of bits in framesize rectangle
+    framesize_nbits = struct_unpack('!B', content[:1])[0] >> 3
+    framesize_len = (5 + 4 * framesize_nbits + 7) // 8
+
+    pos = framesize_len + 2 + 2
     while pos < len(content):
-        header16 = struct.unpack('<H', content[pos:pos + 2])[0]
+        header16 = struct_unpack('<H', content[pos:pos + 2])[0]
         pos += 2
         tag_code = header16 >> 6
         tag_len = header16 & 0x3f
         if tag_len == 0x3f:
-            tag_len = struct.unpack('<I', content[pos:pos + 4])[0]
+            tag_len = struct_unpack('<I', content[pos:pos + 4])[0]
             pos += 4
-        assert pos + tag_len <= len(content)
+        assert pos + tag_len <= len(content), \
+            ('Tag %d ends at %d+%d - that\'s longer than the file (%d)'
+                % (tag_code, pos, tag_len, len(content)))
         yield (tag_code, content[pos:pos + tag_len])
         pos += tag_len
 
@@ -31,6 +50,17 @@ class _AVMClass_Object(object):
         return '%s#%x' % (self.avm_class.name, id(self))
 
 
+class _ScopeDict(dict):
+    def __init__(self, avm_class):
+        super(_ScopeDict, self).__init__()
+        self.avm_class = avm_class
+
+    def __repr__(self):
+        return '%s__Scope(%s)' % (
+            self.avm_class.name,
+            super(_ScopeDict, self).__repr__())
+
+
 class _AVMClass(object):
     def __init__(self, name_idx, name):
         self.name_idx = name_idx
@@ -39,11 +69,29 @@ class _AVMClass(object):
         self.method_idxs = {}
         self.methods = {}
         self.method_pyfunctions = {}
-        self.variables = {}
+
+        self.variables = _ScopeDict(self)
 
     def make_object(self):
         return _AVMClass_Object(self)
 
+    def __repr__(self):
+        return '_AVMClass(%s)' % (self.name)
+
+    def register_methods(self, methods):
+        self.method_names.update(methods.items())
+        self.method_idxs.update(dict(
+            (idx, name)
+            for name, idx in methods.items()))
+
+
+class _Multiname(object):
+    def __init__(self, kind):
+        self.kind = kind
+
+    def __repr__(self):
+        return '[MULTINAME kind: 0x%x]' % self.kind
+
 
 def _read_int(reader):
     res = 0
@@ -51,7 +99,7 @@ def _read_int(reader):
     for _ in range(5):
         buf = reader.read(1)
         assert len(buf) == 1
-        b = struct.unpack('<B', buf)[0]
+        b = struct_unpack('<B', buf)[0]
         res = res | ((b & 0x7f) << shift)
         if b & 0x80 == 0:
             break
@@ -63,7 +111,7 @@ def _u30(reader):
     res = _read_int(reader)
     assert res & 0xf0000000 == 0
     return res
-u32 = _read_int
+_u32 = _read_int
 
 
 def _s32(reader):
@@ -76,8 +124,8 @@ def _s32(reader):
 def _s24(reader):
     bs = reader.read(3)
     assert len(bs) == 3
-    first_byte = b'\xff' if (ord(bs[0:1]) >= 0x80) else b'\x00'
-    return struct.unpack('!i', first_byte + bs)
+    last_byte = b'\xff' if (ord(bs[2:3]) >= 0x80) else b'\x00'
+    return struct_unpack('<i', bs + last_byte)[0]
 
 
 def _read_string(reader):
@@ -88,8 +136,7 @@ def _read_string(reader):
 
 
 def _read_bytes(count, reader):
-    if reader is None:
-        reader = code_reader
+    assert count >= 0
     resb = reader.read(count)
     assert len(resb) == count
     return resb
@@ -97,24 +144,15 @@ def _read_bytes(count, reader):
 
 def _read_byte(reader):
     resb = _read_bytes(1, reader=reader)
-    res = struct.unpack('<B', resb)[0]
+    res = struct_unpack('<B', resb)[0]
     return res
 
 
 class SWFInterpreter(object):
     def __init__(self, file_contents):
-        if file_contents[1:3] != b'WS':
-            raise ExtractorError(
-                'Not an SWF file; header is %r' % file_contents[:3])
-        if file_contents[:1] == b'C':
-            content = zlib.decompress(file_contents[8:])
-        else:
-            raise NotImplementedError(
-                'Unsupported compression format %r' %
-                file_contents[:1])
-
+        self._patched_functions = {}
         code_tag = next(tag
-                        for tag_code, tag in _extract_tags(content)
+                        for tag_code, tag in _extract_tags(file_contents)
                         if tag_code == 82)
         p = code_tag.index(b'\0', 4) + 1
         code_reader = io.BytesIO(code_tag[p:])
@@ -139,12 +177,12 @@ class SWFInterpreter(object):
         for _c in range(1, uint_count):
             u32()
         double_count = u30()
-        read_bytes((double_count - 1) * 8)
+        read_bytes(max(0, (double_count - 1)) * 8)
         string_count = u30()
-        constant_strings = ['']
+        self.constant_strings = ['']
         for _c in range(1, string_count):
             s = _read_string(code_reader)
-            constant_strings.append(s)
+            self.constant_strings.append(s)
         namespace_count = u30()
         for _c in range(1, namespace_count):
             read_bytes(1)  # kind
@@ -174,9 +212,13 @@ class SWFInterpreter(object):
             if kind == 0x07:
                 u30()  # namespace_idx
                 name_idx = u30()
-                self.multinames.append(constant_strings[name_idx])
+                self.multinames.append(self.constant_strings[name_idx])
+            elif kind == 0x09:
+                name_idx = u30()
+                u30()
+                self.multinames.append(self.constant_strings[name_idx])
             else:
-                self.multinames.append('[MULTINAME kind: %d]' % kind)
+                self.multinames.append(_Multiname(kind))
                 for _c2 in range(MULTINAME_SIZES[kind]):
                     u30()
 
@@ -253,7 +295,11 @@ class SWFInterpreter(object):
         classes = []
         for class_id in range(class_count):
             name_idx = u30()
-            classes.append(_AVMClass(name_idx, self.multinames[name_idx]))
+
+            cname = self.multinames[name_idx]
+            avm_class = _AVMClass(name_idx, cname)
+            classes.append(avm_class)
+
             u30()  # super_name idx
             flags = read_byte()
             if flags & 0x08 != 0:  # Protected namespace is present
@@ -264,7 +310,9 @@ class SWFInterpreter(object):
             u30()  # iinit
             trait_count = u30()
             for _c2 in range(trait_count):
-                parse_traits_info()
+                trait_methods = parse_traits_info()
+                avm_class.register_methods(trait_methods)
+
         assert len(classes) == class_count
         self._classes_by_name = dict((c.name, c) for c in classes)
 
@@ -273,10 +321,7 @@ class SWFInterpreter(object):
             trait_count = u30()
             for _c2 in range(trait_count):
                 trait_methods = parse_traits_info()
-                avm_class.method_names.update(trait_methods.items())
-                avm_class.method_idxs.update(dict(
-                    (idx, name)
-                    for name, idx in trait_methods.items()))
+                avm_class.register_methods(trait_methods)
 
         # Scripts
         script_count = u30()
@@ -314,6 +359,9 @@ class SWFInterpreter(object):
 
         assert p + code_reader.tell() == len(code_tag)
 
+    def patch_function(self, avm_class, func_name, f):
+        self._patched_functions[(avm_class, func_name)] = f
+
     def extract_class(self, class_name):
         try:
             return self._classes_by_name[class_name]
@@ -321,12 +369,16 @@ class SWFInterpreter(object):
             raise ExtractorError('Class %r not found' % class_name)
 
     def extract_function(self, avm_class, func_name):
+        p = self._patched_functions.get((avm_class, func_name))
+        if p:
+            return p
         if func_name in avm_class.method_pyfunctions:
             return avm_class.method_pyfunctions[func_name]
         if func_name in self._classes_by_name:
             return self._classes_by_name[func_name].make_object()
         if func_name not in avm_class.methods:
-            raise ExtractorError('Cannot find function %r' % func_name)
+            raise ExtractorError('Cannot find function %s.%s' % (
+                avm_class.name, func_name))
         m = avm_class.methods[func_name]
 
         def resfunc(args):
@@ -335,27 +387,41 @@ class SWFInterpreter(object):
             s24 = lambda: _s24(coder)
             u30 = lambda: _u30(coder)
 
-            print('Invoking %s.%s(%r)' % (avm_class.name, func_name, tuple(args)))
-            registers = ['(this)'] + list(args) + [None] * m.local_count
+            registers = [avm_class.variables] + list(args) + [None] * m.local_count
             stack = []
+            scopes = collections.deque([
+                self._classes_by_name, avm_class.variables])
             while True:
                 opcode = _read_byte(coder)
-                print('opcode: %r, stack(%d): %r' % (opcode, len(stack), stack))
                 if opcode == 17:  # iftrue
                     offset = s24()
                     value = stack.pop()
                     if value:
                         coder.seek(coder.tell() + offset)
+                elif opcode == 18:  # iffalse
+                    offset = s24()
+                    value = stack.pop()
+                    if not value:
+                        coder.seek(coder.tell() + offset)
                 elif opcode == 36:  # pushbyte
                     v = _read_byte(coder)
                     stack.append(v)
+                elif opcode == 42:  # dup
+                    value = stack[-1]
+                    stack.append(value)
                 elif opcode == 44:  # pushstring
                     idx = u30()
-                    stack.append(constant_strings[idx])
+                    stack.append(self.constant_strings[idx])
                 elif opcode == 48:  # pushscope
-                    # We don't implement the scope register, so we'll just
-                    # ignore the popped value
                     new_scope = stack.pop()
+                    scopes.append(new_scope)
+                elif opcode == 66:  # construct
+                    arg_count = u30()
+                    args = list(reversed(
+                        [stack.pop() for _ in range(arg_count)]))
+                    obj = stack.pop()
+                    res = obj.avm_class.make_object()
+                    stack.append(res)
                 elif opcode == 70:  # callproperty
                     index = u30()
                     mname = self.multinames[index]
@@ -363,33 +429,46 @@ class SWFInterpreter(object):
                     args = list(reversed(
                         [stack.pop() for _ in range(arg_count)]))
                     obj = stack.pop()
-                    if mname == 'split':
-                        assert len(args) == 1
-                        assert isinstance(args[0], compat_str)
-                        assert isinstance(obj, compat_str)
-                        if args[0] == '':
-                            res = list(obj)
-                        else:
-                            res = obj.split(args[0])
-                        stack.append(res)
-                    elif mname == 'slice':
-                        assert len(args) == 1
-                        assert isinstance(args[0], int)
-                        assert isinstance(obj, list)
-                        res = obj[args[0]:]
+
+                    if isinstance(obj, _AVMClass_Object):
+                        func = self.extract_function(obj.avm_class, mname)
+                        res = func(args)
                         stack.append(res)
-                    elif mname == 'join':
-                        assert len(args) == 1
-                        assert isinstance(args[0], compat_str)
-                        assert isinstance(obj, list)
-                        res = args[0].join(obj)
+                        continue
+                    elif isinstance(obj, _ScopeDict):
+                        if mname in obj.avm_class.method_names:
+                            func = self.extract_function(obj.avm_class, mname)
+                            res = func(args)
+                        else:
+                            res = obj[mname]
                         stack.append(res)
-                    elif mname in avm_class.method_pyfunctions:
-                        stack.append(avm_class.method_pyfunctions[mname](args))
-                    else:
-                        raise NotImplementedError(
-                            'Unsupported property %r on %r'
-                            % (mname, obj))
+                        continue
+                    elif isinstance(obj, compat_str):
+                        if mname == 'split':
+                            assert len(args) == 1
+                            assert isinstance(args[0], compat_str)
+                            if args[0] == '':
+                                res = list(obj)
+                            else:
+                                res = obj.split(args[0])
+                            stack.append(res)
+                            continue
+                    elif isinstance(obj, list):
+                        if mname == 'slice':
+                            assert len(args) == 1
+                            assert isinstance(args[0], int)
+                            res = obj[args[0]:]
+                            stack.append(res)
+                            continue
+                        elif mname == 'join':
+                            assert len(args) == 1
+                            assert isinstance(args[0], compat_str)
+                            res = args[0].join(obj)
+                            stack.append(res)
+                            continue
+                    raise NotImplementedError(
+                        'Unsupported property %r on %r'
+                        % (mname, obj))
                 elif opcode == 72:  # returnvalue
                     res = stack.pop()
                     return res
@@ -401,11 +480,11 @@ class SWFInterpreter(object):
                     obj = stack.pop()
 
                     mname = self.multinames[index]
-                    construct_method = self.extract_function(
-                        obj.avm_class, mname)
+                    assert isinstance(obj, _AVMClass)
+
                     # We do not actually call the constructor for now;
                     # we just pretend it does nothing
-                    stack.append(obj)
+                    stack.append(obj.make_object())
                 elif opcode == 79:  # callpropvoid
                     index = u30()
                     mname = self.multinames[index]
@@ -430,22 +509,42 @@ class SWFInterpreter(object):
                 elif opcode == 93:  # findpropstrict
                     index = u30()
                     mname = self.multinames[index]
-                    res = self.extract_function(avm_class, mname)
-                    stack.append(res)
+                    for s in reversed(scopes):
+                        if mname in s:
+                            res = s
+                            break
+                    else:
+                        res = scopes[0]
+                    stack.append(res[mname])
                 elif opcode == 94:  # findproperty
                     index = u30()
                     mname = self.multinames[index]
-                    res = avm_class.variables.get(mname)
+                    for s in reversed(scopes):
+                        if mname in s:
+                            res = s
+                            break
+                    else:
+                        res = avm_class.variables
                     stack.append(res)
                 elif opcode == 96:  # getlex
                     index = u30()
                     mname = self.multinames[index]
-                    res = avm_class.variables.get(mname, None)
+                    for s in reversed(scopes):
+                        if mname in s:
+                            scope = s
+                            break
+                    else:
+                        scope = avm_class.variables
+                    # I cannot find where static variables are initialized
+                    # so let's just return None
+                    res = scope.get(mname)
                     stack.append(res)
                 elif opcode == 97:  # setproperty
                     index = u30()
                     value = stack.pop()
                     idx = self.multinames[index]
+                    if isinstance(idx, _Multiname):
+                        idx = stack.pop()
                     obj = stack.pop()
                     obj[idx] = value
                 elif opcode == 98:  # getlocal
@@ -462,16 +561,35 @@ class SWFInterpreter(object):
                         obj = stack.pop()
                         assert isinstance(obj, list)
                         stack.append(len(obj))
+                    elif isinstance(pname, compat_str):  # Member access
+                        obj = stack.pop()
+                        assert isinstance(obj, (dict, _ScopeDict)), \
+                            'Accessing member on %r' % obj
+                        stack.append(obj[pname])
                     else:  # Assume attribute access
                         idx = stack.pop()
                         assert isinstance(idx, int)
                         obj = stack.pop()
                         assert isinstance(obj, list)
                         stack.append(obj[idx])
+                elif opcode == 115:  # convert_
+                    value = stack.pop()
+                    intvalue = int(value)
+                    stack.append(intvalue)
                 elif opcode == 128:  # coerce
                     u30()
                 elif opcode == 133:  # coerce_s
                     assert isinstance(stack[-1], (type(None), compat_str))
+                elif opcode == 160:  # add
+                    value2 = stack.pop()
+                    value1 = stack.pop()
+                    res = value1 + value2
+                    stack.append(res)
+                elif opcode == 161:  # subtract
+                    value2 = stack.pop()
+                    value1 = stack.pop()
+                    res = value1 - value2
+                    stack.append(res)
                 elif opcode == 164:  # modulo
                     value2 = stack.pop()
                     value1 = stack.pop()
@@ -490,6 +608,10 @@ class SWFInterpreter(object):
                     stack.append(registers[2])
                 elif opcode == 211:  # getlocal_3
                     stack.append(registers[3])
+                elif opcode == 212:  # setlocal_0
+                    registers[0] = stack.pop()
+                elif opcode == 213:  # setlocal_1
+                    registers[1] = stack.pop()
                 elif opcode == 214:  # setlocal_2
                     registers[2] = stack.pop()
                 elif opcode == 215:  # setlocal_3