Skip to content
This repository was archived by the owner on Jun 11, 2025. It is now read-only.

feat: implement ciphertext cache preload upon restart #187

Merged
merged 2 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 127 additions & 9 deletions fhevm-engine/fhevm-go-native/fhevm/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ type ExecutorApi interface {
// We pass current block number to know at which
// block ciphertext should be materialized inside blockchain state.
CreateSession(blockNumber int64) ExecutorSession
// Preload ciphertexts into cache and perform initial computations,
// should be called once after blockchain node initialization
PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error
}

type SegmentId int
Expand Down Expand Up @@ -230,6 +233,114 @@ func (executorApi *ApiImpl) CreateSession(blockNumber int64) ExecutorSession {
}
}

func (executorApi *ApiImpl) PreloadCiphertexts(blockNumber int64, api ChainStorageApi) error {
computations := executorApi.loadComputationsFromStateToCache(blockNumber, api)
if computations > 0 {
return executorProcessPendingComputations(executorApi)
}

return nil
}

func (executorApi *ApiImpl) loadComputationsFromStateToCache(startBlockNumber int64, api ChainStorageApi) int {
loadStartTime := time.Now()
computations := 0
defer func() {
duration := time.Since(loadStartTime)
fmt.Printf("ciphertext cache preloaded with %d ciphertexts in %dms\n", computations, duration.Milliseconds())
}()

// TODO: figure out the limit how long in future blocks we should preload
lastBlockToPreload := startBlockNumber + 30

executorApi.cache.lock.Lock()
defer executorApi.cache.lock.Unlock()

for block := startBlockNumber; block < lastBlockToPreload; block++ {
countAddress := blockNumberToQueueItemCountAddress(block)
ciphertextsInBlock := api.GetState(executorApi.contractStorageAddress, countAddress).Big()
inBlock := ciphertextsInBlock.Int64()
queue := make([]*ComputationToInsert, 0)
enqueuedCiphertext := make(map[string]bool)

if inBlock == 0 {
continue
}

computations += int(inBlock)

for ctNum := 0; ctNum < int(inBlock); ctNum++ {
layout := blockQueueStorageLayout(block, int64(ctNum))
metadata := bytesToMetadata(api.GetState(executorApi.contractStorageAddress, layout.metadata))
outputHandle := api.GetState(executorApi.contractStorageAddress, layout.outputHandle)
computation := &ComputationToInsert{
segmentId: 0,
Operation: metadata.Operation,
OutputHandle: outputHandle[:],
CommitBlockId: block,
}

if isBinaryOp(metadata.Operation) {
firstOpHandle := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpHandle[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpHandle[:]),
})

if metadata.IsBigScalar {
// TODO: implement big scalar
} else if metadata.IsScalar {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: true,
Handle: secondOpHandle[:],
FheUintType: handleType(firstOpHandle[:]),
})
} else {
secondOpHandle := api.GetState(executorApi.contractStorageAddress, layout.secondOperand)
secondOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, secondOpHandle)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: secondOpHandle[:],
CompressedCiphertext: secondOpCt,
FheUintType: handleType(secondOpHandle[:]),
})
}
} else if isUnaryOp(metadata.Operation) {
firstOpAddress := api.GetState(executorApi.contractStorageAddress, layout.firstOperand)
firstOpCt := ReadBytesToAddress(api, executorApi.contractStorageAddress, firstOpAddress)

computation.Operands = append(computation.Operands, ComputationOperand{
IsScalar: false,
Handle: firstOpAddress[:],
CompressedCiphertext: firstOpCt,
FheUintType: handleType(firstOpAddress[:]),
})
} else {
// TODO: handle all special functions to load their ciphertext arguments
}

if !enqueuedCiphertext[string(computation.OutputHandle)] {
queue = append(queue, computation)
enqueuedCiphertext[string(computation.OutputHandle)] = true
}
}

ctsToCompute := &BlockCiphertextQueue{
queue: queue,
enqueuedCiphertext: enqueuedCiphertext,
}
executorApi.cache.ciphertextsToCompute[block] = ctsToCompute
}

return computations
}

func (sessionApi *SessionImpl) Commit(blockNumber int64, storage ChainStorageApi) error {
err := sessionApi.sessionStore.Commit(storage)
if err != nil {
Expand Down Expand Up @@ -530,12 +641,13 @@ func (dbApi *EvmStorageComputationStore) InsertComputationBatch(evmStorage Chain

for _, comp := range bucket {
// don't have duplicates, from possibly evaluating multiple trie caches
if !ctsStorage.enqueuedCiphertext[common.Bytes2Hex(comp.OutputHandle)] {
if !ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] {
// we must fill the raw ciphertext values here from storage so cache
// would have ciphertexts to compute on, as cache doesn't have easy
// access to the evm state
dbApi.hydrateComputationFromEvmState(evmStorage, comp)
ctsStorage.queue = append(ctsStorage.queue, comp)
ctsStorage.enqueuedCiphertext[string(comp.OutputHandle)] = true
}
}
}
Expand Down Expand Up @@ -766,18 +878,20 @@ func InitExecutor() (ExecutorApi, error) {

workAvailableChan := make(chan bool, 10)

cache := &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
}

apiImpl := &ApiImpl{
address: fhevmContractAddress,
aclContractAddress: aclContractAddress,
contractStorageAddress: storageAddress,
executorUrl: executorUrl,
cache: &CiphertextCache{
lock: sync.RWMutex{},
blocksCiphertexts: make(map[int64]*CacheBlockData),
ciphertextsToCompute: make(map[int64]*BlockCiphertextQueue),
workAvailableChan: workAvailableChan,
lastCacheGc: time.Now(),
},
cache: cache,
}

// run executor worker in the background
Expand Down Expand Up @@ -885,8 +999,12 @@ func executorProcessPendingComputations(impl *ApiImpl) error {
if err != nil {
return err
}
ciphertexts := response.GetResultCiphertexts()
if ciphertexts == nil {
return errors.New(response.GetError().String())
}

outCts := response.GetResultCiphertexts().Ciphertexts
outCts := ciphertexts.Ciphertexts
fmt.Printf("got %d ciphertext responses from the executor\n", len(outCts))
for _, ct := range outCts {
theBlock, exists := ctToBlockIndex[string(ct.Handle)]
Expand Down
122 changes: 122 additions & 0 deletions fhevm-engine/fhevm-go-native/fhevm/fhelib_ops.go
Original file line number Diff line number Diff line change
Expand Up @@ -1688,3 +1688,125 @@ func getThreeFheOperands(sess ExecutorSession, input []byte) (first []byte, seco

return input[0:32], input[32:64], input[64:96], nil
}

func isBinaryOp(op FheOp) bool {
switch op {
case FheAdd:
return true
case FheBitAnd:
return true
case FheBitOr:
return true
case FheBitXor:
return true
case FheDiv:
return true
case FheEq:
return true
case FheGe:
return true
case FheGt:
return true
case FheLe:
return true
case FheLt:
return true
case FheMax:
return true
case FheMin:
return true
case FheMul:
return true
case FheNe:
return true
case FheRem:
return true
case FheRotl:
return true
case FheRotr:
return true
case FheShl:
return true
case FheShr:
return true
case FheSub:
return true
case FheCast:
return false
case FheNeg:
return false
case FheNot:
return false
case FheRand:
return false
case FheRandBounded:
return false
case FheIfThenElse:
return false
case TrivialEncrypt:
return false
default:
return false
}
}

func isUnaryOp(op FheOp) bool {
switch op {
case FheNeg:
return true
case FheNot:
return true
case FheAdd:
return false
case FheBitAnd:
return false
case FheBitOr:
return false
case FheBitXor:
return false
case FheDiv:
return false
case FheEq:
return false
case FheGe:
return false
case FheGt:
return false
case FheLe:
return false
case FheLt:
return false
case FheMax:
return false
case FheMin:
return false
case FheMul:
return false
case FheNe:
return false
case FheRem:
return false
case FheRotl:
return false
case FheRotr:
return false
case FheShl:
return false
case FheShr:
return false
case FheSub:
return false
case FheCast:
return false
case FheRand:
return false
case FheRandBounded:
return false
case FheIfThenElse:
return false
case TrivialEncrypt:
return false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for nitpicking - but might even drop the false branch and leave it to default :)

If we don't have ternary ops we might otherwise just implement binary as !unary, but that's not essential.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish it was more like rust, where we would have to list all the cases or its a compile time error, we might forget an operand 🤔

default:
return false
}
}
Loading