Skip to content

Wallet management #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,16 @@ type RpcServer struct {
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest) (*avsproto.NonceResp, error) {
func (r *RpcServer) CreateWallet(ctx context.Context, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
}
return r.engine.CreateSmartWallet(user, payload)
}

// Get nonce of an existing smart wallet of a given owner
func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest) (*avsproto.NonceResp, error) {
ownerAddress := common.HexToAddress(payload.Owner)

nonce, err := aa.GetNonce(r.smartWalletRpc, ownerAddress, big.NewInt(0))
Expand All @@ -55,17 +63,15 @@ func (r *RpcServer) GetNonce(ctx context.Context, payload *avsproto.NonceRequest

// GetAddress returns smart account address of the given owner in the auth key
func (r *RpcServer) GetSmartAccountAddress(ctx context.Context, payload *avsproto.AddressRequest) (*avsproto.AddressResp, error) {
ownerAddress := common.HexToAddress(payload.Owner)
salt := big.NewInt(0)
sender, err := aa.GetSenderAddress(r.smartWalletRpc, ownerAddress, salt)

user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
return nil, status.Errorf(codes.Unauthenticated, "invalid authentication key")
}

wallets, err := r.engine.GetSmartWallets(user.Address)

return &avsproto.AddressResp{
SmartAccountAddress: sender.String(),
// TODO: return the right salt
Salt: big.NewInt(0).String(),
Wallets: wallets,
}, nil
}

Expand Down Expand Up @@ -160,12 +166,14 @@ func (r *RpcServer) GetTask(ctx context.Context, taskID *avsproto.UUID) (*avspro
return task.ToProtoBuf()
}

// Operator action
func (r *RpcServer) SyncTasks(payload *avsproto.SyncTasksReq, srv avsproto.Aggregator_SyncTasksServer) error {
err := r.engine.StreamCheckToOperator(payload, srv)

return err
}

// Operator action
func (r *RpcServer) UpdateChecks(ctx context.Context, payload *avsproto.UpdateChecksReq) (*avsproto.UpdateChecksResp, error) {
if err := r.engine.AggregateChecksResult(payload.Address, payload.Id); err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
agg.queue,
agg.logger,
)
agg.engine.Start()
agg.engine.MustStart()

agg.queue.MustStart()
agg.worker.MustStart()
Expand Down
11 changes: 11 additions & 0 deletions core/chainio/aa/aa.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,17 @@ func GetSenderAddress(conn *ethclient.Client, ownerAddress common.Address, salt
return &sender, nil
}

// Compute smart wallet address for a particular factory
func GetSenderAddressForFactory(conn *ethclient.Client, ownerAddress common.Address, customFactoryAddress common.Address, salt *big.Int) (*common.Address, error) {
simpleFactory, err := NewSimpleFactory(customFactoryAddress, conn)
if err != nil {
return nil, err
}

sender, err := simpleFactory.GetAddress(nil, ownerAddress, salt)
return &sender, nil
}

func GetNonce(conn *ethclient.Client, ownerAddress common.Address, salt *big.Int) (*big.Int, error) {
if salt == nil {
salt = defaultSalt
Expand Down
31 changes: 21 additions & 10 deletions core/taskengine/doc.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
/*
package trigger monitor the condition on when to fire a task
there are 3 trigger types:
Task Engine handles task storage and execution. We use badgerdb for all of our task storage. We like to make sure of Go cross compiling extensively and want to leverage pure-go as much as possible. badgerdb sastify that requirement.

Interval: Repeated at certain interval
Cron: trigger on time on cron
Onchain Event: when an event is emiited from a contract
Ev
**Wallet Info**

# Storage Layout
w:<eoa>:<smart-wallet-address> = {factory_address: address, salt: salt}

Task is store into 2 storage
t:a:<task-id>: the raw json of task data
u:<task-id>: the task status
**Task Storage**

w:<eoa>:<smart-wallet-address> -> {factory, salt}
t:<task-status>:<task-id> -> task payload, the source of truth of task information
u:<eoa>:<smart-wallet-address>:<task-id> -> task status
h:<smart-wallet-address>:<task-id>:<execution-id> -> an execution history

The task storage was designed for fast retrieve time at the cost of extra storage.

The storage can also be easily back-up, sync due to simplicity of supported write operation.

**Data console**

Storage can also be inspect with telnet:

telnet /tmp/ap.sock

Then issue `get <ket>` or `list <prefix>` or `list *` to inspect current keys in the storage.
*/
package taskengine
160 changes: 106 additions & 54 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ import (
"github.com/AvaProtocol/ap-avs/model"
"github.com/AvaProtocol/ap-avs/storage"
sdklogging "github.com/Layr-Labs/eigensdk-go/logging"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
grpcstatus "google.golang.org/grpc/status"

avsproto "github.com/AvaProtocol/ap-avs/protobuf"
Expand Down Expand Up @@ -50,6 +53,7 @@ type Engine struct {
lock *sync.Mutex
trackSyncedTasks map[string]*operatorState

smartWalletConfig *config.SmartWalletConfig
// when shutdown is true, our engine will perform the shutdown
// pending execution will be pushed out before the shutdown completely
// to force shutdown, one can type ctrl+c twice
Expand Down Expand Up @@ -91,10 +95,11 @@ func New(db storage.Storage, config *config.Config, queue *apqueue.Queue, logger
db: db,
queue: queue,

lock: &sync.Mutex{},
tasks: make(map[string]*model.Task),
trackSyncedTasks: make(map[string]*operatorState),
shutdown: false,
lock: &sync.Mutex{},
tasks: make(map[string]*model.Task),
trackSyncedTasks: make(map[string]*operatorState),
smartWalletConfig: config.SmartWallet,
shutdown: false,

logger: logger,
}
Expand All @@ -110,14 +115,15 @@ func (n *Engine) Stop() {
n.shutdown = true
}

func (n *Engine) Start() {
func (n *Engine) MustStart() {
var err error
n.seq, err = n.db.GetSequence([]byte("t:seq"), 1000)
if err != nil {
panic(err)
}

kvs, e := n.db.GetByPrefix([]byte(fmt.Sprintf("t:%s:", TaskStatusToStorageKey(avsproto.TaskStatus_Active))))
// Upon booting we will get all the active tasks to sync to operator
kvs, e := n.db.GetByPrefix(TaskByStatusStoragePrefix(avsproto.TaskStatus_Active))
if e != nil {
panic(e)
}
Expand All @@ -127,7 +133,87 @@ func (n *Engine) Start() {
n.tasks[string(item.Key)] = &task
}
}
}

func (n *Engine) GetSmartWallets(owner common.Address) ([]*avsproto.SmartWallet, error) {
// This is the default wallet with our own factory
salt := big.NewInt(0)
sender, err := aa.GetSenderAddress(rpcConn, owner, salt)
if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
}

// now load the customize wallet with different salt or factory that was initialed and store in our db
wallets := []*avsproto.SmartWallet{
&avsproto.SmartWallet{
Address: sender.String(),
Factory: n.smartWalletConfig.FactoryAddress.String(),
Salt: salt.String(),
},
}

items, err := n.db.GetByPrefix(WalletByOwnerPrefix(owner))

if err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_SmartWalletNotFoundError), "cannot determine smart wallet address")
}

for _, item := range items {
w := &model.SmartWallet{}
w.FromStorageData(item.Value)

wallets = append(wallets, &avsproto.SmartWallet{
Address: w.Address.String(),
Factory: w.Factory.String(),
Salt: w.Salt.String(),
})
}

return wallets, nil
}

func (n *Engine) CreateSmartWallet(user *model.User, payload *avsproto.CreateWalletReq) (*avsproto.CreateWalletResp, error) {
// Verify data
// when user passing a custom factory address, we want to validate it
if payload.FactoryAddress != "" && !common.IsHexAddress(payload.FactoryAddress) {
return nil, status.Errorf(codes.InvalidArgument, "invalid factory address")
}

salt := big.NewInt(0)
if payload.Salt != "" {
var ok bool
salt, ok = math.ParseBig256(payload.Salt)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "invalid salt value")
}
}

factoryAddress := n.smartWalletConfig.FactoryAddress
if payload.FactoryAddress != "" {
factoryAddress = common.HexToAddress(payload.FactoryAddress)

}

sender, err := aa.GetSenderAddressForFactory(rpcConn, user.Address, factoryAddress, salt)

wallet := &model.SmartWallet{
Owner: &user.Address,
Address: sender,
Factory: &factoryAddress,
Salt: salt,
}

updates := map[string][]byte{}

updates[string(WalletStorageKey(wallet))], err = wallet.ToJSON()

if err = n.db.BatchWrite(updates); err != nil {
return nil, status.Errorf(codes.Code(avsproto.Error_StorageWriteError), "cannot update key to storage")
}

return &avsproto.CreateWalletResp{
Address: sender.String(),
}, nil
}

func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskReq) (*model.Task, error) {
Expand All @@ -140,21 +226,16 @@ func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskRe
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_SmartWalletRpcError), "cannot get smart wallet address")
}

taskID, err := n.NewTaskID()
if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "cannot create task right now. storage unavailable")
}

task, err := model.NewTaskFromProtobuf(taskID, user, taskPayload)
task, err := model.NewTaskFromProtobuf(user, taskPayload)

if err != nil {
return nil, err
}

updates := map[string][]byte{}

updates[TaskStorageKey(task.ID, task.Status)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", avsproto.TaskStatus_Active))
updates[string(TaskStorageKey(task.ID, task.Status))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", avsproto.TaskStatus_Active))

if err = n.db.BatchWrite(updates); err != nil {
return nil, err
Expand Down Expand Up @@ -249,8 +330,8 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
n.logger.Debug("mark task in executing status", "task_id", id)

if err := n.db.Move(
[]byte(fmt.Sprintf("t:%s:%s", TaskStatusToStorageKey(avsproto.TaskStatus_Active), id)),
[]byte(fmt.Sprintf("t:%s:%s", TaskStatusToStorageKey(avsproto.TaskStatus_Executing), id)),
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Active)),
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Executing)),
); err != nil {
n.logger.Error("error moving the task storage from active to executing", "task", id, "error", err)
}
Expand All @@ -263,7 +344,7 @@ func (n *Engine) AggregateChecksResult(address string, ids []string) error {
}

func (n *Engine) ListTasksByUser(user *model.User) ([]*avsproto.ListTasksResp_TaskItemResp, error) {
taskIDs, err := n.db.GetByPrefix([]byte(fmt.Sprintf("u:%s", user.Address.String())))
taskIDs, err := n.db.GetByPrefix(UserTaskStoragePrefix(user.Address))

if err != nil {
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), "storage is not ready")
Expand Down Expand Up @@ -326,8 +407,8 @@ func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error)
return false, fmt.Errorf("Only non executing task can be deleted")
}

n.db.Delete([]byte(TaskStorageKey(task.ID, task.Status)))
n.db.Delete([]byte(TaskUserKey(task)))
n.db.Delete(TaskStorageKey(task.ID, task.Status))
n.db.Delete(TaskUserKey(task))

return true, nil
}
Expand All @@ -346,13 +427,13 @@ func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error)
updates := map[string][]byte{}
oldStatus := task.Status
task.SetCanceled()
updates[TaskStorageKey(task.ID, oldStatus)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", task.Status))
updates[string(TaskStorageKey(task.ID, oldStatus))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", task.Status))

if err = n.db.BatchWrite(updates); err == nil {
n.db.Move(
[]byte(TaskStorageKey(task.ID, oldStatus)),
[]byte(TaskStorageKey(task.ID, task.Status)),
TaskStorageKey(task.ID, oldStatus),
TaskStorageKey(task.ID, task.Status),
)

delete(n.tasks, task.ID)
Expand All @@ -363,37 +444,8 @@ func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error)
return true, nil
}

func TaskStorageKey(id string, status avsproto.TaskStatus) string {
return fmt.Sprintf(
"t:%s:%s",
TaskStatusToStorageKey(status),
id,
)
}

func TaskUserKey(t *model.Task) string {
return fmt.Sprintf(
"u:%s",
t.Key(),
)
}

func TaskStatusToStorageKey(v avsproto.TaskStatus) string {
switch v {
case 1:
return "c"
case 2:
return "f"
case 3:
return "l"
case 4:
return "x"
}

return "a"
}

func (n *Engine) NewTaskID() (string, error) {
// A global counter for the task engine
func (n *Engine) NewSeqID() (string, error) {
num := uint64(0)
var err error

Expand Down
1 change: 1 addition & 0 deletions core/taskengine/engine_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package taskengine
4 changes: 2 additions & 2 deletions core/taskengine/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func (c *ContractProcessor) Perform(job *apqueue.Job) error {

defer func() {
updates := map[string][]byte{}
updates[TaskStorageKey(task.ID, avsproto.TaskStatus_Executing)], err = task.ToJSON()
updates[TaskUserKey(task)] = []byte(fmt.Sprintf("%d", task.Status))
updates[string(TaskStorageKey(task.ID, avsproto.TaskStatus_Executing))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", task.Status))

if err = c.db.BatchWrite(updates); err == nil {
c.db.Move(
Expand Down
Loading
Loading