RNN的 hidden state
```py
class rnn_(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.rnn = torch.nn.RNN(input_size, hidden_size,
num_layers, batch_first=True)
def forward(self, x):
h = torch.zeros(1, x.size(0), self.hidden_size)
out, h = self.rnn(x, h)
return out
```
不在那个位置写,经常会出莫名其妙的错!
*****
手写 dataset 的问题
在__init__中完成input和 target 的张量,在__getitem__中只做取值操作。数据的shape是**(总量, 其它)**,其它例如图片可能是(channel, height, width),minist数据是(28, 28),文字数据是(序列长度)等。
```py
class qohdataset(data.Dataset):
"""
Dataset must define __getitem__ and __len__
"""
def __init__(self, qoh):
def padding(ele, num):
difference = num - len(ele)
for _ in range(difference):
ele.append(np.zeros((47,)))
self.qoh = qoh
for i in self.qoh:
if len(i) < 13:
padding(i, 13)
self.qoh = np.array(self.qoh, dtype=int)
print(self.qoh.shape)
self.qoh = torch.from_numpy(self.qoh)
self.seq = self.qoh[:, 0:12, :]
self.tar = self.qoh[:, 1:13, :]
def __getitem__(self, index):
"""
index位置的(x, y), x和y都是tensor
Returns one data pair (x and y).
"""
x = self.seq[index, ...]
y = self.tar[index, ...]
return x, y
def __len__(self):
# 0<= index < lens
lens = self.qoh.shape[1]
return lens
```
*****
dataloader的问题
dataloader获得的是(batch, 其它),其它和 dataset 一致。一般而言,只有在输入序列不一样长的时候才会定义collate_fn,否则直接调用即可
*****
数据类型是有要求的:
float, double, half, short(int16), int(int32), long(int64)