Fix convolution op encoding PiperOrigin-RevId: 560156772
diff --git a/sim/kelvin_conv.bin_fmt b/sim/kelvin_conv.bin_fmt index 48b0634..9526fdd 100644 --- a/sim/kelvin_conv.bin_fmt +++ b/sim/kelvin_conv.bin_fmt
@@ -1,4 +1,4 @@ instruction group KelvinVectorConvInst[32] : KelvinV3ArgsType { // vconv - aconv_vxv : KelvinV3ArgsType : func3_hi == 0b10, func3_lo == 0b00, vd == 48, vs1_low3 == 0, vs2 != 0, vs3_low3 == 0, m == 0, form == 0b101; + aconv_vxv : KelvinV3ArgsType : func3_hi == 0b10, func3_lo == 0b00, vd == 48, vs1_low4 == 0, vs2 != 0, vs3_low3 == 0, m == 0, form == 0b101; };
diff --git a/sim/kelvin_encoding.cc b/sim/kelvin_encoding.cc index 21dbce1..fa6acc5 100644 --- a/sim/kelvin_encoding.cc +++ b/sim/kelvin_encoding.cc
@@ -211,6 +211,10 @@ absl::StrCat(mpact::sim::riscv::RiscVState::kXregPrefix, reg_num), xreg_alias_[reg_num]); } + if (opcode_ == OpcodeEnum::kAdwinit) { + // Borrow the strip_mine setting to set 4x registers. + strip_mine = true; + } return GetVectorRegisterSourceOp<mpact::sim::riscv::RVVectorRegister>( state_, reg_num, strip_mine, GetSrc1WidenFactor()); }); @@ -220,8 +224,8 @@ auto reg_num = encoding::kelvin_v2_args_type::ExtractVs2(inst_word_); bool strip_mine = encoding::kelvin_v2_args_type::ExtractM(inst_word_); auto form = encoding::kelvin_v2_args_type::ExtractForm(inst_word_); - // .vx or .xx forms are using scalar xs2. - if (form == 2 || form == 3) { + // .vx or .xx forms are using scalar xs2, also for aconv op + if (form == 2 || form == 3 || opcode_ == OpcodeEnum::kAconvVxv) { if (reg_num == 0) { return new mpact::sim::generic::IntLiteralOperand<0>( {1}, xreg_alias_[0]); @@ -238,7 +242,7 @@ state_, reg_num, strip_mine, 1 /* widen_factor */); }); source_op_getters_.emplace( - // vst and vstq use `vd` field as the source for the vector store. + // `vst` and `vstq` use `vd` field as the source for the vector store. // convolution instructions also use `vd` as one of the sources. static_cast<int>(SourceOpEnum::kVd), [this]() -> SourceOperandInterface * { @@ -252,8 +256,9 @@ static_cast<int>(SourceOpEnum::kVs3), [this]() -> SourceOperandInterface * { auto reg_num = encoding::kelvin_v3_args_type::ExtractVs3(inst_word_); + int widen_factor = opcode_ == OpcodeEnum::kAconvVxv ? 8 : 4; return GetVectorRegisterSourceOp<mpact::sim::riscv::RVVectorRegister>( - state_, reg_num, false /* strip_mine */, 1 /* widen_factor */); + state_, reg_num, false /* strip_mine */, widen_factor); }); source_op_getters_.insert(std::make_pair( static_cast<int>(SourceOpEnum::kNone), []() { return nullptr; })); @@ -289,6 +294,11 @@ [this](int latency) -> DestinationOperandInterface * { auto reg_num = encoding::kelvin_v2_args_type::ExtractVd(inst_word_); bool strip_mine = encoding::kelvin_v2_args_type::ExtractM(inst_word_); + if (opcode_ == OpcodeEnum::kVcget || opcode_ == OpcodeEnum::kAdwinit) { + // Borrow the strip_mine setting to set 4x/8x registers although it is + // not part of the encoding. + strip_mine = true; + } return GetVectorRegisterDestinationOp< mpact::sim::riscv::RVVectorRegister>(state_, reg_num, strip_mine, IsWidenDestinationRegisterOp(), @@ -397,6 +407,11 @@ } } + // `vcget` needs `vd` group size of 8. Use stripmine and widening to set it. + if (opcode_ == OpcodeEnum::kVcget) { + return true; + } + return false; } @@ -425,6 +440,12 @@ return 4; } + // Convolution related ops has 8x src1 registers. + if (opcode_ == OpcodeEnum::kAconvVxv || opcode_ == OpcodeEnum::kAcset || + opcode_ == OpcodeEnum::kActr) { + return 8; + } + return 1; }
diff --git a/sim/kelvin_format.bin_fmt b/sim/kelvin_format.bin_fmt index 7efefdf..a8bc126 100644 --- a/sim/kelvin_format.bin_fmt +++ b/sim/kelvin_format.bin_fmt
@@ -113,7 +113,7 @@ unsigned func1[3]; unsigned form[2]; // .vv==0b00, .vx==0b10, .xx==0b11 overlays: // For accumulation register support. - unsigned vs1_low3[3] = vs1[2..0]; + unsigned vs1_low4[4] = vs1[3..0]; unsigned vs1_low2[2] = vs1[1..0]; unsigned vd_low2[2] = vd[1..0]; }; @@ -130,6 +130,6 @@ unsigned func3_lo[2]; unsigned form[3]; // .vvv=0b001, .vxv=0b101. overlays: - unsigned vs1_low3[3] = vs1[2..0]; + unsigned vs1_low4[4] = vs1[3..0]; unsigned vs3_low3[3] = vs3[2..0]; };
diff --git a/sim/kelvin_memory.bin_fmt b/sim/kelvin_memory.bin_fmt index f3105d6..0e5530d 100644 --- a/sim/kelvin_memory.bin_fmt +++ b/sim/kelvin_memory.bin_fmt
@@ -126,6 +126,6 @@ // acset / actr / adwinit acset : KelvinV2ArgsType : func2 == 0b01'0000, vs2 == 0, m == 0, vd == 48; - actr : KelvinV2ArgsType : func2 == 0b01'0001, vs2 == 0, vs1_low3 == 0, m == 0, vd == 48; + actr : KelvinV2ArgsType : func2 == 0b01'0001, vs2 == 0, vs1_low4 == 0, m == 0, vd == 48; adwinit : KelvinV2ArgsType : func2 == 0b01'0010, vs2 == 0, vs1_low2 == 0, sz == 0b00, m == 0, vd_low2 == 0; };
diff --git a/sim/test/kelvin_encoding_test.cc b/sim/test/kelvin_encoding_test.cc index b7fa6c3..3542b76 100644 --- a/sim/test/kelvin_encoding_test.cc +++ b/sim/test/kelvin_encoding_test.cc
@@ -421,6 +421,32 @@ SourceOpEnum::kVs1); EXPECT_EQ(v_src->size(), 2); delete v_src; + + // Test acset.v, actr.v, adwinit.v + constexpr uint32_t kACSetVBase = 0b010000'000000'010000'00'110000'0'001'10; + v_src = EncodeOpHelper<RV32VectorSourceOperand>( + kACSetVBase, OpcodeEnum::kAcset, SourceOpEnum::kVs1); + EXPECT_EQ(v_src->size(), 8); + delete v_src; + + v_src = EncodeOpHelper<RV32VectorSourceOperand>( + kACSetVBase | (1 << 26 /* actr */), OpcodeEnum::kActr, + SourceOpEnum::kVs1); + EXPECT_EQ(v_src->size(), 8); + delete v_src; + + v_src = EncodeOpHelper<RV32VectorSourceOperand>( + kACSetVBase | (1 << 27 /* adwinit */), OpcodeEnum::kAdwinit, + SourceOpEnum::kVs1); + EXPECT_EQ(v_src->size(), 4); + delete v_src; + + // Test aconv.vxv + constexpr uint32_t kAVConvBase = 0b001000'000001'010000'10'110000'0'00'101; + v_src = EncodeOpHelper<RV32VectorSourceOperand>( + kAVConvBase, OpcodeEnum::kAconvVxv, SourceOpEnum::kVs1); + EXPECT_EQ(v_src->size(), 8); + delete v_src; } TEST_F(KelvinEncodingTest, KelvinWideningVd) { @@ -493,6 +519,29 @@ DestOpEnum::kVd); EXPECT_EQ(v_dest->size(), 2); delete v_dest; + + // Test adwinit.v + constexpr uint32_t kAdwinitVBase = 0b010010'000000'010000'00'110000'0'001'10; + v_dest = EncodeOpHelper<RV32VectorDestOperand>( + kAdwinitVBase, OpcodeEnum::kAdwinit, DestOpEnum::kVd); + EXPECT_EQ(v_dest->size(), 4); + delete v_dest; + + // Test vcget + constexpr uint32_t kVCGet = 0b010100'000000'000000'00'110000'0'111'11; + v_dest = EncodeOpHelper<RV32VectorDestOperand>(kVCGet, OpcodeEnum::kVcget, + DestOpEnum::kVd); + EXPECT_EQ(v_dest->size(), 8); + delete v_dest; +} + +TEST_F(KelvinEncodingTest, KelvinEncodeVs3) { + constexpr uint32_t kACovBase = 0b001000'000001'010000'10'110000'0'00'101; + auto *v_src = EncodeOpHelper<RV32VectorSourceOperand>( + kACovBase, OpcodeEnum::kAconvVxv, SourceOpEnum::kVs3); + EXPECT_EQ(v_src->AsString(), "v8"); + EXPECT_EQ(v_src->size(), 8); + delete v_src; } } // namespace
diff --git a/sim/test/kelvin_vector_convolution_instructions_test.cc b/sim/test/kelvin_vector_convolution_instructions_test.cc index d02aef2..e064785 100644 --- a/sim/test/kelvin_vector_convolution_instructions_test.cc +++ b/sim/test/kelvin_vector_convolution_instructions_test.cc
@@ -70,13 +70,13 @@ false /* widen_dst*/, {}); instructions[1]->set_semantic_function(KelvinVConv); - AppendVectorRegisterOperands(instructions[1].get(), kVLenInWord, - 1 /* src1_widen_factor*/, kVs3, {}, + AppendVectorRegisterOperands(instructions[1].get(), 1, + kVLenInWord /* src1_widen_factor*/, kVs3, {}, false /* widen_dst*/, {kVd}); AppendRegisterOperands(instructions[1].get(), {kelvin::sim::test::kRs2Name}, {}); - AppendVectorRegisterOperands(instructions[1].get(), kVLenInWord, - 1 /* src3_widen_factor*/, kVs1, {}, + AppendVectorRegisterOperands(instructions[1].get(), 1, + kVLenInWord /* src3_widen_factor*/, kVs1, {}, false /* widen_dst*/, {}); execution_fail_ = false; state_->set_on_trap(trap_call_back_);
diff --git a/sim/test/kelvin_vector_memory_instructions_test.cc b/sim/test/kelvin_vector_memory_instructions_test.cc index e11fe0e..524b3a4 100644 --- a/sim/test/kelvin_vector_memory_instructions_test.cc +++ b/sim/test/kelvin_vector_memory_instructions_test.cc
@@ -511,9 +511,9 @@ {{vd_name, vd_span.subspan(kVLenInWord * i, kVLenInWord)}}); } auto instruction = CreateInstruction(); - AppendVectorRegisterOperands(instruction.get(), kVLenInWord, + AppendVectorRegisterOperands(instruction.get(), kVLenInWord / 2, 1 /* src1_widen_factor */, {}, {}, - false /* widen_dst */, {kVd}); + true /* widen_dst */, {kVd}); instruction->set_semantic_function(&KelvinVcGet); instruction->Execute(); // Resulting v48..55 should all have 0 values @@ -544,8 +544,8 @@ {vs_name, vd_span.subspan(kVLenInWord * i, kVLenInWord)}}); } auto instruction = CreateInstruction(); - AppendVectorRegisterOperands(instruction.get(), kVLenInWord, - 1 /* src1_widen_factor */, kVs, {}, + AppendVectorRegisterOperands(instruction.get(), 1, + kVLenInWord /* src1_widen_factor */, kVs, {}, false /* widen_dst */, {kVd}); instruction->set_semantic_function( absl::bind_front(&KelvinAcSet, is_transpose)); @@ -603,7 +603,7 @@ {vd_name, vd_span.subspan(kVLenInByte * i, kVLenInByte)}}); } auto instruction = CreateInstruction(); - AppendVectorRegisterOperands(instruction.get(), kVLenInByte, + AppendVectorRegisterOperands(instruction.get(), kInitLength, 1 /* src1_widen_factor */, kVs, {}, false /* widen_dst */, {kVd}); instruction->set_semantic_function(&KelvinADwInit);