blob: 9613be11f520cebe5271b47aa546fea7b1d4541c [file] [log] [blame]
// RUN: rm -f %t_main.vmfb %t_encoder.mlir %t_encoder.vmfb %t_input.irpa %t_output.irpa
//
// Compile main module with encoder MLIR output and splat parameter export.
// RUN: iree-compile %s \
// RUN: --iree-hal-target-device=local \
// RUN: --iree-hal-local-target-device-backends=vmvx \
// RUN: --iree-parameter-encoder-output-file=%t_encoder.mlir \
// RUN: --iree-parameter-splat=%t_input.irpa \
// RUN: -o %t_main.vmfb
//
// Compile the encoder module separately.
// RUN: iree-compile %t_encoder.mlir \
// RUN: --iree-hal-target-device=local \
// RUN: --iree-hal-local-target-device-backends=vmvx \
// RUN: -o %t_encoder.vmfb
//
// Run the encoder to transform parameters.
// RUN: iree-encode-parameters \
// RUN: --module=%t_encoder.vmfb \
// RUN: --parameters=model=%t_input.irpa \
// RUN: --output=encoded=%t_output.irpa \
// RUN: --quiet
//
// Run the main module with both input and encoded parameters.
// The encoded parameters contain the pre-computed transformed values.
// RUN: iree-run-module \
// RUN: --device=local-sync \
// RUN: --module=%t_main.vmfb \
// RUN: --function=main \
// RUN: --parameters=model=%t_input.irpa \
// RUN: --parameters=encoded=%t_output.irpa | \
// RUN: FileCheck %s
// Test parameter transformation with encoder.
// The global loads a parameter and applies an add operation to transform it.
// The encoder runs the add offline, and the main module loads the
// pre-computed result from the encoded parameter scope.
// CHECK-LABEL: EXEC @main
// CHECK: 256xi32=42 42 42 42
// Parameter loaded from input archive (model scope).
// The splat export creates this with all zeros.
util.global private @raw_param = #flow.parameter.named<"model"::"param_global"> : tensor<256xi32>
// This global holds the transformed value.
util.global private @transformed : tensor<256xi32>
util.initializer {
// Load the raw parameter (all zeros from splat).
%raw = util.global.load @raw_param : tensor<256xi32>
// Add 42 to each element - this uses the parameter values and can be encoded.
// With input of 0s, result is 42s.
%c42 = arith.constant 42 : i32
%init = tensor.empty() : tensor<256xi32>
%c42_tensor = linalg.fill ins(%c42 : i32) outs(%init : tensor<256xi32>) -> tensor<256xi32>
%added = linalg.add ins(%raw, %c42_tensor : tensor<256xi32>, tensor<256xi32>) outs(%init : tensor<256xi32>) -> tensor<256xi32>
util.global.store %added, @transformed : tensor<256xi32>
util.return
}
func.func @main() -> tensor<256xi32> {
// Load and return the full transformed tensor.
// If encoding worked, all elements should be 42 (0 + 42).
// If encoding didn't work, all elements would be 0 (splat init).
%tensor = util.global.load @transformed : tensor<256xi32>
return %tensor : tensor<256xi32>
}