separated state machine generation into an abstract class

This commit is contained in:
ParkerTenBroeck 2025-04-25 17:03:46 -04:00
parent c3532507f3
commit 9b0a9b7ad2
12 changed files with 379 additions and 286 deletions

View file

@ -1,162 +1,161 @@
package generator.runtime;
import generator.gen.Gen;
import java.lang.classfile.*;
import java.lang.classfile.attribute.InnerClassInfo;
import java.lang.classfile.attribute.InnerClassesAttribute;
import java.lang.classfile.attribute.NestHostAttribute;
import java.lang.classfile.instruction.*;
import java.lang.constant.ClassDesc;
import java.lang.constant.ConstantDescs;
import java.lang.constant.MethodTypeDesc;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.function.Function;
import java.lang.reflect.AccessFlag;
import java.util.*;
import java.util.function.BiFunction;
public class StateMachineBuilder {
final ClassBuilder clb;
final CodeBuilder cob;
final CodeModel com;
final GeneratorBuilder gb;
final LocalTracker lt;
public abstract class StateMachineBuilder {
public final static String PARAM_PREFIX = "param_";
public final static String LOCAL_PREFIX = "local_";
public final static String STATE_NAME = "state";
private HashMap<SpecialMethod, Function<StateMachineBuilder, SpecialMethodHandler>> smmap = new HashMap<>();
private boolean ignore_next_pop = false;
final ArrayList<SwitchCase> stateSwitchCases = new ArrayList<>();
final Label invalidState;
private static int sequence;
public final ClassDesc CD_this;
public final ClassDesc[] params;
public final MethodTypeDesc MTD_init;
public final int paramSlotOff;
public enum HandlerRan{
ImmediateRemovePop,
Immediate,
ReplacingNextReturn,
}
public interface SpecialMethodHandler {
void handle(StateMachineBuilder smb);
default boolean removeCall(){return true;}
default HandlerRan handlerRan(){return HandlerRan.ImmediateRemovePop;}
public final ClassModel src_clm;
public final MethodModel src_mem;
public final CodeModel src_com;
public final LocalTracker lt;
protected HashMap<SpecialMethod, BiFunction<StateMachineBuilder, CodeBuilder, SpecialMethodHandler>> smmap = new HashMap<>();
private final ArrayList<SwitchCase> stateSwitchCases = new ArrayList<>();
protected final String uniqueName(){
return sequence+++"";
}
public record SpecialMethod(ClassDesc owner, String name, MethodTypeDesc desc) {
}
static class YieldHandler implements SpecialMethodHandler {
final int resume_state;
final Label resume_label;
final boolean is_void;
public YieldHandler(StateMachineBuilder smb, boolean is_void) {
resume_state = smb.add_state(resume_label = smb.cob.newLabel());
this.is_void = is_void;
}
@Override
public void handle(StateMachineBuilder smb) {
if(is_void)smb.cob.aconst_null();
smb.lt.savingLocals(smb.gb.CD_this, smb.cob, () -> {
smb.cob.aload(0).loadConstant(resume_state).putfield(smb.gb.CD_this, GeneratorBuilder.STATE_NAME, TypeKind.INT.upperBound())
.new_(GeneratorBuilder.CD_Yield)
.dup_x1()
.swap()
.invokespecial(GeneratorBuilder.CD_Yield, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object))
.areturn();
smb.cob.labelBinding(resume_label);
});
smb.ignore_next_pop = true;
public void params(int slot_start, ParamConsumer consumer){
int offset = 0;
for (var param : params) {
consumer.consume(PARAM_PREFIX+offset, offset+slot_start, param);
offset += TypeKind.from(param).slotSize();
}
}
static class RetHandler implements SpecialMethodHandler {
final boolean is_void;
public StateMachineBuilder(ClassModel src_clm, MethodModel src_mem, CodeModel src_com){
this.src_clm = src_clm;
this.src_mem = src_mem;
this.src_com = src_com;
public RetHandler(boolean is_void) {
this.is_void = is_void;
var mts = src_mem.methodTypeSymbol();
mts = mts.changeReturnType(ConstantDescs.CD_void);
if (!src_mem.flags().has(AccessFlag.STATIC)) {
mts = mts.insertParameterTypes(0, src_clm.thisClass().asSymbol());
}
var name = src_clm.thisClass().name().stringValue() + "$" + src_mem.methodName().stringValue() + "$" + uniqueName();
public HandlerRan handlerRan(){return HandlerRan.ReplacingNextReturn;}
this.CD_this = ClassDesc.of(src_clm.thisClass().asSymbol().packageName(), name);
this.params = mts.parameterArray();
this.MTD_init = MethodTypeDesc.of(ConstantDescs.CD_void, params);
this.paramSlotOff = Arrays.stream(params).mapToInt(p -> TypeKind.from(p).slotSize()).sum();
@Override
public void handle(StateMachineBuilder smb) {
if(is_void)smb.cob.aconst_null();
smb.cob.aload(0).loadConstant(-1).putfield(smb.gb.CD_this, GeneratorBuilder.STATE_NAME, TypeKind.INT.upperBound())
.new_(GeneratorBuilder.CD_Ret)
.dup_x1()
.swap()
.invokespecial(GeneratorBuilder.CD_Ret, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object))
.areturn();
}
this.lt = new LocalTracker(this, src_com);
}
static class AwaitHandler implements SpecialMethodHandler{
final int yield_state;
final Label yield_label;
public AwaitHandler(StateMachineBuilder smb) {
yield_state = smb.add_state(yield_label = smb.cob.newLabel());
}
public HandlerRan handlerRan(){return HandlerRan.Immediate;}
@Override
public void handle(StateMachineBuilder smb) {
smb.cob.aload(0).loadConstant(yield_state).putfield(smb.gb.CD_this, GeneratorBuilder.STATE_NAME, TypeKind.INT.upperBound());
var start = smb.cob.newBoundLabel();
smb.cob.dup().dup()
.invokeinterface(GeneratorBuilder.CD_Gen, "next", MethodTypeDesc.of(GeneratorBuilder.CD_Res)).dup()
.instanceOf(GeneratorBuilder.CD_Ret);
smb.cob.ifThenElse(bcb -> {
bcb.checkcast(GeneratorBuilder.CD_Ret).invokevirtual(GeneratorBuilder.CD_Ret, "v", MethodTypeDesc.of(ConstantDescs.CD_Object)).swap().pop();
}, bcb -> {
smb.lt.savingLocals(smb.gb.CD_this, bcb, () -> {
bcb.swap().loadLocal(TypeKind.from(smb.gb.CD_this), 0).swap().putfield(smb.gb.CD_this, "meow", GeneratorBuilder.CD_Gen);
bcb.areturn().labelBinding(yield_label);
bcb.loadLocal(TypeKind.from(smb.gb.CD_this), 0).getfield(smb.gb.CD_this, "meow", GeneratorBuilder.CD_Gen);
});
bcb.goto_(start);
});
}
}
StateMachineBuilder(GeneratorBuilder gb, ClassBuilder clb, CodeBuilder cob, CodeModel com) {
this.gb = gb;
this.clb = clb;
this.cob = cob;
this.com = com;
this.lt = new LocalTracker(this, com);
invalidState = cob.newLabel();
smmap.put(new SpecialMethod(GeneratorBuilder.CD_Gen, "yield", GeneratorBuilder.MTD_Gen_Obj),smb -> new YieldHandler(smb, false));
smmap.put(new SpecialMethod(GeneratorBuilder.CD_Gen, "yield", GeneratorBuilder.MTD_Gen),smb -> new YieldHandler(smb, true));
smmap.put(new SpecialMethod(GeneratorBuilder.CD_Gen, "ret", GeneratorBuilder.MTD_Gen_Obj),_ -> new RetHandler(false));
smmap.put(new SpecialMethod(GeneratorBuilder.CD_Gen, "ret", GeneratorBuilder.MTD_Gen),_ -> new RetHandler(true));
smmap.put(new SpecialMethod(GeneratorBuilder.CD_Gen, "await", GeneratorBuilder.MTD_Obj), AwaitHandler::new);
}
int add_state(Label label) {
public int add_state(Label label) {
stateSwitchCases.add(SwitchCase.of(stateSwitchCases.size(), label));
return stateSwitchCases.size() - 1;
}
public void generateStateMachine() {
public void buildSourceMethodShim(CodeBuilder cob){
cob.new_(CD_this).dup();
params(0, (_, slot, type) -> {
cob.loadLocal(TypeKind.from(type), slot);
});
cob.invokespecial(CD_this, ConstantDescs.INIT_NAME, MTD_init).areturn();
}
public boolean shouldBeInnerClass(){
return false;
}
public byte[] buildStateMachine(){
return ClassFile.of(ClassFile.StackMapsOption.STACK_MAPS_WHEN_REQUIRED, ClassFile.AttributesProcessingOption.PASS_ALL_ATTRIBUTES).build(CD_this, clb -> {
if(shouldBeInnerClass()){
src_clm.findAttributes(Attributes.sourceFile()).forEach(clb::with);
clb.with(InnerClassesAttribute.of(InnerClassInfo.of(CD_this, Optional.of(src_clm.thisClass().asSymbol()), Optional.of(CD_this.displayName().split("\\$")[1]), AccessFlag.PUBLIC, AccessFlag.FINAL, AccessFlag.STATIC)));
clb.with(NestHostAttribute.of(src_clm.thisClass()));
}
// parameter fields
params(0, (param, _, type) -> {
clb.withField(param, type, ClassFile.ACC_PRIVATE);
});
clb.withField(STATE_NAME, ConstantDescs.CD_int, ClassFile.ACC_PRIVATE);
// constructor
clb.withMethod(ConstantDescs.INIT_NAME, MTD_init, ClassFile.ACC_PUBLIC, mb -> mb.withCode(cob -> {
cob.aload(0).invokespecial(ConstantDescs.CD_Object, ConstantDescs.INIT_NAME, ConstantDescs.MTD_void);
params(1, (param, slot, type) -> {
cob.aload(0).loadLocal(TypeKind.from(type), slot).putfield(CD_this, param, type);
});
cob.return_();
}));
buildStateMachineMethod(clb);
});
}
protected abstract void buildStateMachineMethod(ClassBuilder clb);
public void buildStateMachineMethodCode(ClassBuilder clb, CodeBuilder cob){
cob.trying(
tcob -> buildStateMachineCode(clb, tcob),
// catch anything set our state to -1 and throw the exception
ctb -> ctb.catchingAll(
blc ->
blc.aload(0).loadConstant(-1).putfield(CD_this, STATE_NAME, ConstantDescs.CD_int)
.new_(ClassDesc.ofDescriptor(RuntimeException.class.descriptorString()))
.dup_x1()
.swap()
.invokespecial(ClassDesc.ofDescriptor(RuntimeException.class.descriptorString()), ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Throwable))
.athrow()
)
).aconst_null().areturn();
}
public void buildStateMachineCode(ClassBuilder clb, CodeBuilder cob) {
boolean ignore_next_pop = false;
var invalidState = cob.newLabel();
var start_label = cob.newLabel();
add_state(start_label);
var handlers = new ArrayList<SpecialMethodHandler>();
for (CodeElement coe : com) {
for (CodeElement coe : src_com) {
if (coe instanceof InvokeInstruction is){
var handler = smmap.get(new SpecialMethod(is.owner().asSymbol(), is.name().stringValue(), is.typeSymbol()));
if(handler != null)
handlers.add(handler.apply(this));
handlers.add(handler.apply(this, cob));
}
}
cob.aload(0).getfield(gb.CD_this, GeneratorBuilder.STATE_NAME, TypeKind.INT.upperBound()).lookupswitch(invalidState, stateSwitchCases);
cob.aload(0).getfield(CD_this, STATE_NAME, TypeKind.INT.upperBound()).lookupswitch(invalidState, stateSwitchCases);
var start = cob.startLabel();
var end = cob.newLabel();
cob.localVariable(0, "this", gb.CD_this, start, end);
cob.localVariable(0, "this", CD_this, start, end);
SpecialMethodHandler currentHandler = null;
cob.labelBinding(start_label);
for (CodeElement coe : com) {
for (CodeElement coe : src_com) {
if (coe instanceof Instruction i) {
if (ignore_next_pop)
if (i.opcode() == Opcode.POP) {
@ -164,8 +163,8 @@ public class StateMachineBuilder {
continue;
}else throw new RuntimeException("Expected Pop Instruction");
if (i.opcode() == Opcode.ARETURN){
if (currentHandler !=null && currentHandler.handlerRan() == HandlerRan.ReplacingNextReturn){
currentHandler.handle(this);
if (currentHandler !=null && currentHandler.replacementKind() == ReplacementKind.ReplacingNextReturn){
currentHandler.handle(this, cob);
currentHandler = null;
continue;
}
@ -179,9 +178,9 @@ public class StateMachineBuilder {
if(currentHandler!=null)throw new RuntimeException("Multiple method handlers at once not supported");
var handler = handlers.removeFirst();
if(!handler.removeCall()) cob.with(coe);
if(handler.handlerRan() == HandlerRan.Immediate) handler.handle(this);
else if(handler.handlerRan() == HandlerRan.ImmediateRemovePop) {
handler.handle(this);
if(handler.replacementKind() == ReplacementKind.Immediate) handler.handle(this, cob);
else if(handler.replacementKind() == ReplacementKind.ImmediateReplacingPop) {
handler.handle(this, cob);
ignore_next_pop = true;
}else
currentHandler = handler;
@ -189,35 +188,33 @@ public class StateMachineBuilder {
}
}
// System.out.println(coe);
switch (coe) {
// locals which were once function parameters can be ignored
case LocalVariable lv when lv.slot() < gb.paramSlotOff -> {
case LocalVariable lv when lv.slot() < paramSlotOff -> {
}
case LocalVariable lv ->
cob.localVariable(lv.slot() - gb.paramSlotOff + 1, lv.name(), lv.type(), lv.startScope(), lv.endScope());
cob.localVariable(lv.slot() - paramSlotOff + 1, lv.name(), lv.type(), lv.startScope(), lv.endScope());
// increment indexes into the stack
case IncrementInstruction ii when ii.slot() < gb.paramSlotOff ->
cob.aload(0).dup().getfield(gb.CD_this, GeneratorBuilder.PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int)
case IncrementInstruction ii when ii.slot() < paramSlotOff ->
cob.aload(0).dup().getfield(CD_this, PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int)
.loadConstant(ii.constant()).iadd()
.putfield(gb.CD_this, GeneratorBuilder.PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int);
case IncrementInstruction ii -> cob.iinc(ii.slot() - gb.paramSlotOff + 1, ii.constant());
.putfield(CD_this, PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int);
case IncrementInstruction ii -> cob.iinc(ii.slot() - paramSlotOff + 1, ii.constant());
// convert local function parameters to class fields and offset regular locals
case LoadInstruction li when li.slot() < gb.paramSlotOff ->
cob.aload(0).getfield(gb.CD_this, GeneratorBuilder.PARAM_PREFIX + li.slot(), lt.paramType(li.slot()));
case LoadInstruction li -> cob.loadLocal(li.typeKind(), li.slot() - gb.paramSlotOff + 1);
case LoadInstruction li when li.slot() < paramSlotOff ->
cob.aload(0).getfield(CD_this, PARAM_PREFIX + li.slot(), lt.paramType(li.slot()));
case LoadInstruction li -> cob.loadLocal(li.typeKind(), li.slot() - paramSlotOff + 1);
// convert local function parameters to class fields and offset regular locals
case StoreInstruction ls when ls.slot() < gb.paramSlotOff && ls.typeKind().slotSize() == 2 ->
cob.aload(0).dup_x2().pop().putfield(gb.CD_this, GeneratorBuilder.PARAM_PREFIX + ls.slot(), lt.paramType(ls.slot()));
case StoreInstruction ls when ls.slot() < gb.paramSlotOff ->
cob.aload(0).swap().putfield(gb.CD_this, GeneratorBuilder.PARAM_PREFIX + ls.slot(), lt.paramType(ls.slot()));
case StoreInstruction ls when ls.slot() < paramSlotOff && ls.typeKind().slotSize() == 2 ->
cob.aload(0).dup_x2().pop().putfield(CD_this, PARAM_PREFIX + ls.slot(), lt.paramType(ls.slot()));
case StoreInstruction ls when ls.slot() < paramSlotOff ->
cob.aload(0).swap().putfield(CD_this, PARAM_PREFIX + ls.slot(), lt.paramType(ls.slot()));
case StoreInstruction ls -> {
lt.trackLocal(ls.slot() - gb.paramSlotOff + 1, ls.typeKind());
cob.storeLocal(ls.typeKind(), ls.slot() - gb.paramSlotOff + 1);
lt.trackLocal(ls.slot() - paramSlotOff + 1, ls.typeKind());
cob.storeLocal(ls.typeKind(), ls.slot() - paramSlotOff + 1);
}
default -> cob.with(coe);
@ -230,6 +227,5 @@ public class StateMachineBuilder {
cob.labelBinding(end);
lt.createLocalStoreFields(clb);
clb.withField("meow", GeneratorBuilder.CD_Gen, ClassFile.ACC_PRIVATE);
}
}