diff --git a/database.db b/database.db index 22f568f..5fcb2b7 100644 Binary files a/database.db and b/database.db differ diff --git a/fleetcontrol b/fleetcontrol index a3ebb26..f5c4a8f 100644 --- a/fleetcontrol +++ b/fleetcontrol @@ -58,6 +58,14 @@ def add_image(image_name: str, image_file: str, image_version: str): logger.error(f"Error adding image to the database: {str(ex)}") exit(-1) +def remove_image(image_name: str, image_version: str): + try: + db = Database(config.database_file, config.server_loglevel) + obj_to_remove = db.get_image_by_name_version_string(f"{image_name}@{image_version}") + db.delete_image(obj_to_remove) + except Exception as ex: + logger.error(f"Error removing image from the database: {str(ex)}") + exit(-1) def assign_image(image_name: str, image_version: str, client_mac_address: str): try: @@ -65,6 +73,14 @@ def assign_image(image_name: str, image_version: str, client_mac_address: str): db.assign_image_to_client(client_mac_address=client_mac_address, image_name_version_combo=f"{image_name}@{image_version}") except Exception as ex: logger.error(f"Error assigning image to a client: {str(ex)}") + exit(-1) + +def detach_image(image_name: str, image_version: str, client_mac_address: str): + try: + db = Database(config.database_file, config.server_loglevel) + db.detach_image_from_client(client_mac_address=client_mac_address, image_name_version_combo=f"{image_name}@{image_version}") + except Exception as ex: + logger.error(f"Error detaching image from the client {client_mac_address}; error was {str(ex)}") parser = argparse.ArgumentParser( @@ -75,7 +91,9 @@ parser = argparse.ArgumentParser( function_mapper = { "run": run_server, "add_image": add_image, + "remove_image": remove_image, "assign_image": assign_image, + "detach_image": detach_image } parser.add_argument("command", choices=function_mapper) @@ -95,12 +113,23 @@ if "add_image" == args.command: image_file=args.image_filepath, image_version=args.image_version, ) +elif "remove_image" == args.command: + fun( + image_name=args.image_name, + image_version=args.image_version + ) elif "assign_image" == args.command: fun( image_name=args.image_name, image_version=args.image_version, client_mac_address=args.mac_address, ) +elif "detach_image" == args.command: + fun( + image_name=args.image_name, + image_version=args.image_version, + client_mac_address=args.mac_address, + ) elif "run" == args.command: fun() else: diff --git a/network/__pycache__/communication.cpython-310.pyc b/network/__pycache__/communication.cpython-310.pyc index fab1572..4f01d07 100644 Binary files a/network/__pycache__/communication.cpython-310.pyc and b/network/__pycache__/communication.cpython-310.pyc differ diff --git a/network/communication.py b/network/communication.py index eee77fb..494fe25 100644 --- a/network/communication.py +++ b/network/communication.py @@ -143,19 +143,43 @@ class Server(): return response except Exception as ex: response = jsonify({ - "message": "Internal server error", + "message": "Bad input", "data": None, "error": str(ex) }) response.status_code = 400 return response + @require_auth def update_client_data(request_user, self): request_content_type = request.headers.get('Content-Type') if request_content_type == 'application/json': json_object = request.json try: - pass + old_client: Client = self.database.get_client_by_mac_address(json_object["mac_address"]) + if old_client == None: + response = jsonify({ + "message": "client not found", + "data": None, + "error": None + }) + response.status_code = 404 + return response + new_client: Client = Client( + mac_address=json_object["mac_address"], + ip_address=json_object["ip_address"], + hostname=json_object["hostname"], + client_version=json_object["client_version"], + vm_list_on_machine=[] + ) + self.database.modify_client(new_client) + response = jsonify({ + "message": "Data updated", + "data": None, + "error": None + }) + response.status_code = 201 + return response except Exception as ex: response = jsonify({ "message": "Internal server error", @@ -164,6 +188,35 @@ class Server(): }) response.status_code = 400 return response + + @require_auth + def get_client_data(request_user, self, client_mac_address): + try: + client_data = self.database.get_client_by_mac_address(client_mac_address) + return jsonify(client_data.as_dict()) + except Exception as ex: + response = jsonify({ + "message": "Internal server error", + "data": None, + "error": str(ex) + }) + response.status_code = 500 + return response + + @require_auth + def get_client_list_of_vms(request_user, self, client_mac_address): + try: + vm_ids_list = self.database.get_client_vm_list_by_mac_address(client_mac_address) + return jsonify(vm_ids_list) + except Exception as ex: + response = jsonify({ + "message": "Internal server error", + "data": None, + "error": str(ex) + }) + response.status_code = 500 + return response + def run(self): # add admin user to dataabse (or update existing one) @@ -183,5 +236,8 @@ class Server(): handler=self.register_new_client_to_database, methods=["POST"]) self.app.add_endpoint(endpoint="/images", endpoint_name="add_image", handler=self.add_image_to_database, methods=["POST"]) + self.app.add_endpoint(endpoint="/clients", endpoint_name="update_client", handler=self.update_client_data, methods=["PUT"]) + self.app.add_endpoint(endpoint="/clients/", endpoint_name="get_client_data", handler=self.get_client_data, methods=["GET"]) + self.app.add_endpoint(endpoint="/clients//vms", endpoint_name="get_client_vms_list", handler=self.get_client_list_of_vms, methods=["GET"]) # TODO: add rest of endpoints self.app.run() diff --git a/utils/database/__pycache__/database.cpython-310.pyc b/utils/database/__pycache__/database.cpython-310.pyc index b437982..d36a1f2 100644 Binary files a/utils/database/__pycache__/database.cpython-310.pyc and b/utils/database/__pycache__/database.cpython-310.pyc differ diff --git a/utils/database/database.py b/utils/database/database.py index 42de3ce..da85caa 100644 --- a/utils/database/database.py +++ b/utils/database/database.py @@ -24,7 +24,7 @@ class Database: "INFO": 20, "WARNING": 30, "ERROR": 40, - "CRITICAL": 50 + "CRITICAL": 50, } self.logger.setLevel(log_level_mapping_dict[logging_level]) except Exception as ex: @@ -38,18 +38,39 @@ class Database: with session.begin(): result = session.query(Client).all() except Exception as ex: - self.logger.error( - f"Error getting list of clients from database: {ex}") + self.logger.error(f"Error getting list of clients from database: {ex}") result = [] return result def get_client_by_mac_address(self, mac_address: str) -> Client: result = None session = self.Session() + session.expire_on_commit = False + try: + with session.begin(): + result = ( + session.query(Client) + .filter(Client.mac_address == mac_address) + .first() + ) + except Exception as ex: + self.logger.warn(f"Error getting client by mac address: {ex}") + return result + + def get_client_vm_list_by_mac_address(self, mac_address: str): + result = None + session = self.Session() + session.expire_on_commit = False try: with session.begin(): - result = session.query( - Client).filter(Client.mac_address==mac_address).first() + client = ( + session.query(Client) + .filter(Client.mac_address == mac_address) + .first() + ) + result = [] + for vm in client.vm_list_on_machine: + result.append(vm.image_id) except Exception as ex: self.logger.warn(f"Error getting client by mac address: {ex}") return result @@ -59,11 +80,9 @@ class Database: try: session = self.Session() with session.begin(): - result = session.query( - Client, client_version=client_version).all() + result = session.query(Client, client_version=client_version).all() except Exception as ex: - self.logger.warn( - f"Error getting client list by software version: {ex}") + self.logger.warn(f"Error getting client list by software version: {ex}") return result def get_clients_by_vm_image(self, vm_image: VMImage) -> list[Client]: @@ -74,8 +93,7 @@ class Database: if client.has_vm_installed(vm_image.image_hash): result.append() except Exception as ex: - self.logger.warn( - f"Error getting list of clients with VM installed: {ex}") + self.logger.warn(f"Error getting list of clients with VM installed: {ex}") result = [] return result @@ -92,10 +110,12 @@ class Database: def modify_client(self, client: Client) -> Client: try: - old_object = self.get_client_by_mac_address(client.mac_address) session = self.Session() with session.begin(): - old_object = client + old_object: Client = session.query(Client).filter(Client.mac_address==client.mac_address).first() + old_object.ip_address = client.ip_address + old_object.hostname = client.hostname + old_object.client_version = client.client_version session.merge(old_object) session.flush() session.commit() @@ -116,8 +136,9 @@ class Database: try: session = self.Session() with session.begin(): - response = session.query( - VMImage, image_id=image_id).first() + response = ( + session.query(VMImage).filter(VMImage.image_id == image_id).first() + ) return response except Exception as ex: self.logger.error(f"Error getting image data from database: {ex}") @@ -129,40 +150,50 @@ class Database: response = session.query(VMImage).all() return response except Exception as ex: - self.logger.error( - f"Error getting list of images from database: {ex}") + self.logger.error(f"Error getting list of images from database: {ex}") def get_image_by_name(self, image_name: str) -> list[VMImage]: try: session = self.Session() with session.begin(): - response = session.query( - VMImage, image_name=image_name).all() + response = ( + session.query(VMImage) + .filter(VMImage.image_name == image_name) + .all() + ) return response except Exception as ex: - self.logger.error( - f"Error getting list of images from database: {ex}") - - def get_image_by_name_version_string(self, image_name_version_string: str) -> list[VMImage]: + self.logger.error(f"Error getting list of images from database: {ex}") + + def get_image_by_name_version_string( + self, image_name_version_string: str + ) -> list[VMImage]: try: session = self.Session() with session.begin(): - response = session.query( - VMImage).filter(VMImage.image_name_version_combo==image_name_version_string).first() + response = ( + session.query(VMImage) + .filter( + VMImage.image_name_version_combo == image_name_version_string + ) + .first() + ) return response except Exception as ex: - self.logger.error( - f"Error getting list of images from database: {ex}") + self.logger.error(f"Error getting list of images from database: {ex}") def get_image_by_hash(self, image_hash: str) -> list[VMImage]: try: session = self.Session() with session.begin(): - response = session.query(VMImage, image_hash=image_hash) + response = ( + session.query(VMImage) + .filter(VMImage.image_hash == image_hash) + .first() + ) return response except Exception as ex: - self.logger.error( - f"Error getting list of images with specified hash: {ex}") + self.logger.error(f"Error getting list of images with specified hash: {ex}") def add_image(self, image: VMImage): try: @@ -187,15 +218,41 @@ class Database: return old_object except Exception as ex: self.logger.error(f"Couldn't modify object in database: {ex}") + raise DatabaseException(f"Couldn't modify object in database: {ex}") + + def delete_image(self, image_to_delete: VMImage): + try: + session = self.Session() + with session.begin(): + session.delete(image_to_delete) + session.flush() + session.commit() + except Exception as ex: + self.logger.error( + f"Error deleting image with id={image_to_delete.image_id}: {str(ex)}" + ) raise DatabaseException( - f"Couldn't modify object in database: {ex}") + f"Error deleting image with id={image_to_delete.image_id}: {str(ex)}" + ) - def assign_image_to_client(self, client_mac_address: str, image_name_version_combo: str): + def assign_image_to_client( + self, client_mac_address: str, image_name_version_combo: str + ): try: session = self.Session() with session.begin(): - client = session.query(Client).filter(Client.mac_address==client_mac_address).first() - image = session.query(VMImage).filter(VMImage.image_name_version_combo==image_name_version_combo).first() + client = ( + session.query(Client) + .filter(Client.mac_address == client_mac_address) + .first() + ) + image = ( + session.query(VMImage) + .filter( + VMImage.image_name_version_combo == image_name_version_combo + ) + .first() + ) client.vm_list_on_machine.append(image) session.merge(client) session.flush() @@ -204,6 +261,34 @@ class Database: self.logger.error(f"Couldn't add image to client list: {str(ex)}") raise DatabaseException(f"Couldn't add image to client list: {str(ex)}") + def detach_image_from_client( + self, client_mac_address: str, image_name_version_combo: str + ): + try: + session = self.Session() + with session.begin(): + client = ( + session.query(Client) + .filter(Client.mac_address == client_mac_address) + .first() + ) + image = ( + session.query(VMImage) + .filter( + VMImage.image_name_version_combo == image_name_version_combo + ) + .first() + ) + client.vm_list_on_machine.remove(image) + session.merge(client) + session.flush() + session.commit() + except Exception as ex: + self.logger.error(f"Couldn't remove image from client list: {str(ex)}") + raise DatabaseException( + f"Couldn't remove image from client list: {str(ex)}" + ) + def add_user(self, new_user: User): try: session = self.Session() @@ -219,8 +304,7 @@ class Database: try: session = self.Session() with session.begin(): - user = session.query(User).filter( - User.user_id == user_id).first() + user = session.query(User).filter(User.user_id == user_id).first() return user except Exception as ex: self.logger.error(f"Error getting data from database: {ex}") @@ -230,8 +314,7 @@ class Database: try: session = self.Session() with session.begin(): - user = session.query(User).filter( - User.username == username).first() + user = session.query(User).filter(User.username == username).first() return user except Exception as ex: self.logger.error(f"Error getting data from database: {ex}")