diff --git a/src/codeflare_sdk/cli/cli_utils.py b/src/codeflare_sdk/cli/cli_utils.py index 0c557a8..c9d2c87 100644 --- a/src/codeflare_sdk/cli/cli_utils.py +++ b/src/codeflare_sdk/cli/cli_utils.py @@ -57,3 +57,19 @@ def load_auth(): click.echo("No authentication found, trying default kubeconfig") except client.ApiException: click.echo("Invalid authentication, trying default kubeconfig") + + +class PluralAlias(click.Group): + def get_command(self, ctx, cmd_name): + rv = click.Group.get_command(self, ctx, cmd_name) + if rv is not None: + return rv + for x in self.list_commands(ctx): + if x + "s" == cmd_name: + return click.Group.get_command(self, ctx, x) + return None + + def resolve_command(self, ctx, args): + # always return the full command name + _, cmd, args = super().resolve_command(ctx, args) + return cmd.name, cmd, args diff --git a/src/codeflare_sdk/cli/codeflare_cli.py b/src/codeflare_sdk/cli/codeflare_cli.py index 7835469..2731ac0 100644 --- a/src/codeflare_sdk/cli/codeflare_cli.py +++ b/src/codeflare_sdk/cli/codeflare_cli.py @@ -2,6 +2,7 @@ import os from codeflare_sdk.cli.cli_utils import load_auth +from codeflare_sdk.cluster.cluster import get_current_namespace cmd_folder = os.path.abspath(os.path.join(os.path.dirname(__file__), "commands")) @@ -9,6 +10,7 @@ class CodeflareContext: def __init__(self): self.codeflare_path = _initialize_codeflare_folder() + self.current_namespace = get_current_namespace() def _initialize_codeflare_folder(): diff --git a/src/codeflare_sdk/cli/commands/define.py b/src/codeflare_sdk/cli/commands/define.py index 09cfd1f..4db177f 100644 --- a/src/codeflare_sdk/cli/commands/define.py +++ b/src/codeflare_sdk/cli/commands/define.py @@ -12,8 +12,9 @@ def cli(): @cli.command() +@click.pass_context @click.option("--name", type=str, required=True) -@click.option("--namespace", "-n", type=str, required=True) +@click.option("--namespace", "-n", type=str) @click.option("--head_info", cls=PythonLiteralOption, type=list) @click.option("--machine_types", cls=PythonLiteralOption, type=list) @click.option("--min_cpus", type=int) @@ -29,8 +30,10 @@ def cli(): @click.option("--image", type=str) @click.option("--local_interactive", type=bool) @click.option("--image_pull_secrets", cls=PythonLiteralOption, type=list) -def raycluster(**kwargs): +def raycluster(ctx, **kwargs): """Define a RayCluster with parameter specifications""" filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + if "namespace" not in filtered_kwargs.keys(): + filtered_kwargs["namespace"] = ctx.obj.current_namespace clusterConfig = ClusterConfiguration(**filtered_kwargs) Cluster(clusterConfig) # Creates yaml file diff --git a/src/codeflare_sdk/cli/commands/delete.py b/src/codeflare_sdk/cli/commands/delete.py index 7ce9744..c225d42 100644 --- a/src/codeflare_sdk/cli/commands/delete.py +++ b/src/codeflare_sdk/cli/commands/delete.py @@ -12,12 +12,14 @@ def cli(): @cli.command() +@click.pass_context @click.argument("name", type=str) -@click.option("--namespace", type=str, required=True) -def raycluster(name, namespace): +@click.option("--namespace", type=str) +def raycluster(ctx, name, namespace): """ Delete a specified RayCluster from the Kubernetes cluster """ + namespace = namespace or ctx.obj.current_namespace try: cluster = get_cluster(name, namespace) except FileNotFoundError: diff --git a/src/codeflare_sdk/cli/commands/details.py b/src/codeflare_sdk/cli/commands/details.py index b865caa..f6890e7 100644 --- a/src/codeflare_sdk/cli/commands/details.py +++ b/src/codeflare_sdk/cli/commands/details.py @@ -11,10 +11,11 @@ def cli(): @cli.command() @click.argument("name", type=str) -@click.option("--namespace", type=str, required=True) +@click.option("--namespace", type=str) @click.pass_context def raycluster(ctx, name, namespace): """Get the details of a specified RayCluster""" + namespace = namespace or ctx.obj.current_namespace try: cluster = get_cluster(name, namespace) except FileNotFoundError: diff --git a/src/codeflare_sdk/cli/commands/list.py b/src/codeflare_sdk/cli/commands/list.py index dd3ad4e..533aaed 100644 --- a/src/codeflare_sdk/cli/commands/list.py +++ b/src/codeflare_sdk/cli/commands/list.py @@ -4,12 +4,11 @@ from codeflare_sdk.cluster.cluster import ( list_clusters_all_namespaces, list_all_clusters, - get_current_namespace, ) -from codeflare_sdk.cli.cli_utils import load_auth +from codeflare_sdk.cli.cli_utils import PluralAlias -@click.group() +@click.group(cls=PluralAlias) def cli(): """List a specified resource""" pass @@ -19,14 +18,12 @@ def cli(): @click.option("--namespace", type=str) @click.option("--all", is_flag=True) @click.pass_context -def rayclusters(ctx, namespace, all): +def raycluster(ctx, namespace, all): """List all rayclusters in a specified namespace""" if all and namespace: click.echo("--all and --namespace are mutually exclusive") return - if not all and not namespace: - click.echo("You must specify either --namespace or --all") - return + namespace = namespace or ctx.obj.current_namespace if not all: list_all_clusters(namespace) return diff --git a/src/codeflare_sdk/cli/commands/status.py b/src/codeflare_sdk/cli/commands/status.py index fc76ffc..dbd92a5 100644 --- a/src/codeflare_sdk/cli/commands/status.py +++ b/src/codeflare_sdk/cli/commands/status.py @@ -11,10 +11,11 @@ def cli(): @cli.command() @click.argument("name", type=str) -@click.option("--namespace", type=str, required=True) +@click.option("--namespace", type=str) @click.pass_context def raycluster(ctx, name, namespace): """Get the status of a specified RayCluster""" + namespace = namespace or ctx.obj.current_namespace try: cluster = get_cluster(name, namespace) except FileNotFoundError: diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index a25dd1b..b98255e 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -412,7 +412,7 @@ def list_all_clusters(namespace: str, print_to_console: bool = True): """ Returns (and prints by default) a list of all clusters in a given namespace. """ - clusters = _get_ray_clusters_in_namespace(namespace) + clusters = _get_all_ray_clusters(namespace) if print_to_console: pretty_print.print_clusters(clusters) return clusters @@ -539,17 +539,24 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]: return None -def _get_ray_clusters_in_namespace(namespace="default") -> List[RayCluster]: +def _get_all_ray_clusters(namespace: str = None) -> List[RayCluster]: list_of_clusters = [] try: config_check() api_instance = client.CustomObjectsApi(api_config_handler()) - rcs = api_instance.list_namespaced_custom_object( - group="ray.io", - version="v1alpha1", - namespace=namespace, - plural="rayclusters", - ) + if namespace: + rcs = api_instance.list_namespaced_custom_object( + group="ray.io", + version="v1alpha1", + namespace=namespace, + plural="rayclusters", + ) + else: + rcs = api_instance.list_cluster_custom_object( + group="ray.io", + version="v1alpha1", + plural="rayclusters", + ) except Exception as e: # pragma: no cover return _kube_api_error_handling(e) @@ -558,23 +565,6 @@ def _get_ray_clusters_in_namespace(namespace="default") -> List[RayCluster]: return list_of_clusters -def _get_all_ray_clusters() -> List[RayCluster]: - list_of_clusters = [] - try: - config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) - rcs = api_instance.list_cluster_custom_object( - group="ray.io", - version="v1alpha1", - plural="rayclusters", - ) - except Exception as e: - return _kube_api_error_handling(e) - for rc in rcs["items"]: - list_of_clusters.append(_map_to_ray_cluster(rc)) - return list_of_clusters - - def _get_app_wrappers( namespace="default", filter=List[AppWrapperStatus] ) -> List[AppWrapper]: diff --git a/tests/unit_test.py b/tests/unit_test.py index 45b7038..a925801 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -160,7 +160,8 @@ def test_login_tls_cli(mocker): tls_result = runner.invoke(cli, k8s_tls_login_command) skip_tls_result = runner.invoke(cli, k8s_skip_tls_login_command) assert ( - tls_result.output == skip_tls_result.output == "Logged into 'testserver:6443'\n" + "Logged into 'testserver:6443'\n" in tls_result.output + and "Logged into 'testserver:6443'\n" in skip_tls_result.output ) @@ -169,7 +170,7 @@ def test_logout_cli(mocker): mocker.patch.object(client, "ApiClient") k8s_logout_command = "logout" logout_result = runner.invoke(cli, k8s_logout_command) - assert logout_result.output == "Successfully logged out of 'testserver:6443'\n" + assert "Successfully logged out of 'testserver:6443'\n" in logout_result.output assert not os.path.exists(os.path.expanduser("~/.codeflare/auth")) @@ -198,6 +199,10 @@ def test_cluster_deletion_cli(mocker): "kubernetes.client.CustomObjectsApi.list_namespaced_custom_object", side_effect=get_ray_obj, ) + mocker.patch( + "codeflare_sdk.cluster.cluster.get_current_namespace", + return_value="ns", + ) runner = CliRunner() delete_cluster_command = """ delete raycluster