Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35142][PYTHON][ML] Fix incorrect return type for rawPredictionUDF in OneVsRestModel #32245

Closed
wants to merge 8 commits into from

Conversation

@harupy
Copy link
Contributor

@harupy harupy commented Apr 20, 2021

What changes were proposed in this pull request?

Fixes incorrect return type for rawPredictionUDF in OneVsRestModel.

Why are the changes needed?

Bugfix

Does this PR introduce any user-facing change?

No

How was this patch tested?

Unit test.

@harupy harupy force-pushed the harupy:SPARK-35142 branch from bb641f5 to 3f75ab2 Apr 20, 2021
@harupy harupy marked this pull request as ready for review Apr 20, 2021
@@ -3151,7 +3151,7 @@ def func(predictions):
predArray.append(x)
return Vectors.dense(predArray)

rawPredictionUDF = udf(func)

This comment has been minimized.

@harupy

harupy Apr 20, 2021
Author Contributor

Should I add a test here to ensure that the rawPrediction column is no longer string

def test_output_columns(self):
df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
(1.0, Vectors.sparse(2, [], [])),
(2.0, Vectors.dense(0.5, 0.5))],
["label", "features"])
lr = LogisticRegression(maxIter=5, regParam=0.01)
ovr = OneVsRest(classifier=lr, parallelism=1)
model = ovr.fit(df)
output = model.transform(df)
self.assertEqual(output.columns, ["label", "features", "rawPrediction", "prediction"])

This comment has been minimized.

@HyukjinKwon

HyukjinKwon Apr 20, 2021
Member

Yeah, I think we should better add a test if possible.

This comment has been minimized.

@harupy

harupy Apr 20, 2021
Author Contributor

Got it, added a test

This comment has been minimized.

@WeichenXu123

WeichenXu123 Apr 20, 2021
Contributor

@HyukjinKwon
why only transformed_df.head() trigger this error ?
does it indicate bugs in pyspark-sql udf ?

This comment has been minimized.

@HyukjinKwon

HyukjinKwon Apr 21, 2021
Member

Seems like pred.show() triggers an exception too? what does it return in other methods?

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 20, 2021

ok to test

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 20, 2021

add to whitelist

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 20, 2021

cc @WeichenXu123 FYI

harupy added 2 commits Apr 20, 2021
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

Test build #137665 has finished for PR 32245 at commit 3f75ab2.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

Test build #137666 has finished for PR 32245 at commit 5e05b50.

  • This patch fails Python style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

@WeichenXu123
Copy link
Contributor

@WeichenXu123 WeichenXu123 commented Apr 20, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

Test build #137668 has finished for PR 32245 at commit 3c2ac95.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 20, 2021

Copy link
Contributor

@WeichenXu123 WeichenXu123 left a comment

LGTM

harupy added 3 commits Apr 21, 2021
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 21, 2021

Test build #137708 has finished for PR 32245 at commit b6fabb3.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 21, 2021

Kubernetes integration test unable to build dist.

exiting with code: 1
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/42236/

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 21, 2021

Test build #137713 has finished for PR 32245 at commit ed26d2c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
@SparkQA
Copy link

@SparkQA SparkQA commented Apr 21, 2021

@SparkQA
Copy link

@SparkQA SparkQA commented Apr 21, 2021

@WeichenXu123
Copy link
Contributor

@WeichenXu123 WeichenXu123 commented Apr 21, 2021

LGTM

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 21, 2021

Looks good. @harupy, would you mind filling the PR description per the template?

@HyukjinKwon HyukjinKwon changed the title [SPARK-35142][ML] Fix incorrect return type for rawPredictionUDF in OneVsRestModel [SPARK-35142][PYTHON][ML] Fix incorrect return type for rawPredictionUDF in OneVsRestModel Apr 21, 2021
@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 21, 2021

@viirya, are you preparing Spark 2.4 RC now? This is supposed to be in Spark 2.4 too but this isn't a regression so it doesn't block. It's just a good to have so if you're preparing, it should be fine to don't backport.

@viirya
Copy link
Member

@viirya viirya commented Apr 21, 2021

@viirya, are you preparing Spark 2.4 RC now? This is supposed to be in Spark 2.4 too but this isn't a regression so it doesn't block. It's just a good to have so if you're preparing, it should be fine to don't backport.

#32256 was just merged, so I have not started new RC yet. I can wait for this.

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 21, 2021

BTW, the tests passed at https://github.com/harupy/spark/actions/runs/769366516. GitHub Actions didn't work properly for linking that run for some reasons ..

I will leave it to @WeichenXu123 then.

WeichenXu123 added a commit that referenced this pull request Apr 21, 2021
…nUDF` in `OneVsRestModel`

### What changes were proposed in this pull request?

Fixes incorrect return type for `rawPredictionUDF` in `OneVsRestModel`.

### Why are the changes needed?
Bugfix

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Unit test.

Closes #32245 from harupy/SPARK-35142.

Authored-by: harupy <17039389+harupy@users.noreply.github.com>
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
(cherry picked from commit b6350f5)
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
@WeichenXu123
Copy link
Contributor

@WeichenXu123 WeichenXu123 commented Apr 21, 2021

@harupy

Backport to branch-3.1 cause conflicts.
Could you create a PR against apache/spark branch-3.1 ?

++<<<<<<< HEAD
 +    def test_parallelism_doesnt_change_output(self):
++=======
+     def test_raw_prediction_column_is_of_vector_type(self):
+         # SPARK-35142: `OneVsRestModel` outputs raw prediction as a string column
+         df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8)),
+                                          (1.0, Vectors.sparse(2, [], [])),
+                                          (2.0, Vectors.dense(0.5, 0.5))],
+                                         ["label", "features"])
+         lr = LogisticRegression(maxIter=5, regParam=0.01)
+         ovr = OneVsRest(classifier=lr, parallelism=1)
+         model = ovr.fit(df)
+         row = model.transform(df).head()
+         self.assertIsInstance(row["rawPrediction"], DenseVector)
+ 
+     def test_parallelism_does_not_change_output(self):
++>>>>>>> b6350f5bb0... [SPARK-35142][PYTHON][ML] Fix incorrect return type for `rawPredictionUDF` in `OneVsRestModel`
@harupy
Copy link
Contributor Author

@harupy harupy commented Apr 21, 2021

@WeichenXu123 Opened a PR: #32269

@viirya
Copy link
Member

@viirya viirya commented Apr 21, 2021

I don't see backport to 2.4. Do you plan to backport it? @WeichenXu123 @harupy?

@harupy
Copy link
Contributor Author

@harupy harupy commented Apr 21, 2021

@viirya Got it. I'll open another PR for 2.4.


Wait, does OneVsRestModel in 2.4 output the raw prediction column? Looks like it doesn't.

def _transform(self, dataset):
# determine the input columns: these need to be passed through
origCols = dataset.columns
# add an accumulator column to store predictions of all the models
accColName = "mbc$acc" + str(uuid.uuid4())
initUDF = udf(lambda _: [], ArrayType(DoubleType()))
newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))
# persist if underlying dataset is not persistent.
handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
if handlePersistence:
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
# update the accumulator column with the result of prediction of models
aggregatedDataset = newDataset
for index, model in enumerate(self.models):
rawPredictionCol = model._call_java("getRawPredictionCol")
columns = origCols + [rawPredictionCol, accColName]
# add temporary column to store intermediate scores and update
tmpColName = "mbc$tmp" + str(uuid.uuid4())
updateUDF = udf(
lambda predictions, prediction: predictions + [prediction.tolist()[1]],
ArrayType(DoubleType()))
transformedDataset = model.transform(aggregatedDataset).select(*columns)
updatedDataset = transformedDataset.withColumn(
tmpColName,
updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
newColumns = origCols + [tmpColName]
# switch out the intermediate column with the accumulator column
aggregatedDataset = updatedDataset\
.select(*newColumns).withColumnRenamed(tmpColName, accColName)
if handlePersistence:
newDataset.unpersist()
# output the index of the classifier with highest confidence as prediction
labelUDF = udf(
lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
DoubleType())
# output label and label metadata as prediction
return aggregatedDataset.withColumn(
self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName)

@HyukjinKwon
Copy link
Member

@HyukjinKwon HyukjinKwon commented Apr 21, 2021

Okay, looks like we can skip Spark 2.4.

@viirya
Copy link
Member

@viirya viirya commented Apr 21, 2021

Thanks for confirming. @harupy @HyukjinKwon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants