Skip to content

Commit 35ce161

Browse files
Add NVIDIA NGC image aliases (#384)
* Add NVIDIA NGC image alias * Apply suggestions from code review * Restyled by gofmt Co-authored-by: Restyled.io <commits@restyled.io>
1 parent d8de27f commit 35ce161

File tree

4 files changed

+36
-7
lines changed

4 files changed

+36
-7
lines changed

task/aws/resources/data_source_image.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func (i *Image) Read(ctx context.Context) error {
3434
image := i.Identifier
3535
images := map[string]string{
3636
"ubuntu": "ubuntu@099720109477:x86_64:*ubuntu/images/hvm-ssd/ubuntu-focal-20.04*",
37+
"nvidia": "ubuntu@679593333241:x86_64:NVIDIA Deep Learning AMI v21.02.2-*",
3738
}
3839
if val, ok := images[image]; ok {
3940
image = val

task/az/resources/resource_virtual_machine_scale_set.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,13 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error {
8282
image := v.Attributes.Environment.Image
8383
images := map[string]string{
8484
"ubuntu": "ubuntu@Canonical:0001-com-ubuntu-server-focal:20_04-lts:latest",
85+
"nvidia": "ubuntu@nvidia:ngc_base_image_version_b:gen2_21-11-0:latest#plan",
8586
}
8687
if val, ok := images[image]; ok {
8788
image = val
8889
}
8990

90-
imageParts := regexp.MustCompile(`^([^@]+)@([^:]+):([^:]+):([^:]+):([^:]+)$`).FindStringSubmatch(image)
91+
imageParts := regexp.MustCompile(`^([^@]+)@([^:]+):([^:]+):([^:]+):([^:]+)(:?(#plan)?)$`).FindStringSubmatch(image)
9192
if imageParts == nil {
9293
return errors.New("invalid machine image format: use publisher:offer:sku:version")
9394
}
@@ -97,6 +98,7 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error {
9798
offer := imageParts[3]
9899
sku := imageParts[4]
99100
version := imageParts[5]
101+
plan := imageParts[6]
100102

101103
size := v.Attributes.Size.Machine
102104
sizes := map[string]string{
@@ -185,6 +187,14 @@ func (v *VirtualMachineScaleSet) Create(ctx context.Context) error {
185187
},
186188
}
187189

190+
if plan == "#plan" {
191+
settings.Plan = &compute.Plan{
192+
Publisher: to.StringPtr(publisher),
193+
Product: to.StringPtr(offer),
194+
Name: to.StringPtr(sku),
195+
}
196+
}
197+
188198
spot := v.Attributes.Spot
189199
if spot >= 0 {
190200
if spot == 0 {

task/gcp/resources/data_source_image.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,25 +32,35 @@ func (i *Image) Read(ctx context.Context) error {
3232
image := i.Identifier
3333
images := map[string]string{
3434
"ubuntu": "ubuntu@ubuntu-os-cloud/ubuntu-2004-lts",
35+
"nvidia": "ubuntu@nvidia-ngc-public/nvidia-gpu-cloud-image-20211105",
3536
}
3637
if val, ok := images[image]; ok {
3738
image = val
3839
}
3940

4041
match := regexp.MustCompile(`^([^@]+)@([^/]+)/([^/]+)$`).FindStringSubmatch(image)
4142
if match == nil {
42-
return common.NotFoundError
43+
return errors.New("wrong image name")
4344
}
4445

4546
i.Attributes.SSHUser = match[1]
4647
project := match[2]
47-
family := match[3]
48+
imageOrFamily := match[3]
4849

49-
resource, err := i.Client.Services.Compute.Images.GetFromFamily(project, family).Do()
50+
resource, err := i.Client.Services.Compute.Images.Get(project, imageOrFamily).Do()
5051
if err != nil {
5152
var e *googleapi.Error
5253
if errors.As(err, &e) && e.Code == 404 {
53-
return common.NotFoundError
54+
resource, err := i.Client.Services.Compute.Images.GetFromFamily(project, imageOrFamily).Do()
55+
if err != nil {
56+
var e *googleapi.Error
57+
if errors.As(err, &e) && e.Code == 404 {
58+
return common.NotFoundError
59+
}
60+
return err
61+
}
62+
i.Resource = resource
63+
return nil
5464
}
5565
return err
5666
}

task/k8s/resources/resource_job.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,19 @@ func (j *Job) Create(ctx context.Context) error {
6868
"l+v100": "32-256000+nvidia-tesla-v100*4",
6969
"xl+v100": "64-512000+nvidia-tesla-v100*8",
7070
}
71-
7271
if val, ok := sizes[size]; ok {
7372
size = val
7473
}
7574

75+
image := j.Attributes.Task.Environment.Image
76+
images := map[string]string{
77+
"ubuntu": "ubuntu",
78+
"nvidia": "nvidia/cuda",
79+
}
80+
if val, ok := images[image]; ok {
81+
image = val
82+
}
83+
7684
match := regexp.MustCompile(`^(\d+)-(\d+)(?:\+([^*]+)\*([1-9]\d*))?$`).FindStringSubmatch(size)
7785
if match == nil {
7886
return common.NotFoundError
@@ -206,7 +214,7 @@ func (j *Job) Create(ctx context.Context) error {
206214
Containers: []kubernetes_core.Container{
207215
{
208216
Name: j.Identifier,
209-
Image: j.Attributes.Task.Environment.Image,
217+
Image: image,
210218
Resources: kubernetes_core.ResourceRequirements{
211219
Limits: jobLimits,
212220
Requests: kubernetes_core.ResourceList{

0 commit comments

Comments
 (0)