diff --git a/database.py b/database.py index f3a39ee..d93f256 100644 --- a/database.py +++ b/database.py @@ -40,7 +40,6 @@ def adapt_movie_data(data_in: dict): def fill_db( db=SessionLocal(), movie_input_file: str = "input_data/movies_metadata.csv" ): - create_db() import crud import csv @@ -56,4 +55,5 @@ def fill_db( if __name__ == "__main__": + create_db() fill_db() diff --git a/utests/test_sql_database.py b/utests/test_sql_database.py index 0e3e3b7..6a1f474 100644 --- a/utests/test_sql_database.py +++ b/utests/test_sql_database.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from sqlalchemy import MetaData from sqlalchemy import exc -from database import Base +from database import Base, fill_db from dev import app, get_db from models import Movie @@ -120,3 +120,27 @@ def test_list_movies(): movies = client.get("movies") movies_by_title = {m["title"]: m for m in movies.json()} assert all(movies_by_title[name] for name in names) + + +def test_sample_import_toy_story(): + clear_db() + + movie_title = "Toy Story" + file_path = "input_data/movies_metadata_short.csv" + + movies = client.get("movies") + movies_by_title = {m["title"]: m for m in movies.json()} + + assert movie_title not in movies_by_title, "The movie should not be pre existing" + + with db_context() as db: + fill_db(db, file_path) + + movies = client.get("movies") + movies_by_title = {m["title"]: m for m in movies.json()} + + toy_story = movies_by_title["Toy Story"] + + assert "Andy" in toy_story["description"] + + # non regression