package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Method;
import net.covers1624.coffeegrinder.type.ReferenceType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.quack.collection.ColUtils;
import org.objectweb.asm.Type;

import java.util.*;

import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvokeDynamic;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadField;
import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLoadLocal;

/**
 * Responsible for inlining Lambda synthetic methods.
 * <p>
 * Created by covers1624 on 7/8/21.
 */
public class Lambdas implements ClassTransformer {

    private final ClassType lambdaMetaFactory;

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private ClassTransformContext ctx;

    private final Set<Instruction> visited = new HashSet<>();

    public Lambdas(TypeResolver typeResolver) {
        lambdaMetaFactory = typeResolver.resolveClassDecl("java/lang/invoke/LambdaMetafactory");
    }

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        this.ctx = ctx;
        visited.clear();

        // This method is only used by the LambdaMetaFactory when a Serialized lambda is present.
        // We can straight up nuke this method, it is only used at runtime.
        MethodDecl deserializeLambdaDef = cInsn.findMethod("$deserializeLambda$", Type.getMethodType("(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;"));
        if (deserializeLambdaDef != null) {
            ctx.pushStep("Remove $deserializeLambda$ function.");
            deserializeLambdaDef.remove();
            ctx.popStep();
        }

        // Build lookup map of synthetic methods. (for lambdas javac always produces 'lambda$<outside_method_name>$<class_global_lambda_index>')
        Map<Method, MethodDecl> syntheticMethods = cInsn.getMethodMembers()
                .filter(e -> e.getMethod().isSynthetic())
                .toImmutableMap(MethodDecl::getMethod, e -> e);

        ctx.pushStep("Inline lambdas");
        for (FieldDecl field : cInsn.getFieldMembers()) {
            visitMember(field, syntheticMethods);
        }
        // Visit all functions and attempt to inline lambdas.
        for (MethodDecl method : cInsn.getMethodMembers().toLinkedList()) {
            visitMember(method, syntheticMethods);
        }
        ctx.popStep();
    }

    private void visitMember(Instruction toVisit, Map<Method, MethodDecl> syntheticMethods) {
        if (!visited.add(toVisit)) return;
        // Lambda -> Method Ref conversion will remove methods.
        if (!toVisit.isConnected()) return;

        for (InvokeDynamic invokeDynamic : toVisit.descendantsOfType(InvokeDynamic.class)) {
            if (matchInvokeDynamic(invokeDynamic, lambdaMetaFactory) == null) continue;

            Method lambdaMethod = (Method) invokeDynamic.bootstrapArguments[1];
            if (!lambdaMethod.isSynthetic()) {
                assert invokeDynamic.arguments.size() <= 1;
                ctx.pushStep("Produce method reference");
                invokeDynamic.replaceWith(new MethodReference(
                        (ClassType) invokeDynamic.getResultType(),
                        lambdaMethod,
                        invokeDynamic.arguments.onlyOrDefault(new Nop())
                ).withOffsets(invokeDynamic));
                ctx.popStep();
                continue;
            }

            // Try and find the synthetic method for this lambda.
            MethodDecl lambdaTarget = syntheticMethods.get(lambdaMethod);
            if (lambdaTarget == null) continue;

            // Try and convert the lambda back to a method reference.
            if (convertLambdaToMethodReference(invokeDynamic, lambdaTarget)) {
                continue;
            }

            ctx.pushStep("inline lambda " + lambdaMethod.getName());

            // Recurse!
            visitMember(lambdaTarget, syntheticMethods);

            // Remove the lambda target from the Class and replace the InvokeDynamic with it.
            invokeDynamic.replaceWith(lambdaTarget);
            lambdaTarget.setResultType((ReferenceType) invokeDynamic.getResultType());

            // Build a map of variable replacements.
            Map<ParameterVariable, Instruction> replacements = new HashMap<>();

            int indyArgOffset = 0;
            if (!lambdaMethod.isStatic()) {
                assert invokeDynamic.arguments.first() instanceof LoadThis;
                indyArgOffset++;
            }

            // Build list of arguments passed to the lambda.
            for (int i = 0; i < invokeDynamic.arguments.size() - indyArgOffset; i++) {
                Instruction arg = invokeDynamic.arguments.get(i + indyArgOffset);

                ParameterVariable param = lambdaTarget.parameters.get(i);
                assert matchLoadLocal(arg) != null ||
                       matchLoadField(arg) != null && ((FieldReference) ((Load) arg).getReference()).getTarget() instanceof LoadThis;

                // If the parameter has any stores, it would not be effectively final.
                assert param.getStoreCount() == 0;

                // Add to replacements map and remove the parameter.
                replacements.put(param, arg);
            }

            // Run replacements.
            lambdaTarget.descendantsMatching(LoadStoreMatching::matchLoadLocal)
                    .forEach(load -> {
                        @SuppressWarnings ("SuspiciousMethodCalls")
                        Instruction replacement = replacements.get(load.getVariable());
                        if (replacement != null) {
                            load.replaceWith(replacement.copy());
                        }
                    });

            replacements.keySet().forEach(ParameterVariable::makeImplicit);

            ctx.popStep();
        }
    }

    private boolean convertLambdaToMethodReference(InvokeDynamic invokeDynamic, MethodDecl lambdaTarget) {
        Method lambdaMethod = lambdaTarget.getMethod();
        // Javac is only known to do this for non-static lambdas.
        if (!lambdaMethod.isStatic()) return false;

        // Lambda body must have a single block.
        if (lambdaTarget.getBody().blocks.size() != 1) return false;

        // Either zero or a single parameter into the indy.
        if (invokeDynamic.arguments.size() > 1) return false;
        Block bodyBlock = lambdaTarget.getBody().blocks.first();

        // 2 variants
        Instruction toMatch;
        if (bodyBlock.instructions.size() == 1) {
            // the result is returned
            if (!(bodyBlock.instructions.first() instanceof Return ret)) return false;

            toMatch = ret.getFirstChild();
        } else {
            // The result is not returned.
            toMatch = bodyBlock.instructions.first();
            // Ensure last instruction is next, last will always be a return as there is only one block.
            if (toMatch.getNextSiblingOrNull() != bodyBlock.instructions.last()) return false;
        }

        // Match through casts, these appear when the method ref is generic and are safe to yeet.
        if (toMatch instanceof Cast) {
            toMatch = toMatch.getFirstChild();
        }

        if (!(toMatch instanceof Invoke invoke)) return false;
        // Now that we have found the invoke. Javac only generates these for protected methods in different packages.
        if (!invoke.getMethod().getAccessFlags().get(AccessFlag.PROTECTED)) return false;
        if (invoke.getMethod().getDeclaringClass().getPackage().equals(lambdaTarget.getMethod().getDeclaringClass().getPackage())) return false;

        // Only do this if the parameters start with 'x$', javac appears to use this prefix
        // for all generated parameters.
        List<ParameterVariable> parameters = lambdaTarget.parameters.toList();
        if (invoke.getMethod().isStatic() && !parameters.isEmpty() && !ColUtils.anyMatch(parameters, e -> e.getName().startsWith("x$"))) {
            return false;
        }

        ctx.pushStep("Convert lambda to method reference");
        invokeDynamic.replaceWith(new MethodReference(
                (ClassType) invokeDynamic.getResultType(),
                invoke.getTargetClassType(),
                invoke.getMethod(),
                invokeDynamic.arguments.onlyOrDefault(new Nop())
        ).withOffsets(invokeDynamic));
        lambdaTarget.remove();
        ctx.popStep();
        return true;
    }
}
