2323ARISING IN ANY WAY OUT OF THE USE OF THE SOFTWARE CODE, EVEN IF ADVISED OF THE
2424POSSIBILITY OF SUCH DAMAGE.
2525"""
26+ import json
2627import os
2728import sys
2829import argparse
@@ -69,8 +70,9 @@ def main():
6970 "--model_name" ,
7071 type = str ,
7172 help = "Name of the Model" ,
72- default = "sklearn_regression_model .pkl" ,
73+ default = "diabetes_model .pkl" ,
7374 )
75+
7476 parser .add_argument (
7577 "--step_input" ,
7678 type = str ,
@@ -85,40 +87,58 @@ def main():
8587 model_name = args .model_name
8688 model_path = args .step_input
8789
90+ print ("Getting registration parameters" )
91+
92+ # Load the registration parameters from the parameters file
93+ with open ("parameters.json" ) as f :
94+ pars = json .load (f )
95+ try :
96+ register_args = pars ["registration" ]
97+ except KeyError :
98+ print ("Could not load registration values from file" )
99+ register_args = {"tags" : []}
100+
101+ model_tags = {}
102+ for tag in register_args ["tags" ]:
103+ try :
104+ mtag = run .parent .get_metrics ()[tag ]
105+ model_tags [tag ] = mtag
106+ except KeyError :
107+ print (f"Could not find { tag } metric on parent run." )
108+
88109 # load the model
89110 print ("Loading model from " + model_path )
90111 model_file = os .path .join (model_path , model_name )
91112 model = joblib .load (model_file )
92- model_mse = run .parent .get_metrics ()["mse" ]
93113 parent_tags = run .parent .get_tags ()
94114 try :
95115 build_id = parent_tags ["BuildId" ]
96116 except KeyError :
97117 build_id = None
98118 print ("BuildId tag not found on parent run." )
99- print ("Tags present: {parent_tags}" )
119+ print (f "Tags present: { parent_tags } " )
100120 try :
101121 build_uri = parent_tags ["BuildUri" ]
102122 except KeyError :
103123 build_uri = None
104124 print ("BuildUri tag not found on parent run." )
105- print ("Tags present: {parent_tags}" )
125+ print (f "Tags present: { parent_tags } " )
106126
107127 if (model is not None ):
108128 dataset_id = parent_tags ["dataset_id" ]
109129 if (build_id is None ):
110130 register_aml_model (
111131 model_file ,
112132 model_name ,
113- model_mse ,
133+ model_tags ,
114134 exp ,
115135 run_id ,
116136 dataset_id )
117137 elif (build_uri is None ):
118138 register_aml_model (
119139 model_file ,
120140 model_name ,
121- model_mse ,
141+ model_tags ,
122142 exp ,
123143 run_id ,
124144 dataset_id ,
@@ -127,7 +147,7 @@ def main():
127147 register_aml_model (
128148 model_file ,
129149 model_name ,
130- model_mse ,
150+ model_tags ,
131151 exp ,
132152 run_id ,
133153 dataset_id ,
@@ -152,7 +172,7 @@ def model_already_registered(model_name, exp, run_id):
152172def register_aml_model (
153173 model_path ,
154174 model_name ,
155- model_mse ,
175+ model_tags ,
156176 exp ,
157177 run_id ,
158178 dataset_id ,
@@ -162,8 +182,8 @@ def register_aml_model(
162182 try :
163183 tagsValue = {"area" : "diabetes_regression" ,
164184 "run_id" : run_id ,
165- "experiment_name" : exp .name ,
166- "mse" : model_mse }
185+ "experiment_name" : exp .name }
186+ tagsValue . update ( model_tags )
167187 if (build_id != 'none' ):
168188 model_already_registered (model_name , exp , run_id )
169189 tagsValue ["BuildId" ] = build_id
0 commit comments