I would like to add some precomputed embeddings as properties to nodes. I have a CSV file containing these embeddings. However, when I try to cast the embeddings to a list of floats I get an error.
LOAD CSV WITH HEADERS FROM "<my_file.csv>" AS row
MATCH (p:Skill_t4 {name: row.Tier4})
SET p.embedding = toFloatList(row.embeddings);
Error:
Invalid input for function 'toFloatList()': Expected a List, got: String("[-0.5468039 0.21683593 0.00861383 ... 0.7526193 0.1819184
0.88167435]")
Since everything in a csv is always a string, how can I get around this?
Use some string functions to convert the strings to a string list, and then you can use toFloatList(). You need to remove the brackets and split on comma:
LOAD CSV WITH HEADERS FROM "<my_file.csv>" AS row
WITH split(replace(replace(row.embeddings, "[",""), "]",""), ",") as embedding
MATCH (p:Skill_t4 {name: row.Tier4})
SET p.embedding = toFloatList(embedding);
Not pretty (there might be a better way to get it to a string list), but it works
I saw it. thank you. And thank you @christoffer.bergman. I just ran into a few other snags and have to fix them first. At least now it's a proper "list", albeit a string. Those lists were truncated before. That was a separate issue with how the file had been saved.
@joshcornejo and @christoffer.bergman than you both very much for your help. In the end I found it easier to use FastAPI and the GraphDatabase driver. Code is below for anyone coming across this issue. I could probably have used UNWIND instead of looping inside the session but this is a once-off, so it was one less thing to try to figure out how to do ;)
class AddEmbeddings(BaseModel):
UID: str
@app.post("/api/v1/add_embeddings")
async def add_embeddings(request: AddEmbeddings, key: str = Depends(api_key)):
await authenticate(key,API_KEY)
df = pd.read_csv("service_data/t4_nodes.csv")
t4_skills = list(df["Tier4"])
embeddings = list(df["embeddings_flt"])
def set_emb(tx,skill,emb):
records = tx.run("""
MATCH (p:Skill_t4 {name: $skill})
SET p.embedding = toFloatList($emb);
""",
skill = skill,
emb = emb)
with driver.session() as session:
for idx, skill in enumerate(t4_skills):
emb = ast.literal_eval(embeddings[idx])
session.execute_write(set_emb, skill, emb)
result = True
return JSONResponse(
status_code=200,
content=result
)