Is there an integration between Pytorch Geometric and Neo4j

Dear Neo4j Community,

Is there a way to export an object with HeteroData datatype in PyG into Neo4j workspace?

I want to visualize my graph, created from Yelp dataset for recommendation and farther steps like building HybridRAG or GraphRAG agents. Here is a part of my code:

# Create the graph

yelp_graph = HeteroData()

yelp_graph['restaurant'].x = restaurant_node_x

yelp_graph['user'].x = user_x

yelp_graph['user', 'reviews', 'restaurant'].edge_index = reviews_index

yelp_graph['user', 'reviews', 'restaurant'].edge_attr = reviews_attr

# Add the reverse edges from restaurants to users in order to let a GNN be able to pass messages in both directions

yelp_graph = ToUndirected()(yelp_graph)

yelp_graph

Output:

HeteroData(
  restaurant={ x=[1161, 142] },
  user={ x=[37170, 20] },
  (user, reviews, restaurant)={
    edge_index=[2, 52807],
    edge_attr=[52807, 5],
  },
  (restaurant, rev_reviews, user)={
    edge_index=[2, 52807],
    edge_attr=[52807, 5],
  }
)
# Transform data type

# Impute -- those restaurant without reviews will have 0 for sentiment and compliment count

yelp_graph['restaurant'].x = torch.nan_to_num(yelp_graph['restaurant'].x, nan=0.0)

yelp_graph['restaurant'].x = yelp_graph['restaurant'].x.type(torch.float32)

yelp_graph['user'].x = yelp_graph['user'].x.type(torch.float32)

yelp_graph['user', 'reviews', 'restaurant'].edge_attr = yelp_graph['user', 'reviews', 'restaurant'].edge_attr.type(torch.float32)

yelp_graph['restaurant', 'rev_reviews', 'user'].edge_attr = yelp_graph['restaurant', 'rev_reviews', 'user'].edge_attr.type(torch.float32)

I will be grateful for your assistance.

You can try something like that:

from torch_geometric.data import HeteroData
from neo4j import GraphDatabase
from typing import Dict, List, Tuple
import torch

def export_to_neo4j(graph: HeteroData, uri: str, user: str, password: str) -> None:
    driver = GraphDatabase.driver(uri, auth=(user, password))

    def create_node(tx, label: str, node_id: int, props: dict):
        query = f"""
        MERGE (n:{label} {{id: $id}})
        SET n += $props
        """
        tx.run(query, id=node_id, props=props)

    def create_edge(tx, src_label: str, src_id: int, rel_type: str, dst_label: str, dst_id: int, props: dict):
        query = f"""
        MATCH (a:{src_label} {{id: $src_id}})
        MATCH (b:{dst_label} {{id: $dst_id}})
        MERGE (a)-[r:{rel_type}]->(b)
        SET r += $props
        """
        tx.run(query, src_id=src_id, dst_id=dst_id, props=props)

    def infer_feature_names(tensor: torch.Tensor, default_prefix: str = "f") -> List[str]:
        return [f"{default_prefix}{i}" for i in range(tensor.size(1))]

    with driver.session() as session:
        # Export nodes with inferred feature names
        for node_type in graph.node_types:
            if hasattr(graph[node_type], 'x') and graph[node_type].x is not None:
                x = graph[node_type].x
                feature_names = getattr(x, 'feature_names', infer_feature_names(x))
                for idx in range(x.size(0)):
                    feature = x[idx]
                    props = {
                        name: float(feature[i]) for i, name in enumerate(feature_names)
                    }
                    session.write_transaction(create_node, node_type, idx, props)

        # Export edges with inferred feature names
        for edge_type in graph.edge_types:
            src_type, rel_type, dst_type = edge_type
            edge_index = graph[edge_type].edge_index
            edge_attr = getattr(graph[edge_type], 'edge_attr', None)
            if edge_attr is not None:
                feature_names = getattr(edge_attr, 'feature_names', infer_feature_names(edge_attr))
            else:
                feature_names = []

            for i in range(edge_index.size(1)):
                src = int(edge_index[0, i])
                dst = int(edge_index[1, i])
                props = {}
                if edge_attr is not None:
                    props = {
                        name: float(edge_attr[i][j]) for j, name in enumerate(feature_names)
                    }
                session.write_transaction(create_edge, src_type, src, rel_type, dst_type, dst, props)

    driver.close()

Some properties names can be changed in the code

Try to use the function:
export_to_neo4j(yelp_graph, "bolt://localhost:7687", "neo4j", "password")

I think it should make decomposition of your PyG graph to nodes and relations properly

1 Like