[swfinterp] Extend tests and fix parsing
[youtube-dl] / youtube_dl / swfinterp.py
1 from __future__ import unicode_literals
2
3 import collections
4 import io
5 import struct
6 import zlib
7
8 from .utils import ExtractorError
9
10
11 def _extract_tags(file_contents):
12     if file_contents[1:3] != b'WS':
13         raise ExtractorError(
14             'Not an SWF file; header is %r' % file_contents[:3])
15     if file_contents[:1] == b'C':
16         content = zlib.decompress(file_contents[8:])
17     else:
18         raise NotImplementedError(
19             'Unsupported compression format %r' %
20             file_contents[:1])
21
22     # Determine number of bits in framesize rectangle
23     framesize_nbits = struct.unpack('!B', content[:1])[0] >> 3
24     framesize_len = (5 + 4 * framesize_nbits + 7) // 8
25
26     pos = framesize_len + 2 + 2
27     while pos < len(content):
28         header16 = struct.unpack('<H', content[pos:pos + 2])[0]
29         pos += 2
30         tag_code = header16 >> 6
31         tag_len = header16 & 0x3f
32         if tag_len == 0x3f:
33             tag_len = struct.unpack('<I', content[pos:pos + 4])[0]
34             pos += 4
35         assert pos + tag_len <= len(content), \
36             ('Tag %d ends at %d+%d - that\'s longer than the file (%d)'
37                 % (tag_code, pos, tag_len, len(content)))
38         yield (tag_code, content[pos:pos + tag_len])
39         pos += tag_len
40
41
42 class _AVM_Object(object):
43     def __init__(self, value=None, name_hint=None):
44         self.value = value
45         self.name_hint = name_hint
46
47     def __repr__(self):
48         nh = '' if self.name_hint is None else (' %s' % self.name_hint)
49         return 'AVMObject%s(%r)' % (nh, self.value)
50
51
52 class _AVMClass_Object(object):
53     def __init__(self, avm_class):
54         self.avm_class = avm_class
55
56     def __repr__(self):
57         return '%s#%x' % (self.avm_class.name, id(self))
58
59
60 class _AVMClass(object):
61     def __init__(self, name_idx, name):
62         self.name_idx = name_idx
63         self.name = name
64         self.method_names = {}
65         self.method_idxs = {}
66         self.methods = {}
67         self.method_pyfunctions = {}
68         self.variables = {}
69
70     def make_object(self):
71         return _AVMClass_Object(self)
72
73
74 def _read_int(reader):
75     res = 0
76     shift = 0
77     for _ in range(5):
78         buf = reader.read(1)
79         assert len(buf) == 1
80         b = struct.unpack('<B', buf)[0]
81         res = res | ((b & 0x7f) << shift)
82         if b & 0x80 == 0:
83             break
84         shift += 7
85     return res
86
87
88 def _u30(reader):
89     res = _read_int(reader)
90     assert res & 0xf0000000 == 0
91     return res
92 u32 = _read_int
93
94
95 def _s32(reader):
96     v = _read_int(reader)
97     if v & 0x80000000 != 0:
98         v = - ((v ^ 0xffffffff) + 1)
99     return v
100
101
102 def _s24(reader):
103     bs = reader.read(3)
104     assert len(bs) == 3
105     last_byte = b'\xff' if (ord(bs[2:3]) >= 0x80) else b'\x00'
106     return struct.unpack('<i', bs + last_byte)[0]
107
108
109 def _read_string(reader):
110     slen = _u30(reader)
111     resb = reader.read(slen)
112     assert len(resb) == slen
113     return resb.decode('utf-8')
114
115
116 def _read_bytes(count, reader):
117     assert count >= 0
118     resb = reader.read(count)
119     assert len(resb) == count
120     return resb
121
122
123 def _read_byte(reader):
124     resb = _read_bytes(1, reader=reader)
125     res = struct.unpack('<B', resb)[0]
126     return res
127
128
129 class SWFInterpreter(object):
130     def __init__(self, file_contents):
131         code_tag = next(tag
132                         for tag_code, tag in _extract_tags(file_contents)
133                         if tag_code == 82)
134         p = code_tag.index(b'\0', 4) + 1
135         code_reader = io.BytesIO(code_tag[p:])
136
137         # Parse ABC (AVM2 ByteCode)
138
139         # Define a couple convenience methods
140         u30 = lambda *args: _u30(*args, reader=code_reader)
141         s32 = lambda *args: _s32(*args, reader=code_reader)
142         u32 = lambda *args: _u32(*args, reader=code_reader)
143         read_bytes = lambda *args: _read_bytes(*args, reader=code_reader)
144         read_byte = lambda *args: _read_byte(*args, reader=code_reader)
145
146         # minor_version + major_version
147         read_bytes(2 + 2)
148
149         # Constant pool
150         int_count = u30()
151         for _c in range(1, int_count):
152             s32()
153         uint_count = u30()
154         for _c in range(1, uint_count):
155             u32()
156         double_count = u30()
157         read_bytes(max(0, (double_count - 1)) * 8)
158         string_count = u30()
159         constant_strings = ['']
160         for _c in range(1, string_count):
161             s = _read_string(code_reader)
162             constant_strings.append(s)
163         namespace_count = u30()
164         for _c in range(1, namespace_count):
165             read_bytes(1)  # kind
166             u30()  # name
167         ns_set_count = u30()
168         for _c in range(1, ns_set_count):
169             count = u30()
170             for _c2 in range(count):
171                 u30()
172         multiname_count = u30()
173         MULTINAME_SIZES = {
174             0x07: 2,  # QName
175             0x0d: 2,  # QNameA
176             0x0f: 1,  # RTQName
177             0x10: 1,  # RTQNameA
178             0x11: 0,  # RTQNameL
179             0x12: 0,  # RTQNameLA
180             0x09: 2,  # Multiname
181             0x0e: 2,  # MultinameA
182             0x1b: 1,  # MultinameL
183             0x1c: 1,  # MultinameLA
184         }
185         self.multinames = ['']
186         for _c in range(1, multiname_count):
187             kind = u30()
188             assert kind in MULTINAME_SIZES, 'Invalid multiname kind %r' % kind
189             if kind == 0x07:
190                 u30()  # namespace_idx
191                 name_idx = u30()
192                 self.multinames.append(constant_strings[name_idx])
193             else:
194                 self.multinames.append('[MULTINAME kind: %d]' % kind)
195                 for _c2 in range(MULTINAME_SIZES[kind]):
196                     u30()
197
198         # Methods
199         method_count = u30()
200         MethodInfo = collections.namedtuple(
201             'MethodInfo',
202             ['NEED_ARGUMENTS', 'NEED_REST'])
203         method_infos = []
204         for method_id in range(method_count):
205             param_count = u30()
206             u30()  # return type
207             for _ in range(param_count):
208                 u30()  # param type
209             u30()  # name index (always 0 for youtube)
210             flags = read_byte()
211             if flags & 0x08 != 0:
212                 # Options present
213                 option_count = u30()
214                 for c in range(option_count):
215                     u30()  # val
216                     read_bytes(1)  # kind
217             if flags & 0x80 != 0:
218                 # Param names present
219                 for _ in range(param_count):
220                     u30()  # param name
221             mi = MethodInfo(flags & 0x01 != 0, flags & 0x04 != 0)
222             method_infos.append(mi)
223
224         # Metadata
225         metadata_count = u30()
226         for _c in range(metadata_count):
227             u30()  # name
228             item_count = u30()
229             for _c2 in range(item_count):
230                 u30()  # key
231                 u30()  # value
232
233         def parse_traits_info():
234             trait_name_idx = u30()
235             kind_full = read_byte()
236             kind = kind_full & 0x0f
237             attrs = kind_full >> 4
238             methods = {}
239             if kind in [0x00, 0x06]:  # Slot or Const
240                 u30()  # Slot id
241                 u30()  # type_name_idx
242                 vindex = u30()
243                 if vindex != 0:
244                     read_byte()  # vkind
245             elif kind in [0x01, 0x02, 0x03]:  # Method / Getter / Setter
246                 u30()  # disp_id
247                 method_idx = u30()
248                 methods[self.multinames[trait_name_idx]] = method_idx
249             elif kind == 0x04:  # Class
250                 u30()  # slot_id
251                 u30()  # classi
252             elif kind == 0x05:  # Function
253                 u30()  # slot_id
254                 function_idx = u30()
255                 methods[function_idx] = self.multinames[trait_name_idx]
256             else:
257                 raise ExtractorError('Unsupported trait kind %d' % kind)
258
259             if attrs & 0x4 != 0:  # Metadata present
260                 metadata_count = u30()
261                 for _c3 in range(metadata_count):
262                     u30()  # metadata index
263
264             return methods
265
266         # Classes
267         class_count = u30()
268         classes = []
269         for class_id in range(class_count):
270             name_idx = u30()
271             classes.append(_AVMClass(name_idx, self.multinames[name_idx]))
272             u30()  # super_name idx
273             flags = read_byte()
274             if flags & 0x08 != 0:  # Protected namespace is present
275                 u30()  # protected_ns_idx
276             intrf_count = u30()
277             for _c2 in range(intrf_count):
278                 u30()
279             u30()  # iinit
280             trait_count = u30()
281             for _c2 in range(trait_count):
282                 parse_traits_info()
283         assert len(classes) == class_count
284         self._classes_by_name = dict((c.name, c) for c in classes)
285
286         for avm_class in classes:
287             u30()  # cinit
288             trait_count = u30()
289             for _c2 in range(trait_count):
290                 trait_methods = parse_traits_info()
291                 avm_class.method_names.update(trait_methods.items())
292                 avm_class.method_idxs.update(dict(
293                     (idx, name)
294                     for name, idx in trait_methods.items()))
295
296         # Scripts
297         script_count = u30()
298         for _c in range(script_count):
299             u30()  # init
300             trait_count = u30()
301             for _c2 in range(trait_count):
302                 parse_traits_info()
303
304         # Method bodies
305         method_body_count = u30()
306         Method = collections.namedtuple('Method', ['code', 'local_count'])
307         for _c in range(method_body_count):
308             method_idx = u30()
309             u30()  # max_stack
310             local_count = u30()
311             u30()  # init_scope_depth
312             u30()  # max_scope_depth
313             code_length = u30()
314             code = read_bytes(code_length)
315             for avm_class in classes:
316                 if method_idx in avm_class.method_idxs:
317                     m = Method(code, local_count)
318                     avm_class.methods[avm_class.method_idxs[method_idx]] = m
319             exception_count = u30()
320             for _c2 in range(exception_count):
321                 u30()  # from
322                 u30()  # to
323                 u30()  # target
324                 u30()  # exc_type
325                 u30()  # var_name
326             trait_count = u30()
327             for _c2 in range(trait_count):
328                 parse_traits_info()
329
330         assert p + code_reader.tell() == len(code_tag)
331
332     def extract_class(self, class_name):
333         try:
334             return self._classes_by_name[class_name]
335         except KeyError:
336             raise ExtractorError('Class %r not found' % class_name)
337
338     def extract_function(self, avm_class, func_name):
339         if func_name in avm_class.method_pyfunctions:
340             return avm_class.method_pyfunctions[func_name]
341         if func_name in self._classes_by_name:
342             return self._classes_by_name[func_name].make_object()
343         if func_name not in avm_class.methods:
344             raise ExtractorError('Cannot find function %r' % func_name)
345         m = avm_class.methods[func_name]
346
347         def resfunc(args):
348             # Helper functions
349             coder = io.BytesIO(m.code)
350             s24 = lambda: _s24(coder)
351             u30 = lambda: _u30(coder)
352
353             print('Invoking %s.%s(%r)' % (avm_class.name, func_name, tuple(args)))
354             registers = [avm_class.variables] + list(args) + [None] * m.local_count
355             stack = []
356             scopes = collections.deque([avm_class.variables])
357             while True:
358                 opcode = _read_byte(coder)
359                 print('opcode: %r, stack(%d): %r' % (opcode, len(stack), stack))
360                 if opcode == 17:  # iftrue
361                     offset = s24()
362                     value = stack.pop()
363                     if value:
364                         coder.seek(coder.tell() + offset)
365                 elif opcode == 18:  # iffalse
366                     offset = s24()
367                     value = stack.pop()
368                     if not value:
369                         coder.seek(coder.tell() + offset)
370                 elif opcode == 36:  # pushbyte
371                     v = _read_byte(coder)
372                     stack.append(v)
373                 elif opcode == 42:  # dup
374                     value = stack[-1]
375                     stack.append(value)
376                 elif opcode == 44:  # pushstring
377                     idx = u30()
378                     stack.append(constant_strings[idx])
379                 elif opcode == 48:  # pushscope
380                     new_scope = stack.pop()
381                     scopes.append(new_scope)
382                 elif opcode == 70:  # callproperty
383                     index = u30()
384                     mname = self.multinames[index]
385                     arg_count = u30()
386                     args = list(reversed(
387                         [stack.pop() for _ in range(arg_count)]))
388                     obj = stack.pop()
389                     if mname == 'split':
390                         assert len(args) == 1
391                         assert isinstance(args[0], compat_str)
392                         assert isinstance(obj, compat_str)
393                         if args[0] == '':
394                             res = list(obj)
395                         else:
396                             res = obj.split(args[0])
397                         stack.append(res)
398                     elif mname == 'slice':
399                         assert len(args) == 1
400                         assert isinstance(args[0], int)
401                         assert isinstance(obj, list)
402                         res = obj[args[0]:]
403                         stack.append(res)
404                     elif mname == 'join':
405                         assert len(args) == 1
406                         assert isinstance(args[0], compat_str)
407                         assert isinstance(obj, list)
408                         res = args[0].join(obj)
409                         stack.append(res)
410                     elif mname in avm_class.method_pyfunctions:
411                         stack.append(avm_class.method_pyfunctions[mname](args))
412                     else:
413                         raise NotImplementedError(
414                             'Unsupported property %r on %r'
415                             % (mname, obj))
416                 elif opcode == 72:  # returnvalue
417                     res = stack.pop()
418                     return res
419                 elif opcode == 74:  # constructproperty
420                     index = u30()
421                     arg_count = u30()
422                     args = list(reversed(
423                         [stack.pop() for _ in range(arg_count)]))
424                     obj = stack.pop()
425
426                     mname = self.multinames[index]
427                     construct_method = self.extract_function(
428                         obj.avm_class, mname)
429                     # We do not actually call the constructor for now;
430                     # we just pretend it does nothing
431                     stack.append(obj)
432                 elif opcode == 79:  # callpropvoid
433                     index = u30()
434                     mname = self.multinames[index]
435                     arg_count = u30()
436                     args = list(reversed(
437                         [stack.pop() for _ in range(arg_count)]))
438                     obj = stack.pop()
439                     if mname == 'reverse':
440                         assert isinstance(obj, list)
441                         obj.reverse()
442                     else:
443                         raise NotImplementedError(
444                             'Unsupported (void) property %r on %r'
445                             % (mname, obj))
446                 elif opcode == 86:  # newarray
447                     arg_count = u30()
448                     arr = []
449                     for i in range(arg_count):
450                         arr.append(stack.pop())
451                     arr = arr[::-1]
452                     stack.append(arr)
453                 elif opcode == 94:  # findproperty
454                     index = u30()
455                     mname = self.multinames[index]
456                     for s in reversed(scopes):
457                         if mname in s:
458                             res = s
459                             break
460                     else:
461                         res = scopes[0]
462                     stack.append(res)
463                 elif opcode == 96:  # getlex
464                     index = u30()
465                     mname = self.multinames[index]
466                     for s in reversed(scopes):
467                         if mname in s:
468                             scope = s
469                             break
470                     else:
471                         scope = scopes[0]
472                     # I cannot find where static variables are initialized
473                     # so let's just return None
474                     res = scope.get(mname)
475                     stack.append(res)
476                 elif opcode == 97:  # setproperty
477                     index = u30()
478                     value = stack.pop()
479                     idx = self.multinames[index]
480                     obj = stack.pop()
481                     obj[idx] = value
482                 elif opcode == 98:  # getlocal
483                     index = u30()
484                     stack.append(registers[index])
485                 elif opcode == 99:  # setlocal
486                     index = u30()
487                     value = stack.pop()
488                     registers[index] = value
489                 elif opcode == 102:  # getproperty
490                     index = u30()
491                     pname = self.multinames[index]
492                     if pname == 'length':
493                         obj = stack.pop()
494                         assert isinstance(obj, list)
495                         stack.append(len(obj))
496                     else:  # Assume attribute access
497                         idx = stack.pop()
498                         assert isinstance(idx, int)
499                         obj = stack.pop()
500                         assert isinstance(obj, list)
501                         stack.append(obj[idx])
502                 elif opcode == 115:  # convert_
503                     value = stack.pop()
504                     intvalue = int(value)
505                     stack.append(intvalue)
506                 elif opcode == 128:  # coerce
507                     u30()
508                 elif opcode == 133:  # coerce_s
509                     assert isinstance(stack[-1], (type(None), compat_str))
510                 elif opcode == 160:  # add
511                     value2 = stack.pop()
512                     value1 = stack.pop()
513                     res = value1 + value2
514                     stack.append(res)
515                 elif opcode == 161:  # subtract
516                     value2 = stack.pop()
517                     value1 = stack.pop()
518                     res = value1 - value2
519                     stack.append(res)
520                 elif opcode == 164:  # modulo
521                     value2 = stack.pop()
522                     value1 = stack.pop()
523                     res = value1 % value2
524                     stack.append(res)
525                 elif opcode == 175:  # greaterequals
526                     value2 = stack.pop()
527                     value1 = stack.pop()
528                     result = value1 >= value2
529                     stack.append(result)
530                 elif opcode == 208:  # getlocal_0
531                     stack.append(registers[0])
532                 elif opcode == 209:  # getlocal_1
533                     stack.append(registers[1])
534                 elif opcode == 210:  # getlocal_2
535                     stack.append(registers[2])
536                 elif opcode == 211:  # getlocal_3
537                     stack.append(registers[3])
538                 elif opcode == 214:  # setlocal_2
539                     registers[2] = stack.pop()
540                 elif opcode == 215:  # setlocal_3
541                     registers[3] = stack.pop()
542                 else:
543                     raise NotImplementedError(
544                         'Unsupported opcode %d' % opcode)
545
546         avm_class.method_pyfunctions[func_name] = resfunc
547         return resfunc
548