From 06defb59440db27561e596978523add13cfbc523 Mon Sep 17 00:00:00 2001 From: Parker TenBroeck <51721964+ParkerTenBroeck@users.noreply.github.com> Date: Thu, 3 Apr 2025 02:17:11 -0400 Subject: [PATCH] started cleaning up and organizing code --- src/Examples.java | 74 +++ src/{generator/Test.java => Lexer.java} | 74 +-- src/Main.java | 15 +- src/generator/Fun.java | 454 +----------------- src/generator/runtime/GeneratorBuilder.java | 340 +++++++++++++ .../runtime/GeneratorClassLoader.java | 96 ++++ 6 files changed, 531 insertions(+), 522 deletions(-) create mode 100644 src/Examples.java rename src/{generator/Test.java => Lexer.java} (56%) create mode 100644 src/generator/runtime/GeneratorBuilder.java create mode 100644 src/generator/runtime/GeneratorClassLoader.java diff --git a/src/Examples.java b/src/Examples.java new file mode 100644 index 0000000..11ab025 --- /dev/null +++ b/src/Examples.java @@ -0,0 +1,74 @@ +import generator.Gen; + +public class Examples { + // public static Gen parse(String str){ +// for(var gen = parse(str); gen.next() instanceof Gen.Yield(var item);){ +// +// } +// return Gen.ret(); +// } + +// public static Gen parse(String str){ +// { +// String meow = "10"; +// meow += "11"; +// Gen.yield(meow); +// Gen.yield(meow); +// Gen.yield(meow); +// Gen.yield(meow); +// } +// for(var split : str.split(" ")){ +// Gen.yield(split); +// } +// { +// var str2 = str; +// while(str2.length()>10){ +// var len = str2.length(); +// Gen.yield(len+" length"); +// str2 = str2.substring(1); +// } +// } +// +// while(str.length()>10){ +// var len = str.length(); +// Gen.yield(len+" length"); +// str = str.substring(1); +// } +// return Gen.ret(); +// } + +// public static Gen gen(int times, double mul) { +// mul -= 0.5; +// for (int i = 0; i < times; i ++) { +// Gen.yield("iteration number: " + i*mul); +// } +// return Gen.ret(); +// } +// public static Gen gen(int times, double mul) { +// mul -= 0.5; +// for (int i = 0; i < times; i++) { +// Gen.yield("iteration number: " + i*mul); +// } +// return Gen.ret(); +// } + + // public static Gen gen() { +// Gen.yield("1"); +// Gen.yield("2"); +// Gen.yield("3"); +// Gen.yield("4"); +// return Gen.ret(); +// } + +// public static Gentest(T val){ +// Gen.yield(val); +// return Gen.ret(); +// } + + public static Gen test(double[] nyas){ + for(var d : nyas){ + Gen.yield(d); + } + return Gen.ret(); + } +} diff --git a/src/generator/Test.java b/src/Lexer.java similarity index 56% rename from src/generator/Test.java rename to src/Lexer.java index 26cdebc..8d003e3 100644 --- a/src/generator/Test.java +++ b/src/Lexer.java @@ -1,26 +1,7 @@ -package generator; +import generator.Gen; +public class Lexer { -public class Test { -// public static Gen gen() { -// Gen.yield("1"); -// Gen.yield("2"); -// Gen.yield("3"); -// Gen.yield("4"); -// return Gen.ret(); -// } - -// public static Gentest(T val){ -// Gen.yield(val); -// return Gen.ret(); -// } - - public static Gentest(double[] nyas){ - for(var d : nyas){ - Gen.yield(d); - } - return Gen.ret(); - } public sealed interface Token{} public enum Punc implements Token { @@ -85,55 +66,4 @@ public class Test { } } } - -// public static Gen parse(String str){ -// for(var gen = parse(str); gen.next() instanceof Gen.Yield(var item);){ -// -// } -// return Gen.ret(); -// } - -// public static Gen parse(String str){ -// { -// String meow = "10"; -// meow += "11"; -// Gen.yield(meow); -// Gen.yield(meow); -// Gen.yield(meow); -// Gen.yield(meow); -// } -// for(var split : str.split(" ")){ -// Gen.yield(split); -// } -// { -// var str2 = str; -// while(str2.length()>10){ -// var len = str2.length(); -// Gen.yield(len+" length"); -// str2 = str2.substring(1); -// } -// } -// -// while(str.length()>10){ -// var len = str.length(); -// Gen.yield(len+" length"); -// str = str.substring(1); -// } -// return Gen.ret(); -// } - -// public static Gen gen(int times, double mul) { -// mul -= 0.5; -// for (int i = 0; i < times; i ++) { -// Gen.yield("iteration number: " + i*mul); -// } -// return Gen.ret(); -// } -// public static Gen gen(int times, double mul) { -// mul -= 0.5; -// for (int i = 0; i < times; i++) { -// Gen.yield("iteration number: " + i*mul); -// } -// return Gen.ret(); -// } } diff --git a/src/Main.java b/src/Main.java index d5238c9..5e9bc17 100644 --- a/src/Main.java +++ b/src/Main.java @@ -1,5 +1,16 @@ -public class Main { +import generator.Fun; +import generator.Gen; + +public class Main implements Runnable { public static void main(String[] args) { - System.out.println("Hello, World!"); + Fun.runWithGeneratorSupport(Main.class); + } + + @Override + public void run() { + var gen = Lexer.parse("f7(x,y,z,w, u,v, othersIg) = v-(x*y+y+ln(z)^2*sin(z*pi/2))/(w*u)+sqrt(othersIg*120e-1)"); + while(gen.next() instanceof Gen.Yield(var tok)) { + System.out.println(tok); + } } } \ No newline at end of file diff --git a/src/generator/Fun.java b/src/generator/Fun.java index 400f25c..9244274 100644 --- a/src/generator/Fun.java +++ b/src/generator/Fun.java @@ -1,457 +1,15 @@ package generator; -import java.io.IOException; -import java.io.PrintStream; -import java.lang.classfile.*; -import java.lang.classfile.attribute.StackMapFrameInfo; -import java.lang.classfile.instruction.*; -import java.lang.constant.ClassDesc; -import java.lang.constant.ConstantDescs; -import java.lang.constant.MethodTypeDesc; -import java.lang.reflect.AccessFlag; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Objects; -import java.util.concurrent.atomic.AtomicInteger; - -import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.TOP; +import generator.runtime.GeneratorClassLoader; public class Fun { - public static void main(String... args) throws Exception { + public static void runWithGeneratorSupport(Class clazz){ var loader = new GeneratorClassLoader(Fun.class.getClassLoader()); - loader.loadClass(Fun.class.getName()).getMethod("run").invoke(null); - } - - public static void run() { - { - var gen = Test.test(new double[]{1,2,3,45,6,4,3,3,452,452,45,345,45}); - while (gen.next() instanceof Gen.Yield(var val)) { - System.out.println(val); - } - } - - var gen = Test.parse("f7(x,y,z,w, u,v, othersIg) = v-(x*y+y+ln(z)^2*sin(z*pi/2))/(w*u)+sqrt(othersIg*120e-1)"); - var res = gen.next(); - while (true) { - System.out.println(res); - if (res instanceof Gen.Ret || res == null) break; - res = gen.next(); - } - } - - public static class GeneratorClassLoader extends ClassLoader { - private final static MethodTypeDesc MTD_void_String = MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_String); - private final static ClassDesc CD_System = ClassDesc.ofDescriptor(System.class.descriptorString()); - private final static ClassDesc CD_PrintStream = ClassDesc.ofDescriptor(PrintStream.class.descriptorString()); - - private final static ClassDesc CD_Gen = ClassDesc.ofDescriptor(Gen.class.descriptorString()); - private final static ClassDesc CD_Res = ClassDesc.ofDescriptor(Gen.Res.class.descriptorString()); - private final static ClassDesc CD_Yield = ClassDesc.ofDescriptor(Gen.Yield.class.descriptorString()); - private final static ClassDesc CD_Ret = ClassDesc.ofDescriptor(Gen.Ret.class.descriptorString()); - private final static MethodTypeDesc MTD_Res = MethodTypeDesc.of(CD_Res); - private final static MethodTypeDesc MTD_Gen_Obj = MethodTypeDesc.of(CD_Gen, ConstantDescs.CD_Object); - - private final HashMap customClazzDefMap = new HashMap<>(); - private final HashMap> customClazzMap = new HashMap<>(); - - public GeneratorClassLoader(ClassLoader parent) { - super(parent); - } - - @Override - public Class loadClass(String name) throws ClassNotFoundException { - if (customClazzDefMap.get(name) instanceof byte[] bytes) - customClazzMap.put(name, defineClass(name, bytes, 0, bytes.length)); - if (customClazzMap.get(name) instanceof Class clazz) - return clazz; - if (name.startsWith("java")) - return super.loadClass(name); - - var p = "/" + name.replace('.', '/') + ".class"; - try (var stream = Fun.class.getResourceAsStream(p)) { - var bytes = Objects.requireNonNull(stream).readAllBytes(); - bytes = searchForGenerators(bytes); - customClazzDefMap.put(name, bytes); - customClazzMap.put(name, defineClass(name, bytes, 0, bytes.length)); - return customClazzMap.get(name); - } catch (IOException e) { - throw new ClassNotFoundException(name, e); - } - } - - public byte[] searchForGenerators(byte[] in) { - var clm = ClassFile.of().parse(in); - return ClassFile.of().build(clm.thisClass().asSymbol(), cb -> { - for (var ce : clm) { - var isGen = clm.thisClass().asSymbol().descriptorString().equals(Gen.class.descriptorString()); - if (ce instanceof MethodModel mem && !isGen) { - var methodRetGen = mem.methodTypeSymbol().returnType().descriptorString().equals(Gen.class.descriptorString()); - if (methodRetGen) { - cb.withMethod(mem.methodName().stringValue(), mem.methodTypeSymbol(), mem.flags().flagsMask(), mb -> { - for (var me : mem) { - if (me instanceof CodeModel com) { - mb.withCode(cob -> rebuildGeneratorMethod(clm, mem, com, cob)); - } else mb.with(me); - } - }); - } else - cb.with(mem); - - - } else cb.with(ce); - } - }); - } - - private ClassDesc generateGeneratorFromGenMethod(ClassModel clm, MethodModel mem, CodeModel com, CodeBuilder scob) { - var cd = ClassDesc.of("Gen" + customClazzDefMap.size()); - - var bytes = ClassFile.of(ClassFile.StackMapsOption.STACK_MAPS_WHEN_REQUIRED).build(cd, clb -> { - - clb.withInterfaces(List.of(clb.constantPool().classEntry(CD_Gen))); - - scob.new_(cd).dup(); - var mts = mem.methodTypeSymbol(); - mts = mts.changeReturnType(ConstantDescs.CD_void); - if (!mem.flags().has(AccessFlag.STATIC)) { - mts = mts.insertParameterTypes(0, clm.thisClass().asSymbol()); - } - - int offset = 0; - var mts_params = mts.parameterArray(); - for (var param : mts_params) { - clb.withField("param_" + offset, param, ClassFile.ACC_PRIVATE); - var tk = TypeKind.fromDescriptor(param.descriptorString()); - scob.loadLocal(tk, offset); - offset += tk.slotSize(); - } - var count = offset; - scob.invokespecial(cd, ConstantDescs.INIT_NAME, mts).areturn(); - - - clb.withMethod(ConstantDescs.INIT_NAME, mts, ClassFile.ACC_PUBLIC, - mb -> mb.withCode(cob -> { - cob.aload(0).invokespecial(ConstantDescs.CD_Object, ConstantDescs.INIT_NAME, ConstantDescs.MTD_void); - int offset2 = 0; - for (var param : mts_params) { - var tk = TypeKind.fromDescriptor(param.descriptorString()); - cob.aload(0).loadLocal(tk, offset2 + 1).putfield(cd, "param_" + offset2, param); - offset2 += tk.slotSize(); - } - cob.return_(); - }) - ); - clb.withMethod("next", MTD_Res, ClassFile.ACC_PUBLIC, mb -> { - mb.withCode(cob -> cob.trying( - tcob -> { - generateStateMachine(mts_params, cd, clb, com, tcob, count); - }, - ctb -> ctb.catchingAll( - blc -> blc.aload(0).loadConstant(-1).putfield(cd, "___state___", TypeKind.INT.upperBound()).athrow() - ) - ).aconst_null().areturn() - ); - }); - } - ); - customClazzDefMap.put(cd.displayName(), bytes); - return cd; - } - - private static class LocalTracker{ - - record LocalStore(String name, ClassDesc cd){} - - HashMap parameter_map = new HashMap<>(); - - HashMap localVarTypes = new HashMap<>(); - HashMap localVarDetailedType = new HashMap<>(); - HashMap stackMapFrames = new HashMap<>(); - StackMapFrameInfo currentFrame; - ArrayList localStore = new ArrayList<>(); - - - private LocalTracker(ClassDesc[] mts_params, ClassDesc cd, ClassBuilder clb, CodeModel com, CodeBuilder cob, int count){ - int offset = 0; - for (var param : mts_params) { - parameter_map.put(offset, param); - offset += TypeKind.from(param).slotSize(); - } - - for(var attr : com.findAttributes(Attributes.stackMapTable())){ - var entries = new ArrayList(); - for(var smfi : attr.entries()){ - var locals = new ArrayList<>(smfi.locals()); - for(int i = 0; i < mts_params.length; i ++) locals.removeFirst(); - locals.addFirst(StackMapFrameInfo.ObjectVerificationTypeInfo.of(cd)); - entries.add(StackMapFrameInfo.of(smfi.target(), locals, smfi.stack())); - stackMapFrames.put(smfi.target(), entries.getLast()); - } - } - } - - - //Tries its best to reuse old saved locals field slots, only reuses if types exactly match - public void savingLocals(ClassDesc cd, CodeBuilder cob, Runnable run) { - record Saved(int slot, String name, ClassDesc cd){}; - var saved = new ArrayList(); - var lls = new ArrayList<>(localStore); - foreach((slot, tk, desc) -> { - String name = null; - for(int i = 0; i < lls.size(); i ++) - if(lls.get(i).cd.equals(desc)){ - name = lls.get(i).name; - lls.remove(i); - break; - } - if(name==null){ - name = "local_" + localStore.size(); - localStore.add(new LocalStore(name, desc)); - } - saved.add(new Saved(slot, name, desc)); - cob.aload(0).loadLocal(tk, slot).putfield(cd, name, desc); - }); - run.run(); - - for(var save : saved){ - cob.aload(0).getfield(cd, save.name, save.cd).storeLocal(TypeKind.from(save.cd), save.slot); - } - } - - public void createLocalStoreFields(ClassBuilder clb){ - for (var local : localStore) { - clb.withField(local.name, local.cd, ClassFile.ACC_PRIVATE); - } - } - - public void encounterLabel(Label l){ - var tmp = stackMapFrames.get(l); - if(tmp!=null){ - localVarTypes.clear(); - localVarDetailedType.clear(); - currentFrame=tmp; - } - } - - public ClassDesc paramType(int slot){ - return parameter_map.get(slot); - } - - public void addDetailedLocal(int slot, ClassDesc desc) { - localVarDetailedType.put(slot, desc); - } - - public void addLocal(int slot, TypeKind typeKind) { - var prev = localVarTypes.put(slot, typeKind); - } - - interface LocalConsumer{ - void consume(int slot, TypeKind tk, ClassDesc desc); - } - - void foreach(LocalConsumer consumer){ - var slot = 0; - if(currentFrame!=null) - for(var kind : currentFrame.locals()){ - switch(kind){ - case StackMapFrameInfo.ObjectVerificationTypeInfo o -> { - if(slot!=0) - consumer.consume(slot, o.className().typeKind(), o.classSymbol()); - slot += 1; - } - case StackMapFrameInfo.SimpleVerificationTypeInfo ti -> { - if(kind==TOP) { - slot += 1; - if(localVarTypes.get(slot-1) instanceof TypeKind tk){ - ClassDesc cd = tk.upperBound(); - if(localVarDetailedType.get(slot-1) instanceof ClassDesc cld) - cd = cld; - consumer.consume(slot-1, tk, cd); - } - continue; - } - var type = switch(ti){ - case INTEGER -> TypeKind.INT; - case FLOAT -> TypeKind.FLOAT; - case DOUBLE -> TypeKind.DOUBLE; - case LONG -> TypeKind.LONG; - case NULL -> TypeKind.REFERENCE; - default -> - throw new IllegalStateException(); - }; - consumer.consume(slot, type, type.upperBound()); - slot += 1; - } - case StackMapFrameInfo.UninitializedVerificationTypeInfo _ -> - throw new IllegalStateException(); - } - } - for(var entry : localVarTypes.entrySet()){ - if(entry.getKey()(); - var invalidState = cob.newLabel(); - stateSwitchCases.add(SwitchCase.of(0, cob.newLabel())); - int switchCase = 1; - for (CodeElement coe : com){ - if(coe instanceof InvokeInstruction is && is.opcode().equals(Opcode.INVOKESTATIC) && is.owner().asSymbol().equals(CD_Gen) && (is.name().equalsString("yield"))){ - stateSwitchCases.add(SwitchCase.of(switchCase, cob.newLabel())); - switchCase++; - } - } - cob.aload(0).getfield(cd, "___state___", TypeKind.INT.upperBound()).lookupswitch(invalidState, stateSwitchCases); - var start = cob.startLabel(); - var end = cob.newLabel(); - cob.localVariable(0, "this", cd, start, end); - - var localTracker = new LocalTracker(mts_params, cd,clb,com,cob,count); - - switchCase = 1; - cob.labelBinding(stateSwitchCases.removeFirst().target()); - final boolean[] ignore_next_return = {false}; - final boolean[] ignore_next_pop = {false}; - for (CodeElement coe : com) { - switch(coe){ - case Instruction ins when ins.opcode() == Opcode.POP && ignore_next_pop[0] -> { - ignore_next_pop[0] = false; - continue; - } - case ReturnInstruction _ when ignore_next_return[0] -> { - ignore_next_return[0] = false; - continue; - } - case Instruction _ when ignore_next_return[0] || ignore_next_pop[0] -> - throw new RuntimeException(); - - case Label l -> { - localTracker.encounterLabel(l); - } - - default -> {} - } - switch (coe) { - case InvokeInstruction is when is.opcode().equals(Opcode.INVOKESTATIC) && is.owner().asSymbol().equals(CD_Gen) && (is.name().equalsString("yield") || is.name().equalsString("ret")) -> { - if (MethodTypeDesc.ofDescriptor(is.method().type().stringValue()).parameterArray().length == 0) { - cob.aconst_null(); - } - - if (is.name().equalsString("ret")) { - cob.aload(0).loadConstant(-1).putfield(cd, "___state___", TypeKind.INT.upperBound()) - .new_(CD_Ret) - .dup_x1() - .swap() - .invokespecial(CD_Ret, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object)) - .areturn(); - ignore_next_return[0] = true; - } else { - int finalSwitchCase = switchCase; - switchCase++; - localTracker.savingLocals(cd, cob, () -> { - cob.aload(0).loadConstant(finalSwitchCase).putfield(cd, "___state___", TypeKind.INT.upperBound()) - .new_(CD_Yield) - .dup_x1() - .swap() - .invokespecial(CD_Yield, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object)) - .areturn(); - cob.labelBinding(stateSwitchCases.removeFirst().target()); - }); - - ignore_next_pop[0] = true; - } - } - case BranchInstruction b -> cob.with(b); - case LocalVariable lv when lv.slot() < count -> {} - case LocalVariable lv -> { - cob.localVariable(lv.slot() - count + 1, lv.name(), lv.type(), lv.startScope(), lv.endScope()); - //this might be useful to add? -// localTracker.addDetailedLocal(lv.slot() - count + 1, lv.typeSymbol()); - } - - case IncrementInstruction ii when ii.slot() < count -> throw new RuntimeException(); - case IncrementInstruction ii -> cob.iinc(ii.slot() - count + 1, ii.constant()); - - case LoadInstruction li when li.slot() < count -> - cob.aload(0).getfield(cd, "param_" + li.slot(), localTracker.paramType(li.slot())); - case LoadInstruction li -> - cob.loadLocal(li.typeKind(), li.slot() - count + 1); - - case StoreInstruction ls when ls.slot() < count && ls.typeKind().slotSize()==2 -> - cob.aload(0).dup_x2().pop().putfield(cd, "param_" + ls.slot(), localTracker.paramType(ls.slot())); - case StoreInstruction ls when ls.slot() < count -> - cob.aload(0).swap().putfield(cd, "param_" + ls.slot(), localTracker.paramType(ls.slot())); - - case StoreInstruction ls -> { - localTracker.addLocal(ls.slot() - count + 1, ls.typeKind()); - cob.storeLocal(ls.typeKind(), ls.slot() - count + 1); - } - case ConstantInstruction ci -> { - cob.loadConstant(ci.constantValue()); - } - - default -> cob.with(coe); - } - } - cob.labelBinding(invalidState); - cob.new_(ClassDesc.ofDescriptor(IllegalStateException.class.descriptorString())).dup() - .invokespecial(ClassDesc.ofDescriptor(IllegalStateException.class.descriptorString()), ConstantDescs.INIT_NAME, ConstantDescs.MTD_void) - .athrow(); - cob.labelBinding(end); - - localTracker.createLocalStoreFields(clb); - } - - private void rebuildGeneratorMethod(ClassModel clm, MethodModel mem, CodeModel com, CodeBuilder cob) { - generateGeneratorFromGenMethod(clm, mem, com, cob); - } - } - -// public static Gen meow() { -// for (int i = 0; i < 2; i++) { -// System.out.println(i + "asldkjasd"); -// } -// return null; -// } - -// public Gen gen(double l, int v) { -// Gen.yield("Yield"); -// return Gen.ret("Ret"); -// } -// -// public Gen gen2() { -// Gen.yield("Yield"); -// return Gen.ret("Ret"); -// } - - - public static class Mixer implements Gen { - public void test() { - for (int i = 0; i < 10; i++) { - System.out.println(i + "th iteration"); - } - - while (this.next() instanceof Yield(var yield)) { - - } - } - - @Override - public Res next() { - return new Yield<>("12"); + try{ + ((Runnable)loader.loadClass(clazz.getName()).getConstructor().newInstance()).run(); + }catch (Exception e){ + throw new RuntimeException(e); } } } diff --git a/src/generator/runtime/GeneratorBuilder.java b/src/generator/runtime/GeneratorBuilder.java new file mode 100644 index 0000000..c1daee7 --- /dev/null +++ b/src/generator/runtime/GeneratorBuilder.java @@ -0,0 +1,340 @@ +package generator.runtime; + +import generator.Gen; + +import java.lang.classfile.*; +import java.lang.classfile.attribute.StackMapFrameInfo; +import java.lang.classfile.instruction.*; +import java.lang.constant.ClassDesc; +import java.lang.constant.ConstantDescs; +import java.lang.constant.MethodTypeDesc; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.TOP; + +public class GeneratorBuilder { + private final static String PARAM_PREFIX = "param_"; + private final static String LOCAL_PREFIX = "local_"; + private final static String STACK_PREFIX = "stack_"; + private final static String STATE_NAME = "state"; + + private final static ClassDesc CD_Gen = ClassDesc.ofDescriptor(Gen.class.descriptorString()); + private final static ClassDesc CD_Res = ClassDesc.ofDescriptor(Gen.Res.class.descriptorString()); + private final static ClassDesc CD_Yield = ClassDesc.ofDescriptor(Gen.Yield.class.descriptorString()); + private final static ClassDesc CD_Ret = ClassDesc.ofDescriptor(Gen.Ret.class.descriptorString()); + private final static MethodTypeDesc MTD_Res = MethodTypeDesc.of(CD_Res); + private final static MethodTypeDesc MTD_Gen_Obj = MethodTypeDesc.of(CD_Gen, ConstantDescs.CD_Object); + + private final String name; + public final ClassDesc CD_this_gen; + private final ClassDesc[] params; + private final MethodTypeDesc MTD_init; + private final int paramSlotOff; + + public interface ParamConsumer{ + void consume(String param, int slot, ClassDesc type); + } + + public GeneratorBuilder(String name, ClassDesc[] params){ + this.name = name; + CD_this_gen = ClassDesc.of(name); + this.params = params; + MTD_init = MethodTypeDesc.of(ConstantDescs.CD_void, params); + paramSlotOff = Arrays.stream(params).mapToInt(p -> TypeKind.from(p).slotSize()).sum(); + } + + public void params(int slot_start, ParamConsumer consumer){ + int offset = 0; + for (var param : params) { + consumer.consume(PARAM_PREFIX+offset, offset+slot_start, param); + offset += TypeKind.from(param).slotSize(); + } + } + + public void buildGeneratorMethodShim(CodeBuilder cob){ + cob.new_(CD_this_gen).dup(); + params(0, (_, slot, type) -> { + cob.loadLocal(TypeKind.from(type), slot); + }); + cob.invokespecial(CD_this_gen, ConstantDescs.INIT_NAME, MTD_init).areturn(); + } + + public byte[] buildGenerator(CodeModel com){ + return ClassFile.of(ClassFile.StackMapsOption.STACK_MAPS_WHEN_REQUIRED).build(CD_this_gen, clb -> { + clb.withInterfaces(List.of(clb.constantPool().classEntry(CD_Gen))); + // parameter fields + params(0, (param, _, type) -> { + clb.withField(param, type, ClassFile.ACC_PRIVATE); + }); + // fms state + clb.withField(STATE_NAME, ConstantDescs.CD_int, ClassFile.ACC_PRIVATE); + + // constructor + clb.withMethod(ConstantDescs.INIT_NAME, MTD_init, ClassFile.ACC_PUBLIC, mb -> mb.withCode(cob -> { + cob.aload(0).invokespecial(ConstantDescs.CD_Object, ConstantDescs.INIT_NAME, ConstantDescs.MTD_void); + params(1, (param, slot, type) -> { + cob.aload(0).loadLocal(TypeKind.from(type), slot).putfield(CD_this_gen, param, type); + }); + cob.return_(); + })); + + // fms method + clb.withMethod("next", MethodTypeDesc.of(CD_Res), ClassFile.ACC_PUBLIC, mb -> mb.withCode(cob -> { + buildGeneratorNext(clb, cob, com); + })); + }); + } + + private void buildGeneratorNext(ClassBuilder clb, CodeBuilder cob, CodeModel com){ + cob.trying( + tcob -> generateStateMachine(clb, tcob, com), + // catch anything set our state to -1 and throw the exception + ctb -> ctb.catchingAll( + blc -> blc.aload(0).loadConstant(-1).putfield(CD_this_gen, STATE_NAME, ConstantDescs.CD_int).athrow() + ) + ).aconst_null().areturn(); + } + + private void generateStateMachine(ClassBuilder clb, CodeBuilder cob, CodeModel com){ + var stateSwitchCases = new ArrayList(); + var invalidState = cob.newLabel(); + stateSwitchCases.add(SwitchCase.of(0, cob.newLabel())); + int switchCase = 1; + for (CodeElement coe : com) { + if (coe instanceof InvokeInstruction is && is.opcode().equals(Opcode.INVOKESTATIC) && is.owner().asSymbol().equals(CD_Gen) && (is.name().equalsString("yield"))) { + stateSwitchCases.add(SwitchCase.of(switchCase, cob.newLabel())); + switchCase++; + } + } + cob.aload(0).getfield(CD_this_gen, STATE_NAME, TypeKind.INT.upperBound()).lookupswitch(invalidState, stateSwitchCases); + var start = cob.startLabel(); + var end = cob.newLabel(); + cob.localVariable(0, "this", CD_this_gen, start, end); + + var localTracker = new LocalTracker(com); + + switchCase = 1; + cob.labelBinding(stateSwitchCases.removeFirst().target()); + final boolean[] ignore_next_return = {false}; + final boolean[] ignore_next_pop = {false}; + for (CodeElement coe : com) { + switch (coe) { + case Instruction ins when ins.opcode() == Opcode.POP && ignore_next_pop[0] -> { + ignore_next_pop[0] = false; + continue; + } + case ReturnInstruction _ when ignore_next_return[0] -> { + ignore_next_return[0] = false; + continue; + } + case Instruction _ when ignore_next_return[0] || ignore_next_pop[0] -> throw new RuntimeException(); + + case Label l -> localTracker.encounterLabel(l); + + default -> {} + } + if(coe instanceof InvokeInstruction is + && is.opcode().equals(Opcode.INVOKESTATIC) + && is.owner().asSymbol().equals(CD_Gen) + && (is.name().equalsString("yield") || is.name().equalsString("ret"))){ + if (MethodTypeDesc.ofDescriptor(is.method().type().stringValue()).parameterArray().length == 0) { + cob.aconst_null(); + } + + if (is.name().equalsString("ret")) { + cob.aload(0).loadConstant(-1).putfield(CD_this_gen, STATE_NAME, TypeKind.INT.upperBound()) + .new_(CD_Ret) + .dup_x1() + .swap() + .invokespecial(CD_Ret, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object)) + .areturn(); + ignore_next_return[0] = true; + } else { + int finalSwitchCase = switchCase; + switchCase++; + localTracker.savingLocals(CD_this_gen, cob, () -> { + cob.aload(0).loadConstant(finalSwitchCase).putfield(CD_this_gen, STATE_NAME, TypeKind.INT.upperBound()) + .new_(CD_Yield) + .dup_x1() + .swap() + .invokespecial(CD_Yield, ConstantDescs.INIT_NAME, MethodTypeDesc.of(ConstantDescs.CD_void, ConstantDescs.CD_Object)) + .areturn(); + cob.labelBinding(stateSwitchCases.removeFirst().target()); + }); + + ignore_next_pop[0] = true; + } + continue; + } + switch (coe) { + // locals which were once function parameters can be ignored + case LocalVariable lv when lv.slot() < paramSlotOff -> {} + case LocalVariable lv -> cob.localVariable(lv.slot() - paramSlotOff + 1, lv.name(), lv.type(), lv.startScope(), lv.endScope()); + + // increment indexes into the stack + case IncrementInstruction ii when ii.slot() < paramSlotOff -> + cob.aload(0).dup().getfield(CD_this_gen, PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int) + .loadConstant(ii.constant()).iadd() + .putfield(CD_this_gen, PARAM_PREFIX + ii.slot(), ConstantDescs.CD_int); + case IncrementInstruction ii -> cob.iinc(ii.slot() - paramSlotOff + 1, ii.constant()); + + // convert local function parameters to class fields and offset regular locals + case LoadInstruction li when li.slot() < paramSlotOff -> + cob.aload(0).getfield(CD_this_gen, PARAM_PREFIX + li.slot(), localTracker.paramType(li.slot())); + case LoadInstruction li -> cob.loadLocal(li.typeKind(), li.slot() - paramSlotOff + 1); + + // convert local function parameters to class fields and offset regular locals + case StoreInstruction ls when ls.slot() < paramSlotOff && ls.typeKind().slotSize() == 2 -> + cob.aload(0).dup_x2().pop().putfield(CD_this_gen, PARAM_PREFIX + ls.slot(), localTracker.paramType(ls.slot())); + case StoreInstruction ls when ls.slot() < paramSlotOff -> + cob.aload(0).swap().putfield(CD_this_gen, PARAM_PREFIX + ls.slot(), localTracker.paramType(ls.slot())); + case StoreInstruction ls -> { + localTracker.trackLocal(ls.slot() - paramSlotOff + 1, ls.typeKind()); + cob.storeLocal(ls.typeKind(), ls.slot() - paramSlotOff + 1); + } + + default -> cob.with(coe); + } + } + cob.labelBinding(invalidState); + cob.new_(ClassDesc.ofDescriptor(IllegalStateException.class.descriptorString())).dup() + .invokespecial(ClassDesc.ofDescriptor(IllegalStateException.class.descriptorString()), ConstantDescs.INIT_NAME, ConstantDescs.MTD_void) + .athrow(); + cob.labelBinding(end); + + localTracker.createLocalStoreFields(clb); + } + + + private class LocalTracker { + + record LocalStore(String name, ClassDesc cd) { + } + + HashMap parameter_map = new HashMap<>(); + + HashMap localVarTypes = new HashMap<>(); + HashMap stackMapFrames = new HashMap<>(); + StackMapFrameInfo currentFrame; + ArrayList localStore = new ArrayList<>(); + + + private LocalTracker(CodeModel com) { + int offset = 0; + for (var param : params) { + parameter_map.put(offset, param); + offset += TypeKind.from(param).slotSize(); + } + + for (var attr : com.findAttributes(Attributes.stackMapTable())) { + var entries = new ArrayList(); + for (var smfi : attr.entries()) { + var locals = new ArrayList<>(smfi.locals()); + for (int i = 0; i < params.length; i++) locals.removeFirst(); + locals.addFirst(StackMapFrameInfo.ObjectVerificationTypeInfo.of(CD_this_gen)); + entries.add(StackMapFrameInfo.of(smfi.target(), locals, smfi.stack())); + stackMapFrames.put(smfi.target(), entries.getLast()); + } + } + } + + + //Tries its best to reuse old saved locals field slots, only reuses if types exactly match + public void savingLocals(ClassDesc cd, CodeBuilder cob, Runnable run) { + record Saved(int slot, String name, ClassDesc cd) { + } + + var saved = new ArrayList(); + var lls = new ArrayList<>(localStore); + currentLocals((slot, tk, desc) -> { + String name = null; + for (int i = 0; i < lls.size(); i++) + if (lls.get(i).cd.equals(desc)) { + name = lls.get(i).name; + lls.remove(i); + break; + } + if (name == null) { + name = LOCAL_PREFIX + localStore.size(); + localStore.add(new LocalStore(name, desc)); + } + saved.add(new Saved(slot, name, desc)); + cob.aload(0).loadLocal(tk, slot).putfield(cd, name, desc); + }); + run.run(); + + for (var save : saved) { + cob.aload(0).getfield(cd, save.name, save.cd).storeLocal(TypeKind.from(save.cd), save.slot); + } + } + + public void createLocalStoreFields(ClassBuilder clb) { + for (var local : localStore) { + clb.withField(local.name, local.cd, ClassFile.ACC_PRIVATE); + } + } + + public void encounterLabel(Label l) { + var tmp = stackMapFrames.get(l); + if (tmp != null) { + localVarTypes.clear(); + currentFrame = tmp; + } + } + + public ClassDesc paramType(int slot) { + return parameter_map.get(slot); + } + + public void trackLocal(int slot, TypeKind typeKind) { + localVarTypes.put(slot, typeKind); + } + + interface LocalConsumer { + void consume(int slot, TypeKind tk, ClassDesc desc); + } + + void currentLocals(LocalConsumer consumer) { + var slot = 0; + if (currentFrame != null) + for (var kind : currentFrame.locals()) { + switch (kind) { + case StackMapFrameInfo.ObjectVerificationTypeInfo o -> { + if (slot != 0) + consumer.consume(slot, o.className().typeKind(), o.classSymbol()); + slot += 1; + } + case StackMapFrameInfo.SimpleVerificationTypeInfo ti -> { + if (kind == TOP) { + slot += 1; + if (localVarTypes.get(slot - 1) instanceof TypeKind tk) { + ClassDesc cd = tk.upperBound(); + consumer.consume(slot - 1, tk, cd); + } + continue; + } + var type = switch (ti) { + case INTEGER -> TypeKind.INT; + case FLOAT -> TypeKind.FLOAT; + case DOUBLE -> TypeKind.DOUBLE; + case LONG -> TypeKind.LONG; + case NULL -> TypeKind.REFERENCE; + default -> throw new IllegalStateException(); + }; + consumer.consume(slot, type, type.upperBound()); + slot += 1; + } + case StackMapFrameInfo.UninitializedVerificationTypeInfo _ -> throw new IllegalStateException(); + } + } + for (var entry : localVarTypes.entrySet()) { + if (entry.getKey() < slot) continue; + ClassDesc cd = entry.getValue().upperBound(); + consumer.consume(entry.getKey(), TypeKind.from(cd), cd); + } + } + } +} diff --git a/src/generator/runtime/GeneratorClassLoader.java b/src/generator/runtime/GeneratorClassLoader.java new file mode 100644 index 0000000..79b4d04 --- /dev/null +++ b/src/generator/runtime/GeneratorClassLoader.java @@ -0,0 +1,96 @@ +package generator.runtime; + +import generator.Fun; +import generator.Gen; + +import java.io.IOException; +import java.io.PrintStream; +import java.lang.classfile.*; +import java.lang.classfile.attribute.StackMapFrameInfo; +import java.lang.classfile.instruction.*; +import java.lang.constant.ClassDesc; +import java.lang.constant.ConstantDescs; +import java.lang.constant.MethodTypeDesc; +import java.lang.reflect.AccessFlag; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Objects; + +import static java.lang.classfile.attribute.StackMapFrameInfo.SimpleVerificationTypeInfo.TOP; + +public class GeneratorClassLoader extends ClassLoader { + private final HashMap customClazzDefMap = new HashMap<>(); + private final HashMap> customClazzMap = new HashMap<>(); + + public GeneratorClassLoader(ClassLoader parent) { + super(parent); + } + + @Override + public Class loadClass(String name) throws ClassNotFoundException { + if (customClazzDefMap.get(name) instanceof byte[] bytes) + customClazzMap.put(name, defineClass(name, bytes, 0, bytes.length)); + if (customClazzMap.get(name) instanceof Class clazz) + return clazz; + if (name.startsWith("java")) + return super.loadClass(name); + + var p = "/" + name.replace('.', '/') + ".class"; + try (var stream = Fun.class.getResourceAsStream(p)) { + var bytes = Objects.requireNonNull(stream).readAllBytes(); + bytes = searchForGenerators(bytes); + customClazzDefMap.put(name, bytes); + customClazzMap.put(name, defineClass(name, bytes, 0, bytes.length)); + return customClazzMap.get(name); + } catch (IOException e) { + throw new ClassNotFoundException(name, e); + } + } + + public byte[] searchForGenerators(byte[] in) { + var clm = ClassFile.of().parse(in); + var isGen = clm.thisClass().asSymbol().descriptorString().equals(Gen.class.descriptorString()); + return ClassFile.of().build(clm.thisClass().asSymbol(), cb -> { + for (var ce : clm) { + if (ce instanceof MethodModel mem && !isGen) { + var methodRetGen = mem.methodTypeSymbol().returnType().descriptorString().equals(Gen.class.descriptorString()); + if (!methodRetGen) { + cb.with(mem); + } else{ + generatorMethod(cb, mem, clm); + } + } else + cb.with(ce); + } + }); + } + + private void generatorMethod(ClassBuilder cb, MethodModel mem, ClassModel clm) { + cb.withMethod(mem.methodName(), mem.methodType(), mem.flags().flagsMask(), mb -> { + for (var me : mem) { + if (me instanceof CodeModel com) { + var mts = mem.methodTypeSymbol(); + mts = mts.changeReturnType(ConstantDescs.CD_void); + if (!mem.flags().has(AccessFlag.STATIC)) { + mts = mts.insertParameterTypes(0, clm.thisClass().asSymbol()); + } + var gb = new GeneratorBuilder("Gen" + customClazzDefMap.size(), mts.parameterArray()); + mb.withCode(gb::buildGeneratorMethodShim); + addGenerator(gb.CD_this_gen.displayName(), gb.buildGenerator(com)); + } else mb.with(me); + } + }); + } + + protected void addGenerator(String name, byte[] def){ + try { + Files.write(Path.of("out/production/generators/" + name + ".class"), def); + } catch (IOException e) { + throw new RuntimeException(e); + } + customClazzDefMap.put(name, def); + } +}