chat.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import asyncio
  2. import json
  3. import logging
  4. from fastapi import Query, Request, HTTPException, Depends, APIRouter, WebSocket, WebSocketDisconnect
  5. from fastapi.responses import JSONResponse
  6. from fastapi.security import HTTPAuthorizationCredentials
  7. from pydantic import BaseModel, ConfigDict
  8. from broadcaster import Broadcast
  9. from models.chat import ChatCompletionRequest
  10. from models.user import User
  11. from services.openai_service.openai_service import generate_completion
  12. from auth.security import get_current_user
  13. from config.messages import SuccessResponse
  14. from redis import Redis
  15. logger = logging.getLogger(__name__)
  16. chat_router = APIRouter()
  17. redis_client = Redis(host='localhost', port=6379, db=0, decode_responses=True)
  18. @chat_router.websocket("/ws")
  19. async def chat_irc_endpoint(websocket: WebSocket):
  20. """WebSocket endpoint for real-time chat interactions"""
  21. global connected_users
  22. # Backend de broadcast con Redis
  23. broadcast = Broadcast("redis://localhost:6379")
  24. logger.info("New WebSocket connection attempt")
  25. logger.debug(f"WebSocket query params: {websocket.query_params}")
  26. current_user = await get_current_user(
  27. HTTPAuthorizationCredentials(
  28. scheme="Bearer",
  29. credentials=websocket.query_params.get("token", "")
  30. )
  31. )
  32. if not current_user:
  33. await websocket.close(code=1008) # Policy Violation
  34. return
  35. await websocket.accept()
  36. logger.info(f"User {current_user.email} connected to WebSocket chat.")
  37. username = None
  38. async def reader():
  39. """Lee mensajes desde Redis y los manda al WebSocket"""
  40. logger.info("Iniciando lector de mensajes")
  41. async with broadcast.subscribe(channel="chat") as subscriber:
  42. logger.info("Subscribed to Redis channel 'chat'")
  43. logger.debug("Starting message read loop")
  44. async for event in subscriber:
  45. try:
  46. logger.debug(f"Broadcasting message to WebSocket: {event.message}")
  47. await websocket.send_text(event.message)
  48. except Exception:
  49. logger.error("Error sending message to WebSocket, closing connection.")
  50. break
  51. await broadcast.connect()
  52. reader_task = asyncio.create_task(reader())
  53. try:
  54. while True:
  55. try:
  56. data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
  57. payload = json.loads(data)
  58. event_type = payload.get("type")
  59. logger.debug(f"Received event: {event_type} with payload: {payload}")
  60. if event_type == "join":
  61. logger.info(f"User {payload['username']} joined the chat.")
  62. data = {
  63. "mail": current_user.email,
  64. "username": payload["username"]
  65. }
  66. redis_client.sadd("connected_users", json.dumps(data))
  67. response = {"type": "join", "username": payload["username"]}
  68. elif event_type == "message":
  69. logger.debug(f"Message from {payload['username']}: {payload['message']}")
  70. message_username = payload["username"]
  71. message = payload["message"]
  72. response = {"type": "message", "username": message_username, "message": message}
  73. elif event_type == "leave":
  74. logger.info(f"User {payload['username']} left the chat.")
  75. message_username = payload["username"]
  76. users = redis_client.smembers("connected_users")
  77. for user in users:
  78. user_data = json.loads(user)
  79. if user_data["username"] == message_username:
  80. redis_client.srem("connected_users", user)
  81. break
  82. response = {"type": "leave", "username": message_username}
  83. await websocket.close()
  84. break
  85. elif event_type == "ai_message":
  86. logger.debug(f"IA Message from {payload['username']}: {payload['messages']}")
  87. message_username = payload["username"]
  88. messages = payload["messages"]
  89. response_content = await generate_completion(messages, current_user)
  90. response = {"type": "message", "username": "IAKlein", "message": response_content}
  91. elif event_type == "mention":
  92. logger.debug(f"Mention to {payload['username']}")
  93. mention_username = payload["username"]
  94. logger.debug(f"Mention username: {mention_username}, Current username: {username}")
  95. await broadcast.publish(channel="chat", message=json.dumps({"type": "mentioned", "username": mention_username}))
  96. continue
  97. elif event_type == "pong":
  98. logger.debug(f"Received pong from user {current_user.email}")
  99. continue
  100. # Publicar en Redis
  101. await broadcast.publish(channel="chat", message=json.dumps(response))
  102. except asyncio.TimeoutError:
  103. await websocket.send_text(json.dumps({"type": "ping"}))
  104. except WebSocketDisconnect:
  105. logger.info(f"User {current_user.email} disconnected from WebSocket chat.")
  106. users = redis_client.smembers("connected_users")
  107. for user in users:
  108. user_data = json.loads(user)
  109. if user_data["mail"] == current_user.email:
  110. redis_client.srem("connected_users", user)
  111. break
  112. response = {"type": "leave", "username": user_data["username"]}
  113. await broadcast.publish(channel="chat", message=json.dumps(response))
  114. finally:
  115. reader_task.cancel()
  116. await broadcast.disconnect()
  117. @chat_router.get("/users")
  118. async def get_connected_users(q: str = Query(...), _: User = Depends(get_current_user)):
  119. """Get a list of connected users (solo local al worker)"""
  120. # return {"users": [user.username for user in connected_users if q.lower() in user.username.lower()]}
  121. all_users = redis_client.smembers("connected_users")
  122. all_users = [json.loads(user)["username"] for user in all_users]
  123. filtered_users = [user for user in all_users if q.lower() in user.lower()]
  124. return {"users": filtered_users}