dequantize softmax (#9337)
See #8974. This is still a 20% end-to-end latency improvement on MobileBert-int8 on configs where matmuls are already reasonably fast, making other things like Softmax more important relatively. That is even after Softmax slowness was much improved recently as observed in #9170. Moreover, discussion around #8974 suggests that the path forward for non-dequantized Softmax is nontrivial, so putting our benchmarks on the dequantized path for now will help insulate them a bit from what we expect to be in-flux for the foreseeable future.
diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
index 6adf676..804d2e0 100644
--- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
+++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp
@@ -52,9 +52,13 @@
// Convert all TFL ops to TOSA ops
//----------------------------------------------------------------------------
- mlir::tosa::TOSATFTFLLegalizationPipelineOptions tosaOptions;
pm.addPass(createLowerGlobalTensorsPass());
+
+ mlir::tosa::TOSATFTFLLegalizationPipelineOptions tosaOptions;
+ // Temporary work-around for https://github.com/google/iree/issues/8974
+ tosaOptions.dequantize_tfl_softmax = true;
mlir::tosa::createTFTFLtoTOSALegalizationPipeline(pm, tosaOptions);
+
pm.nest<func::FuncOp>().addPass(mlir::tosa::createStripQuantTypesPass());
pm.addPass(createCanonicalizerPass());
pm.addPass(createReconcileUnrealizedCastsPass());
diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py
index fa95b23..259aa1f 100644
--- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py
+++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py
@@ -35,11 +35,16 @@
def compare_results(self, iree_results, tflite_results, details):
super(MobileBertTest, self).compare_results(iree_results, tflite_results,
details)
- # We have confirmed in large scale accuracy tests that differences this large is acceptable.
+ # We have confirmed in large scale accuracy tests that differences as large
+ # as 5.0 is acceptable. We later further relaxed from 5.0 to 7.0 in
+ # https://github.com/google/iree/pull/9337 when quantized Softmax got
+ # de-quantized, which should be numerically correct albeit not bit-exact.
+ # The actual observed max error was ~ 6.36. The value 7.0 is that rounded up
+ # to the next integer.
self.assertTrue(
- np.isclose(iree_results[0], tflite_results[0], atol=5.0).all())
+ np.isclose(iree_results[0], tflite_results[0], atol=7.0).all())
self.assertTrue(
- np.isclose(iree_results[1], tflite_results[1], atol=5.0).all())
+ np.isclose(iree_results[1], tflite_results[1], atol=7.0).all())
def test_compile_tflite(self):
self.compile_and_execute()