1
+ from datetime import timedelta
1
2
from typing import List , Union
2
3
3
4
import pytest
4
5
import yaml
5
6
from msrest import Serializer
6
-
7
- from azure .ai .ml ._restclient .v2023_04_01_preview .models import DataFactory
8
7
from test_utilities .utils import verify_entity_load_and_dump
9
8
10
9
from azure .ai .ml import load_compute
10
+ from azure .ai .ml ._restclient .v2023_04_01_preview .models import DataFactory
11
11
from azure .ai .ml ._restclient .v2023_08_01_preview .models import ComputeResource , ImageMetadata
12
12
from azure .ai .ml .constants ._compute import CustomApplicationDefaults
13
13
from azure .ai .ml .entities import (
17
17
KubernetesCompute ,
18
18
ManagedIdentityConfiguration ,
19
19
SynapseSparkCompute ,
20
- VirtualMachineCompute ,
21
20
UnsupportedCompute ,
21
+ VirtualMachineCompute ,
22
22
)
23
23
24
24
@@ -66,6 +66,9 @@ def test_compute_from_yaml(self):
66
66
)[0 ]
67
67
assert compute .ssh_settings .admin_username == "azureuser"
68
68
assert compute .identity .type == "user_assigned"
69
+ assert compute .idle_time_before_scale_down == 100
70
+ assert compute .min_instances == 0
71
+ assert compute .max_instances == 2
69
72
70
73
rest_intermediate = compute ._to_rest_object ()
71
74
assert rest_intermediate .properties .compute_type == "AmlCompute"
@@ -76,7 +79,9 @@ def test_compute_from_yaml(self):
76
79
assert rest_intermediate .tags is not None
77
80
assert rest_intermediate .tags ["test" ] == "true"
78
81
assert rest_intermediate .properties .disable_local_auth is False
79
- assert rest_intermediate .properties .properties .remote_login_port_public_access == "Enabled"
82
+ assert rest_intermediate .properties .properties .scale_settings .max_node_count == 2
83
+ assert rest_intermediate .properties .properties .scale_settings .min_node_count == 0
84
+ assert rest_intermediate .properties .properties .scale_settings .node_idle_time_before_scale_down == "PT1M40S"
80
85
81
86
serializer = Serializer ({"ComputeResource" : ComputeResource })
82
87
body = serializer .body (rest_intermediate , "ComputeResource" )
@@ -101,6 +106,9 @@ def test_aml_compute_from_yaml_with_disable_public_access(self):
101
106
assert rest_intermediate .properties .disable_local_auth is True
102
107
assert rest_intermediate .location == compute .location
103
108
assert rest_intermediate .properties .properties .remote_login_port_public_access == "NotSpecified"
109
+ assert rest_intermediate .properties .properties .scale_settings .max_node_count == 4
110
+ assert rest_intermediate .properties .properties .scale_settings .min_node_count == 0
111
+ assert rest_intermediate .properties .properties .scale_settings .node_idle_time_before_scale_down == "PT2M"
104
112
105
113
def test_aml_compute_from_yaml_with_creds_and_disable_public_access (self ):
106
114
compute : AmlCompute = load_compute ("tests/test_configs/compute/compute-aml-no-identity.yaml" )
@@ -345,6 +353,11 @@ def validate_no_public_ip(compute: Compute):
345
353
assert compute .enable_node_public_ip == False
346
354
compute_resource = compute ._to_rest_object ()
347
355
assert compute_resource .properties .properties .enable_node_public_ip == False
356
+ # AmlCompute _from_rest_object expects a timedelta object for node_idle_time_before_scale_down
357
+ if compute_resource .properties .compute_type == "AmlCompute" :
358
+ compute_resource .properties .properties .scale_settings .node_idle_time_before_scale_down = timedelta (
359
+ seconds = 120
360
+ )
348
361
compute_from_rest = Compute ._from_rest_object (compute_resource )
349
362
assert compute_from_rest .enable_node_public_ip == False
350
363
0 commit comments