Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions src/goats_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import shutil
import subprocess
import time
from importlib.metadata import version
from pathlib import Path

import click
Expand All @@ -15,6 +14,56 @@
from goats_cli.exceptions import GOATSClickException
from goats_cli.modify_settings import modify_settings
from goats_cli.process_manager import ProcessManager
from goats_cli.versioning import VersionChecker

UPDATE_DOC_URL = "https://goats.readthedocs.io/en/stable/update.html"


def _check_version() -> None:
"""
Check whether GOATS is outdated.
This function does not return a value. It may block for user input and
can raise ``click.Abort`` if the user cancels

Raises
------
GOATSClickException
If the latest version cannot be resolved or version strings are invalid.
"""
utils.display_message("Checking for updates...\n")
try:
checker = VersionChecker()
if checker.check_if_outdated():
utils.display_warning(
"A new version of GOATS is available: "
f"{checker.latest_version} (current: {checker.current_version})"
)
utils.display_info(
"GOATS interacts with several external services (e.g., GPP, GOA, TNS)"
"\n "
"which may evolve over time. Using an outdated version can result in"
"\n "
"unexpected behavior or failed operations due to API changes or\n "
"incompatible features.\n\n"
)
utils.display_info(
f"➤ Visit {UPDATE_DOC_URL} for update instructions\n\n",
)
utils.display_info(
"Press Enter to continue at your own risk, or Ctrl+C to cancel...",
)
try:
click.prompt("", default="", show_default=False, prompt_suffix="")
except (KeyboardInterrupt, EOFError):
raise click.Abort()

else:
utils.display_message(
f"GOATS is up to date (version {checker.current_version})."
"No update necessary.\n"
)
except GOATSClickException as error:
utils.display_warning((f"{error}\n Proceeding without version check.\n"))


def _run_migrations(manage_file: Path) -> None:
Expand Down Expand Up @@ -194,7 +243,7 @@ def start_background_workers(manage_file: Path, workers: int) -> subprocess.Pope


@click.group(invoke_without_command=True)
@click.version_option(version=version("goats"))
@click.version_option(package_name="goats")
@click.pass_context
def cli(ctx):
"""Gemini Observation and Analysis of Targets System (GOATS).
Expand Down Expand Up @@ -454,6 +503,9 @@ def run(
Raised if the 'subprocess' calls fail.
"""
utils.display_message("Serving GOATS.\n")

_check_version()

utils.display_message(
"Finding GOATS and Redis installation:", show_goats_emoji=True
)
Expand Down
147 changes: 147 additions & 0 deletions src/goats_cli/versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from importlib.metadata import version as get_version
from json import JSONDecodeError

import requests
from packaging.version import InvalidVersion, Version

from goats_cli.exceptions import GOATSClickException

CHANNELDATA_URL = "https://gemini-hlsw.github.io/goats-infra/conda/channeldata.json"


class VersionChecker:
"""
Compare the installed GOATS version against the latest available in the channel.

Parameters
----------
channeldata_url : str, optional
URL to the ``channeldata.json`` file used to resolve the latest version.
Defaults to ``CHANNELDATA_URL``.
package_name : str, optional
Package name whose installed version is obtained via
``importlib.metadata.version``. Defaults to ``"goats"``.
timeout_sec : float, optional
HTTP request timeout in seconds. Defaults to ``1.0``.

Attributes
----------
channeldata_url : str
URL used to query the latest available version.
package_name : str
Package name used to resolve the installed version.
timeout_sec : float
Timeout applied to the HTTP request.
current_version : str | None
Installed package version (``None`` until computed).
latest_version : str | None
Latest available version from the channel (``None`` until computed).
is_outdated : bool | None
``True`` if ``current_version < latest_version``, ``False`` otherwise
(``None`` until computed).
"""

def __init__(
self,
channeldata_url: str = CHANNELDATA_URL,
package_name: str = "goats",
timeout_sec: float = 1.0,
) -> None:
"""
Initialize a :class:`VersionChecker` instance.

Parameters
----------
channeldata_url : str, optional
URL to the channel metadata JSON. Defaults to ``CHANNELDATA_URL``.
package_name : str, optional
Package name to inspect. Defaults to ``"goats"``.
timeout_sec : float, optional
Timeout (seconds) for HTTP requests. Defaults to ``1.0``.
"""
self.channeldata_url = channeldata_url
self.package_name = package_name
self.timeout_sec = timeout_sec

self.current_version: str | None = None
self.latest_version: str | None = None
self.is_outdated: bool | None = None

def _get_current_version(self) -> str:
"""
Return the currently installed version string for ``self.package_name``.

Returns
-------
str
Installed version string (e.g., ``"1.2.3"``).

Raises
------
importlib.metadata.PackageNotFoundError
If the package is not installed in the current environment.
"""
return get_version(self.package_name).strip()

def _get_latest_version(self) -> str:
"""
Fetch the latest available version string from the Conda channel.

Returns
-------
str
Latest version string for ``self.package_name`` (e.g., ``"1.2.3"``).

Raises
------
GOATSClickException
If a network/HTTP error occurs, the response payload is invalid JSON,
or the JSON structure does not contain the expected keys.
"""
try:
resp = requests.get(self.channeldata_url, timeout=self.timeout_sec)
resp.raise_for_status()
data = resp.json()
return data["packages"][self.package_name]["version"].strip()
except requests.RequestException as error:
raise GOATSClickException(f"Failed to fetch latest version info: {error}")
except JSONDecodeError as error:
raise GOATSClickException(f"Invalid JSON: {error}")
except (KeyError, TypeError) as error:
raise GOATSClickException(
f"Malformed channel metadata while obtaining latest version: {error}"
)

def check_if_outdated(self) -> bool:
"""
Resolve both installed and latest versions and update the instance state.

This method always re-queries the environment and the channel:
it refreshes :attr:`current_version`, :attr:`latest_version`, and
recomputes :attr:`is_outdated`.

Returns
-------
bool
``True`` if an update is available (``installed < latest``),
otherwise ``False``.

Raises
------
GOATSClickException
If fetching/parsing the channel metadata fails, or if either version
string is invalid (invalid PEP 440 format).
"""
self.current_version = self._get_current_version()
self.latest_version = self._get_latest_version()
try:
self.is_outdated = Version(self.current_version) < Version(
self.latest_version
)
return self.is_outdated
except InvalidVersion as error:
raise GOATSClickException(
"Invalid version string while comparing versions: "
f"current={self.current_version!r}, "
f"latest={self.latest_version!r}"
) from error
129 changes: 129 additions & 0 deletions tests/goats_cli/test_versioning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
from json import JSONDecodeError
from unittest.mock import Mock, patch
from requests import RequestException, HTTPError
import requests

from goats_cli.versioning import VersionChecker
from goats_cli.exceptions import GOATSClickException


def _fake_response(json_payload=None, status_code=200, json_raises=None) -> requests.Response:
resp = Mock(spec=requests.Response)
resp.status_code = status_code
if status_code >= 400:
resp.raise_for_status.side_effect = HTTPError(f"status={status_code}")
else:
resp.raise_for_status.return_value = None
if json_raises is not None:
resp.json.side_effect = json_raises
else:
resp.json.return_value = json_payload
return resp


@patch("goats_cli.versioning.get_version", return_value="1.0.0")
@patch("goats_cli.versioning.requests.get")
def test_is_outdated_true(mock_get, _mock_get_version):
payload = {"packages": {"goats": {"version": "1.2.0"}}}
mock_get.return_value = _fake_response(json_payload=payload)

vc = VersionChecker()
assert vc.check_if_outdated() is True

assert vc.current_version == "1.0.0"
assert vc.latest_version == "1.2.0"
assert vc.is_outdated is True

mock_get.assert_called_once()
called_kwargs = mock_get.call_args.kwargs
assert called_kwargs["timeout"] == 1.0


@patch("goats_cli.versioning.get_version", return_value="1.2.0")
@patch("goats_cli.versioning.requests.get")
def test_is_outdated_false_equal(mock_get, _):
payload = {"packages": {"goats": {"version": "1.2.0"}}}
mock_get.return_value = _fake_response(json_payload=payload)

vc = VersionChecker()
assert vc.check_if_outdated() is False
assert vc.is_outdated is False


@patch("goats_cli.versioning.get_version", return_value="1.0.0")
@patch("goats_cli.versioning.requests.get")
def test_request_exception_wraps_in_goatsclick(mock_get, _):
mock_get.side_effect = RequestException("network down")

vc = VersionChecker()
with pytest.raises(GOATSClickException) as exc:
vc.check_if_outdated()
assert "Failed to fetch latest version info" in str(exc.value)


@patch("goats_cli.versioning.get_version", return_value="1.0.0")
@patch("goats_cli.versioning.requests.get")
def test_jsondecodeerror_is_caught(mock_get, _):
mock_get.return_value = _fake_response(
json_raises=JSONDecodeError("Invalid JSON", doc="<<<", pos=1)
)

vc = VersionChecker()
with pytest.raises(GOATSClickException) as exc:
vc.check_if_outdated()
assert "Invalid JSON" in str(exc.value)


@patch("goats_cli.versioning.get_version", return_value="1.0.0")
@patch("goats_cli.versioning.requests.get")
def test_malformed_structure_keyerror(mock_get, _):
payload = {"packages": {"other": {"version": "9.9.9"}}}
mock_get.return_value = _fake_response(json_payload=payload)

vc = VersionChecker()
with pytest.raises(GOATSClickException) as exc:
vc.check_if_outdated()
assert "Malformed channel metadata" in str(exc.value)


@patch("goats_cli.versioning.get_version", return_value="1.0.0")
@patch("goats_cli.versioning.requests.get")
def test_malformed_structure_typeerror(mock_get, _):
mock_get.return_value = _fake_response(json_payload=["not", "a", "dict"])

vc = VersionChecker()
with pytest.raises(GOATSClickException) as exc:
vc.check_if_outdated()
assert "Malformed channel metadata" in str(exc.value)


@patch("goats_cli.versioning.get_version", return_value="0.9.0")
@patch("goats_cli.versioning.requests.get")
def test_custom_url_timeout_and_package(mock_get, _):
payload = {
"packages": {
"goats-cli": {"version": "1.0.0"},
"goats": {"version": "0.1.0"},
}
}
mock_get.return_value = _fake_response(json_payload=payload)

url = "https://example.invalid/channel.json"
vc = VersionChecker(channeldata_url=url, package_name="goats-cli", timeout_sec=5.0)

assert vc.check_if_outdated() is True
assert vc.current_version == "0.9.0"
assert vc.latest_version == "1.0.0"
assert vc.is_outdated is True

mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert args[0] == url
assert kwargs["timeout"] == 5.0

@pytest.mark.remote_data()
def test_get_latest_version_live():
checker = VersionChecker()
assert isinstance(checker.check_if_outdated(), bool)
assert isinstance(checker.latest_version, str)
24 changes: 12 additions & 12 deletions tests/goats_tom/ocs/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,18 @@ def test_get_observation_summary(client, odb_xml, observation_id, mocker):
assert isinstance(result, dict)
assert result["success"]

@pytest.mark.remote_data()
def test_get_coordinates_remote(client, observation_id):
coordinates_response = client.get_coordinates(observation_id)
assert coordinates_response["data"]
assert coordinates_response["success"]


@pytest.mark.remote_data()
def test_get_sequence_remote(client, observation_id):
sequence_response = client.get_sequence(observation_id)
assert sequence_response["data"]
assert sequence_response["success"]
# @pytest.mark.remote_data()
# def test_get_coordinates_remote(client, observation_id):
# coordinates_response = client.get_coordinates(observation_id)
# assert coordinates_response["data"]
# assert coordinates_response["success"]


# @pytest.mark.remote_data()
# def test_get_sequence_remote(client, observation_id):
# sequence_response = client.get_sequence(observation_id)
# assert sequence_response["data"]
# assert sequence_response["success"]


@pytest.mark.remote_data()
Expand Down
Loading