[aes/model] Extend crypto with support for CBC and CTR modes

Signed-off-by: Pirmin Vogel <vogelpi@lowrisc.org>
diff --git a/hw/ip/aes/dv/aes_model_dpi/aes_model_dpi.c b/hw/ip/aes/dv/aes_model_dpi/aes_model_dpi.c
index 67cbafc..1ae6b33 100644
--- a/hw/ip/aes/dv/aes_model_dpi/aes_model_dpi.c
+++ b/hw/ip/aes/dv/aes_model_dpi/aes_model_dpi.c
@@ -46,9 +46,9 @@
     unsigned char iv[16];
     memset(iv, 0, 16);
     if (!op_i) {
-      crypto_encrypt(ref_out, iv, ref_in, 16, key, key_len);
+      crypto_encrypt(ref_out, iv, ref_in, 16, key, key_len, kCryptoAesEcb);
     } else {
-      crypto_decrypt(ref_out, iv, ref_in, 16, key, key_len);
+      crypto_decrypt(ref_out, iv, ref_in, 16, key, key_len, kCryptoAesEcb);
     }
   }
 
diff --git a/hw/ip/aes/model/aes_example.c b/hw/ip/aes/model/aes_example.c
index f0ffcec..20a7ce0 100644
--- a/hw/ip/aes/model/aes_example.c
+++ b/hw/ip/aes/model/aes_example.c
@@ -213,8 +213,8 @@
   }
 
   // check state vs BoringSSL/OpenSSL
-  cipher_text_len =
-      crypto_encrypt(cipher_text, iv, plain_text, 16, key, key_len);
+  cipher_text_len = crypto_encrypt(cipher_text, iv, plain_text, 16, key,
+                                   key_len, kCryptoAesEcb);
   if (!check_block(state, cipher_text, 0)) {
     printf("SUCCESS: state matches %s cipher text\n", crypto_lib);
   } else {
@@ -333,8 +333,8 @@
   }
 
   // check state vs BoringSSL/OpenSSL
-  crypto_decrypt(decrypted_text, iv, cipher_text, cipher_text_len, key,
-                 key_len);
+  crypto_decrypt(decrypted_text, iv, cipher_text, cipher_text_len, key, key_len,
+                 kCryptoAesEcb);
   if (!check_block(state, decrypted_text, 0)) {
     printf("SUCCESS: state matches %s decrypted text\n", crypto_lib);
   } else {
diff --git a/hw/ip/aes/model/crypto.c b/hw/ip/aes/model/crypto.c
index 56c3ff2..6bbaab8 100644
--- a/hw/ip/aes/model/crypto.c
+++ b/hw/ip/aes/model/crypto.c
@@ -5,9 +5,52 @@
 #include <openssl/conf.h>
 #include <openssl/evp.h>
 
+#include "crypto.h"
+
+/**
+ * Get EVP_CIPHER type pointer defined by key_len and mode.
+ * If the selected cipher is not supported, the AES-128 ECB type is returned.
+ *
+ * @param  key_len   Encryption key length in bytes (16, 24, 32)
+ * @param  mode      AES cipher mode @see crypto_mode.
+ * @return Pointer to EVP_CIPHER type
+ */
+static const EVP_CIPHER *crypto_get_EVP_cipher(int key_len,
+                                               crypto_mode_t mode) {
+  const EVP_CIPHER *cipher;
+
+  if (mode == kCryptoAesCbc) {
+    if (key_len == 32) {
+      cipher = EVP_aes_256_cbc();
+    } else if (key_len == 24) {
+      cipher = EVP_aes_192_cbc();
+    } else {  // key_len = 16
+      cipher = EVP_aes_128_cbc();
+    }
+  } else if (mode == kCryptoAesCtr) {
+    if (key_len == 32) {
+      cipher = EVP_aes_256_ctr();
+    } else if (key_len == 24) {
+      cipher = EVP_aes_192_ctr();
+    } else {  // key_len = 16
+      cipher = EVP_aes_128_ctr();
+    }
+  } else {  // kCryptoAesEcb
+    if (key_len == 32) {
+      cipher = EVP_aes_256_ecb();
+    } else if (key_len == 24) {
+      cipher = EVP_aes_192_ecb();
+    } else {  // key_len = 16
+      cipher = EVP_aes_128_ecb();
+    }
+  }
+
+  return cipher;
+}
+
 int crypto_encrypt(unsigned char *output, const unsigned char *iv,
-                   const unsigned char *input, const int input_len,
-                   const unsigned char *key, const int key_len) {
+                   const unsigned char *input, int input_len,
+                   const unsigned char *key, int key_len, crypto_mode_t mode) {
   EVP_CIPHER_CTX *ctx;
   int ret;
   int len, output_len;
@@ -19,14 +62,12 @@
     return -1;
   }
 
+  // Get cipher
+  const EVP_CIPHER *cipher = crypto_get_EVP_cipher(key_len, mode);
+
   // Init encryption context
-  if (key_len == 16) {
-    ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, key, iv);
-  } else if (key_len == 24) {
-    ret = EVP_EncryptInit_ex(ctx, EVP_aes_192_ecb(), NULL, key, iv);
-  } else {  // key_len = 32
-    ret = EVP_EncryptInit_ex(ctx, EVP_aes_256_ecb(), NULL, key, iv);
-  }
+  ret = EVP_EncryptInit_ex(ctx, cipher, NULL, key, iv);
+
   if (ret != 1) {
     printf("ERROR: Initialization of encryption context failed\n");
     return -1;
@@ -58,8 +99,8 @@
 }
 
 int crypto_decrypt(unsigned char *output, const unsigned char *iv,
-                   const unsigned char *input, const int input_len,
-                   const unsigned char *key, const int key_len) {
+                   const unsigned char *input, int input_len,
+                   const unsigned char *key, int key_len, crypto_mode_t mode) {
   EVP_CIPHER_CTX *ctx;
   int ret;
   int len, output_len;
@@ -71,14 +112,11 @@
     return -1;
   }
 
+  // Get cipher
+  const EVP_CIPHER *cipher = crypto_get_EVP_cipher(key_len, mode);
+
   // Init decryption context
-  if (key_len == 16) {
-    ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_ecb(), NULL, key, iv);
-  } else if (key_len == 24) {
-    ret = EVP_DecryptInit_ex(ctx, EVP_aes_192_ecb(), NULL, key, iv);
-  } else {  // key_len == 32
-    ret = EVP_DecryptInit_ex(ctx, EVP_aes_256_ecb(), NULL, key, iv);
-  }
+  ret = EVP_DecryptInit_ex(ctx, cipher, NULL, key, iv);
   if (ret != 1) {
     printf("ERROR: Initialization of decryption context failed\n");
     return -1;
diff --git a/hw/ip/aes/model/crypto.h b/hw/ip/aes/model/crypto.h
index 1f92abb..ba22e5e 100644
--- a/hw/ip/aes/model/crypto.h
+++ b/hw/ip/aes/model/crypto.h
@@ -6,6 +6,15 @@
 #define CRYPTO_H_
 
 /**
+ * AES cipher mode
+ */
+typedef enum crypto_mode {
+  kCryptoAesEcb = 1 << 0,
+  kCryptoAesCbc = 1 << 1,
+  kCryptoAesCtr = 1 << 2
+} crypto_mode_t;
+
+/**
  * Encrypt using BoringSSL/OpenSSL
  *
  * @param  output    Output cipher text, must be a multiple of 16 bytes
@@ -15,11 +24,12 @@
  *                   of 16
  * @param  key       Encryption key
  * @param  key_len   Encryption key length in bytes (16, 24, 32)
+ * @param  mode      AES cipher mode @see crypto_mode.
  * @return Length of the output cipher text in bytes, -1 in case of error
  */
 int crypto_encrypt(unsigned char *output, const unsigned char *iv,
-                   const unsigned char *input, const int input_len,
-                   const unsigned char *key, const int key_len);
+                   const unsigned char *input, int input_len,
+                   const unsigned char *key, int key_len, crypto_mode_t mode);
 
 /**
  * Decrypt using BoringSSL/OpenSSL
@@ -31,10 +41,11 @@
  *                   multiple of 16
  * @param  key       Encryption key, decryption key is derived internally
  * @param  key_len   Encryption key length in bytes (16, 24, 32)
+ * @param  mode      AES cipher mode @see crypto_mode.
  * @return Length of the output plain text in bytes, -1 in case of error
  */
 int crypto_decrypt(unsigned char *output, const unsigned char *iv,
-                   const unsigned char *input, const int input_len,
-                   const unsigned char *key, const int key_len);
+                   const unsigned char *input, int input_len,
+                   const unsigned char *key, int key_len, crypto_mode_t mode);
 
 #endif  // CRYPTO_H_