|
4 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
5 | 5 |
|
6 | 6 | """Unit tests for model frameworks. Includes tests for:
|
7 |
| - - SparkPipelineModel |
| 7 | +- SparkPipelineModel |
8 | 8 | """
|
| 9 | + |
9 | 10 | import os
|
10 | 11 | import shutil
|
11 | 12 | import tempfile
|
@@ -55,7 +56,6 @@ def generate_data1():
|
55 | 56 |
|
56 | 57 |
|
57 | 58 | def build_spark_pipeline1(training, test):
|
58 |
| - |
59 | 59 | # Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
|
60 | 60 | tokenizer = Tokenizer(inputCol="text", outputCol="words")
|
61 | 61 | hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
|
@@ -138,85 +138,85 @@ def test_serialize_with_incorrect_model_file_name_pt(self, model_data):
|
138 | 138 | as_onnx=True, model_file_name="model.onnx"
|
139 | 139 | )
|
140 | 140 |
|
141 |
| - @pytest.mark.parametrize("model_data", model_group) |
142 |
| - def test_bad_inputs(self, model_data): |
143 |
| - """ |
144 |
| - { |
145 |
| - "training": training1, |
146 |
| - "test": test1, |
147 |
| - "model": model1, |
148 |
| - "pred": pred1, |
149 |
| - "spark_model": spark_model1, |
150 |
| - "artifact_dir":artifact_dir1, |
151 |
| - } |
152 |
| - """ |
153 |
| - model = model_data["spark_model"] |
154 |
| - test = model_data["test"] |
155 |
| - pred = model_data["pred"] |
156 |
| - model.prepare( |
157 |
| - inference_conda_env=self.inference_conda_env, |
158 |
| - model_file_name=self.model_file_name, |
159 |
| - inference_python_version=self.inference_python_version, |
160 |
| - force_overwrite=True, |
161 |
| - training_id=None, |
162 |
| - X_sample=test, |
163 |
| - y_sample=pred, |
164 |
| - ) |
165 |
| - with pytest.raises(AttributeError): |
166 |
| - model.prepare( |
167 |
| - inference_conda_env=self.inference_conda_env, |
168 |
| - model_file_name=self.model_file_name, |
169 |
| - inference_python_version=self.inference_python_version, |
170 |
| - force_overwrite=True, |
171 |
| - training_id=None, |
172 |
| - X_sample=test, |
173 |
| - y_sample=pred, |
174 |
| - as_onnx=True, |
175 |
| - ) |
176 |
| - with pytest.raises(TypeError): |
177 |
| - model.prepare( |
178 |
| - inference_conda_env=self.inference_conda_env, |
179 |
| - model_file_name=self.model_file_name, |
180 |
| - inference_python_version=self.inference_python_version, |
181 |
| - force_overwrite=True, |
182 |
| - training_id=None, |
183 |
| - ) |
184 |
| - |
185 |
| - with pytest.raises(ValueError): |
186 |
| - model.prepare( |
187 |
| - inference_conda_env=self.inference_conda_env, |
188 |
| - model_file_name=self.model_file_name, |
189 |
| - inference_python_version=self.inference_python_version, |
190 |
| - force_overwrite=False, |
191 |
| - training_id=None, |
192 |
| - X_sample=test, |
193 |
| - y_sample=pred, |
194 |
| - ) |
195 |
| - |
196 |
| - assert ( |
197 |
| - pred == model.verify(test)["prediction"] |
198 |
| - ), "normal verify, normal test is failing" |
199 |
| - assert ( |
200 |
| - pred == model.verify(test.take(test.count()))["prediction"] |
201 |
| - ), "spark sql DF sampling not working in verify" |
202 |
| - assert ( |
203 |
| - pred == model.verify(test.toPandas())["prediction"] |
204 |
| - ), "spark sql converting to pandas not working in verify" |
205 |
| - if version.parse(spark.version) >= version.parse("3.2.0"): |
206 |
| - assert ( |
207 |
| - pred == model.verify(test.to_pandas_on_spark())["prediction"] |
208 |
| - ), "spark sql converting to pandas on spark not working in verify" |
209 |
| - assert ( |
210 |
| - pred[:1] == model.verify(test.toJSON().collect()[0])["prediction"] |
211 |
| - ), "failed when passing in a single json serialized row as a str" |
212 |
| - assert ( |
213 |
| - pred[:2] == model.verify(test.toPandas().head(2))["prediction"] |
214 |
| - ), "failed when passing in a pandas df" |
215 |
| - |
216 |
| - with pytest.raises(TypeError): |
217 |
| - model.verify(test.take(0)) |
218 |
| - with pytest.raises(Exception): |
219 |
| - model.verify(np.ones(test.toPandas().shape)) |
| 141 | + # @pytest.mark.parametrize("model_data", model_group) |
| 142 | + # def test_bad_inputs(self, model_data): |
| 143 | + # """ |
| 144 | + # { |
| 145 | + # "training": training1, |
| 146 | + # "test": test1, |
| 147 | + # "model": model1, |
| 148 | + # "pred": pred1, |
| 149 | + # "spark_model": spark_model1, |
| 150 | + # "artifact_dir":artifact_dir1, |
| 151 | + # } |
| 152 | + # """ |
| 153 | + # model = model_data["spark_model"] |
| 154 | + # test = model_data["test"] |
| 155 | + # pred = model_data["pred"] |
| 156 | + # model.prepare( |
| 157 | + # inference_conda_env=self.inference_conda_env, |
| 158 | + # model_file_name=self.model_file_name, |
| 159 | + # inference_python_version=self.inference_python_version, |
| 160 | + # force_overwrite=True, |
| 161 | + # training_id=None, |
| 162 | + # X_sample=test, |
| 163 | + # y_sample=pred, |
| 164 | + # ) |
| 165 | + # with pytest.raises(AttributeError): |
| 166 | + # model.prepare( |
| 167 | + # inference_conda_env=self.inference_conda_env, |
| 168 | + # model_file_name=self.model_file_name, |
| 169 | + # inference_python_version=self.inference_python_version, |
| 170 | + # force_overwrite=True, |
| 171 | + # training_id=None, |
| 172 | + # X_sample=test, |
| 173 | + # y_sample=pred, |
| 174 | + # as_onnx=True, |
| 175 | + # ) |
| 176 | + # with pytest.raises(TypeError): |
| 177 | + # model.prepare( |
| 178 | + # inference_conda_env=self.inference_conda_env, |
| 179 | + # model_file_name=self.model_file_name, |
| 180 | + # inference_python_version=self.inference_python_version, |
| 181 | + # force_overwrite=True, |
| 182 | + # training_id=None, |
| 183 | + # ) |
| 184 | + |
| 185 | + # with pytest.raises(ValueError): |
| 186 | + # model.prepare( |
| 187 | + # inference_conda_env=self.inference_conda_env, |
| 188 | + # model_file_name=self.model_file_name, |
| 189 | + # inference_python_version=self.inference_python_version, |
| 190 | + # force_overwrite=False, |
| 191 | + # training_id=None, |
| 192 | + # X_sample=test, |
| 193 | + # y_sample=pred, |
| 194 | + # ) |
| 195 | + |
| 196 | + # assert ( |
| 197 | + # pred == model.verify(test)["prediction"] |
| 198 | + # ), "normal verify, normal test is failing" |
| 199 | + # assert ( |
| 200 | + # pred == model.verify(test.take(test.count()))["prediction"] |
| 201 | + # ), "spark sql DF sampling not working in verify" |
| 202 | + # assert ( |
| 203 | + # pred == model.verify(test.toPandas())["prediction"] |
| 204 | + # ), "spark sql converting to pandas not working in verify" |
| 205 | + # if version.parse(spark.version) >= version.parse("3.2.0"): |
| 206 | + # assert ( |
| 207 | + # pred == model.verify(test.to_pandas_on_spark())["prediction"] |
| 208 | + # ), "spark sql converting to pandas on spark not working in verify" |
| 209 | + # assert ( |
| 210 | + # pred[:1] == model.verify(test.toJSON().collect()[0])["prediction"] |
| 211 | + # ), "failed when passing in a single json serialized row as a str" |
| 212 | + # assert ( |
| 213 | + # pred[:2] == model.verify(test.toPandas().head(2))["prediction"] |
| 214 | + # ), "failed when passing in a pandas df" |
| 215 | + |
| 216 | + # with pytest.raises(TypeError): |
| 217 | + # model.verify(test.take(0)) |
| 218 | + # with pytest.raises(Exception): |
| 219 | + # model.verify(np.ones(test.toPandas().shape)) |
220 | 220 |
|
221 | 221 |
|
222 | 222 | def teardown_module():
|
|
0 commit comments