On this weblog publish, we’ll dive into an attention-grabbing new strategy to Deep Studying (DL) referred to as Relational Deep Studying (RDL). We may even acquire some hands-on expertise by doing a little RDL on a real-world database (not a dataset!) of an e-commerce firm.
In the actual world, we normally have a relational database towards which we need to run some ML process. However particularly when the database is extremely normalized, this means plenty of time-consuming function engineering and lack of granularity as we now have to do many aggregations. What’s extra, there’s a myriad of potential combos of options that we will assemble every of which could yield good efficiency [2]. Which means we’re more likely to depart some info related to the ML process on the desk.
That is much like the early days of laptop imaginative and prescient, earlier than the appearance of deep neural networks the place options had been hand-crafted from the pixel values. These days, fashions work instantly with the uncooked pixels as a substitute of counting on this intermediate layer.
RDL guarantees to do the identical for tabular studying. That’s, it removes the additional step of establishing a function matrix by studying instantly on prime of your relational database. It does so by reworking the database with its relations right into a graph the place a row in a desk turns into a node and relations between tables grow to be edges. The row values are saved contained in the nodes as node options.
On this weblog publish, we will probably be utilizing this e-commerce dataset from kaggle which accommodates transactional knowledge about an e-commerce platform in a star schema with a central truth desk (transactions) and a few dimension tables. The total code could be discovered on this notebook.
All through this weblog publish, we will probably be utilizing the relbench library to do RDL. The very first thing we now have to do in relbench is to specify the schema of our relational database. Beneath is an instance of how we will accomplish that for the ‘transactions’ desk within the database. We give the desk as a pandas dataframe and specify the first key and the timestamp column. The first key column is used to uniquely determine the entity. The timestamp ensures that we will solely be taught from previous transactions after we need to forecast future transactions. Within the graph, which means info can solely circulation from nodes with a decrease timestamp (i.e. previously) to ones with the next timestamp. Moreover, we specify the overseas keys that exist within the relation. On this case, the transactions desk has the column ‘customer_key’ which is a overseas key that factors to the ‘customer_dim’ desk.
tables['transactions'] = Desk(
df=pd.DataFrame(t),
pkey_col='t_id',
fkey_col_to_pkey_table={
'customer_key': 'prospects',
'item_key': 'merchandise',
'store_key': 'shops'
},
time_col='date'
)
The remainder of the tables must be outlined in the identical approach. Observe that this is also automated if you have already got a database schema. Because the dataset is from Kaggle, I wanted to create the schema manually. We additionally have to convert the date columns to precise pandas datetime objects and take away any NaN values.
class EcommerceDataBase(Dataset):
# instance of making your personal dataset: https://github.com/snap-stanford/relbench/blob/essential/tutorials/custom_dataset.ipynbval_timestamp = pd.Timestamp(12 months=2018, month=1, day=1)
test_timestamp = pd.Timestamp(12 months=2020, month=1, day=1)
def make_db(self) -> Database:
tables = {}
prospects = load_csv_to_db(BASE_DIR + '/customer_dim.csv').drop(columns=['contact_no', 'nid']).rename(columns={'coustomer_key': 'customer_key'})
shops = load_csv_to_db(BASE_DIR + '/store_dim.csv').drop(columns=['upazila'])
merchandise = load_csv_to_db(BASE_DIR + '/item_dim.csv')
transactions = load_csv_to_db(BASE_DIR + '/fact_table.csv').rename(columns={'coustomer_key': 'customer_key'})
occasions = load_csv_to_db(BASE_DIR + '/time_dim.csv')
t = transactions.merge(occasions[['time_key', 'date']], on='time_key').drop(columns=['payment_key', 'time_key', 'unit'])
t['date'] = pd.to_datetime(t.date)
t = t.reset_index().rename(columns={'index': 't_id'})
t['quantity'] = t.amount.astype(int)
t['unit_price'] = t.unit_price.astype(float)
merchandise['unit_price'] = merchandise.unit_price.astype(float)
t['total_price'] = t.total_price.astype(float)
print(t.isna().sum(axis=0))
print(merchandise.isna().sum(axis=0))
print(shops.isna().sum(axis=0))
print(prospects.isna().sum(axis=0))
tables['products'] = Desk(
df=pd.DataFrame(merchandise),
pkey_col='item_key',
fkey_col_to_pkey_table={},
time_col=None
)
tables['customers'] = Desk(
df=pd.DataFrame(prospects),
pkey_col='customer_key',
fkey_col_to_pkey_table={},
time_col=None
)
tables['transactions'] = Desk(
df=pd.DataFrame(t),
pkey_col='t_id',
fkey_col_to_pkey_table={
'customer_key': 'prospects',
'item_key': 'merchandise',
'store_key': 'shops'
},
time_col='date'
)
tables['stores'] = Desk(
df=pd.DataFrame(shops),
pkey_col='store_key',
fkey_col_to_pkey_table={}
)
return Database(tables)
Crucially, the authors introduce the thought of a coaching desk. This coaching desk primarily defines the ML process. The concept right here is that we need to predict the longer term state (i.e. a future worth) of some entity within the database. We do that by specifying a desk the place every row has a timestamp, the identifier of the entity, and a few worth we need to predict. The id serves to specify the entity, the timestamp specifies at which cut-off date we have to predict the entity. This may even restrict the info that can be utilized to deduce the worth of this entity (i.e. solely previous knowledge). The worth itself is what we need to predict (i.e. floor reality).
In our case, we now have a web based platform with prospects. We need to predict a buyer’s income within the subsequent 30 days. We will create the coaching desk with a SQL assertion executed with DuckDB. That is the massive benefit of RDL as we might create any type of ML process with simply SQL. For instance, we will outline a question to pick out the variety of purchases of patrons within the subsequent 30 days to make a churn prediction.
df = duckdb.sql(f"""
choose
timestamp,
customer_key,
sum(total_price) as income
from
timestamp_df t
left be part of
transactions ta
on
ta.date <= t.timestamp + INTERVAL '{self.timedelta}'
and ta.date > t.timestamp
group by timestamp, customer_key
""").df().dropna()
The outcome will probably be a desk that has the seller_id as the important thing of the entity that we need to predict, the income because the goal, and the timestamp because the time at which we have to make the prediction (i.e. we will solely use knowledge up till this level to make the prediction).
Beneath is the entire code for creating the ‘customer_revenue’ process.
class CustomerRevenueTask(EntityTask):
# instance of customized process: https://github.com/snap-stanford/relbench/blob/essential/tutorials/custom_task.ipynbtask_type = TaskType.REGRESSION
entity_col = "customer_key"
entity_table = "prospects"
time_col = "timestamp"
target_col = "income"
timedelta = pd.Timedelta(days=30) # how far we need to predict income into the longer term.
metrics = [r2, mae]
num_eval_timestamps = 40
def make_table(self, db: Database, timestamps: "pd.Collection[pd.Timestamp]") -> Desk:
timestamp_df = pd.DataFrame({"timestamp": timestamps})
transactions = db.table_dict["transactions"].df
df = duckdb.sql(f"""
choose
timestamp,
customer_key,
sum(total_price) as income
from
timestamp_df t
left be part of
transactions ta
on
ta.date <= t.timestamp + INTERVAL '{self.timedelta}'
and ta.date > t.timestamp
group by timestamp, customer_key
""").df().dropna()
print(df)
return Desk(
df=df,
fkey_col_to_pkey_table={self.entity_col: self.entity_table},
pkey_col=None,
time_col=self.time_col,
)
With that, we now have performed the majority of the work. The remainder of the workflow will probably be related, impartial of the ML process. I used to be in a position to copy a lot of the code from the example notebook that relbench offers.
For instance, we have to encode the node options. Right here, we will use glove embeddings to encode all of the textual content options such because the product descriptions and the product names.
from typing import Checklist, Elective
from sentence_transformers import SentenceTransformer
from torch import Tensorclass GloveTextEmbedding:
def __init__(self, system: Elective[torch.device
] = None):
self.mannequin = SentenceTransformer(
"sentence-transformers/average_word_embeddings_glove.6B.300d",
system=system,
)
def __call__(self, sentences: Checklist[str]) -> Tensor:
return torch.from_numpy(self.mannequin.encode(sentences))
After that, we will apply these transformations to our knowledge and construct out the graph.
from torch_frame.config.text_embedder import TextEmbedderConfig
from relbench.modeling.graph import make_pkey_fkey_graphtext_embedder_cfg = TextEmbedderConfig(
text_embedder=GloveTextEmbedding(system=system), batch_size=256
)
knowledge, col_stats_dict = make_pkey_fkey_graph(
db,
col_to_stype_dict=col_to_stype_dict, # speficied column varieties
text_embedder_cfg=text_embedder_cfg, # our chosen textual content encoder
cache_dir=os.path.be part of(
root_dir, f"rel-ecomm_materialized_cache"
), # retailer materialized graph for comfort
)
The remainder of the code will probably be constructing the GNN from commonplace layers, coding the coaching loop, and doing a little evaluations. I’ll depart this code out of this weblog publish for brevity since it is rather commonplace and would be the identical throughout duties. You possibly can try the pocket book here.
Consequently, we will practice this GNN to succeed in an r2 of round 0.3 and an MAE of 500. Which means it predicts the vendor’s income within the subsequent 30 days with a median error of +- $500. In fact, we will’t know if that is good or not, perhaps we might have gotten an r2 of 80% with a mixture of classical ML and have engineering.
Relational Deep Studying is an attention-grabbing new strategy to ML particularly when we now have a fancy relational schema the place handbook function engineering could be too laborious. It offers us the power to outline an ML process with simply SQL which could be particularly helpful for people that aren’t deep into knowledge science however know some SQL. This additionally implies that we will iterate rapidly and experiment rather a lot with totally different duties.
On the identical time, this strategy presents its personal issues equivalent to the problem of coaching GNNs and establishing the graph from the relational schema. Moreover, the query is to what extent RDL can compete by way of efficiency with classical ML fashions. Previously, we now have seen that fashions equivalent to XGboost have confirmed to be higher than neural networks on tabular prediction issues.
- [1] Robinson, Joshua, et al. “RelBench: A Benchmark for Deep Studying on Relational Databases.” arXiv, 2024, https://arxiv.org/abs/2407.20060.
- [2] Fey, Matthias, et al. “Relational deep studying: Graph illustration studying on relational databases.” arXiv preprint arXiv:2312.04615 (2023).
- [3] Schlichtkrull, Michael, et al. “Modeling relational knowledge with graph convolutional networks.” The semantic internet: fifteenth worldwide convention, ESWC 2018, Heraklion, Crete, Greece, June 3–7, 2018, proceedings 15. Springer Worldwide Publishing, 2018.