¿Cómo puedo llamar clasificadores scikit-learn desde Java?

Tengo un clasificador que entrené usando Python scikit-learn. ¿Cómo puedo usar el clasificador de un progtwig Java? ¿Puedo usar Jython? ¿Hay alguna forma de guardar el clasificador en Python y cargarlo en Java? ¿Hay alguna otra manera de usarlo?

Related of "¿Cómo puedo llamar clasificadores scikit-learn desde Java?"

No se puede usar jython, ya que scikit-learn se basa en gran medida en scpy y scipy que tienen muchas extensiones comstackdas de C y Fortran, por lo que no pueden funcionar en jython.

Las formas más fáciles de usar scikit-learn en un entorno java serían:

  • exponga el clasificador como un servicio HTTP / Json, por ejemplo, utilizando un microframo, como un matraz o botella o cornisa, y llámelo desde java utilizando una biblioteca de cliente HTTP

  • escriba una aplicación de envoltorio de línea de comandos en python que lea datos en stdin y produzca predicciones en stdout usando algún formato como CSV o JSON (o alguna representación binaria de nivel inferior) y llame al progtwig python desde java, por ejemplo, utilizando Apache Commons Exec .

  • haga que la salida del progtwig python sea los parámetros numéricos en bruto aprendidos en el momento del ajuste (típicamente como una matriz de valores de punto flotante) y vuelva a implementar la función de predicción en java (esto es típicamente fácil para los modelos lineales predictivos donde la predicción es a menudo solo un producto de puntos con umbral) .

El último enfoque será mucho más trabajo si también necesita volver a implementar la extracción de características en Java.

Finalmente, puede utilizar una biblioteca de Java como Weka o Mahout que implementa los algoritmos que necesita en lugar de intentar utilizar scikit-learn desde Java.

Hay un proyecto JPMML para este propósito.

Primero, puede serializar el modelo scikit-learn a PMML (que es XML internamente) usando la biblioteca sklearn2pmml directamente desde python o volcarlo en python primero y convertir usando jpmml-sklearn en java o desde una línea de comando proporcionada por esta biblioteca. A continuación, puede cargar el archivo pmml, deserializar y ejecutar el modelo cargado utilizando jpmml-evaluator en su código Java.

De esta manera no funciona con todos los modelos de scikit-learn, sino con muchos de ellos.

Puede utilizar un cargador, he probado el sklearn-porter ( https://github.com/nok/sklearn-porter ) y funciona bien para Java.

Mi código es el siguiente:

import pandas as pd from sklearn import tree from sklearn_porter import Porter train_dataset = pd.read_csv('./result2.csv').as_matrix() X_train = train_dataset[:90, :8] Y_train = train_dataset[:90, 8:] X_test = train_dataset[90:, :8] Y_test = train_dataset[90:, 8:] print X_train.shape print Y_train.shape clf = tree.DecisionTreeClassifier() clf = clf.fit(X_train, Y_train) porter = Porter(clf, language='java') output = porter.export(embed_data=True) print(output) 

En mi caso, estoy usando un DecisionTreeClassifier, y la salida de

imprimir (salida)

Es el siguiente código como texto en la consola:

 class DecisionTreeClassifier { private static int findMax(int[] nums) { int index = 0; for (int i = 0; i < nums.length; i++) { index = nums[i] > nums[index] ? i : index; } return index; } public static int predict(double[] features) { int[] classes = new int[2]; if (features[5] <= 51.5) { if (features[6] <= 21.0) { // HUGE amount of ifs.......... } } return findMax(classes); } public static void main(String[] args) { if (args.length == 8) { // Features: double[] features = new double[args.length]; for (int i = 0, l = args.length; i < l; i++) { features[i] = Double.parseDouble(args[i]); } // Prediction: int prediction = DecisionTreeClassifier.predict(features); System.out.println(prediction); } } } 

Aquí hay algo de código para la solución JPMML:

–Parte de piña–

 # helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator. def determine_categorical_columns(df): categorical_columns = [] x = 0 for col in df.dtypes: if col == 'object': val = df[df.columns[x]].iloc[0] if not isinstance(val,Decimal): categorical_columns.append(df.columns[x]) x += 1 return categorical_columns categorical_columns = determine_categorical_columns(df) other_columns = list(set(df.columns).difference(categorical_columns)) #construction of transformators for our example labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns] nones = [(d, None) for d in other_columns] transformators = labelBinarizers+nones mapper = DataFrameMapper(transformators,df_out=True) gbc = GradientBoostingClassifier() #construction of the pipeline lm = PMMLPipeline([ ("mapper", mapper), ("estimator", gbc) ]) 

–JAVA PART –

 //Initialisation. String pmmlFile = "ScikitLearnNew.pmml"; PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile)); ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance(); MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml); //Determine which features are required as input HashMap() inputFieldMap = new HashMap(); for (int i = 0; i < evaluator.getInputFields().size();i++) { InputField curInputField = evaluator.getInputFields().get(i); String fieldName = curInputField.getName().getValue(); inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField()); } //prediction HashMap argsMap = new HashMap(); //... fill argsMap with input Map res; // here we keep only features that are required by the model Map args = new HashMap(); Iterator iter = argsMap.keySet().iterator(); while (iter.hasNext()) { String key = iter.next(); Field f = inputFieldMap.get(key); if (f != null) { FieldName name =f.getName(); String value = argsMap.get(key); args.put(name, value); } } //the model is applied to input, a probability distribution is obtained res = evaluator.evaluate(args); SegmentResult segmentResult = (SegmentResult) res; Object targetValue = segmentResult.getTargetValue(); ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue; 

Me encontré en una situación similar. Recomiendo tallar un microservicio clasificador. Podría tener un microservicio clasificador que se ejecute en Python y luego exponga las llamadas a ese servicio a través de alguna API RESTFul que produzca un formato de intercambio de datos JSON / XML. Creo que este es un enfoque más limpio.

Alternativamente, solo puede generar un código Python a partir de un modelo entrenado. Aquí hay una herramienta que puede ayudarlo con https://github.com/BayesWitnesses/m2cgen