4
4
import numpy as np
5
5
import sys
6
6
7
+ from dataclasses import dataclass , field
8
+
9
+ @dataclass
10
+ class CustomTrainingArguments (TrainingArguments ):
11
+ save_steps : int = field (default = 100 )
12
+
7
13
# model_path = sys.argv[1]
8
14
9
15
def setup_seed (seed ):
@@ -49,33 +55,33 @@ def read_data(dataset_name, wav_split):
49
55
50
56
51
57
# for debug
52
- read_data ('Shanghai_Dialect_Dict' , 1 )
53
- read_data ('Shanghai_Dialect_Dict' , 2 )
54
- random .shuffle (train_data )
55
- eval_ratio = 0.05
56
- index = int (len (train_data ) * eval_ratio )
57
- eval_data = train_data [:10 ]
58
- train_data = train_data [10 :20 ]
59
- batch_size = 1
60
- eval_steps = 100
61
- fp16 = False
62
-
63
- # for train
64
- # read_data('Shanghai_Dialect_Conversational_Speech_Corpus', 1)
65
- # read_data('Shanghai_Dialect_Conversational_Speech_Corpus', 2)
66
- # read_data('Shanghai_Dialect_Scripted_Speech_Corpus_Daily_Use_Sentence', 1)
67
- # read_data('Shanghai_Dialect_Scripted_Speech_Corpus_Daily_Use_Sentence', 2)
68
58
# read_data('Shanghai_Dialect_Dict', 1)
69
59
# read_data('Shanghai_Dialect_Dict', 2)
70
- # read_data('Shanghai_Dialect_Zhongguoyuyan', 1)
71
-
60
+ # random.shuffle(train_data)
72
61
# eval_ratio = 0.05
73
62
# index = int(len(train_data) * eval_ratio)
74
- # eval_data = train_data[:index ]
75
- # train_data = train_data[index: ]
76
- # batch_size = 32
63
+ # eval_data = train_data[:10 ]
64
+ # train_data = train_data[10:20 ]
65
+ # batch_size = 1
77
66
# eval_steps = 100
78
- # fp16 = True
67
+ # fp16 = False
68
+
69
+ # for train
70
+ read_data ('Shanghai_Dialect_Conversational_Speech_Corpus' , 1 )
71
+ read_data ('Shanghai_Dialect_Conversational_Speech_Corpus' , 2 )
72
+ read_data ('Shanghai_Dialect_Scripted_Speech_Corpus_Daily_Use_Sentence' , 1 )
73
+ read_data ('Shanghai_Dialect_Scripted_Speech_Corpus_Daily_Use_Sentence' , 2 )
74
+ read_data ('Shanghai_Dialect_Dict' , 1 )
75
+ read_data ('Shanghai_Dialect_Dict' , 2 )
76
+ read_data ('Shanghai_Dialect_Zhongguoyuyan' , 1 )
77
+
78
+ eval_ratio = 0.05
79
+ index = int (len (train_data ) * eval_ratio )
80
+ eval_data = train_data [:index ]
81
+ train_data = train_data [index :]
82
+ batch_size = 32
83
+ eval_steps = 100
84
+ fp16 = True
79
85
80
86
81
87
print ('eval_data_len:' , len (eval_data ))
@@ -84,7 +90,7 @@ def read_data(dataset_name, wav_split):
84
90
# gradient_checkpointing=True,
85
91
# gradient_accumulation_steps=2,
86
92
87
- training_args = TrainingArguments (
93
+ training_args = CustomTrainingArguments (
88
94
save_steps = eval_steps ,
89
95
group_by_length = True ,
90
96
num_train_epochs = 200 ,
0 commit comments