diff --git a/dev.py b/dev.py index aeadee1..d17c339 100644 --- a/dev.py +++ b/dev.py @@ -114,8 +114,11 @@ async def delete_movie(id_: str, db: Session = Depends(get_db)) -> None: @app.get("/movies/") -async def list_movie(db: Session = Depends(get_db)) -> list[schemas.MovieObject]: - return crud.get_all_movies(db) +async def list_movie(db: Session = Depends(get_db)) -> schemas.MovieObjectsOut: + movies = crud.get_all_movies(db) + count = len(movies) + + return {"movies": movies, "count": count} if __name__ == "__main__": diff --git a/schemas.py b/schemas.py index 7015bba..eb5da75 100644 --- a/schemas.py +++ b/schemas.py @@ -14,3 +14,8 @@ class MoviePayload(BaseModel): class MovieObject(MoviePayload): id: int | str + + +class MovieObjectsOut(BaseModel): + movies: list[MovieObject] + count: int diff --git a/utests/test_api.py b/utests/test_api.py index 5c3d659..cef7ddd 100644 --- a/utests/test_api.py +++ b/utests/test_api.py @@ -112,6 +112,7 @@ class BaseCrud(unittest.TestCase): def test_list_movies(self): response = client.get("/movies/") assert response.status_code == 200 + primary_count = response.json()["count"] # assert response.json() == [] N = 10 @@ -124,11 +125,38 @@ class BaseCrud(unittest.TestCase): response = client.post("/movies/", json=self.create_payload) assert response.status_code == 200 - movies = client.get("/movies/") - movies_by_title = {m["title"]: m for m in movies.json()} + response = client.get("/movies/").json() + + movies = response["movies"] + count = response["count"] + movies_by_title = {m["title"]: m for m in movies} found = list(movies_by_title[title] for title in names) assert all(movies_by_title[title] for title in names) + assert count == primary_count + N + + def test_list_movies_payload_format(self): + response = client.get("/movies/") + assert response.status_code == 200 + # assert response.json() == [] + primary_count = response.json()["count"] + + N = 10 + names = [] + for _ in range(N): + name = rand_name() + + names.append(name) + self.create_payload["title"] = name + response = client.post("/movies/", json=self.create_payload) + assert response.status_code == 200 + + movies = client.get("/movies/").json() + + assert isinstance(movies["count"], int) + assert isinstance(movies["movies"], list) + assert movies["count"] == primary_count + N + class ApiTestCase(unittest.TestCase): def test_payload_content_in_and_out_loopback(self): diff --git a/utests/test_sql_database.py b/utests/test_sql_database.py index 2fde7e7..a7cc834 100644 --- a/utests/test_sql_database.py +++ b/utests/test_sql_database.py @@ -117,8 +117,8 @@ def test_list_movies(): names.append(name) crud.create_movie(db, title=name, genres=["Animated", "Paropaganda"]) - movies = client.get("movies") - movies_by_title = {m["title"]: m for m in movies.json()} + movies = client.get("movies").json()["movies"] + movies_by_title = {m["title"]: m for m in movies} assert all(movies_by_title[name] for name in names) @@ -128,16 +128,16 @@ def test_sample_import_toy_story(): movie_title = "Toy Story" file_path = "input_data/movies_metadata_short.csv" - movies = client.get("movies") - movies_by_title = {m["title"]: m for m in movies.json()} + movies = client.get("movies").json()["movies"] + movies_by_title = {m["title"]: m for m in movies} assert movie_title not in movies_by_title, "The movie should not be pre existing" with db_context() as db: fill_db(db, file_path) - movies = client.get("movies") - movies_by_title = {m["title"]: m for m in movies.json()} + movies = client.get("movies").json()["movies"] + movies_by_title = {m["title"]: m for m in movies} toy_story = movies_by_title["Toy Story"] @@ -156,12 +156,12 @@ def test_title_is_taken_form_original_title_is_missing(): file_path = "utests/movie_error_missing_title.csv" file_path = "input_data/movies_metadata.csv" - movies = client.get("movies") - movies_by_title = {m["title"]: m for m in movies.json()} + movies = client.get("movies").json()["movies"] + movies_by_title = {m["title"]: m for m in movies} assert movie_title not in movies_by_title, "The movie should not be pre existing" with db_context() as db: fill_db(db, file_path, sample_rate=1) - movies = client.get("movies") - movies_by_title = {m["title"]: m for m in movies.json()} + movies = client.get("movies").json()["movies"] + movies_by_title = {m["title"]: m for m in movies}