Skip to content
Merged
6 changes: 2 additions & 4 deletions icechunk-python/python/icechunk/_icechunk_python.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -964,10 +964,8 @@ class PyRepository:
def save_config(self) -> None: ...
def config(self) -> RepositoryConfig: ...
def storage(self) -> Storage: ...
def set_default_commit_metadata(
self, metadata: dict[str, Any] | None = None
) -> None: ...
def default_commit_metadata(self) -> dict[str, Any] | None: ...
def set_default_commit_metadata(self, metadata: dict[str, Any]) -> None: ...
def default_commit_metadata(self) -> dict[str, Any]: ...
def async_ancestry(
self,
*,
Expand Down
10 changes: 5 additions & 5 deletions icechunk-python/python/icechunk/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def storage(self) -> Storage:
"""
return self._repository.storage()

def set_default_commit_metadata(self, metadata: dict[str, Any] | None = None) -> None:
def set_default_commit_metadata(self, metadata: dict[str, Any]) -> None:
"""
Set the default commit metadata for the repository. This is useful for providing
addition static system conexted metadata to all commits.
Expand All @@ -230,18 +230,18 @@ def set_default_commit_metadata(self, metadata: dict[str, Any] | None = None) ->

Parameters
----------
metadata : dict[str, Any], optional
The default commit metadata.
metadata : dict[str, Any]
The default commit metadata. Pass an empty dict to clear the default metadata.
"""
return self._repository.set_default_commit_metadata(metadata)

def default_commit_metadata(self) -> dict[str, Any] | None:
def default_commit_metadata(self) -> dict[str, Any]:
"""
Get the current configured default commit metadata for the repository.

Returns
-------
dict[str, Any] | None
dict[str, Any]
The default commit metadata.
"""
return self._repository.default_commit_metadata()
Expand Down
130 changes: 96 additions & 34 deletions icechunk-python/src/repository.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ impl PyGCSummary {
}

#[pyclass]
pub struct PyRepository(Arc<Repository>);
pub struct PyRepository(Arc<RwLock<Repository>>);

#[pymethods]
/// Most functions in this class call `Runtime.block_on` so they need to `allow_threads` so other
Expand Down Expand Up @@ -390,7 +390,7 @@ impl PyRepository {
.map_err(PyIcechunkStoreError::RepositoryError)
})?;

Ok(Self(Arc::new(repository)))
Ok(Self(Arc::new(RwLock::new(repository))))
})
}

Expand All @@ -416,7 +416,7 @@ impl PyRepository {
.map_err(PyIcechunkStoreError::RepositoryError)
})?;

Ok(Self(Arc::new(repository)))
Ok(Self(Arc::new(RwLock::new(repository))))
})
}

Expand Down Expand Up @@ -444,7 +444,7 @@ impl PyRepository {
)
})?;

Ok(Self(Arc::new(repository)))
Ok(Self(Arc::new(RwLock::new(repository))))
})
}

Expand Down Expand Up @@ -474,14 +474,15 @@ impl PyRepository {
virtual_chunk_credentials: Option<Option<HashMap<String, PyCredentials>>>,
) -> PyResult<Self> {
py.allow_threads(move || {
Ok(Self(Arc::new(
Ok(Self(Arc::new(RwLock::new(
self.0
.blocking_read()
.reopen(
config.map(|c| c.into()),
virtual_chunk_credentials.map(map_credentials),
)
.map_err(PyIcechunkStoreError::RepositoryError)?,
)))
))))
})
}

Expand All @@ -495,15 +496,18 @@ impl PyRepository {
py.allow_threads(move || {
let repository = Repository::from_bytes(bytes)
.map_err(PyIcechunkStoreError::RepositoryError)?;
Ok(Self(Arc::new(repository)))
Ok(Self(Arc::new(RwLock::new(repository))))
})
}

fn as_bytes(&self, py: Python<'_>) -> PyResult<Cow<[u8]>> {
// This is a compute intensive task, we need to release the Gil
py.allow_threads(move || {
let bytes =
self.0.as_bytes().map_err(PyIcechunkStoreError::RepositoryError)?;
let bytes = self
.0
.blocking_read()
.as_bytes()
.map_err(PyIcechunkStoreError::RepositoryError)?;
Ok(Cow::Owned(bytes))
})
}
Expand All @@ -530,6 +534,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let _etag = self
.0
.read()
.await
.save_config()
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -539,33 +545,32 @@ impl PyRepository {
}

pub fn config(&self) -> PyRepositoryConfig {
self.0.config().clone().into()
self.0.blocking_read().config().clone().into()
}

pub fn storage_settings(&self) -> PyStorageSettings {
self.0.storage_settings().clone().into()
self.0.blocking_read().storage_settings().clone().into()
}

pub fn storage(&self) -> PyStorage {
PyStorage(Arc::clone(self.0.storage()))
PyStorage(Arc::clone(self.0.blocking_read().storage()))
}

#[pyo3(signature = (metadata))]
pub fn set_default_commit_metadata(
&self,
metadata: Option<PySnapshotProperties>,
) -> PyResult<()> {
let metadata = metadata.map(|m| m.into());
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0.set_default_commit_metadata(metadata).await;
Ok(())
py: Python<'_>,
metadata: PySnapshotProperties,
) {
py.allow_threads(move || {
let metadata = metadata.into();
self.0.blocking_write().set_default_commit_metadata(metadata);
})
}

pub fn default_commit_metadata(&self) -> PyResult<Option<PySnapshotProperties>> {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let metadata = self.0.default_commit_metadata().await;
Ok(metadata.map(|m| m.into()))
pub fn default_commit_metadata(&self, py: Python<'_>) -> PySnapshotProperties {
py.allow_threads(move || {
let metadata = self.0.blocking_read().default_commit_metadata();
metadata.into()
})
}

Expand All @@ -578,12 +583,20 @@ impl PyRepository {
tag: Option<String>,
snapshot_id: Option<String>,
) -> PyResult<PyAsyncGenerator> {
let repo = Arc::clone(&self.0);
// This function calls block_on, so we need to allow other thread python to make progress
py.allow_threads(move || {
let version = args_to_version_info(branch, tag, snapshot_id, None)?;
let ancestry = pyo3_async_runtimes::tokio::get_runtime()
.block_on(async move { repo.ancestry_arc(&version).await })
.block_on(async move {
let (snapshot_id, asset_manager) = {
let lock = self.0.read().await;
(
lock.resolve_version(&version).await?,
Arc::clone(lock.asset_manager()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could do the Arc::clone once the lock is released.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How? We need the lock to get the reference to the asset manager right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lock -> get the &Arc<AssetManager> -> unlock -> Clone

Copy link
Contributor Author

@mpiannucci mpiannucci Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, I didn't know you could keep a reference after unlocking

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't really think about that... I'm not sure now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah i have to clone from a scope where the lock is held. Which i handled by making an inner scope so it is released immediately after cloning the asset manager

)
};
asset_manager.snapshot_ancestry(&snapshot_id).await
})
.map_err(PyIcechunkStoreError::RepositoryError)?
.map_err(PyIcechunkStoreError::RepositoryError);

Expand Down Expand Up @@ -615,6 +628,8 @@ impl PyRepository {

pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.create_branch(branch_name, &snapshot_id)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -629,6 +644,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let branches = self
.0
.read()
.await
.list_branches()
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -643,6 +660,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let tip = self
.0
.read()
.await
.lookup_branch(branch_name)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -667,6 +686,8 @@ impl PyRepository {

pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.reset_branch(branch_name, &snapshot_id)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -680,6 +701,8 @@ impl PyRepository {
py.allow_threads(move || {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.delete_branch(branch)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -693,6 +716,8 @@ impl PyRepository {
py.allow_threads(move || {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.delete_tag(tag)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -717,6 +742,8 @@ impl PyRepository {

pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.create_tag(tag_name, &snapshot_id)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -731,6 +758,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let tags = self
.0
.read()
.await
.list_tags()
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand All @@ -745,6 +774,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let tag = self
.0
.read()
.await
.lookup_tag(tag)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand Down Expand Up @@ -773,6 +804,8 @@ impl PyRepository {
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let diff = self
.0
.read()
.await
.diff(&from, &to)
.await
.map_err(PyIcechunkStoreError::SessionError)?;
Expand All @@ -796,6 +829,8 @@ impl PyRepository {
let session =
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.readonly_session(&version)
.await
.map_err(PyIcechunkStoreError::RepositoryError)
Expand All @@ -811,6 +846,8 @@ impl PyRepository {
let session =
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
self.0
.read()
.await
.writable_session(branch)
.await
.map_err(PyIcechunkStoreError::RepositoryError)
Expand All @@ -832,10 +869,19 @@ impl PyRepository {
py.allow_threads(move || {
let result =
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let (storage, storage_settings, asset_manager) = {
let lock = self.0.read().await;
(
Arc::clone(lock.storage()),
lock.storage_settings().clone(),
Arc::clone(lock.asset_manager()),
)
};

let result = expire(
self.0.storage().as_ref(),
self.0.storage_settings(),
self.0.asset_manager().clone(),
storage.as_ref(),
&storage_settings,
asset_manager,
older_than,
if delete_expired_branches {
ExpiredRefAction::Delete
Expand Down Expand Up @@ -877,10 +923,18 @@ impl PyRepository {
delete_object_older_than,
Default::default(),
);
let (storage, storage_settings, asset_manager) = {
let lock = self.0.read().await;
(
Arc::clone(lock.storage()),
lock.storage_settings().clone(),
Arc::clone(lock.asset_manager()),
)
};
let result = garbage_collect(
self.0.storage().as_ref(),
self.0.storage_settings(),
self.0.asset_manager().clone(),
storage.as_ref(),
&storage_settings,
asset_manager,
&gc_config,
)
.await
Expand All @@ -897,10 +951,18 @@ impl PyRepository {
py.allow_threads(move || {
let result =
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let (storage, storage_settings, asset_manager) = {
let lock = self.0.read().await;
(
Arc::clone(lock.storage()),
lock.storage_settings().clone(),
Arc::clone(lock.asset_manager()),
)
};
let result = repo_chunks_storage(
self.0.storage().as_ref(),
self.0.storage_settings(),
self.0.asset_manager().clone(),
storage.as_ref(),
&storage_settings,
asset_manager,
)
.await
.map_err(PyIcechunkStoreError::RepositoryError)?;
Expand Down
6 changes: 5 additions & 1 deletion icechunk-python/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@


def create_local_repo(path: str) -> Repository:
return Repository.create(storage=local_filesystem_storage(path))
repo = Repository.create(storage=local_filesystem_storage(path))
repo.set_default_commit_metadata({"author": "test"})
return repo


@pytest.fixture(scope="function")
Expand All @@ -29,6 +31,8 @@ def test_pickle_repository(tmpdir: Path, tmp_repo: Repository) -> None:
pickled = pickle.dumps(tmp_repo)
roundtripped = pickle.loads(pickled)
assert tmp_repo.list_branches() == roundtripped.list_branches()
assert tmp_repo.default_commit_metadata() == roundtripped.default_commit_metadata()
assert tmp_repo.default_commit_metadata() == {"author": "test"}

storage = tmp_repo.storage
assert (
Expand Down
Loading