Skip to content

Instantly share code, notes, and snippets.

@sefatanam
Forked from brownan/prompt.py
Created January 28, 2025 16:27
Show Gist options
  • Select an option

  • Save sefatanam/6e19894b0e61de461f82f1d34835e2bc to your computer and use it in GitHub Desktop.

Select an option

Save sefatanam/6e19894b0e61de461f82f1d34835e2bc to your computer and use it in GitHub Desktop.
Ollama command line interface with Markdown rendering
#!/bin/env python3
import argparse
import io
import json
import os
import sys
import urllib.parse
import urllib.request
from typing import Optional, NamedTuple
import rich.console
import rich.live
import rich.markdown
import rich.panel
import rich.spinner
import rich.text
import rich.table
import rich.pretty
query_url = urllib.parse.urljoin(
os.environ.get("OLLAMA_HOST", "http://localhost:11434"), "/api/generate"
)
class ExecutionParams(NamedTuple):
raw: bool
stats: bool
seed: Optional[int] = None
temperature: Optional[float] = None
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
"--model",
default="llama3",
help="The name of the models to run. Multiple models can be specified separated by commas",
)
parser.add_argument(
"--raw",
action="store_true",
help="Do not format the results with markdown. Output the results to stdout",
)
parser.add_argument(
"--stats",
action="store_true",
help="Output some stats about the generation at the end",
)
parser.add_argument(
"--seed", type=int, help="The seed to use in the model evaluation. "
"for deterministic output you must also set temperature=0."
)
parser.add_argument(
"--temperature",
type=float,
help="The temperature to use in the model evaluation. A value of 0 means the model"
" will always pick the highest probability token. Values above 0 add some"
" randomness to token selection and generally make the model more creative."
" For deterministic output, set this to 0 and set a seed to a non-zero constant."
" Default in ollama is 0.8.",
)
parser.add_argument("prompt", nargs="*")
args = vars(parser.parse_args())
raw = args["raw"]
if args["prompt"]:
prompt = " ".join(args["prompt"])
else:
prompt = sys.stdin.read()
console = rich.console.Console(stderr=raw)
if not raw:
console.print(rich.panel.Panel.fit("[bold]Prompt"))
console.print(rich.text.Text(prompt))
models = args["models"].split(",")
for model in models:
run_ollama(
model,
prompt,
console,
ExecutionParams(
raw=raw,
stats=args["stats"],
seed=args["seed"],
temperature=args["temperature"],
),
)
def run_ollama(model, prompt, console: rich.console.Console, params: ExecutionParams):
raw = params.raw
if not raw:
console.print()
console.print(rich.panel.Panel.fit("[bold]" + model))
req_data = {
"model": model,
"prompt": prompt,
"options": {},
}
if params.temperature is not None:
req_data["options"]["temperature"] = params.temperature
if params.seed is not None:
req_data["options"]["seed"] = params.seed
request = urllib.request.Request(
query_url,
data=json.dumps(
req_data
).encode("utf-8"),
headers={
"Content-Type": "application/json; charset=utf-8",
},
)
output = io.StringIO()
spinner = rich.spinner.Spinner(
"dots", text="Loading model...", style="status.spinner", speed=1.0
)
stats_data: Optional[dict] = None
with rich.live.Live(
spinner,
console=console,
vertical_overflow="ellipsis",
refresh_per_second=12.5,
transient=raw,
) as live:
response = urllib.request.urlopen(request)
with response:
response_buf = io.TextIOWrapper(response, encoding="utf-8")
while True:
buf = response_buf.readline()
if not buf:
break
data = json.loads(buf)
if data["done"]:
stats_data = data
if raw:
sys.stdout.write("\n")
break
if raw:
live.stop()
sys.stdout.write(data["response"])
else:
output.write(data["response"])
live.update(rich.markdown.Markdown(output.getvalue()))
if params.stats and stats_data:
console.print()
table = rich.table.Table(f"Generation Stats", box=None)
table.add_column()
table.add_column()
def format_int(val):
return format(val, "n")
def format_float(val):
return format(val, ".2f")
table.add_row("[bold]Model", stats_data.get("model"))
table.add_row("[bold]Created At", stats_data.get("created_at"))
table.add_row(
"[bold]Total Duration (s)",
format_float(int(stats_data.get("total_duration")) / 10**9),
)
table.add_row(
"[bold]Load Duration (s)",
format_float(int(stats_data.get("load_duration")) / 10**9),
)
if "prompt_eval_count" in stats_data:
table.add_row(
"[bold]Input Tokens", format_int(int(stats_data.get("prompt_eval_count")))
)
table.add_row(
"[bold]Prompt Evaluation Duration (s)",
format_float(int(stats_data.get("prompt_eval_duration")) / 10**9),
)
table.add_row(
"[bold]Response Tokens", format_int(int(stats_data.get("eval_count")))
)
table.add_row(
"[bold]Response Evaluation Duration (s)",
format_float(int(stats_data.get("eval_duration")) / 10**9),
)
table.add_row(
"[bold]Tokens per second",
format_float(
int(stats_data.get("eval_count"))
/ int(stats_data.get("eval_duration"))
* 10**9
),
)
console.print(table)
if __name__ == "__main__":
main()
@sefatanam
Copy link
Author

How to use

Prerequisites

Before you begin, ensure the following are set up:

  1. Ollama: Install and run Ollama to pull any model you need.

  2. Python 3: Ensure Python 3 is installed system-wide.


Running the Script

  1. Download the Script
    Save the script as prompt.py (or any filename you prefer) in your desired folder.

  2. Set Up Virtual Environment
    In your terminal, create a virtual environment:

    python3 -m venv path/to/venv
  3. Activate the Virtual Environment
    For macOS/Linux:

    source path/to/venv/bin/activate

    For Windows:

    path\to\venv\Scripts\activate
  4. Install Dependencies
    Inside the activated virtual environment, install the required dependencies:

    python3 -m pip install -r rich
  5. Run the Script
    Now you’re ready to go! Execute the script with the following command:

    python3 prompt.py --models "deepseek-r1:7b" --temperature 0.8 "Write a recursion function in Go" 💥

@sefatanam
Copy link
Author

Screenshot 2025-01-28 at 10 39 34 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment