| import gradio as gr |
| import torch |
| import os |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from sentence_transformers import SentenceTransformer |
| from sklearn.metrics.pairwise import cosine_similarity |
| import yfinance as yf |
| from datetime import datetime |
| from huggingface_hub import login |
|
|
| |
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token: |
| |
| login(token=hf_token) |
| print("Successfully logged in to Hugging Face") |
| else: |
| print("WARNING: HF_TOKEN not found in environment variables. You may face access issues for gated models.") |
|
|
| |
| model_name = "Akshit-77/llama-3.2-3b-chatbot" |
|
|
| |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) |
| except Exception as e: |
| print(f"Error loading tokenizer: {e}") |
| raise |
|
|
| |
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_quant_type="fp4", |
| bnb_4bit_compute_dtype="float16" |
| ) |
|
|
| |
| try: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| quantization_config=bnb_config, |
| device_map="auto", |
| token=hf_token |
| ) |
| print("Model loaded successfully") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise |
|
|
| |
| css = """ |
| #chatbot { |
| font-family: Arial, sans-serif; |
| background-color: #e5ddd5; |
| } |
| .message { |
| padding: 10px 15px; |
| border-radius: 7.5px; |
| margin: 5px 0; |
| max-width: 75%; |
| position: relative; |
| } |
| .user-message { |
| background: #dcf8c6; |
| margin-left: auto; |
| margin-right: 10px; |
| } |
| .bot-message { |
| background: white; |
| margin-left: 10px; |
| } |
| .timestamp { |
| font-size: 0.7em; |
| color: #667781; |
| float: right; |
| margin-left: 10px; |
| margin-top: 3px; |
| } |
| """ |
|
|
| class StockDataRetriever: |
| def __init__(self): |
| self.stock_mapping = { |
| |
| 'RELIANCE': 'RELIANCE.NS', |
| 'TCS': 'TCS.NS', |
| 'HDFCBANK': 'HDFCBANK.NS', |
| 'INFY': 'INFY.NS', |
| 'ICICIBANK': 'ICICIBANK.NS', |
|
|
| |
| 'RELIANCE-INDUSTRIES': 'RELIANCE.NS', |
| 'HDFC': 'HDFC.NS', |
| 'ONGC': 'ONGC.NS', |
| 'INDIAN-OIL-CORPORATION': 'IOC.NS', |
| 'ADANI-GROUP': 'ADANIENT.NS', |
|
|
| 'HERO-MOTOCORP': 'HEROMOTOCO.NS', |
| 'ASIAN-PAINTS': 'ASIANPAINT.NS', |
| 'EICHER-MOTORS': 'EICHERMOT.NS', |
| 'ITC': 'ITC.NS', |
| 'TATA-STEEL': 'TATASTEEL.NS', |
|
|
| 'SHRIRAM-TRANSPORT-FINANCE': 'SHRIRAMFIN.NS', |
| 'DR-REDDYS-LABORATORIES': 'DRREDDY.NS', |
| 'INFOSYS': 'INFY.NS', |
| 'SUN-PHARMA': 'SUNPHARMA.NS', |
| 'TATA-CONSULTANCY-SERVICES': 'TCS.NS', |
|
|
| 'MARUTI-SUZUKI': 'MARUTI.NS', |
| 'HCL-TECHNOLOGIES': 'HCLTECH.NS', |
| 'COAL-INDIA': 'COALINDIA.NS', |
| 'LTI-MINDTREE': 'MINDTREE.NS', |
| 'HDFC-LIFE': 'HDFCLIFE.NS', |
|
|
| 'BAJAJ-AUTO': 'BAJAJ-AUTO.NS', |
| 'BRITANNIA-INDUSTRIES': 'BRITANNIA.NS', |
| 'HINDALCO-INDUSTRIES': 'HINDALCO.NS', |
| 'LARSEN-AND-TOUBRO': 'LT.NS', |
| 'TATA-CONSUMER-PRODUCTS': 'TATACONSUM.NS', |
|
|
| 'WIPRO': 'WIPRO.NS', |
| 'TITAN': 'TITAN.NS', |
| 'BAJAJ-FINANCE': 'BAJFINANCE.NS', |
| 'JSW-STEEL': 'JSWSTEEL.NS', |
| 'ICICI-BANK': 'ICICIBANK.NS', |
|
|
| 'INDUSIND-BANK': 'INDUSINDBK.NS', |
| 'BHARTI-AIRTEL': 'BHARTIARTL.NS', |
| 'DIVIS-LABORATORIES': 'DIVISLAB.NS', |
| 'SBI-LIFE-INSURANCE': 'SBILIFE.NS', |
| 'BAJAJ-FINSERV': 'BAJAJFINSV.NS', |
|
|
| 'CIPLA': 'CIPLA.NS', |
| 'GRASIM-INDUSTRIES': 'GRASIM.NS', |
| 'HINDUSTAN-UNILEVER': 'HINDUNILVR.NS', |
| 'MAHINDRA-AND-MAHINDRA': 'M&M.NS', |
| 'TATA-MOTORS': 'TATAMOTORS.NS', |
|
|
| 'APOLLO-HOSPITALS-ENTERPRISES': 'APOLLOHOSP.NS', |
| 'SBI': 'SBIN.NS', |
| 'KOTAK-MAHINDRA-BANK': 'KOTAKBANK.NS', |
| 'POWER-GRID-CORPORATION-OF-INDIA': 'POWERGRID.NS', |
| 'AXIS-BANK': 'AXISBANK.NS', |
|
|
| 'NTPC': 'NTPC.NS', |
| 'TECH-MAHINDRA': 'TECHM.NS', |
| 'ADANI-PORTS': 'ADANIPORTS.NS', |
| 'ULTRATECH-CEMENT': 'ULTRACEMCO.NS', |
| 'NESTLE': 'NESTLE.NS', |
| 'BHARAT-PETROLEUM': 'BPCL.NS' |
| } |
|
|
| def get_stock_data(self, symbol: str): |
| """Fetch stock data from Yahoo Finance""" |
| try: |
| |
| yf_symbol = self.stock_mapping.get(symbol.upper(), f"{symbol.upper()}.NS") |
| stock = yf.Ticker(yf_symbol) |
| info = stock.info |
|
|
| |
| if not info or 'currentPrice' not in info: |
| return {"error": f"Stock symbol '{symbol}' not found or invalid. Please verify the symbol."} |
|
|
| return { |
| "current_price": info.get("currentPrice", "N/A"), |
| "previous_close": info.get("previousClose", "N/A"), |
| "day_high": info.get("dayHigh", "N/A"), |
| "day_low": info.get("dayLow", "N/A"), |
| "last_updated": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), |
| } |
| except Exception as e: |
| return {"error": f"Could not fetch stock data: {str(e)}"} |
|
|
|
|
| class RAGPipeline: |
| def __init__(self, model_path): |
| self.tokenizer = tokenizer |
| self.model = model |
| self.stock_retriever = StockDataRetriever() |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
| |
| self.knowledge_base = [ |
| "stock price of", |
| "current price", |
| "stock performance", |
| "today's stock price", |
| "stock data for", |
| "price of stock" |
| ] |
| self.knowledge_embeddings = self.encoder.encode(self.knowledge_base) |
|
|
| |
| self.stock_symbols = list(self.stock_retriever.stock_mapping.keys()) |
|
|
| def _extract_stock_symbol(self, query): |
| |
| query_upper = query.upper() |
| for symbol in self.stock_symbols: |
| if symbol in query_upper: |
| return symbol |
|
|
| |
| words = query.split() |
| if words and len(words[-1]) > 1: |
| return words[-1].upper() |
|
|
| return None |
|
|
| def _is_price_query(self, query): |
| query_embedding = self.encoder.encode([query.lower()]) |
| similarities = cosine_similarity(query_embedding, self.knowledge_embeddings)[0] |
|
|
| |
| return max(similarities) > 0.5 |
|
|
| def _format_stock_data(self, stock_data): |
| """Format stock data into a readable string""" |
| if 'error' in stock_data: |
| return stock_data['error'] |
|
|
| return ( |
| f"Stock Data:\n" |
| f"Current Price: ₹{stock_data['current_price']}\n" |
| f"Previous Close: ₹{stock_data['previous_close']}\n" |
| f"Day's High: ₹{stock_data['day_high']}\n" |
| f"Day's Low: ₹{stock_data['day_low']}\n" |
| f"Last Updated: {stock_data['last_updated']}" |
| ) |
|
|
| def generate_response(self, query): |
| |
| stock_context = "" |
| if self._is_price_query(query): |
| |
| symbol = self._extract_stock_symbol(query) |
|
|
| if symbol: |
| |
| stock_data = self.stock_retriever.get_stock_data(symbol) |
| stock_context = self._format_stock_data(stock_data) |
| else: |
| stock_context = "No specific stock symbol could be identified." |
|
|
| |
| full_prompt = ( |
| f"Context: {stock_context}\n\n" |
| f"Question: {query}\n" |
| "Answer:" |
| ) |
|
|
| |
| inputs = self.tokenizer(full_prompt, return_tensors="pt").to(self.model.device) |
| with torch.no_grad(): |
| outputs = self.model.generate(inputs["input_ids"], max_length=500) |
|
|
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
| |
| try: |
| pipeline = RAGPipeline(model_name) |
| print("RAG Pipeline initialized successfully") |
| except Exception as e: |
| print(f"Error initializing RAG Pipeline: {e}") |
| raise |
|
|
| |
| def chat(message, history): |
| history = history or [] |
| try: |
| response = pipeline.generate_response(message) |
| history.append((message, response)) |
| except Exception as e: |
| history.append((message, f"Error generating response: {str(e)}")) |
| return history, "" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(css=css) as iface: |
| gr.HTML("<h1>Indian Stock Market Assistant</h1>") |
| gr.HTML("<p>Ask me about Indian stock prices or any general questions.</p>") |
| |
| chatbot = gr.Chatbot(height=600, elem_id="chatbot") |
| txt = gr.Textbox( |
| placeholder="Type your question here (e.g., 'What is the current price of RELIANCE?')", |
| show_label=False |
| ) |
| |
| txt.submit(chat, [txt, chatbot], [chatbot, txt]) |
| |
| gr.HTML(""" |
| <div style="text-align: center; margin-top: 20px; padding: 10px; background-color: #f0f0f0; border-radius: 5px;"> |
| <p>This chatbot provides real-time Indian stock market data and can answer general questions.</p> |
| <p>Examples: "What's the current price of TCS?", "How is HDFC performing today?", "Tell me about RELIANCE stock"</p> |
| </div> |
| """) |
| |
| return iface |
|
|
| |
| try: |
| iface = create_interface() |
| print("Interface created successfully") |
| except Exception as e: |
| print(f"Error creating interface: {e}") |
| raise |
|
|
| |
| if __name__ == "__main__": |
| iface.launch() |