Sin and Cos lowering for Spirv
PiperOrigin-RevId: 294310300
diff --git a/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationPass.cpp b/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationPass.cpp
index c08f2bb..227a4bf 100644
--- a/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationPass.cpp
+++ b/iree/compiler/Translation/SPIRV/IndexComputation/IndexComputationPass.cpp
@@ -70,6 +70,7 @@
NoBroadcastPwOpIndexPropagation<xla_hlo::CeilOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::ConvertOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::CosOp>,
+ NoBroadcastPwOpIndexPropagation<xla_hlo::SinOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::ExpOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::FloorOp>,
NoBroadcastPwOpIndexPropagation<xla_hlo::LogOp>,
diff --git a/iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.cpp b/iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.cpp
index e7d8249..aa824a4 100644
--- a/iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.cpp
+++ b/iree/compiler/Translation/SPIRV/XLAToSPIRV/IREEToSPIRVPass.cpp
@@ -115,6 +115,7 @@
SPIRVPwOpLowering<xla_hlo::AbsOp, spirv::GLSLSAbsOp, spirv::GLSLFAbsOp>,
SPIRVPwOpLowering<xla_hlo::CeilOp, spirv::GLSLCeilOp>,
SPIRVPwOpLowering<xla_hlo::CosOp, spirv::GLSLCosOp>,
+ SPIRVPwOpLowering<xla_hlo::SinOp, spirv::GLSLSinOp>,
SPIRVPwOpLowering<xla_hlo::ExpOp, spirv::GLSLExpOp>,
// TODO(ravishankarm) : For now extract-elementOp is a no-op cause index
// propagation only supports aggregates of rank 0.
diff --git a/test/e2e/xla/cos.mlir b/test/e2e/xla/cos.mlir
index fcbd787..a7b1320 100644
--- a/test/e2e/xla/cos.mlir
+++ b/test/e2e/xla/cos.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir -iree-hal-target-backends=interpreter-bytecode %s | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv %s | IreeFileCheck %s)
// CHECK-LABEL: EXEC @tensor
func @tensor() -> tensor<4xf32> {
diff --git a/test/e2e/xla/sin.mlir b/test/e2e/xla/sin.mlir
index 41ac757..744329d 100644
--- a/test/e2e/xla/sin.mlir
+++ b/test/e2e/xla/sin.mlir
@@ -1,4 +1,5 @@
// RUN: iree-run-mlir -iree-hal-target-backends=interpreter-bytecode %s | IreeFileCheck %s
+// RUN: [[ $IREE_VULKAN_DISABLE == 1 ]] || (iree-run-mlir -iree-hal-target-backends=vulkan-spirv %s | IreeFileCheck %s)
// CHECK-LABEL: EXEC @tensor
func @tensor() -> tensor<4xf32> {