from operators import multiply_operator as x
Nb = 4
Nr = 3
Nk = 4
k = [[2, 3, 1, 1],
[1, 2, 3, 1],
[1, 1, 2, 3],
[3, 1, 1, 2]]
inv_k = [[0x0e,0x0b,0x0d,0x09],
[0x09,0x0e,0x0b,0x0d],
[0x0d,0x09,0x0e,0x0b],
[0x0b,0x0d,0x09,0x0e]]
def get_r_const():
return [[0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36],
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],
[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]]
def get_s_block(name='s.ini'):
with open(name, 'r') as s_file:
return [[int(elem, 16) for elem in line.strip('\n').split('\t')] for line in s_file.readlines()]
def to_block(input_bytes, nb=4):
block = [[] for i in range(nb)]
for row in range(4):
for column in range(nb):
block[row].append(input_bytes[row + 4*column])
return block
def to_block2(input_bytes, nb=4):
block = [[] for i in range(nb)]
i = 0
for row in range(4):
for column in range(nb):
block[row].append(input_bytes[i])
i += 1
return block
def from_block(block):
output = [None for i in range(4*Nb)]
for row in range(4):
for column in range(Nb):
output[row + 4*column] = hex(block[row][column])[2:].zfill(2)
return output
def from_block2(block):
a = [[hex(cell)[2:].zfill(2) for cell in row] for row in block]
return a
def sub_bytes(block, s_block):
for row in block:
for column in range(len(row)):
x = row[column] >> 4
y = row[column] & 0xf
row[column] = s_block[x][y]
return block
def shift_rows(block):
for row in range(len(block)):
block[row] = block[row][row:] + block[row][:row]
return block
def inv_shift_rows(block):
for row in range(len(block)):
block[row] = block[row][-row:] + block[row][:-row]
return block
def mix_column(block, k):
s = {}
for column in range(len(block[0])):
s[0] = (block[0][column] |x| k[0][0]) ^ (block[1][column] |x| k[0][1]) ^ (block[2][column] |x| k[0][2]) ^ (block[3][column] |x| k[0][3])
s[1] = (block[0][column] |x| k[1][0]) ^ (block[1][column] |x| k[1][1]) ^ (block[2][column] |x| k[1][2]) ^ (block[3][column] |x| k[1][3])
s[2] = (block[0][column] |x| k[2][0]) ^ (block[1][column] |x| k[2][1]) ^ (block[2][column] |x| k[2][2]) ^ (block[3][column] |x| k[2][3])
s[3] = (block[0][column] |x| k[3][0]) ^ (block[1][column] |x| k[3][1]) ^ (block[2][column] |x| k[3][2]) ^ (block[3][column] |x| k[3][3])
block[0][column], block[1][column], block[2][column], block[3][column] = s[0], s[1], s[2], s[3]
return block
def add_round_key(block, key):
for row in range(len(block)):
for column in range(len(block[row])):
block[row][column] ^= key[row][column]
return block
def get_rot_word(block, index):
rot_word = []
for row in block:
rot_word.append(row[index])
return rot_word
def shift_rot_word(rot_word):
return rot_word[1:] + rot_word[:1]
def sub_words(rot_words, s_block):
for ind, byte in enumerate(rot_words):
x = byte >> 4
y = byte & 0xf
rot_words[ind] = s_block[x][y]
return rot_words
def key_expansion(key, s_block):
key_schedule = key[:]
for w in range(4, Nb*(Nr+1)):
for row in range(4):
if w % 4 == 0:
rot_word = get_rot_word(key_schedule, w-1)
rot_word = shift_rot_word(rot_word)
rot_word = sub_words(rot_word, s_block)
key_schedule[row] += [key_schedule[row][w-4] ^ rot_word[row] ^ get_r_const()[row][w-4]]
else:
key_schedule[row] += [key_schedule[row][w-1] ^ key_schedule[row][w-4]]
return key_schedule
def key_expansion2(key, s_block):
key_schedule = key[:]
r_const = get_r_const()
for w in range(4, Nb*(Nr+1)):
key_schedule.append([])
for column in range(4):
if w % 4 == 0:
rot_word = key_schedule[w-1]
rot_word = shift_rot_word(rot_word)
rot_word = sub_words(rot_word, s_block)
key_schedule[w] += [key_schedule[w-4][column] ^ rot_word[column] ^ r_const[3-column][w/4-1]]
else:
key_schedule[w] += [key_schedule[w-1][column] ^ key_schedule[w-4][column]]
return key_schedule
def get_round_key(key_schedule, rnd):
key = []
for row in key_schedule:
key += [row[rnd*Nk:(rnd+1)*Nk]]
return key
def crypt(input_bytes, input_key, s_block):
block = to_block2(input_bytes)
key = to_block2(input_key)
key_schedule = key_expansion2(key, s_block)
block = add_round_key(block, key)
for i in range(1, Nr):
block = sub_bytes(block, s_block)
block = shift_rows(block)
block = mix_column(block, k)
#key = get_round_key(key_schedule, i)
key = key_schedule[i*4:(i+1)*4]
block = add_round_key(block, key)
block = sub_bytes(block, s_block)
block = shift_rows(block)
key = key_schedule[Nr*4:(Nr+1)*4]
#key = get_round_key(key_schedule, Nr)
block = add_round_key(block, key)
return block
def pretty_print(string, block):
print string
for i in block:
print i
def decrypt(input_bytes, input_key):
block = to_block2(input_bytes)
key = to_block2(input_key)
key_schedule = key_expansion2(key, s_block)
pretty_print('\nAll keys', from_block2(key_schedule[:]))
key_schedule = key_schedule[:4] + mix_column(key_schedule[4:8], inv_k) +\
mix_column(key_schedule[8:12], inv_k) + key_schedule[12:]
pretty_print('\nAll keys after mix column', from_block2(key_schedule[:]))
key = key_schedule[Nr*4:(Nr+1)*4]
block = add_round_key(block, key)
pretty_print('\nStep 1: add 4th key', from_block2(block))
for i in range(Nr-1, 0, -1):
print '\nRound: %d' % (Nr - i)
block = sub_bytes(block, inv_s_block)
pretty_print('Step %d: substitution bytes by s-block' %(Nr - i + 1), from_block2(block))
block = inv_shift_rows(block)
pretty_print('\nStep %d: shift rows' %(Nr - i + 2), from_block2(block))
block = mix_column(block, inv_k)
pretty_print('\nStep %d: mix column' %(Nr - i + 3), from_block2(block))
#key = get_round_key(key_schedule, i)
key = key_schedule[i*4:(i+1)*4]
block = add_round_key(block, key)
pretty_print('\nStep %d: add %dth key' %(Nr - i + 4, i+1), from_block2(block))
print '\nRound: %d' % (Nr)
block = sub_bytes(block, inv_s_block)
pretty_print('Step %d: substitution bytes by s-block' %(Nr + 4), from_block2(block))
block = inv_shift_rows(block)
pretty_print('\nStep %d: shift rows' %(Nr + 5), from_block2(block))
#key = get_round_key(key_schedule, Nr)
block = add_round_key(block, key_schedule[:4])
pretty_print('\nStep %d: add %dth key' %(Nr + 6, 1), from_block2(block))
return block
s_block = get_s_block()
inv_s_block = get_s_block('inv_s.ini')
if __name__ == '__main__':
input_bytes = [0x8E, 0xA5, 0xD6, 0x5A,
0xB4, 0xAE, 0x68, 0x93,
0xAB, 0x56, 0x81, 0x0B,
0xC0, 0x7C, 0xA5, 0x59,
0x8E, 0xA5, 0xD6, 0x5A,
0xB4, 0xAE, 0x68, 0x93,
0xAB, 0x56, 0x81, 0x0B,
0xC0, 0x7C, 0xA5, 0x59]
input_key = [0x2f, 0x22, 0x2e, 0x27, 0x30, 0x3f, 0x2d, 0x2b, 0x2c, 0x00, 0x25, 0x27, 0x30, 0x2b, 0x2c, 0x00]
print '\n0-16 bytes'
a = decrypt(input_bytes[:16], input_key)
pretty_print('\nFinal:', from_block2(a))
print '\n', '='*70
print '\n16-32 bytes'
a2 = decrypt(input_bytes[16:], input_key)
pretty_print('\nFinal:', from_block2(a2))