Skip to content

Implement the new trigger and generic task node structure #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,7 @@ func NewAggregator(c *config.Config) (*Aggregator, error) {
// Open and setup our database
func (agg *Aggregator) initDB(ctx context.Context) error {
var err error
agg.db, err = storage.New(&storage.Config{
Path: agg.config.DbPath,
})
agg.db, err = storage.NewWithPath(agg.config.DbPath)

if err != nil {
panic(err)
Expand Down
6 changes: 1 addition & 5 deletions aggregator/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package aggregator
import (
"context"
"fmt"
"math/big"
"strings"
"time"

"github.com/AvaProtocol/ap-avs/core/auth"
"github.com/AvaProtocol/ap-avs/core/chainio/aa"
"github.com/AvaProtocol/ap-avs/model"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
"github.com/ethereum/go-ethereum/accounts"
Expand Down Expand Up @@ -131,11 +129,9 @@ func (r *RpcServer) verifyAuth(ctx context.Context) (*model.User, error) {
Address: common.HexToAddress(claims["sub"].(string)),
}

smartAccountAddress, err := aa.GetSenderAddress(r.ethrpc, user.Address, big.NewInt(0))
if err != nil {
if err := user.LoadDefaultSmartWallet(r.smartWalletRpc); err != nil {
return nil, fmt.Errorf("Rpc error")
}
user.SmartAccountAddress = smartAccountAddress

return &user, nil
}
Expand Down
10 changes: 10 additions & 0 deletions aggregator/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@ func (agg *Aggregator) stopRepl() {
}

}

// Repl allow an operator to look into node storage directly with a REPL interface.
// It doesn't listen via TCP socket but directly unix socket on file system.
func (agg *Aggregator) startRepl() {
var err error

if _, err := os.Stat(agg.config.SocketPath); err == nil {
// File exists, most likely result of a previous crash without cleaning, attempt to delete
os.Remove(agg.config.SocketPath)
}
repListener, err = net.Listen("unix", agg.config.SocketPath)

if err != nil {
return
}
Expand All @@ -48,6 +57,7 @@ func handleConnection(agg *Aggregator, conn net.Conn) {

reader := bufio.NewReader(conn)
fmt.Fprintln(conn, "AP CLI REPL")
fmt.Fprintln(conn, "Use `list <prefix>*` to list key, `get <key>` to inspect content ")
fmt.Fprintln(conn, "-------------------------")

for {
Expand Down
19 changes: 10 additions & 9 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
timestamppb "google.golang.org/protobuf/types/known/timestamppb"
wrapperspb "google.golang.org/protobuf/types/known/wrapperspb"

"github.com/AvaProtocol/ap-avs/core/auth"
"github.com/AvaProtocol/ap-avs/core/chainio/aa"
"github.com/AvaProtocol/ap-avs/core/config"
"github.com/AvaProtocol/ap-avs/core/taskengine"
Expand All @@ -42,7 +43,7 @@ type RpcServer struct {
func (r *RpcServer) CreateWallet(ctx context.Context, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}
return r.engine.CreateSmartWallet(user, payload)
}
Expand All @@ -53,7 +54,7 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest

nonce, err := aa.GetNonce(r.smartWalletRpc, ownerAddress, big.NewInt(0))
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletRpcError), "cannot determine nonce for smart wallet")
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletRpcError), taskengine.NonceFetchingError)
}

return &avsproto.NonceResp{
Expand All @@ -65,7 +66,7 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest
func (r *RpcServer) GetSmartAccountAddress(ctx context.Context, payload *avsproto.AddressRequest) (*avsproto.AddressResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

wallets, err := r.engine.GetSmartWallets(user.Address)
Expand All @@ -78,7 +79,7 @@ func (r *RpcServer) GetSmartAccountAddress(ctx context.Context, payload *avsprot
func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.UUID) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

r.config.Logger.Info("Process Cancel Task",
Expand All @@ -98,7 +99,7 @@ func (r *RpcServer) CancelTask(ctx context.Context, taskID *avsproto.UUID) (*wra
func (r *RpcServer) DeleteTask(ctx context.Context, taskID *avsproto.UUID) (*wrapperspb.BoolValue, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

r.config.Logger.Info("Process Delete Task",
Expand All @@ -118,7 +119,7 @@ func (r *RpcServer) DeleteTask(ctx context.Context, taskID *avsproto.UUID) (*wra
func (r *RpcServer) CreateTask(ctx context.Context, taskPayload *avsproto.CreateTaskReq) (*avsproto.CreateTaskResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

task, err := r.engine.CreateTask(user, taskPayload)
Expand All @@ -134,7 +135,7 @@ func (r *RpcServer) CreateTask(ctx context.Context, taskPayload *avsproto.Create
func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksReq) (*avsproto.ListTasksResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

r.config.Logger.Info("Process List Task",
Expand All @@ -155,15 +156,15 @@ func (r *RpcServer) ListTasks(ctx context.Context, payload *avsproto.ListTasksRe
func (r *RpcServer) GetTask(ctx context.Context, taskID *avsproto.UUID) (*avsproto.Task, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
return nil, status.Errorf(codes.Unauthenticated, auth.InvalidAuthenticationKey)
}

r.config.Logger.Info("Process Get Task",
"user", user.Address.String(),
"taskID", string(taskID.Bytes),
)

task, err := r.engine.GetTaskByUser(user, string(taskID.Bytes))
task, err := r.engine.GetTask(user, string(taskID.Bytes))
if err != nil {
return nil, err
}
Expand Down
5 changes: 5 additions & 0 deletions core/auth/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package auth

const (
InvalidAuthenticationKey = "Invalid authentication key"
)
93 changes: 47 additions & 46 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"math/big"
"strconv"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -140,7 +141,7 @@ func (n *Engine) GetSmartWallets(owner common.Address) ([]*avsproto.SmartWallet,
salt := big.NewInt(0)
sender, err := aa.GetSenderAddress(rpcConn, owner, salt)
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), SmartAccountCreationError)
}

// now load the customize wallet with different salt or factory that was initialed and store in our db
Expand All @@ -155,7 +156,7 @@ func (n *Engine) GetSmartWallets(owner common.Address) ([]*avsproto.SmartWallet,
items, err := n.db.GetByPrefix(WalletByOwnerPrefix(owner))

if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), SmartAccountCreationError)
}

for _, item := range items {
Expand All @@ -176,15 +177,15 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal
// Verify data
// when user passing a custom factory address, we want to validate it
if payload.FactoryAddress != "" && !common.IsHexAddress(payload.FactoryAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid factory address")
return nil, status.Errorf(codes.InvalidArgument, InvalidFactoryAddressError)
}

salt := big.NewInt(0)
if payload.Salt != "" {
var ok bool
salt, ok = math.ParseBig256(payload.Salt)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "invalid salt value")
return nil, status.Errorf(codes.InvalidArgument, InvalidSmartAccountSaltError)
}
}

Expand All @@ -208,7 +209,7 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal
updates[string(WalletStorageKey(user.Address, sender.Hex()))], err = wallet.ToJSON()

if err = n.db.BatchWrite(updates); err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), "cannot update key to storage")
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), StorageWriteError)
}

return &avsproto.CreateWalletResp{
Expand All @@ -220,19 +221,20 @@ func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWal
func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskReq) (*model.Task, error) {
var err error

fmt.Println("user", user)
if taskPayload.SmartWalletAddress != "" {
if !ValidWalletAddress(taskPayload.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
return nil, status.Errorf(codes.InvalidArgument, InvalidSmartAccountAddressError)
}

if valid, _ := ValidWalletOwner(n.db, user, common.HexToAddress(taskPayload.SmartWalletAddress)); !valid {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
return nil, status.Errorf(codes.InvalidArgument, InvalidSmartAccountAddressError)
}
}
task, err := model.NewTaskFromProtobuf(user, taskPayload)

if err != nil {
return nil, err
return nil, status.Errorf(codes.Code(avsproto.Error_TaskDataMissingError), err.Error())
}

updates := map[string][]byte{}
Expand Down Expand Up @@ -349,23 +351,25 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksReq) ([]*avsproto.Task, error) {
// by default show the task from the default smart wallet, if proving we look into that wallet specifically
owner := user.SmartAccountAddress
if payload.SmartWalletAddress != "" {
if !ValidWalletAddress(payload.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}
if payload.SmartWalletAddress == "" {
return nil, status.Errorf(codes.InvalidArgument, MissingSmartWalletAddressError)
}

if valid, _ := ValidWalletOwner(n.db, user, common.HexToAddress(payload.SmartWalletAddress)); !valid {
return nil, status.Errorf(codes.InvalidArgument, "invalid smart account address")
}
if !ValidWalletAddress(payload.SmartWalletAddress) {
return nil, status.Errorf(codes.InvalidArgument, InvalidSmartAccountAddressError)
}

smartWallet := common.HexToAddress(payload.SmartWalletAddress)
owner = &smartWallet
if valid, _ := ValidWalletOwner(n.db, user, common.HexToAddress(payload.SmartWalletAddress)); !valid {
return nil, status.Errorf(codes.InvalidArgument, InvalidSmartAccountAddressError)
}

smartWallet := common.HexToAddress(payload.SmartWalletAddress)
owner = &smartWallet

taskIDs, err := n.db.GetByPrefix(SmartWalletTaskStoragePrefix(user.Address, *owner))

if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "storage is not ready")
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), StorageUnavailableError)
}

tasks := make([]*avsproto.Task, len(taskIDs))
Expand All @@ -377,56 +381,53 @@ func (n *Engine) ListTasksByUser(user *model.User, payload *avsproto.ListTasksRe
continue
}

task := &model.Task{
ID: taskID,
Owner: user.Address.Hex(),
}
task := model.NewTask()
if err := task.FromStorageData(taskRawByte); err != nil {
continue
}
task.ID = taskID

tasks[i], _ = task.ToProtoBuf()
}

return tasks, nil
}

func (n *Engine) GetTaskByUser(user *model.User, taskID string) (*model.Task, error) {
task := &model.Task{
ID: taskID,
Owner: user.Address.Hex(),
}
func (n *Engine) GetTaskByID(taskID string) (*model.Task, error) {
for status, _ := range avsproto.TaskStatus_name {
if rawTaskData, err := n.db.GetKey(TaskStorageKey(taskID, avsproto.TaskStatus(status))); err == nil {
task := model.NewTask()
err = task.FromStorageData(rawTaskData)

// Get Task Status
rawStatus, err := n.db.GetKey([]byte(TaskUserKey(task)))
if err != nil {
return nil, grpcstatus.Errorf(codes.NotFound, "task not found")
if err == nil {
return task, nil
}

return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_TaskDataCorrupted), TaskStorageCorruptedError)
}
}
status, _ := strconv.Atoi(string(rawStatus))

taskRawByte, err := n.db.GetKey(TaskStorageKey(taskID, avsproto.TaskStatus(status)))
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

func (n *Engine) GetTask(user *model.User, taskID string) (*model.Task, error) {
task, err := n.GetTaskByID(taskID)
if err != nil {
taskRawByte, err = n.db.GetKey([]byte(
TaskStorageKey(taskID, avsproto.TaskStatus_Executing),
))
if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_TaskDataCorrupted), "task data storage is corrupted")
}
return nil, err
}

err = task.FromStorageData(taskRawByte)
if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_TaskDataCorrupted), "task data storage is corrupted")
if strings.ToLower(task.Owner) != strings.ToLower(user.Address.Hex()) {
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

return task, nil
}

func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error) {
task, err := n.GetTaskByUser(user, taskID)
task, err := n.GetTask(user, taskID)

if err != nil {
return false, grpcstatus.Errorf(codes.NotFound, "task not found")
return false, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

if task.Status == avsproto.TaskStatus_Executing {
Expand All @@ -440,10 +441,10 @@ func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error)
}

func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error) {
task, err := n.GetTaskByUser(user, taskID)
task, err := n.GetTask(user, taskID)

if err != nil {
return false, grpcstatus.Errorf(codes.NotFound, "task not found")
return false, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

if task.Status != avsproto.TaskStatus_Active {
Expand Down
18 changes: 18 additions & 0 deletions core/taskengine/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package taskengine

const (
TaskNotFoundError = "task not found"

InvalidSmartAccountAddressError = "invalid smart account address"
InvalidFactoryAddressError = "invalid factory address"
InvalidSmartAccountSaltError = "invalid salt value"
SmartAccountCreationError = "cannot determine smart wallet address"
NonceFetchingError = "cannot determine nonce for smart wallet"

MissingSmartWalletAddressError = "Missing smart_wallet_address"

StorageUnavailableError = "storage is not ready"
StorageWriteError = "cannot write to storage"

TaskStorageCorruptedError = "task data storage is corrupted"
)
Loading
Loading