pyspark 在Databricks上使用MLflow记录Spark模型时出错- mlflow.spark.log_model()

lpwwtiir  于 2024-01-06  发布在  Spark
关注(0)|答案(1)|浏览(195)

我尝试使用下面的代码片段记录Spark模型。模型度量和参数保存在ML流运行中,但模型本身不会保存在artefacts下。但是,在同一环境中使用model.sklearn.log_model()记录Scikit-learn模型时,模型会成功保存。
环境:Databricks 10.4 LTS ML集群

  1. train, test = train_test_random_split(conf, data)
  2. experiment_name = "/mlflow_experiments/debug_spark_model"
  3. mlflow.set_experiment(experiment_name)
  4. evaluator = BinaryClassificationEvaluator()
  5. rf = RandomForestClassifier()
  6. param_grid = (
  7. ParamGridBuilder()
  8. .addGrid(rf.numTrees,[15)
  9. .addGrid(rf.maxDepth, [6])
  10. .addGrid(
  11. rf.minInstancesPerNode,
  12. [7],
  13. )
  14. .build()
  15. )
  16. cv = CrossValidator(
  17. estimator=rf,
  18. estimatorParamMaps=param_grid,
  19. evaluator=BinaryClassificationEvaluator(metricName="areaUnderROC"),
  20. numFolds=10,
  21. )
  22. cv_model = cv.fit(train)
  23. # best model
  24. model = cv_model.bestModel
  25. model_params_best = {
  26. "numTrees": cv_model.getEstimatorParamMaps()[np.argmax(cv_model.avgMetrics)][
  27. cv_model.bestModel.numTrees
  28. ],
  29. "maxDepth": cv_model.getEstimatorParamMaps()[np.argmax(cv_model.avgMetrics)][
  30. cv_model.bestModel.maxDepth
  31. ],
  32. "minInstancesPerNode": cv_model.getEstimatorParamMaps()[
  33. np.argmax(cv_model.avgMetrics)
  34. ][cv_model.bestModel.minInstancesPerNode],
  35. }
  36. model_metrics_best, artifacts_best, predicted_df_best = train_model(
  37. model, train, test, evaluator
  38. )
  39. with mlflow.start_run(run_name="debug_run_1"):
  40. run_id = mlflow.active_run().info.run_id
  41. mlflow.log_params(model_params_best)
  42. mlflow.log_metrics(model_metrics_best)
  43. #debug 1
  44. artifact_path = "best_model"
  45. mlflow.spark.log_model(spark_model = model, artifact_path = artifact_path)
  46. source = get_artifact_uri(run_id=run_id, artifact_path=artifact_path)

字符串
它给出了下面的错误。
Copyright © 2018 - 2019 www.qqq.com All Rights Reserved.粤ICP备15047777号-1技术支持:中企动力
x1c 0d1x的数据
我感谢任何调试方向或解决方案有关此错误。

brjng4g3

brjng4g31#

找到解决此错误或大多数与mlflowdbfs相关的错误的方法。
禁用mlflowdbfs在Databricks ML MySQL集群中的工作上述错误。另一个选项将使用正常的Databricks MySQL集群。

  1. import os
  2. os.environ["DISABLE_MLFLOWDBFS"] = "true"

字符串

相关问题