Details
-
Task
-
Status: Resolved
-
Major
-
Resolution: Fixed
-
None
-
None
Description
Restore and migrate RDDConvererUtilsExt.stringDataFrameToVectorDataFrame method to Spark 2 (mllib->ml Vector) if needed.
The RDDConvererUtilsExt.stringDataFrameToVectorDataFrame method was removed by commit https://github.com/apache/incubator-systemml/commit/578e595fdc506fb8a0c0b18c312fe420a406276d. If this method is needed, migrate it to Spark 2.
Old method:
public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext sqlContext, Dataset<Row> inputDF) throws DMLRuntimeException { StructField[] oldSchema = inputDF.schema().fields(); //create the new schema StructField[] newSchema = new StructField[oldSchema.length]; for(int i = 0; i < oldSchema.length; i++) { String colName = oldSchema[i].name(); newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true); } //converter class StringToVector implements Function<Tuple2<Row, Long>, Row> { private static final long serialVersionUID = -4733816995375745659L; @Override public Row call(Tuple2<Row, Long> arg0) throws Exception { Row oldRow = arg0._1; int oldNumCols = oldRow.length(); if (oldNumCols > 1) { throw new DMLRuntimeException("The row must have at most one column"); } // parse the various strings. i.e // ((1.2,4.3, 3.4)) or (1.2, 3.4, 2.2) or (1.2 3.4) // [[1.2,34.3, 1.2, 1.2]] or [1.2, 3.4] or [1.3 1.2] Object [] fields = new Object[oldNumCols]; ArrayList<Object> fieldsArr = new ArrayList<Object>(); for (int i = 0; i < oldRow.length(); i++) { Object ci=oldRow.get(i); if (ci instanceof String) { String cis = (String)ci; StringBuffer sb = new StringBuffer(cis.trim()); for (int nid=0; i < 2; i++) { //remove two level nesting if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']') ) { sb.deleteCharAt(0); sb.setLength(sb.length() - 1); } } //have the replace code String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]"; Vector v = Vectors.parse(ncis); fieldsArr.add(v); } else { throw new DMLRuntimeException("Only String is supported"); } } Row row = RowFactory.create(fieldsArr.toArray()); return row; } } //output DF JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector()); // DataFrame outDF = sqlContext.createDataFrame(newRows, new StructType(newSchema)); //TODO investigate why it doesn't work Dataset<Row> outDF = sqlContext.createDataFrame(newRows.rdd(), DataTypes.createStructType(newSchema)); return outDF; }
Note: the org.apache.spark.ml.linalg.Vectors.parse() method does not exist in Spark 2.