webleaf.model.TagModel
1import torch 2import torch.nn as nn 3import os 4from pathlib import Path 5 6# Define the path where tag embeddings are stored 7TAG_PATH = os.path.join(Path(__file__).parent.absolute(), f"tag_embeddings.torch") 8# The dimensionality of the tag embeddings 9TAG_DIMS = 8 10 11# List of HTML tags that will be embedded 12html_tags = [ 13 'a', 'abbr', 'address', 'area', 'article', 'aside', 'audio', 'b', 'base', 'bdi', 'bdo', 'blockquote', 14 'body', 'br', 'button', 'canvas', 'caption', 'cite', 'code', 'col', 'colgroup', 'data', 'datalist', 'dd', 15 'del', 'details', 'dfn', 'dialog', 'div', 'dl', 'dt', 'em', 'embed', 'fieldset', 'figcaption', 'figure', 16 'footer', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'head', 'header', 'hr', 'html', 'i', 'iframe', 'img', 17 'input', 'ins', 'kbd', 'label', 'legend', 'li', 'link', 'main', 'map', 'mark', 'meter', 'nav', 'noscript', 18 'object', 'ol', 'optgroup', 'option', 'output', 'p', 'param', 'picture', 'pre', 'progress', 'q', 'rp', 'rt', 'ruby', 19 's', 'samp', 'section', 'select', 'small', 'source', 'span', 'strong', 'sub', 'summary', 'sup', 20 'table', 'tbody', 'td', 'template', 'textarea', 'tfoot', 'th', 'thead', 'time', 'title', 'tr', 'track', 'u', 'ul', 21 'var', 'video', 'wbr' 22] 23 24exclude_html_tags = [ 25 "script", "meta", "style", "svg", 26] 27 28 29class NormalizedEmbedding(nn.Module): 30 """ 31 A PyTorch module that creates an embedding layer and normalizes the output. 32 33 Attributes: 34 ----------- 35 embedding : torch.nn.Embedding 36 The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space. 37 38 Methods: 39 -------- 40 forward(x): 41 Performs the forward pass by computing embeddings for the input and normalizing them. 42 """ 43 def __init__(self, n_classes, m_dimensions): 44 """ 45 Initializes the NormalizedEmbedding class. 46 47 Parameters: 48 ----------- 49 n_classes : int 50 The number of classes (i.e., HTML tags) to embed. 51 m_dimensions : int 52 The number of dimensions in the embedding space. 53 """ 54 super(NormalizedEmbedding, self).__init__() 55 # Create the embedding layer 56 self.embedding = nn.Embedding(n_classes, m_dimensions) 57 58 # Initialize the embedding weights randomly 59 nn.init.xavier_uniform_(self.embedding.weight) 60 61 def forward(self, x): 62 """ 63 Forward pass of the embedding layer with normalization. 64 65 Parameters: 66 ----------- 67 x : torch.Tensor 68 A tensor containing indices of HTML tags to be embedded. 69 70 Returns: 71 -------- 72 torch.Tensor 73 A tensor of normalized embeddings for the input tags. 74 """ 75 # Get the embeddings 76 embed = self.embedding(x) 77 78 # Normalize the embeddings to have unit length 79 normalized_embed = embed / embed.norm(dim=1, keepdim=True) 80 return normalized_embed 81 82 83class TagEmbeddingModel: 84 """ 85 A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and 86 provides a method to retrieve embeddings for specific HTML tags. 87 88 Attributes: 89 ----------- 90 embedding_model : NormalizedEmbedding 91 The model that handles tag embeddings and their normalization. 92 93 Methods: 94 -------- 95 get_tag_embedding(tags): 96 Retrieves the embeddings for a list of HTML tags. 97 """ 98 def __init__(self): 99 """ 100 Initializes the TagEmbeddingModel class by loading a pre-trained embedding model. 101 102 Raises: 103 ------- 104 AssertionError if the pre-trained model cannot be found at the specified path. 105 """ 106 self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS) 107 assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]" 108 self.embedding_model.load_state_dict(torch.load(TAG_PATH)) 109 110 def get_tag_embedding(self, tags): 111 """ 112 Retrieves the embeddings for the provided HTML tags. 113 114 Parameters: 115 ----------- 116 tags : list of str 117 A list of HTML tags for which to retrieve the embeddings. 118 119 Returns: 120 -------- 121 torch.Tensor 122 A tensor containing the embeddings for the specified HTML tags. 123 """ 124 return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))
30class NormalizedEmbedding(nn.Module): 31 """ 32 A PyTorch module that creates an embedding layer and normalizes the output. 33 34 Attributes: 35 ----------- 36 embedding : torch.nn.Embedding 37 The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space. 38 39 Methods: 40 -------- 41 forward(x): 42 Performs the forward pass by computing embeddings for the input and normalizing them. 43 """ 44 def __init__(self, n_classes, m_dimensions): 45 """ 46 Initializes the NormalizedEmbedding class. 47 48 Parameters: 49 ----------- 50 n_classes : int 51 The number of classes (i.e., HTML tags) to embed. 52 m_dimensions : int 53 The number of dimensions in the embedding space. 54 """ 55 super(NormalizedEmbedding, self).__init__() 56 # Create the embedding layer 57 self.embedding = nn.Embedding(n_classes, m_dimensions) 58 59 # Initialize the embedding weights randomly 60 nn.init.xavier_uniform_(self.embedding.weight) 61 62 def forward(self, x): 63 """ 64 Forward pass of the embedding layer with normalization. 65 66 Parameters: 67 ----------- 68 x : torch.Tensor 69 A tensor containing indices of HTML tags to be embedded. 70 71 Returns: 72 -------- 73 torch.Tensor 74 A tensor of normalized embeddings for the input tags. 75 """ 76 # Get the embeddings 77 embed = self.embedding(x) 78 79 # Normalize the embeddings to have unit length 80 normalized_embed = embed / embed.norm(dim=1, keepdim=True) 81 return normalized_embed
A PyTorch module that creates an embedding layer and normalizes the output.
Attributes:
embedding : torch.nn.Embedding The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space.
Methods:
forward(x): Performs the forward pass by computing embeddings for the input and normalizing them.
44 def __init__(self, n_classes, m_dimensions): 45 """ 46 Initializes the NormalizedEmbedding class. 47 48 Parameters: 49 ----------- 50 n_classes : int 51 The number of classes (i.e., HTML tags) to embed. 52 m_dimensions : int 53 The number of dimensions in the embedding space. 54 """ 55 super(NormalizedEmbedding, self).__init__() 56 # Create the embedding layer 57 self.embedding = nn.Embedding(n_classes, m_dimensions) 58 59 # Initialize the embedding weights randomly 60 nn.init.xavier_uniform_(self.embedding.weight)
Initializes the NormalizedEmbedding class.
Parameters:
n_classes : int The number of classes (i.e., HTML tags) to embed. m_dimensions : int The number of dimensions in the embedding space.
62 def forward(self, x): 63 """ 64 Forward pass of the embedding layer with normalization. 65 66 Parameters: 67 ----------- 68 x : torch.Tensor 69 A tensor containing indices of HTML tags to be embedded. 70 71 Returns: 72 -------- 73 torch.Tensor 74 A tensor of normalized embeddings for the input tags. 75 """ 76 # Get the embeddings 77 embed = self.embedding(x) 78 79 # Normalize the embeddings to have unit length 80 normalized_embed = embed / embed.norm(dim=1, keepdim=True) 81 return normalized_embed
Forward pass of the embedding layer with normalization.
Parameters:
x : torch.Tensor A tensor containing indices of HTML tags to be embedded.
Returns:
torch.Tensor A tensor of normalized embeddings for the input tags.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
84class TagEmbeddingModel: 85 """ 86 A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and 87 provides a method to retrieve embeddings for specific HTML tags. 88 89 Attributes: 90 ----------- 91 embedding_model : NormalizedEmbedding 92 The model that handles tag embeddings and their normalization. 93 94 Methods: 95 -------- 96 get_tag_embedding(tags): 97 Retrieves the embeddings for a list of HTML tags. 98 """ 99 def __init__(self): 100 """ 101 Initializes the TagEmbeddingModel class by loading a pre-trained embedding model. 102 103 Raises: 104 ------- 105 AssertionError if the pre-trained model cannot be found at the specified path. 106 """ 107 self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS) 108 assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]" 109 self.embedding_model.load_state_dict(torch.load(TAG_PATH)) 110 111 def get_tag_embedding(self, tags): 112 """ 113 Retrieves the embeddings for the provided HTML tags. 114 115 Parameters: 116 ----------- 117 tags : list of str 118 A list of HTML tags for which to retrieve the embeddings. 119 120 Returns: 121 -------- 122 torch.Tensor 123 A tensor containing the embeddings for the specified HTML tags. 124 """ 125 return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))
A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and provides a method to retrieve embeddings for specific HTML tags.
Attributes:
embedding_model : NormalizedEmbedding The model that handles tag embeddings and their normalization.
Methods:
get_tag_embedding(tags): Retrieves the embeddings for a list of HTML tags.
99 def __init__(self): 100 """ 101 Initializes the TagEmbeddingModel class by loading a pre-trained embedding model. 102 103 Raises: 104 ------- 105 AssertionError if the pre-trained model cannot be found at the specified path. 106 """ 107 self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS) 108 assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]" 109 self.embedding_model.load_state_dict(torch.load(TAG_PATH))
Initializes the TagEmbeddingModel class by loading a pre-trained embedding model.
Raises:
AssertionError if the pre-trained model cannot be found at the specified path.
111 def get_tag_embedding(self, tags): 112 """ 113 Retrieves the embeddings for the provided HTML tags. 114 115 Parameters: 116 ----------- 117 tags : list of str 118 A list of HTML tags for which to retrieve the embeddings. 119 120 Returns: 121 -------- 122 torch.Tensor 123 A tensor containing the embeddings for the specified HTML tags. 124 """ 125 return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))
Retrieves the embeddings for the provided HTML tags.
Parameters:
tags : list of str A list of HTML tags for which to retrieve the embeddings.
Returns:
torch.Tensor A tensor containing the embeddings for the specified HTML tags.