diff --git a/network/communication.py b/network/communication.py index a231c0d..83f07cb 100644 --- a/network/communication.py +++ b/network/communication.py @@ -1,8 +1,18 @@ +from http.client import NON_AUTHORITATIVE_INFORMATION from flask import Flask +from utils.config import ServerConfig +import sqlite3 +class Server(): -class Server(host, port, name, access_password, version): - app = Flask(self.name) + def __init__(self, host, port, name, access_password, version, database_file): + self.app = Flask(self.name) + self.host = host + self.port = port + self.name = name + self.access_password = access_password + self.version = version + self.client_database = sqlite3.connect(database_file) @app.route("/") def basic_server_data(self): @@ -10,5 +20,5 @@ class Server(host, port, name, access_password, version): @app.route("/client/register") def register_new_client_to_database(self): - # TODO: implement + pass diff --git a/utils/config/config.py b/utils/config/config.py index 6266634..0aff9b4 100644 --- a/utils/config/config.py +++ b/utils/config/config.py @@ -1,16 +1,32 @@ import os import yaml -class ServerConfig: - config_file = {} +class ServerConfig: + def __init__(self): + config_file = {} - with open("~/.config/orchestrator/config.yml", "r") as stream: - try: - config_file = yaml.safe_load(stream) - except Exception as e: - print(e) + with open("~/.config/orchestrator/config.yml", "r") as stream: + try: + config_file = yaml.safe_load(stream) + except Exception as e: + print(e) - config = config_file + config = config_file - server_name_override = os.environ.get("VALHALLA_SERVER_NAME") - server_port_override = os.environ.get("VALHALLA_SERVER_PORT") + server_name_override = os.environ.get("VALHALLA_SERVER_NAME") + server_port_override = os.environ.get("VALHALLA_SERVER_PORT") + database_file_override = os.environ.get("VALHALLA_DATABASE_FILE") + + if server_name_override: + config["server_name"] = server_name_override + + if server_port_override: + config["server_port"] = server_port_override + + if database_file_override: + config["database_file"] = database_file_override + + self.server_name = config["server_name"] + self.server_port = config["server_port"] + self.database_file = config["database_file"] + \ No newline at end of file diff --git a/utils/database/database.py b/utils/database/database.py index dac2592..6609391 100644 --- a/utils/database/database.py +++ b/utils/database/database.py @@ -1 +1,27 @@ -from simplesqlite import SimpleSQLite +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from utils.models import Client, VMImage +class Database: + def __init__(self, database_file: str): + try: + # Connect to the database using SQLAlchemy + engine = create_engine(f"sqlite:///{database_file}") + Session = sessionmaker() + Session.configure(bind=engine) + self.session = Session() + except Exception as ex: + print(ex) + exit(-1) + + def get_clients(self) -> list(Client): + result = None + try: + result = self.session.query(Client).all() + except Exception as ex: + print(f"Error getting list of clients from database: {ex}") + result = None + return result + + def get_client_by_mac_address(mac_address: str) -> Client: + result = None + \ No newline at end of file diff --git a/utils/exceptions/DatabaseException.py b/utils/exceptions/DatabaseException.py new file mode 100644 index 0000000..c613867 --- /dev/null +++ b/utils/exceptions/DatabaseException.py @@ -0,0 +1,2 @@ +class DatabaseException(Exception): + \ No newline at end of file diff --git a/utils/models/__init__.py b/utils/models/__init__.py new file mode 100644 index 0000000..20f2059 --- /dev/null +++ b/utils/models/__init__.py @@ -0,0 +1 @@ +from utils.models import Client, VMImage \ No newline at end of file diff --git a/utils/models/models.py b/utils/models/models.py new file mode 100644 index 0000000..317651f --- /dev/null +++ b/utils/models/models.py @@ -0,0 +1,43 @@ +from sqlalchemy import Column, Integer, String, ForeignKey, Table +from sqlalchemy.orm import relationship, backref +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + +client_image_table = Table( + "client_image", + Base.metadata, + Column("client_mac", String, ForeignKey("clients.mac_address")), + Column("image_id", Integer, ForeignKey("vm_images.image_id")) +) + +class Client(Base): + __tablename__ = "clients" + mac_address = Column(String, primary_key=True) + ip_address = Column(String) + hostname = Column(String) + client_version = Column(String) + vm_list_on_machine = relationship( + "VMImages", + secondary = client_image_table, + back_populates = "vm_images" + ) + + def has_vm_installed(self, vm_object): + for vm in self.vm_list_on_machine: + if vm.image_hash == vm_object.image_hash: + return True + return False + +class VMImage(Base): + __tablename__ = "vm_images" + image_id = Column(Integer, primary_key=True) + image_name = Column(String) + image_file = Column(String) + image_version = Column(String) + image_hash = Column(String) + clients = relationship( + "Clients", + secondary = client_image_table, + back_populates = "clients" + ) \ No newline at end of file