package net.covers1624.coffeegrinder;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import it.unimi.dsi.fastutil.ints.IntArraySet;
import it.unimi.dsi.fastutil.ints.IntSet;
import joptsimple.OptionParser;
import joptsimple.OptionSet;
import joptsimple.OptionSpec;
import joptsimple.util.PathConverter;
import net.covers1624.coffeegrinder.bytecode.ClassProcessor;
import net.covers1624.coffeegrinder.bytecode.insns.ClassDecl;
import net.covers1624.coffeegrinder.debug.Debugger;
import net.covers1624.coffeegrinder.debug.NullStepper;
import net.covers1624.coffeegrinder.source.JavaSourceVisitor;
import net.covers1624.coffeegrinder.source.LineBuffer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.util.jvm.JVMUtils;
import net.covers1624.coffeegrinder.util.resolver.ClassResolver;
import net.covers1624.coffeegrinder.util.resolver.Resolver;
import net.covers1624.jdkutils.JavaInstall;
import net.covers1624.quack.collection.FastStream;
import net.covers1624.quack.io.IOUtils;
import net.covers1624.quack.io.IndentPrintWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;

/**
 * Created by covers1624 on 20/2/21.
 */
public class CoffeeGrinder {

    private static final Logger LOGGER = LoggerFactory.getLogger(CoffeeGrinder.class);

    private static final IntSet KNOWN_SUPPORTED = IntArraySet.of(8, 9, 10, 11, 12, 13, 14, 15, 16, 17);

    public static void main(String[] args) throws Throwable {
        System.exit(mainI(args));
    }

    public static int mainI(String[] args) throws Throwable {
        OptionParser parser = new OptionParser();
        OptionSpec<String> nonOptions = parser.nonOptions();

        OptionSpec<Void> helpOpt = parser.acceptsAll(asList("h", "help"), "Prints this help").forHelp();
        OptionSpec<Path> inputOpt = parser.acceptsAll(asList("i", "input"), "Sets the input directory.")
                .withRequiredArg()
                .required()
                .withValuesConvertedBy(new PathConverter());

        OptionSpec<Path> libsOpt = parser.acceptsAll(asList("l", "lib"), "Specifies an optional library for context.")
                .withRequiredArg()
                .withValuesSeparatedBy(File.pathSeparatorChar)
                .withValuesConvertedBy(new PathConverter());
        OptionSpec<Path> jreRefOpt = parser.acceptsAll(asList("j", "jreRef"), "Set to the Java executable of JRE to use as reference when decompiling. If not specified, uses the Decompilers JVM.")
                .withRequiredArg()
                .withValuesConvertedBy(new PathConverter());

        OptionSpec<String> debuggerOpt = parser.accepts("debugger", "Start an interactive http debugger server on the specified address + port.")
                .withRequiredArg()
                .defaultsTo("localhost:8081");

        OptionSpec<Path> outSourceOpt = parser.acceptsAll(asList("o", "output"), "Specifies the location to dump output source files.")
                .availableUnless(debuggerOpt)
                .withRequiredArg()
                .withValuesConvertedBy(new PathConverter());
        OptionSpec<Path> outAstOpt = parser.acceptsAll(asList("a", "output-ast"), "Specifies the location to dump output AST files.")
                .availableUnless(debuggerOpt)
                .withRequiredArg()
                .withValuesConvertedBy(new PathConverter());

        OptionSpec<Integer> threadsOpt = parser.acceptsAll(asList("t", "threads"), "Specifies the number of threads to use when decompiling. "
                                                                                   + "The default value will be the number of cores - 1. To disable threading set to 1.")
                .availableUnless(debuggerOpt)
                .withRequiredArg()
                .ofType(Integer.class)
                .defaultsTo(Runtime.getRuntime().availableProcessors() - 1);

        OptionSpec<Path> reportOpt = parser.acceptsAll(asList("r", "report"), "Prints a decompilation report to the specified path on disk.")
                .availableUnless(debuggerOpt)
                .withRequiredArg()
                .withValuesConvertedBy(new PathConverter());

        OptionSet optSet = parser.parse(args);
        if (optSet.has(helpOpt)) {
            parser.printHelpOn(System.err);
            return -1;
        }

        Path inputPath = optSet.valueOf(inputOpt);

        if (Files.notExists(inputPath)) {
            LOGGER.error("Expected '--input' path to exist.");
            parser.printHelpOn(System.err);
            return -1;
        }

        ClassResolver classResolver = new ClassResolver();
        if (optSet.has(jreRefOpt)) {
            JavaInstall javaInstall = JVMUtils.getJavaInstall(optSet.valueOf(jreRefOpt));
            classResolver.addResolvers(JVMUtils.getTargetJREClasspath(javaInstall));
        } else {
            classResolver.addResolvers(JVMUtils.getRuntimeJREClasspath());
        }

        for (Path path : optSet.valuesOf(libsOpt)) {
            classResolver.addResolver(path);
        }

        classResolver.setTarget(inputPath);

        if (optSet.has(debuggerOpt)) {
            if (Debugger.DEBUGGER == null) {
                LOGGER.error("Debugger is not available in this environment, please ensure it's included on the runtime classpath.");
                return -1;
            }
            Debugger.DEBUGGER.startDebugger(optSet.valueOf(debuggerOpt), classResolver);
            return 0;
        }

        Path sourceOutput;
        Path astOutput;
        if (optSet.has(outSourceOpt)) {
            sourceOutput = optSet.valueOf(outSourceOpt);
            if (Files.exists(sourceOutput) && !Files.isDirectory(sourceOutput)) {
                LOGGER.error("Expected '--output' to be a folder.");
                parser.printHelpOn(System.err);
                return -1;
            }
        } else {
            sourceOutput = null;
        }

        if (optSet.has(outAstOpt)) {
            astOutput = optSet.valueOf(outAstOpt);
            if (Files.exists(astOutput) && !Files.isDirectory(astOutput)) {
                LOGGER.error("Expected '--output-ast' to be a folder.");
                parser.printHelpOn(System.err);
                return -1;
            }
        } else {
            astOutput = null;
        }
        if (sourceOutput == null && astOutput == null) {
            LOGGER.error("Expected '--output' or '--output-ast'.");
            return -1;
        }

        int threads = optSet.valueOf(threadsOpt);
        if (threads < 1) {
            LOGGER.error("Expected '--threads' value to be greater than 1.");
            parser.printHelpOn(System.err);
            return -1;
        }

        TypeResolver typeResolver = new TypeResolver(classResolver);
        DecompilerSettings settings = new DecompilerSettings();

        Resolver resolver = classResolver.getTargetResolver();

        ExecutorService executor = Executors.newFixedThreadPool(threads, new ThreadFactoryBuilder()
                .setDaemon(true)
                .setNameFormat("Decompile Thread %d")
                .build()
        );

        List<String> classes = resolver.getAllClasses().toLinkedList();
        LOGGER.info("Decompiling {} classes.", classes.size());

        IntSet warnedUnsupportedClassVersions = new IntArraySet();

        long decompStart = System.nanoTime();
        List<Failure> failures = Collections.synchronizedList(new LinkedList<>());
        for (String clazz : classes) {
            // TODO, J9+ Module info files.
            if (clazz.endsWith("module-info")) continue;

            // Only process top-level classes. ClassProcessor pulls these in automatically.
            ClassType clazzType = typeResolver.resolveClassDecl(clazz);
            if (clazzType.getDeclType() != ClassType.DeclType.TOP_LEVEL) continue;

            int classVersion = clazzType.getClassVersion().ordinal() + 1;
            if (!KNOWN_SUPPORTED.contains(classVersion) && warnedUnsupportedClassVersions.add(classVersion)) {
                // TODO print full class name? perhaps behind a --verbose cli switch?
                LOGGER.warn("Found classes compiled for java {}, this version is not known to be supported yet. Things may break.", classVersion);
            }

            executor.submit(() -> {
                try {
                    ClassProcessor processor = new ClassProcessor(typeResolver, clazzType, settings);
                    ClassDecl decl = processor.process(NullStepper.INSTANCE);
                    if (astOutput != null) {
                        Path output = astOutput.resolve(clazz + ".ast");
                        Files.createDirectories(output.getParent());
                        Files.write(output, decl.toString().getBytes(StandardCharsets.UTF_8));
                    }
                    if (sourceOutput != null) {
                        LineBuffer lines = decl.accept(new JavaSourceVisitor(typeResolver));
                        Path output = sourceOutput.resolve(clazz + ".java");
                        Files.createDirectories(output.getParent());
                        Files.write(output, String.join("\n", lines.lines).getBytes(StandardCharsets.UTF_8));
                    }
                } catch (Throwable e) {
                    failures.add(new Failure(clazz, e));
                    LOGGER.error("Failed to decompile: {}", clazz, e);
                }
            });
        }

        executor.shutdown();
        while (!executor.awaitTermination(5, TimeUnit.MINUTES)) {
            LOGGER.info("Waiting for tasks...");
        }
        long decompDuration = System.nanoTime() - decompStart;
        LOGGER.info("Finished. Took {}", formatDuration(decompDuration));

        String failureReport = buildFailureReport(failures);
        if (!failureReport.isEmpty()) {
            LOGGER.error("\n{}", failureReport);
        }

        if (optSet.has(reportOpt)) {
            try (PrintWriter pw = new PrintWriter(Files.newBufferedWriter(IOUtils.makeParents(optSet.valueOf(reportOpt))), true)) {
                pw.printf("Decompiled %d classes in %s%n", classes.size(), formatDuration(decompDuration));

                if (!failureReport.isEmpty()) {
                    pw.println();
                    pw.println(failureReport);
                }
            }
        }

        return !failures.isEmpty() ? 1 : 0;
    }

    private static String buildFailureReport(List<Failure> failures) {
        if (failures.isEmpty()) return "";

        StringWriter failureReport = new StringWriter();
        IndentPrintWriter pw = new IndentPrintWriter(new PrintWriter(failureReport, true));
        pw.printf("### %d classes failed! ###%n", failures.size());
        FastStream.of(failures)
                .groupBy(e -> trimmedStackTrace(requireNonNull(e.ex), 7))
                .sorted(Comparator.comparingInt(e -> -e.count()))
                .forEach(g -> {
                    int count = g.count();
                    pw.printf("%dx ", count);
                    pw.println(g.getKey());
                    g.limit(3)
                            .map(e -> e.clazz)
                            .forEach(x -> pw.printf("in %s%n", x));
                    if (count > 3) {
                        pw.printf("in (%d more)%n", count - 3);
                    }
                    pw.println();
                });
        return failureReport.toString();
    }

    private static String formatDuration(long elapsedTimeInNs) {
        StringBuilder result = new StringBuilder();
        if (elapsedTimeInNs >= 3600000000000L) {
            result.append(elapsedTimeInNs / 3600000000000L).append("h ");
        }

        if (elapsedTimeInNs >= 60000000000L) {
            result.append(elapsedTimeInNs % 3600000000000L / 60000000000L).append("m ");
        }

        if (elapsedTimeInNs >= 1000000000L) {
            result.append(elapsedTimeInNs % 60000000000L / 1000000000L).append("s ");
        }

        if (elapsedTimeInNs >= 1000000L) {
            result.append(elapsedTimeInNs % 1000000000L / 1000000L).append("ms");
        } else {
            return "< 1ms";
        }

        return result.toString();
    }

    private static String trimmedStackTrace(Throwable ex, int nLines) {
        StackTraceElement[] trace = ex.getStackTrace();
        StringBuilder builder = new StringBuilder();
        builder.append(ex.getClass().getName()).append(" ").append(ex.getMessage());
        if (trace.length == 0) {
            builder.append(" No stack trace");
        } else {
            int numTrace = Math.min(nLines, trace.length);
            for (int i = 0; i < numTrace; i++) {
                builder.append("\n\tat ").append(trace[i]);
            }
        }
        return builder.toString();
    }

    private record Failure(String clazz, Throwable ex) { }
}
