Created
April 18, 2023 17:31
-
-
Save mbednarski/0359d58f1b85d793ef1514c6d904ca74 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain.agents import initialize_agent, Tool | |
from langchain.agents import AgentType | |
from langchain.tools import BaseTool | |
from langchain.llms import OpenAI | |
from langchain import LLMMathChain, SerpAPIWrapper | |
from spotipy.oauth2 import SpotifyClientCredentials, SpotifyOAuth | |
import azapi | |
import spotipy | |
from langchain.utilities import GoogleSearchAPIWrapper | |
from urllib.parse import quote_plus | |
from langchain.memory import ConversationBufferMemory | |
llm = OpenAI(temperature=0) | |
search = GoogleSearchAPIWrapper() | |
import os | |
google_tool = Tool( | |
name="Current Search", | |
func=search.run, | |
description="useful for when you need to answer questions about current events or the current state of the world", | |
) | |
def get_spotify(scope): | |
spotify = spotipy.Spotify(auth_manager=SpotifyOAuth(scope=scope)) | |
return spotify | |
class LyricsTool(BaseTool): | |
name = "Lyrics" | |
description = "useful when you need to get song lyrics. The inout should be in format artist - track name. For example: Taylor Swift - Flowers" | |
def _run(self, tool_input: str) -> str: | |
API = azapi.AZlyrics() | |
a, t = tool_input.split("-") | |
API.artist = a.strip() | |
API.title = t.strip() | |
API.getLyrics(save=True, ext="lrc") | |
return API.lyrics | |
def _arun(self, tool_input: str) -> str: | |
raise NotImplementedError() | |
class SpotifyCurrentTrackTool(BaseTool): | |
name = "Spotify" | |
description = "useful when you need to read currently played track" | |
def _run(self, tool_input: str) -> str: | |
spotify = get_spotify("user-read-currently-playing") | |
result = spotify.currently_playing() | |
return result | |
async def _arun(self, tool_input: str) -> str: | |
raise NotImplementedError() | |
class SpotifyFindTrackIdTool(BaseTool): | |
name = "spotify-track-uri" | |
description = "useful when you need to find spotify track URI by it's name and artist. The input should be in format name by artist. For example Flowers by Taylor Swift" | |
def _run(self, tool_input: str) -> str: | |
spotify = get_spotify("user-read-currently-playing") | |
t, a = tool_input.split(" by ") | |
result = spotify.search( | |
q=f"track:{t.strip()} artist:{a.strip()}", type="track", limit=1 | |
) | |
return result | |
async def _arun(self, tool_input: str) -> str: | |
raise NotImplementedError() | |
class SpotifyPlayTool(BaseTool): | |
name = "spotify-play" | |
description = "useful when you need to play a track using spotify. The input is spotify track URI" | |
def _run(self, tool_input: str) -> str: | |
try: | |
spotify = get_spotify("user-modify-playback-state") | |
spotify.start_playback(uris=[tool_input]) | |
return result | |
except: | |
return | |
async def _arun(self, tool_input: str) -> str: | |
raise NotImplementedError() | |
tools = [SpotifyCurrentTrackTool(), SpotifyFindTrackIdTool(), SpotifyPlayTool()] | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# agent = initialize_agent(tools=tools, llm=llm, agent=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, memory=memory, verbose=True) | |
agent = initialize_agent( | |
tools=tools, | |
llm=llm, | |
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
memory=memory, | |
verbose=True, | |
) | |
result = agent.run("Play lazarus by david bowie") | |
# print(result) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment