1
0
mirror of https://github.com/MarkParker5/STARK.git synced 2024-11-24 08:12:13 +02:00

hub endpoint: async db

update requirments.txt
This commit is contained in:
MarkParker5 2022-06-20 17:20:25 +02:00
parent 030317c99f
commit c5014f0d5d
No known key found for this signature in database
GPG Key ID: C87FA4BD47B5A169
24 changed files with 449 additions and 189 deletions

View File

@ -0,0 +1,67 @@
import time
from abc import ABC, abstractmethod
from uuid import UUID
from enum import Enum
from pydantic import BaseModel
from jose import jwt, JWTError
import config
import exceptions
class TokenType(str, Enum):
access = 'access'
refresh = 'refresh'
class TokensPair(BaseModel):
access_token: str
refresh_token: str
class HubAuthItem(TokensPair):
public_key: str
class BaseToken(BaseModel):
type: TokenType
exp: float
class BaseAuthManager(ABC):
@abstractmethod
def get_tokens_pair(self, user_id: UUID, is_admin: bool) -> (str, str):
pass
@abstractmethod
def _get_token(self, payload: dict) -> BaseToken:
pass
def validate_refresh(self, token: str) -> BaseToken:
if (token := self._validate_token(token)) and self._is_refresh_valid(token):
return token
raise exceptions.invalid_token
def validate_access(self, token: str) -> BaseToken:
if (token := self._validate_token(token)) and self._is_access_valid(token):
return token
raise exceptions.invalid_token
def _validate_token(self, token: str) -> BaseToken:
try:
payload = jwt.decode(token, config.public_key, algorithms=[config.algorithm])
token = self._get_token(payload)
except JWTError as e:
raise exceptions.invalid_token
if not token or not self._is_valid_token(token):
raise exceptions.invalid_token
return token
def _is_access_valid(self, token: BaseToken) -> bool:
return token.type == TokenType.access
def _is_refresh_valid(self, token: BaseToken) -> bool:
return token.type == TokenType.refresh
def _is_valid_token(self, token: BaseToken) -> bool:
return time.time() <= token.exp

View File

@ -0,0 +1,33 @@
import time
from uuid import UUID
from datetime import timedelta
from jose import jwt
import config
from .BaseAuth import BaseAuthManager, BaseToken, TokenType
class UserToken(BaseToken):
user_id: UUID
is_admin: bool
class UserAuthManager(BaseAuthManager):
def get_tokens_pair(self, user_id: UUID, is_admin: bool) -> (str, str):
access_token = self._get_token_str(TokenType.access, user_id, is_admin, config.access_token_lifetime)
refresh_token = self._get_token_str(TokenType.refresh, user_id, is_admin, config.refresh_token_lifetime)
return access_token, refresh_token
def _get_token(self, payload: dict) -> UserToken:
return UserToken(**payload)
def _is_valid_token(self, token: UserToken) -> bool:
return token.user_id != None
def _get_token_str(self, type: TokenType, user_id: UUID, is_admin: bool, lifetime: timedelta) -> str:
payload = {
'type': type.value,
'user_id': user_id.hex,
'is_admin': is_admin,
'exp': time.time() + lifetime.seconds
}
return jwt.encode(payload, config.secret_key, algorithm = config.algorithm)

View File

@ -0,0 +1,9 @@
from . import passwords
from .BaseAuth import (
TokenType,
TokensPair
)
from .UserAuth import (
UserToken,
UserAuthManager
)

View File

@ -11,22 +11,34 @@ __all__ = [
'create_session',
'create_async_session',
# dependencies
# 'get_session',
# 'get_async_session',
'get_session',
'get_async_session',
]
engine = create_engine(config.db_url)
create_session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# sync
engine = create_engine(
config.db_url, connect_args={'check_same_thread': False}
)
create_session = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
def get_session() -> Session:
with create_session() as session:
yield session
# async_engine = create_async_engine(config.db_url)
# create_async_session = sessionmaker(
# async_engine, class_ = AsyncSession, expire_on_commit = False
# )
#
# async def get_async_session() -> AsyncSession:
# async with create_async_session() as session:
# yield session
# async
async_engine = create_async_engine(
config.db_async_url, connect_args={'check_same_thread': False}
)
create_async_session = sessionmaker(
async_engine, class_ = AsyncSession, expire_on_commit = False
)
async def get_async_session() -> AsyncSession:
async with create_async_session() as session:
yield session

View File

@ -2,48 +2,49 @@ from __future__ import annotations
from uuid import UUID
from fastapi import Depends
from sqlalchemy import delete
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import NoResultFound
from SmartHome.API.models import Hub
from SmartHome.API.dependencies import database
from . import schemas
from Raspberry import WiFi
import config
from API import exceptions
from API.models import Hub
from API.dependencies import database
from . import schemas
class HubManager:
def __init__(self, session = Depends(database.get_session)):
session: AsyncSession
def __init__(self, session = Depends(database.get_async_session)):
self.session = session
def __del__(self):
self.session.close()
async def init(self, create_hub: schemas.HubInit) -> Hub:
db: AsyncSession = self.session
@classmethod
def default(cls, session: database.Session | None = None) -> HubManager:
session = session or database.create_session()
return cls(session)
def get(self) -> Hub:
db = self.session
return db.query(Hub).one()
def create(self, create_hub: schemas.Hub) -> Hub:
db = self.session
if hub := self.get():
db.delete(hub)
if hub := await self.get():
await db.delete(hub)
hub = Hub(id = create_hub.id, name = create_hub.name, house_id = create_hub.house_id)
db.add(hub)
db.commit()
db.refresh(hub)
await db.commit()
await db.refresh(hub)
return hub
def patch(self, id: UUID, hub: schemas.PatchHub):
db = self.session
values = {key: value for key, value in hub.dict().items() if key != 'id' and value != None}
db.execute(Hub.__table__.update().values(values).filter_by(id = id))
db.commit()
async def get(self) -> Hub | None:
db: AsyncSession = self.session
result = await db.scalars(select(Hub))
hub = result.first()
return hub
async def patch(self, hub: schemas.PatchHub):
db: AsyncSession = self.session
values = {key: value for key, value in hub.dict().items() if value != None}
await db.execute(update(Hub).values(**values))
await db.commit()
def wifi(self, ssid: str, password: str):
WiFi.save_and_connect(ssid, password)
@ -51,8 +52,13 @@ class HubManager:
def get_hotspots(self) -> list[schemas.Hotspot]:
return [schemas.Hotspot(**cell.__dict__) for cell in WiFi.get_list()]
def set_tokens(tokens_pair: schemas.TokensPair):
with open(f'{path}/{resources}/access_token.txt', 'w') as f:
def save_tokens(self, tokens_pair: schemas.HubAuthItems):
with open(f'{config.src}/access_token.txt', 'w') as f:
f.write(tokens_pair.access_token)
with open(f'{path}/{resources}/refresh_token.txt', 'w') as f:
with open(f'{config.src}/refresh_token.txt', 'w') as f:
f.write(tokens_pair.refresh_token)
def save_credentials(self, credentials: schemas.HubAuthItems):
self.save_tokens(credentials)
with open(f'{config.src}/public_key.txt', 'w') as f:
f.write(credentials.public_key)

View File

@ -1,8 +1,7 @@
from uuid import UUID
from fastapi import APIRouter, Depends
import SmartHome.API.exceptions
from API import exceptions
from .HubManager import HubManager
from .schemas import Hub, PatchHub, TokensPair, Hotspot
from .schemas import HubInit, Hub, HubPatch, TokensPair, Hotspot
router = APIRouter(
@ -10,26 +9,30 @@ router = APIRouter(
tags = ['hub'],
)
@router.get('', response_model = Hub)
async def hub_get(manager: HubManager = Depends()):
return manager.get()
@router.post('', response_model = Hub)
async def hub_create(hub: Hub, manager: HubManager = Depends()):
return manager.create(hub)
async def init_hub(hub: HubInit, manager: HubManager = Depends()):
return await manager.init(hub)
@router.get('', response_model = Hub)
async def get_hub(manager: HubManager = Depends()):
hub = await manager.get()
if hub:
return hub
else:
raise exceptions.not_found
@router.patch('')
async def hub_patch(hub: PatchHub, manager: HubManager = Depends()):
manager.patch(hub)
async def patch_hub(hub: HubPatch, manager: HubManager = Depends()):
await manager.patch(hub)
@router.post('/connect')
async def hub_connect(ssid: str, password: str, manager: HubManager = Depends()):
async def connect_to_hub(ssid: str, password: str, manager: HubManager = Depends()):
manager.wifi(ssid, password)
@router.get('/hotspots')
async def hub_hotspots(manager: HubManager = Depends()):
@router.get('/hotspots', response_model = list[Hotspot])
async def get_hub_hotspots(manager: HubManager = Depends()):
return manager.get_hotspots()
@router.post('/set_tokens')
async def set_tokens(tokens: TokensPair):
async def set_tokens(tokens: TokensPair, manager: HubManager = Depends()):
manager.save_tokens(tokens)

View File

@ -2,8 +2,12 @@ from uuid import UUID
from pydantic import BaseModel
class PatchHub(BaseModel):
name: str
class TokensPair(BaseModel):
access_token: str
refresh_token: str
class HubAuthItems(TokensPair):
public_key: str
class Hub(BaseModel):
id: UUID
@ -13,9 +17,11 @@ class Hub(BaseModel):
class Config:
orm_mode = True
class TokensPair(BaseModel):
access_token: str
refresh_token: str
class HubPatch(BaseModel):
name: str
class HubInit(Hub, HubAuthItems):
...
class Hotspot(BaseModel):
ssid: str

View File

@ -1,6 +1,8 @@
from SmartHome.Merlin import Merlin, MerlinMessage
from SmartHome.API.dependencies import database
from SmartHome.API.models import Device
from sqlalchemy import select
from Merlin import Merlin, MerlinMessage
from API.dependencies import database
from API.models import Device, DeviceModelParameter
from . import schemas
@ -8,23 +10,18 @@ class WSManager:
merlin = Merlin()
def merlin_send(self, data: schemas.MerlinData):
print(data)
return
# raise Exception('Not implemented')
db = database.get_session()
try:
device = db.get(Device, data.device_id)
if not device:
model_parameter = db.execute(
select(DeviceModelParameter)
.where(
DeviceModelParameter.devicemodel_id == device.model.id,
DeviceModelParameter.parameter_id == data.parameter_id
)
).scalar_one()
except: # TODO: Specify exceptions
return
parameter = next([p for p in device.parameters if p.id == data.parameter_id], None)
if not parameter:
return
# TODO: f = #device.parameters.index(parameter)
x = data.value
self.merlin.send(MerlinMessage(device.urdi, f, x))
def merlin_send_raw(self, data: schemas.MerlinRaw):
self.merlin.send(MerlinMessage(data.urdi, data.f, data.x))
self.merlin.send(MerlinMessage(device.urdi, model_parameter.f, data.value))

View File

@ -1,5 +1,5 @@
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from .schemas import SocketType, SocketData, MerlinData, MerlinRaw
from .schemas import SocketType, SocketData, MerlinData
from .WSManager import WSManager
@ -22,14 +22,11 @@ class ConnectionManager:
except: return
match socket.type:
case SocketType.merlin:
try: merlin_data = MerlinData(**socket.data)
except: return
try:
merlin_data = MerlinData(**socket.data)
except: # TODO: specify exception
return
self.wsmanager.merlin_send(merlin_data)
case SocketType.merlin_raw:
try: merlin_raw_data = MerlinRaw(**socket.data)
except: return
self.wsmanager.merlin_send_raw(merlin_raw_data)
# await websocket.send_text(message)
connection = ConnectionManager()
router = APIRouter(

View File

@ -5,7 +5,6 @@ from pydantic import BaseModel
class SocketType(str, Enum):
merlin = 'merlin'
merlin_raw = 'merlin_raw'
class SocketData(BaseModel):
type: SocketType
@ -15,8 +14,3 @@ class MerlinData(BaseModel):
device_id: str
parameter_id: str
value: str
class MerlinRaw(BaseModel):
urdi: int
f: int
x: int

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import relationship
from .Base import Base
class DeviceParameterAssociation(Base):
class DeviceParameterAssociation(Base): # TODO: remove
id = Column(UUIDType, index = True, primary_key = True, default = uuid1)
device_id = Column(ForeignKey('devices.id'), primary_key = True)
parameter_id = Column(ForeignKey('parameters.id'), primary_key = True)

View File

@ -1,19 +1,26 @@
from uuid import uuid1
from sqlalchemy import Table, Column, String, Integer, ForeignKey
from sqlalchemy import Table, Column, String, Integer, ForeignKey, UniqueConstraint
from sqlalchemy_utils import UUIDType
from sqlalchemy.orm import relationship
from .Base import Base
_devicemodelparameters = Table('devicemodelparameters', Base.metadata,
Column('device_model_id', ForeignKey('devicemodels.id')),
Column('parameter_id', ForeignKey('parameters.id')),
)
class DeviceModelParameter(Base):
id = Column(UUIDType, index = True, primary_key = True, default = uuid1) # TODO: remove
devicemodel_id = Column(ForeignKey('devicemodels.id'), primary_key = True)
parameter_id = Column(ForeignKey('parameters.id'), primary_key = True)
f = Column(Integer, nullable = False, default = 0)
parameter = relationship('Parameter', lazy='joined')
devicemodel = relationship('DeviceModel')
__table_args__ = (UniqueConstraint('devicemodel_id', 'parameter_id'),)
def __str__(self):
return self.parameter.name or super().__str__()
class DeviceModel(Base):
id = Column(UUIDType, index = True, primary_key = True, default = uuid1)
name = Column(String)
parameters = relationship('Parameter', secondary = _devicemodelparameters)
parameters = relationship('DeviceModelParameter', back_populates = 'devicemodel', cascade = 'all, delete-orphan')
def __str__(self):
return self.name or super().__str__()

View File

@ -5,11 +5,12 @@ from sqlalchemy_utils import UUIDType
from .Base import Base
class Parameter(Base):
class Parameter(Base): # TODO: remove
id = Column(UUIDType, index = True, primary_key = True, default = uuid1)
name = Column(String, nullable = False)
value_type = Column(String, nullable = False)
device_parameters = relationship('DeviceParameterAssociation', back_populates = 'parameter', cascade = 'all, delete-orphan')
devicemodel_parameters = relationship('DeviceModelParameter', back_populates = 'parameter', cascade = 'all, delete-orphan')
def __str__(self):
return self.name or super().__str__()

View File

@ -3,5 +3,5 @@ from .House import House
from .Hub import Hub
from .Room import Room
from .Device import Device, DeviceParameterAssociation
from .DeviceModel import DeviceModel
from .DeviceModel import DeviceModel, DeviceModelParameter
from .Parameter import Parameter

View File

@ -1,20 +1,26 @@
from typing import List
from threading import Thread
try:
import RPi.GPIO as GPIO
import spidev
GPIO.setmode(GPIO.BCM)
is_rpi = True
except ModuleNotFoundError:
is_rpi = False
from .MerlinMessage import MerlinMessage
from .lib_nrf24 import NRF24
GPIO.setmode(GPIO.BCM)
class Merlin():
radio: NRF24
send_queue: List[MerlinMessage] = []
my_urdi = [0xf0, 0xf0, 0xf0, 0xf0, 0xe1]
def __init__(self):
if not is_rpi: return
radio = NRF24(GPIO, spidev.SpiDev())
radio.begin(0, 17)
radio.setPALevel(NRF24.PA_HIGH)
@ -37,7 +43,10 @@ class Merlin():
self.send_queue.append(message)
def _send(self, message: MerlinMessage):
if not is_rpi:
# TODO: Log
print(messag.urdi, message.data, [b.to_bytes(1, 'big') for b in message.data])
return
self.radio.stopListening()
self.radio.openWritingPipe(message.urdi)
self.radio.write(message.data)
@ -49,6 +58,8 @@ class Merlin():
while self.send_queue and (message := self.send_queue.pop()):
self._send(message)
if not is_rpi: return
# receiving messages
if not self.radio.available():
continue
@ -56,7 +67,7 @@ class Merlin():
rawData = []
self.radio.read(rawData, self.radio.getPayloadSize())
func, arg = rawData
print(f'{func=} {arg=}')
print(f'{func=} {arg=}') # TODO: Log
receiveAndTransmitThread = Thread(target = Merlin().receiveAndTransmit)
receiveAndTransmitThread.start()

View File

@ -1,5 +1,5 @@
from wifi import Cell, Scheme
from PyAccessPoint import pyaccesspoint
# from PyAccessPoint import pyaccesspoint
import config
@ -13,9 +13,9 @@ __all__ = [
'is_hotspot_running'
]
access_point = pyaccesspoint.AccessPoint()
access_point.ssid = config.wifi_ssid
access_point.password = config.wifi_password
# access_point = pyaccesspoint.AccessPoint()
# access_point.ssid = config.wifi_ssid
# access_point.password = config.wifi_password
class ConnectionException(Exception):
pass
@ -48,10 +48,13 @@ def connect_first() -> None | ConnectionException:
raise ConnectionException('No schemes available')
def start_hotspot():
access_point.start()
# access_point.start()
...
def stop_hotspot():
access_point.stop()
# access_point.stop()
...
def is_hotspot_running():
access_point.is_running()
def is_hotspot_running() -> bool:
return False
# access_point.is_running()

View File

@ -10,10 +10,10 @@ from Raspberry import WiFi
if __name__ == '__main__':
try:
hub = HubManager.default().get()
WiFi.connect_first()
except:
WiFi.start_hotspot()
# try:
# hub = HubManager.default().get()
# WiFi.connect_first()
# except:
# WiFi.start_hotspot()
uvicorn.run('main:app', host = '0.0.0.0', port = 8001, reload = True, reload_dirs=[root,])
uvicorn.run('main:app', host = '0.0.0.0', port = 8000, reload = True, reload_dirs=[root,])

View File

66
SmartHome/tests/setup.py Normal file
View File

@ -0,0 +1,66 @@
from uuid import uuid1, UUID
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from main import app
from API.dependencies.database import get_session, get_async_session
from API import models
# Settings
secret_key = ''
public_key = ''
access_token = ''
hub_access_token = ''
hub_refresh_token = ''
# Dependencies
# sync db
engine = create_engine(
'sqlite:///./test.db', connect_args={'check_same_thread': False}
)
create_session = sessionmaker(
autocommit=False, autoflush=False, bind=engine, connect_args={'check_same_thread': False}
)
def override_get_session() -> Session:
with create_session() as session:
yield session
# async db
async_engine = create_async_engine(
'sqlite+aiosqlite:///./test.db', connect_args={'check_same_thread': False}
)
create_async_session = sessionmaker(
async_engine, class_ = AsyncSession, expire_on_commit = False
)
async def override_get_async_session() -> AsyncSession:
async with create_async_session() as session:
yield session
models.Base.metadata.drop_all(bind = engine)
models.Base.metadata.create_all(bind = engine)
# App
app.dependency_overrides[get_session] = override_get_session
app.dependency_overrides[get_async_session] = override_get_async_session
client = TestClient(app)
# Simulate Cloud
# TODO: auth tokens

View File

@ -0,0 +1,30 @@
from tests.setup import *
def test_init_hub():
id = uuid1()
house_id = uuid1()
response = client.post('/api/hub', json = {
'id': str(id),
'name': 'Inited Hub',
'house_id': str(house_id),
'access_token': hub_access_token,
'refresh_token': hub_refresh_token,
'public_key': public_key,
})
assert response.status_code == 200
# Patch
def test_patch_hub():
hub = client.get('/api/hub').json()
hub['name'] = 'Patched Hub'
response = client.patch('/api/hub', json = {
'name': 'Patched Hub',
})
assert response.status_code == 200
assert client.get('/api/hub').json() == hub

View File

@ -0,0 +1,13 @@
from tests.setup import *
def _test_ws(device_id: UUID, parameter_id: UUID, value: int):
with client.websocket_connect('/ws') as ws:
ws.send_json({
'type': 'merlin',
'data': {
'device_id': str(device_id),
'parameter_id': str(parameter_id),
'value': value
}
}, mode='text')

View File

@ -1,64 +0,0 @@
Python 3.10
# general
requests
aiohttp
# voice assistant
pip install sounddevice
pip install soundfile
pip install numpy
pip install vosk
# download model from https://alphacephei.com/vosk/models
# tts
pip install google-cloud-texttospeech
# Django
pip install django
pip install django-rest-framework
pip install djangorestframework-simplejwt
# telegram
pip install PyTelegramBotApi
# QA
pip install bs4
pip install wikipedia
# Zieit
pip install xlrd
pip install xlwt
pip install xlutils
# Media
pip install aiohttp
pip install pafy
pip install screeninfo
pip install psutil
pip install yt_dlp #pip install youtube-dl
# API
pip install uvicorn
pip install fastapi
pip install sqlalchemy
pip install sqlalchemy_utils
pip install pydantic
pip install passlib
pip install python-dotenv
pip install python-jose
pip install bcrypt
pip install sqladmin # https://github.com/MarkParker5/sqladmin
sudo apt-get install libssl-dev
# WIFI
pip install PyAccessPoint
pip install wifi

69
resources/requirments.txt Normal file
View File

@ -0,0 +1,69 @@
# Python 3.10
# download vosk model from https://alphacephei.com/vosk/models
# general
requests
aiohttp
# voice assistant
sounddevice
soundfile
numpy
vosk # TODO: repo
# tts
google-cloud-texttospeech
# telegram
PyTelegramBotApi
# QA
bs4
wikipedia
# Zieit
# xlrd
# xlwt
# xlutils
# Media
aiohttp
pafy
screeninfo
psutil
yt_dlp # pip install youtube-dl
# DB
pydantic
sqlalchemy
sqlalchemy_utils
# aiosqlite
# API
uvicorn
fastapi
passlib
python-dotenv
python-jose
bcrypt
libssl-dev
git+https://github.com/MarkParker5/sqladmin
# WIFI
# pip install PyAccessPoint
pip install wifi
# tests
pytest