diff --git a/database.py b/database.py index a2e0870..552017d 100644 --- a/database.py +++ b/database.py @@ -19,6 +19,25 @@ def create_db(): models.Base.metadata.create_all(bind=engine) +def _json_quotes(s): + return s.replace("'", '"') + + +def adapt_movie_data(data_in: dict): + import copy, json + + data_out = copy.deepcopy(data_in) + + # adapt genresd to stringlist + breakpoint() + + data_out["genres"] = [ + genre["name"] for genre in json.loads(_json_quotes(data_in["genres"])) + ] + + return data_out + + def fill_db( db=SessionLocal(), movie_input_file: str = "input_data/movies_metadata_short.csv" ): @@ -28,7 +47,8 @@ def fill_db( with open(movie_input_file) as csvfile: for movie_data in csv.DictReader(csvfile): - crud.create_movie(db, **movie_data) + adapted_data = adapt_movie_data(movie_data) + crud.create_movie(db, **adapted_data) if __name__ == "__main__":