package deepimagej;

import com.google.protobuf.InvalidProtocolBufferException;
import deepimagej.tools.DijTensor;
import deepimagej.tools.Index;
import deepimagej.tools.StartTensorflowService;
import ij.IJ;
import java.io.File;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.TensorFlowException;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorShapeProto;

/* loaded from: input_file:deepimagej/DeepLearningModel.class */
public class DeepLearningModel {
    private static final String DEFAULT_TAG = "serve";
    private static final String[] MODEL_TAGS = {DEFAULT_TAG, "inference", "train", "eval", "gpu", "tpu"};
    private static final String[] TF_MODEL_TAGS = {"tf.saved_model.tag_constants.SERVING", "tf.saved_model.tag_constants.INFERENCE", "tf.saved_model.tag_constants.TRAINING", "tf.saved_model.tag_constants.EVAL", "tf.saved_model.tag_constants.GPU", "tf.saved_model.tag_constants.TPU"};
    private static final String[] SIGNATURE_CONSTANTS = {"serving_default", "inputs", "tensorflow/serving/classify", "classes", "scores", "inputs", "tensorflow/serving/predict", "outputs", "inputs", "tensorflow/serving/regress", "outputs", "train", "eval", "tensorflow/supervised/training", "tensorflow/supervised/eval"};
    private static final String[] TF_SIGNATURE_CONSTANTS = {"tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY", "tf.saved_model.signature_constants.CLASSIFY_INPUTS", "tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME", "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES", "tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES", "tf.saved_model.signature_constants.PREDICT_INPUTS", "tf.saved_model.signature_constants.PREDICT_METHOD_NAME", "tf.saved_model.signature_constants.PREDICT_OUTPUTS", "tf.saved_model.signature_constants.REGRESS_INPUTS", "tf.saved_model.signature_constants.REGRESS_METHOD_NAME", "tf.saved_model.signature_constants.REGRESS_OUTPUTS", "tf.saved_model.signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY", "tf.saved_model.signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY", "tf.saved_model.signature_constants.SUPERVISED_TRAIN_METHOD_NAME", "tf.saved_model.signature_constants.SUPERVISED_EVAL_METHOD_NAME"};

    public static SavedModelBundle loadTf(String str, String str2, deepimagej.tools.Log log) {
        log.print("load model from " + str);
        try {
            Runtime runtime = Runtime.getRuntime();
            double freeMemory = runtime.freeMemory() / 1048576.0d;
            SavedModelBundle load = SavedModelBundle.load(str, new String[]{str2});
            System.out.println((runtime.freeMemory() / 1048576.0d) - freeMemory);
            log.print("Loaded");
            return load;
        } catch (Exception e) {
            log.print("Exception in loading model " + str);
            log.print(e.toString());
            log.print(e.getMessage());
            return null;
        }
    }

    public static SavedModelBundle loadTfModel(String str, String str2) {
        SavedModelBundle savedModelBundle;
        try {
            savedModelBundle = SavedModelBundle.load(str, new String[]{str2});
        } catch (TensorFlowException e) {
            System.out.println("The tag was incorrect");
            savedModelBundle = null;
        }
        return savedModelBundle;
    }

    public static Object[] findTfTag(String str) {
        return checkTfTags(str, DEFAULT_TAG);
    }

    public static Object[] checkTfTags(String str, String str2) {
        Set<String> set;
        SavedModelBundle savedModelBundle = null;
        Object[] objArr = new Object[3];
        try {
            savedModelBundle = SavedModelBundle.load(str, new String[]{str2});
            set = metaGraphsSet(savedModelBundle);
        } catch (TensorFlowException e) {
            int indexOf = Index.indexOf(MODEL_TAGS, str2);
            if (indexOf < MODEL_TAGS.length - 1) {
                Object[] checkTfTags = checkTfTags(str, MODEL_TAGS[indexOf + 1]);
                str2 = (String) checkTfTags[0];
                set = (Set) checkTfTags[1];
            } else {
                str2 = null;
                set = null;
            }
        }
        objArr[0] = str2;
        objArr[1] = set;
        objArr[2] = savedModelBundle;
        return objArr;
    }

    public static Set<String> metaGraphsSet(SavedModelBundle savedModelBundle) {
        Map map = null;
        try {
            map = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefMap();
        } catch (InvalidProtocolBufferException e) {
            System.out.println("The model is not a correct SavedModel model");
        }
        return map.keySet();
    }

    public static SignatureDef getSignatureFromGraph(SavedModelBundle savedModelBundle, String str) {
        SignatureDef signatureDef = null;
        try {
            signatureDef = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef()).getSignatureDefOrThrow(str);
        } catch (InvalidProtocolBufferException e) {
            System.out.println("Invalid graph");
        }
        return signatureDef;
    }

    public static int[] modelTfExitDimensions(SignatureDef signatureDef, String str) {
        List dimList = signatureDef.getOutputsOrThrow(str).getTensorShape().getDimList();
        int size = dimList.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = (int) ((TensorShapeProto.Dim) dimList.get(i)).getSize();
        }
        return iArr;
    }

    public static int[] modelTfEntryDimensions(SignatureDef signatureDef, String str) {
        List dimList = signatureDef.getInputsOrThrow(str).getTensorShape().getDimList();
        int size = dimList.size();
        int[] iArr = new int[size];
        for (int i = 0; i < size; i++) {
            iArr[i] = (int) ((TensorShapeProto.Dim) dimList.get(i)).getSize();
        }
        return iArr;
    }

    public static String[] returnTfOutputs(SignatureDef signatureDef) {
        Set keySet = signatureDef.getOutputsMap().keySet();
        return (String[]) keySet.toArray(new String[keySet.size()]);
    }

    public static String[] returnTfInputs(SignatureDef signatureDef) {
        Set keySet = signatureDef.getInputsMap().keySet();
        return (String[]) keySet.toArray(new String[keySet.size()]);
    }

    public static int nChannelsOrSlices(DijTensor dijTensor, String str) {
        int indexOf = Index.indexOf(dijTensor.form.split(""), str.equals("channels") ? "C" : "Z");
        return indexOf == -1 ? 1 : dijTensor.minimum_size[indexOf];
    }

    public static String hSize(Parameters parameters, String str) {
        int indexOf = Index.indexOf(str.split(""), "Y");
        return indexOf == -1 ? "-1" : Integer.toString(parameters.inputList.get(0).tensor_shape[indexOf]);
    }

    public static String wSize(Parameters parameters, String str) {
        int indexOf = Index.indexOf(str.split(""), "X");
        return indexOf == -1 ? "-1" : Integer.toString(parameters.inputList.get(0).tensor_shape[indexOf]);
    }

    public static String nBatch(int[] iArr, String str) {
        int indexOf = Index.indexOf(str.split(""), "B");
        String num = indexOf == -1 ? "1" : Integer.toString(iArr[indexOf]);
        if (num.equals("-1")) {
            num = "1";
        }
        return num;
    }

    public static String returnTfTag(String str) {
        int indexOf = Index.indexOf(MODEL_TAGS, str);
        return indexOf == -1 ? str : TF_MODEL_TAGS[indexOf];
    }

    public static String returnStringTag(String str) {
        int indexOf = Index.indexOf(TF_MODEL_TAGS, str);
        return indexOf == -1 ? str : MODEL_TAGS[indexOf];
    }

    public static Set<String> returnTfSig(Set<String> set) {
        Set<String> hashSet = new HashSet();
        for (int i = 0; i < TF_SIGNATURE_CONSTANTS.length; i++) {
            if (set.contains(SIGNATURE_CONSTANTS[i])) {
                hashSet.add(TF_SIGNATURE_CONSTANTS[i]);
            }
        }
        if (hashSet.size() != set.size()) {
            hashSet = set;
        }
        return hashSet;
    }

    public static String returnStringSig(String str) {
        int indexOf = Index.indexOf(TF_SIGNATURE_CONSTANTS, str);
        return indexOf == -1 ? str : SIGNATURE_CONSTANTS[indexOf];
    }

    public static String returnTfSig(String str) {
        int indexOf = Index.indexOf(SIGNATURE_CONSTANTS, str);
        return indexOf == -1 ? str : TF_SIGNATURE_CONSTANTS[indexOf];
    }

    public static String PytorchCUDACompatibility(String str, String str2) {
        String str3 = "";
        if (str2.equals("nocuda")) {
            str3 = "";
        } else if (str.contains("1.7.0") && !str2.contains("10.2") && !str2.contains("10.1") && !str2.contains("11.0")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with DJL Pytorch 1.7.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install either CUDA 10.1, CUDA 10.2 or CUDA 11.0.\n";
        } else if (str.contains("1.6.0") && !str2.contains("10.2") && !str2.contains("10.1")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with DJL Pytorch  1.6.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 10.1, CUDA 10.2.\n";
        } else if (str.contains("1.5.0") && !str2.contains("10.1") && !str2.contains("10.2") && !str2.contains("9.2")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with DJL Pytorch  1.5.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 9.2, CUDA 10.1 or CUDA 10.2.\n";
        } else if (str.contains("1.4.0") && !str2.contains("10.1") && !str2.contains("9.2")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with DJL Pytorch  1.4.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 9.2 or CUDA 10.1.\n";
        } else if (!str2.toLowerCase().contains("nocuda") && !str.equals("")) {
            str3 = "Make sure that the DJL Pytorch version is compatible with the installed CUDA version.\nCheck the DeepImageJ Wiki for more information";
        }
        return str3;
    }

    public static String TensorflowCUDACompatibility(String str, String str2) {
        String str3 = "";
        if (str.contains("1.15.0") && !str2.contains("10.0")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with tf 1.15.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 10.0.\n";
        } else if (str.contains("1.14.0") && !str2.contains("10.0")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with tf 1.14.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 10.0.\n";
        } else if (str.contains("1.13.0") && !str2.contains("10.0")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with tf 1.13.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 10.0.\n";
        } else if (str.contains("1.12.0") && !str2.contains("9.0")) {
            str3 = "Installed CUDA version " + str2 + " is not compatible with tf 1.12.0.\nThe plugin might not be able to run on GPU.\nFor optimal performance please install CUDA 9.0.\n";
        } else if (!str.contains("1.15.0") && !str.contains("1.14.0") && !str.contains("1.13.0") && !str.contains("1.12.0")) {
            str3 = "Make sure that the Tensorflow version is compatible with the installed CUDA version.\n";
        }
        return str3;
    }

    public static String getPytorchVersion() {
        String libPytorchJar = getLibPytorchJar();
        return !libPytorchJar.contains("jar") ? libPytorchJar : getPytorchVersionFromJar(libPytorchJar);
    }

    public static String getLibPytorchJar() {
        String str = String.valueOf(IJ.getDirectory("imagej")) + File.separator;
        String findPytorchJar = findPytorchJar(String.valueOf(str) + File.separator + "plugins" + File.separator);
        String findPytorchJar2 = findPytorchJar(String.valueOf(str) + File.separator + "jars" + File.separator);
        if (findPytorchJar2.equals(findPytorchJar) && findPytorchJar2.equals("")) {
            return "-No Pytorch version found-";
        }
        if (findPytorchJar2.toLowerCase().contains("more than 1 version") || findPytorchJar.toLowerCase().contains("more than 1 version")) {
            return "-More than one Pytorch version present-";
        }
        if (findPytorchJar2.toLowerCase().contains("tensorflow") && findPytorchJar2.toLowerCase().contains("tensorflow") && !findPytorchJar2.equals(findPytorchJar)) {
            return "-The plugins and jars directories contains a different version of Pytorch each-";
        }
        String str2 = findPytorchJar;
        if (str2.equals("")) {
            str2 = findPytorchJar2;
        }
        return str2;
    }

    public static String findPytorchJar(String str) {
        int i = 0;
        String str2 = "";
        File[] listFiles = new File(str).listFiles();
        if (listFiles == null) {
            return "";
        }
        for (File file : listFiles) {
            if (file.isFile()) {
                String absolutePath = file.getAbsolutePath();
                if (absolutePath.indexOf("pytorch-native-auto") != -1) {
                    i++;
                    str2 = absolutePath;
                }
            }
        }
        if (i == 0) {
            str2 = "";
        } else if (i > 1) {
            str2 = "more than 1 version";
        }
        return str2;
    }

    public static String getPytorchVersionFromJar(String str) {
        String lowerCase = str.toLowerCase();
        return lowerCase.substring(lowerCase.lastIndexOf("pytorch-native-auto-") + "pytorch-native-auto-".length(), lowerCase.indexOf(".jar"));
    }

    public static String getTFVersion(boolean z) {
        return z ? StartTensorflowService.getTfService().getTensorFlowVersion().getVersionNumber() : getTFVersionIJ();
    }

    public static String getTFVersionIJ() {
        String libTfJar;
        try {
            URL resource = ClassLoader.getSystemClassLoader().getResource("org/tensorflow/native");
            if (resource == null) {
                resource = IJ.getClassLoader().getResource("org/tensorflow/native");
            }
            libTfJar = ((JarURLConnection) resource.openConnection()).getJarFileURL().getFile();
        } catch (Exception e) {
            libTfJar = getLibTfJar();
            if (!libTfJar.contains("jar")) {
                return libTfJar;
            }
        }
        String tfVersionFromJar = getTfVersionFromJar(libTfJar);
        if (tfVersionFromJar.contains("gpu")) {
            tfVersionFromJar = String.valueOf(tfVersionFromJar.substring(tfVersionFromJar.toLowerCase().indexOf("gpu_") + 5)) + " GPU";
        }
        return tfVersionFromJar;
    }

    public static String getLibTfJar() {
        String str = String.valueOf(IJ.getDirectory("imagej")) + File.separator;
        String findTFJar = findTFJar(String.valueOf(str) + File.separator + "plugins" + File.separator);
        String findTFJar2 = findTFJar(String.valueOf(str) + File.separator + "jars" + File.separator);
        if (findTFJar2.equals(findTFJar) && findTFJar2.equals("")) {
            return "-No Tensorflow version found-";
        }
        if (findTFJar2.toLowerCase().contains("more than 1 version") || findTFJar.toLowerCase().contains("more than 1 version")) {
            return "-More than one tensorflow version present-";
        }
        if (findTFJar2.toLowerCase().contains("tensorflow") && findTFJar.toLowerCase().contains("tensorflow") && !findTFJar2.equals(findTFJar)) {
            return "-The plugins and jars directories contains a different version of TF each-";
        }
        String str2 = findTFJar;
        if (str2.equals("")) {
            str2 = findTFJar2;
        }
        return str2;
    }

    public static String findTFJar(String str) {
        int i = 0;
        String str2 = "";
        File[] listFiles = new File(str).listFiles();
        if (listFiles == null) {
            return "";
        }
        for (File file : listFiles) {
            if (file.isFile()) {
                String absolutePath = file.getAbsolutePath();
                if (absolutePath.indexOf("libtensorflow_jni") != -1) {
                    i++;
                    str2 = absolutePath;
                }
            }
        }
        if (i == 0) {
            str2 = "";
        } else if (i > 1) {
            str2 = "more than 1 version";
        }
        return str2;
    }

    public static String getTfVersionFromJar(String str) {
        String lowerCase = str.toLowerCase();
        return lowerCase.substring(lowerCase.lastIndexOf("libtensorflow_jni") + "libtensorflow_jni".length() + 1, lowerCase.indexOf(".jar"));
    }
}
