Use generated function to compute number of workgroups to be launched. (#3095)
The current mechanism of deciding number of workgroups to use for the
launch uses an enumeration of fixed strategies. This is not sufficient
to cover all uses. This commit adds a new mechanism where the
codegeneration passes generate the function to use that when called
from the host side will return the number of workgroups to use. The
arguments to the function are the shapes of the inputs and outputs to
the dispatch region.
Currently this is enabled only for static shapes, since there are some
unresolved issues w.r.t to the dynamic shape case.
diff --git a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
index adc2f53..6a4c81e 100644
--- a/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
+++ b/iree/compiler/Conversion/CodegenUtils/FunctionUtils.h
@@ -23,6 +23,16 @@
/// Returns true if the given `func` is a kernel dispatch entry point.
bool isEntryPoint(FuncOp func);
+/// Returns the attribute name used to record the binding associated with an
+/// iree.placeholder operation.
+inline const char* getBindingAttrName() { return "binding"; }
+
+/// Returns the attribute name used to record argument position in the (operand
+/// + result) list of shaped types of the dispatch region.
+inline const char* getOperandResultNumAttrName() {
+ return "operand_result_index";
+}
+
} // namespace iree_compiler
} // namespace mlir
diff --git a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
index 0808d42..ebf96aa 100644
--- a/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/FusionOfTensorOps.cpp
@@ -41,8 +41,8 @@
reshapeOp.src().getDefiningOp<IREE::HAL::InterfaceLoadTensorOp>();
if (!loadOp) return failure();
rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceLoadTensorOp>(
- reshapeOp, reshapeOp.getResultType(), loadOp.binding(),
- loadOp.offset());
+ reshapeOp, reshapeOp.getResultType(), loadOp.offset(),
+ loadOp.getAttrs());
return success();
}
};
@@ -55,8 +55,9 @@
PatternRewriter &rewriter) const override {
auto reshapeOp = storeOp.operand().getDefiningOp<linalg::TensorReshapeOp>();
if (!reshapeOp) return failure();
+ SmallVector<Value, 2> operands = {reshapeOp.src(), storeOp.offset()};
rewriter.replaceOpWithNewOp<IREE::HAL::InterfaceStoreTensorOp>(
- storeOp, reshapeOp.src(), storeOp.binding(), storeOp.offset());
+ storeOp, ArrayRef<Type>(), operands, storeOp.getAttrs());
return success();
}
};
diff --git a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
index 70a407f..87ed137 100644
--- a/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
+++ b/iree/compiler/Conversion/HLOToLinalg/HLOToLinalgOnBuffers.cpp
@@ -22,6 +22,7 @@
#include <cstddef>
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/HLOToLinalg/Passes.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/IREE/IR/IREEDialect.h"
@@ -1348,7 +1349,10 @@
// annotation is carried over if exists.
auto phOp = rewriter.create<IREE::PlaceholderOp>(
loadOp.getLoc(), bufferType, "interface buffer");
- phOp.setAttr("binding", loadOp.binding());
+ phOp.setAttr(getBindingAttrName(), loadOp.binding());
+ StringRef attrName = getOperandResultNumAttrName();
+ if (auto operandResultNumAttr = loadOp.getAttr(attrName))
+ phOp.setAttr(attrName, operandResultNumAttr);
Value buffer = phOp.getResult();
// If the result of the load is already mapped to a buffer, a copy is
@@ -1463,7 +1467,10 @@
// annotation is carried over if exists.
auto phOp = builder.create<IREE::PlaceholderOp>(op.getLoc(), bufferType,
"interface buffer");
- phOp.setAttr("binding", op.binding());
+ phOp.setAttr(getBindingAttrName(), op.binding());
+ StringRef attrName = getOperandResultNumAttrName();
+ if (Attribute operandResultNumAttr = op.getAttr(attrName))
+ phOp.setAttr(attrName, operandResultNumAttr);
Value buffer = phOp;
outputBufferMap[op] = buffer;
@@ -1576,7 +1583,8 @@
target.addIllegalOp<linalg::TensorReshapeOp>();
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>([](Operation *op) {
- // The generated structured Linalg ops should have buffer semantics.
+ // The generated structured Linalg ops should have buffer
+ // semantics.
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
return linalgOp.hasBufferSemantics();
}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir
index a82179b..bb090b8 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/fusion.mlir
@@ -38,7 +38,7 @@
return
}
hal.interface @legacy_io attributes {sym_visiblity = "private"} {
- hal.interface.binding @ret0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=0, type="StorageBuffer", access="Write|Discard"
}
}
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
index 2bab77a..9363815 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/linalg_tensor_to_buffer.mlir
@@ -5,29 +5,35 @@
module {
func @element_wise() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x2xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0 : tensor<2x2xf32>
- %2 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} %0, %1 {
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x2xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<2x2xf32>
+ %2 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0, %1 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%3 = addf %arg3, %arg4 : f32
linalg.yield %3 : f32
}: tensor<2x2xf32>, tensor<2x2xf32> -> tensor<2x2xf32>
- hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0 : tensor<2x2xf32>
+ hal.interface.store.tensor %2, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<2x2xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer",
+ access="Write"
}
}
// CHECK-LABEL: func @element_wise
-// CHECK-DAG: %[[ARG2:.*]] = iree.placeholder
-// CHECK-SAME: {binding = @legacy_io::@ret0}
-// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder
-// CHECK-SAME: {binding = @legacy_io::@arg0}
-// CHECK-DAG: %[[ARG1:.*]] = iree.placeholder
-// CHECK-SAME: {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[ARG2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32}
+// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
+// CHECK-DAG: %[[ARG1:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32}
// CHECK-NOT: hal.interface.load.tensor
// CHECK: linalg.generic
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]]
@@ -45,8 +51,12 @@
module {
func @indexed_generic() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<2x2xi32>
- %1 = linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} %0 {
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x2xi32>
+ %1 = linalg.indexed_generic
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0 {
^bb0(%arg2: index, %arg3: index, %arg4: i32): // no predecessors
%2 = index_cast %arg2 : index to i32
%3 = index_cast %arg3 : index to i32
@@ -54,19 +64,20 @@
%5 = addi %4, %3 : i32
linalg.yield %5 : i32
}: tensor<2x2xi32> -> tensor<2x2xi32>
- hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<2x2xi32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<2x2xi32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write"
}
}
// CHECK: func @indexed_generic
-// CHECK-DAG: %[[RET0:.*]] = iree.placeholder
-// CHECK-SAME: {binding = @legacy_io::@ret0}
-// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder
-// CHECK-SAME: {binding = @legacy_io::@arg0}
+// CHECK-DAG: %[[RET0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
+// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK-NOT: hal.interface.load.tensor
// CHECK: linalg.indexed_generic
// CHECK-SAME: %[[ARG0]], %[[RET0]]
@@ -90,23 +101,23 @@
module {
func @reshape_arg_result() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<5xf32>
- %1 = hal.interface.load.tensor @legacy_io::@arg1,
- offset = %c0 : tensor<5xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<5xf32>
+ %1 = hal.interface.load.tensor @legacy_io::@arg1, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<5xf32>
%2 = linalg.tensor_reshape %0 [#map0] : tensor<5xf32> into tensor<5x1xf32>
%3 = linalg.tensor_reshape %1 [#map0] : tensor<5xf32> into tensor<1x5xf32>
%4 = linalg.generic
- {args_in = 2 : i64, args_out = 1 : i64,
- indexing_maps = [#map1, #map2, #map0],
- iterator_types = ["parallel", "parallel"]} %2, %3 {
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map1, #map2, #map0],
+ iterator_types = ["parallel", "parallel"]} %2, %3 {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%5 = addf %arg3, %arg4 : f32
linalg.yield %5 : f32
}: tensor<5x1xf32>, tensor<1x5xf32> -> tensor<5x5xf32>
%6 = linalg.tensor_reshape %4 [#map0] : tensor<5x5xf32> into tensor<25xf32>
- hal.interface.store.tensor %6, @legacy_io::@ret0,
- offset = %c0 : tensor<25xf32>
+ hal.interface.store.tensor %6, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<25xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -122,13 +133,10 @@
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (0, d1)>
// CHECK: func @reshape_arg_result
-// CHECK-DAG: %[[RET0:.*]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@ret0
+// CHECK-DAG: %[[RET0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 2 : i32}
// CHECK-DAG: %[[RESULT:.*]] = linalg.reshape %[[RET0]] [#[[MAP0]]]
-// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg0
-// CHECK-DAG: %[[ARG1:.*]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg1
+// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
+// CHECK-DAG: %[[ARG1:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1, operand_result_index = 1 : i32}
// CHECK-DAG: %[[LHS:.*]] = linalg.reshape %[[ARG0]] [#[[MAP0]]]
// CHECK-DAG: %[[RHS:.*]] = linalg.reshape %[[ARG1]] [#[[MAP0]]]
// CHECK: linalg.generic
@@ -141,11 +149,11 @@
module {
func @reshape_only() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<5x5xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<5x5xf32>
%1 = linalg.tensor_reshape %0 [#map0] : tensor<5x5xf32> into tensor<25xf32>
- hal.interface.store.tensor %1, @legacy_io::@ret0,
- offset = %c0 : tensor<25xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<25xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
@@ -157,11 +165,9 @@
}
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: func @reshape_only
-// CHECK-DAG: %[[RET0:.*]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@ret0
+// CHECK-DAG: %[[RET0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
// CHECK-DAG: %[[RESULT:.*]] = linalg.reshape %[[RET0]] [#[[MAP0]]]
-// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: linalg.copy(%[[ARG0]], %[[RESULT]])
// -----
@@ -171,34 +177,36 @@
module {
func @store_value_twice() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<2x4xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x4xf32>
%1 = linalg.generic
- {args_in = 1 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"]} %0 {
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0 {
^bb0(%arg0: f32): // no predecessors
%2 = tanh %arg0 : f32
linalg.yield %2 : f32
}: tensor<2x4xf32> -> tensor<2x4xf32>
- hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<2x4xf32>
- hal.interface.store.tensor %1, @legacy_io::@ret1, offset = %c0 : tensor<2x4xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<2x4xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret1, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<2x4xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0,
- type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1,
- type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret1, set=0, binding=2,
- type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @store_value_twice
-// CHECK: %[[T0:.*]] = iree.placeholder {{.*}} @legacy_io::@ret0
-// CHECK: %[[T1:.*]] = iree.placeholder {{.*}} @legacy_io::@ret1
-// CHECK: %[[T2:.*]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[T0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
+// CHECK: %[[T1:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32}
+// CHECK: %[[T2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: linalg.generic {{.*}} %[[T2]], %[[T0]]
// CHECK: linalg.copy(%[[T0]], %[[T1]])
@@ -211,36 +219,39 @@
module {
func @store_reshape_src_and_result_0() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<2x4xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x4xf32>
%1 = linalg.generic
- {args_in = 1 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"]} %0 {
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0 {
^bb0(%arg0: f32): // no predecessors
%2 = tanh %arg0 : f32
linalg.yield %2 : f32
}: tensor<2x4xf32> -> tensor<2x4xf32>
- %3 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<2x4xf32> into tensor<1x2x4xf32>
- hal.interface.store.tensor %3, @legacy_io::@ret1, offset = %c0 : tensor<1x2x4xf32>
- hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<2x4xf32>
+ %3 = linalg.tensor_reshape %1 [#map1, #map2]
+ : tensor<2x4xf32> into tensor<1x2x4xf32>
+ hal.interface.store.tensor %3, @legacy_io::@ret1, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<1x2x4xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<2x4xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0,
- type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1,
- type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret1, set=0, binding=2,
- type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @store_reshape_src_and_result_0
-// CHECK: %[[T0:.*]] = iree.placeholder {{.*}} @legacy_io::@ret1
+// CHECK: %[[T0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32}
// CHECK: %[[T1:.*]] = linalg.reshape %[[T0]]
-// CHECK: %[[T2:.*]] = iree.placeholder {{.*}} @legacy_io::@ret0
-// CHECK: %[[T3:.*]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK: %[[T2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
+// CHECK: %[[T3:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: linalg.generic {{.*}} %[[T3]], %[[T1]]
// CHECK: linalg.copy(%[[T1]], %[[T2]])
// CHECK: return
@@ -254,35 +265,38 @@
module {
func @store_reshape_src_and_result_1() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<2x4xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x4xf32>
%1 = linalg.generic
- {args_in = 1 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"]} %0 {
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0 {
^bb0(%arg0: f32): // no predecessors
%2 = tanh %arg0 : f32
linalg.yield %2 : f32
}: tensor<2x4xf32> -> tensor<2x4xf32>
- %3 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<2x4xf32> into tensor<1x2x4xf32>
- hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0 : tensor<2x4xf32>
- hal.interface.store.tensor %3, @legacy_io::@ret1, offset = %c0 : tensor<1x2x4xf32>
+ %3 = linalg.tensor_reshape %1 [#map1, #map2]
+ : tensor<2x4xf32> into tensor<1x2x4xf32>
+ hal.interface.store.tensor %1, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<2x4xf32>
+ hal.interface.store.tensor %3, @legacy_io::@ret1, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<1x2x4xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0,
- type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1,
- type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret1, set=0, binding=2,
- type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @store_reshape_src_and_result_1
-// CHECK: %[[T0:.*]] = iree.placeholder {{.*}} @legacy_io::@ret0
-// CHECK: %[[T1:.*]] = iree.placeholder {{.*}} @legacy_io::@ret1
-// CHECK: %[[T2:.*]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[T0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
+// CHECK-DAG: %[[T1:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32}
+// CHECK-DAG: %[[T2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: linalg.generic {{.*}} %[[T2]], %[[T0]]
// CHECK: %[[T3:.*]] = linalg.reshape %[[T0]]
// CHECK: linalg.copy(%[[T3]], %[[T1]])
@@ -297,42 +311,48 @@
module {
func @store_reshape_src_and_result_2() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0,
- offset = %c0 : tensor<2x4xf32>
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<2x4xf32>
%1 = linalg.generic
- {args_in = 1 : i64, args_out = 1 : i64,
- indexing_maps = [#map0, #map0],
- iterator_types = ["parallel", "parallel"]} %0 {
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0],
+ iterator_types = ["parallel", "parallel"]} %0 {
^bb0(%arg0: f32): // no predecessors
%2 = tanh %arg0 : f32
linalg.yield %2 : f32
}: tensor<2x4xf32> -> tensor<2x4xf32>
- %3 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<2x4xf32> into tensor<1x2x4xf32>
- %4 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<2x4xf32> into tensor<1x2x4xf32>
- %5 = linalg.tensor_reshape %1 [#map1, #map2] : tensor<2x4xf32> into tensor<1x2x4xf32>
- hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 : tensor<1x2x4xf32>
- hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0 : tensor<1x2x4xf32>
- hal.interface.store.tensor %5, @legacy_io::@ret2, offset = %c0 : tensor<1x2x4xf32>
+ %3 = linalg.tensor_reshape %1 [#map1, #map2]
+ : tensor<2x4xf32> into tensor<1x2x4xf32>
+ %4 = linalg.tensor_reshape %1 [#map1, #map2]
+ : tensor<2x4xf32> into tensor<1x2x4xf32>
+ %5 = linalg.tensor_reshape %1 [#map1, #map2]
+ : tensor<2x4xf32> into tensor<1x2x4xf32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<1x2x4xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<1x2x4xf32>
+ hal.interface.store.tensor %5, @legacy_io::@ret2, offset = %c0
+ {operand_result_index = 3 : i32} : tensor<1x2x4xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0,
- type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1,
- type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret1, set=0, binding=2,
- type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret2, set=0, binding=3,
- type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret2, set=0, binding=3, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @store_reshape_src_and_result_2
-// CHECK: %[[T0:.*]] = iree.placeholder {{.*}} @legacy_io::@ret0
+// CHECK-DAG: %[[T0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
// CHECK: %[[T1:.*]] = linalg.reshape %[[T0]]
-// CHECK: %[[T2:.*]] = iree.placeholder {{.*}} @legacy_io::@ret1
-// CHECK: %[[T3:.*]] = iree.placeholder {{.*}} @legacy_io::@ret2
-// CHECK: %[[T4:.*]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK-DAG: %[[T2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32}
+// CHECK-DAG: %[[T3:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret2, operand_result_index = 3 : i32}
+// CHECK-DAG: %[[T4:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: linalg.generic {{.*}} %[[T4]], %[[T1]]
// CHECK: linalg.copy(%[[T0]], %[[T2]])
// CHECK: linalg.copy(%[[T0]], %[[T3]])
@@ -347,27 +367,37 @@
module {
func @edge_detect_sobel_operator_ex_dispatch_3() {
%c0 = constant 0 : index
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x128x128x1xf32>
- %1 = linalg.tensor_reshape %0 [#map0, #map1] : tensor<1x128x128x1xf32> into tensor<128x128xf32>
- %2 = linalg.tensor_reshape %0 [#map0, #map1] : tensor<1x128x128x1xf32> into tensor<128x128xf32>
- %3 = linalg.generic {args_in = 2 : i64, args_out = 1 : i64, indexing_maps = [#map2, #map2, #map2], iterator_types = ["parallel", "parallel"]} %1, %2 {
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<1x128x128x1xf32>
+ %1 = linalg.tensor_reshape %0 [#map0, #map1]
+ : tensor<1x128x128x1xf32> into tensor<128x128xf32>
+ %2 = linalg.tensor_reshape %0 [#map0, #map1]
+ : tensor<1x128x128x1xf32> into tensor<128x128xf32>
+ %3 = linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map2, #map2, #map2],
+ iterator_types = ["parallel", "parallel"]} %1, %2 {
^bb0(%arg0: f32, %arg1: f32): // no predecessors
%5 = mulf %arg0, %arg1 : f32
linalg.yield %5 : f32
}: tensor<128x128xf32>, tensor<128x128xf32> -> tensor<128x128xf32>
- %4 = linalg.tensor_reshape %3 [#map0, #map1] : tensor<128x128xf32> into tensor<1x128x128x1xf32>
- hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0 : tensor<1x128x128x1xf32>
+ %4 = linalg.tensor_reshape %3 [#map0, #map1]
+ : tensor<128x128xf32> into tensor<1x128x128x1xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<1x128x128x1xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @edge_detect_sobel_operator_ex_dispatch_3
-// CHECK: %[[T0:.*]] = iree.placeholder {{.*}} @legacy_io::@ret0
+// CHECK: %[[T0:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0, operand_result_index = 1 : i32}
// CHECK: %[[T1:.*]] = linalg.reshape %[[T0]]
-// CHECK: %[[T2:.*]] = iree.placeholder {{.*}} @legacy_io::@arg0
+// CHECK: %[[T2:.*]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: %[[T3:.*]] = linalg.reshape %[[T2]]
// CHECK: %[[T4:.*]] = linalg.reshape %[[T2]]
// CHECK: linalg.generic {{.*}} %[[T3]], %[[T4]], %[[T1]]
@@ -383,35 +413,44 @@
func @generic_reshape_reshape() {
%c0 = constant 0 : index
%cst = constant 0.000000e+00 : f32
- %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<1x1x1x1000xf32>
- %1 = linalg.tensor_reshape %0 [#map0] : tensor<1x1x1x1000xf32> into tensor<1000xf32>
- %2 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} %1 {
+ %0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0
+ {operand_result_index = 0 : i32} : tensor<1x1x1x1000xf32>
+ %1 = linalg.tensor_reshape %0 [#map0]
+ : tensor<1x1x1x1000xf32> into tensor<1000xf32>
+ %2 = linalg.generic
+ {args_in = 1 : i64, args_out = 1 : i64,
+ indexing_maps = [#map1, #map1], iterator_types = ["parallel"]} %1 {
^bb0(%arg0: f32): // no predecessors
%5 = addf %arg0, %cst : f32
linalg.yield %5 : f32
}: tensor<1000xf32> -> tensor<1000xf32>
- %3 = linalg.tensor_reshape %2 [#map0] : tensor<1000xf32> into tensor<1x1x1x1000xf32>
- %4 = linalg.tensor_reshape %3 [#map2, #map3] : tensor<1x1x1x1000xf32> into tensor<1x1000xf32>
- hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0 : tensor<1x1x1x1000xf32>
- hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0 : tensor<1x1000xf32>
+ %3 = linalg.tensor_reshape %2 [#map0]
+ : tensor<1000xf32> into tensor<1x1x1x1000xf32>
+ %4 = linalg.tensor_reshape %3 [#map2, #map3]
+ : tensor<1x1x1x1000xf32> into tensor<1x1000xf32>
+ hal.interface.store.tensor %3, @legacy_io::@ret0, offset = %c0
+ {operand_result_index = 1 : i32} : tensor<1x1x1x1000xf32>
+ hal.interface.store.tensor %4, @legacy_io::@ret1, offset = %c0
+ {operand_result_index = 2 : i32} : tensor<1x1000xf32>
return
}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
- hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
- hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
- hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer", access="Write|Discard"
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer",
+ access="Read"
+ hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer",
+ access="Write|Discard"
+ hal.interface.binding @ret1, set=0, binding=2, type="StorageBuffer",
+ access="Write|Discard"
}
}
// CHECK-LABEL: func @generic_reshape_reshape
// CHECK: %[[RET0:.+]] = iree.placeholder
-// CHECK-SAME: @legacy_io::@ret0
+// CHECK-SAME: binding = @legacy_io::@ret0, operand_result_index = 1 : i32
// CHECK: %[[RET0_RESHAPE:.+]] = linalg.reshape %[[RET0]]
// CHECK-SAME: memref<1x1x1x1000xf32> into memref<1000xf32>
-// CHECK: %[[RET1:.+]] = iree.placeholder
-// CHECK-SAME: @legacy_io::@ret1
-// CHECK: %[[ARG0:.+]] = iree.placeholder
-// CHECK-SAME: @legacy_io::@arg0
+// CHECK: %[[RET1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1, operand_result_index = 2 : i32}
+// CHECK: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0, operand_result_index = 0 : i32}
// CHECK: %[[ARG0_RESHAPE:.+]] = linalg.reshape %[[ARG0]]
// CHECK-SAME: memref<1x1x1x1000xf32> into memref<1000xf32>
// CHECK: linalg.generic
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir b/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
index c5b558b..90c4749 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/pipeline_test.mlir
@@ -46,14 +46,10 @@
}
}
// CHECK-LABEL: func @bug_2882_repro2
-// CHECK: %[[RET0:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@ret0
-// CHECK: %[[RET1:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@ret1
-// CHECK: %[[ARG1:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg1
-// CHECK: %[[ARG0:.+]] = iree.placeholder
-// CHECK-SAME: binding = @legacy_io::@arg0
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0}
+// CHECK-DAG: %[[RET1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1}
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0}
// CHECK: linalg.generic
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[RET0]]
// CHECK: linalg.copy(%[[RET0]], %[[RET1]])
diff --git a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
index 300c0ff..0c17311 100644
--- a/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
+++ b/iree/compiler/Conversion/HLOToLinalg/test/slice.mlir
@@ -4,7 +4,8 @@
// CHECK_LABEL: @slice_whole_buffer
// CHECK-NOT: subview
// CHECK: linalg.copy
- func @slice_whole_buffer() {
+ func @slice_whole_buffer()
+ attributes {signature = (tensor<3x4xi32>) -> (tensor<3x4xi32>)} {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x4xi32>
%1 = "mhlo.slice"(%0) {
@@ -38,7 +39,8 @@
// CHECK-SAME: [%[[ONE]], %[[ONE]]]
// CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #[[MAP]]>
// CHECK: linalg.copy
- func @slice_whole_stride() {
+ func @slice_whole_stride()
+ attributes {signature = (tensor<3x4xi32>) -> (tensor<1x4xi32>)} {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x4xi32>
%1 = "mhlo.slice"(%0) {
@@ -72,7 +74,8 @@
// CHECK-SAME: [%[[ONE]], %[[ONE]]]
// CHECK-SAME: : memref<3x4xi32> to memref<?x?xi32, #map0>
// CHECK: linalg.copy
- func @slice_stride_part() {
+ func @slice_stride_part()
+ attributes {signature = (tensor<3x4xi32>) -> (tensor<1x2xi32>)} {
%c0 = constant 0 : index
%0 = hal.interface.load.tensor @legacy_io::@arg0, offset = %c0 : tensor<3x4xi32>
%1 = "mhlo.slice"(%0) {
diff --git a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
index 7d47f47..123a1e0 100644
--- a/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToLLVM/test/matmul_vectorization.mlir
@@ -1,31 +1,31 @@
// RUN: iree-opt --iree-codegen-linalg-to-llvm-matmul-vectorization-pass -split-input-file %s | IreeFileCheck %s
-// CHECK-LABEL: func @matmul_128x128x128
-// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
func @matmul_128x128x128(%arg0 : memref<128x128xf32>, %arg1: memref<128x128xf32>, %arg2: memref<128x128xf32>) {
linalg.matmul %arg0, %arg1, %arg2 : (memref<128x128xf32>, memref<128x128xf32>, memref<128x128xf32>)
return
}
-// CHECK: %[[L3END:.+]] = constant 128 : index
-// CHECK: %[[L3STEP:.+]] = constant 64 : index
-// CHECK: %[[L1STEP:.+]] = constant 4 : index
-// CHECK: %[[L2STEP:.+]] = constant 32 : index
-// CHECK: %[[START:.+]] = constant 0 : index
-// CHECK: scf.for %[[IL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
-// CHECK: scf.for %[[JL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
-// CHECK: scf.for %[[KL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
-// CHECK: %[[ARG0_TILE_L3:.+]] = subview %[[ARG0]][%[[IL3]], %[[KL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
-// CHECK: %[[ARG1_TILE_L3:.+]] = subview %[[ARG1]][%[[KL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
-// CHECK: %[[ARG2_TILE_L3:.+]] = subview %[[ARG2]][%[[IL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
-// CHECK: scf.for %[[IL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
-// CHECK: scf.for %[[JL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
-// CHECK: scf.for %[[KL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
-// CHECK: %[[ARG0_TILE_L2:.+]] = subview %[[ARG0_TILE_L3]][%[[IL2]], %[[KL2]]] [32, 32] [1, 1] : memref<64x64xf32
-// CHECK: %[[ARG1_TILE_L2:.+]] = subview %[[ARG1_TILE_L3]][%[[KL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
-// CHECK: %[[ARG2_TILE_L2:.+]] = subview %[[ARG2_TILE_L3]][%[[IL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
-// CHECK: scf.for %[[IL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
-// CHECK: scf.for %[[JL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
-// CHECK: scf.for %[[KL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
-// CHECK: %[[ARG0_TILE_L1:.+]] = subview %[[ARG0_TILE_L2]][%[[IL1]], %[[KL1]]] [4, 4] [1, 1] : memref<32x32xf32
-// CHECK: %[[ARG1_TILE_L1:.+]] = subview %[[ARG1_TILE_L2]][%[[KL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
-// CHECK: %[[ARG2_TILE_L1:.+]] = subview %[[ARG2_TILE_L2]][%[[IL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK-LABEL: func @matmul_128x128x128
+// CHECK-SAME: (%[[ARG0:.+]]: memref<128x128xf32>, %[[ARG1:.+]]: memref<128x128xf32>, %[[ARG2:.+]]: memref<128x128xf32>)
+// CHECK-DAG: %[[L3END:.+]] = constant 128 : index
+// CHECK-DAG: %[[L3STEP:.+]] = constant 64 : index
+// CHECK-DAG: %[[L1STEP:.+]] = constant 4 : index
+// CHECK-DAG: %[[L2STEP:.+]] = constant 32 : index
+// CHECK-DAG: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[JL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: scf.for %[[KL3:.+]] = %[[START]] to %[[L3END]] step %[[L3STEP]]
+// CHECK: %[[ARG0_TILE_L3:.+]] = subview %[[ARG0]][%[[IL3]], %[[KL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L3:.+]] = subview %[[ARG1]][%[[KL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L3:.+]] = subview %[[ARG2]][%[[IL3]], %[[JL3]]] [64, 64] [1, 1] : memref<128x128xf32> to memref<64x64xf32
+// CHECK: scf.for %[[IL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[JL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: scf.for %[[KL2:.+]] = %[[START]] to %[[L3STEP]] step %[[L2STEP]]
+// CHECK: %[[ARG0_TILE_L2:.+]] = subview %[[ARG0_TILE_L3]][%[[IL2]], %[[KL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG1_TILE_L2:.+]] = subview %[[ARG1_TILE_L3]][%[[KL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: %[[ARG2_TILE_L2:.+]] = subview %[[ARG2_TILE_L3]][%[[IL2]], %[[JL2]]] [32, 32] [1, 1] : memref<64x64xf32
+// CHECK: scf.for %[[IL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[JL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: scf.for %[[KL1:.+]] = %[[START]] to %[[L2STEP]] step %[[L1STEP]]
+// CHECK: %[[ARG0_TILE_L1:.+]] = subview %[[ARG0_TILE_L2]][%[[IL1]], %[[KL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG1_TILE_L1:.+]] = subview %[[ARG1_TILE_L2]][%[[KL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
+// CHECK: %[[ARG2_TILE_L1:.+]] = subview %[[ARG2_TILE_L2]][%[[IL1]], %[[JL1]]] [4, 4] [1, 1] : memref<32x32xf32
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h b/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h
index e801c90..84d6393 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Attributes.h
@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#ifndef IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_ATTRIBUTES_H_
-#define IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_ATTRIBUTES_H_
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_ATTRIBUTES_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_ATTRIBUTES_H_
#include "llvm/ADT/StringRef.h"
@@ -26,34 +26,13 @@
return "vkspv.entry_point_schedule";
}
-/// Enumerates the methods used to compute the number of workgroups to use with
-/// an entry point function. The lowering to SPIR-V sets an integer attribute on
-/// the entry point function with one of these values. It is later used by
-/// `recordDispatch` to compute the number of workgroups for the entry point
-/// function.
-enum class WorkgroupCountMethodology {
- // TODO(#2134): Remove the `Default` option. This is only a fallback used for
- // convolution/pooling cases that are currently not working as intended, as
- // described in the bug.
- Default = 0, // Use the default mechanism used by IREE
- LinearizeResultShape = 1, // Use the linearized shape of the result of the
- // dispatch region
- ResultShape = 2 // Use the shape of the dispatch region.
-};
-
-/// Returns the name of the attribute to use that propagates the method to use
-/// to compute the number of workgroups to use with an entry point function. The
-/// attribute used is an IntegerAttr with value being one of the enum entries of
-/// WorkgroupCountMethodology.
-// TODO(ravishankarm): The approach to use attributes to propagate the
-// methodology to use to compute number of workgroups is to convoluted. Ideally,
-// the lowering should produce a function that should then just be inlined at
-// the point this is needed.
-inline llvm::StringRef getWorkgroupCountAttrName() {
- return "vkspv.workgroup_count_from_result_shape";
+/// Attribute on a entry point function that specifies which function computes
+/// the number of workgroups.
+inline llvm::StringRef getNumWorkgroupsFnAttrName() {
+ return "vkspv.num_workgroups_fn";
}
} // namespace iree_compiler
} // namespace mlir
-#endif // IREE_COMPILER_TRANSLATION_SPIRV_LINALGTOSPIRV_ATTRIBUTES_H_
+#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_ATTRIBUTES_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/BUILD b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
index 6a30713..dd6717a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/BUILD
+++ b/iree/compiler/Conversion/LinalgToSPIRV/BUILD
@@ -24,6 +24,9 @@
"ConvertToGPUPass.cpp",
"ConvertToSPIRVPass.cpp",
"CooperativeMatrixAnalysis.cpp",
+ "DeclareNumWorkgroupsFnPass.cpp",
+ "KernelDispatchUtils.cpp",
+ "LegalizeNumWorkgroupsFnPass.cpp",
"LinalgTileAndFusePass.cpp",
"MarkerUtils.cpp",
"MatMulVectorizationTest.cpp",
@@ -36,6 +39,7 @@
hdrs = [
"Attributes.h",
"CooperativeMatrixAnalysis.h",
+ "KernelDispatchUtils.h",
"MarkerUtils.h",
"MemorySpace.h",
"Passes.h",
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
index 8e34e87..ce9a6cb 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
+++ b/iree/compiler/Conversion/LinalgToSPIRV/CMakeLists.txt
@@ -20,6 +20,7 @@
HDRS
"Attributes.h"
"CooperativeMatrixAnalysis.h"
+ "KernelDispatchUtils.h"
"MarkerUtils.h"
"MemorySpace.h"
"Passes.h"
@@ -28,6 +29,9 @@
"ConvertToGPUPass.cpp"
"ConvertToSPIRVPass.cpp"
"CooperativeMatrixAnalysis.cpp"
+ "DeclareNumWorkgroupsFnPass.cpp"
+ "KernelDispatchUtils.cpp"
+ "LegalizeNumWorkgroupsFnPass.cpp"
"LinalgTileAndFusePass.cpp"
"MarkerUtils.cpp"
"MatMulVectorizationTest.cpp"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
index 38794f4..7f82188 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/ConvertToGPUPass.cpp
@@ -21,11 +21,14 @@
#include <array>
#include <numeric>
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
@@ -506,15 +509,16 @@
namespace {
/// Pass to convert from tiled and fused linalg ops into gpu.func.
-struct ConvertToGPUPass : public PassWrapper<ConvertToGPUPass, FunctionPass> {
+struct ConvertToGPUPass
+ : public PassWrapper<ConvertToGPUPass, OperationPass<ModuleOp>> {
ConvertToGPUPass() = default;
ConvertToGPUPass(const ConvertToGPUPass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, gpu::GPUDialect, scf::SCFDialect>();
+ registry.insert<AffineDialect, gpu::GPUDialect, scf::SCFDialect,
+ ShapeDialect>();
}
-
- void runOnFunction() override;
+ void runOnOperation() override;
};
struct SerializeParallelLoopPattern
@@ -651,11 +655,29 @@
workgroupSize = {32, 1, 1};
}
}
- rewriter.eraseOp(linalgOp);
if (failed(updateWorkGroupSize(funcOp, workgroupSize))) return failure();
- funcOp.setAttr(getWorkgroupCountAttrName(),
- rewriter.getI32IntegerAttr(static_cast<int32_t>(
- WorkgroupCountMethodology::LinearizeResultShape)));
+
+ // TODO(#3145): The use of the generated function to compute the number of
+ // workgroups works for dynamic shapes as well, but there is an issue with
+ // the shape computation + ha.device.switch creation that needs to be
+ // resolved before the generated function can be used on the host side. So
+ // disabling this approach for dynamic shape case.
+ if (llvm::all_of(linalgOp.getOperands(), [](Value v) {
+ Type t = v.getType();
+ return (t.isa<ShapedType>() &&
+ t.cast<ShapedType>().hasStaticShape()) ||
+ t.isIntOrIndexOrFloat();
+ })) {
+ if (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+ failed(createNumWorkgroupsFromLinearizedResultShape(
+ rewriter, cast<linalg::LinalgOp>(linalgOp.getOperation()), funcOp,
+ workgroupSize[0]))) {
+ return failure();
+ }
+ } else {
+ funcOp.removeAttr(getNumWorkgroupsFnAttrName());
+ }
+ rewriter.eraseOp(linalgOp);
return success();
}
};
@@ -730,15 +752,7 @@
patterns.insert<TileAndDistributeCopyOp>(context);
}
-void ConvertToGPUPass::runOnFunction() {
- FuncOp funcOp = getFunction();
-
- Region &body = funcOp.getBody();
- if (!llvm::hasSingleElement(body)) {
- funcOp.emitError("unhandled dispatch function with multiple blocks");
- return signalPassFailure();
- }
-
+void ConvertToGPUPass::runOnOperation() {
MLIRContext *context = &getContext();
ConversionTarget target(*context);
// After this pass Linalg and scf.parallel ops should be gone.
@@ -766,11 +780,19 @@
MapLinalgOpToLocalInvocationId<linalg::PoolingSumOp>,
RemoveLinalgRange, SerializeParallelLoopPattern>(context);
- if (failed(applyFullConversion(funcOp, target, patterns)))
- return signalPassFailure();
+ for (FuncOp funcOp : getOperation().getOps<FuncOp>()) {
+ if (!isEntryPoint(funcOp)) continue;
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body)) {
+ funcOp.emitError("unhandled dispatch function with multiple blocks");
+ return signalPassFailure();
+ }
+ if (failed(applyFullConversion(funcOp, target, patterns)))
+ return signalPassFailure();
+ }
}
-std::unique_ptr<OperationPass<FuncOp>> createConvertToGPUPass() {
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToGPUPass() {
return std::make_unique<ConvertToGPUPass>();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp
new file mode 100644
index 0000000..3bd8074
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/DeclareNumWorkgroupsFnPass.cpp
@@ -0,0 +1,152 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+//===- DeclareNumWorkgroupsFnPass.cpp - Declares num_workgroups_fn --------===//
+//
+// Define the function that computes the number of workgroups for every entry
+// point function. This pass only defines the function. Its body will be filled
+// in later.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+static constexpr const char kNumWorkgroupsStr[] = "__num_workgroups__";
+
+namespace {
+
+/// The contract between the host and the device is captured by the _impl
+/// function that is called from the main entry point function. This pattern
+/// looks for the call operation and
+/// - Declares (doesnt define) the function that computes the number of
+/// workgroups to use for this entry point function. It is defined later in
+/// the codegen pipeline, when the computation is mapped to
+/// workgroups/workitems. The signature of this function is
+///
+/// (!shapex.ranked_shape, !shapex.ranked_shape, ....) ->
+/// (index, index, index)
+///
+/// where the arguments are the shape of the tensor inputs + outputs of the
+/// dispatch region.
+/// - Sets the attribute `operand_result_index` on the
+/// `hal.interface.load.tensor`/`hal.interface.store.tensor` ops that are
+/// later used in the generation of the function declared here.
+struct DeclareNumWorkgroupsFn : OpRewritePattern<FuncOp> {
+ using OpRewritePattern<FuncOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(FuncOp entryPointFn,
+ PatternRewriter &rewriter) const override {
+ if (!isEntryPoint(entryPointFn) ||
+ entryPointFn.getAttr(getNumWorkgroupsFnAttrName()))
+ return failure();
+ Region &body = entryPointFn.getBody();
+ if (!llvm::hasSingleElement(body)) {
+ return entryPointFn.emitError(
+ "unhandled dispatch function with multiple blocks");
+ }
+ auto callOps = body.front().getOps<CallOp>();
+ if (!llvm::hasSingleElement(callOps)) {
+ return entryPointFn.emitError(
+ "expected dispatch function to have a single call operation");
+ }
+ CallOp callOp = *callOps.begin();
+
+ SmallVector<ShapedType, 4> shapedTypes;
+ shapedTypes.reserve(callOp.getNumOperands() - 1 + callOp.getNumResults());
+
+ // Add `operand_result_index` attribute to `hal.interface.load.tensor`
+ // operations that define the operands of the call op.
+ for (Value operand : callOp.operands()) {
+ if (!operand.getType().isa<ShapedType>()) continue;
+ if (auto definingOp =
+ operand.getDefiningOp<IREE::HAL::InterfaceLoadTensorOp>()) {
+ definingOp.setAttr(getOperandResultNumAttrName(),
+ rewriter.getI32IntegerAttr(shapedTypes.size()));
+ }
+ shapedTypes.push_back(operand.getType().cast<ShapedType>());
+ }
+
+ // Add `operand_result_index` attribute to the `hal.interface.store.tensor`
+ // that use the value returned by the call op.
+ for (Value result : callOp.getResults()) {
+ if (!result.getType().isa<ShapedType>()) continue;
+ for (auto &use : result.getUses()) {
+ if (auto storeOp =
+ dyn_cast<IREE::HAL::InterfaceStoreTensorOp>(use.getOwner())) {
+ storeOp.setAttr(getOperandResultNumAttrName(),
+ rewriter.getI32IntegerAttr(shapedTypes.size()));
+ }
+ }
+ shapedTypes.push_back(result.getType().cast<ShapedType>());
+ }
+
+ IndexType indexType = rewriter.getIndexType();
+ SmallVector<Type, 4> argTypes = llvm::to_vector<4>(
+ llvm::map_range(shapedTypes, [&rewriter](ShapedType t) -> Type {
+ return Shape::RankedShapeType::get(t.getShape(),
+ rewriter.getContext());
+ }));
+ FuncOp numWorkgroupsFn = rewriter.create<FuncOp>(
+ entryPointFn.getLoc(), entryPointFn.getName().str() + kNumWorkgroupsStr,
+ rewriter.getFunctionType(argTypes, {indexType, indexType, indexType}));
+ numWorkgroupsFn.setVisibility(FuncOp::Visibility::Private);
+ entryPointFn.setAttr(getNumWorkgroupsFnAttrName(),
+ rewriter.getSymbolRefAttr(numWorkgroupsFn));
+ rewriter.updateRootInPlace(entryPointFn, []() {});
+ return success();
+ }
+};
+
+/// Pass to define the function for number of workgroups for every entry point
+/// function.
+struct DeclareNumWorkgroupsFnPass
+ : public PassWrapper<DeclareNumWorkgroupsFnPass, OperationPass<ModuleOp>> {
+ DeclareNumWorkgroupsFnPass() = default;
+ DeclareNumWorkgroupsFnPass(const DeclareNumWorkgroupsFnPass &pass) {}
+ void runOnOperation() override;
+
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<ShapeDialect>();
+ }
+};
+} // namespace
+
+void DeclareNumWorkgroupsFnPass::runOnOperation() {
+ OwningRewritePatternList patterns;
+ MLIRContext *context = &getContext();
+ patterns.insert<DeclareNumWorkgroupsFn>(context);
+ applyPatternsAndFoldGreedily(getOperation(), patterns);
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass() {
+ return std::make_unique<DeclareNumWorkgroupsFnPass>();
+}
+
+static PassRegistration<DeclareNumWorkgroupsFnPass> pass(
+ "iree-codegen-init-num-workgroups-fn",
+ "Declares the function that computes the number of workgroups to use",
+ [] { return std::make_unique<DeclareNumWorkgroupsFnPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
new file mode 100644
index 0000000..ea85ec2
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.cpp
@@ -0,0 +1,157 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- KernelDispatchUtils.cpp - Utilities for generating dispatch info ---===//
+//
+// This file defines utility functions that can be used to create information
+// the dispatch on the host side needs to execute an entry point function, like
+// the number of workgroups to use for launch, etc.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "llvm/Support/Debug.h"
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/IR/BlockAndValueMapping.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+
+#define DEBUG_TYPE "kernel-dispatch-utils"
+
+namespace mlir {
+namespace iree_compiler {
+
+FuncOp getNumWorkgroupsFn(FuncOp entryPointFn) {
+ SymbolRefAttr attr =
+ entryPointFn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
+ if (!attr) {
+ entryPointFn.emitError("missing attribute '")
+ << getNumWorkgroupsFnAttrName() << "'";
+ return nullptr;
+ }
+ FuncOp numWorkgroupsFn = dyn_cast_or_null<FuncOp>(SymbolTable::lookupSymbolIn(
+ entryPointFn.getParentOfType<ModuleOp>(), attr));
+ if (!numWorkgroupsFn) {
+ entryPointFn.emitError("unable to find num workgroups fn ") << attr;
+ return nullptr;
+ }
+ if (!numWorkgroupsFn.empty()) {
+ entryPointFn.emitError("num workgroups fn expected to be empty");
+ return nullptr;
+ }
+ return numWorkgroupsFn;
+}
+
+/// Computes the bounds of the parallel loops partitioned across workgroups.
+static Optional<SmallVector<Value, 2>> getParallelLoopRange(
+ PatternRewriter &rewriter, Location loc, linalg::LinalgOp linalgOp) {
+ FuncOp numWorkgroupsFn =
+ getNumWorkgroupsFn(linalgOp.getParentOfType<FuncOp>());
+ if (!numWorkgroupsFn) return {};
+ LLVM_DEBUG({
+ llvm::dbgs() << "Found num workgroups function : "
+ << numWorkgroupsFn.getName();
+ });
+ rewriter.setInsertionPointToEnd(numWorkgroupsFn.addEntryBlock());
+ llvm::SetVector<Operation *> slice;
+ getBackwardSlice(linalgOp, &slice);
+ BlockAndValueMapping mapper;
+ for (Operation *op : slice) {
+ rewriter.clone(*op, mapper);
+ }
+ // Clone the linalg operation just to compute the loop bounds.
+ linalg::LinalgOp clonedLinalgOp =
+ rewriter.clone(*linalgOp.getOperation(), mapper);
+ Optional<SmallVector<Value, 4>> bounds =
+ getLoopRanges(rewriter, clonedLinalgOp);
+ unsigned numParallelLoops = linalgOp.iterator_types()
+ .getValue()
+ .take_while([](Attribute attr) -> bool {
+ return attr.cast<StringAttr>().getValue() ==
+ getParallelIteratorTypeName();
+ })
+ .size();
+ SmallVector<Value, 2> returnVals(
+ bounds->begin(), std::next(bounds->begin(), numParallelLoops));
+ rewriter.eraseOp(clonedLinalgOp);
+ return returnVals;
+}
+
+/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
+static Value buildCeilDiv(PatternRewriter &rewriter, Location loc,
+ Value numerator, Value denominator) {
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ Value t = rewriter.create<AddIOp>(
+ loc, numerator, rewriter.create<SubIOp>(loc, denominator, one));
+ return rewriter.create<SignedDivIOp>(loc, t, denominator);
+}
+
+/// Utility method to build IR that computes ceil(`numerator` / `denominator`)
+/// when denominator is a constant.
+static Value buildCeilDivConstDenominator(PatternRewriter &rewriter,
+ Location loc, Value numerator,
+ int64_t denominator) {
+ return buildCeilDiv(rewriter, loc, numerator,
+ rewriter.create<ConstantIndexOp>(loc, denominator));
+}
+
+LogicalResult createNumWorkgroupsFromResultShape(PatternRewriter &rewriter,
+ linalg::LinalgOp linalgOp,
+ FuncOp entryPointFn,
+ ArrayRef<int64_t> tileSizes) {
+ Location loc = linalgOp.getLoc();
+ OpBuilder::InsertionGuard gaurd(rewriter);
+ Optional<SmallVector<Value, 2>> parallelLoopRange =
+ getParallelLoopRange(rewriter, loc, linalgOp);
+ if (!parallelLoopRange) return failure();
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 3> returnValues(3, one);
+ for (size_t i = 0, e = std::min<size_t>(parallelLoopRange->size(), 3); i != e;
+ ++i) {
+ returnValues[i] = buildCeilDivConstDenominator(
+ rewriter, loc, (*parallelLoopRange)[e - i - 1], tileSizes[e - i - 1]);
+ }
+ rewriter.create<mlir::ReturnOp>(loc, returnValues);
+ return success();
+}
+
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ int64_t workgroupSizeX) {
+ Location loc = linalgOp.getLoc();
+ OpBuilder::InsertionGuard gaurd(rewriter);
+ Optional<SmallVector<Value, 2>> parallelLoopRange =
+ getParallelLoopRange(rewriter, loc, linalgOp);
+ if (!parallelLoopRange) return failure();
+ Value one = rewriter.create<ConstantIndexOp>(loc, 1);
+ SmallVector<Value, 3> returnValues(3, one);
+ for (auto range : *parallelLoopRange) {
+ returnValues[0] = rewriter.create<MulIOp>(loc, range, returnValues[0]);
+ }
+ returnValues[0] = buildCeilDivConstDenominator(rewriter, loc, returnValues[0],
+ workgroupSizeX);
+ rewriter.create<mlir::ReturnOp>(loc, returnValues);
+ return success();
+}
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
new file mode 100644
index 0000000..1573c55
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h
@@ -0,0 +1,66 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===- KernelDispatchUtils.h - Utilities for generating dispatch info -----===//
+//
+// This file declares utility functions that can be used to create information
+// the dispatch on the host side needs to execute an entry point function, like
+// the number of workgroups to use for launch, etc.
+//
+//===----------------------------------------------------------------------===//
+#ifndef IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
+#define IREE_COMPILER_CONVERSION_LINALGTOSPIRV_KERNELDISPATCHUTILS_H_
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class FuncOp;
+class LogicalResult;
+class PatternRewriter;
+class ShapedType;
+class Value;
+
+namespace linalg {
+class LinalgOp;
+}
+
+namespace iree_compiler {
+
+/// Generates a function that computes the number of workgroups as
+/// [ceil(`parallelLoopRange`[2] / `tileSizes`[2]),
+/// ceil(`parallelLoopRange`[1] / `tileSizes`[1]),
+/// ceil(`parallelLoopRange`[0] / `tileSizes`[0])]
+/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups.
+LogicalResult createNumWorkgroupsFromResultShape(PatternRewriter &rewriter,
+ linalg::LinalgOp linalgOp,
+ FuncOp entryPointFn,
+ ArrayRef<int64_t> tileSizes);
+
+/// Generates a function that computes the number of workgroups as
+/// ceil(`parallelLoopRange`[0] * `parallelLoopRange`[1] * ... *
+/// `parallelLoopRange`[n-1] / `workgroupSizeX`)
+/// where `parallelLoopRange` is the ranges of the parallel loops of `linalgOp`
+/// distributed across workgroups.
+LogicalResult createNumWorkgroupsFromLinearizedResultShape(
+ PatternRewriter &rewriter, linalg::LinalgOp linalgOp, FuncOp entryPointFn,
+ int64_t workgroupSizeX);
+
+/// For a given `entryPointFn` return the function that computes the number of
+/// workgroups to use at launch time.
+FuncOp getNumWorkgroupsFn(FuncOp entryPointFn);
+
+} // namespace iree_compiler
+} // namespace mlir
+#endif // IREE_COMPILER_CONVERSION_LINALGTOSPIRV_DISPATCHUTILS_H_
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp
new file mode 100644
index 0000000..1d832c1
--- /dev/null
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LegalizeNumWorkgroupsFnPass.cpp
@@ -0,0 +1,126 @@
+// Copyright 2020 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+//===-LegalizeNumWorkgroupsFnPass.cpp - Legalize to be runnable on host ---===//
+//
+// The function generated by the codegeneration pass to compute the number of
+// workgroups uses a slice of the device-side code. Legalize it to run on the
+// host.
+//
+//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace iree_compiler {
+
+namespace {
+
+/// Pattern to legalize shapex.tie_shape operation to tie the shape of the
+/// `iree.placeholder` result to the argument of the function.
+struct LegalizeTieShapeOp : OpRewritePattern<Shape::TieShapeOp> {
+ using OpRewritePattern<Shape::TieShapeOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(Shape::TieShapeOp tieShapeOp,
+ PatternRewriter &rewriter) const override {
+ if (tieShapeOp.shape().isa<BlockArgument>()) return failure();
+ auto phOp = dyn_cast_or_null<IREE::PlaceholderOp>(
+ tieShapeOp.operand().getDefiningOp());
+ if (!phOp) return failure();
+ IntegerAttr operandNumAttr =
+ phOp.getAttrOfType<IntegerAttr>(getOperandResultNumAttrName());
+ if (!operandNumAttr) {
+ return phOp.emitRemark("expected operand_result_index attribute");
+ }
+ FuncOp numWorkgroupsFn = phOp.getParentOfType<FuncOp>();
+ rewriter.replaceOpWithNewOp<Shape::TieShapeOp>(
+ tieShapeOp, phOp,
+ numWorkgroupsFn.getArgument(
+ phOp.getAttrOfType<IntegerAttr>(getOperandResultNumAttrName())
+ .getInt()));
+ return success();
+ }
+};
+
+/// Pattern to remove dead `iree.placeholder` ops. They arent removed since they
+/// are tagged as having `MemoryEffect`.
+struct RemoveDeadPlaceholderOp : OpRewritePattern<IREE::PlaceholderOp> {
+ using OpRewritePattern<IREE::PlaceholderOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IREE::PlaceholderOp phOp,
+ PatternRewriter &rewriter) const override {
+ if (phOp.use_empty()) {
+ rewriter.eraseOp(phOp);
+ return success();
+ }
+ return failure();
+ }
+};
+
+/// Pass to legalize the function that computes the number of workgroups to use
+/// for launch to be runnable on the host-side.
+struct LegalizeNumWorkgroupsFnPass
+ : public PassWrapper<LegalizeNumWorkgroupsFnPass, OperationPass<ModuleOp>> {
+ LegalizeNumWorkgroupsFnPass() = default;
+ LegalizeNumWorkgroupsFnPass(const LegalizeNumWorkgroupsFnPass &pass) {}
+ void runOnOperation() override;
+};
+} // namespace
+
+static void populateLegalizeNumWorkgroupsFnPattern(
+ MLIRContext *context, OwningRewritePatternList &patterns) {
+ patterns.insert<LegalizeTieShapeOp, RemoveDeadPlaceholderOp>(context);
+}
+
+void LegalizeNumWorkgroupsFnPass::runOnOperation() {
+ ModuleOp module = getOperation();
+ auto fns = module.getOps<FuncOp>();
+
+ OwningRewritePatternList patterns;
+ MLIRContext *context = &getContext();
+ populateLegalizeNumWorkgroupsFnPattern(context, patterns);
+
+ SymbolTable symbolTable(module.getOperation());
+ for (FuncOp fn : fns) {
+ if (!isEntryPoint(fn)) continue;
+ auto numWorkgroupsFnAttr =
+ fn.getAttrOfType<SymbolRefAttr>(getNumWorkgroupsFnAttrName());
+ if (!numWorkgroupsFnAttr) continue;
+ StringRef numWorkgroupsFnName = numWorkgroupsFnAttr.getLeafReference();
+ FuncOp numWorkgroupsFn = symbolTable.lookup<FuncOp>(numWorkgroupsFnName);
+ if (!numWorkgroupsFn) {
+ fn.emitError("unable to find function to compute number of workgroups ")
+ << numWorkgroupsFnName;
+ return signalPassFailure();
+ }
+ if (failed(applyPatternsAndFoldGreedily(numWorkgroupsFn, patterns)))
+ return signalPassFailure();
+ }
+}
+
+std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass() {
+ return std::make_unique<LegalizeNumWorkgroupsFnPass>();
+}
+
+static PassRegistration<LegalizeNumWorkgroupsFnPass> pass(
+ "iree-codegen-legalize-num-workgroups-fn",
+ "Legalize the function that computes the number of workgroups to use to be "
+ "usable on the host side",
+ [] { return std::make_unique<LegalizeNumWorkgroupsFnPass>(); });
+
+} // namespace iree_compiler
+} // namespace mlir
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
index 71f550f..4c7417e 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/LinalgTileAndFusePass.cpp
@@ -17,11 +17,14 @@
// Implements a pass to tile and fuse linalg operations on buffers.
//
//===----------------------------------------------------------------------===//
+#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MarkerUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/MemorySpace.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Utils.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeDialect.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -195,7 +198,7 @@
namespace {
/// Function pass that implements tiling and fusion in Linalg on buffers.
struct LinalgTileAndFusePass
- : public PassWrapper<LinalgTileAndFusePass, FunctionPass> {
+ : public PassWrapper<LinalgTileAndFusePass, OperationPass<ModuleOp>> {
LinalgTileAndFusePass(ArrayRef<int64_t> workgroupSize = {},
ArrayRef<int64_t> tileSizes = {},
bool useWorkgroupMem = false) {
@@ -206,15 +209,10 @@
LinalgTileAndFusePass(const LinalgTileAndFusePass &pass) {}
void getDependentDialects(DialectRegistry ®istry) const override {
- // clang-format off
- registry.insert<AffineDialect,
- gpu::GPUDialect,
- linalg::LinalgDialect,
- scf::SCFDialect>();
- // clang-format on
+ registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
+ scf::SCFDialect, ShapeDialect>();
}
-
- void runOnFunction() override;
+ void runOnOperation() override;
private:
Option<bool> useWorkgroupMemory{
@@ -243,17 +241,20 @@
/// Pattern for tiling operations. Updates the workgroup size in the surrounding
/// function operation if tiling succeeds.
-struct TileMatmulPattern
- : public linalg::LinalgTilingPattern<linalg::MatmulOp> {
- using Base = linalg::LinalgTilingPattern<linalg::MatmulOp>;
+template <typename MatmulOp>
+struct TileMatmulPattern : public linalg::LinalgBaseTilingPattern {
+ using Base = linalg::LinalgBaseTilingPattern;
TileMatmulPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> workgroupSize, PatternBenefit benefit = 1)
- : Base(context, options.setDistributionOptions(matmulDistributionOptions),
+ : Base(MatmulOp::getOperationName(), context,
+ options.setDistributionOptions(matmulDistributionOptions),
linalg::LinalgMarker(
ArrayRef<Identifier>(),
Identifier::get(getWorkgroupNumItemsGENumItersMarker(),
context)),
benefit),
+ tileSizes(tileSizes.begin(), tileSizes.end()),
workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
virtual LogicalResult matchAndRewrite(Operation *op,
@@ -262,44 +263,17 @@
// erased.
FuncOp funcOp = op->getParentOfType<FuncOp>();
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, workgroupSize)))
- return failure();
- funcOp.setAttr(getWorkgroupCountAttrName(),
- rewriter.getI32IntegerAttr(static_cast<int32_t>(
- WorkgroupCountMethodology::ResultShape)));
- return success();
- }
-
- SmallVector<int64_t, 3> workgroupSize;
-};
-
-struct TileBatchMatmulPattern
- : public linalg::LinalgTilingPattern<linalg::BatchMatmulOp> {
- using Base = linalg::LinalgTilingPattern<linalg::BatchMatmulOp>;
- TileBatchMatmulPattern(MLIRContext *context,
- linalg::LinalgTilingOptions options,
- ArrayRef<int64_t> workgroupSize,
- PatternBenefit benefit = 1)
- : Base(context, options.setDistributionOptions(matmulDistributionOptions),
- linalg::LinalgMarker(
- ArrayRef<Identifier>(),
- Identifier::get(getWorkgroupMarker(), context)),
- benefit),
- workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
-
- LogicalResult matchAndRewrite(Operation *op,
- PatternRewriter &rewriter) const override {
- FuncOp funcOp = op->getParentOfType<FuncOp>();
- if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
- failed(updateWorkGroupSize(funcOp, this->workgroupSize))) {
+ failed(updateWorkGroupSize(funcOp, workgroupSize)) ||
+ (funcOp.getAttr(getNumWorkgroupsFnAttrName()) &&
+ failed(createNumWorkgroupsFromResultShape(
+ rewriter, cast<linalg::LinalgOp>(op), funcOp, tileSizes)))) {
return failure();
}
- funcOp.setAttr(getWorkgroupCountAttrName(),
- rewriter.getI32IntegerAttr(static_cast<int32_t>(
- WorkgroupCountMethodology::ResultShape)));
+ rewriter.eraseOp(op);
return success();
}
+ SmallVector<int64_t, 3> tileSizes;
SmallVector<int64_t, 3> workgroupSize;
};
@@ -318,6 +292,7 @@
struct TileConvPoolPattern : public linalg::LinalgTilingPattern<OpTy> {
using Base = linalg::LinalgTilingPattern<OpTy>;
TileConvPoolPattern(MLIRContext *context, linalg::LinalgTilingOptions options,
+ ArrayRef<int64_t> tileSizes,
ArrayRef<int64_t> workgroupSize,
PatternBenefit benefit = 1)
: Base(context,
@@ -326,6 +301,7 @@
ArrayRef<Identifier>(),
Identifier::get(getWorkgroupMarker(), context)),
benefit),
+ tileSizes(tileSizes.begin(), tileSizes.end()),
workgroupSize(workgroupSize.begin(), workgroupSize.end()) {}
LogicalResult matchAndRewrite(Operation *op,
@@ -335,12 +311,11 @@
if (!funcOp || failed(Base::matchAndRewrite(op, rewriter)) ||
failed(updateWorkGroupSize(funcOp, this->workgroupSize)))
return failure();
- funcOp.setAttr(getWorkgroupCountAttrName(),
- rewriter.getI32IntegerAttr(static_cast<int32_t>(
- WorkgroupCountMethodology::Default)));
+ funcOp.removeAttr(getNumWorkgroupsFnAttrName());
return success();
}
+ SmallVector<int64_t, 3> tileSizes;
SmallVector<int64_t, 3> workgroupSize;
};
@@ -406,66 +381,75 @@
};
} // namespace
-void LinalgTileAndFusePass::runOnFunction() {
+void LinalgTileAndFusePass::runOnOperation() {
MLIRContext *context = &getContext();
- FuncOp funcOp = getFunction();
- Region &body = funcOp.getBody();
- if (!llvm::hasSingleElement(body.getBlocks())) {
- funcOp.emitError("unhandled dispatch function with multiple blocks");
- return signalPassFailure();
- }
- Block &block = body.front();
- auto linalgOps = block.getOps<linalg::LinalgOp>();
- if (linalgOps.empty()) return;
+ ModuleOp module = getOperation();
- TileSizeCalculator tileSizeCalculator(funcOp);
- if (tileSizes.empty()) {
- // Get the tile sizes to use for the lowering.
- SmallVector<int64_t, 3> tileSizes;
- SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(), linalgOps.end());
- if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
+ LLVM_DEBUG(
+ llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";);
+ for (FuncOp funcOp : module.getOps<FuncOp>()) {
+ if (!isEntryPoint(funcOp)) continue;
+
+ Region &body = funcOp.getBody();
+ if (!llvm::hasSingleElement(body.getBlocks())) {
+ funcOp.emitError("unhandled dispatch function with multiple blocks");
return signalPassFailure();
- } else {
- tileSizeCalculator.setTileSizes(tileSizes);
- if (!workgroupSize.empty())
- tileSizeCalculator.setWorkgroupSize(workgroupSize);
- }
+ }
+ Block &block = body.front();
+ auto linalgOps = block.getOps<linalg::LinalgOp>();
+ if (linalgOps.empty()) return;
- LLVM_DEBUG({
- llvm::dbgs() << "--- IREE Linalg tile and fuse configuration ---\n";
- llvm::dbgs() << "# workgroup sizes: [";
- interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
- llvm::dbgs() << "]\ntile sizes: [";
- interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
- llvm::dbgs() << "]\n";
- });
+ TileSizeCalculator tileSizeCalculator(funcOp);
+ if (tileSizes.empty()) {
+ // Get the tile sizes to use for the lowering.
+ SmallVector<int64_t, 3> tileSizes;
+ SmallVector<linalg::LinalgOp, 1> opsVec(linalgOps.begin(),
+ linalgOps.end());
+ if (failed(tileSizeCalculator.inferTileAndWorkgroupSize(opsVec)))
+ return signalPassFailure();
+ } else {
+ tileSizeCalculator.setTileSizes(tileSizes);
+ if (!workgroupSize.empty())
+ tileSizeCalculator.setWorkgroupSize(workgroupSize);
+ }
- OwningRewritePatternList tilingPatterns;
- tilingPatterns
- .insert<TileConvPoolPattern<linalg::ConvOp>, TileMatmulPattern,
- TileBatchMatmulPattern, TileConvPoolPattern<linalg::PoolingMaxOp>,
- TileConvPoolPattern<linalg::PoolingMinOp>,
- TileConvPoolPattern<linalg::PoolingSumOp>>(
- context,
- linalg::LinalgTilingOptions()
- .setTileSizes(tileSizeCalculator.getTileSizes())
- .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
- tileSizeCalculator.getWorkgroupSize());
- applyPatternsAndFoldGreedily(getOperation(), tilingPatterns);
+ LLVM_DEBUG({
+ llvm::dbgs() << "@func " << funcOp.getName() << ": # workgroup sizes: [";
+ interleaveComma(tileSizeCalculator.getWorkgroupSize(), llvm::dbgs());
+ llvm::dbgs() << "]\ntile sizes: [";
+ interleaveComma(tileSizeCalculator.getTileSizes(), llvm::dbgs());
+ llvm::dbgs() << "]\n";
+ });
- if (useWorkgroupMemory) {
- // The promotion patterns are put separate from the tiling patterns to make
- // sure that the allocated scratchspace memory is constant sizes which
- // requires some folding to trigger.
- OwningRewritePatternList promotionPatterns;
- promotionPatterns.insert<PromoteMatmulSubviewsPattern,
- PromoteConvolutionSubviewsPattern>(
+ OwningRewritePatternList tilingPatterns;
+ tilingPatterns.insert<TileConvPoolPattern<linalg::ConvOp>,
+ TileMatmulPattern<linalg::MatmulOp>,
+ TileMatmulPattern<linalg::BatchMatmulOp>,
+ TileConvPoolPattern<linalg::PoolingMaxOp>,
+ TileConvPoolPattern<linalg::PoolingMinOp>,
+ TileConvPoolPattern<linalg::PoolingSumOp>>(
context,
- linalg::LinalgPromotionOptions()
- .setAllocationDeallocationFns(allocateWorkgroupMemory,
- deallocateWorkgroupMemory)
- .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
- applyPatternsAndFoldGreedily(getOperation(), promotionPatterns);
+ linalg::LinalgTilingOptions()
+ .setTileSizes(tileSizeCalculator.getTileSizes())
+ .setLoopType(linalg::LinalgTilingLoopType::ParallelLoops),
+ tileSizeCalculator.getTileSizes(),
+ tileSizeCalculator.getWorkgroupSize());
+ applyPatternsAndFoldGreedily(funcOp, tilingPatterns);
+
+ if (useWorkgroupMemory) {
+ // The promotion patterns are put separate from the tiling patterns to
+ // make sure that the allocated scratchspace memory is constant sizes
+ // which requires some folding to trigger.
+ OwningRewritePatternList promotionPatterns;
+ promotionPatterns.insert<PromoteMatmulSubviewsPattern,
+ PromoteConvolutionSubviewsPattern>(
+ context,
+ linalg::LinalgPromotionOptions()
+ .setAllocationDeallocationFns(allocateWorkgroupMemory,
+ deallocateWorkgroupMemory)
+ .setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory));
+ applyPatternsAndFoldGreedily(funcOp, promotionPatterns);
+ }
}
}
@@ -473,7 +457,7 @@
// Pass entry point and registration
//===----------------------------------------------------------------------===//
-std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
+std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
ArrayRef<int64_t> workgroupSize, ArrayRef<int64_t> tileSizes,
bool useWorkgroupMemory) {
return std::make_unique<LinalgTileAndFusePass>(workgroupSize, tileSizes,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
index 5c1b672..cc90ccc 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.cpp
@@ -100,11 +100,9 @@
// afterwards. This gives each Linalg op a second chance to be tiled,
// with the second tile and fuse pass.
//===--------------------------------------------------------------------===//
+ pm.addPass(createSplitDispatchFunctionPass());
pm.addPass(createLinalgTileAndFusePass(
options.workgroupSize, options.tileSizes, options.useWorkgroupMemory));
- pm.addPass(createSplitDispatchFunctionPass());
- pm.addPass(createLinalgTileAndFusePass(options.workgroupSize,
- options.useWorkgroupMemory));
pm.addPass(createCanonicalizerPass());
//===--------------------------------------------------------------------===//
@@ -121,6 +119,16 @@
pm.addPass(createCSEPass());
//===--------------------------------------------------------------------===//
+ // Legalize the function that computes the number of workgroups to be runnable
+ // on the host.
+ //
+ // Post-conditions:
+ // - The shape of the values created from `iree.placeholder` operations are
+ // tied to the arguments of the function.
+ //===--------------------------------------------------------------------===//
+ pm.addPass(createLegalizeNumWorkgroupsFnPass());
+
+ //===--------------------------------------------------------------------===//
// Resolve shape related ops.
//
// Pre-conditions:
@@ -137,6 +145,16 @@
pm.addPass(createResolveShapeOpsPass());
//===--------------------------------------------------------------------===//
+ // Legalize the function that computes the number of workgroups to be runnable
+ // on the host.
+ //
+ // Post-conditions:
+ // - The dead `iree.placeholder` operations are removed after shape
+ // resolution.
+ //===--------------------------------------------------------------------===//
+ pm.addPass(createLegalizeNumWorkgroupsFnPass());
+
+ //===--------------------------------------------------------------------===//
// Prepare stdandard ops for SPIR-V conversion.
//
// Post-conditions:
@@ -174,6 +192,21 @@
void buildSPIRVTransformPassPipeline(OpPassManager &pm,
const SPIRVCodegenOptions &options) {
//===--------------------------------------------------------------------===//
+ // The entry point functions call an _impl function that captures the ABI that
+ // the host side uses for the dispatch region. This ABI is needed when
+ // generating the function that computes the number of workgroups. Declare the
+ // function that returns the number of workgroups needed for an entry point
+ // function.
+ //
+ // Post-conditions
+
+ // - An empty, private function is defined for each entry point function
+ // that returns the number of workgroups.
+ // - The entry point function gets an attribute `vkspv.num_workgroups_fn` to
+ // record which function in the module returns the number of workgroups.
+ pm.addPass(createDeclareNumWorkgroupsFnPass());
+
+ //===--------------------------------------------------------------------===//
// Inline the impl dispatch function into the wrapper dispatch function.
//
// TODO(antiagainst): re-evaluate the inlining timing.
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
index 72872f5..6e07cb0 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
+++ b/iree/compiler/Conversion/LinalgToSPIRV/Passes.h
@@ -29,19 +29,27 @@
bool useWorkgroupMemory = false;
};
+/// Pass to initialize the function that computes the number of workgroups for
+/// each entry point function. The function is defined, but is populated later.
+std::unique_ptr<OperationPass<ModuleOp>> createDeclareNumWorkgroupsFnPass();
+
/// Pass to tile and fuse linalg operations on buffers. The pass takes as
/// argument the `workgroupSize` that the tiling should use. Note that the
/// tile-sizes are the reverse of the workgroup size. So workgroup size along
/// "x" is used to tile the innermost loop, along "y" for the next innermost (if
/// it exists) and along "z" for the next loop (if it exists). The workgroup
/// size is expected to be of size at-most 3.
-std::unique_ptr<OperationPass<FuncOp>> createLinalgTileAndFusePass(
+std::unique_ptr<OperationPass<ModuleOp>> createLinalgTileAndFusePass(
ArrayRef<int64_t> workGroupSize = {}, ArrayRef<int64_t> tileSizes = {},
bool useWorkgroupMem = false);
/// Pass to add the synchronizations and attributes needed to lower from PLoops
/// to GPU dialect.
-std::unique_ptr<OperationPass<FuncOp>> createConvertToGPUPass();
+std::unique_ptr<OperationPass<ModuleOp>> createConvertToGPUPass();
+
+/// Pass to legalize function that returns number of workgroups to use for
+/// launch to be runnable on the host.
+std::unique_ptr<OperationPass<ModuleOp>> createLegalizeNumWorkgroupsFnPass();
/// Pass to perform the final conversion to SPIR-V dialect.
/// This pass converts remaining interface ops into SPIR-V global variables,
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
index 9a98f6f..94436c6 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
+++ b/iree/compiler/Conversion/LinalgToSPIRV/SplitDispatchFunctionPass.cpp
@@ -28,6 +28,7 @@
#include "iree/compiler/Conversion/CodegenUtils/FunctionUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
+#include "iree/compiler/Conversion/LinalgToSPIRV/KernelDispatchUtils.h"
#include "iree/compiler/Conversion/LinalgToSPIRV/Passes.h"
#include "iree/compiler/Dialect/IREE/IR/IREEOps.h"
#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
@@ -45,6 +46,8 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/RegionUtils.h"
+#define DEBUG_TYPE "split-dispatch-function"
+
namespace mlir {
namespace iree_compiler {
@@ -178,13 +181,32 @@
StringRef newFnName = splitKernels.back();
builder.setInsertionPointToStart(moduleOp.getBody());
auto newFn = builder.create<FuncOp>(loc, newFnName, oldFn.getType());
+ LLVM_DEBUG({
+ llvm::dbgs() << "Created new function : func @" << newFn.getName()
+ << "\n";
+ });
// Copy over all attributes except type and name.
for (const auto &namedAttr : oldFn.getAttrs()) {
if (namedAttr.first != impl::getTypeAttrName() &&
- namedAttr.first != SymbolTable::getSymbolAttrName())
+ namedAttr.first != SymbolTable::getSymbolAttrName() &&
+ namedAttr.first != getNumWorkgroupsFnAttrName())
newFn.setAttr(namedAttr.first, namedAttr.second);
}
+ // Need special handling for the number of workgroups function.
+ if (FuncOp numWorkgroupsFn = getNumWorkgroupsFn(oldFn)) {
+ FuncOp newNumWorkgroupsFn =
+ builder.create<FuncOp>(loc, newFnName.str() + "__num_workgroups__",
+ numWorkgroupsFn.getType());
+ newNumWorkgroupsFn.setVisibility(FuncOp::Visibility::Private);
+ newFn.setAttr(getNumWorkgroupsFnAttrName(),
+ builder.getSymbolRefAttr(newNumWorkgroupsFn));
+ LLVM_DEBUG({
+ llvm::dbgs() << "Added func @" << newNumWorkgroupsFn.getName()
+ << " as num workgroups fn for func @" << newFn.getName()
+ << "\n";
+ });
+ }
// Collect the closure for the current Linalg op.
closure.clear();
@@ -211,6 +233,15 @@
moduleOp.setAttr(getEntryPointScheduleAttrName(),
builder.getArrayAttr(entryPoints));
+ if (FuncOp numWorkgroupsFn = getNumWorkgroupsFn(oldFn)) {
+ LLVM_DEBUG({
+ llvm::dbgs() << "Erased num workgroups fn func @"
+ << numWorkgroupsFn.getName() << " for func @"
+ << oldFn.getName() << "\n";
+ });
+ numWorkgroupsFn.erase();
+ }
+ LLVM_DEBUG({ llvm::dbgs() << "Erased func @" << oldFn.getName() << "\n"; });
oldFn.erase();
return success();
}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
index 51b088b..daa92e2 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/convert_to_gpu.mlir
@@ -7,10 +7,13 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @parallel_4D(%arg0: memref<?x?x?x?xf32>,
- %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>)
- attributes {iree.dispatch_fn_name = "parallel_4D"} {
+ func @parallel_4D() {
+ %arg0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 4 : i32} : memref<?x?x?x?xf32>
+ %arg1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 9 : i32} : memref<?x?x?x?xf32>
+ %arg2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 10 : i32} : memref<?x?x?x?xf32>
linalg.generic
{args_in = 2 : i64, args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
@@ -22,10 +25,18 @@
} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
+ func @parallel_4D__num_workgroups__
+ (!shapex.ranked_shape<[?,?,?,?]>, !shapex.ranked_shape<[?,?,?,?]>,
+ !shapex.ranked_shape<[?,?,?,?]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-LABEL: func @parallel_4D
// CHECK-SAME: local_size = dense<[32, 1, 1]>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 1
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
@@ -56,6 +67,71 @@
// -----
+#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+module attributes {
+ spv.target_env =
+ #spv.target_env<#spv.vce<v1.3,
+ [Shader], [SPV_KHR_storage_buffer_storage_class]>,
+ {max_compute_workgroup_invocations = 128 : i32,
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @parallel_4D_static() attributes {vkspv.num_workgroups_fn = @parallel_4D_static__num_workgroups__} {
+ %arg0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x5x6xf32>
+ %arg1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<3x4x5x6xf32>
+ %arg2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<3x4x5x6xf32>
+ linalg.generic
+ {args_in = 2 : i64, args_out = 1 : i64,
+ indexing_maps = [#map0, #map0, #map0],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+ %arg0, %arg1, %arg2 {
+ ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
+ %0 = addf %arg3, %arg4 : f32
+ linalg.yield %0 : f32
+ } : memref<3x4x5x6xf32>, memref<3x4x5x6xf32>, memref<3x4x5x6xf32>
+ return
+ }
+ func @parallel_4D_static__num_workgroups__
+ (!shapex.ranked_shape<[3,4,5,6]>, !shapex.ranked_shape<[3,4,5,6]>,
+ !shapex.ranked_shape<[3,4,5,6]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
+}
+// CHECK-LABEL: func @parallel_4D_static()
+// CHECK-SAME: local_size = dense<[32, 1, 1]>
+// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[C360:.+]] = constant 360 : index
+// CHECK-DAG: %[[C120:.+]] = constant 120 : index
+// CHECK-DAG: %[[C30:.+]] = constant 30 : index
+// CHECK-DAG: %[[C6:.+]] = constant 6 : index
+// CHECK-DAG: %[[BID:.+]] = "gpu.block_id"() {dimension = "x"}
+// CHECK-DAG: %[[BDIM:.+]] = "gpu.block_dim"() {dimension = "x"}
+// CHECK-DAG: %[[TID:.+]] = "gpu.thread_id"() {dimension = "x"}
+// CHECK: %[[BOFFSET:.+]] = muli %[[BID]], %[[BDIM]]
+// CHECK: %[[IV:.+]] = addi %[[BOFFSET]], %[[TID]]
+// CHECK: %[[COND:.+]] = cmpi "slt", %[[IV]], %[[C360]]
+// CHECK: scf.if %[[COND]]
+// CHECK: %[[IV0:.+]] = divi_signed %[[IV]], %[[C120]]
+// CHECK: %[[T14:.+]] = remi_signed %[[IV]], %[[C120]]
+// CHECK: %[[IV1:.+]] = divi_signed %[[T14]], %[[C30]]
+// CHECK: %[[T16:.+]] = remi_signed %[[T14]], %[[C30]]
+// CHECK: %[[IV2:.+]] = divi_signed %[[T16]], %[[C6]]
+// CHECK: %[[IV3:.+]] = remi_signed %[[T16]], %[[C6]]
+// CHECK: load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK: load %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+// CHECK: store %{{.+}}[%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
+
+// CHECK: func @[[NUM_WORKGROUPS_FN]]
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK-DAG: %[[C12:.+]] = constant 12 : index
+// CHECK: return %[[C12]], %[[C1]], %[[C1]]
+// -----
+
#map0 = affine_map<() -> ()>
#accesses = [#map0, #map0, #map0]
#trait = {
@@ -71,9 +147,13 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @scalar_add(%arg0 : memref<f32>, %arg1 : memref<f32>,
- %arg2 : memref<f32>)
- {
+ func @scalar_add() attributes {vkspv.num_workgroups_fn = @scalar_add__num_workgroups__} {
+ %arg0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<f32>
+ %arg1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<f32>
+ %arg2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<f32>
linalg.generic #trait %arg0, %arg1, %arg2 {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%0 = addf %arg3, %arg4 : f32
@@ -81,21 +161,38 @@
} : memref<f32>, memref<f32>, memref<f32>
return
}
+ func @scalar_add__num_workgroups__
+ (!shapex.ranked_shape<[]>, !shapex.ranked_shape<[]>,
+ !shapex.ranked_shape<[]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
-// CHECK-LABEL: func @scalar_add
-// CHECK-SAME: local_size = dense<1> : vector<3xi32>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 1
-// CHECK-NEXT: load
+// CHECK-LABEL: func @scalar_add()
+// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:[a-zA-Z0-9_]+]]
+// CHECK: load
// CHECK-NEXT: load
// CHECK-NEXT: addf
// CHECK-NEXT: store
// CHECK-NEXT: return
+// CHECK: func @[[NUM_WORKGROUPS_FN]]
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: return %[[C1]], %[[C1]], %[[C1]]
+
// -----
module {
- func @reduce_sum(%arg0: memref<?x?x?xf32>, %arg1: memref<f32>, %arg2: memref<?xf32>)
- attributes {iree.dispatch_fn_name = "reduce_sum"} {
+ func @reduce_sum() {
+ %arg0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?xf32>
+ %arg1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<f32>
+ %arg2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?xf32>
linalg.indexed_generic
{args_in = 2 : i64, args_out = 1 : i64,
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>,
@@ -113,10 +210,14 @@
}: memref<?x?x?xf32>, memref<f32>, memref<?xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-LABEL: func @reduce_sum
// CHECK-SAME: local_size = dense<[32, 1, 1]> : vector<3xi32>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 1
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
@@ -146,10 +247,11 @@
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>)
- attributes {spv.entry_point_abi = {local_size = dense<[8, 8, 1]> : vector<3xi32>},
- vkspv.workgroup_count_from_result_shape = 2 : i32} {
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul() {
+ %arg0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
+ %arg1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?xf32>
+ %arg2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
%c4 = constant 4 : index
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -178,6 +280,11 @@
}
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-LABEL: func @matmul
@@ -209,12 +316,11 @@
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding(%arg0: memref<?x?x?x?xf32>,
- %arg1: memref<?x?x?x?xf32>,
- %arg2: memref<?x?x?x?xf32>)
- attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>},
- vkspv.workgroup_count_from_result_shape = 0 : i32} {
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding() {
+ %arg0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?x?x?xf32>
+ %arg1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?x?x?xf32>
+ %arg2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?x?x?xf32>
%c2 = constant 2 : index
%c0 = constant 0 : index
%c3 = constant 3 : index
@@ -252,24 +358,29 @@
: memref<?x?x?x?xf32> to memref<?x?x?x?xf32, #map5>
linalg.conv(%arg0, %21, %27)
{__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]}
- : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32, #map5>
+ : memref<?x?x?x?xf32>, memref<?x?x?x?xf32, #map5>, memref<?x?x?x?xf32, #map5>
scf.yield
}
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @conv_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0}
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0}
// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[C1:.+]] = constant 1
// CHECK-DAG: %[[C2:.+]] = constant 2
// CHECK-DAG: %[[N:.+]] = dim %[[ARG1]], %[[C0]]
-// CHECK-DAG: %[[P:.+]] = dim %[[ARG2]], %[[C1]]
-// CHECK-DAG: %[[Q:.+]] = dim %[[ARG2]], %[[C2]]
+// CHECK-DAG: %[[P:.+]] = dim %[[RET0]], %[[C1]]
+// CHECK-DAG: %[[Q:.+]] = dim %[[RET0]], %[[C2]]
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -284,7 +395,7 @@
// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
// CHECK: scf.for %[[IV5:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
// CHECK: %[[SV1:.+]] = subview %[[ARG1]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
-// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
+// CHECK: %[[SV2:.+]] = subview %[[RET0]][%[[IV3]], %[[IV4]], %[[IV5]], 0]
// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
@@ -315,10 +426,11 @@
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @pooling_no_padding(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>)
- attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>},
- vkspv.workgroup_count_from_result_shape = 0 : i32} {
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @pooling_no_padding() {
+ %arg0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?xf32>
+ %arg1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?xf32>
+ %arg2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?xf32>
%c0 = constant 0 : index
%c1 = constant 1 : index
%0 = dim %arg1, %c0 : memref<?x?xf32>
@@ -344,23 +456,28 @@
%19 = subview %arg2[%arg3, %arg4] [%17, %18] [1, 1] : memref<?x?xf32> to memref<?x?xf32, #map4>
linalg.pooling_max(%16, %arg1, %19)
{__internal_linalg_transform__ = "workgroup", dilations = [1, 1], strides = [1, 1]}
- : memref<?x?xf32, #map4>, memref<?x?xf32>, memref<?x?xf32, #map4>
+ : memref<?x?xf32, #map4>, memref<?x?xf32>, memref<?x?xf32, #map4>
scf.yield
}
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
// CHECK: func @pooling_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0}
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0}
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
-// CHECK-DAG: %[[P:.+]] = dim %[[ARG2]], %[[C0]]
-// CHECK-DAG: %[[Q:.+]] = dim %[[ARG2]], %[[C1]]
+// CHECK-DAG: %[[P:.+]] = dim %[[RET0]], %[[C0]]
+// CHECK-DAG: %[[Q:.+]] = dim %[[RET0]], %[[C1]]
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -372,7 +489,7 @@
// CHECK: scf.for %[[IV3:.+]] = %[[BOFFSETY]] to %[[P]] step %[[BSTEPY]]
// CHECK: scf.for %[[IV4:.+]] = %[[BOFFSETX]] to %[[Q]] step %[[BSTEPX]]
// CHECK: %[[SV1:.+]] = subview %[[ARG0]][%[[IV3]], %[[IV4]]]
-// CHECK: %[[SV2:.+]] = subview %[[ARG2]][%[[IV3]], %[[IV4]]]
+// CHECK: %[[SV2:.+]] = subview %[[RET0]][%[[IV3]], %[[IV4]]]
// CHECK-DAG: %[[TIDX:.+]] = "gpu.thread_id"() {dimension = "x"}
// CHECK-DAG: %[[NTHREADSX:.+]] = "gpu.block_dim"() {dimension = "x"}
// CHECK-DAG: %[[TIDY:.+]] = "gpu.thread_id"() {dimension = "y"}
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
index 33cfde8..e06eae5 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/linalg_tile_and_fuse.mlir
@@ -7,23 +7,30 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>) {
- linalg.conv(%arg0, %arg1, %arg2)
+ func @conv_padding() {
+ %0 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0} : memref<?x?x?x?xf32>
+ %1 = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1} : memref<?x?x?x?xf32>
+ %2 = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0} : memref<?x?x?x?xf32>
+ linalg.conv(%0, %1, %2)
{dilations = [1, 1],
padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [1, 1]} :
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
-// CHECK-LABEL: func @conv_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK: func @conv_padding()
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0}
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1}
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0}
// CHECK: linalg.conv
// CHECK-SAME: %[[ARG0]]
// CHECK-SAME: %[[ARG1]]
-// CHECK-SAME: %[[ARG2]]
+// CHECK-SAME: %[[RET0]]
// -----
@@ -33,21 +40,30 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding(%arg0 : memref<?x?x?x?xf32>, %arg1 : memref<?x?x?x?xf32>,
- %arg2 : memref<?x?x?x?xf32>) {
- linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
+ func @conv_no_padding() {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?x?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?x?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?x?x?xf32>
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [1, 1]} :
memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
-// CHECK: func @conv_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?x?x?xf32>
+// CHECK: func @conv_no_padding()
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
@@ -59,12 +75,13 @@
// CHECK: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[NBLOCKSY]]]
// CHECK: %[[LBX:.+]] = affine.apply #[[MAP1]]()[%[[BIDX]]
// CHECK: %[[STEPX:.+]] = affine.apply #[[MAP1]]()[%[[NBLOCKSX]]]
-// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) =
-// (%[[BIDZ]], %[[LBY]], %[[LBX]])
-// CHECK-SAME: step (%[[NBLOCKSZ]], %[[STEPY]], %[[STEPX]])
-// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][%[[IV0]], %[[IV1]],
-// %[[IV2]], %[[C0]]] CHECK: %[[VIEW2:.+]] = subview
-// %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]] CHECK: linalg.conv
+// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]], %[[IV2:.+]]) = (%[[BIDZ]], %[[LBY]], %[[LBX]])
+// CHECK-SAME: step (%[[NBLOCKSZ]], %[[STEPY]], %[[STEPX]])
+// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], %[[IV2]], %[[C0]]]
+// CHECK: linalg.conv
// CHECK-SAME: %[[ARG0]], %[[VIEW1]], %[[VIEW2]]
// CHECK-SAME: "workgroup"
@@ -76,22 +93,35 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul(%arg0: memref<?x?xf32>,
- %arg1: memref<?x?xf32>,
- %ret0: memref<?x?xf32>) {
- linalg.matmul %arg0, %arg1, %ret0 :
+ func @matmul() attributes {vkspv.num_workgroups_fn = @matmul__num_workgroups__} {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?xf32>
+ linalg.matmul %0, %1, %2 :
(memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
+ func @matmul__num_workgroups__
+ (!shapex.ranked_shape<[?,?]>, !shapex.ranked_shape<[?,?]>,
+ !shapex.ranked_shape<[?,?]>) -> (index, index, index)
+ attributes {sym_visibility = "private"}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 8)>
-// CHECK: func @matmul
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK: func @matmul()
// CHECK-SAME: local_size = dense<[8, 8, 1]>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 2
+// CHECK-SAME: vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN:.[a-zA-Z0-9_]+]]
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
// CHECK-DAG: %[[C0:.+]] = constant 0
// CHECK-DAG: %[[C4:.+]] = constant 4
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
@@ -104,10 +134,22 @@
// CHECK: %[[VIEW1:.+]] = subview %[[ARG1]][%[[IV]], %[[LBX]]]
// CHECK: %[[LBY_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDY]]]
// CHECK: %[[LBX_2:.+]] = affine.apply #[[MAP0]]()[%[[BIDX]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]][%[[LBY_2]], %[[LBX_2]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[LBY_2]], %[[LBX_2]]]
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup_numprocs_ge_numiters"
// CHECK-SAME: %[[VIEW0]], %[[VIEW1]], %[[VIEW2]]
+// CHECK: func @[[NUM_WORKGROUPS_FN]]
+// CHECK-DAG: %[[C8:.+]] = constant 8 : index
+// CHECK-DAG: %[[C7:.+]] = constant 7 : index
+// CHECK-DAG: %[[C0:.+]] = constant 0 : index
+// CHECK-DAG: %[[C1:.+]] = constant 1 : index
+// CHECK: %[[DIM0:.+]] = dim %{{.*}}, %[[C0]]
+// CHECK: %[[DIM1:.+]] = dim %{{.*}}, %[[C1]]
+// CHECK: %[[T0:.+]] = addi %[[DIM1]], %[[C7]]
+// CHECK: %[[T1:.+]] = divi_signed %[[T0]], %[[C8]]
+// CHECK: %[[T2:.+]] = addi %[[DIM0]], %[[C7]]
+// CHECK: %[[T3:.+]] = divi_signed %[[T2]], %[[C8]]
+// CHECK: return %[[T1]], %[[T3]], %[[C1]]
// -----
@@ -117,21 +159,30 @@
[Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @pooling_sum_no_padding(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>,
- %arg2 : memref<?x?xf32>) {
- linalg.pooling_max(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]} :
+ func @pooling_sum_no_padding() {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?xf32>
+ linalg.pooling_max(%0, %1, %2) {dilations = [1, 1], strides = [1, 1]} :
memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 32)>
-// CHECK: func @pooling_sum_no_padding
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9$._-]+]]: memref<?x?xf32>
+// CHECK: func @pooling_sum_no_padding()
// CHECK-SAME: local_size = dense<[32, 4, 1]>
-// CHECK-SAME: vkspv.workgroup_count_from_result_shape = 0
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder {{.*}} {binding = @legacy_io::@ret0
// CHECK-DAG: %[[BIDX:.+]] = "gpu.block_id"() {dimension = "x"}
// CHECK-DAG: %[[NBLOCKSX:.+]] = "gpu.grid_dim"() {dimension = "x"}
// CHECK-DAG: %[[BIDY:.+]] = "gpu.block_id"() {dimension = "y"}
@@ -143,7 +194,7 @@
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = (%[[LBY]], %[[LBX]])
// CHECK-SAME: step (%[[STEPY]], %[[STEPX]])
// CHECK: %[[VIEW0:.+]] = subview %[[ARG0]][%[[IV0]], %[[IV1]]]
-// CHECK: %[[VIEW2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]]
+// CHECK: %[[VIEW2:.+]] = subview %[[RET0]][%[[IV0]], %[[IV1]]]
// CHECK: linalg.pooling_max
// CHECK-SAME: %[[VIEW0]], %[[ARG1]], %[[VIEW2]]
// CHECK-SAME: "workgroup"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
index 63ba65f..c642f62 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/matmul_vectorization.mlir
@@ -9,13 +9,13 @@
return
}
-// CHECK: %[[TILESIZE:.+]] = constant 32 : index
-// CHECK: %[[MATSIZE:.+]] = constant 128 : index
-// CHECK: %[[START:.+]] = constant 0 : index
-// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
-// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
-// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
-// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK-DAG: %[[TILESIZE:.+]] = constant 32 : index
+// CHECK-DAG: %[[MATSIZE:.+]] = constant 128 : index
+// CHECK-DAG: %[[START:.+]] = constant 0 : index
+// CHECK: scf.for %[[IL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: scf.for %[[JL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWC:.+]] = subview %[[ARG2]][%[[IL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: scf.for %[[KL:.+]] = %[[START]] to %[[MATSIZE]] step %[[TILESIZE]]
+// CHECK: %[[SUBVVIEWA:.+]] = subview %[[ARG0]][%[[IL]], %[[KL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
+// CHECK: %[[SUBVVIEWB:.+]] = subview %[[ARG1]][%[[KL]], %[[JL]]] [32, 32] [1, 1] : memref<128x128xf32> to memref<32x32xf32
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
index dbffa87..c115b4a 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/split_dispatch_function.mlir
@@ -23,7 +23,7 @@
// CHECK: linalg.fill(%[[TS]], %[[ZERO]])
// CHECK: return
- func @kernel() {
+ func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%dim = hal.interface.load.constant offset = 0 : index
%shape1 = shapex.make_ranked_shape %dim : (index) -> !shapex.ranked_shape<[?,2,2,512]>
@@ -37,6 +37,11 @@
linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
return
}
+ func @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ !shapex.ranked_shape<[3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,512]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -48,35 +53,44 @@
// CHECK: module attributes {vkspv.entry_point_schedule = ["kernel_dispatch_0", "kernel_dispatch_1", "kernel_dispatch_2"]}
module {
- // CHECK: func @kernel_dispatch_2()
- // CHECK: %[[DIM:.+]] = hal.interface.load.constant
- // CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
- // CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
- // CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
- // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- // CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
- // CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
- // CHECK: return
+// CHECK: func @kernel_dispatch_2()
+// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN2:.+]]}
+// CHECK: %[[DIM:.+]] = hal.interface.load.constant
+// CHECK: %[[SHAPE1:.+]] = shapex.make_ranked_shape %[[DIM]]
+// CHECK: %[[SHAPE2:.+]] = shapex.make_ranked_shape %[[DIM]]
+// CHECK: %[[IN1:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<?x2x2x512xf32>
+// CHECK: %[[TS1:.+]] = shapex.tie_shape %[[IN1]], %[[SHAPE1]]
+// CHECK: %[[IN2:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
+// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+// CHECK: %[[TS2:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE2]]
+// CHECK: linalg.conv(%[[IN2]], %[[TS1]], %[[TS2]])
+// CHECK: return
- // CHECK: func @kernel_dispatch_1() {
- // CHECK: %[[C0:.+]] = constant 0 : index
- // CHECK: %[[C1:.+]] = constant 1 : index
- // CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
- // CHECK: scf.yield
- // CHECK: return
+// CHECK: func @[[NUM_WORKGROUPS_FN2]]
- // CHECK: func @kernel_dispatch_0()
- // CHECK: %[[ZERO:.+]] = constant
- // CHECK: %[[DIM:.+]] = hal.interface.load.constant
- // CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
- // CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
- // CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
- // CHECK: linalg.fill(%[[TS]], %[[ZERO]])
- // CHECK: return
+// CHECK: func @kernel_dispatch_1()
+// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN1:.+]]}
+// CHECK: %[[C0:.+]] = constant 0 : index
+// CHECK: %[[C1:.+]] = constant 1 : index
+// CHECK: scf.parallel (%{{.*}}) = (%[[C0]]) to (%[[C1]]) step (%[[C1]])
+// CHECK: scf.yield
+// CHECK: return
- func @kernel() {
+// CHECK: func @[[NUM_WORKGROUPS_FN1]]
+
+// CHECK: func @kernel_dispatch_0()
+// CHECK-SAME: {vkspv.num_workgroups_fn = @[[NUM_WORKGROUPS_FN0:.+]]}
+// CHECK: %[[ZERO:.+]] = constant
+// CHECK: %[[DIM:.+]] = hal.interface.load.constant
+// CHECK: %[[SHAPE:.+]] = shapex.make_ranked_shape %[[DIM]]
+// CHECK: %[[OUT:.+]] = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<?x1x1x512xf32>
+// CHECK: %[[TS:.+]] = shapex.tie_shape %[[OUT]], %[[SHAPE]]
+// CHECK: linalg.fill(%[[TS]], %[[ZERO]])
+// CHECK: return
+
+// CHECK: func @[[NUM_WORKGROUPS_FN0]]
+
+ func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
@@ -95,6 +109,11 @@
linalg.conv(%1, %ts1, %ts2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<?x2x2x512xf32>, memref<?x1x1x512xf32>
return
}
+ func @kernel__num_workgroups__(!shapex.ranked_shape<[?,2,2,512]>,
+ !shapex.ranked_shape<[3,3,512,1]>,
+ !shapex.ranked_shape<[?,1,1,512]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -110,7 +129,7 @@
// CHECK-NOT: vkspv.entry_point_schedule
module {
// CHECK-LABEL: @kernel()
- func @kernel() {
+ func @kernel() attributes {vkspv.num_workgroups_fn = @kernel__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x2x2x512xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg1} : memref<3x3x512x1xf32>
@@ -118,6 +137,8 @@
linalg.conv(%1, %0, %2) {dilations = [1, 1], padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, strides = [2, 2]} : memref<3x3x512x1xf32>, memref<1x2x2x512xf32>, memref<1x1x1x512xf32>
return
}
+ // CHECK-LABEL: @kernel__num_workgroups__
+ func @kernel__num_workgroups__(!shapex.ranked_shape<[1,2,2,512]>, !shapex.ranked_shape<[3,3,1,512]>, !shapex.ranked_shape<[1,1,1,512]>) -> (index, index, index) attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
@@ -153,7 +174,8 @@
#map0 = affine_map<(d0, d1) -> (d0 * 12 + d1 + 53)>
module {
- func @subview_interleaved() {
+ func @subview_interleaved()
+ attributes {vkspv.num_workgroups_fn = @subview_interleaved__num_workgroups__} {
%cst = constant 0.000000e+00 : f32
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<18x12xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<12x4xf32>
@@ -162,6 +184,10 @@
linalg.copy(%1, %2) : memref<12x4xf32>, memref<18x12xf32, #map0>
return
}
+ func @subview_interleaved__num_workgroups__(!shapex.ranked_shape<[12,4]>,
+ !shapex.ranked_shape<[18,12]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write"
@@ -191,7 +217,8 @@
#map2 = affine_map<(d0, d1, d2) -> (d2)>
module {
- func @reshape_interleaved() {
+ func @reshape_interleaved()
+ attributes {vkspv.num_workgroups_fn = @reshape_interleaved__num_workgroups__} {
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<2x4xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1} : memref<1x2x4xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<2x4xf32>
@@ -206,6 +233,11 @@
linalg.copy(%3, %1) : memref<1x2x4xf32>, memref<1x2x4xf32>
return
}
+ func @reshape_interleaved__num_workgroups__(!shapex.ranked_shape<[2,4]>,
+ !shapex.ranked_shape<[2,4]>,
+ !shapex.ranked_shape<[1,2,4]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @ret0, set=0, binding=1, type="StorageBuffer", access="Write|Discard"
@@ -233,7 +265,8 @@
// -----
module {
- func @predict_ex_dispatch_0() {
+ func @predict_ex_dispatch_0()
+ attributes {vkspv.num_workgroups_fn = @predict_ex_dispatch_0__num_workgroups__} {
%0 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret0} : memref<1x512x1xf32>
%1 = iree.placeholder for "interface buffer" {binding = @legacy_io::@ret1} : memref<4x8x16xf32>
%2 = iree.placeholder for "interface buffer" {binding = @legacy_io::@arg0} : memref<1x512x1xf32>
@@ -248,6 +281,12 @@
}: memref<4x8x16xf32>, memref<4x8x16xf32>
return
}
+ func @predict_ex_dispatch_0__num_workgroups__(!shapex.ranked_shape<[1,512,1]>,
+ !shapex.ranked_shape<[4,8,16]>,
+ !shapex.ranked_shape<[1,512,1]>,
+ !shapex.ranked_shape<[4,8,16]>)
+ -> (index, index, index)
+ attributes {sym_visibility = "private"}
hal.interface @legacy_io attributes {push_constants = 1 : i32, sym_visibility = "private"} {
hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
diff --git a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
index fbbb483..9ff21ba 100644
--- a/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
+++ b/iree/compiler/Conversion/LinalgToSPIRV/test/workgroup_memory_promotion.mlir
@@ -4,21 +4,32 @@
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @matmul_tile(%arg0 : memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
- linalg.matmul %arg0, %arg1, %arg2 :
- (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @matmul_tile()
+ attributes {signature = (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>)} {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<?x?xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?xf32>
+ linalg.matmul %0, %1, %2 : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>)
return
}
+ hal.interface @legacy_io attributes {sym_visibility = "private"} {
+ hal.interface.binding @arg0, set=0, binding=0, type="StorageBuffer", access="Read"
+ hal.interface.binding @arg1, set=0, binding=1, type="StorageBuffer", access="Read"
+ hal.interface.binding @ret0, set=0, binding=2, type="StorageBuffer", access="Write"
+ }
}
-// CHECK-LABEL: func @matmul_tile
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?xf32>
+// CHECK: func @matmul_tile()
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0
// CHECK: scf.for
// CHECK: %[[ARG0SV:.+]] = subview %[[ARG0]]
// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
-// CHECK: %[[ARG2SV:.+]] = subview %[[ARG2]]
+// CHECK: %[[RET0SV:.+]] = subview %[[RET0]]
// CHECK: %[[ALLOC1:.+]] = alloc() : memref<8x4xf32, 3>
// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
// CHECK: %[[ALLOC2:.+]] = alloc() : memref<4x8xf32, 3>
@@ -29,7 +40,7 @@
// CHECK-SAME: "copy_to_workgroup_memory"
// CHECK: linalg.matmul
// CHECK-SAME: "workgroup_memory_numprocs_ge_numiters"
-// CHECK-SAME: %[[SUBVIEW1]], %[[SUBVIEW2]], %[[ARG2SV]]
+// CHECK-SAME: %[[SUBVIEW1]], %[[SUBVIEW2]], %[[RET0SV]]
// CHECK-DAG: dealloc %[[ALLOC1]] : memref<8x4xf32, 3>
// CHECK-DAG: dealloc %[[ALLOC2]] : memref<4x8xf32, 3>
@@ -39,25 +50,31 @@
spv.target_env =
#spv.target_env<#spv.vce<v1.3, [Shader], [SPV_KHR_storage_buffer_storage_class]>,
{max_compute_workgroup_invocations = 128 : i32,
- max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
- func @conv_no_padding_tile(%arg0: memref<3x4x3x2xf32>,
- %arg1: memref<?x?x?x3xf32>, %arg2: memref<?x?x?x2xf32>) {
- linalg.conv(%arg0, %arg1, %arg2) {dilations = [1, 1], strides = [1, 1]}
+ max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}>} {
+ func @conv_no_padding_tile()
+ attributes {signature = (tensor<3x4x3x2xf32>, tensor<?x?x?x3xf32>) -> (tensor<?x?x?x2xf32>)} {
+ %0 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg0, operand_result_index = 0 : i32} : memref<3x4x3x2xf32>
+ %1 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@arg1, operand_result_index = 1 : i32} : memref<?x?x?x3xf32>
+ %2 = iree.placeholder for "interace buffer"
+ {binding = @legacy_io::@ret0, operand_result_index = 2 : i32} : memref<?x?x?x2xf32>
+ linalg.conv(%0, %1, %2) {dilations = [1, 1], strides = [1, 1]}
: memref<3x4x3x2xf32>, memref<?x?x?x3xf32>, memref<?x?x?x2xf32>
return
}
}
-// CHECK-LABEL: func @conv_no_padding_tile
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<3x4x3x2xf32>
-// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?x?x3xf32>
-// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<?x?x?x2xf32>
-// CHECK: scf.parallel (%{{.*}}, %{{.*}}, %{{.*}})
-// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
-// CHECK: %[[ARG2SV:.+]] = subview %[[ARG2]]
-// CHECK: %[[ALLOC1:.+]] = alloc()
-// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
-// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW1]])
-// CHECK-SAME: "copy_to_workgroup_memory"
-// CHECK: linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[ARG2SV]])
-// CHECK-SAME: "workgroup_memory"
-// CHECK: dealloc %[[ALLOC1]]
+// CHECK: func @conv_no_padding_tile()
+// CHECK-DAG: %[[ARG0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg0
+// CHECK-DAG: %[[ARG1:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@arg1
+// CHECK-DAG: %[[RET0:.+]] = iree.placeholder for "interace buffer" {binding = @legacy_io::@ret0
+// CHECK: scf.parallel (%{{.*}}, %{{.*}}, %{{.*}})
+// CHECK: %[[ARG1SV:.+]] = subview %[[ARG1]]
+// CHECK: %[[RET0SV:.+]] = subview %[[RET0]]
+// CHECK: %[[ALLOC1:.+]] = alloc() : memref<1x6x35x3xf32, 3>
+// CHECK: %[[SUBVIEW1:.+]] = subview %[[ALLOC1]]
+// CHECK: linalg.copy(%[[ARG1SV]], %[[SUBVIEW1]])
+// CHECK-SAME: "copy_to_workgroup_memory"
+// CHECK: linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[RET0SV]])
+// CHECK-SAME: "workgroup_memory"
+// CHECK: dealloc %[[ALLOC1]] : memref<1x6x35x3xf32, 3>
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
index 18006a7..97f0662 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/BUILD
@@ -41,6 +41,7 @@
"//iree/compiler/Conversion/LinalgToSPIRV",
"//iree/compiler/Dialect/Flow/IR",
"//iree/compiler/Dialect/HAL/Target",
+ "//iree/compiler/Dialect/Shape/IR",
"//iree/compiler/Dialect/Vulkan/IR",
"//iree/compiler/Dialect/Vulkan/Utils",
"//iree/schemas:spirv_executable_def_cc_fbs",
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
index 05d2710..fdc9911 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/CMakeLists.txt
@@ -44,6 +44,7 @@
iree::compiler::Conversion::LinalgToSPIRV
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
+ iree::compiler::Dialect::Shape::IR
iree::compiler::Dialect::Vulkan::IR
iree::compiler::Dialect::Vulkan::Utils
iree::schemas::spirv_executable_def_cc_fbs
diff --git a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
index 6851ec2..f18c2b8 100644
--- a/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
+++ b/iree/compiler/Dialect/HAL/Target/VulkanSPIRV/VulkanSPIRVTarget.cpp
@@ -21,6 +21,7 @@
#include "iree/compiler/Conversion/LinalgToSPIRV/Attributes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetRegistry.h"
+#include "iree/compiler/Dialect/Shape/IR/ShapeOps.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanAttributes.h"
#include "iree/compiler/Dialect/Vulkan/IR/VulkanDialect.h"
#include "iree/compiler/Dialect/Vulkan/Utils/TargetEnvUtils.h"
@@ -37,6 +38,7 @@
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/TargetAndABI.h"
#include "mlir/Dialect/Vector/VectorOps.h"
+#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
@@ -127,105 +129,61 @@
ArrayRef<Value>{memoryBarrier}, ArrayRef<Value>{});
}
-/// Generates IR to compute the ceil(`numerator`, `denominator`).
-static Value computeCeilDiv(Location loc, Value one, Value numerator,
- Value denominator, OpBuilder &builder) {
- Value dm1 = builder.create<SubIOp>(loc, denominator, one);
- return builder.create<SignedDivIOp>(
- loc, builder.create<AddIOp>(loc, numerator, dm1), denominator);
-}
-
-/// Calculates the number of workgroups to use based on the shape of the result
-/// of the dispatch region. If the `resultShape` is {s0, s1, s2, s3, ....} and
-/// `workgroupSize` is {wx, wy, wz}, the number of workgroups is {ceil(s0/wz),
-/// ceil(s1/wy), ceil(s2/wx)}
-static std::array<Value, 3> calculateDispatchWorkgroupCountFromResultShape(
- Location loc, ArrayRef<Value> resultShape,
- const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
- if (resultShape.size() > 3) resultShape = resultShape.take_front(3);
- SmallVector<Value, 4> reverseResultSize(reverse(resultShape));
- Value one = builder.create<ConstantOp>(loc, builder.getIndexAttr(1));
- reverseResultSize.resize(3, one);
- return {
- computeCeilDiv(loc, one, reverseResultSize[0], workgroupSize[0], builder),
- computeCeilDiv(loc, one, reverseResultSize[1], workgroupSize[1], builder),
- computeCeilDiv(loc, one, reverseResultSize[2], workgroupSize[2],
- builder)};
-}
-
-/// Calculates the number of workgroups to use based on the linearized shape of
-/// the result of the dispatch region. The `workgroupSize` is assumed to be of
-/// the form {wx, 1, 1}. If the `resultShape` is {s0, s1, s2, ... sn}, then the
-/// number of workgroups is {ceil(s0*s1*s2*...*sn, wx)}
-static std::array<Value, 3>
-calculateDispatchWorkgroupCountFromLinearizedResultShape(
- Location loc, ArrayRef<Value> resultShape,
- const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
- if (!mlir::matchPattern(workgroupSize[1], m_One()) ||
- !mlir::matchPattern(workgroupSize[2], m_One())) {
- emitError(loc,
- "invalid workgroup size when computing workgroup count "
- "based linearized result shape");
- return {nullptr, nullptr, nullptr};
+/// The codegeneration emits a function `numWorkgroupsFn` for each entry point
+/// function. This function has arguments the !shapex.ranked_shape for all the
+/// input and output shaped types. Using this the function returns the number of
+/// workgroups to use. To use this function on the host side, generate the
+/// !shapex.ranked_shape values that describe the shape of the inputs and
+/// outputs of the dispatch region and "inline" the function body.
+static std::array<Value, 3> calculateWorkgroupCountFromNumWorkgroupsFn(
+ Location loc, FuncOp numWorkgroupsFn, IREE::HAL::InterfaceOp interface,
+ ArrayRef<Optional<TensorRewriteAdaptor>> operands,
+ ArrayRef<Optional<TensorRewriteAdaptor>> results, OpBuilder &builder) {
+ std::array<Value, 3> returnValue = {nullptr, nullptr, nullptr};
+ // TODO: This is really just inlining a function. For now assume that the
+ // `numWorkgroupsFn` has a single block to make inlining easier.
+ if (!numWorkgroupsFn || !llvm::hasSingleElement(numWorkgroupsFn))
+ return returnValue;
+ SmallVector<SmallVector<Value, 4>, 4> shapeValues;
+ shapeValues.reserve(operands.size() + results.size());
+ auto getShapeValuesFn =
+ [&](ArrayRef<Optional<TensorRewriteAdaptor>> values) -> LogicalResult {
+ for (auto val : values) {
+ if (!val) continue;
+ Optional<SmallVector<Value, 4>> shape = val->getShapeDims(builder);
+ if (!shape) return emitError(loc, "shape computation for operand failed");
+ shapeValues.push_back(shape.getValue());
+ }
+ return success();
+ };
+ if (failed(getShapeValuesFn(operands)) || failed(getShapeValuesFn(results)))
+ return returnValue;
+ BlockAndValueMapping mapper;
+ for (Operation &op : numWorkgroupsFn.front()) {
+ if (isa<mlir::ReturnOp>(op)) {
+ for (unsigned i = 0, e = std::min<unsigned>(3, op.getNumOperands());
+ i != e; ++i) {
+ returnValue[i] = mapper.lookupOrNull(op.getOperand(i));
+ }
+ break;
+ }
+ if (auto shapeOp = dyn_cast<Shape::RankedDimOp>(op)) {
+ if (BlockArgument arg = shapeOp.shape().dyn_cast<BlockArgument>()) {
+ auto &dimValues = shapeValues[arg.getArgNumber()];
+ mapper.map(arg, dimValues[shapeOp.getIndex()]);
+ continue;
+ }
+ return returnValue;
+ }
+ // If all its operands are mapped, clone it.
+ if (llvm::all_of(op.getOperands(), [&mapper](Value operand) {
+ return mapper.contains(operand);
+ })) {
+ builder.clone(op, mapper);
+ continue;
+ }
}
- Value one = builder.create<ConstantOp>(loc, builder.getIndexAttr(1));
- Value linearizedSize = one;
- for (Value dim : resultShape)
- linearizedSize = builder.create<MulIOp>(loc, linearizedSize, dim);
- return {computeCeilDiv(loc, one, linearizedSize, workgroupSize[0], builder),
- one, one};
-}
-
-/// Calculates the number of workgroups to use for a dispatch region based on
-/// the value of `workgroupCountMethodAttr`. This is obtained from an attribute
-/// specified on the entry point functions that is added while lowering to
-/// SPIR-V.
-// TODO(ravishankarm): This method of using enums to specify methodology to
-// compute workgroup count is very hard to maintain. The best approach would be
-// that the lowering generates a function that is "inlined" here. Need to figure
-// out the signature of that function so that it covers all use cases.
-static std::array<Value, 3> calculateSPIRVDispatchWorkgroupCount(
- Location loc, ArrayRef<Value> resultShape,
- IntegerAttr workgroupCountMethodAttr,
- const std::array<Value, 3> &workgroupSize, OpBuilder &builder) {
- WorkgroupCountMethodology workgroupCountMethod =
- static_cast<WorkgroupCountMethodology>(
- workgroupCountMethodAttr.getValue().getZExtValue());
- switch (workgroupCountMethod) {
- case WorkgroupCountMethodology::Default:
- return {nullptr, nullptr, nullptr};
- case WorkgroupCountMethodology::LinearizeResultShape:
- return calculateDispatchWorkgroupCountFromLinearizedResultShape(
- loc, resultShape, workgroupSize, builder);
- case WorkgroupCountMethodology::ResultShape:
- return calculateDispatchWorkgroupCountFromResultShape(
- loc, resultShape, workgroupSize, builder);
- }
- return {nullptr, nullptr, nullptr};
-}
-
-/// Gets the shape of the result from the dispatchState.
-static Optional<SmallVector<Value, 4>> getFirstResultShape(
- Location loc, TargetBackend::DispatchState dispatchState,
- OpBuilder &builder) {
- if (dispatchState.results.empty()) return llvm::None;
- Optional<TensorRewriteAdaptor> result = dispatchState.results[0];
- SmallVector<Value, 4> resultShape;
- // If the output is not a shaped type, assume it is a scalar, and return {1}.
- if (!result) {
- resultShape.push_back(
- builder.create<ConstantOp>(loc, builder.getIndexAttr(1)));
- return resultShape;
- }
-
- // TODO(ravishankarm): Using the result shape to get workgroup count, which
- // involes using `getShapeDims,` results in the shape values being captured
- // from outside of the switch statement in dynamic shape cases. This results
- // in an error since switch statements cannot capture. For now, use the
- // default path when the shape is dynamic.
- if (!result->getTensorType().hasStaticShape()) return llvm::None;
-
- return result->getShapeDims(builder);
+ return returnValue;
}
class VulkanSPIRVTargetBackend : public TargetBackend {
@@ -346,32 +304,18 @@
auto workgroupSize = calculateDispatchWorkgroupSize(
loc, spvModuleOp, spvFuncOp.sym_name(), workload, builder);
- StringRef workgroupCountAttrName = getWorkgroupCountAttrName();
- IntegerAttr workgroupCountAttr =
- spvFuncOp.getAttrOfType<IntegerAttr>(workgroupCountAttrName);
- if (!workgroupCountAttr)
- return spvFuncOp.emitError("missing attribute ")
- << workgroupCountAttrName;
-
- // Assuming here that the shape of the first result value of the dispatch
- // region is enough to calculate the number of workgroups. Either
- // - All results have the same shape and the `workgroupCountMethod` is set
- // to WorkgroupCountMethodology::ResultShape, or
- // - All the results have the same linearized shape and the
- // `workgourpCountMethod` is set to
- // WorkgroupCountMethodology::LinearizedResultShape.
- Optional<SmallVector<Value, 4>> resultShape =
- getFirstResultShape(loc, dispatchState, builder);
-
- WorkgroupCountMethodology workgroupCountMethod =
- static_cast<WorkgroupCountMethodology>(
- workgroupCountAttr.getValue().getZExtValue());
+ FlatSymbolRefAttr numWorkgroupsFnAttr =
+ spvFuncOp.getAttrOfType<FlatSymbolRefAttr>(
+ getNumWorkgroupsFnAttrName());
std::array<Value, 3> workgroupCount = {nullptr, nullptr, nullptr};
- if (resultShape &&
- workgroupCountMethod != WorkgroupCountMethodology::Default) {
- workgroupCount = calculateSPIRVDispatchWorkgroupCount(
- loc, *resultShape, workgroupCountAttr, workgroupSize, builder);
+ if (numWorkgroupsFnAttr) {
+ FuncOp numWorkgroupsFn = dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(
+ spvFuncOp.getParentOfType<ModuleOp>(), numWorkgroupsFnAttr));
+ if (!numWorkgroupsFn) return failure();
+ workgroupCount = calculateWorkgroupCountFromNumWorkgroupsFn(
+ loc, numWorkgroupsFn, executableOp.getInterfaceOp(),
+ dispatchState.operands, dispatchState.results, builder);
} else {
workgroupCount = calculateDispatchWorkgroupCount(
loc, workload, workgroupSize, builder);