diff --git a/database.db b/database.db index e69de29..abd3ec3 100644 Binary files a/database.db and b/database.db differ diff --git a/fleetcontrol b/fleetcontrol index 6c508a3..a4fe5b1 100644 --- a/fleetcontrol +++ b/fleetcontrol @@ -1,5 +1,4 @@ #!/usr/bin/python3 - from network.communication import Server from utils.config.config import ServerConfig diff --git a/network/__pycache__/communication.cpython-310.pyc b/network/__pycache__/communication.cpython-310.pyc index e8db3a4..ce0e29d 100644 Binary files a/network/__pycache__/communication.cpython-310.pyc and b/network/__pycache__/communication.cpython-310.pyc differ diff --git a/network/communication.py b/network/communication.py index 787045f..a53e975 100644 --- a/network/communication.py +++ b/network/communication.py @@ -69,21 +69,26 @@ class Server(): temporary_salt = bcrypt.gensalt() hashed_request_password = bcrypt.hashpw( password=request_data["password"].encode("utf-8"), salt=temporary_salt) - if current_user.password_hash != hashed_request_password: + print( + f"Password: {request_data['password']}, password hash from database: {current_user.password_hash}") + # if current_user.password_hash != hashed_request_password: + if not bcrypt.checkpw(request_data["password"].encode("utf-8"), current_user.password_hash): return { "message": "Invalid login data", "data": None, "error": "Auth error" }, 401 try: - current_user["token"] = jwt.encode({ - "username": current_user["username"] + new_token = jwt.encode({ + "username": current_user.username }, self.flask_app.config["SECRET_KEY"], algorithm="HS256" ) - current_user.pop("password") - return current_user, 202 + user_dictionary = current_user.as_dict() + user_dictionary.pop("password_hash") + user_dictionary["token"] = new_token + return user_dictionary, 202 except Exception as ex: return { "message": "Error loging in", @@ -98,13 +103,11 @@ class Server(): }, 500 @require_auth - def register_new_client_to_database(self, request_user): + def register_new_client_to_database(request_user, self): request_content_type = request.headers.get('Content-Type') - json_string = "" if request_content_type == 'application/json': - json_string = request.json + json_object = request.json try: - json_object = json.loads(json_string) new_client_object = Client(mac_address=json_object["mac_address"], ip_address=json_object["ip_address"], hostname=json_object[ "hostname"], client_version=json_object["client_version"], vm_list_on_machine=json_object["vm_list_on_machine"]) self.database.add_client(new_client_object) @@ -112,7 +115,38 @@ class Server(): response.status_code = 201 return response except Exception as ex: - response = jsonify(success=False) + response = jsonify({ + "message": "Internal server error", + "data": None, + "error": str(ex) + }) + response.status_code = 400 + return response + + @require_auth + def add_image_to_database(request_user, self): + request_content_type = request.headers.get('Content-Type') + if request_content_type == 'application/json': + json_object = request.json + try: + new_image_object = VMImage( + image_name=json_object["image_name"], + image_file=json_object["image_file"], + image_version=json_object["image_version"], + image_hash=json_object["image_hash"], + image_name_version_combo=f"{json_object['image_name']}@{json_object['image_version']}", + clients=[] + ) + self.database.add_client(new_image_object) + response = jsonify(success=True) + response.status_code = 201 + return response + except Exception as ex: + response = jsonify({ + "message": "Internal server error", + "data": None, + "error": str(ex) + }) response.status_code = 400 return response @@ -128,7 +162,11 @@ class Server(): self.database.add_user(admin_user) self.app.add_endpoint(endpoint="/", endpoint_name="server_data", handler=self.basic_server_data, methods=["GET"]) + self.app.add_endpoint( + endpoint="/login", endpoint_name="login", handler=self.login, methods=["POST"]) self.app.add_endpoint(endpoint="/clients", endpoint_name="register_client", handler=self.register_new_client_to_database, methods=["POST"]) + self.app.add_endpoint(endpoint="/images", endpoint_name="add_image", + handler=self.add_image_to_database, methods=["POST"]) # TODO: add rest of endpoints self.app.run() diff --git a/utils/database/__pycache__/database.cpython-310.pyc b/utils/database/__pycache__/database.cpython-310.pyc index a861f46..fabbb5b 100644 Binary files a/utils/database/__pycache__/database.cpython-310.pyc and b/utils/database/__pycache__/database.cpython-310.pyc differ diff --git a/utils/database/database.py b/utils/database/database.py index e4bc5dc..0c196bb 100644 --- a/utils/database/database.py +++ b/utils/database/database.py @@ -1,7 +1,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base -from utils.models.models import Client, VMImage, User +from utils.models.models import Client, VMImage, User, Base from utils.exceptions.DatabaseException import DatabaseException import logging @@ -11,9 +11,11 @@ class Database: try: # Connect to the database using SQLAlchemy engine = create_engine(f"sqlite:///{database_file}") - Session = sessionmaker() - Session.configure(bind=engine) - self.session = Session() + self.Session = sessionmaker() + self.Session.configure(bind=engine, expire_on_commit=False) + self.base = Base + self.base.metadata.create_all(bind=engine) + # session = self.Session() # create logger using data from config file self.logger = logging.getLogger(__name__) log_level_mapping_dict = { @@ -32,8 +34,9 @@ class Database: def get_clients(self) -> list[Client]: result = [] try: - with self.session.begin(): - result = self.session.query(Client).all() + session = self.Session() + with session.begin(): + result = session.query(Client).all() except Exception as ex: self.logger.error( f"Error getting list of clients from database: {ex}") @@ -42,9 +45,10 @@ class Database: def get_client_by_mac_address(self, mac_address: str) -> Client: result = None + session = self.Session() try: - with self.session.begin(): - result = self.session.query( + with session.begin(): + result = session.query( Client, mac_address=mac_address).first() except Exception as ex: self.logger.warn(f"Error getting client by mac address: {ex}") @@ -53,8 +57,9 @@ class Database: def get_clients_by_client_version(self, client_version: str) -> list[Client]: result = [] try: - with self.session.begin(): - result = self.session.query( + session = self.Session() + with session.begin(): + result = session.query( Client, client_version=client_version).all() except Exception as ex: self.logger.warn( @@ -76,10 +81,11 @@ class Database: def add_client(self, client: Client): try: - with self.session.begin(): - self.session.add(client) - self.session.flush() - self.session.commit() + session = self.Session() + with session.begin(): + session.add(client) + session.flush() + session.commit() except Exception as ex: self.logger.error(f"Error adding entity to database: {ex}") raise DatabaseException("Error adding entity to database") @@ -87,11 +93,12 @@ class Database: def modify_client(self, client: Client) -> Client: try: old_object = self.get_client_by_mac_address(client.mac_address) - with self.session.begin(): + session = self.Session() + with session.begin(): old_object = client - self.session.merge(old_object) - self.session.flush() - self.session.commit() + session.merge(old_object) + session.flush() + session.commit() return old_object except Exception as ex: self.logger.error(f"Error modifying object in the database: {ex}") @@ -99,15 +106,17 @@ class Database: def delete_client(self, client: Client): try: - with self.session.begin(): - self.session.delete(client) + session = self.Session() + with session.begin(): + session.delete(client) except Exception as ex: self.logger.error(f"Error deleting client from database: {ex}") def get_image_by_id(self, image_id: int) -> VMImage: try: - with self.session.begin(): - response = self.session.query( + session = self.Session() + with session.begin(): + response = session.query( VMImage, image_id=image_id).first() return response except Exception as ex: @@ -115,8 +124,9 @@ class Database: def get_images(self) -> list[VMImage]: try: - with self.session.begin(): - response = self.session.query(VMImage).all() + session = self.Session() + with session.begin(): + response = session.query(VMImage).all() return response except Exception as ex: self.logger.error( @@ -124,8 +134,9 @@ class Database: def get_image_by_name(self, image_name: str) -> list[VMImage]: try: - with self.session.begin(): - response = self.session.query( + session = self.Session() + with session.begin(): + response = session.query( VMImage, image_name=image_name).all() return response except Exception as ex: @@ -134,8 +145,9 @@ class Database: def get_image_by_hash(self, image_hash: str) -> list[VMImage]: try: - with self.session.begin(): - response = self.session.query(VMImage, image_hash=image_hash) + session = self.Session() + with session.begin(): + response = session.query(VMImage, image_hash=image_hash) return response except Exception as ex: self.logger.error( @@ -143,10 +155,11 @@ class Database: def add_image(self, image: VMImage): try: - with self.session.begin(): - self.session.add(image) - self.session.flush() - self.session.commit() + session = self.Session() + with session.begin(): + session.add(image) + session.flush() + session.commit() except Exception as ex: self.logger.error(f"Couldn't save client data do database: {ex}") raise DatabaseException(f"Couldn't add image to database: {ex}") @@ -154,11 +167,12 @@ class Database: def modify_image(self, new_image_object: VMImage) -> VMImage: try: old_object = self.get_image_by_id(new_image_object.image_id) - with self.session.begin(): + session = self.Session() + with session.begin(): old_object = new_image_object - self.session.merge(old_object) - self.session.flush() - self.session.commit() + session.merge(old_object) + session.flush() + session.commit() return old_object except Exception as ex: self.logger.error(f"Couldn't modify object in database: {ex}") @@ -167,26 +181,33 @@ class Database: def add_user(self, new_user: User): try: - with self.session.begin(): - self.session.add(new_user) - self.session.flush() - self.session.merge() + session = self.Session() + with session.begin(): + session.add(new_user) + session.flush() + session.commit() except Exception as ex: self.logger.error(f"Couldn't add user to the database: {ex}") raise DatabaseException(f"Couldn't add user to the database: {ex}") def get_user_by_id(self, user_id: int) -> User: try: - with self.session.begin(): - return self.session.query(User).filter(User.user_id == user_id).first() + session = self.Session() + with session.begin(): + user = session.query(User).filter( + User.user_id == user_id).first() + return user except Exception as ex: self.logger.error(f"Error getting data from database: {ex}") raise DatabaseException(f"Error getting data from database: {ex}") def get_user_by_name(self, username: str) -> User: try: - with self.session.begin(): - return self.session.query(User).filter(User.username == username).first() + session = self.Session() + with session.begin(): + user = session.query(User).filter( + User.username == username).first() + return user except Exception as ex: self.logger.error(f"Error getting data from database: {ex}") raise DatabaseException(f"Error getting data from database: {ex}") diff --git a/utils/middleware/__pycache__/auth.cpython-310.pyc b/utils/middleware/__pycache__/auth.cpython-310.pyc index b3a9f89..97c9706 100644 Binary files a/utils/middleware/__pycache__/auth.cpython-310.pyc and b/utils/middleware/__pycache__/auth.cpython-310.pyc differ diff --git a/utils/middleware/auth.py b/utils/middleware/auth.py index 921ebaf..017a52e 100644 --- a/utils/middleware/auth.py +++ b/utils/middleware/auth.py @@ -4,6 +4,7 @@ from flask import request, abort from flask import current_app from utils.models.models import User from utils.database.database import Database +from utils.config.config import ServerConfig # Inspired by: https://blog.loginradius.com/engineering/guest-post/securing-flask-api-with-jwt/ [access: 16.11.2022, 18:33 CET] @@ -21,10 +22,11 @@ def require_auth(f): "error": "Unauthorized" }, 401 try: + config = ServerConfig() database = Database( - database_file=current_app.config["DATABASE_FILE"], logging_level=current_app.config["LOGGING_LEVEL"]) + database_file=config.database_file, logging_level=config.server_loglevel) user_data_from_request = jwt.decode( - token, current_app.config["SECRET_KEY"], algorithms=["HS256"]) + token, config.jwt_secret, algorithms=["HS256"]) request_user = database.get_user_by_name( username=user_data_from_request["username"]) if request_user is None: diff --git a/utils/models/__pycache__/models.cpython-310.pyc b/utils/models/__pycache__/models.cpython-310.pyc index 8a7fa42..495f88d 100644 Binary files a/utils/models/__pycache__/models.cpython-310.pyc and b/utils/models/__pycache__/models.cpython-310.pyc differ diff --git a/utils/models/models.py b/utils/models/models.py index 24060bd..b745440 100644 --- a/utils/models/models.py +++ b/utils/models/models.py @@ -1,13 +1,8 @@ from sqlalchemy import Column, Integer, String, ForeignKey, Table from sqlalchemy.orm import relationship, backref -from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from utils.config.config import ServerConfig -config = ServerConfig() -engine = create_engine(f"sqlite:///{config.database_file}") -Base = declarative_base(bind=engine) -Base.metadata.create_all() +Base = declarative_base() client_image_table = Table( "client_image", @@ -34,22 +29,32 @@ class Client(Base): return True return False + def as_dict(self): + return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns} + class VMImage(Base): __tablename__ = "vm_images" image_id = Column(Integer, primary_key=True) - image_name = Column(String(100), unique=True, nullable=False) + image_name = Column(String(100), unique=False, nullable=False) image_file = Column(String(500), unique=False, nullable=False) image_version = Column(String(100), nullable=False) image_hash = Column(String(500), nullable=False) + image_name_version_combo = Column(String(600), nullable=False, unique=True) clients = relationship( "Client", secondary=client_image_table ) + def as_dict(self): + return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns} + class User(Base): __tablename__ = "users" user_id = Column(Integer, primary_key=True, autoincrement=True) username = Column(String) password_hash = Column(String) + + def as_dict(self): + return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns}