Skip to content

Fix #827 #907 #908 #901, add compress operator and encrypt operator #906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ cvxpy = { version = ">=1.1.0", optional = true }
graphviz = { version = ">=0.15", optional = true }
matplotlib = { version = ">=3.0.0", optional = true }
numpy = { version = ">=1.19.0", optional = true }
networkx = { version = ">=2.5", optional = true }

# gateway dependencies
flask = { version = "^2.1.2", optional = true }
Expand All @@ -70,7 +71,7 @@ gcp = ["google-api-python-client", "google-auth", "google-cloud-compute", "googl
ibm = ["ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
all = ["boto3", "azure-identity", "azure-mgmt-authorization", "azure-mgmt-compute", "azure-mgmt-network", "azure-mgmt-resource", "azure-mgmt-storage", "azure-mgmt-subscription", "azure-storage-blob", "google-api-python-client", "google-auth", "google-cloud-compute", "google-cloud-storage", "ibm-cloud-sdk-core", "ibm-cos-sdk", "ibm-vpc"]
gateway = ["flask", "lz4", "pynacl", "pyopenssl", "werkzeug"]
solver = ["cvxpy", "graphviz", "matplotlib", "numpy"]
solver = ["networkx", "cvxpy", "graphviz", "matplotlib", "numpy"]

[tool.poetry.dev-dependencies]
pytest = ">=6.0.0"
Expand Down
1 change: 1 addition & 0 deletions scripts/requirements-gateway.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ numpy
pandas
pyarrow
typer
networkx
4 changes: 0 additions & 4 deletions skyplane/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
from typing import TYPE_CHECKING, Optional

from skyplane.api.config import TransferConfig
from skyplane.api.dataplane import Dataplane
from skyplane.api.provisioner import Provisioner
from skyplane.api.obj_store import ObjectStore
from skyplane.api.usage import get_clientid
from skyplane.obj_store.object_store_interface import ObjectStoreInterface
from skyplane.planner.planner import MulticastDirectPlanner
from skyplane.utils import logger
from skyplane.utils.definitions import tmp_log_dir
from skyplane.utils.path import parse_path

from skyplane.api.pipeline import Pipeline

Expand Down
2 changes: 1 addition & 1 deletion skyplane/api/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass

from typing import Optional, List
from typing import Optional

from skyplane import compute

Expand Down
14 changes: 8 additions & 6 deletions skyplane/api/dataplane.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import json
import os
import threading
from collections import defaultdict, Counter
from collections import defaultdict
from datetime import datetime
from functools import partial
from datetime import datetime

import nacl.secret
import nacl.utils
import typer
import urllib3
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional

from skyplane import compute
from skyplane.exceptions import GatewayContainerStartException
from skyplane.api.tracker import TransferProgressTracker, TransferHook
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.transfer_job import TransferJob
from skyplane.api.config import TransferConfig
from skyplane.planner.topology import TopologyPlan, TopologyPlanGateway
from skyplane.utils import logger
Expand Down Expand Up @@ -89,7 +90,6 @@ def _start_gateway(
gateway_server: compute.Server,
gateway_log_dir: Optional[PathLike],
authorize_ssh_pub_key: Optional[str] = None,
e2ee_key_bytes: Optional[str] = None,
):
# map outgoing ports
setup_args = {}
Expand Down Expand Up @@ -119,9 +119,7 @@ def _start_gateway(
gateway_docker_image=gateway_docker_image,
gateway_program_path=str(gateway_program_filename),
gateway_info_path=f"{gateway_log_dir}/gateway_info.json",
e2ee_key_bytes=e2ee_key_bytes, # TODO: remove
use_bbr=self.transfer_config.use_bbr, # TODO: remove
use_compression=self.transfer_config.use_compression,
use_socket_tls=self.transfer_config.use_socket_tls,
)

Expand Down Expand Up @@ -202,6 +200,10 @@ def provision(
# todo: move server.py:start_gateway here
logger.fs.info(f"Using docker image {gateway_docker_image}")
e2ee_key_bytes = nacl.utils.random(nacl.secret.SecretBox.KEY_SIZE)
# save E2EE keys
e2ee_key_file = "e2ee_key"
with open(f"/tmp/{e2ee_key_file}", 'wb') as f:
f.write(e2ee_key_bytes)

# create gateway logging dir
gateway_program_dir = f"{self.log_dir}/programs"
Expand All @@ -218,7 +220,7 @@ def provision(
jobs = []
for node, server in gateway_bound_nodes.items():
jobs.append(
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key, e2ee_key_bytes)
partial(self._start_gateway, gateway_docker_image, node, server, gateway_program_dir, authorize_ssh_pub_key)
)
logger.fs.debug(f"[Dataplane.provision] Starting gateways on {len(jobs)} servers")
try:
Expand Down
42 changes: 27 additions & 15 deletions skyplane/api/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
import json
import time
import os
import threading
from collections import defaultdict, Counter
from datetime import datetime
from functools import partial
from datetime import datetime

import nacl.secret
import nacl.utils
import urllib3
from typing import TYPE_CHECKING, Dict, List, Optional

from skyplane import compute
from skyplane.api.tracker import TransferProgressTracker, TransferHook
from skyplane.api.tracker import TransferProgressTracker
from skyplane.api.transfer_job import CopyJob, SyncJob, TransferJob
from skyplane.api.config import TransferConfig

from skyplane.planner.planner import MulticastDirectPlanner, DirectPlannerSourceOneSided, DirectPlannerDestOneSided
from skyplane.planner.planner import (
MulticastDirectPlanner,
DirectPlannerSourceOneSided,
DirectPlannerDestOneSided,
UnicastDirectPlanner,
UnicastILPPlanner,
MulticastILPPlanner,
MulticastMDSTPlanner,
)
from skyplane.planner.topology import TopologyPlanGateway
from skyplane.utils import logger
from skyplane.utils.definitions import gateway_docker_image, tmp_log_dir
from skyplane.utils.fn import PathLike, do_parallel
from skyplane.utils.definitions import tmp_log_dir

from skyplane.api.dataplane import Dataplane

Expand Down Expand Up @@ -69,12 +69,23 @@ def __init__(

# planner
self.planning_algorithm = planning_algorithm

if self.planning_algorithm == "direct":
self.planner = MulticastDirectPlanner(self.max_instances, self.n_connections, self.transfer_config)
self.planner = MulticastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "src_one_sided":
self.planner = DirectPlannerSourceOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerSourceOneSided(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "dst_one_sided":
self.planner = DirectPlannerDestOneSided(self.max_instances, self.n_connections, self.transfer_config)
self.planner = DirectPlannerDestOneSided(self.transfer_config, self.max_instances, self.n_connections)
# TODO: should find some ways to merge direct / Ndirect
self.planner = UnicastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_direct":
self.planner = MulticastDirectPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_dst":
self.planner = MulticastMDSTPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "multi_ilp":
self.planning_algorithm = MulticastILPPlanner(self.transfer_config, self.max_instances, self.n_connections)
elif self.planning_algorithm == "uni_ilp":
self.planning_algorithm = UnicastILPPlanner(self.transfer_config, self.max_instances, self.n_connections)
else:
raise ValueError(f"No such planning algorithm {planning_algorithm}")

Expand Down Expand Up @@ -118,7 +129,7 @@ def start(self, debug=False, progress=False):
# copy gateway logs
if debug:
dp.copy_gateway_logs()
except Exception as e:
except Exception:
dp.copy_gateway_logs()
dp.deprovision(spinner=True)
return dp
Expand Down Expand Up @@ -193,3 +204,4 @@ def estimate_total_cost(self):

# return size
return total_size * topo.cost_per_gb

1 change: 0 additions & 1 deletion skyplane/api/tracker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
from pprint import pprint
import json
import time
from abc import ABC
Expand Down
17 changes: 11 additions & 6 deletions skyplane/api/transfer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, Generator, List, Optional, Tuple, TypeVar, Dict

from abc import ABC, abstractmethod
from abc import ABC

import urllib3
from rich import print as rprint
from functools import partial

from skyplane import exceptions
from skyplane.api.config import TransferConfig
from skyplane.chunk import Chunk, ChunkRequest
from skyplane.chunk import Chunk
from skyplane.obj_store.storage_interface import StorageInterface
from skyplane.obj_store.object_store_interface import ObjectStoreObject, ObjectStoreInterface
from skyplane.utils import logger
Expand Down Expand Up @@ -102,6 +102,7 @@ def _run_multipart_chunk_thread(
src_object = transfer_pair.src_obj
dest_objects = transfer_pair.dst_objs
dest_key = transfer_pair.dst_key
print("dest_key: ", dest_key)
if isinstance(self.src_iface, ObjectStoreInterface):
mime_type = self.src_iface.get_obj_mime_type(src_object.key)
# create multipart upload request per destination
Expand Down Expand Up @@ -283,10 +284,10 @@ def transfer_pair_generator(
dest_provider, dest_region = dst_iface.region_tag().split(":")
try:
dest_key = self.map_object_key_prefix(src_prefix, obj.key, dst_prefix, recursive=recursive)
assert (
dest_key[: len(dst_prefix)] == dst_prefix
), f"Destination key {dest_key} does not start with destination prefix {dst_prefix}"
dest_keys.append(dest_key[len(dst_prefix) :])
# TODO: why is it changed here?
# dest_keys.append(dest_key[len(dst_prefix) :])

dest_keys.append(dest_key)
except exceptions.MissingObjectException as e:
logger.fs.exception(e)
raise e from None
Expand Down Expand Up @@ -508,8 +509,12 @@ def dst_prefixes(self) -> List[str]:
if not hasattr(self, "_dst_prefix"):
if self.transfer_type == "unicast":
self._dst_prefix = [str(parse_path(self.dst_paths[0])[2])]
print("return dst_prefixes for unicast", self._dst_prefix)
else:
for path in self.dst_paths:
print("Parsing result for multicast", parse_path(path))
self._dst_prefix = [str(parse_path(path)[2]) for path in self.dst_paths]
print("return dst_prefixes for multicast", self._dst_prefix)
return self._dst_prefix

@property
Expand Down
2 changes: 1 addition & 1 deletion skyplane/api/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import requests
from rich import print as rprint
from typing import Optional, Dict, List
from typing import Optional, Dict

import skyplane
from skyplane.utils.definitions import tmp_log_dir
Expand Down
11 changes: 3 additions & 8 deletions skyplane/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ class Chunk:
part_number: Optional[int] = None
upload_id: Optional[str] = None # TODO: for broadcast, this is not used

def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int, is_compressed: bool = False):
def to_wire_header(self, n_chunks_left_on_socket: int, wire_length: int, raw_wire_length: int):
return WireProtocolHeader(
chunk_id=self.chunk_id,
data_len=wire_length,
raw_data_len=raw_wire_length,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand Down Expand Up @@ -99,7 +98,6 @@ class WireProtocolHeader:
chunk_id: str # 128bit UUID
data_len: int # long
raw_data_len: int # long (uncompressed, unecrypted)
is_compressed: bool # char
n_chunks_left_on_socket: int # long

@staticmethod
Expand All @@ -115,8 +113,8 @@ def protocol_version():

@staticmethod
def length_bytes():
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + is_compressed (1) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 1 + 8
# magic (8) + protocol_version (4) + chunk_id (16) + data_len (8) + raw_data_len(8) + n_chunks_left_on_socket (8)
return 8 + 4 + 16 + 8 + 8 + 8

@staticmethod
def from_bytes(data: bytes):
Expand All @@ -130,13 +128,11 @@ def from_bytes(data: bytes):
chunk_id = data[12:28].hex()
chunk_len = int.from_bytes(data[28:36], byteorder="big")
raw_chunk_len = int.from_bytes(data[36:44], byteorder="big")
is_compressed = bool(int.from_bytes(data[44:45], byteorder="big"))
n_chunks_left_on_socket = int.from_bytes(data[45:53], byteorder="big")
return WireProtocolHeader(
chunk_id=chunk_id,
data_len=chunk_len,
raw_data_len=raw_chunk_len,
is_compressed=is_compressed,
n_chunks_left_on_socket=n_chunks_left_on_socket,
)

Expand All @@ -149,7 +145,6 @@ def to_bytes(self):
out_bytes += chunk_id_bytes
out_bytes += self.data_len.to_bytes(8, byteorder="big")
out_bytes += self.raw_data_len.to_bytes(8, byteorder="big")
out_bytes += self.is_compressed.to_bytes(1, byteorder="big")
out_bytes += self.n_chunks_left_on_socket.to_bytes(8, byteorder="big")
assert len(out_bytes) == WireProtocolHeader.length_bytes(), f"{len(out_bytes)} != {WireProtocolHeader.length_bytes()}"
return out_bytes
Expand Down
2 changes: 1 addition & 1 deletion skyplane/cli/impl/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, DownloadColumn, TransferSpeedColumn, TimeRemainingColumn
from skyplane import exceptions
from skyplane.chunk import Chunk
from skyplane.cli.impl.common import console, print_stats_completed
from skyplane.cli.impl.common import console
from skyplane.utils.definitions import format_bytes
from skyplane.api.tracker import TransferHook

Expand Down
Loading