90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
import unittest
|
|
import numpy as np
|
|
import psycopg2
|
|
import os
|
|
from vectData import calculate_cosine_similarity, is_similar_data, insert_data, get_data, create_db
|
|
|
|
class TestIntegration(unittest.TestCase):
|
|
host = os.getenv("DB_HOST")
|
|
port = os.getenv("DB_PORT")
|
|
user = os.getenv("DB_USER")
|
|
password = os.getenv("DB_PASSWORD")
|
|
dbname = os.getenv("DB_NAME")
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.host = os.getenv("DB_HOST")
|
|
cls.port = os.getenv("DB_PORT")
|
|
cls.user = os.getenv("DB_USER")
|
|
cls.password = os.getenv("DB_PASSWORD")
|
|
cls.dbname = os.getenv("DB_NAME")
|
|
|
|
cls.conn = psycopg2.connect(
|
|
host=cls.host,
|
|
port=cls.port,
|
|
user=cls.user,
|
|
password=cls.password,
|
|
dbname=cls.dbname
|
|
)
|
|
create_db(cls.conn)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls.conn.close()
|
|
|
|
def setUp(self):
|
|
if self.conn.closed:
|
|
self.conn = psycopg2.connect(
|
|
host=self.host,
|
|
port=self.port,
|
|
user=self.user,
|
|
password=self.password,
|
|
dbname=self.dbname
|
|
)
|
|
self.cursor = self.conn.cursor()
|
|
|
|
def tearDown(self):
|
|
if not self.cursor.closed:
|
|
self.cursor.close()
|
|
|
|
if not self.conn.closed:
|
|
self.conn.close()
|
|
|
|
def test_insert_and_retrieve_data(self):
|
|
title = 'test_title'
|
|
text = 'test_text'
|
|
link = 'test_link'
|
|
embedding = np.arange(1, 1537)
|
|
|
|
insert_data(title, text, link, embedding)
|
|
|
|
data = get_data()
|
|
|
|
self.assertEqual(data, [(title, text, link)])
|
|
|
|
def test_is_similar_data_integration(self):
|
|
title = 'test_title'
|
|
text = 'test_text'
|
|
link = 'test_link'
|
|
embedding = np.arange(1, 1537)
|
|
|
|
insert_data(title, text, link, embedding)
|
|
|
|
result = is_similar_data(title, text, link, embedding)
|
|
self.assertTrue(result)
|
|
|
|
result = is_similar_data(title, text, link, embedding)
|
|
self.assertTrue(result)
|
|
|
|
result = is_similar_data(title, text, link, embedding)
|
|
self.assertTrue(result)
|
|
|
|
def test_create_db_integration(self):
|
|
cursor = self.conn.cursor()
|
|
cursor.execute("SELECT * FROM information_schema.tables WHERE table_name = 'vectorsvevijesti'")
|
|
table_exist = bool(cursor.fetchone())
|
|
self.assertTrue(table_exist)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|