Customize Model Template
Here, we present how to develop a new Model Template, and apply it into EduStudio. EduStudio provides the Model Protocol in EduStudio.edustudio.model.basemodeltpl.BaseModelTPL
(BaseModelTPL
).
Protocol
BaseModelTPL
The protocols in BaseModelTPL
are listed as follows.
name |
description |
type |
note |
---|---|---|---|
default_cfg |
the default configuration |
class variable |
|
add_extra_data |
add extra data in addition to the data loaders |
function interface |
implemented by subclass |
BaseProxyModelTPL
This protocol is for some works proposed general frameworks which can replace the backbone model. We implement the BaseProxyModelTPL
Protocol. Users only need to inherit this protocol and set the backbone_model_cls
parameter in the default_cfg
.
ModelTPL
EduStudio provides the ModelTPL in EduStudio.edustudio.model.gd_basemodeltpl.GDBaseModelTPL
(GDBaseModelTPL
). GDBaseModelTPL
inherents BaseModelTPL
, and the methods in GDBaseModelTPL
are listed as follows.
name |
description |
type |
note |
---|---|---|---|
default_cfg |
the default configuration |
class variable |
|
build_cfg |
construct model config |
abstract method |
implemented by subclass |
build_model |
construct model component |
abstract method |
implemented by subclass |
predict |
predict function |
function interface |
implemented by subclass |
get_loss_dict |
obtain loss |
function interface |
implemented by subclass |
_init_params |
initial parameters |
function interface |
implemented by subclass |
_load_params_from_pretrained |
load parameters as dict |
function interface |
implemented by subclass |
Develop a New ModelTPL in EduStudio
When you develope a new model in EduStudio, then you can inherent GDBaseModelTPL
and implement the abstract methods build_cfg()
and build_model()
. Then, you can revise the function predict()
and get_loss_dict()
. You can also define the configuration of the new model template in the dictionary default_cfg
.
If you want to develope a new ModelTPL for a backbone-style model, then you can inherent BaseProxyModelTPL
and set the backbone_model_cls
parameter in default_cfg
. Then, you can implement the some methods such as build_model()
, get_loss_dict()
, and predict()
.
Example 1: Develop a traditional ModelTPL
from ..gd_basemodeltpl import GDBaseModelTPL
class NewModelTPL(GDBaseModelTPL):
default_cfg = {
'dnn_units': [512, 256],
'dropout_rate': 0.5,
'activation': 'sigmoid',
'disc_scale': 10
}
def __init__(self, cfg):
super().__init__(cfg)
def build_cfg(self):
self.n_user = self.datafmt_cfg['dt_info']['stu_count']
...
def build_model(self):
...
self.pd_net = PosMLP(
input_dim=self.n_cpt, output_dim=1, activation=self.model_cfg['activation'],
dnn_units=self.model_cfg['dnn_units'], dropout_rate=self.model_cfg['dropout_rate']
)
def forward(self, stu_id, exer_id, Q_mat, **kwargs):
...
pd = self.pd_net(input_x).sigmoid()
return pd
@torch.no_grad()
def predict(self, stu_id, exer_id, Q_mat, **kwargs):
return {
'y_pd': self(stu_id, exer_id, Q_mat).flatten(),
}
def get_main_loss(self, **kwargs):
...
pd = self(stu_id, exer_id, Q_mat).flatten()
loss = F.binary_cross_entropy(input=pd, target=label)
return {
'loss_main': loss
}
def get_loss_dict(self, **kwargs):
return self.get_main_loss(**kwargs)
Example 2: Develop a backbone-style ModelTPL
from ..basemodeltpl import BaseProxyModelTPL
class NewBackboneModelTPL(BaseProxyModelTPL):
default_cfg = {
"backbone_model_cls": "IRT",
}
def __init__(self, cfg):
super().__init__(cfg)
def build_model(self):
super().build_model()
self.irr_pair_loss = PairSCELoss()
def get_main_loss(self, **kwargs):
pair_exer = kwargs['pair_exer']
pair_pos_stu = kwargs['pair_pos_stu']
pair_neg_stu = kwargs['pair_neg_stu']
kwargs['exer_id'] = pair_exer
kwargs['stu_id'] = pair_pos_stu
pos_pd = self(**kwargs).flatten()
kwargs['stu_id'] = pair_neg_stu
neg_pd = self(**kwargs).flatten()
return {
'loss_main': self.irr_pair_loss(pos_pd, neg_pd)
}
class PairSCELoss(nn.Module):
...