Created
October 24, 2024 17:53
-
-
Save cyysky/fff21e8722f1e0d84ca49523a060c98e to your computer and use it in GitHub Desktop.
some_resoning_idea
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import pipeline | |
from typing import List, Dict, Any, Optional | |
import networkx as nx | |
from dataclasses import dataclass | |
import logging | |
import json | |
from pathlib import Path | |
import time | |
import asyncio | |
from concurrent.futures import ThreadPoolExecutor | |
import numpy as np | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
@dataclass | |
class ReasoningStep: | |
"""Represents a single reasoning step with metadata.""" | |
content: str | |
confidence: float | |
supporting_evidence: List[str] | |
verification_status: bool | |
timestamp: float | |
step_number: int | |
class Llama3Reasoner: | |
"""Reasoning system using Llama 3.2 for step-by-step problem solving.""" | |
def __init__( | |
self, | |
model_id: str = "meta-llama/Llama-3.1-8B-Instruct", | |
cache_dir: Optional[str] = None, | |
max_steps: int = 5 | |
): | |
self.max_steps = max_steps | |
self.cache_dir = Path(cache_dir) if cache_dir else None | |
self.executor = ThreadPoolExecutor(max_workers=4) | |
# Initialize the model pipeline | |
logger.info(f"Loading Llama 3.1 model: {model_id}") | |
self.pipe = pipeline( | |
"text-generation", | |
model=model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
async def generate_reasoning(self, prompt: str) -> Dict[str, Any]: | |
"""Generates step-by-step reasoning for a given problem.""" | |
messages = [ | |
{ | |
"role": "system", | |
"content": """You are a logical reasoning assistant. Analyze problems step by step, | |
providing clear explanations and evidence for each step. Format your response as: | |
Step 1: [explanation] | |
Evidence: [supporting facts] | |
Step 2: [explanation] | |
Evidence: [supporting facts] | |
And so on...""" | |
}, | |
{ | |
"role": "user", | |
"content": f"Please solve this problem step by step: {prompt}" | |
} | |
] | |
try: | |
# Run generation in executor to avoid blocking | |
loop = asyncio.get_event_loop() | |
outputs = await loop.run_in_executor( | |
self.executor, | |
lambda: self.pipe( | |
messages, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
) | |
) | |
# Parse the response into steps | |
response = outputs[0]["generated_text"][-1]['content'] | |
steps = await self._parse_steps(response) | |
# Create reasoning graph | |
graph = await self._create_reasoning_graph(steps) | |
return { | |
'steps': steps, | |
'graph': graph, | |
'metrics': self._calculate_metrics(steps, graph) | |
} | |
except Exception as e: | |
logger.error(f"Error generating reasoning: {str(e)}") | |
raise | |
async def _parse_steps(self, response: str) -> List[ReasoningStep]: | |
"""Parses the model's response into structured reasoning steps.""" | |
steps = [] | |
step_texts = [s.strip() for s in response.split("Step") if s.strip()] | |
for i, step_text in enumerate(step_texts): | |
# Split step content and evidence | |
parts = step_text.split("Evidence:") | |
content = parts[0].strip() | |
evidence = [e.strip() for e in parts[1].split("\n")] if len(parts) > 1 else [] | |
# Calculate confidence based on language markers | |
confidence = self._calculate_confidence(content) | |
steps.append(ReasoningStep( | |
content=content, | |
confidence=confidence, | |
supporting_evidence=evidence, | |
verification_status=True, | |
timestamp=time.time(), | |
step_number=i | |
)) | |
return steps | |
async def _create_reasoning_graph(self, steps: List[ReasoningStep]) -> nx.DiGraph: | |
"""Creates a directed graph representing the reasoning process.""" | |
graph = nx.DiGraph() | |
for i, step in enumerate(steps): | |
# Add node with step information | |
graph.add_node( | |
i, | |
content=step.content, | |
confidence=step.confidence, | |
evidence=step.supporting_evidence | |
) | |
# Add edge to previous step if it exists | |
if i > 0: | |
weight = self._calculate_step_coherence(steps[i-1], step) | |
graph.add_edge(i-1, i, weight=weight) | |
return graph | |
def _calculate_confidence(self, text: str) -> float: | |
"""Calculates confidence score based on language patterns.""" | |
high_confidence = ['clearly', 'definitely', 'must', 'certainly'] | |
medium_confidence = ['likely', 'probably', 'should'] | |
low_confidence = ['might', 'maybe', 'possibly', 'could'] | |
text_lower = text.lower() | |
if any(marker in text_lower for marker in high_confidence): | |
return 0.9 | |
elif any(marker in text_lower for marker in medium_confidence): | |
return 0.7 | |
elif any(marker in text_lower for marker in low_confidence): | |
return 0.5 | |
return 0.6 | |
def _calculate_step_coherence(self, prev_step: ReasoningStep, curr_step: ReasoningStep) -> float: | |
"""Calculates coherence between consecutive steps.""" | |
# Check for logical connectors | |
connectors = ['therefore', 'thus', 'consequently', 'as a result'] | |
has_connector = any(c in curr_step.content.lower() for c in connectors) | |
# Check for shared evidence | |
shared_evidence = len(set(prev_step.supporting_evidence) & set(curr_step.supporting_evidence)) | |
# Calculate base coherence | |
coherence = 0.6 # base coherence | |
if has_connector: | |
coherence += 0.2 | |
if shared_evidence: | |
coherence += 0.1 * shared_evidence | |
return min(coherence, 1.0) | |
def _calculate_metrics(self, steps: List[ReasoningStep], graph: nx.DiGraph) -> Dict[str, float]: | |
"""Calculates quality metrics for the reasoning process.""" | |
confidence_scores = [step.confidence for step in steps] | |
edge_weights = [data['weight'] for _, _, data in graph.edges(data=True)] | |
return { | |
'average_confidence': np.mean(confidence_scores), | |
'min_confidence': np.min(confidence_scores), | |
'coherence': np.mean(edge_weights) if edge_weights else 1.0, | |
'step_count': len(steps) | |
} | |
def save_analysis(self, analysis: Dict[str, Any], filepath: str): | |
"""Saves the reasoning analysis to a file.""" | |
save_data = { | |
'steps': [ | |
{ | |
'content': step.content, | |
'confidence': step.confidence, | |
'evidence': step.supporting_evidence, | |
'step_number': step.step_number | |
} | |
for step in analysis['steps'] | |
], | |
'metrics': analysis['metrics'], | |
'graph': nx.node_link_data(analysis['graph']), | |
'timestamp': time.time() | |
} | |
with open(filepath, 'w') as f: | |
json.dump(save_data, f, indent=2) | |
async def main(): | |
try: | |
reasoner = Llama3Reasoner() | |
problem = """ | |
A farmer is traveling with a fox, a chicken, and a bag of grain. | |
He must cross a river, but his boat can only carry himself and one other item. | |
If left alone, the fox will eat the chicken, and the chicken will eat the grain. | |
How can he safely transport all three across the river? | |
""" | |
logger.info("Solving problem...") | |
solution = await reasoner.generate_reasoning(problem) | |
logger.info("\nReasoning steps:") | |
for step in solution['steps']: | |
logger.info(f"\nStep {step.step_number + 1}:") | |
logger.info(f"Content: {step.content}") | |
logger.info(f"Confidence: {step.confidence:.2f}") | |
if step.supporting_evidence: | |
logger.info("Evidence:") | |
for evidence in step.supporting_evidence: | |
logger.info(f"- {evidence}") | |
logger.info("\nMetrics:") | |
for metric, value in solution['metrics'].items(): | |
logger.info(f"{metric}: {value:.2f}") | |
# Save analysis | |
reasoner.save_analysis(solution, 'reasoning_analysis.json') | |
except Exception as e: | |
logger.error(f"Error in main: {str(e)}") | |
raise | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment