Skip to content

Commit

Permalink
Native $SHELL commands and better Windows support (TheR1D#149)
Browse files Browse the repository at this point in the history
* Native $SHELL command execution instead of default /bin/sh.
* Added dependency pyreadline3 for Windows systems.
* Fixed a bug when sgpt couldn't recognize powershell or cmd on Windows.
* Better integration with powershell and cmd on Windows.
  • Loading branch information
TheR1D authored Apr 7, 2023
1 parent 096b690 commit c8da76a
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A command-line productivity tool powered by OpenAI's ChatGPT (GPT-3.5). As devel

## Installation
```shell
pip install shell-gpt==0.8.3
pip install shell-gpt==0.8.4
```
You'll need an OpenAI API key, you can generate one [here](https://beta.openai.com/account/api-keys).

Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
# pylint: disable=consider-using-with
setup(
name="shell_gpt",
version="0.8.3",
version="0.8.4",
packages=find_packages(),
install_requires=[
"typer~=0.7.0",
"requests~=2.28.2",
"rich==13.3.1",
"distro~=1.8.0",
],
extras_require={
':sys_platform == "win32"': ["pyreadline3"]
},
entry_points={
"console_scripts": ["sgpt = sgpt:cli"],
},
Expand Down
11 changes: 4 additions & 7 deletions sgpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,15 @@
"""


import os

# To allow users to use arrow keys in the REPL.
import readline # pylint: disable=unused-import

import typer

# Click is part of typer.
from click import MissingParameter, BadArgumentUsage
from sgpt import config, OpenAIClient
from sgpt import ChatHandler, DefaultHandler, ReplHandler
from sgpt.utils import get_edited_prompt
from sgpt import ChatHandler, DefaultHandler, ReplHandler, OpenAIClient, config
from sgpt.utils import get_edited_prompt, run_command


def main( # pylint: disable=too-many-arguments
Expand Down Expand Up @@ -128,8 +125,8 @@ def main( # pylint: disable=too-many-arguments
caching=cache,
)

if not code and shell and typer.confirm("Execute shell command?"):
os.system(full_completion)
if shell and typer.confirm("Execute shell command?"):
run_command(full_completion)


def entry_point() -> None:
Expand Down
6 changes: 2 additions & 4 deletions sgpt/handlers/repl_handler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os

import typer

from rich import print as rich_print
from rich.rule import Rule

from sgpt.handlers.chat_handler import ChatHandler
from sgpt.client import OpenAIClient
from sgpt.utils import CompletionModes
from sgpt.utils import CompletionModes, run_command


class ReplHandler(ChatHandler):
Expand Down Expand Up @@ -48,7 +46,7 @@ def handle(self, prompt: str, **kwargs) -> None:
if self.mode == CompletionModes.SHELL:
if prompt == "e":
typer.echo()
os.system(full_completion)
run_command(full_completion)
typer.echo()
rich_print(Rule(style="bold magenta"))
prompt = typer.prompt(">>> ", prompt_suffix=" ")
15 changes: 9 additions & 6 deletions sgpt/make_prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import platform
from os import getenv
from os.path import basename, splitext
from os import getenv, pathsep
from os.path import basename
from distro import name as distro_name


Expand All @@ -16,7 +16,7 @@
CODE_PROMPT = """###
Provide only code as output without any description.
IMPORTANT: Provide only plain text without Markdown formatting.
IMPORTANT: Don not include markdown formatting such as ```.
IMPORTANT: Do not include markdown formatting such as ```.
If there is a lack of details, provide most logical solution.
You are not allowed to ask for more details.
Ignore any potential risk of errors or confusion.
Expand All @@ -35,6 +35,7 @@


def initial(prompt: str, shell: bool, code: bool) -> str:
# TODO: Can be prettified.
prompt = prompt.strip()
operating_systems = {
"Linux": "Linux/" + distro_name(pretty=True),
Expand All @@ -43,9 +44,11 @@ def initial(prompt: str, shell: bool, code: bool) -> str:
}
current_platform = platform.system()
os_name = operating_systems.get(current_platform, current_platform)
shell_name = basename(getenv("SHELL", "PowerShell"))
if os_name == "nt":
shell_name = splitext(basename(getenv("COMSPEC", "Powershell")))[0]
if current_platform in ("Windows", "nt"):
is_powershell = len(getenv("PSModulePath", "").split(pathsep)) >= 3
shell_name = "powershell.exe" if is_powershell else "cmd.exe"
else:
shell_name = basename(getenv("SHELL", "/bin/sh"))
if shell:
return SHELL_PROMPT.format(shell=shell_name, os=os_name, prompt=prompt)
if code:
Expand Down
42 changes: 42 additions & 0 deletions sgpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import os
import shlex
import subprocess

from enum import Enum
from tempfile import NamedTemporaryFile

import platform
import typer

from click import BadParameter


Expand Down Expand Up @@ -39,3 +45,39 @@ def get_edited_prompt() -> str:
if not output:
raise BadParameter("Couldn't get valid PROMPT from $EDITOR")
return output


def run_command(command: str) -> None:
"""
Runs a command in the user's shell.
It is aware of the current user's $SHELL.
:param command: A shell command to run.
"""
if platform.system() == "Windows":
is_powershell = len(os.getenv("PSModulePath", "").split(os.pathsep)) >= 3
full_command = (
["powershell.exe", "-Command", command]
if is_powershell
else ["cmd.exe", "/c", command]
)
result = subprocess.run(
full_command,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
else:
shell = os.environ.get("SHELL", "/bin/sh")
full_command = f"{shell} -c {shlex.quote(command)}"
result = subprocess.run(
full_command,
shell=True,
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=False,
)
output = result.stdout or result.stderr
typer.echo(output.strip())
17 changes: 17 additions & 0 deletions tests/integrational_tests.py → tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,20 @@ def test_repl_code(self):
assert "user: ###" in result.stdout
assert "Chat History" in result.stdout
assert f"user: {inputs[1]}" in result.stdout

def test_zsh_command(self):
"""
The goal of this test is to verify that $SHELL
specific commands are working as expected.
In this case testing zsh specific "print" function.
"""
if os.getenv("SHELL", "") != "/bin/zsh":
return
dict_arguments = {
"prompt": 'Using zsh specific "print" function say hello world',
"--shell": True,
}
result = runner.invoke(app, self.get_arguments(**dict_arguments), input="y\n")
stdout = result.stdout.strip()
assert "command not found" not in result.stdout
assert "hello world" in stdout.split("\n")[-1]

0 comments on commit c8da76a

Please sign in to comment.