diff --git a/network/communication.py b/network/communication.py index f6ade2b..009c473 100644 --- a/network/communication.py +++ b/network/communication.py @@ -1,7 +1,9 @@ from http.client import NON_AUTHORITATIVE_INFORMATION -from flask import Flask +from flask import Flask, request, jsonify from utils.database.database import Database from utils.exceptions import DatabaseException +from utils.models.models import Client +import json class FlaskAppWrapper(object): @@ -40,11 +42,27 @@ class Server(): return {"server_name": self.name, "server_version": self.version, "host": self.host} def register_new_client_to_database(self): - # TODO: implement - self.database + request_content_type = request.headers.get('Content-Type') + json_string = "" + if request_content_type == 'application/json': + json_string = request.json + try: + json_object = json.loads(json_string) + new_client_object = Client(mac_address=json_object["mac_address"], ip_address=json_object["ip_address"], hostname=json_object[ + "hostname"], client_version=json_object["client_version"], vm_list_on_machine=json_object["vm_list_on_machine"]) + self.database.add_client(new_client_object) + response = jsonify(success=True) + response.status_code = 201 + return response + except Exception as ex: + response = jsonify(success=False) + response.status_code = 400 + return response def run(self): self.app.add_endpoint(endpoint="/", endpoint_name="server_data", handler=self.basic_server_data, methods=["GET"]) + self.app.add_endpoint(endpoint="/clients", endpoint_name="register_client", + handler=self.register_new_client_to_database, methods=["POST"]) # TODO: add rest of endpoints self.app.run() diff --git a/utils/models/models.py b/utils/models/models.py index 317651f..01feba7 100644 --- a/utils/models/models.py +++ b/utils/models/models.py @@ -11,33 +11,42 @@ client_image_table = Table( Column("image_id", Integer, ForeignKey("vm_images.image_id")) ) + class Client(Base): __tablename__ = "clients" mac_address = Column(String, primary_key=True) - ip_address = Column(String) - hostname = Column(String) - client_version = Column(String) + ip_address = Column(String(16), nullable=False) + hostname = Column(String(100), nullable=False) + client_version = Column(String(100), nullable=False) vm_list_on_machine = relationship( "VMImages", - secondary = client_image_table, - back_populates = "vm_images" + secondary=client_image_table, + back_populates="vm_images" ) - + def has_vm_installed(self, vm_object): for vm in self.vm_list_on_machine: if vm.image_hash == vm_object.image_hash: return True return False + class VMImage(Base): __tablename__ = "vm_images" image_id = Column(Integer, primary_key=True) - image_name = Column(String) - image_file = Column(String) - image_version = Column(String) - image_hash = Column(String) + image_name = Column(String(100), unique=True, nullable=False) + image_file = Column(String(500), unique=False, nullable=False) + image_version = Column(String(100), nullable=False) + image_hash = Column(String(500), nullalbe=False) clients = relationship( "Clients", - secondary = client_image_table, - back_populates = "clients" - ) \ No newline at end of file + secondary=client_image_table, + back_populates="clients" + ) + + +class User(Base): + __tablename__ = "users" + user_id = Column(Integer, primary_key=True) + username = Column(String) + password_hash = Column(String)