Created
December 11, 2019 23:32
-
-
Save MightyAlex200/b57da504fd333a00daca3603da63551f to your computer and use it in GitHub Desktop.
AIDungeon2 everything
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%bash | |
# Everything | |
pip install func_timeout | |
echo """From 5c7ad075b4065625bc314782ba041fb998bed72f Mon Sep 17 00:00:00 2001 | |
From: MightyAlex200 <[email protected]> | |
Date: Wed, 11 Dec 2019 18:28:15 -0500 | |
Subject: [PATCH] All mods | |
--- | |
play.py | 101 ++++++++++++++++++++++++++++++++++++----- | |
story/story_manager.py | 9 ++-- | |
2 files changed, 96 insertions(+), 14 deletions(-) | |
diff --git a/play.py b/play.py | |
index 31cd2bd..883d691 100644 | |
--- a/play.py | |
+++ b/play.py | |
@@ -5,6 +5,7 @@ import time | |
from generator.gpt2.gpt2_generator import * | |
from story.story_manager import * | |
from story.utils import * | |
+from func_timeout import func_timeout, FunctionTimedOut | |
os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" | |
@@ -43,7 +44,7 @@ def select_game(): | |
+ \"terrorizing the kingdom. You enter the forest searching for the dragon and see' \" | |
) | |
prompt = input(\"Starting Prompt: \") | |
- return context, prompt | |
+ return context, prompt, True | |
setting_key = list(settings)[choice] | |
@@ -73,7 +74,7 @@ def select_game(): | |
prompt_num = np.random.randint(0, len(character[\"prompts\"])) | |
prompt = character[\"prompts\"][prompt_num] | |
- return context, prompt | |
+ return context, prompt, False | |
def instructions(): | |
@@ -105,6 +106,11 @@ def play_aidungeon_2(): | |
print(\"\\nInitializing AI Dungeon! (This might take a few minutes)\\n\") | |
generator = GPT2Generator() | |
story_manager = UnconstrainedStoryManager(generator) | |
+ inference_timeout = 30 | |
+ def act(action): | |
+ return func_timeout(inference_timeout, story_manager.act, (action,)) | |
+ def notify_hanged(): | |
+ console_print(f\"That input caused the model to hang (timeout is {inference_timeout}, use infto ## command to change)\") | |
print(\"\\n\") | |
with open(\"opening.txt\", \"r\", encoding=\"utf-8\") as file: | |
@@ -115,18 +121,21 @@ def play_aidungeon_2(): | |
if story_manager.story != None: | |
del story_manager.story | |
+ characters = [] | |
+ current_character = \"You\" | |
+ | |
print(\"\\n\\n\") | |
splash_choice = splash() | |
if splash_choice == \"new\": | |
print(\"\\n\\n\") | |
- context, prompt = select_game() | |
+ context, prompt, noblock = select_game() | |
console_print(instructions()) | |
print(\"\\nGenerating story...\") | |
story_manager.start_new_story( | |
- prompt, context=context, upload_story=upload_story | |
+ prompt, context=context, upload_story=upload_story, noblock=noblock | |
) | |
print(\"\\n\") | |
console_print(str(story_manager.story)) | |
@@ -207,31 +216,101 @@ def play_aidungeon_2(): | |
else: | |
console_print(story_manager.story.story_start) | |
continue | |
+ | |
+ elif action == 'retry': | |
+ | |
+ if len(story_manager.story.actions) is 0: | |
+ console_print(\"There is nothing to retry.\") | |
+ continue | |
+ | |
+ last_action = story_manager.story.actions.pop() | |
+ last_result = story_manager.story.results.pop() | |
+ | |
+ try: | |
+ # Compatibility with timeout patch | |
+ act | |
+ except NameError: | |
+ act = story_manager.act | |
+ | |
+ try: | |
+ try: | |
+ act(last_action) | |
+ console_print(last_action) | |
+ console_print(story_manager.story.results[-1]) | |
+ except FunctionTimedOut: | |
+ story_manager.story.actions.append(last_action) | |
+ story_manager.story.results.append(last_result) | |
+ notify_hanged() | |
+ console_print(\"Your story progress has not been altered.\") | |
+ except NameError: | |
+ pass | |
+ | |
+ continue | |
+ | |
+ elif len(action.split(\" \")) == 2 and action.split(\" \")[0] == 'infto': | |
+ | |
+ try: | |
+ inference_timeout = int(action.split(\" \")[1]) | |
+ console_print(f\"Set timeout to {inference_timeout}\") | |
+ except: | |
+ console_print(\"Failed to set timeout. Example usage: infto 30\") | |
+ | |
+ continue | |
+ | |
+ elif len(action.split(\" \")) >= 2 and action.split(\" \")[0] == \"setchar\": | |
+ | |
+ new_char = action[len(action.split(\" \")[0]):].strip() | |
+ if new_char == \"\": | |
+ console_print(\"Character name cannot be empty\") | |
+ continue | |
+ is_known_char = False | |
+ for known_char in characters: | |
+ if known_char.lower() == new_char.lower(): | |
+ is_known_char = True | |
+ new_char = known_char | |
+ break | |
+ if not is_known_char: | |
+ characters.append(new_char) | |
+ | |
+ current_character = new_char | |
+ console_print(\"Switched to character \" + new_char) | |
+ continue | |
else: | |
if action == \"\": | |
action = \"\" | |
- result = story_manager.act(action) | |
+ try: | |
+ result = act(action) | |
+ except FunctionTimedOut: | |
+ notify_hanged() | |
+ continue | |
console_print(result) | |
elif action[0] == '\"': | |
- action = \"You say \" + action | |
+ if current_character == \"You\": | |
+ action = \"You say \" + action | |
+ else: | |
+ action = current_character + \" says \" + action | |
+ | |
+ elif action[0] == '!': | |
+ action = \"\\n\" + action[1:].replace(\"\\\\n\", \"\\n\") + \"\\n\" | |
else: | |
action = action.strip() | |
action = action[0].lower() + action[1:] | |
- if \"You\" not in action[:6] and \"I\" not in action[:6]: | |
- action = \"You \" + action | |
+ action = current_character + \" \" + action | |
if action[-1] not in [\".\", \"?\", \"!\"]: | |
action = action + \".\" | |
- action = first_to_second_person(action) | |
- | |
action = \"\\n> \" + action + \"\\n\" | |
- result = \"\\n\" + story_manager.act(action) | |
+ try: | |
+ result = \"\\n\" + act(action) | |
+ except FunctionTimedOut: | |
+ notify_hanged() | |
+ continue | |
if len(story_manager.story.results) >= 2: | |
similarity = get_similarity( | |
story_manager.story.results[-1], story_manager.story.results[-2] | |
diff --git a/story/story_manager.py b/story/story_manager.py | |
index aba3974..78bd0b0 100644 | |
--- a/story/story_manager.py | |
+++ b/story/story_manager.py | |
@@ -159,10 +159,13 @@ class StoryManager: | |
self.story = None | |
def start_new_story( | |
- self, story_prompt, context=\"\", game_state=None, upload_story=False | |
+ self, story_prompt, context=\"\", game_state=None, upload_story=False, noblock=False | |
): | |
- block = self.generator.generate(context + story_prompt) | |
- block = cut_trailing_sentence(block) | |
+ if noblock: | |
+ block = \"\" | |
+ else: | |
+ block = self.generator.generate(context + story_prompt) | |
+ block = cut_trailing_sentence(block) | |
self.story = Story( | |
context + story_prompt + block, | |
context=context, | |
-- | |
2.24.0""" > fix.patch | |
# email and username required | |
git config user.email "[email protected]" | |
git config user.name "Anonymous" | |
git am --3way fix.patch && echo Patch Applied! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment