webleaf.model.WebGraphAutoEncoder

  1from torch_geometric.nn import GAE, GCN2Conv
  2import torch.nn.functional as F
  3from .TagModel import TagEmbeddingModel, exclude_html_tags, html_tags, TAG_DIMS
  4from .TextModel import TextEmbeddingModel, TEXT_DIMS
  5from torch_geometric.utils import subgraph
  6from torch.nn import Linear
  7import os
  8from lxml import etree
  9from pathlib import Path
 10import torch
 11import re
 12
 13MODEL_PATH = os.path.join(Path(__file__).parent.absolute(), f"product_page_model_4_80.torch")
 14EMBEDDING_DIMENSIONS = 32
 15
 16tag_embedding_model = None
 17text_embedding_model = None
 18
 19
 20class GCNEncoder(torch.nn.Module):
 21    def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0):
 22        super().__init__()
 23
 24        self.lins = torch.nn.ModuleList()
 25        self.lins.append(Linear(input_channels, hidden_channels))
 26        self.lins.append(Linear(hidden_channels, output_channels))
 27
 28        self.convs = torch.nn.ModuleList()
 29        for layer in range(num_layers):
 30            self.convs.append(
 31                GCN2Conv(hidden_channels, alpha, theta, layer + 1,
 32                         shared_weights, normalize=False))
 33
 34        self.dropout = dropout
 35
 36    def forward(self, x, edge_index):
 37        x = F.dropout(x, self.dropout, training=self.training)
 38        x = x_0 = self.lins[0](x).relu()
 39
 40        for conv in self.convs:
 41            x = F.dropout(x, self.dropout, training=self.training)
 42            x = conv(x, x_0, edge_index)
 43            x = x.relu()
 44
 45        x = F.dropout(x, self.dropout, training=self.training)
 46        x = self.lins[1](x)
 47
 48        return x
 49
 50
 51class WebGraphAutoEncoder:
 52    def __init__(self):
 53        global tag_embedding_model
 54        if not tag_embedding_model:
 55            tag_embedding_model = TagEmbeddingModel()
 56
 57        global text_embedding_model
 58        if not text_embedding_model:
 59            text_embedding_model = TextEmbeddingModel()
 60
 61        num_features = TAG_DIMS + TEXT_DIMS
 62        hidden = 128
 63        out_channels = EMBEDDING_DIMENSIONS
 64        encoder = GCNEncoder(input_channels=num_features, hidden_channels=hidden, output_channels=out_channels, num_layers=6, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.1)
 65        self.model = GAE(encoder)
 66
 67        assert os.path.exists(MODEL_PATH), f"Could not find WebLeaf model at [{MODEL_PATH}]"
 68        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 69        self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
 70        self.model.eval()
 71        self.model = self.model.to(self.device)
 72
 73    def extract(self, tree):
 74        root = tree.getroot()
 75
 76        # List of formatting tags we want to remove
 77        formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins']
 78
 79        etree.strip_tags(root, *formatting_tags)
 80
 81        stack = [(root, 0)]
 82        exclude_tag_lookup = set(exclude_html_tags)
 83        tag_lookup = set(html_tags)
 84        assert tree, "Could not create tree"
 85
 86        i = 0
 87        texts = [""]
 88        tags = [root.tag]
 89        edge_index = []
 90        paths = [tree.getpath(root)]
 91        while stack:
 92            element, parent_id = stack.pop(0)
 93
 94            for index, child in enumerate(element):
 95                if isinstance(child, etree._Comment):
 96                    continue
 97
 98                while child.tag == "div" and len(child) == 1:
 99                    child = child[0]
100
101                if child.tag not in exclude_tag_lookup:
102                    tag = child.tag
103                    if tag not in tag_lookup:
104                        tag = "div"
105                    tags.append(tag)
106                    text = self.extract_text(child)[:256]
107                    texts.append(text)
108                    paths.append(tree.getpath(child))
109                    i += 1
110                    edge_index.append([parent_id, i])
111                    stack.append((child, i))
112
113        text_embeddings = text_embedding_model.get_text_embeddings(texts)
114        tag_embeddings = tag_embedding_model.get_tag_embedding(tags)
115        x = []
116        for i in range(len(text_embeddings)):
117            x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i])))
118
119        input_features = torch.stack(x).to(self.device)
120        input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0)
121        input_edge_index = input_edge_index.to(self.device)
122
123        with torch.no_grad():
124            features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach()
125        torch.cuda.empty_cache()
126        del input_features, input_edge_index
127        return features, paths
128
129    def clean_text(self, text):
130        if not text:
131            return
132        cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split())
133        return cleaned_text
134
135    def extract_text(self, element) -> str:
136        text = self.clean_text(element.text)
137        if text:
138            return text
139
140        for label in ["alt", "tite", "aria-label"]:
141            text = self.clean_text(element.get(label))
142            if text:
143                return text
144        return ""
MODEL_PATH = '/home/runner/work/WebLeaf/WebLeaf/webleaf/model/product_page_model_4_80.torch'
EMBEDDING_DIMENSIONS = 32
tag_embedding_model = None
text_embedding_model = None
class GCNEncoder(torch.nn.modules.module.Module):
21class GCNEncoder(torch.nn.Module):
22    def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0):
23        super().__init__()
24
25        self.lins = torch.nn.ModuleList()
26        self.lins.append(Linear(input_channels, hidden_channels))
27        self.lins.append(Linear(hidden_channels, output_channels))
28
29        self.convs = torch.nn.ModuleList()
30        for layer in range(num_layers):
31            self.convs.append(
32                GCN2Conv(hidden_channels, alpha, theta, layer + 1,
33                         shared_weights, normalize=False))
34
35        self.dropout = dropout
36
37    def forward(self, x, edge_index):
38        x = F.dropout(x, self.dropout, training=self.training)
39        x = x_0 = self.lins[0](x).relu()
40
41        for conv in self.convs:
42            x = F.dropout(x, self.dropout, training=self.training)
43            x = conv(x, x_0, edge_index)
44            x = x.relu()
45
46        x = F.dropout(x, self.dropout, training=self.training)
47        x = self.lins[1](x)
48
49        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

GCNEncoder( input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0)
22    def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0):
23        super().__init__()
24
25        self.lins = torch.nn.ModuleList()
26        self.lins.append(Linear(input_channels, hidden_channels))
27        self.lins.append(Linear(hidden_channels, output_channels))
28
29        self.convs = torch.nn.ModuleList()
30        for layer in range(num_layers):
31            self.convs.append(
32                GCN2Conv(hidden_channels, alpha, theta, layer + 1,
33                         shared_weights, normalize=False))
34
35        self.dropout = dropout

Initialize internal Module state, shared by both nn.Module and ScriptModule.

lins
convs
dropout
def forward(self, x, edge_index):
37    def forward(self, x, edge_index):
38        x = F.dropout(x, self.dropout, training=self.training)
39        x = x_0 = self.lins[0](x).relu()
40
41        for conv in self.convs:
42            x = F.dropout(x, self.dropout, training=self.training)
43            x = conv(x, x_0, edge_index)
44            x = x.relu()
45
46        x = F.dropout(x, self.dropout, training=self.training)
47        x = self.lins[1](x)
48
49        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class WebGraphAutoEncoder:
 52class WebGraphAutoEncoder:
 53    def __init__(self):
 54        global tag_embedding_model
 55        if not tag_embedding_model:
 56            tag_embedding_model = TagEmbeddingModel()
 57
 58        global text_embedding_model
 59        if not text_embedding_model:
 60            text_embedding_model = TextEmbeddingModel()
 61
 62        num_features = TAG_DIMS + TEXT_DIMS
 63        hidden = 128
 64        out_channels = EMBEDDING_DIMENSIONS
 65        encoder = GCNEncoder(input_channels=num_features, hidden_channels=hidden, output_channels=out_channels, num_layers=6, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.1)
 66        self.model = GAE(encoder)
 67
 68        assert os.path.exists(MODEL_PATH), f"Could not find WebLeaf model at [{MODEL_PATH}]"
 69        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 70        self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device))
 71        self.model.eval()
 72        self.model = self.model.to(self.device)
 73
 74    def extract(self, tree):
 75        root = tree.getroot()
 76
 77        # List of formatting tags we want to remove
 78        formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins']
 79
 80        etree.strip_tags(root, *formatting_tags)
 81
 82        stack = [(root, 0)]
 83        exclude_tag_lookup = set(exclude_html_tags)
 84        tag_lookup = set(html_tags)
 85        assert tree, "Could not create tree"
 86
 87        i = 0
 88        texts = [""]
 89        tags = [root.tag]
 90        edge_index = []
 91        paths = [tree.getpath(root)]
 92        while stack:
 93            element, parent_id = stack.pop(0)
 94
 95            for index, child in enumerate(element):
 96                if isinstance(child, etree._Comment):
 97                    continue
 98
 99                while child.tag == "div" and len(child) == 1:
100                    child = child[0]
101
102                if child.tag not in exclude_tag_lookup:
103                    tag = child.tag
104                    if tag not in tag_lookup:
105                        tag = "div"
106                    tags.append(tag)
107                    text = self.extract_text(child)[:256]
108                    texts.append(text)
109                    paths.append(tree.getpath(child))
110                    i += 1
111                    edge_index.append([parent_id, i])
112                    stack.append((child, i))
113
114        text_embeddings = text_embedding_model.get_text_embeddings(texts)
115        tag_embeddings = tag_embedding_model.get_tag_embedding(tags)
116        x = []
117        for i in range(len(text_embeddings)):
118            x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i])))
119
120        input_features = torch.stack(x).to(self.device)
121        input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0)
122        input_edge_index = input_edge_index.to(self.device)
123
124        with torch.no_grad():
125            features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach()
126        torch.cuda.empty_cache()
127        del input_features, input_edge_index
128        return features, paths
129
130    def clean_text(self, text):
131        if not text:
132            return
133        cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split())
134        return cleaned_text
135
136    def extract_text(self, element) -> str:
137        text = self.clean_text(element.text)
138        if text:
139            return text
140
141        for label in ["alt", "tite", "aria-label"]:
142            text = self.clean_text(element.get(label))
143            if text:
144                return text
145        return ""
model
device
def extract(self, tree):
 74    def extract(self, tree):
 75        root = tree.getroot()
 76
 77        # List of formatting tags we want to remove
 78        formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins']
 79
 80        etree.strip_tags(root, *formatting_tags)
 81
 82        stack = [(root, 0)]
 83        exclude_tag_lookup = set(exclude_html_tags)
 84        tag_lookup = set(html_tags)
 85        assert tree, "Could not create tree"
 86
 87        i = 0
 88        texts = [""]
 89        tags = [root.tag]
 90        edge_index = []
 91        paths = [tree.getpath(root)]
 92        while stack:
 93            element, parent_id = stack.pop(0)
 94
 95            for index, child in enumerate(element):
 96                if isinstance(child, etree._Comment):
 97                    continue
 98
 99                while child.tag == "div" and len(child) == 1:
100                    child = child[0]
101
102                if child.tag not in exclude_tag_lookup:
103                    tag = child.tag
104                    if tag not in tag_lookup:
105                        tag = "div"
106                    tags.append(tag)
107                    text = self.extract_text(child)[:256]
108                    texts.append(text)
109                    paths.append(tree.getpath(child))
110                    i += 1
111                    edge_index.append([parent_id, i])
112                    stack.append((child, i))
113
114        text_embeddings = text_embedding_model.get_text_embeddings(texts)
115        tag_embeddings = tag_embedding_model.get_tag_embedding(tags)
116        x = []
117        for i in range(len(text_embeddings)):
118            x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i])))
119
120        input_features = torch.stack(x).to(self.device)
121        input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0)
122        input_edge_index = input_edge_index.to(self.device)
123
124        with torch.no_grad():
125            features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach()
126        torch.cuda.empty_cache()
127        del input_features, input_edge_index
128        return features, paths
def clean_text(self, text):
130    def clean_text(self, text):
131        if not text:
132            return
133        cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split())
134        return cleaned_text
def extract_text(self, element) -> str:
136    def extract_text(self, element) -> str:
137        text = self.clean_text(element.text)
138        if text:
139            return text
140
141        for label in ["alt", "tite", "aria-label"]:
142            text = self.clean_text(element.get(label))
143            if text:
144                return text
145        return ""