| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import time |
| | from typing import Any, Dict, List, Union |
| |
|
| | from requests.exceptions import RequestException |
| |
|
| | from camel.toolkits import FunctionTool |
| | from camel.toolkits.base import BaseToolkit |
| |
|
| |
|
| | class RedditToolkit(BaseToolkit): |
| | r"""A class representing a toolkit for Reddit operations. |
| | |
| | This toolkit provides methods to interact with the Reddit API, allowing |
| | users to collect top posts, perform sentiment analysis on comments, and |
| | track keyword discussions across multiple subreddits. |
| | |
| | Attributes: |
| | retries (int): Number of retries for API requests in case of failure. |
| | delay (int): Delay between retries in seconds. |
| | reddit (Reddit): An instance of the Reddit client. |
| | """ |
| |
|
| | def __init__(self, retries: int = 3, delay: int = 0): |
| | r"""Initializes the RedditToolkit with the specified number of retries |
| | and delay. |
| | |
| | Args: |
| | retries (int): Number of times to retry the request in case of |
| | failure. Defaults to `3`. |
| | delay (int): Time in seconds to wait between retries. Defaults to |
| | `0`. |
| | """ |
| | from praw import Reddit |
| |
|
| | self.retries = retries |
| | self.delay = delay |
| |
|
| | self.client_id = os.environ.get("REDDIT_CLIENT_ID", "") |
| | self.client_secret = os.environ.get("REDDIT_CLIENT_SECRET", "") |
| | self.user_agent = os.environ.get("REDDIT_USER_AGENT", "") |
| |
|
| | self.reddit = Reddit( |
| | client_id=self.client_id, |
| | client_secret=self.client_secret, |
| | user_agent=self.user_agent, |
| | request_timeout=30, |
| | ) |
| |
|
| | def _retry_request(self, func, *args, **kwargs): |
| | r"""Retries a function in case of network-related errors. |
| | |
| | Args: |
| | func (callable): The function to be retried. |
| | *args: Arguments to pass to the function. |
| | **kwargs: Keyword arguments to pass to the function. |
| | |
| | Returns: |
| | Any: The result of the function call if successful. |
| | |
| | Raises: |
| | RequestException: If all retry attempts fail. |
| | """ |
| | for attempt in range(self.retries): |
| | try: |
| | return func(*args, **kwargs) |
| | except RequestException as e: |
| | print(f"Attempt {attempt + 1}/{self.retries} failed: {e}") |
| | if attempt < self.retries - 1: |
| | time.sleep(self.delay) |
| | else: |
| | raise |
| |
|
| | def collect_top_posts( |
| | self, |
| | subreddit_name: str, |
| | post_limit: int = 5, |
| | comment_limit: int = 5, |
| | ) -> Union[List[Dict[str, Any]], str]: |
| | r"""Collects the top posts and their comments from a specified |
| | subreddit. |
| | |
| | Args: |
| | subreddit_name (str): The name of the subreddit to collect posts |
| | from. |
| | post_limit (int): The maximum number of top posts to collect. |
| | Defaults to `5`. |
| | comment_limit (int): The maximum number of top comments to collect |
| | per post. Defaults to `5`. |
| | |
| | Returns: |
| | Union[List[Dict[str, Any]], str]: A list of dictionaries, each |
| | containing the post title and its top comments if success. |
| | String warming if credentials are not set. |
| | """ |
| | if not all([self.client_id, self.client_secret, self.user_agent]): |
| | return ( |
| | "Reddit API credentials are not set. " |
| | "Please set the environment variables." |
| | ) |
| |
|
| | subreddit = self._retry_request(self.reddit.subreddit, subreddit_name) |
| | top_posts = self._retry_request(subreddit.top, limit=post_limit) |
| | data = [] |
| |
|
| | for post in top_posts: |
| | post_data = { |
| | "Post Title": post.title, |
| | "Comments": [ |
| | {"Comment Body": comment.body, "Upvotes": comment.score} |
| | for comment in self._retry_request( |
| | lambda post=post: list(post.comments) |
| | )[:comment_limit] |
| | ], |
| | } |
| | data.append(post_data) |
| | time.sleep(self.delay) |
| |
|
| | return data |
| |
|
| | def perform_sentiment_analysis( |
| | self, data: List[Dict[str, Any]] |
| | ) -> List[Dict[str, Any]]: |
| | r"""Performs sentiment analysis on the comments collected from Reddit |
| | posts. |
| | |
| | Args: |
| | data (List[Dict[str, Any]]): A list of dictionaries containing |
| | Reddit post data and comments. |
| | |
| | Returns: |
| | List[Dict[str, Any]]: The original data with an added 'Sentiment |
| | Score' for each comment. |
| | """ |
| | from textblob import TextBlob |
| |
|
| | for item in data: |
| | |
| | item["Sentiment Score"] = TextBlob( |
| | item["Comment Body"] |
| | ).sentiment.polarity |
| |
|
| | return data |
| |
|
| | def track_keyword_discussions( |
| | self, |
| | subreddits: List[str], |
| | keywords: List[str], |
| | post_limit: int = 10, |
| | comment_limit: int = 10, |
| | sentiment_analysis: bool = False, |
| | ) -> Union[List[Dict[str, Any]], str]: |
| | r"""Tracks discussions about specific keywords in specified subreddits. |
| | |
| | Args: |
| | subreddits (List[str]): A list of subreddit names to search within. |
| | keywords (List[str]): A list of keywords to track in the subreddit |
| | discussions. |
| | post_limit (int): The maximum number of top posts to collect per |
| | subreddit. Defaults to `10`. |
| | comment_limit (int): The maximum number of top comments to collect |
| | per post. Defaults to `10`. |
| | sentiment_analysis (bool): If True, performs sentiment analysis on |
| | the comments. Defaults to `False`. |
| | |
| | Returns: |
| | Union[List[Dict[str, Any]], str]: A list of dictionaries |
| | containing the subreddit name, post title, comment body, and |
| | upvotes for each comment that contains the specified keywords |
| | if success. String warming if credentials are not set. |
| | """ |
| | if not all([self.client_id, self.client_secret, self.user_agent]): |
| | return ( |
| | "Reddit API credentials are not set. " |
| | "Please set the environment variables." |
| | ) |
| |
|
| | data = [] |
| |
|
| | for subreddit_name in subreddits: |
| | subreddit = self._retry_request( |
| | self.reddit.subreddit, subreddit_name |
| | ) |
| | top_posts = self._retry_request(subreddit.top, limit=post_limit) |
| |
|
| | for post in top_posts: |
| | for comment in self._retry_request( |
| | lambda post=post: list(post.comments) |
| | )[:comment_limit]: |
| | |
| | if any( |
| | keyword.lower() in comment.body.lower() |
| | for keyword in keywords |
| | ): |
| | comment_data = { |
| | "Subreddit": subreddit_name, |
| | "Post Title": post.title, |
| | "Comment Body": comment.body, |
| | "Upvotes": comment.score, |
| | } |
| | data.append(comment_data) |
| | |
| | time.sleep(self.delay) |
| | if sentiment_analysis: |
| | data = self.perform_sentiment_analysis(data) |
| | return data |
| |
|
| | def get_tools(self) -> List[FunctionTool]: |
| | r"""Returns a list of FunctionTool objects representing the |
| | functions in the toolkit. |
| | |
| | Returns: |
| | List[FunctionTool]: A list of FunctionTool objects for the |
| | toolkit methods. |
| | """ |
| | return [ |
| | FunctionTool(self.collect_top_posts), |
| | FunctionTool(self.perform_sentiment_analysis), |
| | FunctionTool(self.track_keyword_discussions), |
| | ] |
| |
|