|
|
@@ -17,12 +17,10 @@ logger = logging.getLogger(__name__)
|
|
|
chat_router = APIRouter()
|
|
|
|
|
|
redis_client = Redis(host='localhost', port=6379, db=1 if DEVELOPMENT else 0, decode_responses=True)
|
|
|
-
|
|
|
+broadcast_channel = "chat" if not DEVELOPMENT else "chat_dev"
|
|
|
@chat_router.websocket("/ws")
|
|
|
async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
"""WebSocket endpoint for real-time chat interactions"""
|
|
|
- global connected_users
|
|
|
-
|
|
|
|
|
|
# Backend de broadcast con Redis
|
|
|
broadcast = Broadcast("redis://localhost:6379")
|
|
|
@@ -46,15 +44,15 @@ async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
async def reader():
|
|
|
"""Lee mensajes desde Redis y los manda al WebSocket"""
|
|
|
logger.info("Iniciando lector de mensajes")
|
|
|
- async with broadcast.subscribe(channel="chat") as subscriber:
|
|
|
- logger.info("Subscribed to Redis channel 'chat'")
|
|
|
+ async with broadcast.subscribe(channel=broadcast_channel) as subscriber:
|
|
|
+ logger.info(f"Subscribed to Redis channel '{broadcast_channel}'")
|
|
|
logger.debug("Starting message read loop")
|
|
|
- async for event in subscriber:
|
|
|
+ async for event in subscriber: # type: ignore
|
|
|
try:
|
|
|
logger.debug(f"Broadcasting message to WebSocket: {event.message}")
|
|
|
await websocket.send_text(event.message)
|
|
|
- except Exception:
|
|
|
- logger.error("Error sending message to WebSocket, closing connection.")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error sending message to WebSocket: {e}, closing connection.")
|
|
|
break
|
|
|
await broadcast.connect()
|
|
|
reader_task = asyncio.create_task(reader())
|
|
|
@@ -65,6 +63,7 @@ async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
try:
|
|
|
|
|
|
data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
|
|
|
+ logger.debug(f"Received data from WebSocket: {data}")
|
|
|
payload = json.loads(data)
|
|
|
event_type = payload.get("type")
|
|
|
logger.debug(f"Received event: {event_type} with payload: {payload}")
|
|
|
@@ -89,7 +88,7 @@ async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
message_username = payload["username"]
|
|
|
|
|
|
users = redis_client.smembers("connected_users")
|
|
|
- for user in users:
|
|
|
+ for user in users: # type: ignore
|
|
|
user_data = json.loads(user)
|
|
|
if user_data["username"] == message_username:
|
|
|
redis_client.srem("connected_users", user)
|
|
|
@@ -99,31 +98,23 @@ async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
break
|
|
|
elif event_type == "ai_message":
|
|
|
messages = redis_client.lrange("chat_history", -15, -1)
|
|
|
- parsed_messages = [json.loads(msg) for msg in messages]
|
|
|
+ parsed_messages = [json.loads(msg) for msg in messages] # type: ignore
|
|
|
logger.debug(f"IA Message from {payload['username']}")
|
|
|
message_username = payload["username"]
|
|
|
response_content = await generate_completion(parsed_messages, current_user)
|
|
|
response = {"type": "message", "username": "IAKlein", "message": response_content}
|
|
|
- elif event_type == "notification":
|
|
|
- logger.debug(f"Notification to chat")
|
|
|
- notification_message = payload["message"]
|
|
|
- messages = redis_client.lrange("chat_history", -15, -1)
|
|
|
- parsed_messages = [json.loads(msg) for msg in messages]
|
|
|
- response_content = admin_completion(notification_message, parsed_messages)
|
|
|
- response = {"type": "message", "username": "IAKlein", "message": response_content}
|
|
|
-
|
|
|
elif event_type == "mention":
|
|
|
logger.debug(f"Mention to {payload['username']}")
|
|
|
mention_username = payload["username"]
|
|
|
logger.debug(f"Mention username: {mention_username}, Current username: {username}")
|
|
|
- await broadcast.publish(channel="chat", message=json.dumps({"type": "mentioned", "username": mention_username}))
|
|
|
+ await broadcast.publish(channel=broadcast_channel, message=json.dumps({"type": "mentioned", "username": mention_username}))
|
|
|
continue
|
|
|
elif event_type == "pong":
|
|
|
logger.debug(f"Received pong from user {current_user.email}")
|
|
|
continue
|
|
|
|
|
|
# Publicar en Redis
|
|
|
- await broadcast.publish(channel="chat", message=json.dumps(response))
|
|
|
+ await broadcast.publish(channel=broadcast_channel, message=json.dumps(response)) # type: ignore
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
|
|
@@ -132,13 +123,13 @@ async def chat_irc_endpoint(websocket: WebSocket):
|
|
|
logger.info(f"User {current_user.email} disconnected from WebSocket chat.")
|
|
|
|
|
|
users = redis_client.smembers("connected_users")
|
|
|
- for user in users:
|
|
|
+ for user in users: # type: ignore
|
|
|
user_data = json.loads(user)
|
|
|
if user_data["mail"] == current_user.email:
|
|
|
redis_client.srem("connected_users", user)
|
|
|
break
|
|
|
- response = {"type": "leave", "username": user_data["username"]}
|
|
|
- await broadcast.publish(channel="chat", message=json.dumps(response))
|
|
|
+ response = {"type": "leave", "username": user_data["username"]} # type: ignore
|
|
|
+ await broadcast.publish(channel=broadcast_channel, message=json.dumps(response))
|
|
|
|
|
|
finally:
|
|
|
reader_task.cancel()
|
|
|
@@ -149,7 +140,19 @@ async def notify_users(message: NotifyRequest, _: User = Depends(get_current_use
|
|
|
"""Send a notification message to all connected users"""
|
|
|
broadcast = Broadcast("redis://localhost:6379")
|
|
|
await broadcast.connect()
|
|
|
- await broadcast.publish(channel="chat", message=json.dumps({"type": "notification", "message": message.message}))
|
|
|
+
|
|
|
+ # Obtener historial de mensajes para generar respuesta de IA
|
|
|
+ messages = redis_client.lrange("chat_history", -15, -1)
|
|
|
+ parsed_messages = [json.loads(msg) for msg in messages] # type: ignore
|
|
|
+
|
|
|
+ # Generar respuesta de IA para la notificación
|
|
|
+ logger.debug(f"Processing notification: {message.message}")
|
|
|
+ response_content = admin_completion(message.message, parsed_messages)
|
|
|
+
|
|
|
+ # Enviar directamente la respuesta de IA como mensaje
|
|
|
+ response = {"type": "message", "username": "IAKlein", "message": response_content}
|
|
|
+ await broadcast.publish(channel=broadcast_channel, message=json.dumps(response))
|
|
|
+
|
|
|
await broadcast.disconnect()
|
|
|
return {"status": "Notification sent"}
|
|
|
|
|
|
@@ -158,7 +161,7 @@ async def get_connected_users(q: Optional[str] = Query(None), _: User = Depends(
|
|
|
"""Get a list of connected users (solo local al worker)"""
|
|
|
# return {"users": [user.username for user in connected_users if q.lower() in user.username.lower()]}
|
|
|
all_users = redis_client.smembers("connected_users")
|
|
|
- all_users = [json.loads(user)["username"] for user in all_users]
|
|
|
+ all_users = [json.loads(user)["username"] for user in all_users] # type: ignore
|
|
|
if q is None or q.strip() == "":
|
|
|
return {"users": all_users}
|
|
|
filtered_users = [user for user in all_users if q.lower() in user.lower()]
|
|
|
@@ -168,4 +171,4 @@ async def get_connected_users(q: Optional[str] = Query(None), _: User = Depends(
|
|
|
async def get_online_user_count(_: User = Depends(get_current_user)):
|
|
|
"""Get the count of online users (solo local al worker)"""
|
|
|
all_users = redis_client.smembers("connected_users")
|
|
|
- return {"count": len(all_users)}
|
|
|
+ return {"count": len(all_users)} # type: ignore
|