blob: 3293eff7cd0695227ffc504d04fe09d0123933d4 [file] [log] [blame]
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
import random
from typing import List, Optional, Tuple
from shared.insn_yaml import InsnsFile
from .program import Program
from .model import Model
from .snippet import Snippet
from .snippet_gen import GenRet, SnippetGen
from .gens.branch import Branch
from .gens.ecall import ECall
from .gens.jump import Jump
from .gens.straight_line_insn import StraightLineInsn
class SnippetGens:
'''A collection of snippet generators'''
_WEIGHTED_CLASSES = [
(Branch, 0.1),
(ECall, 1.0),
(Jump, 0.1),
(StraightLineInsn, 1.0)
]
def __init__(self, insns_file: InsnsFile) -> None:
self.generators = [] # type: List[Tuple[SnippetGen, float]]
for cls, weight in SnippetGens._WEIGHTED_CLASSES:
self.generators.append((cls(insns_file), weight))
def gen(self,
model: Model,
program: Program,
ecall: bool) -> Optional[GenRet]:
'''Pick a snippet and update model, program with its contents.
Normally returns a GenRet tuple with the same meanings as Snippet.gen.
If the chosen snippet would generate an ECALL and ecall is False, this
instead returns None (and leaves model and program unchanged).
'''
real_weights = []
for generator, weight in self.generators:
weight_mult = generator.pick_weight(model, program)
real_weights.append(weight * weight_mult)
# Define a continuation (which basically just calls self.gens()) to
# pass to each generator. This allows recursive generation and avoids
# needing circular imports to get the types right.
def cont(md: Model, prg: Program) -> Optional[Tuple[Snippet, Model]]:
ret = self.gens(md, prg, False)
if ret is None:
return None
snippet, model = ret
# We should always have a Model returned (because the ecall
# argument was False)
assert model is not None
return (snippet, model)
while True:
# Pick a generator based on the weights in real_weights.
idx = random.choices(range(len(self.generators)),
weights=real_weights)[0]
generator, _ = self.generators[idx]
# Note that there should always be at least one non-zero weight in
# real_weights. random.choices doesn't check that: if you pass all
# weights equal to zero, it always picks the last element. Since
# that would cause an infinite loop, add a check here to make sure
# that the choice we made had positive weight.
assert real_weights[idx] > 0
if isinstance(generator, ECall) and not ecall:
return None
# Run the generator to generate a snippet
gen_res = generator.gen(cont, model, program)
if gen_res is not None:
return gen_res
# If gen_res is None, the generator failed. Set that weight to zero
# and try again.
real_weights[idx] = 0.0
def gens(self,
model: Model,
program: Program,
ecall: bool) -> Optional[GenRet]:
'''Generate some snippets to continue program.
This will try to run down model.fuel and program.size. If ecall is
True, it will eventually generate an ECALL instruction. If ecall is
False then instead of generating the ECALL instruction, it will instead
stop (leaving model.pc where the ECALL instruction would have been
inserted).
'''
children = [] # type: List[Snippet]
next_model = model # type: Optional[Model]
while True:
assert next_model is not None
old_fuel = next_model.fuel
gr = self.gen(next_model, program, ecall)
if gr is None:
assert ecall is False
break
snippet, next_model = gr
children.append(snippet)
if next_model is None:
break
assert next_model.fuel < old_fuel
if not children:
assert ecall is False
return None
snippet = Snippet.merge_list(children)
return (snippet, next_model)