## Licensed to the Apache Software Foundation (ASF) under one or more# contributor license agreements. See the NOTICE file distributed with# this work for additional information regarding copyright ownership.# The ASF licenses this file to You under the Apache License, Version 2.0# (the "License"); you may not use this file except in compliance with# the License. You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#importjsonimportshutilimportosimporttempfileimporttimefromurllib.parseimporturlparsefromtypingimportAny,Dict,Listfrompyspark.ml.baseimportParamsfrompyspark.sqlimportSparkSessionfrompyspark.sql.utilsimportis_remotefrompysparkimport__version__aspyspark_version_META_DATA_FILE_NAME="metadata.json"def_copy_file_from_local_to_fs(local_path:str,dest_path:str)->None:session=SparkSession.active()ifis_remote():session.copyFromLocalToFs(local_path,dest_path)else:jvm=session.sparkContext._gateway.jvm# type: ignore[union-attr]jvm.org.apache.spark.ml.python.MLUtil.copyFileFromLocalToFs(local_path,dest_path)def_copy_dir_from_local_to_fs(local_path:str,dest_path:str)->None:""" Copy directory from local path to cloud storage path. Limitation: Currently only one level directory is supported. """assertos.path.isdir(local_path)file_list=os.listdir(local_path)forfile_nameinfile_list:file_path=os.path.join(local_path,file_name)dest_file_path=os.path.join(dest_path,file_name)assertos.path.isfile(file_path)_copy_file_from_local_to_fs(file_path,dest_file_path)def_get_class(clazz:str)->Any:""" Loads Python class from its name. """parts=clazz.split(".")module=".".join(parts[:-1])m=__import__(module,fromlist=[parts[-1]])returngetattr(m,parts[-1])
[docs]classParamsReadWrite(Params):""" The base interface Estimator / Transformer / Model / Evaluator needs to inherit for supporting saving and loading. """def_get_extra_metadata(self)->Any:""" Returns exta metadata of the instance """returnNonedef_get_skip_saving_params(self)->List[str]:""" Returns params to be skipped when saving metadata. """return[]def_get_metadata_to_save(self)->Dict[str,Any]:""" Extract metadata of Estimator / Transformer / Model / Evaluator instance. """extra_metadata=self._get_extra_metadata()skipped_params=self._get_skip_saving_params()uid=self.uidcls=self.__module__+"."+self.__class__.__name__# User-supplied param valuesparams=self._paramMapjson_params={}skipped_params=skipped_paramsor[]forpinparams:ifp.namenotinskipped_params:json_params[p.name]=params[p]# Default param valuesjson_default_params={}forpinself._defaultParamMap:json_default_params[p.name]=self._defaultParamMap[p]metadata={"class":cls,"timestamp":int(round(time.time()*1000)),"sparkVersion":pyspark_version,"uid":uid,"paramMap":json_params,"defaultParamMap":json_default_params,"type":"spark_connect",}ifextra_metadataisnotNone:assertisinstance(extra_metadata,dict)metadata["extra"]=extra_metadatareturnmetadatadef_load_extra_metadata(self,metadata:Dict[str,Any])->None:""" Load extra metadata attribute from metadata json object. """passdef_save_to_local(self,path:str)->None:metadata=self._save_to_node_path(path,[])withopen(os.path.join(path,_META_DATA_FILE_NAME),"w")asfp:json.dump(metadata,fp)
[docs]defsaveToLocal(self,path:str,*,overwrite:bool=False)->None:""" Save Estimator / Transformer / Model / Evaluator to provided local path. .. versionadded:: 3.5.0 """ifos.path.exists(path):ifoverwrite:ifos.path.isdir(path):shutil.rmtree(path)else:os.remove(path)else:raiseValueError(f"The path {path} already exists.")os.makedirs(path)self._save_to_local(path)
@classmethoddef_load_metadata(cls,metadata:Dict[str,Any])->"Params":if"type"notinmetadataormetadata["type"]!="spark_connect":raiseRuntimeError("The saved data is not saved by ML algorithm implemented in 'pyspark.ml.connect' ""module.")class_name=metadata["class"]instance=_get_class(class_name)()instance._resetUid(metadata["uid"])# Set user-supplied param valuesforparamNameinmetadata["paramMap"]:param=instance.getParam(paramName)paramValue=metadata["paramMap"][paramName]instance.set(param,paramValue)forparamNameinmetadata["defaultParamMap"]:paramValue=metadata["defaultParamMap"][paramName]instance._setDefault(**{paramName:paramValue})if"extra"inmetadata:instance._load_extra_metadata(metadata["extra"])returninstance@classmethoddef_load_instance_from_metadata(cls,metadata:Dict[str,Any],path:str)->Any:instance=cls._load_metadata(metadata)ifisinstance(instance,CoreModelReadWrite):core_model_path=metadata["core_model_path"]instance._load_core_model(os.path.join(path,core_model_path))ifisinstance(instance,MetaAlgorithmReadWrite):instance._load_meta_algorithm(path,metadata)returninstance@classmethoddef_load_from_local(cls,path:str)->"Params":withopen(os.path.join(path,_META_DATA_FILE_NAME),"r")asfp:metadata=json.load(fp)returncls._load_instance_from_metadata(metadata,path)
[docs]@classmethoddefloadFromLocal(cls,path:str)->"Params":""" Load Estimator / Transformer / Model / Evaluator from provided local path. .. versionadded:: 3.5.0 """returncls._load_from_local(path)
def_save_to_node_path(self,root_path:str,node_path:List[str])->Any:""" Save the instance to provided node path, and return the node metadata. """ifisinstance(self,MetaAlgorithmReadWrite):metadata=self._save_meta_algorithm(root_path,node_path)else:metadata=self._get_metadata_to_save()ifisinstance(self,CoreModelReadWrite):core_model_path=".".join(node_path+[self._get_core_model_filename()])self._save_core_model(os.path.join(root_path,core_model_path))metadata["core_model_path"]=core_model_pathreturnmetadata
[docs]defsave(self,path:str,*,overwrite:bool=False)->None:""" Save Estimator / Transformer / Model / Evaluator to provided cloud storage path. .. versionadded:: 3.5.0 """session=SparkSession.active()path_exist=Truetry:session.read.format("binaryFile").load(path).head()exceptExceptionase:if"Path does not exist"instr(e):path_exist=Falseelse:# Unexpected error.raiseeifpath_existandnotoverwrite:raiseValueError(f"The path {path} already exists.")tmp_local_dir=tempfile.mkdtemp(prefix="pyspark_ml_model_")try:self._save_to_local(tmp_local_dir)_copy_dir_from_local_to_fs(tmp_local_dir,path)finally:shutil.rmtree(tmp_local_dir,ignore_errors=True)
[docs]@classmethoddefload(cls,path:str)->"Params":""" Load Estimator / Transformer / Model / Evaluator from provided cloud storage path. .. versionadded:: 3.5.0 """session=SparkSession.active()tmp_local_dir=tempfile.mkdtemp(prefix="pyspark_ml_model_")try:file_data_df=session.read.format("binaryFile").load(path)forrowinfile_data_df.toLocalIterator():file_name=os.path.basename(urlparse(row.path).path)file_content=bytes(row.content)withopen(os.path.join(tmp_local_dir,file_name),"wb")asf:f.write(file_content)returncls._load_from_local(tmp_local_dir)finally:shutil.rmtree(tmp_local_dir,ignore_errors=True)
[docs]classCoreModelReadWrite:def_get_core_model_filename(self)->str:""" Returns the name of the file for saving the core model. """raiseNotImplementedError()def_save_core_model(self,path:str)->None:""" Save the core model to provided local path. Different pyspark models contain different type of core model, e.g. for LogisticRegressionModel, its core model is a pytorch model. """raiseNotImplementedError()def_load_core_model(self,path:str)->None:""" Load the core model from provided local path. """raiseNotImplementedError()
[docs]classMetaAlgorithmReadWrite(ParamsReadWrite):""" Meta-algorithm such as pipeline and cross validator must implement this interface. """def_get_child_stages(self)->List[Any]:raiseNotImplementedError()def_save_meta_algorithm(self,root_path:str,node_path:List[str])->Dict[str,Any]:raiseNotImplementedError()def_load_meta_algorithm(self,root_path:str,node_metadata:Dict[str,Any])->None:raiseNotImplementedError()@staticmethoddef_get_all_nested_stages(instance:Any)->List[Any]:ifisinstance(instance,MetaAlgorithmReadWrite):child_stages=instance._get_child_stages()else:child_stages=[]nested_stages=[]forstageinchild_stages:nested_stages.extend(MetaAlgorithmReadWrite._get_all_nested_stages(stage))return[instance]+nested_stages
[docs]@staticmethoddefget_uid_map(instance:Any)->Dict[str,Any]:all_nested_stages=MetaAlgorithmReadWrite._get_all_nested_stages(instance)uid_map={stage.uid:stageforstageinall_nested_stages}iflen(all_nested_stages)!=len(uid_map):raiseRuntimeError(f"{instance.__class__.__module__}.{instance.__class__.__name__}"f"is a compound estimator with stages with duplicate "f"UIDs. List of UIDs: {list(uid_map.keys())}.")returnuid_map