Skip to main content
UncategorizedPrediction638 lines

Knowledge Graph Prediction

Quick Summary17 lines
Knowledge graph prediction uses structured representations of entities and their relationships to forecast future events, discover hidden connections, and reason about complex systems. By converting unstructured data into graph structures and applying graph-based reasoning, we can predict missing links, forecast temporal events, and generate insights that are impossible to extract from flat data. GraphRAG (Graph Retrieval-Augmented Generation) extends this by combining knowledge graphs with large language models for enhanced prediction.

## Key Points

1. Your prediction with a probability estimate
2. Key reasoning chains from the knowledge graph
3. What additional information would change your prediction
1. Knowledge graphs provide structured, queryable representations of relationships that enhance prediction accuracy
2. Entity and relationship extraction from unstructured text is the critical bottleneck; LLMs have dramatically improved this
3. Link prediction using graph topology (common neighbors, Jaccard, Adamic-Adar) provides strong baselines with no training required
4. Knowledge graph embeddings (TransE and successors) learn dense representations enabling similarity-based prediction
5. Temporal knowledge graphs capture how relationships evolve, enabling pattern-based temporal prediction
6. GraphRAG combines the structured reasoning of knowledge graphs with the generative power of LLMs for enhanced forecasting
7. Deduplication and entity resolution are essential maintenance tasks for production knowledge graphs
8. Multi-hop reasoning through knowledge graphs reveals indirect connections that are invisible in flat data
skilldb get prediction-skills/knowledge-graph-predictionFull skill: 638 lines
Paste into your CLAUDE.md or agent config

Knowledge Graph Prediction

Overview

Knowledge graph prediction uses structured representations of entities and their relationships to forecast future events, discover hidden connections, and reason about complex systems. By converting unstructured data into graph structures and applying graph-based reasoning, we can predict missing links, forecast temporal events, and generate insights that are impossible to extract from flat data. GraphRAG (Graph Retrieval-Augmented Generation) extends this by combining knowledge graphs with large language models for enhanced prediction.

Knowledge Graph Fundamentals

Structure

A knowledge graph represents knowledge as a collection of triples:

(Subject, Predicate, Object)
(Apple, competes_with, Microsoft)
(GPT-4, developed_by, OpenAI)
(OpenAI, founded_in, 2015)
(Sam Altman, leads, OpenAI)
from dataclasses import dataclass, field
from typing import Optional
import uuid
from datetime import datetime

@dataclass
class Entity:
    """A node in the knowledge graph."""
    id: str
    name: str
    entity_type: str  # Person, Organization, Event, Concept, etc.
    properties: dict = field(default_factory=dict)
    created_at: datetime = field(default_factory=datetime.now)

@dataclass
class Relation:
    """An edge in the knowledge graph."""
    id: str
    source_id: str
    target_id: str
    relation_type: str
    properties: dict = field(default_factory=dict)
    confidence: float = 1.0
    temporal: Optional[dict] = None  # {start: date, end: date}

class KnowledgeGraph:
    """Core knowledge graph data structure."""

    def __init__(self):
        self.entities = {}  # id -> Entity
        self.relations = {}  # id -> Relation
        self.adjacency = {}  # entity_id -> [(relation_id, target_id)]
        self.reverse_adjacency = {}  # entity_id -> [(relation_id, source_id)]
        self.type_index = {}  # entity_type -> [entity_ids]

    def add_entity(self, name: str, entity_type: str, **properties) -> Entity:
        entity = Entity(
            id=str(uuid.uuid4())[:12],
            name=name,
            entity_type=entity_type,
            properties=properties
        )
        self.entities[entity.id] = entity
        self.adjacency[entity.id] = []
        self.reverse_adjacency[entity.id] = []

        if entity_type not in self.type_index:
            self.type_index[entity_type] = []
        self.type_index[entity_type].append(entity.id)

        return entity

    def add_relation(self, source_id: str, target_id: str,
                     relation_type: str, confidence: float = 1.0,
                     **properties) -> Relation:
        relation = Relation(
            id=str(uuid.uuid4())[:12],
            source_id=source_id,
            target_id=target_id,
            relation_type=relation_type,
            properties=properties,
            confidence=confidence
        )
        self.relations[relation.id] = relation
        self.adjacency[source_id].append((relation.id, target_id))
        self.reverse_adjacency[target_id].append((relation.id, source_id))
        return relation

    def get_neighbors(self, entity_id: str, relation_type: str = None) -> list:
        """Get all entities connected to a given entity."""
        neighbors = []
        for rel_id, target_id in self.adjacency.get(entity_id, []):
            rel = self.relations[rel_id]
            if relation_type is None or rel.relation_type == relation_type:
                neighbors.append({
                    'entity': self.entities[target_id],
                    'relation': rel,
                    'direction': 'outgoing'
                })
        for rel_id, source_id in self.reverse_adjacency.get(entity_id, []):
            rel = self.relations[rel_id]
            if relation_type is None or rel.relation_type == relation_type:
                neighbors.append({
                    'entity': self.entities[source_id],
                    'relation': rel,
                    'direction': 'incoming'
                })
        return neighbors

    def find_paths(self, source_id: str, target_id: str,
                   max_depth: int = 3) -> list:
        """Find all paths between two entities up to max_depth."""
        paths = []

        def dfs(current, target, path, visited, depth):
            if current == target:
                paths.append(list(path))
                return
            if depth >= max_depth:
                return
            for rel_id, next_id in self.adjacency.get(current, []):
                if next_id not in visited:
                    visited.add(next_id)
                    path.append((self.relations[rel_id], self.entities[next_id]))
                    dfs(next_id, target, path, visited, depth + 1)
                    path.pop()
                    visited.remove(next_id)

        dfs(source_id, target_id, [], {source_id}, 0)
        return paths

    def subgraph(self, entity_ids: set, max_hops: int = 1) -> 'KnowledgeGraph':
        """Extract a subgraph around specified entities."""
        sub = KnowledgeGraph()
        expanded_ids = set(entity_ids)

        for _ in range(max_hops):
            new_ids = set()
            for eid in expanded_ids:
                for rel_id, target_id in self.adjacency.get(eid, []):
                    new_ids.add(target_id)
                for rel_id, source_id in self.reverse_adjacency.get(eid, []):
                    new_ids.add(source_id)
            expanded_ids.update(new_ids)

        for eid in expanded_ids:
            if eid in self.entities:
                e = self.entities[eid]
                sub.add_entity(e.name, e.entity_type, **e.properties)
                sub.entities[e.id] = e

        for rel_id, rel in self.relations.items():
            if rel.source_id in expanded_ids and rel.target_id in expanded_ids:
                sub.add_relation(rel.source_id, rel.target_id,
                               rel.relation_type, rel.confidence)

        return sub

Converting Unstructured Data to Knowledge Graphs

Entity and Relationship Extraction

class GraphExtractor:
    """Extract knowledge graph triples from unstructured text."""

    def __init__(self, llm_client=None):
        self.llm_client = llm_client
        self.entity_types = [
            'Person', 'Organization', 'Technology', 'Event',
            'Location', 'Product', 'Concept', 'Date'
        ]
        self.relation_types = [
            'develops', 'competes_with', 'acquired', 'founded',
            'leads', 'invested_in', 'partners_with', 'located_in',
            'caused', 'preceded', 'affects', 'part_of'
        ]

    async def extract_from_text(self, text: str, kg: KnowledgeGraph) -> dict:
        """Extract entities and relations from text using LLM."""
        prompt = f"""Extract entities and relationships from this text.

Text: {text}

Entity types: {', '.join(self.entity_types)}
Relation types: {', '.join(self.relation_types)}

Respond as JSON:
{{
  "entities": [
    {{"name": "...", "type": "...", "properties": {{}}}}
  ],
  "relations": [
    {{"source": "...", "target": "...", "type": "...", "confidence": 0.0-1.0}}
  ]
}}"""

        response = await self.llm_client.complete(prompt)
        extracted = self._parse_response(response)

        # Add to knowledge graph
        entity_map = {}
        for e in extracted.get('entities', []):
            entity = kg.add_entity(e['name'], e['type'], **e.get('properties', {}))
            entity_map[e['name']] = entity.id

        for r in extracted.get('relations', []):
            source_id = entity_map.get(r['source'])
            target_id = entity_map.get(r['target'])
            if source_id and target_id:
                kg.add_relation(source_id, target_id, r['type'],
                              confidence=r.get('confidence', 0.8))

        return extracted

    def extract_from_documents(self, documents: list,
                                kg: KnowledgeGraph) -> KnowledgeGraph:
        """Batch extract from multiple documents with deduplication."""
        for doc in documents:
            # Chunk document
            chunks = self._chunk_text(doc['text'], max_chars=2000)
            for chunk in chunks:
                # In production, this would be async
                pass  # await self.extract_from_text(chunk, kg)

        # Deduplicate entities
        self._deduplicate_entities(kg)
        return kg

    def _deduplicate_entities(self, kg: KnowledgeGraph):
        """Merge duplicate entities based on name similarity."""
        from difflib import SequenceMatcher

        entities_by_type = {}
        for eid, entity in kg.entities.items():
            if entity.entity_type not in entities_by_type:
                entities_by_type[entity.entity_type] = []
            entities_by_type[entity.entity_type].append(entity)

        merges = []
        for etype, entities in entities_by_type.items():
            for i in range(len(entities)):
                for j in range(i + 1, len(entities)):
                    similarity = SequenceMatcher(
                        None, entities[i].name.lower(), entities[j].name.lower()
                    ).ratio()
                    if similarity > 0.85:
                        merges.append((entities[i].id, entities[j].id))

        for keep_id, merge_id in merges:
            self._merge_entities(kg, keep_id, merge_id)

    def _merge_entities(self, kg: KnowledgeGraph, keep_id: str, merge_id: str):
        """Merge merge_id into keep_id."""
        for rel_id, target_id in kg.adjacency.get(merge_id, []):
            kg.add_relation(keep_id, target_id, kg.relations[rel_id].relation_type)
        for rel_id, source_id in kg.reverse_adjacency.get(merge_id, []):
            kg.add_relation(source_id, keep_id, kg.relations[rel_id].relation_type)

        if merge_id in kg.entities:
            del kg.entities[merge_id]

    def _chunk_text(self, text: str, max_chars: int = 2000) -> list:
        sentences = text.split('. ')
        chunks = []
        current_chunk = ""
        for sentence in sentences:
            if len(current_chunk) + len(sentence) > max_chars:
                if current_chunk:
                    chunks.append(current_chunk)
                current_chunk = sentence
            else:
                current_chunk += ". " + sentence if current_chunk else sentence
        if current_chunk:
            chunks.append(current_chunk)
        return chunks

    def _parse_response(self, response: str) -> dict:
        import json
        try:
            return json.loads(response)
        except json.JSONDecodeError:
            return {'entities': [], 'relations': []}

Link Prediction

Predicting Missing or Future Connections

import numpy as np

class LinkPredictor:
    """Predict missing or future links in a knowledge graph."""

    def __init__(self, kg: KnowledgeGraph):
        self.kg = kg

    def common_neighbors_score(self, entity_a: str, entity_b: str) -> float:
        """Simple: more common neighbors = more likely to be connected."""
        neighbors_a = set(t for _, t in self.kg.adjacency.get(entity_a, []))
        neighbors_b = set(t for _, t in self.kg.adjacency.get(entity_b, []))
        common = neighbors_a & neighbors_b
        return len(common)

    def jaccard_coefficient(self, entity_a: str, entity_b: str) -> float:
        """Normalized common neighbors."""
        neighbors_a = set(t for _, t in self.kg.adjacency.get(entity_a, []))
        neighbors_b = set(t for _, t in self.kg.adjacency.get(entity_b, []))
        union = neighbors_a | neighbors_b
        if not union:
            return 0
        return len(neighbors_a & neighbors_b) / len(union)

    def adamic_adar_score(self, entity_a: str, entity_b: str) -> float:
        """Weight common neighbors by inverse log of their degree."""
        neighbors_a = set(t for _, t in self.kg.adjacency.get(entity_a, []))
        neighbors_b = set(t for _, t in self.kg.adjacency.get(entity_b, []))
        common = neighbors_a & neighbors_b

        score = 0
        for neighbor in common:
            degree = len(self.kg.adjacency.get(neighbor, []))
            if degree > 1:
                score += 1 / np.log(degree)
        return score

    def predict_links(self, entity_id: str, top_k: int = 10,
                      relation_type: str = None) -> list:
        """Predict the most likely new connections for an entity."""
        candidates = []

        existing = set(t for _, t in self.kg.adjacency.get(entity_id, []))
        existing.add(entity_id)

        for candidate_id in self.kg.entities:
            if candidate_id in existing:
                continue

            score = (
                0.4 * self.common_neighbors_score(entity_id, candidate_id) +
                0.3 * self.jaccard_coefficient(entity_id, candidate_id) +
                0.3 * self.adamic_adar_score(entity_id, candidate_id)
            )

            if score > 0:
                candidates.append({
                    'entity': self.kg.entities[candidate_id],
                    'score': score,
                    'common_neighbors': self.common_neighbors_score(entity_id, candidate_id)
                })

        candidates.sort(key=lambda x: -x['score'])
        return candidates[:top_k]

Knowledge Graph Embeddings for Link Prediction

class TransE:
    """
    TransE embedding model: h + r ≈ t
    Entities and relations are embedded in the same vector space.
    A relation is modeled as a translation from head to tail.
    """

    def __init__(self, n_entities: int, n_relations: int,
                 embedding_dim: int = 100, learning_rate: float = 0.01,
                 margin: float = 1.0):
        self.dim = embedding_dim
        self.lr = learning_rate
        self.margin = margin

        # Initialize embeddings
        self.entity_embeddings = np.random.uniform(
            -6/np.sqrt(embedding_dim), 6/np.sqrt(embedding_dim),
            (n_entities, embedding_dim)
        )
        self.relation_embeddings = np.random.uniform(
            -6/np.sqrt(embedding_dim), 6/np.sqrt(embedding_dim),
            (n_relations, embedding_dim)
        )

        # Normalize entity embeddings
        norms = np.linalg.norm(self.entity_embeddings, axis=1, keepdims=True)
        self.entity_embeddings /= np.where(norms > 0, norms, 1)

    def score(self, head: int, relation: int, tail: int) -> float:
        """Score a triple. Lower = more likely to be true."""
        h = self.entity_embeddings[head]
        r = self.relation_embeddings[relation]
        t = self.entity_embeddings[tail]
        return np.linalg.norm(h + r - t)

    def train_step(self, positive_triples: list, n_entities: int):
        """One training step with margin-based loss."""
        for h, r, t in positive_triples:
            # Generate negative sample by corrupting head or tail
            if np.random.random() < 0.5:
                h_neg = np.random.randint(n_entities)
                t_neg = t
            else:
                h_neg = h
                t_neg = np.random.randint(n_entities)

            pos_score = self.score(h, r, t)
            neg_score = self.score(h_neg, r, t_neg)

            loss = max(0, self.margin + pos_score - neg_score)

            if loss > 0:
                # Gradient update
                h_vec = self.entity_embeddings[h]
                r_vec = self.relation_embeddings[r]
                t_vec = self.entity_embeddings[t]

                grad = 2 * (h_vec + r_vec - t_vec)

                self.entity_embeddings[h] -= self.lr * grad
                self.relation_embeddings[r] -= self.lr * grad
                self.entity_embeddings[t] += self.lr * grad

        # Renormalize
        norms = np.linalg.norm(self.entity_embeddings, axis=1, keepdims=True)
        self.entity_embeddings /= np.where(norms > 0, norms, 1)

    def predict_tail(self, head: int, relation: int, top_k: int = 10) -> list:
        """Predict most likely tail entities for (head, relation, ?)."""
        h = self.entity_embeddings[head]
        r = self.relation_embeddings[relation]
        target = h + r

        distances = np.linalg.norm(self.entity_embeddings - target, axis=1)
        top_indices = np.argsort(distances)[:top_k]

        return [(idx, distances[idx]) for idx in top_indices]

Temporal Knowledge Graphs

class TemporalKnowledgeGraph(KnowledgeGraph):
    """Knowledge graph with time-aware relations for temporal prediction."""

    def add_temporal_relation(self, source_id: str, target_id: str,
                              relation_type: str, start_time: datetime,
                              end_time: datetime = None, **properties):
        rel = self.add_relation(source_id, target_id, relation_type, **properties)
        rel.temporal = {'start': start_time, 'end': end_time}
        return rel

    def snapshot_at(self, timestamp: datetime) -> KnowledgeGraph:
        """Get the state of the graph at a specific point in time."""
        snapshot = KnowledgeGraph()

        for eid, entity in self.entities.items():
            snapshot.entities[eid] = entity
            snapshot.adjacency[eid] = []
            snapshot.reverse_adjacency[eid] = []

        for rel_id, rel in self.relations.items():
            if rel.temporal:
                start = rel.temporal['start']
                end = rel.temporal.get('end')
                if start <= timestamp and (end is None or end >= timestamp):
                    snapshot.relations[rel_id] = rel
                    snapshot.adjacency[rel.source_id].append((rel_id, rel.target_id))
                    snapshot.reverse_adjacency[rel.target_id].append((rel_id, rel.source_id))
            else:
                snapshot.relations[rel_id] = rel
                snapshot.adjacency[rel.source_id].append((rel_id, rel.target_id))

        return snapshot

    def predict_future_relations(self, entity_id: str,
                                 future_time: datetime) -> list:
        """Predict what relations an entity will have in the future."""
        # Analyze temporal patterns
        entity_history = []
        for rel_id, rel in self.relations.items():
            if rel.source_id == entity_id or rel.target_id == entity_id:
                if rel.temporal:
                    entity_history.append({
                        'relation': rel,
                        'start': rel.temporal['start'],
                        'end': rel.temporal.get('end'),
                        'type': rel.relation_type
                    })

        entity_history.sort(key=lambda x: x['start'])

        # Pattern detection: recurring relation types
        type_counts = {}
        for event in entity_history:
            rtype = event['type']
            type_counts[rtype] = type_counts.get(rtype, 0) + 1

        predictions = []
        for rtype, count in type_counts.items():
            if count >= 2:
                # Calculate average interval between occurrences
                occurrences = [e['start'] for e in entity_history if e['type'] == rtype]
                if len(occurrences) >= 2:
                    intervals = [(occurrences[i+1] - occurrences[i]).days
                                for i in range(len(occurrences)-1)]
                    avg_interval = np.mean(intervals)
                    last = occurrences[-1]
                    next_expected = last + pd.Timedelta(days=avg_interval)

                    predictions.append({
                        'relation_type': rtype,
                        'expected_time': next_expected,
                        'confidence': min(count / 10, 0.9),
                        'based_on_occurrences': count
                    })

        return predictions

GraphRAG for Prediction

Combining Knowledge Graphs with LLMs

class GraphRAGPredictor:
    """
    Use knowledge graphs to enhance LLM predictions.
    GraphRAG provides structured context that improves reasoning.
    """

    def __init__(self, kg: KnowledgeGraph, llm_client=None):
        self.kg = kg
        self.llm = llm_client

    def predict_with_context(self, question: str, max_context_triples: int = 50) -> dict:
        """Answer a prediction question using graph-enhanced context."""

        # Step 1: Extract key entities from the question
        entities = self._extract_question_entities(question)

        # Step 2: Retrieve relevant subgraph
        relevant_triples = self._retrieve_context(entities, max_context_triples)

        # Step 3: Format context for LLM
        context = self._format_graph_context(relevant_triples)

        # Step 4: Generate prediction with context
        prompt = f"""Based on the following knowledge graph context, answer the prediction question.

Knowledge Graph Context:
{context}

Question: {question}

Provide:
1. Your prediction with a probability estimate
2. Key reasoning chains from the knowledge graph
3. What additional information would change your prediction
"""

        return {
            'context_triples': len(relevant_triples),
            'entities_found': len(entities),
            'prompt': prompt
        }

    def _extract_question_entities(self, question: str) -> list:
        """Find entities in the knowledge graph mentioned in the question."""
        found = []
        question_lower = question.lower()
        for eid, entity in self.kg.entities.items():
            if entity.name.lower() in question_lower:
                found.append(eid)
        return found

    def _retrieve_context(self, entity_ids: list, max_triples: int) -> list:
        """Retrieve relevant triples around the question entities."""
        triples = []
        visited = set()

        for eid in entity_ids:
            # Get direct connections
            for rel_id, target_id in self.kg.adjacency.get(eid, []):
                if rel_id not in visited:
                    rel = self.kg.relations[rel_id]
                    triples.append((
                        self.kg.entities[eid].name,
                        rel.relation_type,
                        self.kg.entities[target_id].name,
                        rel.confidence
                    ))
                    visited.add(rel_id)

            # Get 2-hop connections
            for rel_id, target_id in self.kg.adjacency.get(eid, []):
                for rel_id2, target_id2 in self.kg.adjacency.get(target_id, []):
                    if rel_id2 not in visited and len(triples) < max_triples:
                        rel2 = self.kg.relations[rel_id2]
                        triples.append((
                            self.kg.entities[target_id].name,
                            rel2.relation_type,
                            self.kg.entities[target_id2].name,
                            rel2.confidence
                        ))
                        visited.add(rel_id2)

        # Sort by confidence
        triples.sort(key=lambda x: -x[3])
        return triples[:max_triples]

    def _format_graph_context(self, triples: list) -> str:
        lines = []
        for subj, pred, obj, conf in triples:
            confidence_label = "high" if conf > 0.8 else "medium" if conf > 0.5 else "low"
            lines.append(f"- {subj} --[{pred}]--> {obj} (confidence: {confidence_label})")
        return "\n".join(lines)

Key Takeaways

  1. Knowledge graphs provide structured, queryable representations of relationships that enhance prediction accuracy
  2. Entity and relationship extraction from unstructured text is the critical bottleneck; LLMs have dramatically improved this
  3. Link prediction using graph topology (common neighbors, Jaccard, Adamic-Adar) provides strong baselines with no training required
  4. Knowledge graph embeddings (TransE and successors) learn dense representations enabling similarity-based prediction
  5. Temporal knowledge graphs capture how relationships evolve, enabling pattern-based temporal prediction
  6. GraphRAG combines the structured reasoning of knowledge graphs with the generative power of LLMs for enhanced forecasting
  7. Deduplication and entity resolution are essential maintenance tasks for production knowledge graphs
  8. Multi-hop reasoning through knowledge graphs reveals indirect connections that are invisible in flat data

Install this skill directly: skilldb add prediction-skills

Get CLI access →