webleaf.model.TagModel

  1import torch
  2import torch.nn as nn
  3import os
  4from pathlib import Path
  5
  6# Define the path where tag embeddings are stored
  7TAG_PATH = os.path.join(Path(__file__).parent.absolute(), f"tag_embeddings.torch")
  8# The dimensionality of the tag embeddings
  9TAG_DIMS = 8
 10
 11# List of HTML tags that will be embedded
 12html_tags = [
 13    'a', 'abbr', 'address', 'area', 'article', 'aside', 'audio', 'b', 'base', 'bdi', 'bdo', 'blockquote',
 14    'body', 'br', 'button', 'canvas', 'caption', 'cite', 'code', 'col', 'colgroup', 'data', 'datalist', 'dd',
 15    'del', 'details', 'dfn', 'dialog', 'div', 'dl', 'dt', 'em', 'embed', 'fieldset', 'figcaption', 'figure',
 16    'footer', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'head', 'header', 'hr', 'html', 'i', 'iframe', 'img',
 17    'input', 'ins', 'kbd', 'label', 'legend', 'li', 'link', 'main', 'map', 'mark', 'meter', 'nav', 'noscript',
 18    'object', 'ol', 'optgroup', 'option', 'output', 'p', 'param', 'picture', 'pre', 'progress', 'q', 'rp', 'rt', 'ruby',
 19    's', 'samp', 'section', 'select', 'small', 'source', 'span', 'strong', 'sub', 'summary', 'sup',
 20    'table', 'tbody', 'td', 'template', 'textarea', 'tfoot', 'th', 'thead', 'time', 'title', 'tr', 'track', 'u', 'ul',
 21    'var', 'video', 'wbr'
 22]
 23
 24exclude_html_tags = [
 25    "script", "meta", "style", "svg",
 26]
 27
 28
 29class NormalizedEmbedding(nn.Module):
 30    """
 31    A PyTorch module that creates an embedding layer and normalizes the output.
 32
 33    Attributes:
 34    -----------
 35    embedding : torch.nn.Embedding
 36        The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space.
 37
 38    Methods:
 39    --------
 40    forward(x):
 41        Performs the forward pass by computing embeddings for the input and normalizing them.
 42    """
 43    def __init__(self, n_classes, m_dimensions):
 44        """
 45        Initializes the NormalizedEmbedding class.
 46
 47        Parameters:
 48        -----------
 49        n_classes : int
 50            The number of classes (i.e., HTML tags) to embed.
 51        m_dimensions : int
 52            The number of dimensions in the embedding space.
 53        """
 54        super(NormalizedEmbedding, self).__init__()
 55        # Create the embedding layer
 56        self.embedding = nn.Embedding(n_classes, m_dimensions)
 57
 58        # Initialize the embedding weights randomly
 59        nn.init.xavier_uniform_(self.embedding.weight)
 60
 61    def forward(self, x):
 62        """
 63        Forward pass of the embedding layer with normalization.
 64
 65        Parameters:
 66        -----------
 67        x : torch.Tensor
 68            A tensor containing indices of HTML tags to be embedded.
 69
 70        Returns:
 71        --------
 72        torch.Tensor
 73            A tensor of normalized embeddings for the input tags.
 74        """
 75        # Get the embeddings
 76        embed = self.embedding(x)
 77
 78        # Normalize the embeddings to have unit length
 79        normalized_embed = embed / embed.norm(dim=1, keepdim=True)
 80        return normalized_embed
 81
 82
 83class TagEmbeddingModel:
 84    """
 85    A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and
 86    provides a method to retrieve embeddings for specific HTML tags.
 87
 88    Attributes:
 89    -----------
 90    embedding_model : NormalizedEmbedding
 91        The model that handles tag embeddings and their normalization.
 92
 93    Methods:
 94    --------
 95    get_tag_embedding(tags):
 96        Retrieves the embeddings for a list of HTML tags.
 97    """
 98    def __init__(self):
 99        """
100         Initializes the TagEmbeddingModel class by loading a pre-trained embedding model.
101
102         Raises:
103         -------
104         AssertionError if the pre-trained model cannot be found at the specified path.
105         """
106        self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS)
107        assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]"
108        self.embedding_model.load_state_dict(torch.load(TAG_PATH))
109
110    def get_tag_embedding(self, tags):
111        """
112         Retrieves the embeddings for the provided HTML tags.
113
114         Parameters:
115         -----------
116         tags : list of str
117             A list of HTML tags for which to retrieve the embeddings.
118
119         Returns:
120         --------
121         torch.Tensor
122             A tensor containing the embeddings for the specified HTML tags.
123         """
124        return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))
TAG_PATH = '/home/runner/work/WebLeaf/WebLeaf/webleaf/model/tag_embeddings.torch'
TAG_DIMS = 8
html_tags = ['a', 'abbr', 'address', 'area', 'article', 'aside', 'audio', 'b', 'base', 'bdi', 'bdo', 'blockquote', 'body', 'br', 'button', 'canvas', 'caption', 'cite', 'code', 'col', 'colgroup', 'data', 'datalist', 'dd', 'del', 'details', 'dfn', 'dialog', 'div', 'dl', 'dt', 'em', 'embed', 'fieldset', 'figcaption', 'figure', 'footer', 'form', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'head', 'header', 'hr', 'html', 'i', 'iframe', 'img', 'input', 'ins', 'kbd', 'label', 'legend', 'li', 'link', 'main', 'map', 'mark', 'meter', 'nav', 'noscript', 'object', 'ol', 'optgroup', 'option', 'output', 'p', 'param', 'picture', 'pre', 'progress', 'q', 'rp', 'rt', 'ruby', 's', 'samp', 'section', 'select', 'small', 'source', 'span', 'strong', 'sub', 'summary', 'sup', 'table', 'tbody', 'td', 'template', 'textarea', 'tfoot', 'th', 'thead', 'time', 'title', 'tr', 'track', 'u', 'ul', 'var', 'video', 'wbr']
exclude_html_tags = ['script', 'meta', 'style', 'svg']
class NormalizedEmbedding(torch.nn.modules.module.Module):
30class NormalizedEmbedding(nn.Module):
31    """
32    A PyTorch module that creates an embedding layer and normalizes the output.
33
34    Attributes:
35    -----------
36    embedding : torch.nn.Embedding
37        The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space.
38
39    Methods:
40    --------
41    forward(x):
42        Performs the forward pass by computing embeddings for the input and normalizing them.
43    """
44    def __init__(self, n_classes, m_dimensions):
45        """
46        Initializes the NormalizedEmbedding class.
47
48        Parameters:
49        -----------
50        n_classes : int
51            The number of classes (i.e., HTML tags) to embed.
52        m_dimensions : int
53            The number of dimensions in the embedding space.
54        """
55        super(NormalizedEmbedding, self).__init__()
56        # Create the embedding layer
57        self.embedding = nn.Embedding(n_classes, m_dimensions)
58
59        # Initialize the embedding weights randomly
60        nn.init.xavier_uniform_(self.embedding.weight)
61
62    def forward(self, x):
63        """
64        Forward pass of the embedding layer with normalization.
65
66        Parameters:
67        -----------
68        x : torch.Tensor
69            A tensor containing indices of HTML tags to be embedded.
70
71        Returns:
72        --------
73        torch.Tensor
74            A tensor of normalized embeddings for the input tags.
75        """
76        # Get the embeddings
77        embed = self.embedding(x)
78
79        # Normalize the embeddings to have unit length
80        normalized_embed = embed / embed.norm(dim=1, keepdim=True)
81        return normalized_embed

A PyTorch module that creates an embedding layer and normalizes the output.

Attributes:

embedding : torch.nn.Embedding The embedding layer that learns to map each HTML tag to a vector in a fixed-dimensional space.

Methods:

forward(x): Performs the forward pass by computing embeddings for the input and normalizing them.

NormalizedEmbedding(n_classes, m_dimensions)
44    def __init__(self, n_classes, m_dimensions):
45        """
46        Initializes the NormalizedEmbedding class.
47
48        Parameters:
49        -----------
50        n_classes : int
51            The number of classes (i.e., HTML tags) to embed.
52        m_dimensions : int
53            The number of dimensions in the embedding space.
54        """
55        super(NormalizedEmbedding, self).__init__()
56        # Create the embedding layer
57        self.embedding = nn.Embedding(n_classes, m_dimensions)
58
59        # Initialize the embedding weights randomly
60        nn.init.xavier_uniform_(self.embedding.weight)

Initializes the NormalizedEmbedding class.

Parameters:

n_classes : int The number of classes (i.e., HTML tags) to embed. m_dimensions : int The number of dimensions in the embedding space.

embedding
def forward(self, x):
62    def forward(self, x):
63        """
64        Forward pass of the embedding layer with normalization.
65
66        Parameters:
67        -----------
68        x : torch.Tensor
69            A tensor containing indices of HTML tags to be embedded.
70
71        Returns:
72        --------
73        torch.Tensor
74            A tensor of normalized embeddings for the input tags.
75        """
76        # Get the embeddings
77        embed = self.embedding(x)
78
79        # Normalize the embeddings to have unit length
80        normalized_embed = embed / embed.norm(dim=1, keepdim=True)
81        return normalized_embed

Forward pass of the embedding layer with normalization.

Parameters:

x : torch.Tensor A tensor containing indices of HTML tags to be embedded.

Returns:

torch.Tensor A tensor of normalized embeddings for the input tags.

Inherited Members
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class TagEmbeddingModel:
 84class TagEmbeddingModel:
 85    """
 86    A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and
 87    provides a method to retrieve embeddings for specific HTML tags.
 88
 89    Attributes:
 90    -----------
 91    embedding_model : NormalizedEmbedding
 92        The model that handles tag embeddings and their normalization.
 93
 94    Methods:
 95    --------
 96    get_tag_embedding(tags):
 97        Retrieves the embeddings for a list of HTML tags.
 98    """
 99    def __init__(self):
100        """
101         Initializes the TagEmbeddingModel class by loading a pre-trained embedding model.
102
103         Raises:
104         -------
105         AssertionError if the pre-trained model cannot be found at the specified path.
106         """
107        self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS)
108        assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]"
109        self.embedding_model.load_state_dict(torch.load(TAG_PATH))
110
111    def get_tag_embedding(self, tags):
112        """
113         Retrieves the embeddings for the provided HTML tags.
114
115         Parameters:
116         -----------
117         tags : list of str
118             A list of HTML tags for which to retrieve the embeddings.
119
120         Returns:
121         --------
122         torch.Tensor
123             A tensor containing the embeddings for the specified HTML tags.
124         """
125        return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))

A class for managing tag embeddings for HTML tags. It loads a pre-trained embedding model and provides a method to retrieve embeddings for specific HTML tags.

Attributes:

embedding_model : NormalizedEmbedding The model that handles tag embeddings and their normalization.

Methods:

get_tag_embedding(tags): Retrieves the embeddings for a list of HTML tags.

TagEmbeddingModel()
 99    def __init__(self):
100        """
101         Initializes the TagEmbeddingModel class by loading a pre-trained embedding model.
102
103         Raises:
104         -------
105         AssertionError if the pre-trained model cannot be found at the specified path.
106         """
107        self.embedding_model = NormalizedEmbedding(len(html_tags), TAG_DIMS)
108        assert os.path.exists(TAG_PATH), f"Could not find tag model at [{TAG_PATH}]"
109        self.embedding_model.load_state_dict(torch.load(TAG_PATH))

Initializes the TagEmbeddingModel class by loading a pre-trained embedding model.

Raises:

AssertionError if the pre-trained model cannot be found at the specified path.

embedding_model
def get_tag_embedding(self, tags):
111    def get_tag_embedding(self, tags):
112        """
113         Retrieves the embeddings for the provided HTML tags.
114
115         Parameters:
116         -----------
117         tags : list of str
118             A list of HTML tags for which to retrieve the embeddings.
119
120         Returns:
121         --------
122         torch.Tensor
123             A tensor containing the embeddings for the specified HTML tags.
124         """
125        return self.embedding_model(torch.tensor([html_tags.index(tag) for tag in tags]))

Retrieves the embeddings for the provided HTML tags.

Parameters:

tags : list of str A list of HTML tags for which to retrieve the embeddings.

Returns:

torch.Tensor A tensor containing the embeddings for the specified HTML tags.