Skip to content

seperate compress and encryption from code #905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 52 additions & 4 deletions skyplane/gateway/gateway_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
GatewayObjStoreReadOperator,
GatewayObjStoreWriteOperator,
GatewayWaitReceiver,
GatewayCompress,
GatewayDecompress,
GatewayEncrypt,
GatewayDecrypt,
)
from skyplane.gateway.operators.gateway_receiver import GatewayReceiver
from skyplane.utils import logger
Expand Down Expand Up @@ -90,8 +94,6 @@ def __init__(
error_queue=self.error_queue,
max_pending_chunks=max_incoming_ports,
use_tls=self.use_tls,
use_compression=use_compression,
e2ee_key_bytes=self.e2ee_key_bytes,
)

# API server
Expand Down Expand Up @@ -232,8 +234,6 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_tls=self.use_tls,
use_compression=op["compress"],
e2ee_key_bytes=self.e2ee_key_bytes,
n_processes=op["num_connections"],
)
total_p += op["num_connections"]
Expand Down Expand Up @@ -264,6 +264,54 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_
chunk_store=self.chunk_store,
)
total_p += 1
elif op["op_type"] == "compress":
operators[handle] = GatewayCompress(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_compression=op["compress"],
)
total_p += 1
elif op["op_type"] == "decompress":
operators[handle] = GatewayDecompress(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
use_compression=op["compress"],
)
total_p += 1
elif op["op_type"] == "encrypt":
operators[handle] = GatewayEncrypt(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
e2ee_key_bytes=self.e2ee_key_bytes,
)
total_p += 1
elif op["op_type"] == "decrypt":
operators[handle] = GatewayDecrypt(
handle=handle,
region=self.region,
input_queue=input_queue,
output_queue=output_queue,
error_event=self.error_event,
error_queue=self.error_queue,
chunk_store=self.chunk_store,
e2ee_key_bytes=self.e2ee_key_bytes,
)
total_p += 1
else:
raise ValueError(f"Unsupported op_type {op['op_type']}")
# recursively create for child operators
Expand Down
32 changes: 25 additions & 7 deletions skyplane/gateway/gateway_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,18 @@ def __init__(
target_gateway_id: str,
region: str,
num_connections: int = 32,
compress: bool = False,
encrypt: bool = False,
private_ip: bool = False,
):
super().__init__("send")
self.target_gateway_id = target_gateway_id # gateway to send to
self.region = region # region to send to
self.num_connections = num_connections # default this for now
self.compress = compress
self.encrypt = encrypt
self.private_ip = private_ip # whether to send to private or public IP (private for GCP->GCP)


class GatewayReceive(GatewayOperator):
def __init__(self, decompress: bool = False, decrypt: bool = False, max_pending_chunks: int = 1000):
def __init__(self, max_pending_chunks: int = 1000):
super().__init__("receive")
self.decompress = decompress
self.decrypt = decrypt
self.max_pending_chunks = max_pending_chunks


Expand Down Expand Up @@ -97,6 +91,30 @@ def __init__(self):
super().__init__("mux_or")


class GatewayCompress(GatewayOperator):
def __init__(self, compress: bool = False):
super().__init__("compress")
self.compress = compress


class GatewayDecompress(GatewayOperator):
def __init__(self, compress: bool = False):
super().__init__("decompress")
self.compress = compress


class GatewayEncrypt(GatewayOperator):
def __init__(self, encrypt: bool = False):
super().__init__("encrypt")
self.encrypt = encrypt


class GatewayDecrypt(GatewayOperator):
def __init__(self, decrypt: bool = False):
super().__init__("decrypt")
self.decrypt = decrypt


class GatewayProgram:

"""
Expand Down
203 changes: 185 additions & 18 deletions skyplane/gateway/operators/gateway_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,15 +162,11 @@ def __init__(
chunk_store: ChunkStore,
ip_addr: str,
use_tls: Optional[bool] = True,
use_compression: Optional[bool] = True,
e2ee_key_bytes: Optional[bytes] = None,
n_processes: Optional[int] = 32,
):
super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes)
self.ip_addr = ip_addr
self.use_tls = use_tls
self.use_compression = use_compression
self.e2ee_key_bytes = e2ee_key_bytes
self.args = (ip_addr,)

# provider = region.split(":")[0]
Expand All @@ -179,12 +175,6 @@ def __init__(
# elif provider == "azure":
# self.n_processes = 24 # due to throttling limits from authentication

# encryption
if e2ee_key_bytes is None:
self.e2ee_secretbox = None
else:
self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes)

# SSL context
if use_tls:
self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
Expand Down Expand Up @@ -332,14 +322,6 @@ def process(self, chunk_req: ChunkRequest, dst_host: str):
raw_wire_length = wire_length
compressed_length = None

if self.use_compression:
data = lz4.frame.compress(data)
wire_length = len(data)
compressed_length = wire_length
if self.e2ee_secretbox is not None:
data = self.e2ee_secretbox.encrypt(data)
wire_length = len(data)

# send chunk header
header = chunk.to_wire_header(
n_chunks_left_on_socket=len(chunk_ids) - idx - 1,
Expand Down Expand Up @@ -600,3 +582,188 @@ def process(self, chunk_req: ChunkRequest):
f"[obj_store:{self.worker_id}] Uploaded {chunk_req.chunk.chunk_id} partition {chunk_req.chunk.part_number} to {self.bucket_name}"
)
return True


class GatewayCompress(GatewayOperator):
def __init__(
self,
handle: str,
region: str,
input_queue: GatewayQueue,
output_queue: GatewayQueue,
error_event,
error_queue: Queue,
chunk_store: Optional[ChunkStore] = None,
use_compression: Optional[bool] = True,
prefix: Optional[str] = "",
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, chunk_store
)
self.chunk_store = chunk_store
self.use_compression = use_compression
self.prefix = prefix

def process(self, chunk_req: ChunkRequest):
if not self.use_compression: return True
logger.debug(
f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}"
)
chunk_reqs = [chunk_req]
for idx, chunk_req in enumerate(chunk_reqs):
chunk_id = chunk_req.chunk.chunk_id
chunk = chunk_req.chunk
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id)

with open(chunk_file_path, "rb") as f:
data = f.read()
assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}"

data = lz4.frame.compress(data)
wire_length = len(data)
compressed_length = wire_length

with open(chunk_file_path, "wb") as f:
data = f.write()
return True


class GatewayDecompress(GatewayOperator):
def __init__(
self,
handle: str,
region: str,
input_queue: GatewayQueue,
output_queue: GatewayQueue,
error_event,
error_queue: Queue,
chunk_store: Optional[ChunkStore] = None,
use_compression: Optional[bool] = True,
prefix: Optional[str] = "",
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, chunk_store
)
self.chunk_store = chunk_store
self.use_compression = use_compression
self.prefix = prefix

def process(self, chunk_req: ChunkRequest):
if not self.use_compression: return True
logger.debug(
f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}"
)
chunk_reqs = [chunk_req]
for idx, chunk_req in enumerate(chunk_reqs):
chunk_id = chunk_req.chunk.chunk_id
chunk = chunk_req.chunk
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id)

with open(chunk_file_path, "rb") as f:
data = f.read()
assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}"

data = lz4.frame.decompress(data)
wire_length = len(data)
compressed_length = wire_length

with open(chunk_file_path, "wb") as f:
data = f.write()
return True

class GatewayEncrypt(GatewayOperator):
def __init__(
self,
handle: str,
region: str,
input_queue: GatewayQueue,
output_queue: GatewayQueue,
error_event,
error_queue: Queue,
chunk_store: Optional[ChunkStore] = None,
e2ee_key_bytes: Optional[bytes] = None,
prefix: Optional[str] = "",
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, chunk_store
)
self.chunk_store = chunk_store
self.prefix = prefix

# encryption
if e2ee_key_bytes is None:
self.e2ee_secretbox = None
else:
self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes)

def process(self, chunk_req: ChunkRequest):
if not self.e2ee_secretbox: return
logger.debug(
f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}"
)
chunk_reqs = [chunk_req]
for idx, chunk_req in enumerate(chunk_reqs):
chunk_id = chunk_req.chunk.chunk_id
chunk = chunk_req.chunk
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id)

with open(chunk_file_path, "rb") as f:
data = f.read()
assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}"

data = self.e2ee_secretbox.encrypt(data)
wire_length = len(data)
encrypted_length = wire_length

with open(chunk_file_path, "wb") as f:
data = f.write()
return True


class GatewayDecrypt(GatewayOperator):
def __init__(
self,
handle: str,
region: str,
input_queue: GatewayQueue,
output_queue: GatewayQueue,
error_event,
error_queue: Queue,
chunk_store: Optional[ChunkStore] = None,
e2ee_key_bytes: Optional[bytes] = None,
prefix: Optional[str] = "",
):
super().__init__(
handle, region, input_queue, output_queue, error_event, error_queue, chunk_store
)
self.chunk_store = chunk_store
self.prefix = prefix

# encryption
if e2ee_key_bytes is None:
self.e2ee_secretbox = None
else:
self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes)

def process(self, chunk_req: ChunkRequest):
if not self.e2ee_secretbox: return
logger.debug(
f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}"
)
chunk_reqs = [chunk_req]
for idx, chunk_req in enumerate(chunk_reqs):
chunk_id = chunk_req.chunk.chunk_id
chunk = chunk_req.chunk
chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id)

with open(chunk_file_path, "rb") as f:
data = f.read()
assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}"

data = self.e2ee_secretbox.decrypt(data)
wire_length = len(data)
encrypted_length = wire_length

with open(chunk_file_path, "wb") as f:
data = f.write()
return True
Loading