159 lines
5.3 KiB
Python
159 lines
5.3 KiB
Python
from auth import *
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from functions import *
|
|
import random
|
|
from sql import *
|
|
from tables import *
|
|
|
|
# fastapi init
|
|
app = FastAPI(lifespan=lifespan, redirect_slashes=False)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# routes
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Hello World"}
|
|
|
|
@app.post("/user/create", response_model=UserPublic)
|
|
def create_user(user: UserCreate, session: SessionDep):
|
|
salt = random.randint(00000, 99999)
|
|
existing_email = session.query(User).filter(User.email == user.email).first()
|
|
if existing_email:
|
|
return {"detail": "Email already registered"}
|
|
existing_username = session.query(User).filter(User.username == user.username).first()
|
|
if existing_username:
|
|
return {"detail": "Username already registered"}
|
|
|
|
user.password = hash_password(user.password, salt)
|
|
user.salt = salt
|
|
|
|
db_user = User.model_validate(user)
|
|
session.add(db_user)
|
|
session.commit()
|
|
session.refresh(db_user)
|
|
return db_user
|
|
|
|
# generated using Claude 3.5 Sonnet
|
|
@app.patch("/user/update-tasks", response_model=UserPublic)
|
|
async def update_user_task(task: UpdateTask, session: SessionDep, current_user: User = Depends(get_current_user)):
|
|
user_db = session.get(User, current_user.id)
|
|
user_data = task.model_dump(exclude_unset=True)
|
|
new_task = user_data
|
|
|
|
print(new_task['task_id']);
|
|
|
|
if user_db.tasks is None:
|
|
user_db.tasks = []
|
|
|
|
# Create a copy of current tasks
|
|
updated_tasks = user_db.tasks.copy()
|
|
|
|
# Check if task_id exists in user_db tasks
|
|
existing_task = next(
|
|
(task for task in updated_tasks if task.get('task_id') == new_task.get('task_id')),
|
|
None
|
|
)
|
|
|
|
if existing_task:
|
|
# Overwrite the existing task with the new task
|
|
for i, task in enumerate(updated_tasks):
|
|
if task['task_id'] == new_task['task_id']:
|
|
updated_tasks[i] = new_task
|
|
break
|
|
else:
|
|
# Generate new task_id if not provided
|
|
if new_task.get('task_id') is None:
|
|
new_task['task_id'] = len(updated_tasks)
|
|
# Append the new task
|
|
updated_tasks.append(new_task)
|
|
|
|
# Update user_db tasks with the modified list
|
|
user_db.tasks = updated_tasks
|
|
|
|
# Update the session
|
|
session.add(user_db)
|
|
session.commit()
|
|
session.refresh(user_db)
|
|
|
|
return user_db
|
|
|
|
@app.delete("/user/delete-task", response_model=UserPublic)
|
|
async def delete_user_task(task: DeleteTask, session: SessionDep, current_user: User = Depends(get_current_user)):
|
|
user_db = session.get(User, current_user.id)
|
|
user_data = task.model_dump(exclude_unset=True)
|
|
task_id = user_data['task_id']
|
|
|
|
if user_db.tasks is None:
|
|
return {"detail": "No tasks to delete"}
|
|
|
|
# Create a copy of current tasks
|
|
updated_tasks = user_db.tasks.copy()
|
|
|
|
# Remove the task with the specified task_id
|
|
updated_tasks = [task for task in updated_tasks if task.get('task_id') != task_id]
|
|
|
|
# Update user_db tasks with the modified list
|
|
user_db.tasks = updated_tasks
|
|
|
|
# Update the session
|
|
session.add(user_db)
|
|
session.commit()
|
|
session.refresh(user_db)
|
|
|
|
return user_db
|
|
|
|
@app.patch("/user/update", response_model=UserPublic)
|
|
async def update_user(user: UserUpdate, session: SessionDep, current_user: User = Depends(get_current_user)):
|
|
user_db = session.get(User, current_user.id)
|
|
user_data = user.model_dump(exclude_unset=True)
|
|
|
|
# Check if the password is being updated
|
|
if 'password' in user_data:
|
|
# Generate a new salt
|
|
salt = random.randint(00000, 99999)
|
|
# Hash the new password with the new salt
|
|
user_data['password'] = hash_password(user_data['password'], salt)
|
|
user_data['salt'] = salt
|
|
|
|
user_db.sqlmodel_update(user_data)
|
|
session.add(user_db)
|
|
session.commit()
|
|
session.refresh(user_db)
|
|
return user_db
|
|
|
|
@app.post('/user/login')
|
|
async def verify_user(user: VerifyUser, session: SessionDep):
|
|
existing_user = session.query(User).filter(User.username == user.username).first()
|
|
if not existing_user:
|
|
return {"detail": "User not found"}
|
|
|
|
is_password_valid = validate_password(user.password, existing_user.password, existing_user.salt)
|
|
|
|
if not is_password_valid:
|
|
return {"detail": "Invalid password"}
|
|
|
|
return {"detail": "Login successful"}
|
|
|
|
|
|
@app.post('/token', response_model=Token)
|
|
async def login_for_access_token(session: SessionDep, form_data: OAuth2PasswordRequestForm = Depends()):
|
|
user = authenticate_user(form_data.username, form_data.password, session)
|
|
if not user:
|
|
raise HTTPException(status_code=401, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
|
|
|
|
access_token_expires = timedelta(minutes=int(os.getenv("ACCESS_TOKEN_EXPIRES_MINUTES")))
|
|
access_token = create_access_token(data={"sub": user.username}, expires_delta=access_token_expires)
|
|
return {"access_token": access_token, "token_type": "bearer", "expires_days": access_token_expires.days}
|
|
|
|
@app.get("/users/me", response_model=UserPublic)
|
|
async def read_users_me(session: SessionDep, current_user: User = Depends(get_current_user)):
|
|
return current_user |