[spirv] Migrate tests for LinalgTileAndDistributePass (#5610)
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
index c2f5d19..5355724 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -480,14 +480,18 @@
// memory available (maybe). For now, just hard-wire it.
tileSizeK = 32;
}
- assert(tileSizes.empty());
- int64_t M = op.inputs()[0].getType().cast<ShapedType>().getShape()[0];
- int64_t N = op.inputs()[1].getType().cast<ShapedType>().getShape()[1];
- int64_t K = op.inputs()[0].getType().cast<ShapedType>().getShape()[1];
+
+ SmallVector<ShapedType> inputTypes;
+ std::tie(inputTypes, std::ignore) = getInputOutputTypes(op);
+ int64_t M = inputTypes[0].getShape()[0];
+ int64_t N = inputTypes[1].getShape()[1];
+ int64_t K = inputTypes[0].getShape()[1];
+
SmallVector<int64_t, 4> ts = {
getMinIfShapeStatic(M, nRowsPerWorkitem * config.workgroupSize[1]),
getMinIfShapeStatic(N, nColsPerWorkitem * config.workgroupSize[0]),
getMinIfShapeStatic(K, tileSizeK)};
+ assert(tileSizes.empty());
tileSizes.emplace_back(std::move(ts));
return success();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
index 2b3f9a6..f6c84b2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/BUILD
@@ -27,28 +27,27 @@
name = "lit",
srcs = enforce_glob(
[
- "batch_matmul_vectorization.mlir",
"concretize_tile_among_workgroups.mlir",
"concretize_tile_among_workgroups_dynamic.mlir",
"convert_to_gpu.mlir",
"convert_to_spirv.mlir",
"dead_alloc.mlir",
- "elementwise_vectorization.mlir",
"fold-gpu-procid-uses.mlir",
"forop_canonicalization.mlir",
"materialize_launch_configuration.mlir",
"materialize_launch_configuration2.mlir",
- "matmul_fused_vectorization.mlir",
- "matmul_vectorization.mlir",
"matmul_vectorization_licm.mlir",
- "vectorize_memref_load_store.mlir",
"pipeline_matmul_vectorization.mlir",
"pipeline_test_cooperative_mat.mlir",
"promote_workgroup_memory.mlir",
"split_dispatch_function.mlir",
+ "tile_and_vectorize_batch_matmul.mlir",
"tile_and_vectorize_conv.mlir",
"tile_and_vectorize_matmul.mlir",
"vector_to_gpu.mlir",
+ "vectorize_elementwise_ops.mlir",
+ "vectorize_matmul.mlir",
+ "vectorize_memref_load_store.mlir",
],
include = ["*.mlir"],
),
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
index a00bf0a..45673f5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/CMakeLists.txt
@@ -14,27 +14,26 @@
NAME
lit
SRCS
- "batch_matmul_vectorization.mlir"
"concretize_tile_among_workgroups.mlir"
"concretize_tile_among_workgroups_dynamic.mlir"
"convert_to_gpu.mlir"
"convert_to_spirv.mlir"
"dead_alloc.mlir"
- "elementwise_vectorization.mlir"
"fold-gpu-procid-uses.mlir"
"forop_canonicalization.mlir"
"materialize_launch_configuration.mlir"
"materialize_launch_configuration2.mlir"
- "matmul_fused_vectorization.mlir"
- "matmul_vectorization.mlir"
"matmul_vectorization_licm.mlir"
"pipeline_matmul_vectorization.mlir"
"pipeline_test_cooperative_mat.mlir"
"promote_workgroup_memory.mlir"
"split_dispatch_function.mlir"
+ "tile_and_vectorize_batch_matmul.mlir"
"tile_and_vectorize_conv.mlir"
"tile_and_vectorize_matmul.mlir"
"vector_to_gpu.mlir"
+ "vectorize_elementwise_ops.mlir"
+ "vectorize_matmul.mlir"
"vectorize_memref_load_store.mlir"
DATA
iree::tools::IreeFileCheck
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
deleted file mode 100644
index 458e135..0000000
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_fused_vectorization.mlir
+++ /dev/null
@@ -1,63 +0,0 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
-
-hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
- hal.interface @io {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- hal.executable.target @vulkan, filter="dylib*" {
- hal.executable.entry_point @matmul_static_shape attributes {
- interface = @io, ordinal = 0 : index,
- signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
- !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
- module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
- StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
- UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
- GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
- GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
- VariablePointersStorageBuffer],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
- SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
- ARM:IntegratedGPU,
- {max_compute_shared_memory_size = 32768 : i32,
- max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<512> : vector<3xi32>,
- subgroup_size = 16 : i32}>} {
- func @matmul_static_shape()
- attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf32>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf32>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf32>
- %cst = constant 0.000000e+00 : f32
- linalg.fill(%ret0, %cst) : memref<4096x4096xf32>, f32
- linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf32>, memref<4096x4096xf32>)
- outs(%ret0 : memref<4096x4096xf32>)
- return
- }
- func private @matmul_static_shape__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
- hal.interface @io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
- }
- }
- }
-}
-// CHECK-LABEL: func @matmul_static_shape
-// CHECK-COUNT-8: vector.transfer_write
-// CHECK-COUNT-8: vector.transfer_read
-// CHECK: %[[FOR_RES:.+]]:8 = scf.for
-// CHECK-COUNT-12: vector.transfer_read
-// CHECK-COUNT-32: vector.contract
-// CHECK: scf.yield
-// CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]]
-// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
similarity index 62%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
index da4dc0a..4cdaeeb 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/batch_matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_batch_matmul.mlir
@@ -1,46 +1,53 @@
-// RUN: iree-opt -split-input-file -iree-spirv-workgroup-tile-size=1,8,64,4 -iree-spirv-invocation-tile-size=1,8,4,4 -iree-spirv-workgroup-size=16,1,1 -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -canonicalize -cse -iree-spirv-workgroup-tile-size=1,8,64,4 -iree-spirv-invocation-tile-size=1,8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
hal.executable @batch_matmul_static_shape attributes {sym_visibility = "private"} {
- hal.interface @io {
+ hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
- hal.executable.target @vulkan, filter="dylib*" {
+ hal.executable.target @vulkan_spirv, filter="vulkan*" {
hal.executable.entry_point @batch_matmul_static_shape attributes {
interface = @io, ordinal = 0 : index,
- signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
- !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
- module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
- StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
- UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
- GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
- GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
- VariablePointersStorageBuffer],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
- SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
- ARM:IntegratedGPU,
- {max_compute_shared_memory_size = 32768 : i32,
- max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<512> : vector<3xi32>,
- subgroup_size = 16 : i32}>} {
- func @batch_matmul_static_shape()
- attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
- linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+ signature = (!flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<writeonly:4x1024x1024xf32>) -> ()}
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
+ func @batch_matmul_static_shape() {
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c1024 = constant 1024 : index
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x1024x1024xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x1024xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x1024x1024xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+ scf.for %arg0 = %3 to %c4 step %4 {
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg1 = %5 to %c1024 step %6 {
+ %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg2 = %7 to %c1024 step %8 {
+ %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
+ %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_y]
+ %11 = memref.subview %0[%arg0, %arg1, 0] [%9, %10, 1024] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
+ %13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ %14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+ }
+ }
+ }
return
}
- func private @matmul_static_shape__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -49,15 +56,15 @@
}
}
}
+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK: func @batch_matmul_static_shape
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0]
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0]
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0]
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
-// CHECK-DAG: %[[CST:.+]] = constant 0.0
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
@@ -65,11 +72,15 @@
// CHECK-DAG: %[[C5:.+]] = constant 5 : index
// CHECK-DAG: %[[C6:.+]] = constant 6 : index
// CHECK-DAG: %[[C7:.+]] = constant 7 : index
-// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
-// CHECK: %[[BIDZ:.+]] = "gpu.block_id"() {dimension = "z"}
+// CHECK: %[[BIDX:.+]] = hal.interface.workgroup.id[0]
+// CHECK: %[[BIDY:.+]] = hal.interface.workgroup.id[1]
+// CHECK: %[[BIDZ:.+]] = hal.interface.workgroup.id[2]
// CHECK-DAG: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK-DAG: %[[BOFFSET_X:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]]
+// CHECK: %[[SUBVIEW_ARG0:.+]] = memref.subview %[[ARG0]]
+// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], 0] [1, 8, 1024]
+// CHECK: %[[SUBVIEW_ARG1:.+]] = memref.subview %[[ARG1]]
+// CHECK-SAME: [%[[BIDZ]], 0, %[[BOFFSET_X]]] [1, 1024, 64]
// CHECK: %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], %[[BOFFSET_X]]] [1, 8, 64]
// CHECK: %[[IIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
@@ -105,12 +116,10 @@
// CHECK-SAME: %[[ACC_5:.+]] = %[[READ_INIT_5]],
// CHECK-SAME: %[[ACC_6:.+]] = %[[READ_INIT_6]],
// CHECK-SAME: %[[ACC_7:.+]] = %[[READ_INIT_7]])
-// CHECK-DAG: %[[SUBVIEW_LHS:.+]] = memref.subview %[[ARG0]]
-// CHECK-SAME: [%[[BIDZ]], %[[BOFFSET_Y]], %[[IV0]]] [1, 8, 4]
-// CHECK-DAG: %[[SUBVIEW_RHS:.+]] = memref.subview %[[ARG1]]
-// CHECK-SAME: [%[[BIDZ]], %[[IV0]], %[[BOFFSET_X]]] [1, 4, 64]
-// CHECK-DAG: %[[SUBVIEW_RHS_2:.+]] = memref.subview %[[SUBVIEW_RHS]]
-// CHECK-SAME: [%[[IIDZ]], 0, %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
+// CHECK-DAG: %[[SUBVIEW_LHS:.+]] = memref.subview %[[SUBVIEW_ARG0]]
+// CHECK-SAME: [%[[IIDZ]], %[[IOFFSET_Y]], %[[IV0]]] [1, 8, 4]
+// CHECK-DAG: %[[SUBVIEW_RHS:.+]] = memref.subview %[[SUBVIEW_ARG1]]
+// CHECK-SAME: [%[[IIDZ]], %[[IV0]], %[[IOFFSET_X]]] [1, 4, 4] [1, 1, 1]
// CHECK-DAG: %[[READ_LHS_0:.+]] = vector.transfer_read
// CHECK-SAME: %[[SUBVIEW_LHS]][%[[C0]], %[[C0]], %[[C0]]]
@@ -130,13 +139,13 @@
// CHECK-SAME: %[[SUBVIEW_LHS]][%[[C0]], %[[C7]], %[[C0]]]
// CHECK-DAG: %[[READ_RHS_0:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C0]], %[[C0]]]
+// CHECK-SAME: %[[SUBVIEW_RHS]][%[[C0]], %[[C0]], %[[C0]]]
// CHECK-DAG: %[[READ_RHS_1:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C1]], %[[C0]]]
+// CHECK-SAME: %[[SUBVIEW_RHS]][%[[C0]], %[[C1]], %[[C0]]]
// CHECK-DAG: %[[READ_RHS_2:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C2]], %[[C0]]]
+// CHECK-SAME: %[[SUBVIEW_RHS]][%[[C0]], %[[C2]], %[[C0]]]
// CHECK-DAG: %[[READ_RHS_3:.+]] = vector.transfer_read
-// CHECK-SAME: %[[SUBVIEW_RHS_2]][%[[C0]], %[[C3]], %[[C0]]]
+// CHECK-SAME: %[[SUBVIEW_RHS]][%[[C0]], %[[C3]], %[[C0]]]
// CHECK-DAG: %[[READ_LHS_0_0:.+]] = vector.extract_strided_slice
// CHECK-SAME: %[[READ_LHS_0]] {offsets = [0, 0, 0]
@@ -290,49 +299,56 @@
// -----
-hal.executable @batch_matmul_fused_fillop attributes {sym_visibility = "private"} {
- hal.interface @io {
+hal.executable @fused_fill_batch_matmul attributes {sym_visibility = "private"} {
+ hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
- hal.executable.target @vulkan, filter="dylib*" {
- hal.executable.entry_point @batch_matmul_fused_fillop attributes {
+ hal.executable.target @vulkan_spirv, filter="vulkan*" {
+ hal.executable.entry_point @fused_fill_batch_matmul attributes {
interface = @io, ordinal = 0 : index,
- signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
- !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
- module attributes {
- spv.target_env =
- #spv.target_env<#spv.vce<v1.3,
- [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
- StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
- UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
- GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
- GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
- VariablePointersStorageBuffer],
- [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
- SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>,
- ARM:IntegratedGPU,
- {max_compute_shared_memory_size = 32768 : i32,
- max_compute_workgroup_invocations = 512 : i32,
- max_compute_workgroup_size = dense<512> : vector<3xi32>,
- subgroup_size = 16 : i32}>} {
- func @batch_matmul_fused_fillop()
- attributes {vkspv.num_workgroups_fn = @batch_matmul_fused_fillop__num_workgroups__} {
- %cst = constant 0.000000e+00 : f32
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4x1024x1024xf32>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4x1024x1024xf32>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4x1024x1024xf32>
- linalg.fill(%ret0, %cst) : memref<4x1024x1024xf32>, f32
- linalg.batch_matmul ins(%arg0, %arg1 : memref<4x1024x1024xf32>, memref<4x1024x1024xf32>) outs(%ret0 : memref<4x1024x1024xf32>)
+ signature = (!flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<readonly:4x1024x1024xf32>, !flow.dispatch.tensor<writeonly:4x1024x1024xf32>) -> ()}
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
+ func @fused_fill_batch_matmul() {
+ %zero = constant 0.0 : f32
+ %c0 = constant 0 : index
+ %c4 = constant 4 : index
+ %c1024 = constant 1024 : index
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4x1024x1024xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4x1024x1024xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4x1024x1024xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_size_z = hal.interface.workgroup.size[2] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %workgroup_id_z = hal.interface.workgroup.id[2] : index
+ %workgroup_count_z = hal.interface.workgroup.count[2] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_z, %workgroup_size_z]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_z, %workgroup_size_z]
+ scf.for %arg0 = %3 to %c4 step %4 {
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg1 = %5 to %c1024 step %6 {
+ %7 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %8 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg2 = %7 to %c1024 step %8 {
+ %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4)>(%arg0)[%workgroup_size_z]
+ %10 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg1)[%workgroup_size_y]
+ %11 = memref.subview %0[%arg0, %arg1, 0] [%9, %10, 1024] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ %12 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 1024)>(%arg2)[%workgroup_size_x]
+ %13 = memref.subview %1[%arg0, 0, %arg2] [%9, 1024, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ %14 = memref.subview %2[%arg0, %arg1, %arg2] [%9, %10, %12] [1, 1, 1] : memref<4x1024x1024xf32> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>
+ linalg.fill(%14, %zero) : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, f32
+ linalg.batch_matmul {__internal_linalg_transform__ = "workgroup"} ins(%11, %13 : memref<?x?x1024xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>, memref<?x1024x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>) outs(%14 : memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * 1048576 + s0 + d1 * 1024 + d2)>>)
+ }
+ }
+ }
return
}
- func private @batch_matmul_fused_fillop__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -341,7 +357,8 @@
}
}
}
-// CHECK-LABEL: func @batch_matmul_fused_fillop
+
+// CHECK-LABEL: func @fused_fill_batch_matmul
// CHECK-COUNT-8: vector.transfer_write
// CHECK-COUNT-8: vector.transfer_read
// CHECK: %[[FOR_RES:.+]]:8 = scf.for
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
index 07e871d..89d2053 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/tile_and_vectorize_matmul.mlir
@@ -1,4 +1,4 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -iree-spirv-enable-vectorization -iree-codegen-spirv-experimental-linalg-on-tensors -canonicalize -cse %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-concretize-tile-among-workgroups,iree-spirv-tile-and-vectorize-in-one-workgroup))" -canonicalize -cse -iree-spirv-workgroup-tile-size=8,64,4 -iree-spirv-invocation-tile-size=8,4,4 -iree-spirv-workgroup-size=16,1,1 %s | IreeFileCheck %s
hal.executable @matmul_static_shape_f16 attributes {sym_visibility = "private"} {
hal.interface @io attributes {sym_visibility = "private"} {
@@ -9,7 +9,7 @@
hal.executable.target @vulkan_spirv, filter="vulkan*" {
hal.executable.entry_point @matmul_static_shape_f16 attributes {
interface = @io, ordinal = 0 : index,
- signature = (!flow.dispatch.tensor<readonly:1x225x225x16xf32>, !flow.dispatch.tensor<readonly:3x3x16x32xf32>, !flow.dispatch.tensor<writeonly:1x112x112x32xf32>) -> ()}
+ signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
func @matmul_static_shape_f16() {
%cst = constant 0.000000e+00 : f16
@@ -51,11 +51,73 @@
}
// CHECK-LABEL: func @matmul_static_shape_f16
-// CHECK-COUNT-16: vector.transfer_write
-// CHECK-COUNT-16: vector.transfer_read
-// CHECK: %[[FOR_RES:.+]]:16 = scf.for
-// CHECK-COUNT-16: vector.transfer_read
-// CHECK-COUNT-64: vector.contract
+// CHECK-COUNT-8: vector.transfer_write
+// CHECK-COUNT-8: vector.transfer_read
+// CHECK: %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12: vector.transfer_read
+// CHECK-COUNT-32: vector.contract
// CHECK: scf.yield
-// CHECK-COUNT-16: vector.transfer_write %[[FOR_RES]]
+// CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]]
+// CHECK: return
+
+// -----
+
+hal.executable @matmul_static_shape_f32 attributes {sym_visibility = "private"} {
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vulkan_spirv, filter="vulkan*" {
+ hal.executable.entry_point @matmul_static_shape_f32 attributes {
+ interface = @io, ordinal = 0 : index,
+ signature = (!flow.dispatch.tensor<readonly:4096x4096xf32>, !flow.dispatch.tensor<readonly:4096x4096xf32>, !flow.dispatch.tensor<writeonly:4096x4096xf32>) -> ()}
+ module attributes {spv.target_env = #spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>, ARM:IntegratedGPU, {}>} {
+ func @matmul_static_shape_f32() {
+ %c0 = constant 0 : index
+ %cst = constant 0.000000e+00 : f32
+ %c4096 = constant 4096 : index
+ %0 = hal.interface.binding.subspan @io::@s0b0_ro_external[%c0] : memref<4096x4096xf32>
+ %1 = hal.interface.binding.subspan @io::@s0b1_ro_external[%c0] : memref<4096x4096xf32>
+ %2 = hal.interface.binding.subspan @io::@s0b2_xw_external[%c0] : memref<4096x4096xf32>
+ %workgroup_size_x = hal.interface.workgroup.size[0] : index
+ %workgroup_size_y = hal.interface.workgroup.size[1] : index
+ %workgroup_id_x = hal.interface.workgroup.id[0] : index
+ %workgroup_count_x = hal.interface.workgroup.count[0] : index
+ %workgroup_id_y = hal.interface.workgroup.id[1] : index
+ %workgroup_count_y = hal.interface.workgroup.count[1] : index
+ %3 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_y, %workgroup_size_y]
+ %4 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_y, %workgroup_size_y]
+ scf.for %arg0 = %3 to %c4096 step %4 {
+ %5 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_id_x, %workgroup_size_x]
+ %6 = affine.apply affine_map<()[s0, s1] -> (s0 * s1)>()[%workgroup_count_x, %workgroup_size_x]
+ scf.for %arg1 = %5 to %c4096 step %6 {
+ %7 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg0)[%workgroup_size_y]
+ %8 = memref.subview %0[%arg0, 0] [%7, 4096] [1, 1] : memref<4096x4096xf32> to memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %9 = affine.min affine_map<(d0)[s0] -> (s0, -d0 + 4096)>(%arg1)[%workgroup_size_x]
+ %10 = memref.subview %1[0, %arg1] [4096, %9] [1, 1] : memref<4096x4096xf32> to memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %11 = memref.subview %2[%arg0, %arg1] [%7, %9] [1, 1] : memref<4096x4096xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ linalg.fill(%11, %cst) : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, f32
+ linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%8, %10 : memref<?x4096xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<4096x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ }
+ }
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
+}
+
+// CHECK-LABEL: func @matmul_static_shape_f32
+// CHECK-COUNT-8: vector.transfer_write
+// CHECK-COUNT-8: vector.transfer_read
+// CHECK: %[[FOR_RES:.+]]:8 = scf.for
+// CHECK-COUNT-12: vector.transfer_read
+// CHECK-COUNT-32: vector.contract
+// CHECK: scf.yield
+// CHECK-COUNT-8: vector.transfer_write %[[FOR_RES]]
// CHECK: return
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
similarity index 61%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
index 983dd35..b0e3c2b 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/elementwise_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_elementwise_ops.mlir
@@ -1,8 +1,8 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" %s | IreeFileCheck %s
// CHECK-LABEL: func @elementwise_static_shape
-// CHECK: vector.transfer_read %10[%c0], {{.*}} memref<4xf32, #map1>, vector<4xf32>
-// CHECK: vector.transfer_read %11[%c0], {{.*}} memref<4xf32, #map1>, vector<4xf32>
+// CHECK: vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
+// CHECK: vector.transfer_read %{{.+}}[%c0], {{.+}} memref<4xf32, #{{.+}}>, vector<4xf32>
// CHECK: addf %{{.*}}, %{{.*}} : vector<4xf32>
// CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32
hal.executable @elementwise_static_shape attributes {sym_visibility = "private"} {
@@ -11,7 +11,7 @@
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
- hal.executable.target @vulkan, filter="dylib*" {
+ hal.executable.target @vulkan, filter="vulkan*" {
hal.executable.entry_point @elementwise_static_shape attributes {
interface = @io, ordinal = 0 : index,
signature = (!flow.dispatch.tensor<readonly:?xf32>,
@@ -23,30 +23,25 @@
[Shader],
[]>, NVIDIA:DiscreteGPU,
{subgroup_size = 32 : i32}>} {
- func @elementwise_static_shape()
- attributes {vkspv.num_workgroups_fn = @elementwise_static_shape__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<128xf32>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<128xf32>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<128xf32>
+ func @elementwise_static_shape() {
+ %c0 = constant 0 : index
+ %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128xf32>
+ %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<128xf32>
+ %ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128xf32>
linalg.generic {
+ __internal_linalg_transform__ = "workgroup",
indexing_maps = [affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>],
- iterator_types = ["parallel"]
- } ins(%arg0, %arg1 : memref<128xf32>, memref<128xf32>)
- outs(%ret0 : memref<128xf32>) {
+ iterator_types = ["parallel"]
+ } ins(%arg0, %arg1 : memref<128xf32>, memref<128xf32>)
+ outs(%ret0 : memref<128xf32>) {
^bb0(%a : f32, %b : f32, %c : f32):
%add = addf %a, %b : f32
linalg.yield %add : f32
}
return
}
- func private @elementwise_static_shape__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -81,30 +76,25 @@
[Shader],
[]>, NVIDIA:DiscreteGPU,
{subgroup_size = 32 : i32}>} {
- func @elementwise_transpose()
- attributes {vkspv.num_workgroups_fn = @elementwise_transpose__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<128x8xf32>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<128xf32>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<128x8xf32>
+ func @elementwise_transpose() {
+ %c0 = constant 0 : index
+ %arg0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<128x8xf32>
+ %arg1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<128xf32>
+ %ret0 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<128x8xf32>
linalg.generic {
+ __internal_linalg_transform__ = "workgroup",
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0)>,
affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } ins(%arg0, %arg1 : memref<128x8xf32>, memref<128xf32>)
- outs(%ret0 : memref<128x8xf32>) {
+ iterator_types = ["parallel", "parallel"]
+ } ins(%arg0, %arg1 : memref<128x8xf32>, memref<128xf32>)
+ outs(%ret0 : memref<128x8xf32>) {
^bb0(%a : f32, %b : f32, %c : f32):
%add = addf %a, %b : f32
linalg.yield %add : f32
}
return
}
- func private @elementwise_transpose__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
similarity index 79%
rename from iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
rename to iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
index 80ed9b1..b97424c 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/vectorize_matmul.mlir
@@ -1,17 +1,16 @@
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization %s | IreeFileCheck %s
-// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-codegen-spirv-linalg-tile-and-distribute,iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-enable-vectorization -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" %s | IreeFileCheck %s
+// RUN: iree-opt -split-input-file -pass-pipeline="hal.executable(hal.executable.target(iree-spirv-tile-and-vectorize-in-one-workgroup,canonicalize,cse))" -iree-spirv-use-workgroup-memory %s | IreeFileCheck %s -check-prefix=PROMOTE
hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
- hal.interface @io {
+ hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
}
- hal.executable.target @vulkan, filter="dylib*" {
+ hal.executable.target @vulkan_spirv, filter="vulkan*" {
hal.executable.entry_point @matmul_static_shape attributes {
interface = @io, ordinal = 0 : index,
- signature = (!flow.dispatch.tensor<readonly:?x?xf32>, !flow.dispatch.tensor<readonly:?x?xf32>,
- !flow.dispatch.tensor<writeonly:?x?xf32>) -> ()}
+ signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
module attributes {
spv.target_env =
#spv.target_env<#spv.vce<v1.5,
@@ -37,21 +36,27 @@
max_compute_workgroup_invocations = 1024 : i32,
max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
subgroup_size = 32 : i32}>} {
- func @matmul_static_shape()
- attributes {vkspv.num_workgroups_fn = @matmul_static_shape__num_workgroups__} {
- %arg0 = iree.placeholder for "interface buffer"
- {binding = @io::@arg0, operand_result_num = 0 : i32} : memref<4096x4096xf16>
- %arg1 = iree.placeholder for "interface buffer"
- {binding = @io::@arg1, operand_result_num = 1 : i32} : memref<4096x4096xf16>
- %ret0 = iree.placeholder for "interface buffer"
- {binding = @io::@ret0, operand_result_num = 2 : i32} : memref<4096x4096xf16>
- linalg.matmul ins(%arg0, %arg1 : memref<4096x4096xf16>, memref<4096x4096xf16>)
- outs(%ret0 : memref<4096x4096xf16>)
+ func @matmul_static_shape() {
+ %c32 = constant 32 : index
+ %c4096 = constant 4096 : index
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4096x4096xf16>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4096x4096xf16>
+ %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4096x4096xf16>
+ %3 = hal.interface.workgroup.size[0] : index
+ %4 = hal.interface.workgroup.size[1] : index
+ scf.for %arg0 = %c0 to %c4096 step %c32 {
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%4]
+ %6 = memref.subview %0[%5, %arg0] [64, 32] [1, 1] : memref<4096x4096xf16> to memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%3]
+ %8 = memref.subview %1[%arg0, %7] [32, 64] [1, 1] : memref<4096x4096xf16> to memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%4]
+ %10 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%3]
+ %11 = memref.subview %2[%9, %10] [64, 64] [1, 1] : memref<4096x4096xf16> to memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ linalg.matmul {__internal_linalg_transform__ = "workgroup"} ins(%6, %8 : memref<64x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<64x64xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ }
return
}
- func private @matmul_static_shape__num_workgroups__
- (!shapex.ranked_shape<[4096, 4096]>, !shapex.ranked_shape<[4096, 4096]>,
- !shapex.ranked_shape<[4096, 4096]>) -> (index, index, index)
hal.interface @io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -60,18 +65,19 @@
}
}
}
+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK: func @matmul_static_shape
-// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
-// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[CST:.+]] = constant 0.0
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C16:.+]] = constant 16 : index
// CHECK-DAG: %[[C32:.+]] = constant 32 : index
// CHECK-DAG: %[[C48:.+]] = constant 48 : index
-// CHECK: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
-// CHECK: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
+// CHECK-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]]
+// CHECK-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1[%[[C0]]]
+// CHECK-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]]
+// CHECK: %[[BIDX:.+]] = hal.interface.workgroup.size[0]
+// CHECK: %[[BIDY:.+]] = hal.interface.workgroup.size[1]
// CHECK: %[[BOFFSET_Y:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[BOFFSET_X:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
// CHECK: %[[SUBVIEW_RESULT:.+]] = memref.subview %[[RET0]]
@@ -268,12 +274,75 @@
// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#14, %[[SUBVIEW_RESULT]][%[[C48]], %[[C32]]]
// CHECK-DAG: vector.transfer_write %[[FOR_RES]]#15, %[[SUBVIEW_RESULT]][%[[C48]], %[[C48]]]
+// -----
+
+hal.executable @matmul_static_shape attributes {sym_visibility = "private"} {
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ hal.executable.target @vulkan_spirv, filter="vulkan*" {
+ hal.executable.entry_point @matmul_static_shape attributes {
+ interface = @io, ordinal = 0 : index,
+ signature = (!flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<readonly:4096x4096xf16>, !flow.dispatch.tensor<writeonly:4096x4096xf16>) -> ()}
+ module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.5,
+ [Shader, Float64, Float16, Int64, Int16, Int8, StorageBuffer16BitAccess,
+ StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess,
+ UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniform,
+ GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot,
+ GroupNonUniformShuffle, GroupNonUniformShuffleRelative, VariablePointers,
+ VariablePointersStorageBuffer, CooperativeMatrixNV],
+ [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage,
+ SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers,
+ SPV_NV_cooperative_matrix]>, NVIDIA:DiscreteGPU,
+ {cooperative_matrix_properties_nv = [
+ {a_type = i8, b_type = i8, c_type = i32, k_size = 32 : i32,
+ m_size = 8 : i32, n_size = 8 : i32, result_type = i32, scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f16, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f16,
+ scope = 3 : i32},
+ {a_type = f16, b_type = f16, c_type = f32, k_size = 16 : i32,
+ m_size = 16 : i32, n_size = 16 : i32, result_type = f32,
+ scope = 3 : i32}],
+ max_compute_shared_memory_size = 49152 : i32,
+ max_compute_workgroup_invocations = 1024 : i32,
+ max_compute_workgroup_size = dense<[2147483647, 65535, 65535]> : vector<3xi32>,
+ subgroup_size = 32 : i32}>} {
+ func @matmul_static_shape() {
+ %c32 = constant 32 : index
+ %c4096 = constant 4096 : index
+ %c0 = constant 0 : index
+ %0 = hal.interface.binding.subspan @io::@arg0[%c0] : memref<4096x4096xf16>
+ %1 = hal.interface.binding.subspan @io::@arg1[%c0] : memref<4096x4096xf16>
+ %2 = hal.interface.binding.subspan @io::@ret0[%c0] : memref<4096x4096xf16>
+ %3 = hal.interface.workgroup.size[0] : index
+ %4 = hal.interface.workgroup.size[1] : index
+ scf.for %arg0 = %c0 to %c4096 step %c32 {
+ %5 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
+ %6 = memref.subview %0[%5, %arg0] [128, 32] [1, 1] : memref<4096x4096xf16> to memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %7 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
+ %8 = memref.subview %1[%arg0, %7] [32, 128] [1, 1] : memref<4096x4096xf16> to memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ %9 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%4]
+ %10 = affine.apply affine_map<()[s0] -> (s0 * 128)>()[%3]
+ %11 = memref.subview %2[%9, %10] [128, 128] [1, 1] : memref<4096x4096xf16> to memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>
+ linalg.matmul {__internal_linalg_transform__ = "workgroup", is_root_op, launch_info_key = "__op_num_0__"} ins(%6, %8 : memref<128x32xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>, memref<32x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>) outs(%11 : memref<128x128xf16, affine_map<(d0, d1)[s0] -> (d0 * 4096 + s0 + d1)>>)
+ }
+ return
+ }
+ hal.interface @io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ }
+ }
+ }
+}
// PROMOTE-DAG: #[[MAP4:.+]] = affine_map<()[s0] -> (s0 * 64 - (s0 floordiv 2) * 128)>
// PROMOTE: func @matmul_static_shape
-// PROMOTE-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @io::@arg0
-// PROMOTE-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @io::@arg1
-// PROMOTE-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @io::@ret0
// PROMOTE-DAG: %[[C0:.+]] = constant 0 : index
// PROMOTE-DAG: %[[C2:.+]] = constant 2
// PROMOTE-DAG: %[[C16:.+]] = constant 16
@@ -281,6 +350,9 @@
// PROMOTE-DAG: %[[C48:.+]] = constant 48
// PROMOTE-DAG: %[[ALLOC1:.+]] = memref.alloc() : memref<128x32xf16, 3>
// PROMOTE-DAG: %[[ALLOC2:.+]] = memref.alloc() : memref<32x128xf16, 3>
+// PROMOTE-DAG: %[[ARG0:.+]] = hal.interface.binding.subspan @io::@arg0[%[[C0]]]
+// PROMOTE-DAG: %[[ARG1:.+]] = hal.interface.binding.subspan @io::@arg1[%[[C0]]]
+// PROMOTE-DAG: %[[RET0:.+]] = hal.interface.binding.subspan @io::@ret0[%[[C0]]]
// PROMOTE: %[[RESULT_SUBVIEW:.+]] = memref.subview %[[RET0]]
// PROMOTE: %[[WGMEM_LHS_SUBVIEW:.+]] = memref.subview %[[ALLOC1]][0, 0] [128, 32] [1, 1]