diff --git a/config.yml b/config.yml index 7f1a02d..2e9073a 100644 --- a/config.yml +++ b/config.yml @@ -4,4 +4,5 @@ server_loglevel: "INFO" database_file: "database.db" server_host: "localhost" server_password: "sekret_password" -server_access_username: "user" \ No newline at end of file +server_access_username: "user" +jwt_secret: "sekrit" \ No newline at end of file diff --git a/database.db b/database.db new file mode 100644 index 0000000..e69de29 diff --git a/fleetcontrol b/fleetcontrol index 05192c6..6c508a3 100644 --- a/fleetcontrol +++ b/fleetcontrol @@ -5,7 +5,7 @@ from utils.config.config import ServerConfig config = ServerConfig() server = Server(host=config.server_host, port=config.server_port, name=config.server_name, access_password=config.server_password, - access_username=config.server_access_username, version="v0.0.1alpha", database_file_path=config.database_file, logging_level=config.server_loglevel) + access_username=config.server_access_username, jwt_secret=config.jwt_secret, version="v0.0.1alpha", database_file_path=config.database_file, logging_level=config.server_loglevel) server.run() diff --git a/network/__pycache__/communication.cpython-310.pyc b/network/__pycache__/communication.cpython-310.pyc index 8c65223..e8db3a4 100644 Binary files a/network/__pycache__/communication.cpython-310.pyc and b/network/__pycache__/communication.cpython-310.pyc differ diff --git a/network/communication.py b/network/communication.py index 009c473..787045f 100644 --- a/network/communication.py +++ b/network/communication.py @@ -2,8 +2,11 @@ from http.client import NON_AUTHORITATIVE_INFORMATION from flask import Flask, request, jsonify from utils.database.database import Database from utils.exceptions import DatabaseException -from utils.models.models import Client +from utils.models.models import Client, VMImage, User +from utils.middleware.auth import require_auth import json +import bcrypt +import jwt class FlaskAppWrapper(object): @@ -26,7 +29,7 @@ class FlaskAppWrapper(object): class Server(): - def __init__(self, host: str, port: int, name: str, access_password: str, access_username: str, version: str, database_file_path: str, logging_level: str, ): + def __init__(self, host: str, port: int, name: str, access_password: str, access_username: str, jwt_secret: str, version: str, database_file_path: str, logging_level: str, ): self.host = host self.port = port self.name = name @@ -36,12 +39,66 @@ class Server(): self.database = Database( database_file=database_file_path, logging_level=logging_level) self.flask_app = Flask(name) + self.flask_app.config['SECRET_KEY'] = jwt_secret + self.flask_app.config['DATABASE_FILE_PATH'] = database_file_path + self.flask_app.config['LOGGING_LEVEL'] = logging_level self.app = FlaskAppWrapper(self.flask_app) def basic_server_data(self): return {"server_name": self.name, "server_version": self.version, "host": self.host} - def register_new_client_to_database(self): + def login(self): + try: + request_data = request.json + if not request_data: + return { + "message": "Please provide login info", + "data": None, + "error": "Bad request" + }, 400 + + current_user = self.database.get_user_by_name( + request_data["username"]) + if current_user is None: + return { + "message": "Please provide correct login info", + "data": None, + "error": "Bad request" + }, 400 + correct_password = False + temporary_salt = bcrypt.gensalt() + hashed_request_password = bcrypt.hashpw( + password=request_data["password"].encode("utf-8"), salt=temporary_salt) + if current_user.password_hash != hashed_request_password: + return { + "message": "Invalid login data", + "data": None, + "error": "Auth error" + }, 401 + try: + current_user["token"] = jwt.encode({ + "username": current_user["username"] + }, + self.flask_app.config["SECRET_KEY"], + algorithm="HS256" + ) + current_user.pop("password") + return current_user, 202 + except Exception as ex: + return { + "message": "Error loging in", + "data": None, + "error": f"Internal server error: {str(ex)}", + }, 500 + except Exception as ex: + return { + "message": "Error loging in", + "data": None, + "error": f"Internal server error: {str(ex)}", + }, 500 + + @require_auth + def register_new_client_to_database(self, request_user): request_content_type = request.headers.get('Content-Type') json_string = "" if request_content_type == 'application/json': @@ -60,6 +117,15 @@ class Server(): return response def run(self): + # add admin user to dataabse (or update existing one) + salt = bcrypt.gensalt() + temp_password_hash = bcrypt.hashpw( + self.access_password.encode("utf-8"), salt) + admin_user = self.database.get_user_by_name(self.access_username) + if admin_user == None: + admin_user = User(username=self.access_username, + password_hash=temp_password_hash) + self.database.add_user(admin_user) self.app.add_endpoint(endpoint="/", endpoint_name="server_data", handler=self.basic_server_data, methods=["GET"]) self.app.add_endpoint(endpoint="/clients", endpoint_name="register_client", diff --git a/utils/config/__pycache__/config.cpython-310.pyc b/utils/config/__pycache__/config.cpython-310.pyc index 1d71752..a88902d 100644 Binary files a/utils/config/__pycache__/config.cpython-310.pyc and b/utils/config/__pycache__/config.cpython-310.pyc differ diff --git a/utils/config/config.py b/utils/config/config.py index 7a5d0e4..01f2844 100644 --- a/utils/config/config.py +++ b/utils/config/config.py @@ -18,6 +18,7 @@ class ServerConfig: server_port_override = os.environ.get("VALHALLA_SERVER_PORT") server_host_override = os.environ.get("VALHALLA_SERVER_HOST") server_password_override = os.environ.get("VALHALLA_SERVER_PASSWORD") + jwt_secret_override = os.environ.get("VALHALLA_JWT_SECRET") server_access_username_override = os.environ.get( "VALHALLA_SERVER_ACCESS_USERNAME") database_file_override = os.environ.get("VALHALLA_DATABASE_FILE") @@ -38,6 +39,9 @@ class ServerConfig: if server_password_override: config["server_password"] = server_password_override + if jwt_secret_override: + config["jwt_secret"] = jwt_secret_override + if server_access_username_override: config["server_access_username"] = server_access_username_override @@ -49,5 +53,6 @@ class ServerConfig: self.database_file = config["database_file"] self.server_host = config["server_host"] self.server_password = config["server_password"] + self.jwt_secret = config["jwt_secret"] self.server_access_username = config["server_access_username"] self.server_loglevel = config["server_loglevel"] diff --git a/utils/database/__pycache__/database.cpython-310.pyc b/utils/database/__pycache__/database.cpython-310.pyc index 772f4a4..a861f46 100644 Binary files a/utils/database/__pycache__/database.cpython-310.pyc and b/utils/database/__pycache__/database.cpython-310.pyc differ diff --git a/utils/database/database.py b/utils/database/database.py index 1753959..e4bc5dc 100644 --- a/utils/database/database.py +++ b/utils/database/database.py @@ -1,6 +1,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from utils.models.models import Client, VMImage +from sqlalchemy.ext.declarative import declarative_base +from utils.models.models import Client, VMImage, User from utils.exceptions.DatabaseException import DatabaseException import logging @@ -163,3 +164,29 @@ class Database: self.logger.error(f"Couldn't modify object in database: {ex}") raise DatabaseException( f"Couldn't modify object in database: {ex}") + + def add_user(self, new_user: User): + try: + with self.session.begin(): + self.session.add(new_user) + self.session.flush() + self.session.merge() + except Exception as ex: + self.logger.error(f"Couldn't add user to the database: {ex}") + raise DatabaseException(f"Couldn't add user to the database: {ex}") + + def get_user_by_id(self, user_id: int) -> User: + try: + with self.session.begin(): + return self.session.query(User).filter(User.user_id == user_id).first() + except Exception as ex: + self.logger.error(f"Error getting data from database: {ex}") + raise DatabaseException(f"Error getting data from database: {ex}") + + def get_user_by_name(self, username: str) -> User: + try: + with self.session.begin(): + return self.session.query(User).filter(User.username == username).first() + except Exception as ex: + self.logger.error(f"Error getting data from database: {ex}") + raise DatabaseException(f"Error getting data from database: {ex}") diff --git a/utils/middleware/__init__.py b/utils/middleware/__init__.py new file mode 100644 index 0000000..25d9ef7 --- /dev/null +++ b/utils/middleware/__init__.py @@ -0,0 +1 @@ +from . import auth diff --git a/utils/middleware/__pycache__/__init__.cpython-310.pyc b/utils/middleware/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..8746ad3 Binary files /dev/null and b/utils/middleware/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/middleware/__pycache__/auth.cpython-310.pyc b/utils/middleware/__pycache__/auth.cpython-310.pyc new file mode 100644 index 0000000..b3a9f89 Binary files /dev/null and b/utils/middleware/__pycache__/auth.cpython-310.pyc differ diff --git a/utils/middleware/auth.py b/utils/middleware/auth.py new file mode 100644 index 0000000..921ebaf --- /dev/null +++ b/utils/middleware/auth.py @@ -0,0 +1,45 @@ +from functools import wraps +import jwt +from flask import request, abort +from flask import current_app +from utils.models.models import User +from utils.database.database import Database + +# Inspired by: https://blog.loginradius.com/engineering/guest-post/securing-flask-api-with-jwt/ [access: 16.11.2022, 18:33 CET] + + +def require_auth(f): + @wraps(f) + def decorated(*args, **kwargs): + token = None + if "Authorization" in request.headers: + token = request.headers["Authorization"].split(" ")[1] + if not token: + return { + "message": "Missing auth token", + "data": None, + "error": "Unauthorized" + }, 401 + try: + database = Database( + database_file=current_app.config["DATABASE_FILE"], logging_level=current_app.config["LOGGING_LEVEL"]) + user_data_from_request = jwt.decode( + token, current_app.config["SECRET_KEY"], algorithms=["HS256"]) + request_user = database.get_user_by_name( + username=user_data_from_request["username"]) + if request_user is None: + return { + "message": "Invalid auth token", + "data": None, + "error": "Unauthorized" + }, 403 + except Exception as ex: + return { + "message": "Internal server error", + "data": None, + "error": str(ex) + }, 500 + + return f(request_user, *args, **kwargs) + + return decorated diff --git a/utils/models/__pycache__/models.cpython-310.pyc b/utils/models/__pycache__/models.cpython-310.pyc index 62ec91e..8a7fa42 100644 Binary files a/utils/models/__pycache__/models.cpython-310.pyc and b/utils/models/__pycache__/models.cpython-310.pyc differ diff --git a/utils/models/models.py b/utils/models/models.py index 01feba7..24060bd 100644 --- a/utils/models/models.py +++ b/utils/models/models.py @@ -1,8 +1,13 @@ from sqlalchemy import Column, Integer, String, ForeignKey, Table from sqlalchemy.orm import relationship, backref +from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base +from utils.config.config import ServerConfig -Base = declarative_base() +config = ServerConfig() +engine = create_engine(f"sqlite:///{config.database_file}") +Base = declarative_base(bind=engine) +Base.metadata.create_all() client_image_table = Table( "client_image", @@ -19,9 +24,8 @@ class Client(Base): hostname = Column(String(100), nullable=False) client_version = Column(String(100), nullable=False) vm_list_on_machine = relationship( - "VMImages", + "VMImage", secondary=client_image_table, - back_populates="vm_images" ) def has_vm_installed(self, vm_object): @@ -37,16 +41,15 @@ class VMImage(Base): image_name = Column(String(100), unique=True, nullable=False) image_file = Column(String(500), unique=False, nullable=False) image_version = Column(String(100), nullable=False) - image_hash = Column(String(500), nullalbe=False) + image_hash = Column(String(500), nullable=False) clients = relationship( - "Clients", - secondary=client_image_table, - back_populates="clients" + "Client", + secondary=client_image_table ) class User(Base): __tablename__ = "users" - user_id = Column(Integer, primary_key=True) + user_id = Column(Integer, primary_key=True, autoincrement=True) username = Column(String) password_hash = Column(String)