package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.tags.ErrorTag;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.generics.GenericTransform;
import net.covers1624.coffeegrinder.debug.Step;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.*;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchConstructorInvokeSpecial;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.*;
import static net.covers1624.quack.util.SneakyUtils.unsafeCast;

/**
 * Created by covers1624 on 2/9/21.
 */
public class LocalClasses implements ClassTransformer {

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private ClassTransformContext ctx;

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private ClassDecl outer;

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        this.ctx = ctx;
        outer = cInsn;

        cInsn.getClassMembers()
                .filter(e -> !e.getClazz().isSynthetic() && e.getClazz().getDeclType().isLocalOrAnonymous())
                .toImmutableList()
                .reverse()
                .forEach(this::process);
    }

    private void process(ClassDecl classDecl) {
        if (classDecl.getClazz().getDeclType() == ClassType.DeclType.ANONYMOUS) {
            processAnonClass(classDecl);
        } else {
            processLocalClass(classDecl);
        }
    }

    private void processAnonClass(ClassDecl anonClass) {
        ClassType localClazz = anonClass.getClazz();

        FastStream<Instruction> enclosingMembersToSearch;
        if (!localClazz.getEnclosingMethod().isPresent()) {
            // the anon class is in the static constructor, or a non-static field initializer
            // Search field members first, and then the static constructor if it exists
            enclosingMembersToSearch = FastStream.concat(outer.getFieldMembers(), FastStream.ofNullable(getStaticInit(outer)));
        } else {
            Method enclosing = localClazz.getEnclosingMethod().get();
            enclosingMembersToSearch = FastStream.ofNullable(outer.findMethod(enclosing));
            // Java 13 fixed a bug which was incorrectly causing the metadata for enclosing methods to point to
            // lambdas. This code is only required on J13+ where the enclosing method of an anon class, in a lambda,
            // in a field initializer, will point to the ctor/static initializer method, which, the lambda has already been inlined
            // to the field declaration.
            // https://github.com/openjdk/jdk17u/commit/23278fea38d6ad179dd2202b8a251199b7443674
            if (enclosing.isConstructor() || enclosing.getName().equals("<clinit>")) {
                enclosingMembersToSearch = enclosingMembersToSearch.concat(outer.getFieldMembers());
            }
        }

        New construction = enclosingMembersToSearch.flatMap(f -> f.<New>descendantsWhere(e -> isNew(e, localClazz))).only();
        moveAnonClass(anonClass, construction);
    }

    private void processLocalClass(ClassDecl localClass) {
        ClassType localClazz = localClass.getClazz();
        MethodDecl methodInsn = localClazz.getEnclosingMethod().map(outer::findMethod).orElseGet(() -> requireNonNull(getStaticInit(outer)));

        ctx.pushStep(methodInsn.getMethod().getName() + "::" + localClazz.getName(), Step.StepContextType.CLASS);

        List<Instruction> usages = methodInsn.descendantsToListWhere(e -> isUsageOf(e, localClazz));
        Optional<Instruction> declPointOpt = FastStream.of(usages).fold(VariableDeclarations::unifyUsages);

        if (!declPointOpt.isPresent()) {
            moveUnusedLocalClass(localClass, methodInsn);
            ctx.popStep();
            return; // no constructions to use for determining local scope
        }

        moveLocalClass(localClass, methodInsn, declPointOpt.get());
        if (!localClazz.isInterface() && !localClazz.isEnum()) {// Interface and Enum local classes are static
            // Add this ctor calls from within the class
            List<Instruction> constructions = FastStream.of(usages).filter(e -> isCtorUsage(e, localClazz)).toLinkedList();
            localClass.descendantsWhere(e -> isCtorUsage(e, localClazz)).forEach(constructions::add);
            replaceSyntheticLocals(localClass, unsafeCast(constructions));
        }

        ctx.popStep();
    }

    private static boolean isCtorUsage(Instruction insn, ClassType localClazz) {
        return isNew(insn, localClazz) || matchConstructorInvokeSpecial(insn, localClazz) != null;
    }

    @Nullable
    private MethodDecl getStaticInit(ClassDecl outer) {
        return outer.findMethod("<clinit>", Type.getType("()V"));
    }

    private void moveUnusedLocalClass(ClassDecl localClass, MethodDecl methodInsn) {
        ctx.pushStep("Inline (unused)");
        // move to top of method
        methodInsn.getBody().getEntryPoint().instructions.addFirst(new DeadCode(localClass));
        ctx.popStep();
    }

    private void moveLocalClass(ClassDecl localClass, MethodDecl methodInsn, Instruction unifiedUsagePoint) {
        // a local class may be able to be declared inside another local class, or method, but we know the containing method, so keep going up
        while (unifiedUsagePoint.firstAncestorOfType(InsnOpcode.METHOD_DECL) != methodInsn) {
            unifiedUsagePoint = unifiedUsagePoint.getParent();
        }

        // selectDeclarableParent is required if there's only one usage, as unifyUsages won't get called
        Instruction declPoint = VariableDeclarations.selectDeclarableParent(unifiedUsagePoint);

        ctx.pushStep("Inline");
        declPoint.insertBefore(localClass);
        ctx.popStep();
    }

    private boolean isUsageOf(Instruction insn, ClassType localClazz) {
        if (isCtorUsage(insn, localClazz)) return true;

        switch (insn.opcode) {
            case CLASS_DECL:
                return ((ClassDecl) insn).getClazz().getDirectSuperTypes().anyMatch(e -> e.mentions(localClazz));
            case CHECK_CAST:
                return ((Cast) insn).getType().mentions(localClazz);
            case INSTANCE_OF:
                return ((InstanceOf) insn).getType().mentions(localClazz);
            case LDC_CLASS:
                return ((LdcClass) insn).getType().mentions(localClazz);
            case LOCAL_VARIABLE:
                return ((LocalVariable) insn).getKind() == LocalVariable.VariableKind.PARAMETER && isUsageOf((LocalVariable) insn, localClazz);
            case LOCAL_REFERENCE:
                return isFirstUsage((LocalReference) insn) && isUsageOf(((LocalReference) insn).variable, localClazz);
        }
        return false;
    }

    private boolean isFirstUsage(LocalReference insn) {
        return insn.variable.getReferences().get(0) == insn;
    }

    private boolean isUsageOf(LocalVariable var, ClassType decl) {
        AType type = var.getType();
        if (var.getGenericSignature() != null) {
            type = GenericTransform.getVariableGenericType(var, ctx.getTypeResolver());
        }
        return type.mentions(decl);
    }

    private void moveAnonClass(ClassDecl anonClass, New construction) {
        ctx.pushStep(anonClass.getClazz().getName(), Step.StepContextType.CLASS);

        construction.setAnonymousClassDeclaration(anonClass);

        replaceSyntheticLocals(anonClass, Collections.singletonList(construction));

        ctx.popStep();
    }

    private static boolean isNew(Instruction insn, ClassType localClazz) {
        if (insn.opcode != InsnOpcode.NEW) return false;
        AType resultType = insn.getResultType();
        if (!(resultType instanceof ClassType)) return false;
        return ((ClassType) resultType).getDeclaration() == localClazz;
    }

    private void replaceSyntheticLocals(ClassDecl localClass, List<AbstractInvoke> ctorUsages) {
        List<Invoke> superCalls = localClass.getMethodMembers()
                .filter(e -> e.getMethod().isConstructor())
                .map(InvokeMatching::getSuperConstructorCall)
                .filter(Objects::nonNull)
                .toLinkedList();

        assert !superCalls.isEmpty();

        Map<Method, List<AbstractInvoke>> ctorUsageLookup = new HashMap<>();
        for (AbstractInvoke invoke : ctorUsages) {
            ctorUsageLookup.computeIfAbsent(invoke.getMethod().getDeclaration(), e -> new LinkedList<>())
                    .add(invoke);
        }

        // TODO We should probably clean this all up, use the synthetic/mandated parameter flags as the primary
        //      search mechanism, with this current tracing as the legacy fallback.

        // Apply Field based tracing of synthetic parameters.
        while (true) {
            List<@Nullable Store> synStores = FastStream.of(superCalls)
                    .map(e -> matchStoreField(e.getPrevSiblingOrNull()))
                    .toLinkedList();

            Store synStore = synStores.get(0);
            if (synStore == null) {
                break; // we're done here!
            }

            Field field = ((FieldReference) synStore.getReference()).getField();
            assert field.getDeclaringClass() == localClass.getClazz();
            assert field.isSynthetic(); // Must be synthetic.
            assert FastStream.of(synStores).allMatch(e -> matchStoreField(e, field) != null);

            ctx.pushStep(field.getName());

            CapturedVariableProcessor processor = new CapturedVariableProcessor(field, localClass.getClazz(), ctorUsageLookup);
            synStores.forEach(processor::processSynStore);
            if (processor.value == null) {
                assert ctorUsages.isEmpty();
                processor.value = createLocalClassFieldFallback(localClass, field);
            }

            // replace all loads of the synthetic field with the value
            localClass.descendantsMatching(e -> matchLoadField(e, field)).forEach(e -> e.replaceWith(processor.value.copy()));

            // replace all the other loads of removed params with the value
            processor.paramUsagesToReplace.forEach(e -> requireNonNull(matchLoad(e.getParent())).replaceWith(processor.value.copy()));

            // remove the synthetic field
            requireNonNull(localClass.findField(field)).remove();
            ctx.popStep();
        }

        // Don't handle mandated/synthetic parameters for enums, for now.
        if (localClass.getClazz().isEnum()) return;

        // Process Java21+ mandated/synthetic parameters without an associated field.
        localClass.getMethodMembers()
                .filter(e -> e.getMethod().isConstructor())
                .forEach(ctor -> {
                    for (ParameterVariable pVar : ctor.parameters) {
                        Parameter par = pVar.parameter;
                        if (!pVar.isImplicit() && (par.isMandated() || par.isSynthetic())) {
                            // We have found a J21 synthetic parameter which needs to be nuked.
                            CapturedVariableProcessor processor = new CapturedVariableProcessor(null, localClass.getClazz(), ctorUsageLookup);
                            processor.traceAndRemoveSynParam(ctor, pVar);
                            if (processor.value == null) {
                                processor.value = createLocalClassFieldFallbackNoField(pVar);
                            }
                            processor.paramUsagesToReplace.forEach(e -> requireNonNull(matchLoad(e.getParent())).replaceWith(processor.value.copy()));
                        }
                    }
                });
    }

    private Instruction createLocalClassFieldFallback(ClassDecl localClass, Field field) {
        if (field.getName().endsWith("this$0")) return new LoadThis((ClassType) field.getType());

        if (field.getName().startsWith("val$")) {
            String name = field.getName().substring("val$".length());
            MethodDecl method = localClass.firstAncestorOfType(InsnOpcode.METHOD_DECL);
            FastStream<LocalVariable> localVars = FastStream.concat(method.parameters, method.variables);
            List<LocalVariable> vars = localVars.filter(v -> v.getName().equals(name)).toLinkedList();
            if (vars.size() == 1) {
                return new Load(new LocalReference(vars.get(0)));
            }
            // TODO: proper scope search if there's > 1 var?
        }

        Nop err = new Nop();
        err.setTag(new ErrorTag("unmatched synthetic", field.getName()));
        return err;
    }

    private Instruction createLocalClassFieldFallbackNoField(ParameterVariable pVar) {
        if (pVar.getName().endsWith("this$0")) {
            return new LoadThis((ClassType) pVar.getType());
        }

        Nop err = new Nop();
        err.setTag(new ErrorTag("unmatched synthetic", pVar.getName()));
        return err;
    }

    // return a list of values while nop-ing out the parameters till you get to an external call
    // called with a storeField inside a super-calling (non delegating) constructor
    public static class CapturedVariableProcessor {

        private final @Nullable Field field;
        private final ClassType localClass;
        private final Map<Method, List<AbstractInvoke>> ctorUsageLookup;
        private final List<LocalReference> paramUsagesToReplace = new LinkedList<>();

        @Nullable
        private Instruction value;

        public CapturedVariableProcessor(@Nullable Field field, ClassType localClass, Map<Method, List<AbstractInvoke>> ctorUsageLookup) {
            this.field = field;
            this.localClass = localClass;
            this.ctorUsageLookup = ctorUsageLookup;
        }

        // trace the values assigned to the synStore via the usages of the constructor param while removing the parameters and nop-ing out calls
        private void processSynStore(Store synStore) {
            MethodDecl ctor = getContainingFunction(synStore);
            ParameterVariable param = (ParameterVariable) requireNonNull(matchLoadLocal(synStore.getValue())).getVariable();

            synStore.remove();
            traceAndRemoveSynParam(ctor, param);
        }

        private boolean processThisCtorCall(Instruction insn, int paramIdx) {
            Invoke thisCtorCall = matchConstructorInvokeSpecial(insn, localClass);
            if (thisCtorCall == null) return false;

            MethodDecl ctor = getContainingFunction(thisCtorCall);
            if (ctor.getMethod().getDeclaringClass() != localClass) return false; //invoke is a super constructor call

            Load passThrough = requireNonNull(matchLoadLocal(thisCtorCall.getArguments().get(paramIdx)));
            ParameterVariable param = (ParameterVariable) passThrough.getVariable();

            passThrough.replaceWith(new Nop());
            traceAndRemoveSynParam(ctor, param);
            return true;
        }

        private void traceAndRemoveSynParam(MethodDecl ctor, ParameterVariable param) {
            param.makeImplicit();
            assert param.getStoreCount() == 0;
            paramUsagesToReplace.addAll(param.getReferences()); // other usages of the parameter rather than the synthetic field

            List<AbstractInvoke> usages = ctorUsageLookup.get(ctor.getMethod());
            if (usages == null) return;

            for (AbstractInvoke call : usages) {
                if (processThisCtorCall(call, param.pIndex)) { continue; }

                Instruction value = call.getArguments().get(param.pIndex);

                Load localRef = LoadStoreMatching.matchLoadLocal(value);
                boolean isCtorParamReuse = localRef != null && paramUsagesToReplace.remove((LocalReference) localRef.getReference());
                if (!isCtorParamReuse) {
                    addValue(value);
                }

                if (localClass.getDeclType() == ClassType.DeclType.ANONYMOUS && param.pIndex == 0 && value.opcode == InsnOpcode.LOAD_THIS) {
                    // Could also check the type of the loadthis, or the type of the param against the enclosing class type, but probably no need
                    // The index 0 check is sufficient

                    // If the usage is both targeted and an instanced anon, and the target is the same as the enclosing scope instance, javac only generates one param
                    // We need to leave the call alone and let the targeted new logic do its work with the LOAD_THIS
                    if (param.getLoadCount() != 0 && TypeSystem.isConstructedViaTargetInstance(localClass)) {
                        return;
                    }

                    ((New) call).hasEnclosingScopeInstanceParam = true;
                }

                value.replaceWith(new Nop());
            }
        }

        private void addValue(Instruction value) {
            if (this.value == null) {
                this.value = value;
                assert !isSyntheticReuse(value);
            } else {
                assert isSyntheticReuse(value) || loadTargetEqual(this.value, value);
            }
        }

        private static boolean loadTargetEqual(Instruction value1, Instruction value2) {
            if (value1.opcode != value2.opcode) return false;

            if (value1.opcode == InsnOpcode.LOAD_THIS) {
                return ((LoadThis) value1).getType().equals(((LoadThis) value2).getType());
            }
            if (value1.opcode == InsnOpcode.LOAD) {
                Load load1 = matchLoadLocal(value1);
                Load load2 = matchLoadLocal(value2);
                return load1 != null && load2 != null && load1.getVariable() == load2.getVariable();
            }
            return false;
        }

        private boolean isSyntheticReuse(Instruction value) {
            return field != null && LoadStoreMatching.matchLoadField(value, field) != null;
        }

        private static MethodDecl getContainingFunction(Instruction insn) {
            return (MethodDecl) insn.getParent().getParent().getParent();
        }
    }
}
