0

I have an unbatched tensorflow dataset that looks like this:

			
ds = ...
for record in ds.take(3):
    print('data shape={}'.format(record['data'].shape))

-> data shape=(512, 512, 87)
-> data shape=(512, 512, 277)
-> data shape=(512, 512, 133)

I want to feed the data to my network in chunks of depth 5. In the example above, the tensor of shape (512, 512, 87) would be divided into 17 tensors of shape (512, 512, 5). The final 2 rows of the matrix (

tensor[:, :, 85:87]
) should be discarded.

For example:

			
chunked_ds = ...
for record in chunked_ds.take(1):
    print('chunked data shape={}'.format(record['data'].shape))

-> chunked data shape=(512, 512, 5)

How can I get from

ds
to
chunked_ds
?
tf.data.Dataset.window()
looks like what I need but I cannot get this working.

Anonymous Asked question May 14, 2021