package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.LocalVariable.VariableKind;
import net.covers1624.coffeegrinder.bytecode.insns.tags.IIncTag;
import net.covers1624.coffeegrinder.bytecode.insns.tags.PotentialConstantLookupTag;
import net.covers1624.coffeegrinder.bytecode.insns.tags.RecordPatternComponentTag;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.Nullable;

import java.util.LinkedList;
import java.util.List;
import java.util.function.Function;

import static java.util.Objects.requireNonNull;

/**
 * Performs inlining transformations.
 * <p>
 * Created by coves1624 on 20/4/21.
 */
public class Inlining implements StatementTransformer {

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        Store store = LoadStoreMatching.matchStoreLocal(statement);
        if (store == null) return;

        LocalVariable v = store.getVariable();
        if (v.getStoreCount() != 1) return;
        if (!isKnownSyntheticVariable(v)) return;

        if (v.getLoadCount() == 0 && v.getKind() == VariableKind.STACK_SLOT) {
            ctx.pushStep("Remove redundant stack store " + v.getUniqueName());
            store.replaceWith(store.getValue());
            ctx.popStep();
            return;
        }

        if (v.getLoadCount() != 1) return;

        Instruction inlinedExpression = store.getValue();
        Load inlineTarget = (Load) FastStream.of(v.getReferences()).filter(Reference::isReadFrom).only().getParent();

        List<Runnable> extraTransforms = new LinkedList<>();
        if (v.getKind() == VariableKind.STACK_SLOT) {
            // make sure the stack load is in the next insn. Semantics of stack variables guarantee the rest.
            if (matchWithPotentialInline(store.getNextSiblingOrNull(), extraTransforms, ctx, e -> matchLoadIn(e, inlineTarget)) == null) {
                return;
            }
        } else {
            // careful, inlining here may change semantics, as such, we only allow in 'safe' cases
            if (!loadIsFirstOpIn(store.getNextSiblingOrNull(), inlineTarget)) {
                return;
            }
        }

        ctx.pushStep("Inline variable " + v.getUniqueName());
        extraTransforms.forEach(Runnable::run);
        // Assign the ranges of the load instruction:
        inlinedExpression.setOffsets(inlineTarget);
        // Assign the ranges of the store instruction:
        inlinedExpression.setOffsets(store);

        inlineTarget.replaceWith(inlinedExpression);
        store.remove();
        insertCastForIntegerConstantIfNecessary(inlinedExpression, v.getType(), ctx);
        unwrapStringValueOf(inlinedExpression, ctx);
        ctx.popStep();
    }

    private void insertCastForIntegerConstantIfNecessary(Instruction expr, AType type, StatementTransformContext ctx) {
        if (expr.getParent() instanceof Store) return;

        if (!(expr.getResultType() instanceof IntegerConstantType) || TypeSystem.isAssignableTo(PrimitiveType.INT, type)) return;

        ctx.pushStep("Add cast");
        expr.replaceWith(new Cast(expr, type));
        ctx.popStep();
    }

    private void unwrapStringValueOf(Instruction expr, StatementTransformContext ctx) {
        Invoke valueOf = InvokeMatching.matchInvoke(expr, Invoke.InvokeKind.STATIC, "valueOf");
        if (valueOf == null) return;
        if (!TypeSystem.isString(valueOf.getMethod().getDeclaringClass())) return;

        Binary binary = matchBinaryAdd(valueOf.getParent());
        if (binary == null) return;

        ctx.pushStep("Unwrap String.valueOf in string concat.");
        valueOf.replaceWith(valueOf.getArguments().get(0));

        if (!TypeSystem.isString(binary.getLeft().getResultType()) && !TypeSystem.isString(binary.getRight().getResultType())) {
            ctx.pushStep("Add empty string constant to coerce concat of non-strings");
            ClassType string = ctx.getTypeResolver().resolveClassDecl(TypeResolver.STRING_TYPE);
            binary.getLeft().replaceWith(new Binary(BinaryOp.ADD, new LdcString(string, ""), binary.getLeft()));
            ctx.popStep();
        }
        ctx.popStep();
    }

    @Contract ("null->null")
    private static @Nullable Binary matchBinaryAdd(@Nullable Instruction insn) {
        if (!(insn instanceof Binary binary)) return null;
        if (binary.getOp() != BinaryOp.ADD) return null;

        return binary;
    }

    private boolean loadIsFirstOpIn(@Nullable Instruction e, Load inlineTarget) {
        // because the expression we're inlining may be impure, we need to make sure we don't inline over something else with side effects
        // for now, we can be conservative, and only inline to the first executed slot
        // we try and design the instruction system to have children in semantic execution order, but it's not actually enforced anywhere
        // as such, we'll just match known cases here, rather than recursively matching first child for eg

        // current known cases:
        //  return from try-finally
        //  record pattern getter results
        //  synthetic compound assignment on a constructor result
        //  pattern instanceof
        //  Java25 method ref as lambda target

        if (inlineTarget.getParent() instanceof Return ret && ret == e) return true;

        // STORE (LOCAL ..., inlineTarget)
        if (e instanceof Store store && store.getReference() instanceof LocalReference && store.getValue() == inlineTarget) return true;

        if ((inlineTarget.getParent() instanceof FieldReference fieldRef) && fieldRef.getParent() instanceof CompoundAssignment assignment && assignment == e) {
            return true; // a field ref is not a value on its own, so it can only be the 'target' of a compound assign
        }

        if (inlineTarget.getParent() instanceof InstanceOf instanceOf && !(instanceOf.getPattern() instanceof Nop)) {
            return true;
        }

        // Java 25 mref as lambda target. We can't actually detect this is a method reference, because we
        // don't have access to the lambda body.
        if (inlineTarget.getParent() instanceof InvokeDynamic indy && indy.arguments.onlyOrDefault() == inlineTarget) {
            return true;
        }

        return false;
    }

    @Nullable
    private Instruction matchLoadIn(Instruction e, Load inlineTarget) {
        return inlineTarget.isDescendantOf(e) ? e : null;
    }

    private static boolean isKnownSyntheticVariable(LocalVariable var) {
        return var.isSynthetic()
               // JDK-8372635: https://github.com/openjdk/jdk/pull/28724
               // Javac doesn't preserve its internal flags for synthetic local variables when lowering a lambda,
               // Instanceof patterns which require a temp local
               || var.getName().startsWith("patt") && var.getName().endsWith("$temp")
               // Java 25 mref to lambda target/receiver
               || var.getName().equals("rec$");
    }

    @Nullable
    public static <T> T matchWithPotentialInline(@Nullable Instruction insn, List<Runnable> extraTransforms, MethodTransformContext ctx, Function<Instruction, @Nullable T> matcher) {
        if (insn == null) return null;
        T matched = matcher.apply(insn);
        if (matched != null) return matched;
        if (!canInlineIfRequired(insn, extraTransforms, ctx)) return null;
        return matcher.apply(insn.getNextSiblingOrNull());
    }

    public static boolean canInlineIfRequired(Instruction expr, List<Runnable> extraInliningTasks, MethodTransformContext ctx) {
        if (expr.getNextSiblingOrNull() == null) return false;

        if (canInlineIInc(expr, extraInliningTasks, ctx)) return true;
        if (canInlinePotentialConstantLookup(expr, extraInliningTasks, ctx)) return true;
        return false;
    }

    private static boolean canInlineIInc(Instruction expr, List<Runnable> extraInliningTasks, MethodTransformContext ctx) {
        if (!(expr.getTag() instanceof IIncTag)) return false;

        Load potentialInline = ((IIncTag) expr.getTag()).potentialInline;
        if (potentialInline == null) return false;

        extraInliningTasks.add(() -> {
            ctx.pushStep("Inline iinc");
            potentialInline.replaceWith(expr);
            ctx.popStep();
        });
        return true;
    }

    private static boolean canInlinePotentialConstantLookup(Instruction expr, List<Runnable> extraInliningTasks, MethodTransformContext ctx) {
        if (!(expr.getTag() instanceof PotentialConstantLookupTag(boolean isStatic, LdcInsn ldc))) return false;
        Object ldcValue = requireNonNull(ldc.getRawValue());

        Instruction target;
        if (isStatic) {
            target = expr;
        } else {
            Invoke invoke = (Invoke) expr;
            // >= J9 synthetic null checks will be static invoke to `Objects.requireNonNull(expr)`.
            if (invoke.getKind() == Invoke.InvokeKind.STATIC) {
                target = invoke.getArguments().first();
            } else {
                // <= J8 will be `expr.getClass()` invokes.
                target = ((Invoke) expr).getTarget();
            }
        }
        Field constantField = ((ClassType) target.getResultType()).findConstant(ldcValue, isStatic);
        if (constantField == null) {
            expr.setTag(null);
            return false;
        }

        assert ldc.isConnected();

        extraInliningTasks.add(() -> {
            ctx.pushStep("Inline constant lookup");
            ldc.replaceWith(new Load(new FieldReference(constantField, target)));
            if (expr != target) {
                expr.remove();
            }
            ctx.popStep();
        });
        return true;
    }
}
