package deepimagej.stamp;

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.jni.LibUtils;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.training.util.ProgressBar;
import deepimagej.BuildDialog;
import deepimagej.Constants;
import deepimagej.DeepLearningModel;
import deepimagej.Parameters;
import deepimagej.components.HTMLPane;
import deepimagej.tools.DijTensor;
import deepimagej.tools.SystemUsage;
import ij.IJ;
import ij.gui.GenericDialog;
import java.awt.BorderLayout;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Method;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.swing.BoxLayout;
import javax.swing.JPanel;
import javax.swing.JTextField;
import org.mvel2.MVEL;

/* loaded from: input_file:deepimagej/stamp/LoadPytorchStamp.class */
public class LoadPytorchStamp extends AbstractStamp implements Runnable {
    private JTextField inpNumber;
    private JTextField outNumber;
    private HTMLPane pnLoad;

    public LoadPytorchStamp(BuildDialog buildDialog) {
        super(buildDialog);
        this.inpNumber = new JTextField();
        this.outNumber = new JTextField();
        buildPanel();
    }

    @Override // deepimagej.stamp.AbstractStamp
    public void buildPanel() {
        this.pnLoad = new HTMLPane(Constants.width, 70);
        HTMLPane hTMLPane = new HTMLPane(Constants.width / 2, 70);
        hTMLPane.append("h2", "Number of inputs");
        hTMLPane.append("p", "Number of inputs to the Pytorch model");
        HTMLPane hTMLPane2 = new HTMLPane((2 * Constants.width) / 2, 70);
        hTMLPane2.append("h2", "Number of outputs");
        hTMLPane2.append("p", "Number of outputs of the Pytorch model.");
        JPanel jPanel = new JPanel();
        jPanel.setLayout(new BoxLayout(jPanel, 3));
        jPanel.add(hTMLPane.getPane());
        jPanel.add(this.inpNumber);
        this.inpNumber.setText(MVEL.VERSION_SUB);
        this.inpNumber.setEnabled(false);
        jPanel.add(hTMLPane2.getPane());
        jPanel.add(this.outNumber);
        this.outNumber.setText(MVEL.VERSION_SUB);
        this.outNumber.setEnabled(false);
        JPanel jPanel2 = new JPanel(new BorderLayout());
        jPanel2.add(this.pnLoad.getPane(), "Center");
        jPanel2.add(jPanel, "South");
        this.panel.add(jPanel2);
    }

    @Override // deepimagej.stamp.AbstractStamp
    public void init() {
        Thread thread = new Thread(this);
        thread.setPriority(1);
        thread.start();
    }

    @Override // deepimagej.stamp.AbstractStamp
    public boolean finish() {
        Parameters parameters = this.parent.getDeepPlugin().params;
        parameters.totalInputList = new ArrayList();
        parameters.totalOutputList = new ArrayList();
        try {
            int parseInt = Integer.parseInt(this.inpNumber.getText().trim());
            int parseInt2 = Integer.parseInt(this.outNumber.getText().trim());
            if (parseInt2 < 1) {
                IJ.error("The number of outputs shoud be 1 or bigger");
                return false;
            }
            if (parseInt < 1) {
                IJ.error("The number of inputs shoud be 1 or bigger");
                return false;
            }
            for (int i = 0; i < parseInt; i++) {
                parameters.totalInputList.add(new DijTensor("input" + i));
            }
            for (int i2 = 0; i2 < parseInt2; i2++) {
                parameters.totalOutputList.add(new DijTensor("output" + i2));
            }
            return true;
        } catch (Exception e) {
            if (0 == 0) {
                IJ.error("Please introduce a valid integer for the number of inputs.");
                return false;
            }
            if (0 == 0) {
                return false;
            }
            IJ.error("Please introduce a valid integer for the number of outputs.");
            return false;
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        this.pnLoad.setCaretPosition(0);
        this.pnLoad.setText("");
        this.pnLoad.append("p", "Loading Deep Java Library...");
        Parameters parameters = this.parent.getDeepPlugin().params;
        parameters.selectedModelPath = findPytorchModels(parameters.path2Model);
        this.pnLoad.clear();
        parameters.pytorchVersion = DeepLearningModel.getPytorchVersion();
        this.pnLoad.append("h2", "Pytorch version");
        this.pnLoad.append("p", "Currently using Pytorch " + parameters.pytorchVersion);
        this.pnLoad.append("p", "Supported by Deep Java Library " + parameters.pytorchVersion);
        String cUDAEnvVariables = SystemUsage.getCUDAEnvVariables();
        if (cUDAEnvVariables.toLowerCase().equals("nocuda")) {
            this.pnLoad.append("p", "No CUDA distribution found.\n");
            this.parent.setGPUTf("CPU");
        } else if (!cUDAEnvVariables.contains(File.separator) && !cUDAEnvVariables.contains("---")) {
            this.pnLoad.append("p", "Currently using CUDA " + cUDAEnvVariables);
        } else if (!cUDAEnvVariables.contains(File.separator) && cUDAEnvVariables.contains("---")) {
            String[] split = cUDAEnvVariables.split("---");
            if (split.length == 1) {
                this.pnLoad.append("p", "Currently using CUDA " + split[0]);
            } else {
                for (String str : split) {
                    this.pnLoad.append("p", "Found CUDA " + str);
                }
            }
        } else if (cUDAEnvVariables.contains("bin") || cUDAEnvVariables.contains("libnvvp")) {
            String[] split2 = cUDAEnvVariables.split(";");
            this.pnLoad.append("p", "Found CUDA distribution " + split2[0] + ".\n");
            this.pnLoad.append("p", "Could not find environment variable:\n - " + split2[1] + "\n");
            if (split2.length == 3) {
                this.pnLoad.append("p", "Could not find environment variable:\n - " + split2[2] + "\n");
            }
            this.pnLoad.append("p", "Please add the missing environment variables to the path.\n");
        }
        this.pnLoad.append("p", DeepLearningModel.PytorchCUDACompatibility(parameters.pytorchVersion, cUDAEnvVariables));
        this.pnLoad.append("h2", "Model info");
        this.pnLoad.append("p", "Path: " + parameters.selectedModelPath);
        this.pnLoad.append("<p>Loading model...");
        if (!SystemUsage.checkFiji()) {
            Thread.currentThread().setContextClassLoader(IJ.getClassLoader());
        }
        this.parent.setEnabledBackNext(false);
        try {
            URL url = new File(new File(parameters.path2Model).getAbsolutePath()).toURI().toURL();
            if (parameters.selectedModelPath.equals("")) {
                this.pnLoad.append("No Pytorch model found in the directory.");
                this.parent.setEnabledBack(true);
            }
            String name = new File(parameters.selectedModelPath).getName();
            String substring = name.substring(0, name.indexOf(".pt"));
            long nanoTime = System.nanoTime();
            this.parent.getDeepPlugin().setTorchModel(ModelZoo.loadModel(Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls(url.toString()).optModelName(substring).optProgress(new ProgressBar()).build()));
            this.pnLoad.append(" -> Loaded!!!</p>");
            parameters.pytorchVersion = Engine.getInstance().getVersion();
            if (new File(getNativeLbraryFile()).getName().toLowerCase().contains("cpu")) {
                this.pnLoad.append("p", "Model loaded on the <b>CPU</b>.\n");
                this.parent.setGPUPt("CPU");
            } else {
                this.pnLoad.append("p", "Model loaded on the <b>GPU</b>.\n");
                this.parent.setGPUPt("GPU");
            }
            String sb = new StringBuilder().append(new File(parameters.selectedModelPath).length() / 1048576.0d).toString();
            String substring2 = sb.substring(0, sb.lastIndexOf(".") + 2);
            String sb2 = new StringBuilder().append(((float) (System.nanoTime() - nanoTime)) / 1.0E9f).toString();
            String substring3 = sb2.substring(0, sb2.lastIndexOf(".") + 3);
            this.pnLoad.append("p", "Model size: " + substring2 + " Mb");
            this.pnLoad.append("p", "Loading time: " + substring3 + " s");
            this.parent.setEnabledBackNext(true);
            this.inpNumber.setEnabled(true);
            this.outNumber.setEnabled(true);
        } catch (Exception e) {
            this.pnLoad.append("p", "DeepImageJ could not load the model");
            this.pnLoad.append("p", "Error whie accessing the model file.");
            this.parent.setEnabledBack(true);
            e.printStackTrace();
        } catch (MalformedModelException e2) {
            this.pnLoad.append("p", "DeepImageJ could not load the model");
            this.pnLoad.append("p", "The model provided is not a correct Torchscript model.");
            this.parent.setEnabledBack(true);
            e2.printStackTrace();
        } catch (MalformedURLException e3) {
            this.pnLoad.append("p", "DeepImageJ could not load the model");
            this.pnLoad.append("p", "Check that the path provided to the model remains the same.");
            this.parent.setEnabledBack(true);
            e3.printStackTrace();
        } catch (IOException e4) {
            this.pnLoad.append("p", "DeepImageJ could not load the model");
            this.pnLoad.append("p", "Error whie accessing the model file.");
            this.parent.setEnabledBack(true);
            e4.printStackTrace();
        } catch (ModelNotFoundException e5) {
            this.pnLoad.append("p", "DeepImageJ could not load the model");
            this.pnLoad.append("p", "No model was found in the path provided.");
            this.parent.setEnabledBack(true);
            e5.printStackTrace();
        } catch (EngineException e6) {
            String message = e6.getMessage();
            String lowerCase = System.getProperty("os.name").toLowerCase();
            if (lowerCase.contains("win") && message.contains("https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md")) {
                this.pnLoad.append("p", "DeepImageJ could not load the model");
                this.pnLoad.append("p", "Please install the Visual Studio 2019 redistributables and reboot\nyour machine to be able to use Pytorch with DeepImageJ.");
                this.pnLoad.append("p", "For more information:\n");
                this.pnLoad.append("p", " -https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md");
                this.pnLoad.append("p", " -https://github.com/awslabs/djl/issues/126");
                this.pnLoad.append("p", "If you already have installed VS2019 redistributables, the error\nmight be caused by a missing dependency or an incompatible Pytorch version.");
                this.pnLoad.append("p", "Furthermore, the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) should be compatible with each other.");
                this.pnLoad.append("p", "Please check the DeepImageJ Wiki.");
            } else if ((lowerCase.contains("linux") || lowerCase.contains("unix")) && message.contains("https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md")) {
                this.pnLoad.append("p", "DeepImageJ could not load the model.");
                this.pnLoad.append("p", "Check that there are no repeated dependencies on the jars folder.");
                this.pnLoad.append("p", "The problem might be caused by a missing or repeated dependency or an incompatible Pytorch version.");
                this.pnLoad.append("p", "Furthermore, the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) should be compatible with each other.");
                this.pnLoad.append("p", "If the problem persists, please check the DeepImageJ Wiki.");
            } else {
                this.pnLoad.append("p", "DeepImageJ could not load the model");
                this.pnLoad.append("p", "Either the DJL Pytorch version is incompatible with the Torchscript model's Pytorch version or the DJL Pytorch dependencies (pytorch-egine, pytorch-api and pytorch-native-auto) are not compatible with each other.");
                this.pnLoad.append("p", "Please check the DeepImageJ Wiki.");
            }
            this.parent.setEnabledBack(true);
            e6.printStackTrace();
        }
    }

    private String findPytorchModels(String str) {
        File[] listFiles = new File(str).listFiles();
        ArrayList arrayList = new ArrayList();
        for (File file : listFiles) {
            if (file.getName().contains(".pt")) {
                arrayList.add(file);
            }
        }
        if (arrayList.size() == 1) {
            return ((File) arrayList.get(0)).getAbsolutePath();
        }
        GenericDialog genericDialog = new GenericDialog("Choose Pytorch model");
        genericDialog.addMessage("The folder provided contained several Pytorch models");
        genericDialog.addMessage("Select which do you want to load.");
        String[] strArr = new String[arrayList.size()];
        int i = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            strArr[i2] = ((File) it.next()).getName();
        }
        genericDialog.addChoice("Select framework", strArr, strArr[0]);
        genericDialog.showDialog();
        if (!genericDialog.wasCanceled()) {
            return String.valueOf(str) + File.separator + genericDialog.getNextChoice();
        }
        genericDialog.dispose();
        return "";
    }

    public static String getNativeLbraryFile() {
        String str = "???";
        try {
            Method declaredMethod = LibUtils.class.getDeclaredMethod("findNativeLibrary", AtomicBoolean.class);
            declaredMethod.setAccessible(true);
            str = (String) declaredMethod.invoke(LibUtils.class, new AtomicBoolean(false));
        } catch (Exception e) {
            e.printStackTrace();
        }
        return str;
    }
}
