Source code for paddlets.models.classify.dl.adapter.data_adapter

# !/usr/bin/env python3
# -*- coding:utf-8 -*-
import numpy as np

from paddlets.models.classify.dl.adapter.paddle_dataset_impl import ClassifyPaddleDatasetImpl
from paddlets.datasets import TSDataset
from paddlets.logger import Logger

from paddle.io import DataLoader as PaddleDataLoader
from typing import List, Callable, Tuple, Optional

logger = Logger(__name__)


[docs]class ClassifyDataAdapter(object): """ Data adapter, converts `TSDataset` to `paddle.io.Dataset` and `paddle.io.DataLoader`. """ def __init__(self): pass
[docs] def to_paddle_dataset( self, rawdatasets: List[TSDataset], labels: np.ndarray ) -> ClassifyPaddleDatasetImpl: """ Converts :class:`TSDataset` to :class:`paddle.io.Dataset`. Args: rawdataset(TSDataset): Raw TSDataset for converting to :class:`paddle.io.Dataset`. labels:(np.ndarray) : The data class labels Returns: PaddleDatasetImpl: A built PaddleDatasetImpl. """ return ClassifyPaddleDatasetImpl( rawdatasets=rawdatasets, labels = labels )
[docs] def to_paddle_dataloader( self, paddle_dataset: ClassifyPaddleDatasetImpl, batch_size: int, collate_fn: Callable = None, shuffle: bool = True ) -> PaddleDataLoader: """ Converts :class:`paddle.io.Dataset` to :class:`paddle.io.DataLoader`. Args: paddle_dataset(PaddleDatasetImpl): Raw :class:`TSDataset` for building :class:`paddle.io.DataLoader`. batch_size(int): The number of samples for a single batch. collate_fn(Callable, optional): User-defined collate function for each batch, optional. shuffle(bool, optional): Whether to shuffle indices order before generating batch indices, default True. Returns: PaddleDataLoader: A built paddle DataLoader. """ return PaddleDataLoader(dataset=paddle_dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle)