blob: 1e2c713c507c21fac775613cc936721707c1ece4 [file] [log] [blame]
// Copyright lowRISC contributors.
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0
#include "aes_example.h"
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "aes.h"
#include "crypto.h"
#define DEBUG_LEVEL_ENC 0 // 0, 1, 2
#define DEBUG_LEVEL_DEC 0 // 0, 1, 2
#define EXAMPLE 1 // 0, 1
static int check_block(const unsigned char *actual,
const unsigned char *expected, const int print) {
for (int i = 0; i < 16; i++) {
if (actual[i] != expected[i]) {
if (print) {
printf("ERROR: block mismatch. Found %#x, expected %#x\n", actual[i],
expected[i]);
}
return 1;
}
}
return 0;
}
int main(int argc, char *argv[]) {
#ifdef USE_BORING_SSL
char crypto_lib[10] = "BoringSSL";
#else
char crypto_lib[10] = "OpenSSL";
#endif
int key_len = 16;
if (argc > 1) {
key_len = atoi(argv[1]);
if (key_len != 16 && key_len != 24 && key_len != 32) {
printf("WARNING: Unsupported key length %d, switching to 16 (AES-128)\n",
key_len);
key_len = 16;
}
if (argc > 2) {
printf(
"WARNING: Only 1 command line argument supported. Ignoring "
"further arguments\n");
}
}
int num_rounds = aes_get_num_rounds(key_len);
if (num_rounds < 0) {
return num_rounds;
}
/*
* Select plain_text, key and golden ciphertext from example
*/
const unsigned char *plain_text;
const unsigned char *key;
const unsigned char *cipher_text_gold;
if (EXAMPLE == 0) {
plain_text = plain_text_0;
if (key_len == 16) {
key = key_16_0;
cipher_text_gold = cipher_text_gold_16_0;
} else if (key_len == 24) {
key = key_24_0;
cipher_text_gold = cipher_text_gold_24_0;
} else { // key_len == 32
key = key_32_0;
cipher_text_gold = cipher_text_gold_32_0;
}
}
else { // EXAMPLE == 1
plain_text = plain_text_1;
if (key_len == 16) {
key = key_16_1;
cipher_text_gold = cipher_text_gold_16_1;
} else if (key_len == 24) {
key = key_24_1;
cipher_text_gold = cipher_text_gold_24_1;
} else { // key_len == 32
key = key_32_1;
cipher_text_gold = cipher_text_gold_32_1;
}
}
// libcrypto-related variables and buffers
unsigned char *iv = (unsigned char *)"0000000000000000";
int cipher_text_len;
unsigned char cipher_text[16];
unsigned char decrypted_text[16];
printf("Encryption key:\t\t");
aes_print_block((const unsigned char *)key, 16);
if (key_len > 16) {
printf("\t\t\t");
aes_print_block((const unsigned char *)&key[16], key_len - 16);
}
printf("Input data:\t\t");
aes_print_block((const unsigned char *)plain_text, 16);
printf("\n");
// intermediate buffers
unsigned char *full_key =
(unsigned char *)malloc(key_len * sizeof(unsigned char));
if (full_key == NULL) {
printf("ERROR: malloc() failed\n");
return -ENOMEM;
}
unsigned char state[16];
unsigned char state_lib[16];
unsigned char round_key[16];
unsigned char inv_round_key[16];
unsigned char rcon;
//
// ENCRYPTION
//
// init
// copy plain text to state
for (int i = 0; i < 16; i++) {
state[i] = plain_text[i];
}
// copy first 16 bytes of key to round key
for (int i = 0; i < 16; i++) {
round_key[i] = key[i];
}
// copy key to long key
for (int i = 0; i < key_len; i++) {
full_key[i] = key[i];
}
// reset rcon
rcon = 0x00;
if (DEBUG_LEVEL_ENC > 0) {
printf("Init input:\t\t");
aes_print_block((const unsigned char *)state, 16);
printf("Init key:\t\t");
aes_print_block((const unsigned char *)round_key, 16);
}
// add round key
aes_add_round_key(state, round_key);
// rounds
for (int j = 0; j < num_rounds; j++) {
if (DEBUG_LEVEL_ENC > 0) {
printf("Round %i input:\t\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// sub bytes
aes_sub_bytes(state);
if (DEBUG_LEVEL_ENC > 1) {
printf("Round %i SubBytes:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// shift rows
aes_shift_rows(state);
if (DEBUG_LEVEL_ENC > 1) {
printf("Round %i ShiftRows:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// mix columns
if (j < (num_rounds - 1)) {
aes_mix_columns(state);
if (DEBUG_LEVEL_ENC > 1) {
printf("Round %i MixColumns:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
}
// expand key
aes_key_expand(round_key, full_key, key_len, &rcon, j);
if (DEBUG_LEVEL_ENC > 0) {
printf("Round %i key:\t\t", j);
aes_print_block((const unsigned char *)round_key, 16);
}
// add round key
aes_add_round_key(state, round_key);
}
// print
printf("Encryption output:\t");
aes_print_block((const unsigned char *)state, 16);
printf("\n");
// check state vs AES model/lib
aes_encrypt_block(plain_text, key, key_len, state_lib);
if (check_block(state, state_lib, 0)) {
printf("ERROR: state does not match AES model output\n");
}
// check state versus gold
if (!check_block(state, cipher_text_gold, 1)) {
printf("SUCCESS: state matches golden cipher text\n");
} else {
printf("ERROR: state does not match golden cipher text\n");
}
// check state vs BoringSSL/OpenSSL
cipher_text_len = crypto_encrypt(cipher_text, iv, plain_text, 16, key,
key_len, kCryptoAesEcb);
if (!check_block(state, cipher_text, 0)) {
printf("SUCCESS: state matches %s cipher text\n", crypto_lib);
} else {
printf("ERROR: state does not match %s cipher text\n", crypto_lib);
return 0;
}
printf("\n");
//
// DECRYPTION using Equivalent Inverse Cipher
//
//
// generate decryption key
//
// copy first 16 bytes of key to round key
for (int i = 0; i < 16; i++) {
round_key[i] = key[i];
}
// copy key to long key
for (int i = 0; i < key_len; i++) {
full_key[i] = key[i];
}
// reset rcon
rcon = 0x00;
for (int j = 0; j < num_rounds; j++) {
aes_key_expand(round_key, full_key, key_len, &rcon, j);
}
// init
// cyper text is already in state
// round key is prepared already
// reset rcon
rcon = 0x00;
if (DEBUG_LEVEL_DEC > 0) {
printf("Init input:\t\t");
aes_print_block((const unsigned char *)state, 16);
printf("Init key:\t\t");
aes_print_block((const unsigned char *)round_key, 16);
}
// add round key
aes_add_round_key(state, round_key);
// rounds
for (int j = 0; j < num_rounds; j++) {
if (DEBUG_LEVEL_DEC > 0) {
printf("Round %i input:\t\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// sub bytes
aes_inv_sub_bytes(state);
if (DEBUG_LEVEL_DEC > 1) {
printf("Round %i InvSubBytes:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// shift rows
aes_inv_shift_rows(state);
if (DEBUG_LEVEL_DEC > 1) {
printf("Round %i InvShiftRows:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
// mix columns
if (j < (num_rounds - 1)) {
aes_inv_mix_columns(state);
if (DEBUG_LEVEL_DEC > 1) {
printf("Round %i InvMixColumns:\t", j);
aes_print_block((const unsigned char *)state, 16);
}
}
// expand key
aes_inv_key_expand(round_key, full_key, key_len, &rcon, j);
if (DEBUG_LEVEL_DEC > 0) {
printf("Round %i key:\t\t", j);
aes_print_block((const unsigned char *)round_key, 16);
}
// mix columns on round key
for (int i = 0; i < 16; i++) {
inv_round_key[i] = round_key[i];
}
if (j < (num_rounds - 1)) {
aes_inv_mix_columns(inv_round_key);
if (DEBUG_LEVEL_DEC > 0) {
printf("Round %i mixed key:\t", j);
aes_print_block((const unsigned char *)inv_round_key, 16);
}
}
// add round key
aes_add_round_key(state, inv_round_key);
}
// print
printf("Decryption Output:\t");
aes_print_block((const unsigned char *)state, 16);
printf("\n");
// check state vs AES model/lib
aes_decrypt_block(cipher_text, key, key_len, state_lib);
if (check_block(state, state_lib, 0)) {
printf("ERROR: state does not match AES model output\n");
}
// check state versus gold/plain_text
if (!check_block(state, plain_text, 1)) {
printf("SUCCESS: state matches expected plain text\n");
} else {
printf("ERROR: state does not match expected plain text\n");
}
// check state vs BoringSSL/OpenSSL
crypto_decrypt(decrypted_text, iv, cipher_text, cipher_text_len, key, key_len,
kCryptoAesEcb);
if (!check_block(state, decrypted_text, 0)) {
printf("SUCCESS: state matches %s decrypted text\n", crypto_lib);
} else {
printf("ERROR: state does not match %s decrypted text\n", crypto_lib);
return 0;
}
return 0;
}