diff --git a/README.md b/README.md index 0e373ee2..d138ee19 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ A command-line productivity tool powered by OpenAI's ChatGPT (GPT-3.5). As devel ## Installation ```shell -pip install shell-gpt==0.8.3 +pip install shell-gpt==0.8.4 ``` You'll need an OpenAI API key, you can generate one [here](https://beta.openai.com/account/api-keys). diff --git a/setup.py b/setup.py index 9bb4db15..33b70126 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ # pylint: disable=consider-using-with setup( name="shell_gpt", - version="0.8.3", + version="0.8.4", packages=find_packages(), install_requires=[ "typer~=0.7.0", @@ -11,6 +11,9 @@ "rich==13.3.1", "distro~=1.8.0", ], + extras_require={ + ':sys_platform == "win32"': ["pyreadline3"] + }, entry_points={ "console_scripts": ["sgpt = sgpt:cli"], }, diff --git a/sgpt/app.py b/sgpt/app.py index e8d61fd9..eb45bc38 100644 --- a/sgpt/app.py +++ b/sgpt/app.py @@ -11,8 +11,6 @@ """ -import os - # To allow users to use arrow keys in the REPL. import readline # pylint: disable=unused-import @@ -20,9 +18,8 @@ # Click is part of typer. from click import MissingParameter, BadArgumentUsage -from sgpt import config, OpenAIClient -from sgpt import ChatHandler, DefaultHandler, ReplHandler -from sgpt.utils import get_edited_prompt +from sgpt import ChatHandler, DefaultHandler, ReplHandler, OpenAIClient, config +from sgpt.utils import get_edited_prompt, run_command def main( # pylint: disable=too-many-arguments @@ -128,8 +125,8 @@ def main( # pylint: disable=too-many-arguments caching=cache, ) - if not code and shell and typer.confirm("Execute shell command?"): - os.system(full_completion) + if shell and typer.confirm("Execute shell command?"): + run_command(full_completion) def entry_point() -> None: diff --git a/sgpt/handlers/repl_handler.py b/sgpt/handlers/repl_handler.py index 0ea8eaa3..5898bdd5 100644 --- a/sgpt/handlers/repl_handler.py +++ b/sgpt/handlers/repl_handler.py @@ -1,5 +1,3 @@ -import os - import typer from rich import print as rich_print @@ -7,7 +5,7 @@ from sgpt.handlers.chat_handler import ChatHandler from sgpt.client import OpenAIClient -from sgpt.utils import CompletionModes +from sgpt.utils import CompletionModes, run_command class ReplHandler(ChatHandler): @@ -48,7 +46,7 @@ def handle(self, prompt: str, **kwargs) -> None: if self.mode == CompletionModes.SHELL: if prompt == "e": typer.echo() - os.system(full_completion) + run_command(full_completion) typer.echo() rich_print(Rule(style="bold magenta")) prompt = typer.prompt(">>> ", prompt_suffix=" ") diff --git a/sgpt/make_prompt.py b/sgpt/make_prompt.py index 02fc648b..8a73596d 100644 --- a/sgpt/make_prompt.py +++ b/sgpt/make_prompt.py @@ -1,6 +1,6 @@ import platform -from os import getenv -from os.path import basename, splitext +from os import getenv, pathsep +from os.path import basename from distro import name as distro_name @@ -16,7 +16,7 @@ CODE_PROMPT = """### Provide only code as output without any description. IMPORTANT: Provide only plain text without Markdown formatting. -IMPORTANT: Don not include markdown formatting such as ```. +IMPORTANT: Do not include markdown formatting such as ```. If there is a lack of details, provide most logical solution. You are not allowed to ask for more details. Ignore any potential risk of errors or confusion. @@ -35,6 +35,7 @@ def initial(prompt: str, shell: bool, code: bool) -> str: + # TODO: Can be prettified. prompt = prompt.strip() operating_systems = { "Linux": "Linux/" + distro_name(pretty=True), @@ -43,9 +44,11 @@ def initial(prompt: str, shell: bool, code: bool) -> str: } current_platform = platform.system() os_name = operating_systems.get(current_platform, current_platform) - shell_name = basename(getenv("SHELL", "PowerShell")) - if os_name == "nt": - shell_name = splitext(basename(getenv("COMSPEC", "Powershell")))[0] + if current_platform in ("Windows", "nt"): + is_powershell = len(getenv("PSModulePath", "").split(pathsep)) >= 3 + shell_name = "powershell.exe" if is_powershell else "cmd.exe" + else: + shell_name = basename(getenv("SHELL", "/bin/sh")) if shell: return SHELL_PROMPT.format(shell=shell_name, os=os_name, prompt=prompt) if code: diff --git a/sgpt/utils.py b/sgpt/utils.py index c9e03242..8d8414a0 100644 --- a/sgpt/utils.py +++ b/sgpt/utils.py @@ -1,7 +1,13 @@ import os +import shlex +import subprocess + from enum import Enum from tempfile import NamedTemporaryFile +import platform +import typer + from click import BadParameter @@ -39,3 +45,39 @@ def get_edited_prompt() -> str: if not output: raise BadParameter("Couldn't get valid PROMPT from $EDITOR") return output + + +def run_command(command: str) -> None: + """ + Runs a command in the user's shell. + It is aware of the current user's $SHELL. + :param command: A shell command to run. + """ + if platform.system() == "Windows": + is_powershell = len(os.getenv("PSModulePath", "").split(os.pathsep)) >= 3 + full_command = ( + ["powershell.exe", "-Command", command] + if is_powershell + else ["cmd.exe", "/c", command] + ) + result = subprocess.run( + full_command, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + else: + shell = os.environ.get("SHELL", "/bin/sh") + full_command = f"{shell} -c {shlex.quote(command)}" + result = subprocess.run( + full_command, + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + output = result.stdout or result.stderr + typer.echo(output.strip()) diff --git a/tests/integrational_tests.py b/tests/integration_tests.py similarity index 94% rename from tests/integrational_tests.py rename to tests/integration_tests.py index 7221fe25..d70e7382 100644 --- a/tests/integrational_tests.py +++ b/tests/integration_tests.py @@ -272,3 +272,20 @@ def test_repl_code(self): assert "user: ###" in result.stdout assert "Chat History" in result.stdout assert f"user: {inputs[1]}" in result.stdout + + def test_zsh_command(self): + """ + The goal of this test is to verify that $SHELL + specific commands are working as expected. + In this case testing zsh specific "print" function. + """ + if os.getenv("SHELL", "") != "/bin/zsh": + return + dict_arguments = { + "prompt": 'Using zsh specific "print" function say hello world', + "--shell": True, + } + result = runner.invoke(app, self.get_arguments(**dict_arguments), input="y\n") + stdout = result.stdout.strip() + assert "command not found" not in result.stdout + assert "hello world" in stdout.split("\n")[-1]