NXPR-3 Database support

This commit is contained in:
Wojciech Janota 2022-11-18 18:24:30 +01:00
parent 2ccc2e82a3
commit 915fb12f01
No known key found for this signature in database
GPG Key ID: E83FBD2850CC1F14
10 changed files with 129 additions and 64 deletions

Binary file not shown.

View File

@ -1,5 +1,4 @@
#!/usr/bin/python3 #!/usr/bin/python3
from network.communication import Server from network.communication import Server
from utils.config.config import ServerConfig from utils.config.config import ServerConfig

View File

@ -69,21 +69,26 @@ class Server():
temporary_salt = bcrypt.gensalt() temporary_salt = bcrypt.gensalt()
hashed_request_password = bcrypt.hashpw( hashed_request_password = bcrypt.hashpw(
password=request_data["password"].encode("utf-8"), salt=temporary_salt) password=request_data["password"].encode("utf-8"), salt=temporary_salt)
if current_user.password_hash != hashed_request_password: print(
f"Password: {request_data['password']}, password hash from database: {current_user.password_hash}")
# if current_user.password_hash != hashed_request_password:
if not bcrypt.checkpw(request_data["password"].encode("utf-8"), current_user.password_hash):
return { return {
"message": "Invalid login data", "message": "Invalid login data",
"data": None, "data": None,
"error": "Auth error" "error": "Auth error"
}, 401 }, 401
try: try:
current_user["token"] = jwt.encode({ new_token = jwt.encode({
"username": current_user["username"] "username": current_user.username
}, },
self.flask_app.config["SECRET_KEY"], self.flask_app.config["SECRET_KEY"],
algorithm="HS256" algorithm="HS256"
) )
current_user.pop("password") user_dictionary = current_user.as_dict()
return current_user, 202 user_dictionary.pop("password_hash")
user_dictionary["token"] = new_token
return user_dictionary, 202
except Exception as ex: except Exception as ex:
return { return {
"message": "Error loging in", "message": "Error loging in",
@ -98,13 +103,11 @@ class Server():
}, 500 }, 500
@require_auth @require_auth
def register_new_client_to_database(self, request_user): def register_new_client_to_database(request_user, self):
request_content_type = request.headers.get('Content-Type') request_content_type = request.headers.get('Content-Type')
json_string = ""
if request_content_type == 'application/json': if request_content_type == 'application/json':
json_string = request.json json_object = request.json
try: try:
json_object = json.loads(json_string)
new_client_object = Client(mac_address=json_object["mac_address"], ip_address=json_object["ip_address"], hostname=json_object[ new_client_object = Client(mac_address=json_object["mac_address"], ip_address=json_object["ip_address"], hostname=json_object[
"hostname"], client_version=json_object["client_version"], vm_list_on_machine=json_object["vm_list_on_machine"]) "hostname"], client_version=json_object["client_version"], vm_list_on_machine=json_object["vm_list_on_machine"])
self.database.add_client(new_client_object) self.database.add_client(new_client_object)
@ -112,7 +115,38 @@ class Server():
response.status_code = 201 response.status_code = 201
return response return response
except Exception as ex: except Exception as ex:
response = jsonify(success=False) response = jsonify({
"message": "Internal server error",
"data": None,
"error": str(ex)
})
response.status_code = 400
return response
@require_auth
def add_image_to_database(request_user, self):
request_content_type = request.headers.get('Content-Type')
if request_content_type == 'application/json':
json_object = request.json
try:
new_image_object = VMImage(
image_name=json_object["image_name"],
image_file=json_object["image_file"],
image_version=json_object["image_version"],
image_hash=json_object["image_hash"],
image_name_version_combo=f"{json_object['image_name']}@{json_object['image_version']}",
clients=[]
)
self.database.add_client(new_image_object)
response = jsonify(success=True)
response.status_code = 201
return response
except Exception as ex:
response = jsonify({
"message": "Internal server error",
"data": None,
"error": str(ex)
})
response.status_code = 400 response.status_code = 400
return response return response
@ -128,7 +162,11 @@ class Server():
self.database.add_user(admin_user) self.database.add_user(admin_user)
self.app.add_endpoint(endpoint="/", endpoint_name="server_data", self.app.add_endpoint(endpoint="/", endpoint_name="server_data",
handler=self.basic_server_data, methods=["GET"]) handler=self.basic_server_data, methods=["GET"])
self.app.add_endpoint(
endpoint="/login", endpoint_name="login", handler=self.login, methods=["POST"])
self.app.add_endpoint(endpoint="/clients", endpoint_name="register_client", self.app.add_endpoint(endpoint="/clients", endpoint_name="register_client",
handler=self.register_new_client_to_database, methods=["POST"]) handler=self.register_new_client_to_database, methods=["POST"])
self.app.add_endpoint(endpoint="/images", endpoint_name="add_image",
handler=self.add_image_to_database, methods=["POST"])
# TODO: add rest of endpoints # TODO: add rest of endpoints
self.app.run() self.app.run()

View File

@ -1,7 +1,7 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from utils.models.models import Client, VMImage, User from utils.models.models import Client, VMImage, User, Base
from utils.exceptions.DatabaseException import DatabaseException from utils.exceptions.DatabaseException import DatabaseException
import logging import logging
@ -11,9 +11,11 @@ class Database:
try: try:
# Connect to the database using SQLAlchemy # Connect to the database using SQLAlchemy
engine = create_engine(f"sqlite:///{database_file}") engine = create_engine(f"sqlite:///{database_file}")
Session = sessionmaker() self.Session = sessionmaker()
Session.configure(bind=engine) self.Session.configure(bind=engine, expire_on_commit=False)
self.session = Session() self.base = Base
self.base.metadata.create_all(bind=engine)
# session = self.Session()
# create logger using data from config file # create logger using data from config file
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
log_level_mapping_dict = { log_level_mapping_dict = {
@ -32,8 +34,9 @@ class Database:
def get_clients(self) -> list[Client]: def get_clients(self) -> list[Client]:
result = [] result = []
try: try:
with self.session.begin(): session = self.Session()
result = self.session.query(Client).all() with session.begin():
result = session.query(Client).all()
except Exception as ex: except Exception as ex:
self.logger.error( self.logger.error(
f"Error getting list of clients from database: {ex}") f"Error getting list of clients from database: {ex}")
@ -42,9 +45,10 @@ class Database:
def get_client_by_mac_address(self, mac_address: str) -> Client: def get_client_by_mac_address(self, mac_address: str) -> Client:
result = None result = None
session = self.Session()
try: try:
with self.session.begin(): with session.begin():
result = self.session.query( result = session.query(
Client, mac_address=mac_address).first() Client, mac_address=mac_address).first()
except Exception as ex: except Exception as ex:
self.logger.warn(f"Error getting client by mac address: {ex}") self.logger.warn(f"Error getting client by mac address: {ex}")
@ -53,8 +57,9 @@ class Database:
def get_clients_by_client_version(self, client_version: str) -> list[Client]: def get_clients_by_client_version(self, client_version: str) -> list[Client]:
result = [] result = []
try: try:
with self.session.begin(): session = self.Session()
result = self.session.query( with session.begin():
result = session.query(
Client, client_version=client_version).all() Client, client_version=client_version).all()
except Exception as ex: except Exception as ex:
self.logger.warn( self.logger.warn(
@ -76,10 +81,11 @@ class Database:
def add_client(self, client: Client): def add_client(self, client: Client):
try: try:
with self.session.begin(): session = self.Session()
self.session.add(client) with session.begin():
self.session.flush() session.add(client)
self.session.commit() session.flush()
session.commit()
except Exception as ex: except Exception as ex:
self.logger.error(f"Error adding entity to database: {ex}") self.logger.error(f"Error adding entity to database: {ex}")
raise DatabaseException("Error adding entity to database") raise DatabaseException("Error adding entity to database")
@ -87,11 +93,12 @@ class Database:
def modify_client(self, client: Client) -> Client: def modify_client(self, client: Client) -> Client:
try: try:
old_object = self.get_client_by_mac_address(client.mac_address) old_object = self.get_client_by_mac_address(client.mac_address)
with self.session.begin(): session = self.Session()
with session.begin():
old_object = client old_object = client
self.session.merge(old_object) session.merge(old_object)
self.session.flush() session.flush()
self.session.commit() session.commit()
return old_object return old_object
except Exception as ex: except Exception as ex:
self.logger.error(f"Error modifying object in the database: {ex}") self.logger.error(f"Error modifying object in the database: {ex}")
@ -99,15 +106,17 @@ class Database:
def delete_client(self, client: Client): def delete_client(self, client: Client):
try: try:
with self.session.begin(): session = self.Session()
self.session.delete(client) with session.begin():
session.delete(client)
except Exception as ex: except Exception as ex:
self.logger.error(f"Error deleting client from database: {ex}") self.logger.error(f"Error deleting client from database: {ex}")
def get_image_by_id(self, image_id: int) -> VMImage: def get_image_by_id(self, image_id: int) -> VMImage:
try: try:
with self.session.begin(): session = self.Session()
response = self.session.query( with session.begin():
response = session.query(
VMImage, image_id=image_id).first() VMImage, image_id=image_id).first()
return response return response
except Exception as ex: except Exception as ex:
@ -115,8 +124,9 @@ class Database:
def get_images(self) -> list[VMImage]: def get_images(self) -> list[VMImage]:
try: try:
with self.session.begin(): session = self.Session()
response = self.session.query(VMImage).all() with session.begin():
response = session.query(VMImage).all()
return response return response
except Exception as ex: except Exception as ex:
self.logger.error( self.logger.error(
@ -124,8 +134,9 @@ class Database:
def get_image_by_name(self, image_name: str) -> list[VMImage]: def get_image_by_name(self, image_name: str) -> list[VMImage]:
try: try:
with self.session.begin(): session = self.Session()
response = self.session.query( with session.begin():
response = session.query(
VMImage, image_name=image_name).all() VMImage, image_name=image_name).all()
return response return response
except Exception as ex: except Exception as ex:
@ -134,8 +145,9 @@ class Database:
def get_image_by_hash(self, image_hash: str) -> list[VMImage]: def get_image_by_hash(self, image_hash: str) -> list[VMImage]:
try: try:
with self.session.begin(): session = self.Session()
response = self.session.query(VMImage, image_hash=image_hash) with session.begin():
response = session.query(VMImage, image_hash=image_hash)
return response return response
except Exception as ex: except Exception as ex:
self.logger.error( self.logger.error(
@ -143,10 +155,11 @@ class Database:
def add_image(self, image: VMImage): def add_image(self, image: VMImage):
try: try:
with self.session.begin(): session = self.Session()
self.session.add(image) with session.begin():
self.session.flush() session.add(image)
self.session.commit() session.flush()
session.commit()
except Exception as ex: except Exception as ex:
self.logger.error(f"Couldn't save client data do database: {ex}") self.logger.error(f"Couldn't save client data do database: {ex}")
raise DatabaseException(f"Couldn't add image to database: {ex}") raise DatabaseException(f"Couldn't add image to database: {ex}")
@ -154,11 +167,12 @@ class Database:
def modify_image(self, new_image_object: VMImage) -> VMImage: def modify_image(self, new_image_object: VMImage) -> VMImage:
try: try:
old_object = self.get_image_by_id(new_image_object.image_id) old_object = self.get_image_by_id(new_image_object.image_id)
with self.session.begin(): session = self.Session()
with session.begin():
old_object = new_image_object old_object = new_image_object
self.session.merge(old_object) session.merge(old_object)
self.session.flush() session.flush()
self.session.commit() session.commit()
return old_object return old_object
except Exception as ex: except Exception as ex:
self.logger.error(f"Couldn't modify object in database: {ex}") self.logger.error(f"Couldn't modify object in database: {ex}")
@ -167,26 +181,33 @@ class Database:
def add_user(self, new_user: User): def add_user(self, new_user: User):
try: try:
with self.session.begin(): session = self.Session()
self.session.add(new_user) with session.begin():
self.session.flush() session.add(new_user)
self.session.merge() session.flush()
session.commit()
except Exception as ex: except Exception as ex:
self.logger.error(f"Couldn't add user to the database: {ex}") self.logger.error(f"Couldn't add user to the database: {ex}")
raise DatabaseException(f"Couldn't add user to the database: {ex}") raise DatabaseException(f"Couldn't add user to the database: {ex}")
def get_user_by_id(self, user_id: int) -> User: def get_user_by_id(self, user_id: int) -> User:
try: try:
with self.session.begin(): session = self.Session()
return self.session.query(User).filter(User.user_id == user_id).first() with session.begin():
user = session.query(User).filter(
User.user_id == user_id).first()
return user
except Exception as ex: except Exception as ex:
self.logger.error(f"Error getting data from database: {ex}") self.logger.error(f"Error getting data from database: {ex}")
raise DatabaseException(f"Error getting data from database: {ex}") raise DatabaseException(f"Error getting data from database: {ex}")
def get_user_by_name(self, username: str) -> User: def get_user_by_name(self, username: str) -> User:
try: try:
with self.session.begin(): session = self.Session()
return self.session.query(User).filter(User.username == username).first() with session.begin():
user = session.query(User).filter(
User.username == username).first()
return user
except Exception as ex: except Exception as ex:
self.logger.error(f"Error getting data from database: {ex}") self.logger.error(f"Error getting data from database: {ex}")
raise DatabaseException(f"Error getting data from database: {ex}") raise DatabaseException(f"Error getting data from database: {ex}")

View File

@ -4,6 +4,7 @@ from flask import request, abort
from flask import current_app from flask import current_app
from utils.models.models import User from utils.models.models import User
from utils.database.database import Database from utils.database.database import Database
from utils.config.config import ServerConfig
# Inspired by: https://blog.loginradius.com/engineering/guest-post/securing-flask-api-with-jwt/ [access: 16.11.2022, 18:33 CET] # Inspired by: https://blog.loginradius.com/engineering/guest-post/securing-flask-api-with-jwt/ [access: 16.11.2022, 18:33 CET]
@ -21,10 +22,11 @@ def require_auth(f):
"error": "Unauthorized" "error": "Unauthorized"
}, 401 }, 401
try: try:
config = ServerConfig()
database = Database( database = Database(
database_file=current_app.config["DATABASE_FILE"], logging_level=current_app.config["LOGGING_LEVEL"]) database_file=config.database_file, logging_level=config.server_loglevel)
user_data_from_request = jwt.decode( user_data_from_request = jwt.decode(
token, current_app.config["SECRET_KEY"], algorithms=["HS256"]) token, config.jwt_secret, algorithms=["HS256"])
request_user = database.get_user_by_name( request_user = database.get_user_by_name(
username=user_data_from_request["username"]) username=user_data_from_request["username"])
if request_user is None: if request_user is None:

View File

@ -1,13 +1,8 @@
from sqlalchemy import Column, Integer, String, ForeignKey, Table from sqlalchemy import Column, Integer, String, ForeignKey, Table
from sqlalchemy.orm import relationship, backref from sqlalchemy.orm import relationship, backref
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from utils.config.config import ServerConfig
config = ServerConfig() Base = declarative_base()
engine = create_engine(f"sqlite:///{config.database_file}")
Base = declarative_base(bind=engine)
Base.metadata.create_all()
client_image_table = Table( client_image_table = Table(
"client_image", "client_image",
@ -34,22 +29,32 @@ class Client(Base):
return True return True
return False return False
def as_dict(self):
return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns}
class VMImage(Base): class VMImage(Base):
__tablename__ = "vm_images" __tablename__ = "vm_images"
image_id = Column(Integer, primary_key=True) image_id = Column(Integer, primary_key=True)
image_name = Column(String(100), unique=True, nullable=False) image_name = Column(String(100), unique=False, nullable=False)
image_file = Column(String(500), unique=False, nullable=False) image_file = Column(String(500), unique=False, nullable=False)
image_version = Column(String(100), nullable=False) image_version = Column(String(100), nullable=False)
image_hash = Column(String(500), nullable=False) image_hash = Column(String(500), nullable=False)
image_name_version_combo = Column(String(600), nullable=False, unique=True)
clients = relationship( clients = relationship(
"Client", "Client",
secondary=client_image_table secondary=client_image_table
) )
def as_dict(self):
return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns}
class User(Base): class User(Base):
__tablename__ = "users" __tablename__ = "users"
user_id = Column(Integer, primary_key=True, autoincrement=True) user_id = Column(Integer, primary_key=True, autoincrement=True)
username = Column(String) username = Column(String)
password_hash = Column(String) password_hash = Column(String)
def as_dict(self):
return {c.name: str(getattr(self, c.name)) for c in self.__table__.columns}