FINAL: WebSocket + DMs + Files - Backend Complete!
WebSocket Real-Time: - Socket.IO server integrated - Real-time message delivery - User online/offline status - Typing indicators - Channel room management - Auto-join on channel access Direct Messages: - 1-on-1 chat API - DM history and conversations - @grimlock in DMs (AI responds) - Read receipts - Unread count tracking - WebSocket notifications File Management: - Upload files to channels - Download with streaming - File metadata tracking - File listing by channel - Delete with permissions Integration: - Messages broadcast via WebSocket - DM notifications via WebSocket - All APIs updated for real-time BACKEND IS FEATURE COMPLETE! - Auth ✅ - Channels ✅ - Messages ✅ - DMs ✅ - Files ✅ - WebSocket ✅ - @grimlock AI ✅ Ready for frontend development in next session!
This commit is contained in:
265
backend/api/direct_messages.py
Normal file
265
backend/api/direct_messages.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
"""
|
||||||
|
Direct Messages API - 1-on-1 messaging with @grimlock support
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, BackgroundTasks
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import or_, and_
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from core.database import get_db
|
||||||
|
from core.models import DirectMessage, User
|
||||||
|
from core.ai_client import AIClient
|
||||||
|
from core.context_manager import ContextManager
|
||||||
|
from api.auth import get_current_user
|
||||||
|
from core.websocket import send_dm_notification
|
||||||
|
import main
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
class DMCreate(BaseModel):
|
||||||
|
recipient_id: int
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class DMResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
sender_id: int
|
||||||
|
recipient_id: int
|
||||||
|
content: str
|
||||||
|
is_ai_message: bool
|
||||||
|
read_at: Optional[str]
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class ConversationResponse(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
user_name: str
|
||||||
|
user_email: str
|
||||||
|
last_message: str
|
||||||
|
last_message_at: str
|
||||||
|
unread_count: int
|
||||||
|
|
||||||
|
async def handle_grimlock_dm(
|
||||||
|
dm: DirectMessage,
|
||||||
|
sender: User,
|
||||||
|
db: Session,
|
||||||
|
context_manager: ContextManager,
|
||||||
|
ai_client: AIClient
|
||||||
|
):
|
||||||
|
"""Handle @grimlock mention in DM"""
|
||||||
|
|
||||||
|
if not dm.content.lower().startswith('@grimlock'):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get conversation history
|
||||||
|
history = db.query(DirectMessage)\
|
||||||
|
.filter(
|
||||||
|
or_(
|
||||||
|
and_(DirectMessage.sender_id == sender.id, DirectMessage.recipient_id == 1),
|
||||||
|
and_(DirectMessage.sender_id == 1, DirectMessage.recipient_id == sender.id)
|
||||||
|
)
|
||||||
|
)\
|
||||||
|
.filter(DirectMessage.id < dm.id)\
|
||||||
|
.order_by(DirectMessage.id.desc())\
|
||||||
|
.limit(10)\
|
||||||
|
.all()
|
||||||
|
|
||||||
|
# Build conversation
|
||||||
|
conversation = []
|
||||||
|
for msg in reversed(history):
|
||||||
|
if msg.sender_id == sender.id:
|
||||||
|
conversation.append({"role": "user", "content": msg.content})
|
||||||
|
else:
|
||||||
|
conversation.append({"role": "assistant", "content": msg.content})
|
||||||
|
|
||||||
|
# Add current message
|
||||||
|
query = dm.content.replace('@grimlock', '').strip()
|
||||||
|
conversation.append({"role": "user", "content": query})
|
||||||
|
|
||||||
|
# Get context
|
||||||
|
context = context_manager.get_context_for_query(query, role=sender.role.value)
|
||||||
|
system_prompt = context_manager.get_system_prompt(role=sender.role.value)
|
||||||
|
system_prompt += f"\n\nYou are in a direct message conversation with {sender.name}."
|
||||||
|
if context:
|
||||||
|
system_prompt += f"\n\n# Company Context\n{context}"
|
||||||
|
|
||||||
|
# Get AI response
|
||||||
|
response = await ai_client.chat(messages=conversation, system_prompt=system_prompt)
|
||||||
|
|
||||||
|
# Create AI DM
|
||||||
|
ai_dm = DirectMessage(
|
||||||
|
sender_id=None, # AI has no user_id
|
||||||
|
recipient_id=sender.id,
|
||||||
|
content=response,
|
||||||
|
is_ai_message=True
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(ai_dm)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(ai_dm)
|
||||||
|
|
||||||
|
# Notify via WebSocket
|
||||||
|
await send_dm_notification(sender.id, {
|
||||||
|
"id": ai_dm.id,
|
||||||
|
"sender_id": None,
|
||||||
|
"content": response,
|
||||||
|
"is_ai_message": True,
|
||||||
|
"created_at": ai_dm.created_at.isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error handling Grimlock DM: {e}")
|
||||||
|
|
||||||
|
@router.post("/", response_model=DMResponse)
|
||||||
|
async def send_dm(
|
||||||
|
dm_data: DMCreate,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
context_manager: ContextManager = Depends(main.get_context_manager),
|
||||||
|
ai_client: AIClient = Depends(main.get_ai_client)
|
||||||
|
):
|
||||||
|
"""Send a direct message"""
|
||||||
|
|
||||||
|
# Check recipient exists
|
||||||
|
recipient = db.query(User).filter(User.id == dm_data.recipient_id).first()
|
||||||
|
if not recipient:
|
||||||
|
raise HTTPException(status_code=404, detail="Recipient not found")
|
||||||
|
|
||||||
|
# Can't DM yourself
|
||||||
|
if recipient.id == current_user.id:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot send DM to yourself")
|
||||||
|
|
||||||
|
# Create DM
|
||||||
|
dm = DirectMessage(
|
||||||
|
sender_id=current_user.id,
|
||||||
|
recipient_id=dm_data.recipient_id,
|
||||||
|
content=dm_data.content
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(dm)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(dm)
|
||||||
|
|
||||||
|
# Notify recipient via WebSocket
|
||||||
|
await send_dm_notification(recipient.id, {
|
||||||
|
"id": dm.id,
|
||||||
|
"sender_id": current_user.id,
|
||||||
|
"sender_name": current_user.name,
|
||||||
|
"content": dm.content,
|
||||||
|
"created_at": dm.created_at.isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for @grimlock
|
||||||
|
if '@grimlock' in dm.content.lower():
|
||||||
|
background_tasks.add_task(handle_grimlock_dm, dm, current_user, db, context_manager, ai_client)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": dm.id,
|
||||||
|
"sender_id": dm.sender_id,
|
||||||
|
"recipient_id": dm.recipient_id,
|
||||||
|
"content": dm.content,
|
||||||
|
"is_ai_message": dm.is_ai_message,
|
||||||
|
"read_at": dm.read_at.isoformat() if dm.read_at else None,
|
||||||
|
"created_at": dm.created_at.isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/conversations", response_model=List[ConversationResponse])
|
||||||
|
async def list_conversations(
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""List all DM conversations"""
|
||||||
|
|
||||||
|
# Get all users current user has DMed with
|
||||||
|
sent = db.query(DirectMessage.recipient_id).filter(DirectMessage.sender_id == current_user.id).distinct()
|
||||||
|
received = db.query(DirectMessage.sender_id).filter(DirectMessage.recipient_id == current_user.id).distinct()
|
||||||
|
|
||||||
|
user_ids = set([r[0] for r in sent] + [r[0] for r in received])
|
||||||
|
|
||||||
|
conversations = []
|
||||||
|
for user_id in user_ids:
|
||||||
|
if user_id is None: # Skip AI (no user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if not user:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get last message
|
||||||
|
last_msg = db.query(DirectMessage)\
|
||||||
|
.filter(
|
||||||
|
or_(
|
||||||
|
and_(DirectMessage.sender_id == current_user.id, DirectMessage.recipient_id == user_id),
|
||||||
|
and_(DirectMessage.sender_id == user_id, DirectMessage.recipient_id == current_user.id)
|
||||||
|
)
|
||||||
|
)\
|
||||||
|
.order_by(DirectMessage.id.desc())\
|
||||||
|
.first()
|
||||||
|
|
||||||
|
# Count unread
|
||||||
|
unread = db.query(DirectMessage)\
|
||||||
|
.filter(DirectMessage.sender_id == user_id)\
|
||||||
|
.filter(DirectMessage.recipient_id == current_user.id)\
|
||||||
|
.filter(DirectMessage.read_at == None)\
|
||||||
|
.count()
|
||||||
|
|
||||||
|
conversations.append({
|
||||||
|
"user_id": user.id,
|
||||||
|
"user_name": user.name,
|
||||||
|
"user_email": user.email,
|
||||||
|
"last_message": last_msg.content[:100] if last_msg else "",
|
||||||
|
"last_message_at": last_msg.created_at.isoformat() if last_msg else "",
|
||||||
|
"unread_count": unread
|
||||||
|
})
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
@router.get("/{user_id}/messages", response_model=List[DMResponse])
|
||||||
|
async def get_dm_history(
|
||||||
|
user_id: int,
|
||||||
|
limit: int = 50,
|
||||||
|
before: Optional[int] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Get DM history with a user"""
|
||||||
|
|
||||||
|
query = db.query(DirectMessage).filter(
|
||||||
|
or_(
|
||||||
|
and_(DirectMessage.sender_id == current_user.id, DirectMessage.recipient_id == user_id),
|
||||||
|
and_(DirectMessage.sender_id == user_id, DirectMessage.recipient_id == current_user.id)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if before:
|
||||||
|
query = query.filter(DirectMessage.id < before)
|
||||||
|
|
||||||
|
messages = query.order_by(DirectMessage.id.desc()).limit(limit).all()
|
||||||
|
messages.reverse()
|
||||||
|
|
||||||
|
# Mark as read
|
||||||
|
db.query(DirectMessage)\
|
||||||
|
.filter(DirectMessage.sender_id == user_id)\
|
||||||
|
.filter(DirectMessage.recipient_id == current_user.id)\
|
||||||
|
.filter(DirectMessage.read_at == None)\
|
||||||
|
.update({"read_at": datetime.utcnow()})
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": msg.id,
|
||||||
|
"sender_id": msg.sender_id,
|
||||||
|
"recipient_id": msg.recipient_id,
|
||||||
|
"content": msg.content,
|
||||||
|
"is_ai_message": msg.is_ai_message,
|
||||||
|
"read_at": msg.read_at.isoformat() if msg.read_at else None,
|
||||||
|
"created_at": msg.created_at.isoformat()
|
||||||
|
}
|
||||||
|
for msg in messages
|
||||||
|
]
|
||||||
169
backend/api/files.py
Normal file
169
backend/api/files.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""
|
||||||
|
Files API - Upload and download files
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File as FastAPIFile
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
import aiofiles
|
||||||
|
import mimetypes
|
||||||
|
|
||||||
|
from core.database import get_db
|
||||||
|
from core.models import File, User
|
||||||
|
from api.auth import get_current_user
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# File storage configuration
|
||||||
|
UPLOAD_DIR = os.getenv("UPLOAD_DIR", "./uploads")
|
||||||
|
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
class FileResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
filename: str
|
||||||
|
original_filename: str
|
||||||
|
file_size: int
|
||||||
|
mime_type: str
|
||||||
|
uploaded_by: int
|
||||||
|
channel_id: Optional[int]
|
||||||
|
created_at: str
|
||||||
|
download_url: str
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
@router.post("/upload", response_model=FileResponse)
|
||||||
|
async def upload_file(
|
||||||
|
file: UploadFile = FastAPIFile(...),
|
||||||
|
channel_id: Optional[int] = None,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Upload a file"""
|
||||||
|
|
||||||
|
# Generate unique filename
|
||||||
|
file_ext = os.path.splitext(file.filename)[1]
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, unique_filename)
|
||||||
|
|
||||||
|
# Save file
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(file_path, 'wb') as f:
|
||||||
|
content = await file.read()
|
||||||
|
await f.write(content)
|
||||||
|
|
||||||
|
file_size = len(content)
|
||||||
|
mime_type = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
||||||
|
|
||||||
|
# Create file record
|
||||||
|
file_record = File(
|
||||||
|
filename=unique_filename,
|
||||||
|
original_filename=file.filename,
|
||||||
|
file_path=file_path,
|
||||||
|
file_size=file_size,
|
||||||
|
mime_type=mime_type,
|
||||||
|
uploaded_by=current_user.id,
|
||||||
|
channel_id=channel_id
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(file_record)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(file_record)
|
||||||
|
|
||||||
|
return {
|
||||||
|
**file_record.__dict__,
|
||||||
|
"created_at": file_record.created_at.isoformat(),
|
||||||
|
"download_url": f"/api/files/{file_record.id}/download"
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/{file_id}/download")
|
||||||
|
async def download_file(
|
||||||
|
file_id: int,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Download a file"""
|
||||||
|
|
||||||
|
# Get file record
|
||||||
|
file_record = db.query(File).filter(File.id == file_id).first()
|
||||||
|
if not file_record:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
# Check file exists
|
||||||
|
if not os.path.exists(file_record.file_path):
|
||||||
|
raise HTTPException(status_code=404, detail="File not found on disk")
|
||||||
|
|
||||||
|
# Stream file
|
||||||
|
def iterfile():
|
||||||
|
with open(file_record.file_path, mode="rb") as f:
|
||||||
|
yield from f
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
iterfile(),
|
||||||
|
media_type=file_record.mime_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename={file_record.original_filename}"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[FileResponse])
|
||||||
|
async def list_files(
|
||||||
|
channel_id: Optional[int] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""List files (optionally filtered by channel)"""
|
||||||
|
|
||||||
|
query = db.query(File)
|
||||||
|
|
||||||
|
if channel_id:
|
||||||
|
query = query.filter(File.channel_id == channel_id)
|
||||||
|
|
||||||
|
files = query.order_by(File.id.desc()).limit(limit).all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
**f.__dict__,
|
||||||
|
"created_at": f.created_at.isoformat(),
|
||||||
|
"download_url": f"/api/files/{f.id}/download"
|
||||||
|
}
|
||||||
|
for f in files
|
||||||
|
]
|
||||||
|
|
||||||
|
@router.delete("/{file_id}")
|
||||||
|
async def delete_file(
|
||||||
|
file_id: int,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""Delete a file"""
|
||||||
|
|
||||||
|
file_record = db.query(File).filter(File.id == file_id).first()
|
||||||
|
if not file_record:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
# Check permissions (owner or admin)
|
||||||
|
if file_record.uploaded_by != current_user.id and current_user.role.value != "admin":
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to delete this file")
|
||||||
|
|
||||||
|
# Delete from disk
|
||||||
|
try:
|
||||||
|
if os.path.exists(file_record.file_path):
|
||||||
|
os.remove(file_record.file_path)
|
||||||
|
except Exception as e:
|
||||||
|
pass # Continue even if disk delete fails
|
||||||
|
|
||||||
|
# Delete record
|
||||||
|
db.delete(file_record)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return {"message": "File deleted"}
|
||||||
@@ -14,7 +14,11 @@ from core.models import Message, Channel, User, ChannelType
|
|||||||
from core.ai_client import AIClient
|
from core.ai_client import AIClient
|
||||||
from core.context_manager import ContextManager
|
from core.context_manager import ContextManager
|
||||||
from api.auth import get_current_user
|
from api.auth import get_current_user
|
||||||
|
from core.websocket import broadcast_new_message
|
||||||
import main
|
import main
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -178,6 +182,25 @@ async def send_message(
|
|||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(message)
|
db.refresh(message)
|
||||||
|
|
||||||
|
# Broadcast via WebSocket
|
||||||
|
message_data = {
|
||||||
|
"id": message.id,
|
||||||
|
"content": message.content,
|
||||||
|
"is_ai_message": message.is_ai_message,
|
||||||
|
"user": {
|
||||||
|
"id": current_user.id,
|
||||||
|
"name": current_user.name,
|
||||||
|
"email": current_user.email,
|
||||||
|
"role": current_user.role.value,
|
||||||
|
"is_online": current_user.is_online
|
||||||
|
} if message.user else None,
|
||||||
|
"reply_to_message_id": message.reply_to_message_id,
|
||||||
|
"created_at": message.created_at.isoformat(),
|
||||||
|
"edited_at": message.edited_at.isoformat() if message.edited_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
await broadcast_new_message(channel_id, message_data)
|
||||||
|
|
||||||
# Check for @grimlock mention
|
# Check for @grimlock mention
|
||||||
if detect_grimlock_mention(message_data.content):
|
if detect_grimlock_mention(message_data.content):
|
||||||
# Handle in background to not block response
|
# Handle in background to not block response
|
||||||
|
|||||||
153
backend/core/websocket.py
Normal file
153
backend/core/websocket.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
"""
|
||||||
|
WebSocket Server - Real-time messaging and presence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socketio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Set
|
||||||
|
from core.models import User
|
||||||
|
from core.auth import decode_token
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create Socket.IO server
|
||||||
|
sio = socketio.AsyncServer(
|
||||||
|
async_mode='asgi',
|
||||||
|
cors_allowed_origins='*',
|
||||||
|
logger=True,
|
||||||
|
engineio_logger=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Track connected users: {user_id: set(session_ids)}
|
||||||
|
connected_users: Dict[int, Set[str]] = {}
|
||||||
|
|
||||||
|
# Track user channels: {session_id: user_id}
|
||||||
|
session_to_user: Dict[str, int] = {}
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def connect(sid, environ, auth):
|
||||||
|
"""Handle client connection"""
|
||||||
|
logger.info(f"Client connecting: {sid}")
|
||||||
|
|
||||||
|
# Authenticate via token
|
||||||
|
token = auth.get('token') if auth else None
|
||||||
|
if not token:
|
||||||
|
logger.warning(f"Connection rejected - no token: {sid}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Verify token
|
||||||
|
payload = decode_token(token)
|
||||||
|
if not payload:
|
||||||
|
logger.warning(f"Connection rejected - invalid token: {sid}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
user_id = int(payload.get('sub'))
|
||||||
|
|
||||||
|
# Track connection
|
||||||
|
if user_id not in connected_users:
|
||||||
|
connected_users[user_id] = set()
|
||||||
|
connected_users[user_id].add(sid)
|
||||||
|
session_to_user[sid] = user_id
|
||||||
|
|
||||||
|
logger.info(f"User {user_id} connected: {sid}")
|
||||||
|
|
||||||
|
# Notify others user is online
|
||||||
|
await sio.emit('user_online', {'user_id': user_id}, skip_sid=sid)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def disconnect(sid):
|
||||||
|
"""Handle client disconnect"""
|
||||||
|
user_id = session_to_user.get(sid)
|
||||||
|
if not user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Remove session
|
||||||
|
if user_id in connected_users:
|
||||||
|
connected_users[user_id].discard(sid)
|
||||||
|
if not connected_users[user_id]:
|
||||||
|
del connected_users[user_id]
|
||||||
|
# User fully offline
|
||||||
|
await sio.emit('user_offline', {'user_id': user_id})
|
||||||
|
|
||||||
|
del session_to_user[sid]
|
||||||
|
logger.info(f"User {user_id} disconnected: {sid}")
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def join_channel(sid, data):
|
||||||
|
"""Join a channel room"""
|
||||||
|
channel_id = data.get('channel_id')
|
||||||
|
if not channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
await sio.enter_room(sid, f"channel_{channel_id}")
|
||||||
|
logger.info(f"Session {sid} joined channel {channel_id}")
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def leave_channel(sid, data):
|
||||||
|
"""Leave a channel room"""
|
||||||
|
channel_id = data.get('channel_id')
|
||||||
|
if not channel_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
await sio.leave_room(sid, f"channel_{channel_id}")
|
||||||
|
logger.info(f"Session {sid} left channel {channel_id}")
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def typing_start(sid, data):
|
||||||
|
"""User started typing"""
|
||||||
|
channel_id = data.get('channel_id')
|
||||||
|
user_id = session_to_user.get(sid)
|
||||||
|
|
||||||
|
if channel_id and user_id:
|
||||||
|
await sio.emit(
|
||||||
|
'user_typing',
|
||||||
|
{'user_id': user_id, 'channel_id': channel_id},
|
||||||
|
room=f"channel_{channel_id}",
|
||||||
|
skip_sid=sid
|
||||||
|
)
|
||||||
|
|
||||||
|
@sio.event
|
||||||
|
async def typing_stop(sid, data):
|
||||||
|
"""User stopped typing"""
|
||||||
|
channel_id = data.get('channel_id')
|
||||||
|
user_id = session_to_user.get(sid)
|
||||||
|
|
||||||
|
if channel_id and user_id:
|
||||||
|
await sio.emit(
|
||||||
|
'user_stopped_typing',
|
||||||
|
{'user_id': user_id, 'channel_id': channel_id},
|
||||||
|
room=f"channel_{channel_id}",
|
||||||
|
skip_sid=sid
|
||||||
|
)
|
||||||
|
|
||||||
|
async def broadcast_new_message(channel_id: int, message_data: dict):
|
||||||
|
"""Broadcast new message to channel"""
|
||||||
|
await sio.emit(
|
||||||
|
'new_message',
|
||||||
|
message_data,
|
||||||
|
room=f"channel_{channel_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def broadcast_message_update(channel_id: int, message_data: dict):
|
||||||
|
"""Broadcast message update to channel"""
|
||||||
|
await sio.emit(
|
||||||
|
'message_updated',
|
||||||
|
message_data,
|
||||||
|
room=f"channel_{channel_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_dm_notification(user_id: int, message_data: dict):
|
||||||
|
"""Send DM notification to user"""
|
||||||
|
if user_id in connected_users:
|
||||||
|
for sid in connected_users[user_id]:
|
||||||
|
await sio.emit('new_dm', message_data, room=sid)
|
||||||
|
|
||||||
|
def get_connected_users() -> list:
|
||||||
|
"""Get list of online user IDs"""
|
||||||
|
return list(connected_users.keys())
|
||||||
|
|
||||||
|
def is_user_online(user_id: int) -> bool:
|
||||||
|
"""Check if user is online"""
|
||||||
|
return user_id in connected_users
|
||||||
@@ -9,15 +9,19 @@ from contextlib import asynccontextmanager
|
|||||||
import logging
|
import logging
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import os
|
import os
|
||||||
|
import socketio
|
||||||
|
|
||||||
from api.chat import router as chat_router
|
from api.chat import router as chat_router
|
||||||
from api.auth import router as auth_router
|
from api.auth import router as auth_router
|
||||||
from api.channels import router as channels_router
|
from api.channels import router as channels_router
|
||||||
from api.messages import router as messages_router
|
from api.messages import router as messages_router
|
||||||
|
from api.direct_messages import router as dm_router
|
||||||
|
from api.files import router as files_router
|
||||||
from core.context_manager import ContextManager
|
from core.context_manager import ContextManager
|
||||||
from core.ai_client import AIClient
|
from core.ai_client import AIClient
|
||||||
from core.database import engine
|
from core.database import engine
|
||||||
from core.models import Base
|
from core.models import Base
|
||||||
|
from core.websocket import sio, broadcast_new_message
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
@@ -59,6 +63,8 @@ async def lifespan(app: FastAPI):
|
|||||||
ai_client = AIClient(api_key=api_key)
|
ai_client = AIClient(api_key=api_key)
|
||||||
logger.info("AI client initialized")
|
logger.info("AI client initialized")
|
||||||
|
|
||||||
|
logger.info("WebSocket server ready")
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
@@ -68,10 +74,13 @@ async def lifespan(app: FastAPI):
|
|||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Grimlock",
|
title="Grimlock",
|
||||||
description="AI-Native Company Operating System",
|
description="AI-Native Company Operating System",
|
||||||
version="0.2.0",
|
version="0.3.0",
|
||||||
lifespan=lifespan
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Mount Socket.IO
|
||||||
|
socket_app = socketio.ASGIApp(sio, app)
|
||||||
|
|
||||||
# CORS middleware
|
# CORS middleware
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -85,6 +94,8 @@ app.add_middleware(
|
|||||||
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
app.include_router(auth_router, prefix="/api/auth", tags=["auth"])
|
||||||
app.include_router(channels_router, prefix="/api/channels", tags=["channels"])
|
app.include_router(channels_router, prefix="/api/channels", tags=["channels"])
|
||||||
app.include_router(messages_router, prefix="/api/channels", tags=["messages"])
|
app.include_router(messages_router, prefix="/api/channels", tags=["messages"])
|
||||||
|
app.include_router(dm_router, prefix="/api/dms", tags=["direct-messages"])
|
||||||
|
app.include_router(files_router, prefix="/api/files", tags=["files"])
|
||||||
app.include_router(chat_router, prefix="/api/chat", tags=["chat"])
|
app.include_router(chat_router, prefix="/api/chat", tags=["chat"])
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
@@ -93,8 +104,8 @@ async def root():
|
|||||||
return {
|
return {
|
||||||
"status": "online",
|
"status": "online",
|
||||||
"service": "Grimlock",
|
"service": "Grimlock",
|
||||||
"version": "0.2.0",
|
"version": "0.3.0",
|
||||||
"features": ["auth", "channels", "messages", "ai"]
|
"features": ["auth", "channels", "messages", "dms", "files", "websocket", "ai"]
|
||||||
}
|
}
|
||||||
|
|
||||||
@app.get("/api/health")
|
@app.get("/api/health")
|
||||||
@@ -104,7 +115,8 @@ async def health():
|
|||||||
"status": "healthy",
|
"status": "healthy",
|
||||||
"context_loaded": context_manager is not None and context_manager.is_loaded(),
|
"context_loaded": context_manager is not None and context_manager.is_loaded(),
|
||||||
"ai_client_ready": ai_client is not None,
|
"ai_client_ready": ai_client is not None,
|
||||||
"database": "connected"
|
"database": "connected",
|
||||||
|
"websocket": "ready"
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_context_manager() -> ContextManager:
|
def get_context_manager() -> ContextManager:
|
||||||
|
|||||||
Reference in New Issue
Block a user