package deepimagej;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import deepimagej.tools.DijTensor;
import deepimagej.tools.FileTools;
import ij.IJ;
import ij.gui.GenericDialog;
import java.awt.TextArea;
import java.io.File;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import org.tensorflow.SavedModelBundle;

/* loaded from: input_file:deepimagej/DeepImageJ.class */
public class DeepImageJ {
    private String path;
    public String dirname;
    public Parameters params;
    private boolean valid;
    public boolean presentYaml;
    private boolean developer;
    private SavedModelBundle tfModel = null;
    private ZooModel<NDList, NDList> torchModel = null;
    public String ptName = "pytorch_script.pt";
    public String ptName2 = "weights-torchscript.pt";
    public String tfName = "tensorflow_saved_model_bundle.zip";

    public DeepImageJ(String str, String str2, boolean z) {
        this.valid = false;
        this.presentYaml = true;
        this.developer = true;
        String str3 = String.valueOf(str) + File.separator + str2 + File.separator;
        this.path = str3.replace(String.valueOf(File.separator) + File.separator, File.separator);
        this.path = cleanPathStr(str3);
        this.dirname = str2;
        this.developer = z;
        if (!z && !new File(this.path, "model.yaml").isFile() && !new File(this.path, "rdf.yaml").isFile()) {
            this.presentYaml = false;
            this.params = new Parameters(this.valid, this.path, z);
            this.valid = check(str3);
        } else if (z || new File(this.path, "model.yaml").isFile() || new File(this.path, "rdf.yaml").isFile()) {
            try {
                this.params = new Parameters(this.valid, this.path, z);
                this.params.path2Model = this.path;
                this.valid = check(str3);
            } catch (Exception e) {
                IJ.log("Unable to read the rdf.yaml specifications file in following fodler.\nPlease review that the compulsory fields are not missing.\n -" + this.path);
            }
        }
        if (this.valid && z && this.params.framework.equals("tensorflow/pytorch")) {
            askFrameworkGUI();
        }
    }

    public static String cleanPathStr(String str) {
        while (str.indexOf(String.valueOf(File.separator) + File.separator) != -1) {
            str = str.replace(String.valueOf(File.separator) + File.separator, File.separator);
        }
        return str;
    }

    public String getPath() {
        return this.path;
    }

    public String getName() {
        return (this.params.name.equals("n.a.") ? this.dirname : this.params.name).replace("\"", "");
    }

    public ZooModel<NDList, NDList> getTorchModel() {
        return this.torchModel;
    }

    public void setTorchModel(ZooModel<NDList, NDList> zooModel) {
        this.torchModel = zooModel;
    }

    public SavedModelBundle getTfModel() {
        return this.tfModel;
    }

    public void setTfModel(SavedModelBundle savedModelBundle) {
        this.tfModel = savedModelBundle;
    }

    public boolean getValid() {
        return this.valid;
    }

    public static HashMap<String, DeepImageJ> list(String str, boolean z, TextArea textArea, String str2) {
        if (str2 == null) {
            return list(str, z, textArea);
        }
        if (!new File(str2).isDirectory()) {
            String str3 = "The following directory does not contain a model:" + System.lineSeparator() + " - " + str2;
            System.out.println("[DEBUG] " + str3);
            IJ.log(str3);
            return list(str, z, textArea);
        }
        HashMap<String, DeepImageJ> hashMap = new HashMap<>();
        DeepImageJ deepImageJ = new DeepImageJ(String.valueOf(new File(str2).getParent()) + File.separator, new File(str2).getName(), z);
        if (deepImageJ.valid && deepImageJ.params != null) {
            hashMap.put(deepImageJ.dirname, deepImageJ);
        }
        return hashMap;
    }

    public static HashMap<String, DeepImageJ> list(String str, boolean z, TextArea textArea) {
        HashMap<String, DeepImageJ> hashMap = new HashMap<>();
        File[] listFiles = new File(str).listFiles();
        if (listFiles == null) {
            String str2 = "No models found at: " + System.lineSeparator() + " - " + str;
            System.out.println("[DEBUG] " + str2);
            IJ.log(str2);
            return hashMap;
        }
        Date date = new Date();
        for (File file : listFiles) {
            if (file.isDirectory()) {
                String name = file.getName();
                if (textArea != null) {
                    textArea.append(" - " + new SimpleDateFormat("HH:mm:ss").format(date) + " -- Looking for a model at: " + name + "\n");
                }
                DeepImageJ deepImageJ = new DeepImageJ(String.valueOf(str) + File.separator, name, z);
                if (deepImageJ.valid && deepImageJ.params != null) {
                    hashMap.put(deepImageJ.dirname, deepImageJ);
                }
            }
        }
        return hashMap;
    }

    public boolean loadTfModel(boolean z) {
        double nanoTime = System.nanoTime();
        try {
            setTfModel(SavedModelBundle.load(this.path, new String[]{DeepLearningModel.returnStringTag(this.params.tag)}));
            double nanoTime2 = (System.nanoTime() - nanoTime) / 1000000.0d;
            return true;
        } catch (Exception e) {
            IJ.log("Exception in loading model " + this.dirname);
            IJ.log(e.toString());
            IJ.log(e.getMessage());
            return false;
        }
    }

    public boolean loadPtModel(String str, boolean z) {
        if (!z) {
            try {
                Thread.currentThread().setContextClassLoader(IJ.getClassLoader());
            } catch (IOException e) {
                IJ.log("Model not found in the path provided:");
                IJ.log(str);
                e.printStackTrace();
                return false;
            } catch (Exception e2) {
                e2.printStackTrace();
                return false;
            } catch (UnsatisfiedLinkError e3) {
                e3.printStackTrace();
                IJ.log("DeepImageJ could not load the Pytorch model.");
                IJ.log("This is probably because the Visual Studio 2019 Redistributables are missing.");
                IJ.log("In order to be able to load Pytorch models, download Visual Studio 2019 and ");
                IJ.log("its redistributables from teh following links.");
                IJ.log("- https://visualstudio.microsoft.com/es/downloads/");
                IJ.log("- https://support.microsoft.com/en-us/help/2977003/the-latest-supported-visual-c-downloads");
                IJ.log("If the problem persists visit the following link for more info:");
                IJ.log("- http://docs.djl.ai/docs/development/troubleshooting.html#13-unsatisfiedlinkerror-issue");
                return false;
            } catch (ModelNotFoundException e4) {
                e4.printStackTrace();
                return false;
            } catch (MalformedURLException e5) {
                e5.printStackTrace();
                return false;
            } catch (MalformedModelException e6) {
                e6.printStackTrace();
                return false;
            }
        }
        URL url = new File(new File(str).getParent()).toURI().toURL();
        String name = new File(str).getName();
        setTorchModel(ModelZoo.loadModel(Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls(url.toString()).optModelName(name.substring(0, name.indexOf(".pt"))).optProgress(new ProgressBar()).build()));
        return true;
    }

    public void writeParameters(TextArea textArea) {
        if (this.params == null) {
            textArea.append("No params\n");
            return;
        }
        if (this.params.ptAttachmentsNotIncluded.size() != 0) {
            textArea.append("----------- ATTENTION -----------\n");
            textArea.append("To use the Pytorch format, please make sure that\nthe following plugins/jars are installed:\n");
            Iterator<String> it = this.params.ptAttachmentsNotIncluded.iterator();
            while (it.hasNext()) {
                textArea.append(" - " + it.next() + "\n");
            }
        }
        if (this.params.tfAttachmentsNotIncluded.size() != 0) {
            textArea.append("----------- ATTENTION -----------\n");
            textArea.append("To use the Tensorflow format, please make sure that\nthe following plugins/jars are installed:\n");
            Iterator<String> it2 = this.params.tfAttachmentsNotIncluded.iterator();
            while (it2.hasNext()) {
                textArea.append(" - " + it2.next() + "\n");
            }
        }
        textArea.append("---------- MODEL INFO ----------\n");
        textArea.append("Authors\n");
        for (HashMap<String, String> hashMap : this.params.author) {
            String str = hashMap.get("name") == null ? "n/a" : hashMap.get("name");
            String str2 = hashMap.get("affiliation") == null ? "n/a" : hashMap.get("affiliation");
            String str3 = hashMap.get("orcid") == null ? "n/a" : hashMap.get("orcid");
            textArea.append("  - Name: " + str + "\n");
            textArea.append("    Affiliation: " + str2 + "\n");
            textArea.append("    Orcid: " + str3 + "\n");
        }
        textArea.append("References\n");
        for (HashMap<String, String> hashMap2 : this.params.cite) {
            textArea.append("  - Article: " + hashMap2.get("text") + "\n");
            textArea.append("    Doi: " + hashMap2.get("doi") + "\n");
        }
        textArea.append("Framework: " + this.params.framework + "\n");
        if (this.params.framework.contains("tensorflow")) {
            textArea.append("Tag: " + this.params.tag + "\n");
            textArea.append("Signature: " + this.params.graph + "\n");
        }
        textArea.append("Allow tiling: " + this.params.allowPatching + "\n");
        textArea.append("\n");
        textArea.append("------------ TEST INFO -----------\n");
        textArea.append("Inputs:\n");
        for (DijTensor dijTensor : this.params.inputList) {
            textArea.append("  - Name: " + dijTensor.exampleInput + "\n");
            textArea.append("    Size: " + dijTensor.inputTestSize + "\n");
            textArea.append("      x: " + dijTensor.inputPixelSizeX + "\n");
            textArea.append("      y: " + dijTensor.inputPixelSizeY + "\n");
            textArea.append("      z: " + dijTensor.inputPixelSizeZ + "\n");
        }
        textArea.append("Outputs:\n");
        for (HashMap<String, String> hashMap3 : this.params.savedOutputs) {
            textArea.append("  - Name: " + hashMap3.get("name") + "\n");
            textArea.append("  - Type: " + hashMap3.get("type") + "\n");
            textArea.append("     Size: " + hashMap3.get("size") + "\n");
        }
        textArea.append("Memory peak: " + this.params.memoryPeak + "\n");
        textArea.append("Runtime: " + this.params.runtime + "\n");
        String str4 = "weights-torchscript.pt";
        String str5 = "tensorflow_saved_model_bundle.zip";
        if (this.params.framework.toLowerCase().contains("pytorch")) {
            String findNameFromSourceParam = findNameFromSourceParam(this.params.ptSource, "pytorch");
            if (new File(String.valueOf(getPath()) + File.separator + findNameFromSourceParam).exists()) {
                str4 = findNameFromSourceParam;
            }
        }
        if (this.params.framework.toLowerCase().contains("tensorflow")) {
            String findNameFromSourceParam2 = findNameFromSourceParam(this.params.tfSource, "tensorflow");
            if (new File(String.valueOf(getPath()) + File.separator + findNameFromSourceParam2).exists()) {
                str5 = findNameFromSourceParam2;
            }
        }
        if (this.params.framework.equals("pytorch")) {
            String sb = new StringBuilder().append(new File(String.valueOf(getPath()) + File.separator + str4).length() / 1048576.0d).toString();
            textArea.append("Weights size: " + sb.substring(0, sb.lastIndexOf(".") + 3) + " MB\n");
            return;
        }
        if (this.params.framework.equals("tensorflow") && new File(getPath(), "variables").exists()) {
            String sb2 = new StringBuilder().append(FileTools.getFolderSize(String.valueOf(getPath()) + File.separator + "variables") / 1048576.0d).toString();
            textArea.append("Weights size: " + sb2.substring(0, sb2.lastIndexOf(".") + 3) + " MB\n");
            return;
        }
        if (this.params.framework.equals("tensorflow")) {
            String sb3 = new StringBuilder().append(new File(String.valueOf(getPath()) + File.separator + str5).length() / 1048576.0d).toString();
            textArea.append("Zipped model size: " + sb3.substring(0, sb3.lastIndexOf(".") + 2) + " MB\n");
            return;
        }
        if (this.params.framework.equals("tensorflow/pytorch") && new File(getPath(), "variables").exists()) {
            String sb4 = new StringBuilder().append(new File(String.valueOf(getPath()) + File.separator + str4).length() / 1048576.0d).toString();
            textArea.append("Pytorch weights size: " + sb4.substring(0, sb4.lastIndexOf(".") + 3) + " MB\n");
            String sb5 = new StringBuilder().append(FileTools.getFolderSize(String.valueOf(getPath()) + File.separator + "variables") / 1048576.0d).toString();
            textArea.append("Tensorflow weights size: " + sb5.substring(0, sb5.lastIndexOf(".") + 3) + " MB\n");
            return;
        }
        if (this.params.framework.equals("tensorflow/pytorch")) {
            String sb6 = new StringBuilder().append(new File(String.valueOf(getPath()) + File.separator + str4).length() / 1048576.0d).toString();
            textArea.append("Pytorch weights size: " + sb6.substring(0, sb6.lastIndexOf(".") + 3) + " MB\n");
            String sb7 = new StringBuilder().append(new File(String.valueOf(getPath()) + File.separator + str5).length() / 1048576.0d).toString();
            textArea.append("Zipped Tensorflow model size: " + sb7.substring(0, sb7.lastIndexOf(".") + 3) + " MB\n");
        }
    }

    public static String findNameFromSourceParam(String str, String str2) {
        String str3 = str;
        if (str3 == null && str2.toLowerCase().contentEquals("pytorch")) {
            str3 = "weights-torchscript.pt";
        } else if (str3 == null && str2.toLowerCase().contentEquals("tensorflow")) {
            str3 = "tensorflow_saved_model_bundle.zip";
        } else if (str3.indexOf("/") != -1 && str3.indexOf("/") < 2) {
            str3 = str3.substring(str3.indexOf("/") + 1);
        }
        return str3;
    }

    public boolean check(String str) {
        File file = new File(str);
        if (!file.exists() || !file.isDirectory()) {
            return false;
        }
        boolean z = false;
        boolean z2 = false;
        File file2 = new File(String.valueOf(str) + "saved_model.pb");
        File file3 = new File(String.valueOf(str) + "variables");
        if (file2.exists() && file3.exists()) {
            z = true;
            this.params.framework = "tensorflow";
        }
        if (findPytorchModel(file)) {
            this.params.selectedModelPath = file.getAbsolutePath();
            z2 = true;
            this.params.framework = "pytorch";
        }
        if (z && z2) {
            this.params.framework = "tensorflow/pytorch";
        }
        if (!z && !z2) {
            try {
                z = findZippedBiozooModel(file);
            } catch (IOException e) {
                z = false;
            }
        }
        return z || z2;
    }

    public boolean findZippedBiozooModel(File file) throws IOException {
        String str = this.params.tfSource;
        String name = str == null ? "tensorflow_saved_model_bundle.zip" : (str.indexOf("/") == -1 || str.indexOf("/") >= 2) ? checkURL(str) ? new File(str).getName() : new File(str).getName() : str.substring(str.indexOf("/") + 1);
        boolean z = false;
        for (String str2 : file.list()) {
            if (str2.equals(name) && !this.presentYaml) {
                this.tfName = name;
                return true;
            }
            if (str2.equals(name)) {
                FileTools.createSHA256(String.valueOf(file.getPath()) + File.separator + str2);
                this.tfName = name;
                return true;
            }
            if (str2.equals(name)) {
                IJ.log("Zipped Bioimage Model Zoo model at:");
                IJ.log(String.valueOf(file.getAbsolutePath()) + File.separator + str2);
                IJ.log("does not coincide with the one specified in the rdf.yaml (incorrect sha256).");
                IJ.log("\n");
                this.params.incorrectSha256 = true;
                this.tfName = name;
                return true;
            }
            if (str2.equals("tensorflow_saved_model_bundle.zip")) {
                z = true;
            }
        }
        if (z && !this.presentYaml) {
            return true;
        }
        if (z && FileTools.createSHA256(String.valueOf(file.getPath()) + File.separator + "tensorflow_saved_model_bundle.zip").equals(this.params.tfSha256)) {
            return true;
        }
        if (!z) {
            return false;
        }
        IJ.log("Zipped Bioimage Model Zoo model at:");
        IJ.log(String.valueOf(file.getAbsolutePath()) + File.separator + "tensorflow_saved_model_bundle.zip");
        IJ.log("does not coincide with the one specified in the rdf.yaml (incorrect sha256).");
        IJ.log("\n");
        this.params.incorrectSha256 = true;
        return true;
    }

    public boolean findPytorchModel(File file) {
        String str = this.params.ptSource;
        String name = str == null ? "weights-torchscript.pt" : (str.indexOf("/") == -1 || str.indexOf("/") >= 2) ? checkURL(str) ? new File(str).getName() : new File(str).getName() : str.substring(str.indexOf("/") + 1);
        boolean z = false;
        try {
            for (String str2 : file.list()) {
                if (!this.developer && str2.contains(name) && !this.presentYaml) {
                    this.ptName = name;
                    return true;
                }
                if (!this.developer && str2.contains(name) && FileTools.createSHA256(String.valueOf(file.getPath()) + File.separator + str2).equals(this.params.ptSha256)) {
                    this.ptName = name;
                    return true;
                }
                if (this.developer && str2.contains(".pt")) {
                    this.ptName = str2;
                    return true;
                }
                if (!this.developer && str2.contains(name)) {
                    IJ.log("Pytorch model at:");
                    IJ.log(String.valueOf(file.getAbsolutePath()) + File.separator + str2);
                    IJ.log("does not coincide with the one specified in the rdf.yaml (incorrect sha256).");
                    IJ.log("\n");
                    this.params.incorrectSha256 = true;
                    this.ptName = name;
                    return true;
                }
                if (!this.developer && (str2.contains("pytorch_script.pt") || str2.contains("weights-torchscript.pt"))) {
                    z = true;
                }
            }
            if (!this.developer && z && !this.presentYaml) {
                return true;
            }
            if (!this.developer && z && FileTools.createSHA256(String.valueOf(file.getPath()) + File.separator + "pytorch_script.pt").equals(this.params.ptSha256)) {
                return true;
            }
            if (this.developer || !z) {
                return false;
            }
            IJ.log("Zipped Bioimage Model Zoo model at:");
            IJ.log(String.valueOf(file.getAbsolutePath()) + File.separator + "pytorch_script.pt");
            IJ.log("does not coincide with the one specified in the rdf.yaml (incorrect sha256).");
            IJ.log("\n");
            this.params.incorrectSha256 = true;
            return true;
        } catch (IOException e) {
            return false;
        }
    }

    public static boolean isTherePytorch(File file) {
        for (String str : file.list()) {
            if (str.contains(".pt") && (str.indexOf(".pt") == str.lastIndexOf(".") || str.indexOf(".pth") == str.lastIndexOf("."))) {
                return true;
            }
        }
        return false;
    }

    public void askFrameworkGUI() {
        GenericDialog genericDialog = new GenericDialog("Choose model framework");
        genericDialog.addMessage("The folder provided contained both a Tensorflow and a Pytorch model");
        genericDialog.addMessage("Select which do you want to load.");
        genericDialog.addChoice("Select framework", new String[]{"tensorflow", "pytorch"}, "tensorflow");
        genericDialog.showDialog();
        if (genericDialog.wasCanceled()) {
            genericDialog.dispose();
        } else {
            this.params.framework = genericDialog.getNextChoice();
        }
    }

    public static boolean checkURL(String str) {
        try {
            new URL(str);
            return true;
        } catch (MalformedURLException e) {
            return false;
        }
    }
}
