@@ -7,10 +7,12 @@ import (
77 "fmt"
88 "os"
99 "path/filepath"
10+ "strings"
1011 "sync"
1112 "testing"
1213 "time"
1314
15+ "github.com/Masterminds/squirrel"
1416 "github.com/juju/errors"
1517 "github.com/loopfz/gadgeto/zesty"
1618 "github.com/maxatome/go-testdeep/td"
@@ -23,6 +25,7 @@ import (
2325 "github.com/ovh/utask/api"
2426 "github.com/ovh/utask/db"
2527 "github.com/ovh/utask/db/pgjuju"
28+ "github.com/ovh/utask/db/sqlgenerator"
2629 "github.com/ovh/utask/engine"
2730 "github.com/ovh/utask/engine/functions"
2831 functionrunner "github.com/ovh/utask/engine/functions/runner"
@@ -36,6 +39,7 @@ import (
3639 compress "github.com/ovh/utask/pkg/compress/init"
3740 "github.com/ovh/utask/pkg/now"
3841 "github.com/ovh/utask/pkg/plugins"
42+ pluginbatch "github.com/ovh/utask/pkg/plugins/builtin/batch"
3943 plugincallback "github.com/ovh/utask/pkg/plugins/builtin/callback"
4044 "github.com/ovh/utask/pkg/plugins/builtin/echo"
4145 "github.com/ovh/utask/pkg/plugins/builtin/script"
@@ -91,6 +95,7 @@ func TestMain(m *testing.M) {
9195 step .RegisterRunner (echo .Plugin .PluginName (), echo .Plugin )
9296 step .RegisterRunner (script .Plugin .PluginName (), script .Plugin )
9397 step .RegisterRunner (pluginsubtask .Plugin .PluginName (), pluginsubtask .Plugin )
98+ step .RegisterRunner (pluginbatch .Plugin .PluginName (), pluginbatch .Plugin )
9499 step .RegisterRunner (plugincallback .Plugin .PluginName (), plugincallback .Plugin )
95100
96101 os .Exit (m .Run ())
@@ -194,6 +199,21 @@ func templateFromYAML(dbp zesty.DBProvider, filename string) (*tasktemplate.Task
194199 return tasktemplate .LoadFromName (dbp , tmpl .Name )
195200}
196201
202+ func listBatchTasks (dbp zesty.DBProvider , batchID int64 ) ([]string , error ) {
203+ query , params , err := sqlgenerator .PGsql .
204+ Select ("public_id" ).
205+ From ("task" ).
206+ Where (squirrel.Eq {"id_batch" : batchID }).
207+ ToSql ()
208+ if err != nil {
209+ return nil , err
210+ }
211+
212+ var taskIDs []string
213+ _ , err = dbp .DB ().Select (& taskIDs , query , params ... )
214+ return taskIDs , err
215+ }
216+
197217func TestSimpleTemplate (t * testing.T ) {
198218 input := map [string ]interface {}{
199219 "foo" : "bar" ,
@@ -1370,3 +1390,106 @@ func TestB64RawEncodeDecode(t *testing.T) {
13701390 assert .Equal (t , "cmF3IG1lc3NhZ2U" , output ["a" ])
13711391 assert .Equal (t , "raw message" , output ["b" ])
13721392}
1393+
1394+ func TestBatch (t * testing.T ) {
1395+ dbp , err := zesty .NewDBProvider (utask .DBName )
1396+ require .Nil (t , err )
1397+
1398+ _ , err = templateFromYAML (dbp , "batchedTask.yaml" )
1399+ require .Nil (t , err )
1400+
1401+ _ , err = templateFromYAML (dbp , "batch.yaml" )
1402+ require .Nil (t , err )
1403+
1404+ res , err := createResolution ("batch.yaml" , map [string ]interface {}{}, nil )
1405+ require .Nil (t , err , "failed to create resolution: %s" , err )
1406+
1407+ res , err = runResolution (res )
1408+ require .Nil (t , err )
1409+ require .NotNil (t , res )
1410+ assert .Equal (t , resolution .StateWaiting , res .State )
1411+
1412+ for _ , batchStepName := range []string {"batchJsonInputs" , "batchYamlInputs" } {
1413+ batchStepMetadataRaw , ok := res .Steps [batchStepName ].Metadata .(string )
1414+ assert .True (t , ok , "wrong type of metadata for step '%s'" , batchStepName )
1415+
1416+ assert .Nil (t , res .Steps [batchStepName ].Output , "output nil for step '%s'" , batchStepName )
1417+
1418+ // The plugin formats Metadata in a special way that we need to revert before unmarshalling them
1419+ batchStepMetadataRaw = strings .ReplaceAll (batchStepMetadataRaw , `\"` , `"` )
1420+ var batchStepMetadata map [string ]any
1421+ err := json .Unmarshal ([]byte (batchStepMetadataRaw ), & batchStepMetadata )
1422+ require .Nil (t , err , "metadata unmarshalling of step '%s'" , batchStepName )
1423+
1424+ batchPublicID := batchStepMetadata ["batch_id" ].(string )
1425+ assert .NotEqual (t , "" , batchPublicID , "wrong batch ID '%s'" , batchPublicID )
1426+
1427+ b , err := task .LoadBatchFromPublicID (dbp , batchPublicID )
1428+ require .Nil (t , err )
1429+
1430+ taskIDs , err := listBatchTasks (dbp , b .ID )
1431+ require .Nil (t , err )
1432+ assert .Len (t , taskIDs , 2 )
1433+
1434+ for i , publicID := range taskIDs {
1435+ child , err := task .LoadFromPublicID (dbp , publicID )
1436+ require .Nil (t , err )
1437+ assert .Equal (t , task .StateTODO , child .State )
1438+
1439+ childResolution , err := resolution .Create (dbp , child , nil , "" , false , nil )
1440+ require .Nil (t , err )
1441+
1442+ childResolution , err = runResolution (childResolution )
1443+ require .Nil (t , err )
1444+ assert .Equal (t , resolution .StateDone , childResolution .State )
1445+
1446+ for k , v := range childResolution .Steps {
1447+ assert .Equal (t , step .StateDone , v .State , "not valid state for step %s" , k )
1448+ }
1449+
1450+ child , err = task .LoadFromPublicID (dbp , child .PublicID )
1451+ require .Nil (t , err )
1452+ assert .Equal (t , task .StateDone , child .State )
1453+
1454+ parentTaskToResume , err := taskutils .ShouldResumeParentTask (dbp , child )
1455+ require .Nil (t , err )
1456+ if i == len (taskIDs )- 1 {
1457+ // Only the last child task should resume the parent
1458+ require .NotNil (t , parentTaskToResume )
1459+ assert .Equal (t , res .TaskID , parentTaskToResume .ID )
1460+ } else {
1461+ require .Nil (t , parentTaskToResume )
1462+ }
1463+ }
1464+ }
1465+
1466+ // checking if the parent task is picked up after that the subtask is resolved.
1467+ // need to sleep a bit because the parent task is resumed asynchronously
1468+ ti := time .Second
1469+ i := time .Duration (0 )
1470+ for i < ti {
1471+ res , err = resolution .LoadFromPublicID (dbp , res .PublicID )
1472+ require .Nil (t , err )
1473+ if res .State != resolution .StateWaiting {
1474+ break
1475+ }
1476+
1477+ time .Sleep (time .Millisecond * 10 )
1478+ i += time .Millisecond * 10
1479+ }
1480+
1481+ ti = time .Second
1482+ i = time .Duration (0 )
1483+ for i < ti {
1484+ res , err = resolution .LoadFromPublicID (dbp , res .PublicID )
1485+ require .Nil (t , err )
1486+ if res .State != resolution .StateRunning {
1487+ break
1488+ }
1489+
1490+ time .Sleep (time .Millisecond * 10 )
1491+ i += time .Millisecond * 10
1492+
1493+ }
1494+ assert .Equal (t , resolution .StateDone , res .State )
1495+ }
0 commit comments