From 0617ed1d72480535ccdccdff974dd32522070c7d Mon Sep 17 00:00:00 2001 From: huangwei Date: Mon, 31 Jul 2023 11:36:08 +0800 Subject: [PATCH] seperate compress and encryption --- skyplane/gateway/gateway_daemon.py | 56 ++++- skyplane/gateway/gateway_program.py | 32 ++- .../gateway/operators/gateway_operator.py | 203 ++++++++++++++++-- .../gateway/operators/gateway_receiver.py | 24 +-- skyplane/planner/planner.py | 54 ++++- 5 files changed, 312 insertions(+), 57 deletions(-) diff --git a/skyplane/gateway/gateway_daemon.py b/skyplane/gateway/gateway_daemon.py index 06ef3e1ad..343bad6ac 100644 --- a/skyplane/gateway/gateway_daemon.py +++ b/skyplane/gateway/gateway_daemon.py @@ -23,6 +23,10 @@ GatewayObjStoreReadOperator, GatewayObjStoreWriteOperator, GatewayWaitReceiver, + GatewayCompress, + GatewayDecompress, + GatewayEncrypt, + GatewayDecrypt, ) from skyplane.gateway.operators.gateway_receiver import GatewayReceiver from skyplane.utils import logger @@ -90,8 +94,6 @@ def __init__( error_queue=self.error_queue, max_pending_chunks=max_incoming_ports, use_tls=self.use_tls, - use_compression=use_compression, - e2ee_key_bytes=self.e2ee_key_bytes, ) # API server @@ -232,8 +234,6 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_ error_queue=self.error_queue, chunk_store=self.chunk_store, use_tls=self.use_tls, - use_compression=op["compress"], - e2ee_key_bytes=self.e2ee_key_bytes, n_processes=op["num_connections"], ) total_p += op["num_connections"] @@ -264,6 +264,54 @@ def create_gateway_operators_helper(input_queue, program: List[Dict], partition_ chunk_store=self.chunk_store, ) total_p += 1 + elif op["op_type"] == "compress": + operators[handle] = GatewayCompress( + handle=handle, + region=self.region, + input_queue=input_queue, + output_queue=output_queue, + error_event=self.error_event, + error_queue=self.error_queue, + chunk_store=self.chunk_store, + use_compression=op["compress"], + ) + total_p += 1 + elif op["op_type"] == "decompress": + operators[handle] = GatewayDecompress( + handle=handle, + region=self.region, + input_queue=input_queue, + output_queue=output_queue, + error_event=self.error_event, + error_queue=self.error_queue, + chunk_store=self.chunk_store, + use_compression=op["compress"], + ) + total_p += 1 + elif op["op_type"] == "encrypt": + operators[handle] = GatewayEncrypt( + handle=handle, + region=self.region, + input_queue=input_queue, + output_queue=output_queue, + error_event=self.error_event, + error_queue=self.error_queue, + chunk_store=self.chunk_store, + e2ee_key_bytes=self.e2ee_key_bytes, + ) + total_p += 1 + elif op["op_type"] == "decrypt": + operators[handle] = GatewayDecrypt( + handle=handle, + region=self.region, + input_queue=input_queue, + output_queue=output_queue, + error_event=self.error_event, + error_queue=self.error_queue, + chunk_store=self.chunk_store, + e2ee_key_bytes=self.e2ee_key_bytes, + ) + total_p += 1 else: raise ValueError(f"Unsupported op_type {op['op_type']}") # recursively create for child operators diff --git a/skyplane/gateway/gateway_program.py b/skyplane/gateway/gateway_program.py index f275b82e8..f2432df95 100644 --- a/skyplane/gateway/gateway_program.py +++ b/skyplane/gateway/gateway_program.py @@ -37,24 +37,18 @@ def __init__( target_gateway_id: str, region: str, num_connections: int = 32, - compress: bool = False, - encrypt: bool = False, private_ip: bool = False, ): super().__init__("send") self.target_gateway_id = target_gateway_id # gateway to send to self.region = region # region to send to self.num_connections = num_connections # default this for now - self.compress = compress - self.encrypt = encrypt self.private_ip = private_ip # whether to send to private or public IP (private for GCP->GCP) class GatewayReceive(GatewayOperator): - def __init__(self, decompress: bool = False, decrypt: bool = False, max_pending_chunks: int = 1000): + def __init__(self, max_pending_chunks: int = 1000): super().__init__("receive") - self.decompress = decompress - self.decrypt = decrypt self.max_pending_chunks = max_pending_chunks @@ -97,6 +91,30 @@ def __init__(self): super().__init__("mux_or") +class GatewayCompress(GatewayOperator): + def __init__(self, compress: bool = False): + super().__init__("compress") + self.compress = compress + + +class GatewayDecompress(GatewayOperator): + def __init__(self, compress: bool = False): + super().__init__("decompress") + self.compress = compress + + +class GatewayEncrypt(GatewayOperator): + def __init__(self, encrypt: bool = False): + super().__init__("encrypt") + self.encrypt = encrypt + + +class GatewayDecrypt(GatewayOperator): + def __init__(self, decrypt: bool = False): + super().__init__("decrypt") + self.decrypt = decrypt + + class GatewayProgram: """ diff --git a/skyplane/gateway/operators/gateway_operator.py b/skyplane/gateway/operators/gateway_operator.py index c574caebc..49417ad48 100644 --- a/skyplane/gateway/operators/gateway_operator.py +++ b/skyplane/gateway/operators/gateway_operator.py @@ -162,15 +162,11 @@ def __init__( chunk_store: ChunkStore, ip_addr: str, use_tls: Optional[bool] = True, - use_compression: Optional[bool] = True, - e2ee_key_bytes: Optional[bytes] = None, n_processes: Optional[int] = 32, ): super().__init__(handle, region, input_queue, output_queue, error_event, error_queue, chunk_store, n_processes) self.ip_addr = ip_addr self.use_tls = use_tls - self.use_compression = use_compression - self.e2ee_key_bytes = e2ee_key_bytes self.args = (ip_addr,) # provider = region.split(":")[0] @@ -179,12 +175,6 @@ def __init__( # elif provider == "azure": # self.n_processes = 24 # due to throttling limits from authentication - # encryption - if e2ee_key_bytes is None: - self.e2ee_secretbox = None - else: - self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes) - # SSL context if use_tls: self.ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) @@ -332,14 +322,6 @@ def process(self, chunk_req: ChunkRequest, dst_host: str): raw_wire_length = wire_length compressed_length = None - if self.use_compression: - data = lz4.frame.compress(data) - wire_length = len(data) - compressed_length = wire_length - if self.e2ee_secretbox is not None: - data = self.e2ee_secretbox.encrypt(data) - wire_length = len(data) - # send chunk header header = chunk.to_wire_header( n_chunks_left_on_socket=len(chunk_ids) - idx - 1, @@ -600,3 +582,188 @@ def process(self, chunk_req: ChunkRequest): f"[obj_store:{self.worker_id}] Uploaded {chunk_req.chunk.chunk_id} partition {chunk_req.chunk.part_number} to {self.bucket_name}" ) return True + + +class GatewayCompress(GatewayOperator): + def __init__( + self, + handle: str, + region: str, + input_queue: GatewayQueue, + output_queue: GatewayQueue, + error_event, + error_queue: Queue, + chunk_store: Optional[ChunkStore] = None, + use_compression: Optional[bool] = True, + prefix: Optional[str] = "", + ): + super().__init__( + handle, region, input_queue, output_queue, error_event, error_queue, chunk_store + ) + self.chunk_store = chunk_store + self.use_compression = use_compression + self.prefix = prefix + + def process(self, chunk_req: ChunkRequest): + if not self.use_compression: return True + logger.debug( + f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}" + ) + chunk_reqs = [chunk_req] + for idx, chunk_req in enumerate(chunk_reqs): + chunk_id = chunk_req.chunk.chunk_id + chunk = chunk_req.chunk + chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) + + with open(chunk_file_path, "rb") as f: + data = f.read() + assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}" + + data = lz4.frame.compress(data) + wire_length = len(data) + compressed_length = wire_length + + with open(chunk_file_path, "wb") as f: + data = f.write() + return True + + +class GatewayDecompress(GatewayOperator): + def __init__( + self, + handle: str, + region: str, + input_queue: GatewayQueue, + output_queue: GatewayQueue, + error_event, + error_queue: Queue, + chunk_store: Optional[ChunkStore] = None, + use_compression: Optional[bool] = True, + prefix: Optional[str] = "", + ): + super().__init__( + handle, region, input_queue, output_queue, error_event, error_queue, chunk_store + ) + self.chunk_store = chunk_store + self.use_compression = use_compression + self.prefix = prefix + + def process(self, chunk_req: ChunkRequest): + if not self.use_compression: return True + logger.debug( + f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}" + ) + chunk_reqs = [chunk_req] + for idx, chunk_req in enumerate(chunk_reqs): + chunk_id = chunk_req.chunk.chunk_id + chunk = chunk_req.chunk + chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) + + with open(chunk_file_path, "rb") as f: + data = f.read() + assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}" + + data = lz4.frame.decompress(data) + wire_length = len(data) + compressed_length = wire_length + + with open(chunk_file_path, "wb") as f: + data = f.write() + return True + +class GatewayEncrypt(GatewayOperator): + def __init__( + self, + handle: str, + region: str, + input_queue: GatewayQueue, + output_queue: GatewayQueue, + error_event, + error_queue: Queue, + chunk_store: Optional[ChunkStore] = None, + e2ee_key_bytes: Optional[bytes] = None, + prefix: Optional[str] = "", + ): + super().__init__( + handle, region, input_queue, output_queue, error_event, error_queue, chunk_store + ) + self.chunk_store = chunk_store + self.prefix = prefix + + # encryption + if e2ee_key_bytes is None: + self.e2ee_secretbox = None + else: + self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes) + + def process(self, chunk_req: ChunkRequest): + if not self.e2ee_secretbox: return + logger.debug( + f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}" + ) + chunk_reqs = [chunk_req] + for idx, chunk_req in enumerate(chunk_reqs): + chunk_id = chunk_req.chunk.chunk_id + chunk = chunk_req.chunk + chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) + + with open(chunk_file_path, "rb") as f: + data = f.read() + assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}" + + data = self.e2ee_secretbox.encrypt(data) + wire_length = len(data) + encrypted_length = wire_length + + with open(chunk_file_path, "wb") as f: + data = f.write() + return True + + +class GatewayDecrypt(GatewayOperator): + def __init__( + self, + handle: str, + region: str, + input_queue: GatewayQueue, + output_queue: GatewayQueue, + error_event, + error_queue: Queue, + chunk_store: Optional[ChunkStore] = None, + e2ee_key_bytes: Optional[bytes] = None, + prefix: Optional[str] = "", + ): + super().__init__( + handle, region, input_queue, output_queue, error_event, error_queue, chunk_store + ) + self.chunk_store = chunk_store + self.prefix = prefix + + # encryption + if e2ee_key_bytes is None: + self.e2ee_secretbox = None + else: + self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes) + + def process(self, chunk_req: ChunkRequest): + if not self.e2ee_secretbox: return + logger.debug( + f"[{self.handle}:{self.worker_id}] Start upload {chunk_req.chunk.chunk_id} to {self.bucket_name}, key {chunk_req.chunk.dest_key}" + ) + chunk_reqs = [chunk_req] + for idx, chunk_req in enumerate(chunk_reqs): + chunk_id = chunk_req.chunk.chunk_id + chunk = chunk_req.chunk + chunk_file_path = self.chunk_store.get_chunk_file_path(chunk_id) + + with open(chunk_file_path, "rb") as f: + data = f.read() + assert len(data) == chunk.chunk_length_bytes, f"chunk {chunk_id} has size {len(data)} but should be {chunk.chunk_length_bytes}" + + data = self.e2ee_secretbox.decrypt(data) + wire_length = len(data) + encrypted_length = wire_length + + with open(chunk_file_path, "wb") as f: + data = f.write() + return True diff --git a/skyplane/gateway/operators/gateway_receiver.py b/skyplane/gateway/operators/gateway_receiver.py index aeed0e341..21d3c8c98 100644 --- a/skyplane/gateway/operators/gateway_receiver.py +++ b/skyplane/gateway/operators/gateway_receiver.py @@ -31,8 +31,6 @@ def __init__( recv_block_size=4 * MB, max_pending_chunks=1, use_tls: Optional[bool] = True, - use_compression: Optional[bool] = True, - e2ee_key_bytes: Optional[bytes] = None, ): self.handle = handle self.region = region @@ -42,11 +40,6 @@ def __init__( self.recv_block_size = recv_block_size self.max_pending_chunks = max_pending_chunks print("Max pending chunks", self.max_pending_chunks) - self.use_compression = use_compression - if e2ee_key_bytes is None: - self.e2ee_secretbox = None - else: - self.e2ee_secretbox = nacl.secret.SecretBox(e2ee_key_bytes) self.server_processes = [] self.server_ports = [] self.next_gateway_worker_id = 0 @@ -153,9 +146,6 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): # TODO: this wont work # chunk_request = self.chunk_store.get_chunk_request(chunk_header.chunk_id) - should_decrypt = self.e2ee_secretbox is not None # and chunk_request.dst_region == self.region - should_decompress = chunk_header.is_compressed # and chunk_request.dst_region == self.region - # wait for space # while self.chunk_store.remaining_bytes() < chunk_header.data_len * self.max_pending_chunks: # print( @@ -171,7 +161,7 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): fpath = self.chunk_store.get_chunk_file_path(chunk_header.chunk_id) with fpath.open("wb") as f: socket_data_len = chunk_header.data_len - chunk_received_size, chunk_received_size_decompressed = 0, 0 + chunk_received_size = 0 to_write = bytearray(socket_data_len) to_write_view = memoryview(to_write) while socket_data_len > 0: @@ -188,18 +178,6 @@ def recv_chunks(self, conn: socket.socket, addr: Tuple[str, int]): ) to_write = bytes(to_write) - if should_decrypt: - to_write = self.e2ee_secretbox.decrypt(to_write) - print(f"[receiver:{server_port}]:{chunk_header.chunk_id} Decrypting {len(to_write)} bytes") - - if should_decompress: - data_batch_decompressed = lz4.frame.decompress(to_write) - chunk_received_size_decompressed += len(data_batch_decompressed) - to_write = data_batch_decompressed - print( - f"[receiver:{server_port}]:{chunk_header.chunk_id} Decompressing {len(to_write)} bytes to {chunk_received_size_decompressed} bytes" - ) - # try to write data until successful while True: try: diff --git a/skyplane/planner/planner.py b/skyplane/planner/planner.py index 6bfb5233a..03a400092 100644 --- a/skyplane/planner/planner.py +++ b/skyplane/planner/planner.py @@ -17,6 +17,10 @@ GatewayWriteObjectStore, GatewayReceive, GatewaySend, + GatewayEncrypt, + GatewayDecrypt, + GatewayCompress, + GatewayDecompress, ) from skyplane.api.transfer_job import TransferJob @@ -239,6 +243,16 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: ) mux_or = src_program.add_operator(GatewayMuxOr(), parent_handle=obj_store_read, partition_id=partition_id) for i in range(n_instances): + compress_op = dst_program[dst_region_tag].add_operator( + GatewayCompress(decompress=self.transfer_config.use_compression), + parent_handle=mux_or, + partition_id=partition_id, + ) + encrypt_op = dst_program[dst_region_tag].add_operator( + GatewayEncrypt(decrypt=self.transfer_config.use_e2ee), + parent_handle=compress_op, + partition_id=partition_id, + ) src_program.add_operator( GatewaySend( target_gateway_id=dst_gateways[i].gateway_id, @@ -247,14 +261,24 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: compress=True, encrypt=True, ), - parent_handle=mux_or, + parent_handle=encrypt_op, partition_id=partition_id, ) # dst region gateway program recv_op = dst_program.add_operator(GatewayReceive(decompress=True, decrypt=True), partition_id=partition_id) + decrypt_op = dst_program[dst_region_tag].add_operator( + GatewayDecrypt(decrypt=self.transfer_config.use_e2ee), + parent_handle=recv_op, + partition_id=partition_id, + ) + decompress_op = dst_program[dst_region_tag].add_operator( + GatewayDecompress(decompress=self.transfer_config.use_compression), + parent_handle=decrypt_op, + partition_id=partition_id, + ) dst_program.add_operator( - GatewayWriteObjectStore(dst_bucket, dst_region_tag, self.n_connections), parent_handle=recv_op, partition_id=partition_id + GatewayWriteObjectStore(dst_bucket, dst_region_tag, self.n_connections), parent_handle=decompress_op, partition_id=partition_id ) # update cost per GB @@ -341,6 +365,16 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: if dst_gateways[i].provider == "gcp" and src_provider == "gcp": # print("Using private IP for GCP to GCP transfer", src_region_tag, dst_region_tag) private_ip = True + compress_op = dst_program[dst_region_tag].add_operator( + GatewayDecompress(decompress=self.transfer_config.use_compression), + parent_handle=mux_or, + partition_id=partition_id, + ) + encrypt_op = dst_program[dst_region_tag].add_operator( + GatewayDecrypt(decrypt=self.transfer_config.use_e2ee), + parent_handle=compress_op, + partition_id=partition_id, + ) src_program.add_operator( GatewaySend( target_gateway_id=dst_gateways[i].gateway_id, @@ -350,18 +384,28 @@ def plan(self, jobs: List[TransferJob]) -> TopologyPlan: compress=self.transfer_config.use_compression, encrypt=self.transfer_config.use_e2ee, ), - parent_handle=mux_or, + parent_handle=encrypt_op, partition_id=partition_id, ) # each gateway also recieves data from source recv_op = dst_program[dst_region_tag].add_operator( - GatewayReceive(decompress=self.transfer_config.use_compression, decrypt=self.transfer_config.use_e2ee), + GatewayReceive(), + partition_id=partition_id, + ) + decrypt_op = dst_program[dst_region_tag].add_operator( + GatewayDecrypt(decrypt=self.transfer_config.use_e2ee), + parent_handle=recv_op, + partition_id=partition_id, + ) + decompress_op = dst_program[dst_region_tag].add_operator( + GatewayDecompress(decompress=self.transfer_config.use_compression), + parent_handle=decrypt_op, partition_id=partition_id, ) dst_program[dst_region_tag].add_operator( GatewayWriteObjectStore(dst_bucket, dst_region_tag, self.n_connections, key_prefix=dst_prefix), - parent_handle=recv_op, + parent_handle=decompress_op, partition_id=partition_id, )