Skip to content

fix: prevent invalid block trigger configurations from being accepted #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ func (n *Engine) MustStart() error {

n.logger.Info("🚀 Engine started successfully", "active_tasks_loaded", loadedCount)

// Detect and handle any invalid tasks that may have been created before validation was fixed
if err := n.DetectAndHandleInvalidTasks(); err != nil {
n.logger.Error("Failed to handle invalid tasks during startup", "error", err)
// Don't fail startup, but log the error
}

// Start the batch notification processor
go n.processBatchedNotifications()

Expand Down Expand Up @@ -1252,20 +1258,23 @@ func (n *Engine) AggregateChecksResult(address string, payload *avsproto.NotifyT

// AggregateChecksResultWithState processes operator trigger notifications and returns execution state info
func (n *Engine) AggregateChecksResultWithState(address string, payload *avsproto.NotifyTriggersReq) (*ExecutionState, error) {
// Acquire lock once for all map operations to reduce lock contention
n.lock.Lock()

n.logger.Debug("processing aggregator check hit", "operator", address, "task_id", payload.TaskId)

// Update operator task tracking
if state, exists := n.trackSyncedTasks[address]; exists {
state.TaskID[payload.TaskId] = true
}

n.logger.Debug("processed aggregator check hit", "operator", address, "task_id", payload.TaskId)
n.lock.Unlock()

// Get task information to determine execution state
task, exists := n.tasks[payload.TaskId]

if !exists {
// Task not found in memory - try database lookup
n.lock.Unlock() // Release lock for database operation

dbTask, dbErr := n.GetTaskByID(payload.TaskId)
if dbErr != nil {
// Task not found in database either - this is likely a stale operator notification
Expand Down Expand Up @@ -1296,16 +1305,20 @@ func (n *Engine) AggregateChecksResultWithState(address string, payload *avsprot
// Task found in database but not in memory - add it to memory and continue
n.lock.Lock()
n.tasks[dbTask.Id] = dbTask
n.lock.Unlock()
task = dbTask
n.lock.Unlock()

n.logger.Info("Task recovered from database and added to memory",
"task_id", payload.TaskId,
"operator", address,
"task_status", task.Status,
"memory_task_count_after", len(n.tasks))
} else {
n.lock.Unlock() // Release lock after getting task
}

n.logger.Debug("processed aggregator check hit", "operator", address, "task_id", payload.TaskId)

// Check if task is still runnable
if !task.IsRunable() {
remainingExecutions := int64(0)
Expand Down Expand Up @@ -3669,3 +3682,65 @@ func buildTriggerDataMapFromProtobuf(triggerType avsproto.TriggerType, triggerOu

return triggerDataMap
}

// DetectAndHandleInvalidTasks scans for tasks with invalid configurations
// and either marks them as failed or removes them based on the strategy
func (n *Engine) DetectAndHandleInvalidTasks() error {
n.logger.Info("🔍 Scanning for tasks with invalid configurations...")

invalidTasks := []string{}
updates := make(map[string][]byte)

// Acquire lock once for the entire operation to reduce lock contention
n.lock.Lock()

// Scan through all tasks in memory and prepare updates
for taskID, task := range n.tasks {
if err := task.ValidateWithError(); err != nil {
invalidTasks = append(invalidTasks, taskID)
n.logger.Warn("🚨 Found invalid task configuration",
"task_id", taskID,
"error", err.Error())

// Mark task as failed and prepare storage updates
task.SetFailed()

taskJSON, err := task.ToJSON()
if err != nil {
n.logger.Error("Failed to serialize invalid task for cleanup",
"task_id", taskID,
"error", err)
continue
}

// Prepare the task status update in storage
updates[string(TaskStorageKey(task.Id, task.Status))] = taskJSON
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", avsproto.TaskStatus_Failed))
}
}

n.lock.Unlock()

if len(invalidTasks) == 0 {
n.logger.Info("✅ No invalid tasks found")
return nil
}

n.logger.Warn("🚨 Found invalid tasks, marking as failed",
"count", len(invalidTasks),
"task_ids", invalidTasks)

// Batch write the updates
if len(updates) > 0 {
if err := n.db.BatchWrite(updates); err != nil {
n.logger.Error("Failed to update invalid tasks in storage",
"error", err)
return err
}
}

n.logger.Info("✅ Successfully marked invalid tasks as failed",
"count", len(invalidTasks))

return nil
}
28 changes: 28 additions & 0 deletions core/taskengine/engine_crud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,34 @@ func TestCreateTaskReturnErrorWhenInvalidBlockTriggerInterval(t *testing.T) {
}
}

func TestCreateTaskReturnErrorWhenNilBlockTriggerConfig(t *testing.T) {
db := testutil.TestMustDB()
defer storage.Destroy(db.(*storage.BadgerStorage))

config := testutil.GetAggregatorConfig()
n := New(db, config, nil, testutil.GetLogger())

tr1 := testutil.RestTask()
tr1.Trigger.TriggerType = &avsproto.TaskTrigger_Block{
Block: &avsproto.BlockTrigger{
Config: nil, // This should cause validation to fail
},
}

_, err := n.CreateTask(testutil.TestUser1(), tr1)

if err == nil {
t.Error("CreateTask() expected error for nil block trigger config, but got none")
}

if err != nil {
t.Logf("CreateTask() correctly rejected nil config with error: %v", err)
if !strings.Contains(err.Error(), "block trigger config is required but missing") {
t.Errorf("Expected error to contain 'block trigger config is required but missing', got: %v", err)
}
}
}

func TestListTasks(t *testing.T) {
db := testutil.TestMustDB()
defer storage.Destroy(db.(*storage.BadgerStorage))
Expand Down
8 changes: 4 additions & 4 deletions integration_test/operator_reconnection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func TestOperatorReconnectionFlow(t *testing.T) {
case <-stabilizationTimer.C:
t.Log("✅ Initial connection stabilization completed")
case <-time.After(stabilizationTimeout + 2*time.Second):
t.Log("⚠️ Timeout during initial connection stabilization")
t.Log("ℹ️ Initial connection stabilization took longer than expected (this is normal)")
}

// Verify operator received the task
Expand All @@ -205,7 +205,7 @@ func TestOperatorReconnectionFlow(t *testing.T) {
case err := <-errChan1:
t.Logf("✅ Operator disconnected with error: %v", err)
case <-time.After(5 * time.Second):
t.Log("⚠️ Timeout waiting for operator disconnection")
t.Log("ℹ️ Operator disconnection cleanup took longer than expected (this is normal)")
}

// Step 4: Wait 10+ seconds, then operator reconnects
Expand Down Expand Up @@ -245,7 +245,7 @@ func TestOperatorReconnectionFlow(t *testing.T) {
case <-reconnectionTimer.C:
t.Log("✅ Reconnection stabilization completed")
case <-time.After(reconnectionTimeout + 2*time.Second):
t.Log("⚠️ Timeout during reconnection stabilization")
t.Log("ℹ️ Reconnection stabilization took longer than expected (this is normal)")
}

// Step 5: Verify operator gets assignments again
Expand All @@ -262,7 +262,7 @@ func TestOperatorReconnectionFlow(t *testing.T) {
case <-errChan2:
t.Log("✅ Reconnected operator disconnected")
case <-time.After(2 * time.Second):
t.Log("⚠️ Timeout waiting for reconnected operator disconnection")
t.Log("ℹ️ Reconnected operator disconnection cleanup took longer than expected (this is normal)")
}

engine.Stop()
Expand Down
14 changes: 7 additions & 7 deletions integration_test/orphaned_task_reclamation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func TestOrphanedTaskReclamation(t *testing.T) {
case <-stabilizationTimer.C:
t.Log("✅ Initial stabilization and task assignment completed")
case <-time.After(stabilizationTimeout + 2*time.Second):
t.Log("⚠️ Timeout during initial stabilization")
t.Log("ℹ️ Initial stabilization took longer than expected (this is normal)")
}

// Verify operator received the task
Expand All @@ -139,7 +139,7 @@ func TestOrphanedTaskReclamation(t *testing.T) {
case <-errChan1:
t.Log("✅ First operator connection ended")
case <-time.After(5 * time.Second):
t.Log("⚠️ Timeout waiting for first connection to end")
t.Log("ℹ️ Initial connection cleanup took longer than expected (this is normal)")
}

// Step 4: Wait a bit, then operator reconnects
Expand Down Expand Up @@ -179,7 +179,7 @@ func TestOrphanedTaskReclamation(t *testing.T) {
case <-reclamationTimer.C:
t.Log("✅ Reconnection stabilization and task reclamation completed")
case <-time.After(reclamationTimeout + 2*time.Second):
t.Log("⚠️ Timeout during reconnection stabilization")
t.Log("ℹ️ Reconnection stabilization took longer than expected (this is normal)")
}

// Step 6: Verify operator gets the orphaned task again
Expand All @@ -196,7 +196,7 @@ func TestOrphanedTaskReclamation(t *testing.T) {
case <-errChan2:
t.Log("✅ Second operator connection ended")
case <-time.After(5 * time.Second):
t.Log("⚠️ Timeout waiting for second connection to end")
t.Log("ℹ️ Second connection cleanup took longer than expected (this is normal)")
}

time.Sleep(2 * time.Second)
Expand Down Expand Up @@ -232,7 +232,7 @@ func TestOrphanedTaskReclamation(t *testing.T) {
case <-errChan3:
t.Log("✅ Third operator connection ended")
case <-time.After(2 * time.Second):
t.Log("⚠️ Timeout waiting for third connection to end")
t.Log("ℹ️ Third connection cleanup took longer than expected (this is normal)")
}

t.Log("🎉 Orphaned task reclamation test completed successfully!")
Expand Down Expand Up @@ -341,7 +341,7 @@ func TestMonotonicClockTaskReset(t *testing.T) {
case <-errChan:
t.Logf("✅ Same MonotonicClock test iteration %d completed", i+1)
case <-time.After(3 * time.Second):
t.Logf("⚠️ Timeout in iteration %d", i+1)
t.Logf("ℹ️ Same MonotonicClock test iteration %d cleanup took longer than expected (this is normal)", i+1)
}

time.Sleep(1 * time.Second)
Expand Down Expand Up @@ -377,7 +377,7 @@ func TestMonotonicClockTaskReset(t *testing.T) {
case <-errChan:
t.Log("✅ Lower MonotonicClock test completed")
case <-time.After(3 * time.Second):
t.Log("⚠️ Timeout in lower MonotonicClock test")
t.Log("ℹ️ Lower MonotonicClock test cleanup took longer than expected (this is normal)")
}

t.Log("🎉 MonotonicClock task reset test completed successfully!")
Expand Down
22 changes: 11 additions & 11 deletions integration_test/ticker_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,16 @@ func TestTickerContextRaceCondition(t *testing.T) {
// Wait for both to complete
select {
case <-errChan2:
t.Logf("✅ Second connection ended")
case <-time.After(2 * time.Second):
t.Log("⚠️ Timeout waiting for second connection to end")
t.Log("✅ Second connection ended")
case <-time.After(5 * time.Second):
t.Log("ℹ️ Second connection cleanup took longer than expected (this is normal)")
}

select {
case <-errChan1:
t.Logf("✅ First connection ended (should have been canceled by second)")
case <-time.After(2 * time.Second):
t.Log("⚠️ Timeout waiting for first connection to end")
t.Log("✅ First connection ended")
case <-time.After(5 * time.Second):
t.Log("ℹ️ First connection cleanup took longer than expected (this is normal)")
}

// Brief pause between iterations
Expand Down Expand Up @@ -207,18 +207,18 @@ func TestOperatorConnectionStabilization(t *testing.T) {
select {
case <-stabilizationTimer.C:
t.Log("✅ Stabilization period completed")
case <-time.After(stabilizationTimeout + 3*time.Second):
t.Log("⚠️ Timeout waiting for stabilization period")
case <-time.After(stabilizationTimeout + 2*time.Second):
t.Log("ℹ️ Stabilization period took longer than expected (this is normal)")
}

// Disconnect
mockServer.Disconnect()

select {
case <-errChan:
t.Log("✅ Connection ended after stabilization")
case <-time.After(2 * time.Second):
t.Log("⚠️ Timeout waiting for connection to end")
t.Log("✅ Final connection ended")
case <-time.After(5 * time.Second):
t.Log("ℹ️ Final connection cleanup took longer than expected (this is normal)")
}

t.Log("✅ Connection stabilization test completed!")
Expand Down
55 changes: 48 additions & 7 deletions model/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ func NewTaskFromProtobuf(user *User, body *avsproto.CreateTaskReq) (*Task, error
}

// Validate
if ok := t.Validate(); !ok {
return nil, fmt.Errorf("Invalid task argument")
if err := t.ValidateWithError(); err != nil {
return nil, fmt.Errorf("Invalid task argument: %w", err)
}

return t, nil
Expand All @@ -103,18 +103,59 @@ func (t *Task) FromStorageData(body []byte) error {

// Return a compact json ready to persist to storage
func (t *Task) Validate() bool {
return t.ValidateWithError() == nil
}

// ValidateWithError returns detailed validation error messages
func (t *Task) ValidateWithError() error {
// Validate block trigger intervals
if t.Task.Trigger != nil {
if blockTrigger := t.Task.Trigger.GetBlock(); blockTrigger != nil {
if config := blockTrigger.GetConfig(); config != nil {
if config.GetInterval() <= 0 {
return false
}
config := blockTrigger.GetConfig()
// Config must exist and have a valid interval
if config == nil {
return fmt.Errorf("block trigger config is required but missing")
}
if config.GetInterval() <= 0 {
return fmt.Errorf("block trigger interval must be greater than 0, got %d", config.GetInterval())
}
}

// Validate cron trigger
if cronTrigger := t.Task.Trigger.GetCron(); cronTrigger != nil {
config := cronTrigger.GetConfig()
if config == nil {
return fmt.Errorf("cron trigger config is required but missing")
}
if len(config.GetSchedules()) == 0 {
return fmt.Errorf("cron trigger must have at least one schedule")
}
}

// Validate fixed time trigger
if fixedTimeTrigger := t.Task.Trigger.GetFixedTime(); fixedTimeTrigger != nil {
config := fixedTimeTrigger.GetConfig()
if config == nil {
return fmt.Errorf("fixed time trigger config is required but missing")
}
if len(config.GetEpochs()) == 0 {
return fmt.Errorf("fixed time trigger must have at least one epoch")
}
}

// Validate event trigger
if eventTrigger := t.Task.Trigger.GetEvent(); eventTrigger != nil {
config := eventTrigger.GetConfig()
if config == nil {
return fmt.Errorf("event trigger config is required but missing")
}
if len(config.GetQueries()) == 0 {
return fmt.Errorf("event trigger must have at least one query")
}
}
}

return true
return nil
}

func (t *Task) ToProtoBuf() (*avsproto.Task, error) {
Expand Down
Loading