add an aes implementation
[youtube-dl] / youtube_dl / aes.py
1 __all__ = ['aes_encrypt', 'key_expansion', 'aes_ctr_decrypt', 'aes_decrypt_text']
2
3 import base64
4 from math import ceil
5
6 BLOCK_SIZE_BYTES = 16
7
8 def aes_ctr_decrypt(data, key, counter):
9     """
10     Decrypt with aes in counter mode
11     
12     @param {int[]} data        cipher
13     @param {int[]} key         16/24/32-Byte cipher key
14     @param {instance} counter  Instance whose next_value function (@returns {int[]}  16-Byte block)
15                                returns the next counter block
16     @returns {int[]}           decrypted data
17     """
18     expanded_key = key_expansion(key)
19     block_count = int(ceil(float(len(data)) / BLOCK_SIZE_BYTES))
20     
21     decrypted_data=[]
22     for i in range(block_count):
23         counter_block = counter.next_value()
24         block = data[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES]
25         block += [0]*(BLOCK_SIZE_BYTES - len(block))
26         
27         cipher_counter_block = aes_encrypt(counter_block, expanded_key)
28         decrypted_data += xor(block, cipher_counter_block)
29     decrypted_data = decrypted_data[:len(data)]
30     
31     return decrypted_data
32
33 def key_expansion(data):
34     """
35     Generate key schedule
36     
37     @param {int[]} data  16/24/32-Byte cipher key
38     @returns {int[]}     176/208/240-Byte expanded key 
39     """
40     data = data[:] # copy
41     rcon_iteration = 1
42     key_size_bytes = len(data)
43     expanded_key_size_bytes = (key_size_bytes/4 + 7) * BLOCK_SIZE_BYTES
44     
45     while len(data) < expanded_key_size_bytes:
46         temp = data[-4:]
47         temp = key_schedule_core(temp, rcon_iteration)
48         rcon_iteration += 1
49         data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
50         
51         for _ in range(3):
52             temp = data[-4:]
53             data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
54         
55         if key_size_bytes == 32:
56             temp = data[-4:]
57             temp = sub_bytes(temp)
58             data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
59         
60         for _ in range(3 if key_size_bytes == 32  else 2 if key_size_bytes == 24 else 0):
61             temp = data[-4:]
62             data += xor(temp, data[-key_size_bytes : 4-key_size_bytes])
63     data = data[:expanded_key_size_bytes]
64     
65     return data
66
67 def aes_encrypt(data, expanded_key):
68     """
69     Encrypt one block with aes
70     
71     @param {int[]} data          16-Byte state
72     @param {int[]} expanded_key  176/208/240-Byte expanded key 
73     @returns {int[]}             16-Byte cipher
74     """
75     rounds = len(expanded_key) / BLOCK_SIZE_BYTES - 1
76     
77     data = xor(data, expanded_key[:BLOCK_SIZE_BYTES])
78     for i in range(1, rounds+1):
79         data = sub_bytes(data)
80         data = shift_rows(data)
81         if i != rounds:
82             data = mix_columns(data)
83         data = xor(data, expanded_key[i*BLOCK_SIZE_BYTES : (i+1)*BLOCK_SIZE_BYTES])
84     
85     return data
86
87 def aes_decrypt_text(data, password, key_size_bytes):
88     """
89     Decrypt text
90     - The first 8 Bytes of decoded 'data' are the 8 high Bytes of the counter
91     - The cipher key is retrieved by encrypting the first 16 Byte of 'password'
92       with the first 'key_size_bytes' Bytes from 'password' (if necessary filled with 0's)
93     - Mode of operation is 'counter'
94     
95     @param {str} data                    Base64 encoded string
96     @param {str,unicode} password        Password (will be encoded with utf-8)
97     @param {int} key_size_bytes          Possible values: 16 for 128-Bit, 24 for 192-Bit or 32 for 256-Bit
98     @returns {str}                       Decrypted data
99     """
100     NONCE_LENGTH_BYTES = 8
101     
102     data = map(lambda c: ord(c), base64.b64decode(data))
103     password = map(lambda c: ord(c), password.encode('utf-8'))
104     
105     key = password[:key_size_bytes] + [0]*(key_size_bytes - len(password))
106     key = aes_encrypt(key[:BLOCK_SIZE_BYTES], key_expansion(key)) * (key_size_bytes / BLOCK_SIZE_BYTES)
107     
108     nonce = data[:NONCE_LENGTH_BYTES]
109     cipher = data[NONCE_LENGTH_BYTES:]
110     
111     class Counter:
112         __value = nonce + [0]*(BLOCK_SIZE_BYTES - NONCE_LENGTH_BYTES)
113         def next_value(self):
114             temp = self.__value
115             self.__value = inc(self.__value)
116             return temp
117     
118     decrypted_data = aes_ctr_decrypt(cipher, key, Counter())
119     plaintext = ''.join(map(lambda x: chr(x), decrypted_data))
120     
121     return plaintext
122
123 RCON = (0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36)
124 SBOX = (0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76,
125         0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0,
126         0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15,
127         0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75,
128         0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84,
129         0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF,
130         0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8,
131         0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2,
132         0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73,
133         0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB,
134         0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79,
135         0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08,
136         0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A,
137         0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E,
138         0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF,
139         0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16)
140 MIX_COLUMN_MATRIX = ((2,3,1,1),
141                      (1,2,3,1),
142                      (1,1,2,3),
143                      (3,1,1,2))
144
145 def sub_bytes(data):
146     return map(lambda x: SBOX[x], data)
147
148 def rotate(data):
149     return data[1:] + [data[0]]
150
151 def key_schedule_core(data, rcon_iteration):
152     data = rotate(data)
153     data = sub_bytes(data)
154     data[0] = data[0] ^ RCON[rcon_iteration]
155     
156     return data
157
158 def xor(data1, data2):
159     return map(lambda (x,y): x^y, zip(data1, data2))
160
161 def mix_column(data):
162     data_mixed = []
163     for row in range(4):
164         mixed = 0
165         for column in range(4):
166             addend = data[column]
167             if MIX_COLUMN_MATRIX[row][column] in (2,3):
168                 addend <<= 1
169                 if addend > 0xff:
170                     addend &= 0xff
171                     addend ^= 0x1b
172                 if MIX_COLUMN_MATRIX[row][column] == 3:
173                     addend ^= data[column]
174             mixed ^= addend & 0xff
175         data_mixed.append(mixed)
176     return data_mixed
177
178 def mix_columns(data):
179     data_mixed = []
180     for i in range(4):
181         column = data[i*4 : (i+1)*4]
182         data_mixed += mix_column(column)
183     return data_mixed
184
185 def shift_rows(data):
186     data_shifted = []
187     for column in range(4):
188         for row in range(4):
189             data_shifted.append( data[((column + row) & 0b11) * 4 + row] )
190     return data_shifted
191
192 def inc(data):
193     data = data[:] # copy
194     for i in range(len(data)-1,-1,-1):
195         if data[i] == 255:
196             data[i] = 0
197         else:
198             data[i] = data[i] + 1
199             break
200     return data