基于 Transformer 的中文对联生成器

简介(Introduction)

本项目是一个基于 Transformer 的中文对联生成器,使用 PyTorch 构建模型,使用 Gradio 构建 Web UI。

数据集:https://www.kaggle.com/datasets/marquis03/chinese-couplets-dataset

GitHub 仓库:https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer

Gitee 仓库:https://gitee.com/marquis03/Chinese-Couplets-Generator-based-on-Transformer

项目结构(Structure)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
.
├── config
│ ├── __init__.py
│ └── config.py
├── data
│ ├── fixed_couplets_in.txt
│ └── fixed_couplets_out.txt
├── dataset
│ ├── __init__.py
│ └── dataset.py
├── img
│ ├── history.png
│ ├── lr_schedule.png
│ └── webui.gif
├── model
│ ├── __init__.py
│ └── model.py
├── trained
│ ├── vocab.pkl
│ └── CoupletsTransformer_best.pth
├── utils
│ ├── __init__.py
│ └── EarlyStopping.py
├── LICENSE
├── README.md
├── requirements.txt
├── train.py
└── webui.py

部署(Deployment)

克隆项目(Clone Project)

1
2
git clone https://github.com/Marquis03/Chinese-Couplets-Generator-based-on-Transformer.git
cd Chinese-Couplets-Generator-based-on-Transformer

安装依赖(Requirements)

1
pip install -r requirements.txt

训练模型(Train Model)

1
python train.py

Kaggle Notebook: https://www.kaggle.com/code/marquis03/chinese-couplets-generator-based-on-transformer

启动 Web UI(Start Web UI)

1
python webui.py

项目演示(Demo)

Web UI

Web UI

学习率变化(Learning Rate Schedule)

Learning Rate Schedule

训练历史(Training History)

Training History

代码(Code)

配置参数(Config)

该部分用于配置项目的参数,包括全局参数、路径参数、模型参数、训练参数和日志参数。

对应项目文件为 config/config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
import time
import torch
from loguru import logger


class Config:
def __init__(self):
# global
self.seed = 0
self.cuDNN = True
self.debug = False
self.num_workers = 0
self.str_time = time.strftime("%Y-%m-%dT%H%M%S", time.localtime(time.time()))
# path
self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
self.dataset_dir = os.path.join(self.project_dir, "data")
self.in_path = os.path.join(self.dataset_dir, "fixed_couplets_in.txt")
self.out_path = os.path.join(self.dataset_dir, "fixed_couplets_out.txt")
self.log_dir = os.path.join(self.project_dir, "logs")
self.save_dir = os.path.join(self.log_dir, self.str_time)
self.img_save_dir = os.path.join(self.save_dir, "images")
self.model_save_dir = os.path.join(self.save_dir, "checkpoints")
for path in (
self.log_dir,
self.save_dir,
self.img_save_dir,
self.model_save_dir,
):
if not os.path.exists(path):
os.makedirs(path)
# model
self.d_model = 256
self.num_head = 8
self.num_encoder_layers = 2
self.num_decoder_layers = 2
self.dim_feedforward = 1024
self.dropout = 0.1
# train
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.batch_size = 128
self.val_ratio = 0.1
self.epochs = 20
self.warmup_ratio = 0.12
self.lr_max = 1e-3
self.lr_min = 1e-4
self.beta1 = 0.9
self.beta2 = 0.98
self.epsilon = 10e-9
self.weight_decay = 0.01
self.early_stop = True
self.patience = 4
self.delta = 0
# log
logger.remove()
level_std = "DEBUG" if self.debug else "INFO"
logger.add(
sys.stdout,
colorize=True,
format="[<green>{time:YYYY-MM-DD HH:mm:ss,SSS}</green>|<level>{level: <8}</level>|<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan>] >>> <level>{message}</level>",
level=level_std,
)
logger.add(
os.path.join(self.save_dir, f"{self.str_time}.log"),
format="[{time:YYYY-MM-DD HH:mm:ss,SSS}|{level: <8}|{name}:{function}:{line}] >>> {message}",
level="INFO",
)
logger.info("### Config:")
for key, value in self.__dict__.items():
logger.info(f"### {key:20} = {value}")

数据集(Dataset)

该部分用于定义词典、数据集以及相关函数,包括数据的加载、词典的构建、数据集的封装和数据集的加载器。

对应项目文件为 dataset/dataset.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from collections import Counter

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader


def load_data(filepaths, tokenizer=lambda s: s.strip().split()):
raw_in_iter = iter(open(filepaths[0], encoding="utf8"))
raw_out_iter = iter(open(filepaths[1], encoding="utf8"))
return list(zip(map(tokenizer, raw_in_iter), map(tokenizer, raw_out_iter)))


class Vocab(object):
UNK = "<unk>" # 0
PAD = "<pad>" # 1
BOS = "<bos>" # 2
EOS = "<eos>" # 3

def __init__(self, data=None, min_freq=1):
counter = Counter()
for lines in data:
counter.update(lines[0])
counter.update(lines[1])
self.word2idx = {Vocab.UNK: 0, Vocab.PAD: 1, Vocab.BOS: 2, Vocab.EOS: 3}
self.idx2word = {0: Vocab.UNK, 1: Vocab.PAD, 2: Vocab.BOS, 3: Vocab.EOS}
idx = 4
for word, freq in counter.items():
if freq >= min_freq:
self.word2idx[word] = idx
self.idx2word[idx] = word
idx += 1

def __len__(self):
return len(self.word2idx)

def __getitem__(self, word):
return self.word2idx.get(word, 0)

def __call__(self, word):
if not isinstance(word, (list, tuple)):
return self[word]
return [self[w] for w in word]

def to_tokens(self, indices):
if not isinstance(indices, (list, tuple, np.ndarray, torch.Tensor)):
return self.idx2word[int(indices)]
return [self.idx2word[int(i)] for i in indices]


def pad_sequence(sequences, batch_first=False, padding_value=0):
max_len = max([s.size(0) for s in sequences])
out_tensors = []
for tensor in sequences:
padding_content = [padding_value] * (max_len - tensor.size(0))
tensor = torch.cat([tensor, torch.tensor(padding_content)], dim=0)
out_tensors.append(tensor)
out_tensors = torch.stack(out_tensors, dim=1)
if batch_first:
out_tensors = out_tensors.transpose(0, 1)
return out_tensors.long()


class CoupletsDataset(Dataset):
def __init__(self, data, vocab):
self.data = data
self.vocab = vocab
self.PAD_IDX = self.vocab[self.vocab.PAD]
self.BOS_IDX = self.vocab[self.vocab.BOS]
self.EOS_IDX = self.vocab[self.vocab.EOS]

def __len__(self):
return len(self.data)

def __getitem__(self, index):
raw_in, raw_out = self.data[index]
in_tensor_ = torch.LongTensor(self.vocab(raw_in))
out_tensor_ = torch.LongTensor(self.vocab(raw_out))
return in_tensor_, out_tensor_

def collate_fn(self, batch):
in_batch, out_batch = [], []
for in_, out_ in batch:
in_batch.append(in_)
out_ = torch.cat(
[
torch.LongTensor([self.BOS_IDX]),
out_,
torch.LongTensor([self.EOS_IDX]),
],
dim=0,
)
out_batch.append(out_)
in_batch = pad_sequence(in_batch, True, self.PAD_IDX)
out_batch = pad_sequence(out_batch, True, self.PAD_IDX)
return in_batch, out_batch

def get_loader(self, batch_size, shuffle=False, num_workers=0):
return DataLoader(
self,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=self.collate_fn,
pin_memory=True,
)

模型(Model)

该部分用于定义模型,包括 TokenEmbedding、PositionalEncoding 和 CoupletsTransformer。

对应项目文件为 model/model.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import math
import torch
import torch.nn as nn


class TokenEmbedding(nn.Module):
def __init__(self, vocab_size, emb_size):
super(TokenEmbedding, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.emb_size = emb_size

def forward(self, tokens):
return self.embedding(tokens) * math.sqrt(self.emb_size)


class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)

def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)


class CoupletsTransformer(nn.Module):
def __init__(
self,
vocab_size,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
):
super(CoupletsTransformer, self).__init__()
self.name = "CoupletsTransformer"
self.token_embedding = TokenEmbedding(vocab_size, d_model)
self.pos_embedding = PositionalEncoding(d_model, dropout)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.fc = nn.Linear(d_model, vocab_size)
self._reset_parameters()

def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)

def forward(self, src, tgt, padding_value=0):
src_embed = self.token_embedding(src) # [batch_size, src_len, embed_dim]
src_embed = self.pos_embedding(src_embed) # [batch_size, src_len, embed_dim]
tgt_embed = self.token_embedding(tgt) # [batch_size, tgt_len, embed_dim]
tgt_embed = self.pos_embedding(tgt_embed) # [batch_size, tgt_len, embed_dim]

tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(-1)).to(
tgt.device
)
src_key_padding_mask = src == padding_value # [batch_size, src_len]
tgt_key_padding_mask = tgt == padding_value # [batch_size, tgt_len]

outs = self.transformer(
src=src_embed,
tgt=tgt_embed,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
) # [batch_size, tgt_len, embed_dim]
logits = self.fc(outs) # [batch_size, tgt_len, vocab_size]
return logits

def encoder(self, src):
src_embed = self.token_embedding(src)
src_embed = self.pos_embedding(src_embed)
memory = self.transformer.encoder(src_embed)
return memory

def decoder(self, tgt, memory):
tgt_embed = self.token_embedding(tgt)
tgt_embed = self.pos_embedding(tgt_embed)
outs = self.transformer.decoder(tgt_embed, memory=memory)
return outs

def generate(self, text, vocab):
self.eval()
device = next(self.parameters()).device
max_len = len(text)
src = torch.LongTensor(vocab(list(text))).unsqueeze(0).to(device)
memory = self.encoder(src)
l_out = [vocab.BOS]
for i in range(max_len):
tgt = torch.LongTensor(vocab(l_out)).unsqueeze(0).to(device)
outs = self.decoder(tgt, memory)
prob = self.fc(outs[:, -1, :])
next_token = vocab.to_tokens(prob.argmax(1).item())
if next_token == vocab.EOS:
break
l_out.append(next_token)
return "".join(l_out[1:])

工具(Utils)

该部分用于定义工具函数,包括 EarlyStopping。

对应项目文件为 utils/EarlyStopping.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class EarlyStopping(object):
def __init__(self, patience=7, delta=0):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = float("inf")
self.delta = delta

def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
elif score < self.best_score + self.delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.counter = 0

训练(Train)

该部分用于定义训练函数,包括训练、验证和保存模型。

对应项目文件为 train.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
import gc
import time
import math
import random
import joblib
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import seaborn as sns
from loguru import logger
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

sns.set_theme(
style="darkgrid", font_scale=1.2, font="SimHei", rc={"axes.unicode_minus": False}
)

import torch
from torch import nn, optim
from torch.optim.lr_scheduler import LambdaLR


from config import Config
from model import CoupletsTransformer
from dataset import load_data, Vocab, CoupletsDataset
from utils import EarlyStopping


def train_model(
config, model, train_loader, val_loader, optimizer, criterion, scheduler
):
model = model.to(config.device)
best_loss = float("inf")
history = []
model_path = os.path.join(config.model_save_dir, f"{model.name}_best.pth")
if config.early_stop:
early_stopping = EarlyStopping(patience=config.patience, delta=config.delta)
for epoch in tqdm(range(1, config.epochs + 1), desc=f"All"):
train_loss = train_one_epoch(
config, model, train_loader, optimizer, criterion, scheduler
)
val_loss = evaluate(config, model, val_loader, criterion)

perplexity = math.exp(val_loss)
history.append((epoch, train_loss, val_loss))
msg = f"Epoch {epoch}/{config.epochs}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, Perplexity: {perplexity:.4f}"
logger.info(msg)
if val_loss < best_loss:
logger.info(
f"Val loss decrease from {best_loss:>10.6f} to {val_loss:>10.6f}"
)
torch.save(model.state_dict(), model_path)
best_loss = val_loss
if config.early_stop:
early_stopping(val_loss, model)
if early_stopping.early_stop:
logger.info(f"Early stopping at epoch {epoch}")
break
logger.info(f"Save best model with val loss {best_loss:.6f} to {model_path}")

model_path = os.path.join(config.model_save_dir, f"{model.name}_last.pth")
torch.save(model.state_dict(), model_path)
logger.info(f"Save last model with val loss {val_loss:.6f} to {model_path}")

history = pd.DataFrame(
history, columns=["Epoch", "Train Loss", "Val Loss"]
).set_index("Epoch")
history.plot(
subplots=True, layout=(1, 2), sharey="row", figsize=(14, 6), marker="o", lw=2
)
history_path = os.path.join(config.img_save_dir, "history.png")
plt.savefig(history_path, dpi=300)
logger.info(f"Save history to {history_path}")


def train_one_epoch(config, model, train_loader, optimizer, criterion, scheduler):
model.train()
train_loss = 0
for src, tgt in tqdm(train_loader, desc=f"Epoch", leave=False):
src, tgt = src.to(config.device), tgt.to(config.device)
output = model(src, tgt[:, :-1], config.PAD_IDX)
output = output.contiguous().view(-1, output.size(-1))
tgt = tgt[:, 1:].contiguous().view(-1)
loss = criterion(output, tgt)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
return train_loss / len(train_loader)


def evaluate(config, model, val_loader, criterion):
model.eval()
val_loss = 0
with torch.no_grad():
for src, tgt in tqdm(val_loader, desc=f"Val", leave=False):
src, tgt = src.to(config.device), tgt.to(config.device)
output = model(src, tgt[:, :-1], config.PAD_IDX)
output = output.contiguous().view(-1, output.size(-1))
tgt = tgt[:, 1:].contiguous().view(-1)
loss = criterion(output, tgt)
val_loss += loss.item()
return val_loss / len(val_loader)


def test_model(model, data, vocab):
model.eval()
for src_text, tgt_text in data:
src_text, tgt_text = "".join(src_text), "".join(tgt_text)
out_text = model.generate(src_text, vocab)
logger.info(f"\nInput: {src_text}\nTarget: {tgt_text}\nOutput: {out_text}")


def seed_everything(seed):
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def main():
config = Config()

# Set random seed
seed_everything(config.seed)
logger.info(f"Set random seed to {config.seed}")

# Set cuDNN
if config.cuDNN:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# Load data
data = load_data([config.in_path, config.out_path])
if config.debug:
data = data[:1000]
logger.info(f"Load {len(data)} couplets")

# Build vocab
vocab = Vocab(data)
vocab_size = len(vocab)
logger.info(f"Build vocab with {vocab_size} tokens")
vocab_path = os.path.join(config.model_save_dir, "vocab.pkl")
joblib.dump(vocab, vocab_path)
logger.info(f"Save vocab to {vocab_path}")

# Build dataset
data_train, data_val = train_test_split(
data, test_size=config.val_ratio, random_state=config.seed, shuffle=True
)
train_dataset = CoupletsDataset(data_train, vocab)
val_dataset = CoupletsDataset(data_val, vocab)

config.PAD_IDX = train_dataset.PAD_IDX

logger.info(f"Build train dataset with {len(train_dataset)} samples")
logger.info(f"Build val dataset with {len(val_dataset)} samples")

# Build dataloader
train_loader = train_dataset.get_loader(
config.batch_size, shuffle=True, num_workers=config.num_workers
)
val_loader = val_dataset.get_loader(
config.batch_size, shuffle=False, num_workers=config.num_workers
)
logger.info(f"Build train dataloader with {len(train_loader)} batches")
logger.info(f"Build val dataloader with {len(val_loader)} batches")

# Build model
model = CoupletsTransformer(
vocab_size=vocab_size,
d_model=config.d_model,
nhead=config.num_head,
num_encoder_layers=config.num_encoder_layers,
num_decoder_layers=config.num_decoder_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout,
)
logger.info(f"Build model with {model.name}")

# Build optimizer
optimizer = optim.AdamW(
model.parameters(),
lr=1,
betas=(config.beta1, config.beta2),
eps=config.epsilon,
weight_decay=config.weight_decay,
)

# Build criterion
criterion = nn.CrossEntropyLoss(ignore_index=config.PAD_IDX, reduction="mean")

# Build scheduler
lr_max, lr_min = config.lr_max, config.lr_min
T_max = config.epochs * len(train_loader)
warm_up_iter = int(T_max * config.warmup_ratio)

def WarmupExponentialLR(cur_iter):
gamma = math.exp(math.log(lr_min / lr_max) / (T_max - warm_up_iter))
if cur_iter < warm_up_iter:
return (lr_max - lr_min) * (cur_iter / warm_up_iter) + lr_min
else:
return lr_max * gamma ** (cur_iter - warm_up_iter)

scheduler = LambdaLR(optimizer, lr_lambda=WarmupExponentialLR)

df_lr = pd.DataFrame(
[WarmupExponentialLR(i) for i in range(T_max)],
columns=["Learning Rate"],
)
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_lr, linewidth=2)
plt.title("Learning Rate Schedule")
plt.xlabel("Iteration")
plt.ylabel("Learning Rate")
lr_img_path = os.path.join(config.img_save_dir, "lr_schedule.png")
plt.savefig(lr_img_path, dpi=300)
logger.info(f"Save learning rate schedule to {lr_img_path}")

# Garbage collect
gc.collect()
torch.cuda.empty_cache()

# Train model
train_model(
config, model, train_loader, val_loader, optimizer, criterion, scheduler
)

# Test model
test_model(model, data_val[:10], vocab)


if __name__ == "__main__":
main()

Web UI

该部分用于定义 Web UI,包括输入、输出和启动 Web UI。

对应项目文件为 webui.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import random
import joblib

import torch
import gradio as gr

from dataset import Vocab
from model import CoupletsTransformer

data_path = "./data/fixed_couplets_in.txt"
vocab_path = "./trained/vocab.pkl"
model_path = "./trained/CoupletsTransformer_best.pth"


vocab = joblib.load(vocab_path)
vocab_size = len(vocab)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CoupletsTransformer(
vocab_size,
d_model=256,
nhead=8,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=1024,
dropout=0.1,
).to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

example = (
line.replace(" ", "").strip() for line in iter(open(data_path, encoding="utf8"))
)
example = [line for line in example if len(line) > 5]

example = random.sample(example, 300)


def generate_couplet(vocab, model, src_text):
if not src_text:
return "上联不能为空"
out_text = model.generate(src_text, vocab)
return out_text


input_text = gr.Textbox(
label="上联",
placeholder="在这里输入上联",
max_lines=1,
lines=1,
show_copy_button=True,
autofocus=True,
)

output_text = gr.Textbox(
label="下联",
placeholder="在这里生成下联",
max_lines=1,
lines=1,
show_copy_button=True,
)

demo = gr.Interface(
fn=lambda x: generate_couplet(vocab, model, x),
inputs=input_text,
outputs=output_text,
title="中文对联生成器",
description="输入上联,生成下联",
allow_flagging="never",
submit_btn="生成下联",
clear_btn="清空",
examples=example,
examples_per_page=50,
)

demo.launch()