Skip to content

Prevent using Azure Low Priority VMs when using Batch Managed mode #6267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions plugins/nf-azure/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ dependencies {
api('com.azure:azure-identity:1.15.1') {
exclude group: 'org.slf4j', module: 'slf4j-api'
}
api('com.azure.resourcemanager:azure-resourcemanager-batch:1.1.0-beta.4') {
exclude group: 'org.slf4j', module: 'slf4j-api'
}

// address security vulnerabilities
runtimeOnly 'io.netty:netty-handler:4.1.118.Final'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,36 @@ class AzBatchExecutor extends Executor implements ExtensionPoint {
}
}

protected void validateLowPriorityVMs() {
// Check if any pool has lowPriority enabled
def lowPriorityPools = config.batch().pools.findAll { poolName, poolOpts ->
poolOpts.lowPriority
}

if( lowPriorityPools ) {
def poolNames = lowPriorityPools.keySet().join(', ')

// Get the pool allocation mode to determine if low priority VMs are allowed
def poolAllocationMode = batchService.getPoolAllocationMode()
log.debug "[AZURE BATCH] Pool allocation mode determined as: ${poolAllocationMode}"

if( poolAllocationMode == 'BATCH_SERVICE' || poolAllocationMode == 'BatchService' ) {
throw new AbortOperationException(
"Low Priority VMs are not supported with Batch Managed pool allocation mode. " +
"Update your configuration to use standard VMs or switch to User Subscription mode. " +
"Pools: ${poolNames}."
)
} else if( poolAllocationMode == 'USER_SUBSCRIPTION' || poolAllocationMode == 'UserSubscription' ) {
// Low Priority VMs are still supported in User Subscription mode, proceed without warning
log.debug "[AZURE BATCH] User Subscription mode detected, allowing low priority VMs in pools: ${poolNames}"
} else {
// If we can't determine the pool allocation mode, show a warning but allow execution
log.warn "[AZURE BATCH] Unable to determine pool allocation mode. Low Priority VMs are configured in pools: ${poolNames}. " +
"Low Priority VMs may not be supported. Set 'azure.batch.subscriptionId' in your config or 'AZURE_SUBSCRIPTION_ID' environment variable for automatic detection."
}
}
}

protected void uploadBinDir() {
/*
* upload local binaries
Expand Down Expand Up @@ -120,6 +150,7 @@ class AzBatchExecutor extends Executor implements ExtensionPoint {
initBatchService()
validateWorkDir()
validatePathDir()
validateLowPriorityVMs()
uploadBinDir()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,103 @@ class AzBatchService implements Closeable {
return builder.buildClient()
}

/**
* Determines the pool allocation mode of the Azure Batch account
* @return The pool allocation mode ('BatchService' or 'UserSubscription'), or null if it cannot be determined
*/
@Memoized
protected String getPoolAllocationMode() {
try {
// Get batch account name from endpoint
final accountName = extractAccountName(config.batch().endpoint)
if (!accountName) {
log.debug "[AZURE BATCH] Cannot extract account name from endpoint"
return null
}

// Get subscription ID
final subscriptionId = config.batch().subscriptionId
if (!subscriptionId) {
log.debug "[AZURE BATCH] No subscription ID configured. Set azure.batch.subscriptionId or AZURE_SUBSCRIPTION_ID"
return null
}

// Get Azure credentials
final credential = getAzureCredential()
if (!credential) {
log.debug "[AZURE BATCH] No valid credentials for Azure Resource Manager"
return null
}

// Create BatchManager with proper configuration
final batchManager = createBatchManager(credential, subscriptionId)

// Find the batch account
return findBatchAccountPoolMode(batchManager, accountName)

} catch (Exception e) {
log.warn "[AZURE BATCH] Failed to determine pool allocation mode: ${e.message}", e
return null
}
}

/**
* Extract account name from batch endpoint URL
*/
private String extractAccountName(String endpoint) {
if (!endpoint) return null
// Format: https://accountname.region.batch.azure.com
return endpoint.split('\\.')[0].replace('https://', '')
}

/**
* Get Azure credentials based on configuration
*/
private TokenCredential getAzureCredential() {
if (config.managedIdentity().isConfigured()) {
return createBatchCredentialsWithManagedIdentity()
} else if (config.activeDirectory().isConfigured()) {
return createBatchCredentialsWithServicePrincipal()
}
return null
}

/**
* Create and configure BatchManager
*/
private com.azure.resourcemanager.batch.BatchManager createBatchManager(TokenCredential credential, String subscriptionId) {
// AzureProfile requires: (tenantId, subscriptionId, environment)
// We pass null for tenantId to use the default from the credential
final profile = new com.azure.core.management.profile.AzureProfile(
null,
subscriptionId,
com.azure.core.management.AzureEnvironment.AZURE
)

// Use configure().authenticate() pattern to ensure proper initialization
return com.azure.resourcemanager.batch.BatchManager
.configure()
.authenticate(credential, profile)
}

/**
* Find batch account and return its pool allocation mode
*/
private String findBatchAccountPoolMode(com.azure.resourcemanager.batch.BatchManager batchManager, String accountName) {
log.debug "[AZURE BATCH] Searching for account '${accountName}'"

for (batchAccount in batchManager.batchAccounts().list()) {
if (batchAccount.name() == accountName) {
final poolMode = batchAccount.poolAllocationMode()
log.debug "[AZURE BATCH] Found account with pool allocation mode: ${poolMode}"
return poolMode?.toString()
}
}

log.debug "[AZURE BATCH] Account '${accountName}' not found in subscription"
return null
}

AzTaskKey submitTask(TaskRun task) {
final poolId = getOrCreatePool(task)
final jobId = getOrCreateJob(poolId, task)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class AzBatchOpts implements CloudTransferOptions {
String accountKey
String endpoint
String location
String subscriptionId
Boolean autoPoolMode
Boolean allowPoolCreation
Boolean terminateJobsOnCompletion
Expand All @@ -67,6 +68,7 @@ class AzBatchOpts implements CloudTransferOptions {
accountKey = config.accountKey ?: sysEnv.get('AZURE_BATCH_ACCOUNT_KEY')
endpoint = config.endpoint
location = config.location
subscriptionId = config.subscriptionId ?: sysEnv.get('AZURE_SUBSCRIPTION_ID')
autoPoolMode = config.autoPoolMode
allowPoolCreation = config.allowPoolCreation
terminateJobsOnCompletion = config.terminateJobsOnCompletion != Boolean.FALSE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package nextflow.cloud.azure.batch

import nextflow.Session
import nextflow.cloud.azure.config.AzConfig
import nextflow.cloud.azure.config.AzBatchOpts
import nextflow.cloud.azure.config.AzPoolOpts
import nextflow.exception.AbortOperationException
import spock.lang.Specification

/**
* Test for AzBatchExecutor validation logic
*/
class AzBatchExecutorTest extends Specification {

def 'should validate low priority VMs for BatchService allocation mode'() {
given:
def CONFIG = [
batch: [
endpoint: 'https://testaccount.eastus.batch.azure.com',
pools: [
'pool1': [vmType: 'Standard_D2_v2', lowPriority: true],
'pool2': [vmType: 'Standard_D2_v2', lowPriority: false]
]
]
]

and:
def config = new AzConfig(CONFIG)
def batchService = Mock(AzBatchService) {
getPoolAllocationMode() >> 'BatchService'
}

and:
def executor = new AzBatchExecutor()
executor.config = config
executor.batchService = batchService

when:
executor.validateLowPriorityVMs()

then:
def e = thrown(AbortOperationException)
e.message.contains('Low Priority VMs are not supported with Batch Managed pool allocation mode')
e.message.contains('Update your configuration to use standard VMs or switch to User Subscription mode')
e.message.contains('pool1')
}

def 'should allow low priority VMs for UserSubscription allocation mode'() {
given:
def CONFIG = [
batch: [
endpoint: 'https://testaccount.eastus.batch.azure.com',
pools: [
'pool1': [vmType: 'Standard_D2_v2', lowPriority: true]
]
]
]

and:
def config = new AzConfig(CONFIG)
def batchService = Mock(AzBatchService) {
getPoolAllocationMode() >> 'UserSubscription'
}

and:
def executor = new AzBatchExecutor()
executor.config = config
executor.batchService = batchService

when:
executor.validateLowPriorityVMs()

then:
noExceptionThrown()
}

def 'should handle unknown allocation mode gracefully'() {
given:
def CONFIG = [
batch: [
endpoint: 'https://testaccount.eastus.batch.azure.com',
pools: [
'pool1': [vmType: 'Standard_D2_v2', lowPriority: true]
]
]
]

and:
def config = new AzConfig(CONFIG)
def batchService = Mock(AzBatchService) {
getPoolAllocationMode() >> null
}

and:
def executor = new AzBatchExecutor()
executor.config = config
executor.batchService = batchService

when:
executor.validateLowPriorityVMs()

then:
noExceptionThrown()
}

def 'should not validate when no low priority VMs configured'() {
given:
def CONFIG = [
batch: [
endpoint: 'https://testaccount.eastus.batch.azure.com',
pools: [
'pool1': [vmType: 'Standard_D2_v2', lowPriority: false],
'pool2': [vmType: 'Standard_D2_v2']
]
]
]

and:
def config = new AzConfig(CONFIG)
def batchService = Mock(AzBatchService) {
getPoolAllocationMode() >> 'BatchService'
}

and:
def executor = new AzBatchExecutor()
executor.config = config
executor.batchService = batchService

when:
executor.validateLowPriorityVMs()

then:
noExceptionThrown()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,26 @@ class AzBatchOptsTest extends Specification {
then:
opts3.jobMaxWallClockTime.toString() == '12h'
}

def 'should set subscription ID from config or environment' () {
when:
def opts1 = new AzBatchOpts([:], [:])
then:
opts1.subscriptionId == null

when:
def opts2 = new AzBatchOpts([subscriptionId: 'config-sub-id'], [:])
then:
opts2.subscriptionId == 'config-sub-id'

when:
def opts3 = new AzBatchOpts([:], [AZURE_SUBSCRIPTION_ID: 'env-sub-id'])
then:
opts3.subscriptionId == 'env-sub-id'

when:
def opts4 = new AzBatchOpts([subscriptionId: 'config-sub-id'], [AZURE_SUBSCRIPTION_ID: 'env-sub-id'])
then:
opts4.subscriptionId == 'config-sub-id' // config takes precedence over environment
}
}