package net.covers1624.coffeegrinder.source;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.DebugPrintOptions;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.Invoke.InvokeKind;
import net.covers1624.coffeegrinder.bytecode.insns.LocalVariable.VariableKind;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.source.LineBuffer.of;
import static net.covers1624.coffeegrinder.source.LineBuffer.paren;
import static net.covers1624.coffeegrinder.type.TypeSystem.isObject;

/**
 * Created by covers1624 on 28/7/22.
 */
public class AstSourceVisitor extends AbstractSourceVisitor {

    private static final Map<Class<?>, String> AST_NAMES = new ConcurrentHashMap<>();

    private final DebugPrintOptions opts;

    public AstSourceVisitor(DebugPrintOptions opts) {
        super(null);
        this.opts = opts;
    }

    @Override
    protected boolean showImplicits() {
        return opts.showImplicits();
    }

    private LineBuffer braced(Supplier<LineBuffer> body) {
        LineBuffer buffer = of("{");
        pushIndent();
        buffer = buffer.join(indent(body.get()));
        popIndent();
        return buffer.add("}");
    }

    private LineBuffer header(Instruction insn) {
        return header(insn, true);
    }

    private LineBuffer header(Instruction insn, boolean printRange) {
        return of()
                .append(opts.printRanges() && printRange ? range(insn).append(" ") : of())
                .append(opts.printTags() && insn.getTag() != null ? "[" + insn.getTag().describe() + "] " : "")
                .append(name(insn));
    }

    private String name(Instruction insn) {
        return switch (insn) {
            case LocalReference ignored -> "LOCAL";
            case ArrayElementReference ignored -> "ELEM";
            case FieldReference ignored -> "FIELD";
            case TryCatch.TryCatchHandler ignored -> "CATCH";
            case IfInstruction ignored -> "IF";
            case Cast ignored -> "CHECK_CAST";
            case DoWhileLoop ignored -> "DO_WHILE";
            default -> AST_NAMES.computeIfAbsent(insn.getClass(), AstSourceVisitor::computeName);
        };
    }

    private static String computeName(Class<?> clazz) {
        StringBuilder sb = new StringBuilder();
        for (char c : clazz.getSimpleName().toCharArray()) {
            if (!sb.isEmpty() && Character.isUpperCase(c)) {
                sb.append("_");
            }
            sb.append(Character.toUpperCase(c));
        }
        return sb.toString();
    }

    private LineBuffer range(Instruction insn) {
        int offset = insn.getBytecodeOffset();
        return of(offset == -1 ? "[?]" : "[" + offset + "]");
    }

    public List<String> getImports(@Nullable ClassDecl ctx) {
        return importCollector.getImports(ctx);
    }

    private LineBuffer typeQualifier(ClassType type, boolean mustPrint) {
        if (!mustPrint && !opts.qualifiedMemberReferences()) return of();
        return of(importCollector.collect(type))
                .append(".");
    }

    protected LineBuffer optionalArg(Instruction arg) {
        if (arg instanceof Nop) return of();

        return of(" ").append(argList(arg));
    }

    private LineBuffer typeParamDecl(List<TypeParameter> params, boolean padEnd) {
        if (params.isEmpty()) return of();

        return of("<")
                .append(FastStream.of(params).map(importCollector::collectSimpleTypeParam).join(", "))
                .append(">")
                .append(padEnd ? " " : "");
    }

    private LineBuffer parameterize(Method method, boolean explicit) {
        if (method instanceof ParameterizedMethod) {
            List<ReferenceType> typeArguments = ((ParameterizedMethod) method).getTypeArguments();
            if (!typeArguments.isEmpty()) {
                return parameterize(typeArguments, explicit);
            }
        }
        if (method.hasTypeParameters()) {
            return of("<???>");// bound but args not yet inferred
        }

        return of();
    }

    private LineBuffer parameterize(ClassType clazz, boolean explicit) {
        if (clazz instanceof ParameterizedClass) {
            List<ReferenceType> typeArguments = ((ParameterizedClass) clazz).getTypeArguments();
            if (!typeArguments.isEmpty()) {
                return parameterize(typeArguments, explicit);
            }
        } else if (clazz.hasTypeParameters()) {
            return of("<???>");
        }

        return of();
    }

    private LineBuffer parameterize(List<ReferenceType> typeArguments, boolean explicit) {
        LineBuffer buffer = of();
        if (!typeArguments.isEmpty()) {
            if (opts.showImplicits() || explicit) {
                if (!explicit) {
                    // Implicit arguments
                    buffer = buffer.append("~");
                }
                buffer = buffer.append("<");
                for (int i = 0; i < typeArguments.size(); i++) {
                    ReferenceType arg = typeArguments.get(i);
                    if (i > 0) {
                        buffer = buffer.append(", ");
                    }
                    buffer = buffer.append(importCollector.collect(arg));
                }
                buffer = buffer.append(">");
            } else {
                // Implicit arguments
                buffer = buffer.append("<>");
            }
        }
        return buffer;
    }

    @Override
    public LineBuffer visitDefault(Instruction insn, None ctx) {
        return printDefault(insn);
    }

    private LineBuffer printDefault(Instruction insn) {
        return printDefault(insn, null);
    }

    private LineBuffer printDefault(Instruction insn, @Nullable Object extra) {
        boolean hasChildren = insn.getFirstChildOrNull() != null;

        return header(insn)
                .append(extra != null ? "." + extra : "")
                .append(hasChildren ? of(" ").append(argList(insn.getChildren())) : of());
    }

    @Override
    public LineBuffer visitNop(Nop nop, None ctx) {
        return header(nop, false);
    }

    @Override
    public LineBuffer visitDeadCode(DeadCode deadCode, None ctx) {
        return header(deadCode)
                .append(" ")
                .append(argList(deadCode.getCode()));
    }

    @Override
    public LineBuffer visitArrayElementReference(ArrayElementReference elemRef, None ctx) {
        return printDefault(elemRef);
    }

    @Override
    public LineBuffer visitBinary(Binary binary, None ctx) {
        return printDefault(binary, binary.getOp());
    }

    @Override
    public LineBuffer visitBlock(Block block, None ctx) {
        return header(block)
                .append(" ")
                .append(block.getName())
                .append(block.getParentOrNull() instanceof BlockContainer ? " (incoming: " + block.getIncomingEdgeCount() + ")" : "")
                .append(" ")
                .append(braced(() -> visitBlockBody(block)));
    }

    private LineBuffer visitBlockBody(Block block) {
        LineBuffer buffer = of();
        for (Instruction insn : block.instructions) {
            if (opts.printLineNumbers()) {
                buffer = buffer.add("# LN: " + insn.getSourceLine());
            }
            buffer = buffer.join(lines(insn));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitBlockContainer(BlockContainer container, None ctx) {
        return header(container)
                .append(" ")
                .append(braced(() -> visitContainerBlocks(container)));
    }

    private LineBuffer visitContainerBlocks(BlockContainer container) {
        LineBuffer buffer = of();
        boolean first = true;
        for (Block block : container.blocks) {
            if (!first) {
                buffer = buffer.add("");
            }
            first = false;
            buffer = buffer.join(lines(block));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitBranch(Branch branch, None ctx) {
        return printDefault(branch)
                .append(" ")
                .append(branch.getTargetBlock().getName());
    }

    @Override
    public LineBuffer visitCheckCast(Cast cast, None ctx) {
        return header(cast)
                .append(" " + importCollector.collect(cast.getType()))
                .append(argList(cast.getArgument()));
    }

    @Override
    public LineBuffer visitClassDecl(ClassDecl classDecl, None ctx) {
        return header(classDecl, false)
                .append(" ")
                .append(!(classDecl.getParentOrNull() instanceof New) ? printClassHeader(classDecl) : of())
                .append(braced(() -> visitClassMembers(classDecl)));
    }

    private LineBuffer printClassHeader(ClassDecl classDecl) {
        ClassType clazz = classDecl.getClazz();
        return of(AccessFlag.stringRep(clazz.getAccessFlags()))
                .append(clazz.getName())
                .append(typeParamDecl(clazz.getTypeParameters(), false))
                .append(visitSuperClass(clazz))
                .append(visitInterfaces(clazz))
                .append(" ");
    }

    private LineBuffer visitSuperClass(ClassType type) {
        if (isObject(type.getSuperClass())) return of();

        return of(" extends ")
                .append(importCollector.collect(type.getSuperClass()));
    }

    private LineBuffer visitInterfaces(ClassType type) {
        List<ClassType> iFaces = type.getInterfaces();
        if (iFaces.isEmpty()) return of();

        LineBuffer buffer = of();
        if (type.isInterface()) {
            buffer = buffer.append(" extends ");
        } else {
            buffer = buffer.append(" implements ");
        }

        return buffer.append(FastStream.of(iFaces).map(importCollector::collect).join(", "));
    }

    private LineBuffer visitClassMembers(ClassDecl decl) {
        LineBuffer buffer = of();
        for (Instruction field : decl.members) {
            buffer = buffer.add("")
                    .join(lines(field));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitCompare(Compare compare, None ctx) {
        return printDefault(compare, compare.getKind());
    }

    @Override
    public LineBuffer visitComparison(Comparison comparison, None ctx) {
        return printDefault(comparison, comparison.getKind());
    }

    @Override
    public LineBuffer visitCompoundAssignment(CompoundAssignment comp, None ctx) {
        return printDefault(comp, comp.getOp());
    }

    @Override
    public LineBuffer visitContinue(Continue cont, None ctx) {
        return printDefault(cont)
                .append(" ")
                .append(cont.getLoop().getBody().getEntryPoint().getName());
    }

    @Override
    public LineBuffer visitDoWhileLoop(DoWhileLoop doWhileLoop, None ctx) {
        return header(doWhileLoop)
                .append(" ")
                .append(lines(doWhileLoop.getBody()))
                .append(" ")
                .append(argList(doWhileLoop.getCondition()));
    }

    @Override
    public LineBuffer visitFieldDecl(FieldDecl fieldDecl, None ctx) {
        Field field = fieldDecl.getField();

        return header(fieldDecl, false)
                .append(" ")
                .append(AccessFlag.toString(field.getAccessFlags()))
                .append(importCollector.collect(field.getType()))
                .append(" ")
                .append(field.getName())
                .append(optionalArg(fieldDecl.getValue()));
    }

    @Override
    public LineBuffer visitFieldReference(FieldReference fieldRef, None ctx) {
        Field field = fieldRef.getField();

        return header(fieldRef, false)
                .append(optionalArg(fieldRef.getTarget()))
                .append(" ")
                .append(typeQualifier(field.getDeclaringClass(), field.isStatic()))
                .append(field.getName());
    }

    @Override
    public LineBuffer visitForEachLoop(ForEachLoop forEachLoop, None ctx) {
        return header(forEachLoop)
                .append(" ")
                .append(argList(forEachLoop.getVariable(), forEachLoop.getIterator()))
                .append(" ")
                .append(lines(forEachLoop.getBody()));
    }

    @Override
    public LineBuffer visitForLoop(ForLoop forLoop, None ctx) {
        return header(forLoop)
                .append(" ")
                .append(argList(forLoop.getInitializer(), forLoop.getCondition(), forLoop.getIncrement()))
                .append(" ")
                .append(lines(forLoop.getBody()));
    }

    @Override
    public LineBuffer visitIfInstruction(IfInstruction ifInsn, None ctx) {
        LineBuffer buffer = header(ifInsn)
                .append(" ")
                .append(argList(ifInsn.getCondition()))
                .append(" ")
                .append(lines(ifInsn.getTrueInsn()));

        if (!(ifInsn.getFalseInsn() instanceof Nop)) {
            buffer = buffer.add("ELSE ")
                    .append(lines(ifInsn.getFalseInsn()));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitLocalVariable(LocalVariable lv, None ctx) {

        return of(lv.getKind().name())
                .append(" ")
                .append(lv.getUniqueName())
                .append(" : ")
                .append(importCollector.collect(lv.getType()))
                .append("(")
                .append(lv.getKind() != VariableKind.STACK_SLOT ? "Index=" + lv.getIndex() + ", " : "")
                .append("LoadCount=" + lv.getLoadCount())
                .append(", StoreCount=" + lv.getStoreCount())
                .append(lv.isSynthetic() ? ", Synthetic" : "")
                .append(lv.getKind() == VariableKind.PARAMETER && ((ParameterVariable) lv).isImplicit() ? ", Implicit" : "")
                .append(")");
    }

    @Override
    public LineBuffer visitPostIncrement(PostIncrement postIncrement, None ctx) {
        return printDefault(postIncrement, postIncrement.isPositive() ? "ADD" : "SUB");
    }

    @Override
    public LineBuffer visitInstanceOf(InstanceOf instanceOf, None ctx) {
        return header(instanceOf)
                .append(" " + importCollector.collect(instanceOf.getType()))
                .append(argList(instanceOf.getChildren().filterNot(e -> e instanceof Nop)));
    }

    @Override
    public LineBuffer visitInvoke(Invoke invoke, None ctx) {
        Method method = invoke.getMethod();

        return header(invoke)
                .append(".")
                .append(invoke.getKind().name())
                .append(optionalArg(invoke.getTarget()))
                .append(" ")
                .append(typeQualifier(method.getDeclaringClass(), invoke.getKind() == InvokeKind.SPECIAL || method.isStatic()))
                .append(parameterize(method, invoke.explicitTypeArgs))
                .append(method.getName())
                .append(argList(invoke.getArguments()));
    }

    @Override
    public LineBuffer visitInvokeDynamic(InvokeDynamic indy, None ctx) {
        Method bsm = indy.bootstrapHandle;
        return header(indy)
                .append(" ")
                .append(importCollector.collect(bsm.getDeclaringClass()))
                .append(".")
                .append(bsm.getName())
                .append(paren(FastStream.concat(FastStream.of(indy.name), FastStream.of(indy.bootstrapArguments))
                        .map(this::debugIndyBSMArg)
                        .fold(of(), (a, b) -> {
                            assert a != null;
                            if (!a.lines.isEmpty()) a = a.append(", ");
                            return a.append(b);
                        })
                ))
                .append(argList(indy.arguments));
    }

    private LineBuffer ldc(LdcInsn ldc) {
        return ldc(ldc, requireNonNull(ldc.getRawValue()));
    }

    private LineBuffer ldc(LdcInsn ldc, Object value) {
        return printDefault(ldc).append(" " + value);
    }

    @Override
    public LineBuffer visitLdcBoolean(LdcBoolean ldcBoolean, None ctx) {
        return ldc(ldcBoolean);
    }

    @Override
    public LineBuffer visitLdcChar(LdcChar ldcChar, None ctx) {
        return ldc(ldcChar, "'" + EscapeUtils.escapeChar(ldcChar.getValue()) + "'");
    }

    @Override
    public LineBuffer visitLdcClass(LdcClass ldcClass, None ctx) {
        return ldc(ldcClass, importCollector.collect(ldcClass.getType()));
    }

    @Override
    public LineBuffer visitLdcNumber(LdcNumber ldcNumber, None ctx) {
        String suffix = switch (ldcNumber.getValue()) {
            case Double v -> "D";
            case Float v -> "F";
            case Long l -> "L";
            default -> "";
        };
        return ldc(ldcNumber, ldcNumber.getValue() + suffix);
    }

    @Override
    public LineBuffer visitLdcString(LdcString ldcString, None ctx) {
        return ldc(ldcString, "\"" + EscapeUtils.escapeChars(ldcString.getValue()) + "\"");
    }

    @Override
    public LineBuffer visitLeave(Leave leave, None ctx) {
        return header(leave)
                .append(" ")
                .append(leave.getTargetContainer().getEntryPoint().getName());
    }

    @Override
    public LineBuffer visitReturn(Return ret, None ctx) {
        return header(ret)
                .append(optionalArg(ret.getValue()));
    }

    @Override
    public LineBuffer visitLoad(Load load, None ctx) {
        if (opts.hideLoads()) return lines(load.getReference());

        return printDefault(load);
    }

    @Override
    public LineBuffer visitLocalReference(LocalReference localRef, None ctx) {
        return printDefault(localRef)
                .append(" ")
                .append(localRef.variable.getUniqueName());
    }

    @Override
    public LineBuffer visitMethodDecl(MethodDecl methodDecl, None ctx) {
        boolean isLambda = methodDecl.getParentOrNull() != null && !(methodDecl.getParent() instanceof ClassDecl);

        return header(methodDecl, false)
                .append(" ")
                .append(!isLambda ? visitMethodHeader(methodDecl) : visitLambdaHeader(methodDecl))
                .append(braced(() -> visitMethodBody(methodDecl)));
    }

    private LineBuffer visitMethodHeader(MethodDecl methodDecl) {
        Method method = methodDecl.getMethod();

        return of(AccessFlag.stringRep(method.getAccessFlags()))
                .append(typeParamDecl(method.getTypeParameters(), true))
                .append(importCollector.collect(methodDecl.getReturnType()))
                .append(" ")
                .append(method.getName())
                .append(" ");
    }

    private LineBuffer visitLambdaHeader(MethodDecl lambda) {
        return of()
                .append(importCollector.collect(lambda.getReturnType()))
                .append(" : ")
                .append(importCollector.collect(lambda.getResultType()))
                .append(" ");
    }

    private LineBuffer visitMethodBody(MethodDecl methodDecl) {
        LineBuffer buffer = of();
        for (ParameterVariable variable : methodDecl.parameters) {
            if (!variable.isImplicit() || opts.showImplicits()) {
                buffer = buffer.join(lines(variable));
            }
        }
        for (LocalVariable variable : methodDecl.variables) {
            buffer = buffer.join(lines(variable));
        }

        if (methodDecl.hasBody()) {
            buffer = buffer.join(lines(methodDecl.getBody()));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitMethodReference(MethodReference mRef, None ctx) {
        Method method = mRef.getMethod();

        return header(mRef)
                .append(optionalArg(mRef.getTarget()))
                .append(" ")
                .append(importCollector.collect(method.getDeclaringClass()))
                .append("::")
                .append(method.isConstructor() ? "new" : method.getName());
    }

    @Override
    public LineBuffer visitNew(New newInsn, None ctx) {
        Method method = newInsn.getMethod();

        return header(newInsn)
                .append(newInsn.getTarget() != null ? ".INNER_INST" : "")
                .append(" ")
                .append(parameterize(method, newInsn.explicitTypeArgs))
                .append(importCollector.collect(newInsn.getResultType().asRaw()))
                .append(parameterize(newInsn.getResultType(), newInsn.explicitClassTypeArgs))
                .append(argList(newInsn.getArguments()))
                .append(newInsn.hasAnonymousClassDeclaration() ? lines(newInsn.getAnonymousClassDeclaration()).prepend(" ") : of());
    }

    @Override
    public LineBuffer visitNewObject(NewObject newObject, None ctx) {
        return printDefault(newObject)
                .append(" ")
                .append(importCollector.collect(newObject.getType()));
    }

    @Override
    public LineBuffer visitSwitch(Switch switchInsn, None ctx) {
        return header(switchInsn)
                .append(" ")
                .append(argList(switchInsn.getValue()))
                .append(" ")
                .append(lines(switchInsn.getBody()));
    }

    @Override
    public LineBuffer visitSwitchTable(SwitchTable switchTable, None ctx) {
        return header(switchTable)
                .append(optionalArg(switchTable.getValue()))
                .append(" ")
                .append(braced(() -> switchTable.sections
                        .map(this::lines)
                        .fold(of(), (a, b) -> a.join(b))
                ));
    }

    @Override
    public LineBuffer visitSwitchSection(SwitchTable.SwitchSection switchSection, None ctx) {
        LineBuffer buffer = of();
        for (Instruction value : switchSection.values) {
            if (value instanceof Nop) {
                buffer = buffer.add("default: ");
            } else {
                buffer = buffer.add("case ")
                        .append(lines(value))
                        .append(": ");
            }
        }
        return buffer.append(lines(switchSection.getBody()));
    }

    @Override
    public LineBuffer visitSynchronized(Synchronized synchInsn, None ctx) {
        return header(synchInsn)
                .append(argList(synchInsn.getVariable()))
                .append(" ")
                .append(lines(synchInsn.getBody()));
    }

    @Override
    public LineBuffer visitTryCatch(TryCatch tryCatch, None ctx) {
        LineBuffer buffer = header(tryCatch)
                .append(" ")
                .append(lines(tryCatch.getTryBody()));

        for (TryCatch.TryCatchHandler handler : tryCatch.handlers) {
            buffer = buffer.join(lines(handler));
        }
        return buffer;
    }

    @Override
    public LineBuffer visitTryCatchHandler(TryCatch.TryCatchHandler catchHandler, None ctx) {
        return header(catchHandler, false)
                .append(" (")
                .append(importCollector.collect(catchHandler.getVariable().getType()))
                .append(" ")
                .append(catchHandler.getVariable().variable.getUniqueName())
                .append(catchHandler.isUnprocessedFinally ? ", Unprocessed Finally" : "")
                .append(") ")
                .append(lines(catchHandler.getBody()));
    }

    @Override
    public LineBuffer visitTryFinally(TryFinally tryFinally, None ctx) {
        return header(tryFinally)
                .append(" ")
                .append(lines(tryFinally.getTryBody()))
                .add("FINALLY ")
                .append(lines(tryFinally.getFinallyBody()));
    }

    @Override
    public LineBuffer visitTryWithResources(TryWithResources tryWithResources, None ctx) {
        return header(tryWithResources)
                .append(" ")
                .append(argList(tryWithResources.getResource()))
                .append(" ")
                .append(lines(tryWithResources.getTryBody()));
    }

    @Override
    public LineBuffer visitWhileLoop(WhileLoop whileLoop, None ctx) {
        return header(whileLoop)
                .append(" ")
                .append(argList(whileLoop.getCondition()))
                .append(" ")
                .append(lines(whileLoop.getBody()));
    }
}
