webleaf.model.WebGraphAutoEncoder
1from torch_geometric.nn import GAE, GCN2Conv 2import torch.nn.functional as F 3from .TagModel import TagEmbeddingModel, exclude_html_tags, html_tags, TAG_DIMS 4from .TextModel import TextEmbeddingModel, TEXT_DIMS 5from torch_geometric.utils import subgraph 6from torch.nn import Linear 7import os 8from lxml import etree 9from pathlib import Path 10import torch 11import re 12 13MODEL_PATH = os.path.join(Path(__file__).parent.absolute(), f"product_page_model_4_80.torch") 14EMBEDDING_DIMENSIONS = 32 15 16tag_embedding_model = None 17text_embedding_model = None 18 19 20class GCNEncoder(torch.nn.Module): 21 def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0): 22 super().__init__() 23 24 self.lins = torch.nn.ModuleList() 25 self.lins.append(Linear(input_channels, hidden_channels)) 26 self.lins.append(Linear(hidden_channels, output_channels)) 27 28 self.convs = torch.nn.ModuleList() 29 for layer in range(num_layers): 30 self.convs.append( 31 GCN2Conv(hidden_channels, alpha, theta, layer + 1, 32 shared_weights, normalize=False)) 33 34 self.dropout = dropout 35 36 def forward(self, x, edge_index): 37 x = F.dropout(x, self.dropout, training=self.training) 38 x = x_0 = self.lins[0](x).relu() 39 40 for conv in self.convs: 41 x = F.dropout(x, self.dropout, training=self.training) 42 x = conv(x, x_0, edge_index) 43 x = x.relu() 44 45 x = F.dropout(x, self.dropout, training=self.training) 46 x = self.lins[1](x) 47 48 return x 49 50 51class WebGraphAutoEncoder: 52 def __init__(self): 53 global tag_embedding_model 54 if not tag_embedding_model: 55 tag_embedding_model = TagEmbeddingModel() 56 57 global text_embedding_model 58 if not text_embedding_model: 59 text_embedding_model = TextEmbeddingModel() 60 61 num_features = TAG_DIMS + TEXT_DIMS 62 hidden = 128 63 out_channels = EMBEDDING_DIMENSIONS 64 encoder = GCNEncoder(input_channels=num_features, hidden_channels=hidden, output_channels=out_channels, num_layers=6, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.1) 65 self.model = GAE(encoder) 66 67 assert os.path.exists(MODEL_PATH), f"Could not find WebLeaf model at [{MODEL_PATH}]" 68 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 69 self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device)) 70 self.model.eval() 71 self.model = self.model.to(self.device) 72 73 def extract(self, tree): 74 root = tree.getroot() 75 76 # List of formatting tags we want to remove 77 formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins'] 78 79 etree.strip_tags(root, *formatting_tags) 80 81 stack = [(root, 0)] 82 exclude_tag_lookup = set(exclude_html_tags) 83 tag_lookup = set(html_tags) 84 assert tree, "Could not create tree" 85 86 i = 0 87 texts = [""] 88 tags = [root.tag] 89 edge_index = [] 90 paths = [tree.getpath(root)] 91 while stack: 92 element, parent_id = stack.pop(0) 93 94 for index, child in enumerate(element): 95 if isinstance(child, etree._Comment): 96 continue 97 98 while child.tag == "div" and len(child) == 1: 99 child = child[0] 100 101 if child.tag not in exclude_tag_lookup: 102 tag = child.tag 103 if tag not in tag_lookup: 104 tag = "div" 105 tags.append(tag) 106 text = self.extract_text(child)[:256] 107 texts.append(text) 108 paths.append(tree.getpath(child)) 109 i += 1 110 edge_index.append([parent_id, i]) 111 stack.append((child, i)) 112 113 text_embeddings = text_embedding_model.get_text_embeddings(texts) 114 tag_embeddings = tag_embedding_model.get_tag_embedding(tags) 115 x = [] 116 for i in range(len(text_embeddings)): 117 x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i]))) 118 119 input_features = torch.stack(x).to(self.device) 120 input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0) 121 input_edge_index = input_edge_index.to(self.device) 122 123 with torch.no_grad(): 124 features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach() 125 torch.cuda.empty_cache() 126 del input_features, input_edge_index 127 return features, paths 128 129 def clean_text(self, text): 130 if not text: 131 return 132 cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split()) 133 return cleaned_text 134 135 def extract_text(self, element) -> str: 136 text = self.clean_text(element.text) 137 if text: 138 return text 139 140 for label in ["alt", "tite", "aria-label"]: 141 text = self.clean_text(element.get(label)) 142 if text: 143 return text 144 return ""
21class GCNEncoder(torch.nn.Module): 22 def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0): 23 super().__init__() 24 25 self.lins = torch.nn.ModuleList() 26 self.lins.append(Linear(input_channels, hidden_channels)) 27 self.lins.append(Linear(hidden_channels, output_channels)) 28 29 self.convs = torch.nn.ModuleList() 30 for layer in range(num_layers): 31 self.convs.append( 32 GCN2Conv(hidden_channels, alpha, theta, layer + 1, 33 shared_weights, normalize=False)) 34 35 self.dropout = dropout 36 37 def forward(self, x, edge_index): 38 x = F.dropout(x, self.dropout, training=self.training) 39 x = x_0 = self.lins[0](x).relu() 40 41 for conv in self.convs: 42 x = F.dropout(x, self.dropout, training=self.training) 43 x = conv(x, x_0, edge_index) 44 x = x.relu() 45 46 x = F.dropout(x, self.dropout, training=self.training) 47 x = self.lins[1](x) 48 49 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their
parameters converted too when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
22 def __init__(self, input_channels, hidden_channels, output_channels, num_layers, alpha, theta, shared_weights=True, dropout=0.0): 23 super().__init__() 24 25 self.lins = torch.nn.ModuleList() 26 self.lins.append(Linear(input_channels, hidden_channels)) 27 self.lins.append(Linear(hidden_channels, output_channels)) 28 29 self.convs = torch.nn.ModuleList() 30 for layer in range(num_layers): 31 self.convs.append( 32 GCN2Conv(hidden_channels, alpha, theta, layer + 1, 33 shared_weights, normalize=False)) 34 35 self.dropout = dropout
Initialize internal Module state, shared by both nn.Module and ScriptModule.
37 def forward(self, x, edge_index): 38 x = F.dropout(x, self.dropout, training=self.training) 39 x = x_0 = self.lins[0](x).relu() 40 41 for conv in self.convs: 42 x = F.dropout(x, self.dropout, training=self.training) 43 x = conv(x, x_0, edge_index) 44 x = x.relu() 45 46 x = F.dropout(x, self.dropout, training=self.training) 47 x = self.lins[1](x) 48 49 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
52class WebGraphAutoEncoder: 53 def __init__(self): 54 global tag_embedding_model 55 if not tag_embedding_model: 56 tag_embedding_model = TagEmbeddingModel() 57 58 global text_embedding_model 59 if not text_embedding_model: 60 text_embedding_model = TextEmbeddingModel() 61 62 num_features = TAG_DIMS + TEXT_DIMS 63 hidden = 128 64 out_channels = EMBEDDING_DIMENSIONS 65 encoder = GCNEncoder(input_channels=num_features, hidden_channels=hidden, output_channels=out_channels, num_layers=6, alpha=0.1, theta=0.5, shared_weights=True, dropout=0.1) 66 self.model = GAE(encoder) 67 68 assert os.path.exists(MODEL_PATH), f"Could not find WebLeaf model at [{MODEL_PATH}]" 69 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 70 self.model.load_state_dict(torch.load(MODEL_PATH, map_location=self.device)) 71 self.model.eval() 72 self.model = self.model.to(self.device) 73 74 def extract(self, tree): 75 root = tree.getroot() 76 77 # List of formatting tags we want to remove 78 formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins'] 79 80 etree.strip_tags(root, *formatting_tags) 81 82 stack = [(root, 0)] 83 exclude_tag_lookup = set(exclude_html_tags) 84 tag_lookup = set(html_tags) 85 assert tree, "Could not create tree" 86 87 i = 0 88 texts = [""] 89 tags = [root.tag] 90 edge_index = [] 91 paths = [tree.getpath(root)] 92 while stack: 93 element, parent_id = stack.pop(0) 94 95 for index, child in enumerate(element): 96 if isinstance(child, etree._Comment): 97 continue 98 99 while child.tag == "div" and len(child) == 1: 100 child = child[0] 101 102 if child.tag not in exclude_tag_lookup: 103 tag = child.tag 104 if tag not in tag_lookup: 105 tag = "div" 106 tags.append(tag) 107 text = self.extract_text(child)[:256] 108 texts.append(text) 109 paths.append(tree.getpath(child)) 110 i += 1 111 edge_index.append([parent_id, i]) 112 stack.append((child, i)) 113 114 text_embeddings = text_embedding_model.get_text_embeddings(texts) 115 tag_embeddings = tag_embedding_model.get_tag_embedding(tags) 116 x = [] 117 for i in range(len(text_embeddings)): 118 x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i]))) 119 120 input_features = torch.stack(x).to(self.device) 121 input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0) 122 input_edge_index = input_edge_index.to(self.device) 123 124 with torch.no_grad(): 125 features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach() 126 torch.cuda.empty_cache() 127 del input_features, input_edge_index 128 return features, paths 129 130 def clean_text(self, text): 131 if not text: 132 return 133 cleaned_text = ' '.join(re.sub(r'[^a-zA-Z\s.,!?\'\";:]', '', text).split()) 134 return cleaned_text 135 136 def extract_text(self, element) -> str: 137 text = self.clean_text(element.text) 138 if text: 139 return text 140 141 for label in ["alt", "tite", "aria-label"]: 142 text = self.clean_text(element.get(label)) 143 if text: 144 return text 145 return ""
74 def extract(self, tree): 75 root = tree.getroot() 76 77 # List of formatting tags we want to remove 78 formatting_tags = ['b', 'i', 'u', 'strong', 'em', 'mark', 'small', 'del', 'ins'] 79 80 etree.strip_tags(root, *formatting_tags) 81 82 stack = [(root, 0)] 83 exclude_tag_lookup = set(exclude_html_tags) 84 tag_lookup = set(html_tags) 85 assert tree, "Could not create tree" 86 87 i = 0 88 texts = [""] 89 tags = [root.tag] 90 edge_index = [] 91 paths = [tree.getpath(root)] 92 while stack: 93 element, parent_id = stack.pop(0) 94 95 for index, child in enumerate(element): 96 if isinstance(child, etree._Comment): 97 continue 98 99 while child.tag == "div" and len(child) == 1: 100 child = child[0] 101 102 if child.tag not in exclude_tag_lookup: 103 tag = child.tag 104 if tag not in tag_lookup: 105 tag = "div" 106 tags.append(tag) 107 text = self.extract_text(child)[:256] 108 texts.append(text) 109 paths.append(tree.getpath(child)) 110 i += 1 111 edge_index.append([parent_id, i]) 112 stack.append((child, i)) 113 114 text_embeddings = text_embedding_model.get_text_embeddings(texts) 115 tag_embeddings = tag_embedding_model.get_tag_embedding(tags) 116 x = [] 117 for i in range(len(text_embeddings)): 118 x.append(torch.concatenate((torch.from_numpy(text_embeddings[i]), tag_embeddings[i]))) 119 120 input_features = torch.stack(x).to(self.device) 121 input_edge_index = torch.tensor(edge_index, dtype=torch.int64).permute(1, 0) 122 input_edge_index = input_edge_index.to(self.device) 123 124 with torch.no_grad(): 125 features = self.model.encode(input_features, edge_index=input_edge_index).cpu().detach() 126 torch.cuda.empty_cache() 127 del input_features, input_edge_index 128 return features, paths