from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN
from sentence_transformers import SentenceTransformer
import numpy as np
from plotly import graph_objects as go
from os import listdir, path, stat
from datetime import datetime
import faiss

text_dir = "/home/trent/Videos/CobraArchive/AudioOnly/"
files = listdir(text_dir)
srt_files = [f for f in files if f.endswith(".srt")]
srt_contents = [open(text_dir + f).read() for f in srt_files]

model = SentenceTransformer('all-MiniLM-L12-v2', device='cuda')
embeddings = model.encode(srt_contents, show_progress_bar=True)
embeddings = np.array(embeddings)

dbs = DBSCAN(eps=0.5, min_samples=5).fit(embeddings)
labels = dbs.labels_
colors = ['red', 'blue', 'green', 'yellow', 'orange', 'purple', 'pink', 'black', 'brown', 'gray']
colors = [colors[l % len(colors)] for l in labels]

tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300, random_state=42)
tsne_results = tsne.fit_transform(embeddings)

fig = go.Figure()
fig.add_trace(go.Scatter(x=tsne_results[:,0], y=tsne_results[:,1], mode='markers', text=srt_files, marker=dict(color=colors)))
fig.show()

# Run tsne again for 1D embeddings, then save the video titles in order according to the 1D embeddings' values
tsne = TSNE(n_components=1, verbose=1, perplexity=40, n_iter=300, random_state=42)
tsne_results = tsne.fit_transform(embeddings)
order = np.argsort(tsne_results[:,0])
ordered_files = [srt_files[i] for i in order]
with open("/home/trent/Videos/CobraArchive/video-recommendation-list.txt", "w+") as f:
    f.write("\n".join(ordered_files))

# Build a faiss index for the embeddings
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

# Search the index
while True:
    query = input("Enter a search query: ")
    if query == "":
        break
    query_embedding = np.array(model.encode([query]))
    D, I = index.search(query_embedding, 10)
    for i in I[0]:
        print(srt_files[i])