diff --git a/examples/basic_tutorials/cifar10_cnn.py b/examples/basic_tutorials/cifar10_cnn.py index 35061da..089a41b 100644 --- a/examples/basic_tutorials/cifar10_cnn.py +++ b/examples/basic_tutorials/cifar10_cnn.py @@ -1,30 +1,38 @@ #! /usr/bin/python # -*- coding: utf-8 -*- -################################ TensorLayerX and Jittor. ################################# + + +################################ TensorLayerX and Torch can be mixed programming. ################################# import os +# os.environ['TL_BACKEND'] = 'paddle' +# os.environ['TL_BACKEND'] = 'tensorflow' +# os.environ['TL_BACKEND'] = 'mindspore' +os.environ['TL_BACKEND'] = 'torch' + + import time -import tensorlayerx as tlx from tensorlayerx.dataflow import Dataset, DataLoader from tensorlayerx.vision.transforms import ( Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop ) -from tensorlayerx.nn import Conv2d, Linear, Flatten, Module, MaxPool2d, BatchNorm2d -from tensorlayerx.optimizers import Adam -from tqdm import tqdm - -# Enable debug logging +from tensorlayerx.model import TrainOneStep +from tensorlayerx.nn import Module +import tensorlayerx as tlx +from tensorlayerx.nn import (Conv2d, Linear, Flatten, MaxPool2d, BatchNorm2d) +# enable debug logging tlx.logging.set_verbosity(tlx.logging.DEBUG) -os.environ['TL_BACKEND'] = 'jittor' - -# Download and prepare the CIFAR10 dataset -print("Downloading CIFAR10 dataset...") +# ################## Download and prepare the CIFAR10 dataset ################## +# This is just some way of getting the CIFAR10 dataset from an online location +# and loading it into numpy arrays with shape [32,32,3] X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) -# Define the CIFAR10 dataset -class CIFAR10Dataset(Dataset): +# ################## CIFAR10 dataset ################## +# We define a Dataset class for Loading CIFAR10 images and labels. +class make_dataset(Dataset): + def __init__(self, data, label, transforms): self.data = data self.label = label @@ -34,109 +42,161 @@ def __getitem__(self, idx): x = self.data[idx].astype('uint8') y = self.label[idx].astype('int64') x = self.transforms(x) + return x, y def __len__(self): + return len(self.label) -# Define the CIFAR10 images preprocessing pipeline -train_transforms = Compose([ - RandomCrop(size=[24, 24]), - RandomFlipHorizontal(), - RandomBrightness(brightness_factor=(0.5, 1.5)), - RandomContrast(contrast_factor=(0.5, 1.5)), - StandardizePerImage() -]) +# We define the CIFAR10 iamges preprocessing pipeline. +train_transforms = Compose( # Combining multiple operations sequentially + [ + RandomCrop(size=[24, 24]), #random crop from images to shape [24, 24] + RandomFlipHorizontal(), # random invert each image horizontally by probability + RandomBrightness(brightness_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust brightness randomly + RandomContrast(contrast_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust contrast randomly + StandardizePerImage() #Normalize the values of each image to [-1, 1] + ] +) test_transforms = Compose([Resize(size=(24, 24)), StandardizePerImage()]) -# Create DataLoaders for training and testing -print("Processing CIFAR10 dataset...") -train_dataset = CIFAR10Dataset(data=X_train, label=y_train, transforms=train_transforms) -test_dataset = CIFAR10Dataset(data=X_test, label=y_test, transforms=test_transforms) +# We use DataLoader to batch and shuffle data, and make data into iterators. +train_dataset = make_dataset(data=X_train, label=y_train, transforms=train_transforms) +test_dataset = make_dataset(data=X_test, label=y_test, transforms=test_transforms) -train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True) -test_dataloader = DataLoader(test_dataset, batch_size=128) +train_dataset = DataLoader(train_dataset, batch_size=128, shuffle=True) +test_dataset = DataLoader(test_dataset, batch_size=128) +# ################## CNN network ################## +class CNN(Module): -class SimpleCNN(Module): def __init__(self): - super(SimpleCNN, self).__init__() - self.conv1 = Conv2d(16, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=3) - self.conv2 = Conv2d(32, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=16) - self.maxpool1 = MaxPool2d((2, 2), (2, 2), padding='SAME') - self.conv3 = Conv2d(64, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=32) - self.bn1 = BatchNorm2d(num_features=64, act=tlx.nn.ReLU) - self.conv4 = Conv2d(128, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=64) - self.maxpool2 = MaxPool2d((2, 2), (2, 2), padding='SAME') - self.flatten = Flatten() - self.fc1 = Linear(out_features=128, act=tlx.nn.ReLU, in_features=128 * 6 * 6) - self.fc2 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=128) - self.fc3 = Linear(out_features=10, act=None, in_features=64) - + super(CNN, self).__init__() + # Parameter initialization method + W_init = tlx.nn.initializers.truncated_normal(stddev=5e-2) + W_init2 = tlx.nn.initializers.truncated_normal(stddev=0.04) + b_init2 = tlx.nn.initializers.constant(value=0.1) + + # 2D Convolutional Neural Network, Set padding method "SAME", convolutional kernel size [5,5], stride [1,1], in channels, out channels + self.conv1 = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1', in_channels=3) + # Add 2D BatchNormalize, using ReLU for output. + self.bn = BatchNorm2d(num_features=64, act=tlx.nn.ReLU) + # Add 2D Max pooling layer. + self.maxpool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1') + + self.conv2 = Conv2d( + 64, (5, 5), (1, 1), padding='SAME', act=tlx.nn.ReLU, W_init=W_init, name='conv2', in_channels=64 + ) + self.maxpool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2') + # Flatten 2D data to 1D data + self.flatten = Flatten(name='flatten') + # Linear layer with 384 units, using ReLU for output. + self.linear1 = Linear(384, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear1relu', in_features=2304) + self.linear2 = Linear(192, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear2relu', in_features=384) + self.linear3 = Linear(10, act=None, W_init=W_init2, name='output', in_features=192) + + # We define the forward computation process. def forward(self, x): z = self.conv1(x) - z = self.conv2(z) + z = self.bn(z) z = self.maxpool1(z) - z = self.conv3(z) - z = self.bn1(z) - z = self.conv4(z) + z = self.conv2(z) z = self.maxpool2(z) z = self.flatten(z) - z = self.fc1(z) - z = self.fc2(z) - z = self.fc3(z) + z = self.linear1(z) + z = self.linear2(z) + z = self.linear3(z) return z -# Instantiate the model -model = SimpleCNN() - -# Define the optimizer -optimizer = Adam(lr=0.001) -# optimizer = Adam(lr=0.001, params=model.trainable_weights ) - -# Define the loss function -loss_fn = tlx.losses.softmax_cross_entropy_with_logits - -# Use the built-in training method -metric = tlx.metrics.Recall() -tlx_model = tlx.model.Model(network=model, loss_fn=loss_fn, optimizer=optimizer, metrics=metric) -tlx_model.train(n_epoch=2, train_dataset=train_dataloader, print_freq=1, print_train_batch=True) - - -################################ TensorLayerX and Torch. ################################# - +# get the network +net = CNN() + +# training settings +n_epoch = 500 +learning_rate = 0.0001 +print_freq = 5 +n_step_epoch = int(len(y_train) / 128) +n_step = n_epoch * n_step_epoch +shuffle_buffer_size = 128 +# Get training parameters +train_weights = net.trainable_weights +# Define the optimizer, use the Adam optimizer. +optimizer = tlx.optimizers.Adam(learning_rate) +# Define evaluation metrics. +metrics = tlx.metrics.Accuracy() + +# Define the loss calculation process +class WithLoss(Module): + + def __init__(self, net, loss_fn): + super(WithLoss, self).__init__() + self._net = net + self._loss_fn = loss_fn + + def forward(self, data, label): + out = self._net(data) + loss = self._loss_fn(out, label) + return loss + + +net_with_loss = WithLoss(net, loss_fn=tlx.losses.softmax_cross_entropy_with_logits) +# Initialize one-step training +net_with_train = TrainOneStep(net_with_loss, optimizer, train_weights) + +# Custom training loops +for epoch in range(n_epoch): + start_time = time.time() + # Set the network to training state + net.set_train() + train_loss, train_acc, n_iter = 0, 0, 0 + # Get training data and labels + for X_batch, y_batch in train_dataset: + # Calculate the loss value, and automatically complete the gradient update + _loss_ce = net_with_train(X_batch, y_batch) + train_loss += _loss_ce + + n_iter += 1 + _logits = net(X_batch) + # Calculate accuracy + metrics.update(_logits, y_batch) + train_acc += metrics.result() + metrics.reset() + print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) + print(" train loss: {}".format(train_loss / n_iter)) + print(" train acc: {}".format(train_acc / n_iter)) + + +################################ TensorLayerX and Jittor can be mixed programming. ################################# # import os -# # os.environ['TL_BACKEND'] = 'paddle' -# # os.environ['TL_BACKEND'] = 'tensorflow' -# # os.environ['TL_BACKEND'] = 'mindspore' -# os.environ['TL_BACKEND'] = 'torch' - - # import time +# import numpy as np +# import tensorlayerx as tlx # from tensorlayerx.dataflow import Dataset, DataLoader # from tensorlayerx.vision.transforms import ( # Compose, Resize, RandomFlipHorizontal, RandomContrast, RandomBrightness, StandardizePerImage, RandomCrop # ) -# from tensorlayerx.model import TrainOneStep -# from tensorlayerx.nn import Module -# import tensorlayerx as tlx -# from tensorlayerx.nn import (Conv2d, Linear, Flatten, MaxPool2d, BatchNorm2d) -# # enable debug logging +# from tensorlayerx.nn import Conv2d, Linear, Flatten, Module +# from tensorlayerx.optimizers import Adam +# from tqdm import tqdm + +# # Enable debug logging # tlx.logging.set_verbosity(tlx.logging.DEBUG) -# # ################## Download and prepare the CIFAR10 dataset ################## -# # This is just some way of getting the CIFAR10 dataset from an online location -# # and loading it into numpy arrays with shape [32,32,3] -# X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) +# os.environ['TL_BACKEND'] = 'jittor' -# # ################## CIFAR10 dataset ################## -# # We define a Dataset class for Loading CIFAR10 images and labels. -# class make_dataset(Dataset): + +# # Download and prepare the CIFAR10 dataset with progress bar +# print("Downloading CIFAR10 dataset...") +# X_train, y_train, X_test, y_test = tlx.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) + +# # Define the CIFAR10 dataset +# class CIFAR10Dataset(Dataset): # def __init__(self, data, label, transforms): # self.data = data # self.label = label @@ -146,131 +206,79 @@ def forward(self, x): # x = self.data[idx].astype('uint8') # y = self.label[idx].astype('int64') # x = self.transforms(x) - # return x, y # def __len__(self): - # return len(self.label) -# # We define the CIFAR10 iamges preprocessing pipeline. -# train_transforms = Compose( # Combining multiple operations sequentially -# [ -# RandomCrop(size=[24, 24]), #random crop from images to shape [24, 24] -# RandomFlipHorizontal(), # random invert each image horizontally by probability -# RandomBrightness(brightness_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust brightness randomly -# RandomContrast(contrast_factor=(0.5, 1.5)), # Within the range of values (0.5, 1.5), adjust contrast randomly -# StandardizePerImage() #Normalize the values of each image to [-1, 1] -# ] -# ) +# # Define the CIFAR10 images preprocessing pipeline +# train_transforms = Compose([ +# RandomCrop(size=[24, 24]), +# RandomFlipHorizontal(), +# RandomBrightness(brightness_factor=(0.5, 1.5)), +# RandomContrast(contrast_factor=(0.5, 1.5)), +# StandardizePerImage() +# ]) # test_transforms = Compose([Resize(size=(24, 24)), StandardizePerImage()]) -# # We use DataLoader to batch and shuffle data, and make data into iterators. -# train_dataset = make_dataset(data=X_train, label=y_train, transforms=train_transforms) -# test_dataset = make_dataset(data=X_test, label=y_test, transforms=test_transforms) - -# train_dataset = DataLoader(train_dataset, batch_size=128, shuffle=True) -# test_dataset = DataLoader(test_dataset, batch_size=128) +# # Create DataLoaders for training and testing +# print("Processing CIFAR10 dataset...") +# train_dataset = CIFAR10Dataset(data=X_train, label=y_train, transforms=train_transforms) +# test_dataset = CIFAR10Dataset(data=X_test, label=y_test, transforms=test_transforms) -# # ################## CNN network ################## -# class CNN(Module): +# train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True) +# test_dataloader = DataLoader(test_dataset, batch_size=128) +# # Define a simple CNN model +# class SimpleCNN(Module): # def __init__(self): -# super(CNN, self).__init__() -# # Parameter initialization method -# W_init = tlx.nn.initializers.truncated_normal(stddev=5e-2) -# W_init2 = tlx.nn.initializers.truncated_normal(stddev=0.04) -# b_init2 = tlx.nn.initializers.constant(value=0.1) - -# # 2D Convolutional Neural Network, Set padding method "SAME", convolutional kernel size [5,5], stride [1,1], in channels, out channels -# self.conv1 = Conv2d(64, (5, 5), (1, 1), padding='SAME', W_init=W_init, b_init=None, name='conv1', in_channels=3) -# # Add 2D BatchNormalize, using ReLU for output. -# self.bn = BatchNorm2d(num_features=64, act=tlx.nn.ReLU) -# # Add 2D Max pooling layer. -# self.maxpool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1') +# super(SimpleCNN, self).__init__() +# self.conv1 = Conv2d(16, (3, 3), (1, 1), padding='SAME', act=tlx.nn.ReLU, in_channels=3) +# self.flatten = Flatten() +# self.fc1 = Linear(out_features=64, act=tlx.nn.ReLU, in_features=16 * 24 * 24) +# self.fc2 = Linear(out_features=10, act=None, in_features=64) -# self.conv2 = Conv2d( -# 64, (5, 5), (1, 1), padding='SAME', act=tlx.nn.ReLU, W_init=W_init, name='conv2', in_channels=64 -# ) -# self.maxpool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2') -# # Flatten 2D data to 1D data -# self.flatten = Flatten(name='flatten') -# # Linear layer with 384 units, using ReLU for output. -# self.linear1 = Linear(384, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear1relu', in_features=2304) -# self.linear2 = Linear(192, act=tlx.nn.ReLU, W_init=W_init2, b_init=b_init2, name='linear2relu', in_features=384) -# self.linear3 = Linear(10, act=None, W_init=W_init2, name='output', in_features=192) - -# # We define the forward computation process. # def forward(self, x): # z = self.conv1(x) -# z = self.bn(z) -# z = self.maxpool1(z) -# z = self.conv2(z) -# z = self.maxpool2(z) # z = self.flatten(z) -# z = self.linear1(z) -# z = self.linear2(z) -# z = self.linear3(z) +# z = self.fc1(z) +# z = self.fc2(z) # return z +# # Instantiate the model +# model = SimpleCNN() -# # get the network -# net = CNN() - -# # training settings -# n_epoch = 500 -# learning_rate = 0.0001 -# print_freq = 5 -# n_step_epoch = int(len(y_train) / 128) -# n_step = n_epoch * n_step_epoch -# shuffle_buffer_size = 128 -# # Get training parameters -# train_weights = net.trainable_weights -# # Define the optimizer, use the Adam optimizer. -# optimizer = tlx.optimizers.Adam(learning_rate) -# # Define evaluation metrics. -# metrics = tlx.metrics.Accuracy() - -# # Define the loss calculation process -# class WithLoss(Module): - -# def __init__(self, net, loss_fn): -# super(WithLoss, self).__init__() -# self._net = net -# self._loss_fn = loss_fn +# # Define the optimizer +# optimizer = Adam(model.trainable_weights, lr=0.001) -# def forward(self, data, label): -# out = self._net(data) -# loss = self._loss_fn(out, label) -# return loss +# # Define the loss function +# loss_fn = tlx.losses.softmax_cross_entropy_with_logits - -# net_with_loss = WithLoss(net, loss_fn=tlx.losses.softmax_cross_entropy_with_logits) -# # Initialize one-step training -# net_with_train = TrainOneStep(net_with_loss, optimizer, train_weights) - -# # Custom training loops +# # Training loop +# n_epoch = 2 # for epoch in range(n_epoch): # start_time = time.time() -# # Set the network to training state -# net.set_train() -# train_loss, train_acc, n_iter = 0, 0, 0 -# # Get training data and labels -# for X_batch, y_batch in train_dataset: -# # Calculate the loss value, and automatically complete the gradient update -# _loss_ce = net_with_train(X_batch, y_batch) -# train_loss += _loss_ce +# model.set_train() +# train_loss, n_iter = 0, 0 + +# with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{n_epoch}", unit="batch") as pbar: +# for X_batch, y_batch in train_dataloader: +# X_batch = tlx.convert_to_tensor(X_batch) +# y_batch = tlx.convert_to_tensor(y_batch) +# _logits = model(X_batch) +# loss = loss_fn(_logits, y_batch) + +# optimizer.zero_grad() +# optimizer.step(loss) + +# train_loss += loss.item() +# n_iter += 1 +# pbar.update(1) + +# print(f"Epoch {epoch + 1} of {n_epoch} took {time.time() - start_time:.2f}s") +# print(f" train loss: {train_loss / n_iter:.4f}") -# n_iter += 1 -# _logits = net(X_batch) -# # Calculate accuracy -# metrics.update(_logits, y_batch) -# train_acc += metrics.result() -# metrics.reset() -# print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time)) -# print(" train loss: {}".format(train_loss / n_iter)) -# print(" train acc: {}".format(train_acc / n_iter)) ################################ TensorLayerX and TensorFlow can be mixed programming. ################################# diff --git a/examples/basic_tutorials/cifar10_cnn_dist.py b/examples/basic_tutorials/cifar10_cnn_dist.py index bff9efb..c72c704 100644 --- a/examples/basic_tutorials/cifar10_cnn_dist.py +++ b/examples/basic_tutorials/cifar10_cnn_dist.py @@ -3,10 +3,10 @@ import os # os.environ['TL_BACKEND'] = 'paddle' -os.environ['TL_BACKEND'] = 'jittor' +# os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' -# os.environ['TL_BACKEND'] = 'torch' +os.environ['TL_BACKEND'] = 'torch' import paddle from paddle.distributed import fleet diff --git a/examples/basic_tutorials/cifar10_cnn_train.py b/examples/basic_tutorials/cifar10_cnn_train.py index e98294d..1a549cc 100644 --- a/examples/basic_tutorials/cifar10_cnn_train.py +++ b/examples/basic_tutorials/cifar10_cnn_train.py @@ -5,11 +5,11 @@ import os # os.environ['TL_BACKEND'] = 'paddle' - -os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' -# os.environ['TL_BACKEND'] = 'torch' +# os.environ['TL_BACKEND'] = 'jittor' + +os.environ['TL_BACKEND'] = 'torch' @@ -75,7 +75,7 @@ def forward(self, x): # 定义损失函数、优化器等 loss_fn=tlx.losses.softmax_cross_entropy_with_logits -optimizer = tlx.optimizers.Adam(lr=learning_rate) +optimizer = tlx.optimizers.Adam(learning_rate) metrics = tlx.metrics.Accuracy() diff --git a/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py b/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py index d72b0c7..baf54a8 100644 --- a/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py +++ b/examples/basic_tutorials/gradient_clip_mixed_tensorflow.py @@ -4,8 +4,8 @@ import os # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'paddle' -# os.environ['TL_BACKEND'] = 'torch' -os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' +# os.environ['TL_BACKEND'] = 'jittor' import time diff --git a/examples/basic_tutorials/mnist_dataflow.py b/examples/basic_tutorials/mnist_dataflow.py index 18af70a..f5fb3de 100644 --- a/examples/basic_tutorials/mnist_dataflow.py +++ b/examples/basic_tutorials/mnist_dataflow.py @@ -3,10 +3,11 @@ import os # os.environ['TL_BACKEND'] = 'tensorflow' -os.environ['TL_BACKEND'] = 'jittor' - # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' +# os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' + import tensorlayerx as tlx from tensorlayerx.nn import Module diff --git a/examples/basic_tutorials/mnist_gan.py b/examples/basic_tutorials/mnist_gan.py index a40dd7a..5700dd6 100644 --- a/examples/basic_tutorials/mnist_gan.py +++ b/examples/basic_tutorials/mnist_gan.py @@ -5,8 +5,7 @@ # os.environ['TL_BACKEND'] = 'paddle' # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' -# os.environ['TL_BACKEND'] = 'torch' -os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' import time import numpy as np diff --git a/examples/basic_tutorials/mnist_mlp_custom_train.py b/examples/basic_tutorials/mnist_mlp_custom_train.py index fe66f50..514098e 100644 --- a/examples/basic_tutorials/mnist_mlp_custom_train.py +++ b/examples/basic_tutorials/mnist_mlp_custom_train.py @@ -6,9 +6,9 @@ # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' -os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'oneflow' -# os.environ['TL_BACKEND'] = 'torch' +# os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' import time import tensorlayerx as tlx diff --git a/examples/basic_tutorials/mnist_mlp_mix_programming.py b/examples/basic_tutorials/mnist_mlp_mix_programming.py index d05619e..9602605 100644 --- a/examples/basic_tutorials/mnist_mlp_mix_programming.py +++ b/examples/basic_tutorials/mnist_mlp_mix_programming.py @@ -1,27 +1,22 @@ -################################ TensorLayerX and Jittor can be mixed programming. ################################# - +################################## TensorLayerX and Torch can be mixed programming. ################################## import os -import time -import numpy as np -import tensorlayerx as tlx -import jittor as jt -from jittor import nn, optim +os.environ['TL_BACKEND'] = 'torch' + +import torch from tensorlayerx.nn import Module, Linear, Dropout +import tensorlayerx as tlx from tensorlayerx.dataflow import Dataset, DataLoader -from tqdm import tqdm -# Enable debug logging -tlx.logging.set_verbosity(tlx.logging.DEBUG) +# Get cpu or gpu device for training. +device = "cuda" if torch.cuda.is_available() else "cpu" +print("Using {} device".format(device)) -# Set the backend environment variable -os.environ['TL_BACKEND'] = 'jittor' - -# Load MNIST data by TensorLayerX +# Load MNIST data and make Dataset by TensorLayerX X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 784)) -# Define the MNIST dataset using TensorLayerX -class MNISTDataset(Dataset): - def __init__(self, data, label): +class mnistdataset(Dataset): + + def __init__(self, data=X_train, label=y_train): self.data = data self.label = label @@ -33,16 +28,16 @@ def __getitem__(self, index): def __len__(self): return len(self.data) -# Create DataLoaders for training and testing -train_dataset = MNISTDataset(data=X_train, label=y_train) -train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True) +train_dataset = mnistdataset(data=X_train, label=y_train) +train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) -# Define a simple MLP model using TensorLayerX +# Define the network through TensorLayerX class MLP(Module): + def __init__(self): super(MLP, self).__init__() self.dropout1 = Dropout(p=0.2) - self.linear1 = Linear(out_features=800, in_features=784) + self.linear1 = Linear(out_features=800, act=tlx.nn.ReLU, in_features=784) self.dropout2 = Dropout(p=0.2) self.linear2 = Linear(out_features=800, act=tlx.nn.ReLU, in_features=800) self.dropout3 = Dropout(p=0.2) @@ -57,64 +52,35 @@ def forward(self, x): out = self.linear3(z) return out -# Instantiate the model -model = MLP() -# Define the loss function -loss_fn = tlx.losses.softmax_cross_entropy_with_logits +model = MLP().to(device) -# Define the optimizer using Jittor -optimizer = optim.Adam(model.trainable_weights, lr=0.0001) +# Define the loss fucntion through TensorLayerX +loss_fn = tlx.losses.softmax_cross_entropy_with_logits +# Define the optimizer through torch +optimizer = torch.optim.SGD(lr=0.05, momentum=0.9, params=model.trainable_weights) -# Custom training loop n_epoch = 50 -print_freq = 1 +size = len(train_loader.dataset) +model.train() +# We use tlx's Model, loss function, Dataset and torch's optimizer to train the network for epoch in range(n_epoch): - start_time = time.time() - model.set_train() - train_loss, train_acc, n_iter = 0, 0, 0 - - with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{n_epoch}", unit="batch") as pbar: - for X_batch, y_batch in train_dataloader: - X_batch = tlx.convert_to_tensor(X_batch) - y_batch = tlx.convert_to_tensor(y_batch) - - # Forward pass - _logits = model(X_batch) - # Compute loss - _loss = loss_fn(_logits, y_batch) - # Backward pass and optimization - optimizer.step(_loss) - - train_loss += _loss.item() - train_acc += np.mean(np.equal(np.argmax(_logits, axis=1), y_batch)) - n_iter += 1 - - pbar.set_postfix({'loss': train_loss / n_iter, 'acc': train_acc / n_iter}) - pbar.update(1) - - # Print training progress - print("Epoch {} of {} took {:.2f}s".format(epoch + 1, n_epoch, time.time() - start_time)) - print(" train loss: {:.6f}".format(train_loss / n_iter)) - print(" train acc: {:.6f}".format(train_acc / n_iter)) - - # Validation (optional, using training data as a placeholder for validation) - val_loss, val_acc, n_iter = 0, 0, 0 - with tqdm(total=len(train_dataloader), desc="Validation", unit="batch") as pbar: - for X_batch, y_batch in train_dataloader: - X_batch = tlx.convert_to_tensor(X_batch) - y_batch = tlx.convert_to_tensor(y_batch) - _logits = model(X_batch) - val_loss += loss_fn(_logits, y_batch).item() - val_acc += np.mean(np.equal(np.argmax(_logits, axis=1), y_batch)) - n_iter += 1 - - pbar.set_postfix({'val_loss': val_loss / n_iter, 'val_acc': val_acc / n_iter}) - pbar.update(1) - print(" val loss: {:.6f}".format(val_loss / n_iter)) - print(" val acc: {:.6f}".format(val_acc / n_iter)) + for batch, (X, y) in enumerate(train_loader): + X, y = X.to(device), y.to(device) + + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + acc = tlx.metrics.acc(pred, y) + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} acc: {acc:>7f} [{current:>5d}/{size:>5d}] [{epoch} / {n_epoch}epoch]") ################################ TensorLayerX and TensorFlow can be mixed programming. ################################# @@ -379,86 +345,3 @@ def forward(self, x): # print(" train acc: {}".format(acc.numpy())) -################################## TensorLayerX and Torch can be mixed programming. ################################## -# import os -# os.environ['TL_BACKEND'] = 'torch' -# -# import torch -# from tensorlayerx.nn import Module, Linear, Dropout -# import tensorlayerx as tlx -# from tensorlayerx.dataflow import Dataset, DataLoader -# -# # Get cpu or gpu device for training. -# device = "cuda" if torch.cuda.is_available() else "cpu" -# print("Using {} device".format(device)) -# -# # Load MNIST data and make Dataset by TensorLayerX -# X_train, y_train, X_val, y_val, X_test, y_test = tlx.files.load_mnist_dataset(shape=(-1, 784)) -# -# class mnistdataset(Dataset): -# -# def __init__(self, data=X_train, label=y_train): -# self.data = data -# self.label = label -# -# def __getitem__(self, index): -# data = self.data[index].astype('float32') -# label = self.label[index].astype('int64') -# return data, label -# -# def __len__(self): -# return len(self.data) -# -# train_dataset = mnistdataset(data=X_train, label=y_train) -# train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) -# -# # Define the network through TensorLayerX -# class MLP(Module): -# -# def __init__(self): -# super(MLP, self).__init__() -# self.dropout1 = Dropout(p=0.2) -# self.linear1 = Linear(out_features=800, act=tlx.nn.ReLU, in_features=784) -# self.dropout2 = Dropout(p=0.2) -# self.linear2 = Linear(out_features=800, act=tlx.nn.ReLU, in_features=800) -# self.dropout3 = Dropout(p=0.2) -# self.linear3 = Linear(out_features=10, act=tlx.nn.ReLU, in_features=800) -# -# def forward(self, x): -# z = self.dropout1(x) -# z = self.linear1(z) -# z = self.dropout2(z) -# z = self.linear2(z) -# z = self.dropout3(z) -# out = self.linear3(z) -# return out -# -# -# model = MLP().to(device) -# -# # Define the loss fucntion through TensorLayerX -# loss_fn = tlx.losses.softmax_cross_entropy_with_logits -# # Define the optimizer through torch -# optimizer = torch.optim.SGD(lr=0.05, momentum=0.9, params=model.trainable_weights) -# -# n_epoch = 50 -# size = len(train_loader.dataset) -# model.train() -# -# # We use tlx's Model, loss function, Dataset and torch's optimizer to train the network -# for epoch in range(n_epoch): -# for batch, (X, y) in enumerate(train_loader): -# X, y = X.to(device), y.to(device) -# -# # Compute prediction error -# pred = model(X) -# loss = loss_fn(pred, y) -# acc = tlx.metrics.acc(pred, y) -# # Backpropagation -# optimizer.zero_grad() -# loss.backward() -# optimizer.step() -# -# if batch % 100 == 0: -# loss, current = loss.item(), batch * len(X) -# print(f"loss: {loss:>7f} acc: {acc:>7f} [{current:>5d}/{size:>5d}] [{epoch} / {n_epoch}epoch]") diff --git a/examples/basic_tutorials/mnist_mlp_simple_train.py b/examples/basic_tutorials/mnist_mlp_simple_train.py index e169c08..f1ff42a 100644 --- a/examples/basic_tutorials/mnist_mlp_simple_train.py +++ b/examples/basic_tutorials/mnist_mlp_simple_train.py @@ -6,9 +6,9 @@ # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' -os.environ['TL_BACKEND'] = 'jittor' +# os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'oneflow' -# os.environ['TL_BACKEND'] = 'torch' +os.environ['TL_BACKEND'] = 'torch' import tensorlayerx as tlx from tensorlayerx.nn import Module diff --git a/examples/basic_tutorials/mnist_sequential.py b/examples/basic_tutorials/mnist_sequential.py index 16f500b..2e28bbb 100644 --- a/examples/basic_tutorials/mnist_sequential.py +++ b/examples/basic_tutorials/mnist_sequential.py @@ -4,7 +4,8 @@ # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' -os.environ['TL_BACKEND'] = 'jittor' +# os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' from tensorlayerx.nn import Sequential from tensorlayerx.nn import Linear @@ -52,4 +53,4 @@ def __len__(self): ) model.train(n_epoch=n_epoch, train_dataset=train_loader, print_freq=print_freq, print_train_batch=False) model.save_weights('./model.npz', format='npz_dict') -model.load_weights('./model.npz', format='npz_dict', skip = True) +model.load_weights('./model.npz', format='npz_dict') diff --git a/examples/basic_tutorials/module_container.py b/examples/basic_tutorials/module_container.py index 5b7cffc..bd929af 100644 --- a/examples/basic_tutorials/module_container.py +++ b/examples/basic_tutorials/module_container.py @@ -3,10 +3,12 @@ import os # os.environ['TL_BACKEND'] = 'tensorflow' -os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'mindspore' +# os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'paddle' -# os.environ['TL_BACKEND'] = 'torch' +os.environ['TL_BACKEND'] = 'torch' + + import numpy as np from tensorlayerx.nn import Module, ModuleList, Linear, ModuleDict import tensorlayerx as tlx diff --git a/examples/basic_tutorials/parameter_container.py b/examples/basic_tutorials/parameter_container.py index d45ab13..f780dc3 100644 --- a/examples/basic_tutorials/parameter_container.py +++ b/examples/basic_tutorials/parameter_container.py @@ -2,8 +2,7 @@ # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' -os.environ['TL_BACKEND'] = 'jittor' -# os.environ['TL_BACKEND'] = 'torch' +os.environ['TL_BACKEND'] = 'torch' import tensorlayerx as tlx from tensorlayerx.nn import Module, Parameter, ParameterList, ParameterDict diff --git a/examples/basic_tutorials/quick_start.py b/examples/basic_tutorials/quick_start.py index d428882..70f2615 100644 --- a/examples/basic_tutorials/quick_start.py +++ b/examples/basic_tutorials/quick_start.py @@ -1,10 +1,10 @@ # TensorlayerX目前支持包括TensorFlow、Pytorch、PaddlePaddle、MindSpore作为计算后端,指定计算后端的方法也非常简单,只需要设置环境变量即可 import os # os.environ['TL_BACKEND'] = 'tensorflow' -os.environ['TL_BACKEND'] = 'jittor' -# os.environ['TL_BACKEND'] = 'torch' # os.environ['TL_BACKEND'] = 'mindspore' # os.environ['TL_BACKEND'] = 'paddle' +# os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' import tensorlayerx as tlx diff --git a/examples/basic_tutorials/tensorlayerx_graph.py b/examples/basic_tutorials/tensorlayerx_graph.py index f76eefd..f36bba4 100644 --- a/examples/basic_tutorials/tensorlayerx_graph.py +++ b/examples/basic_tutorials/tensorlayerx_graph.py @@ -4,8 +4,8 @@ import os # os.environ['TL_BACKEND'] = 'tensorflow' # os.environ['TL_BACKEND'] = 'mindspore' -os.environ['TL_BACKEND'] = 'jittor' -# os.environ['TL_BACKEND'] = 'torch' +# os.environ['TL_BACKEND'] = 'jittor' +os.environ['TL_BACKEND'] = 'torch' import tensorlayerx as tlx from tensorlayerx.nn import Module diff --git a/examples/basic_tutorials/tensorlayerx_model_load.py b/examples/basic_tutorials/tensorlayerx_model_load.py index 567ecd7..4f9a16e 100644 --- a/examples/basic_tutorials/tensorlayerx_model_load.py +++ b/examples/basic_tutorials/tensorlayerx_model_load.py @@ -3,10 +3,9 @@ import os # os.environ['TL_BACKEND'] = 'tensorflow' -os.environ['TL_BACKEND'] = 'jittor' # os.environ['TL_BACKEND'] = 'paddle' # os.environ['TL_BACKEND'] = 'mindspore' -# os.environ['TL_BACKEND'] = 'torch' +os.environ['TL_BACKEND'] = 'torch' import tensorlayerx as tlx from tensorlayerx.nn import Module @@ -62,11 +61,11 @@ def forward(self, x): z = self.conv1(x) print("conv1 outputs:", z[1, :, :, 1]) z = self.maxpool1(z) - # print("maxpool outputs:", z[1, :, :, 1]) + print("maxpool outputs:", z[1, :, :, 1]) z = self.conv2(z) - # print("conv2 outputs:", z[1, :, :, 1]) + print("conv2 outputs:", z[1, :, :, 1]) z = self.maxpool2(z) - # print("max2 outputs:", z[1, :, :, 1]) + print("max2 outputs:", z[1, :, :, 1]) z = self.flatten(z) z = self.linear1(z) z = self.linear2(z) @@ -88,7 +87,7 @@ def forward(self, x): # and imported into TensorFlow/PyTorch/PaddlePaddle/MindSpore. cnn = CNN() # cnn.save_standard_weights('./cnn.npz') -cnn.load_standard_weights('./cnn.npz', weights_from='torch', weights_to='tensorflow', skip= True) +cnn.load_standard_weights('./cnn.npz', weights_from='torch', weights_to='tensorflow') cnn.set_eval() inputs = tlx.nn.Input(shape=(10, 28, 28, 3), dtype=tlx.float32) diff --git a/tensorlayerx/__init__.py b/tensorlayerx/__init__.py index 481c043..b174e13 100644 --- a/tensorlayerx/__init__.py +++ b/tensorlayerx/__init__.py @@ -40,6 +40,7 @@ 'paddle': '2.2.0', 'torch': '1.10.0', 'jittor': '1.3.8.5', + 'oneflow':'0.9.0' } if BACKEND_VERSION != backend_v[BACKEND]: diff --git a/tensorlayerx/backend/ops/jittor_backend.py b/tensorlayerx/backend/ops/jittor_backend.py index bbce67f..e74acb7 100644 --- a/tensorlayerx/backend/ops/jittor_backend.py +++ b/tensorlayerx/backend/ops/jittor_backend.py @@ -1179,7 +1179,7 @@ def __call__(self, x, y): class CountNonzero(object): - def __init__(self, keepdims=None, dtype=None): + def __init__(self, keepdims=None, dtype="float32"): self.keepdims = keepdims self.dtype = dtype @@ -1354,7 +1354,7 @@ def angle(x): def argmax(x, axis=None, keepdim=False, dtype='int64'): - return jt.argmax(x, dim=axis, keepdim=keepdim) + return jt.argmax(x, dim=axis, keepdims=keepdim) def argmin(x, axis=None, dtype='int64'): @@ -1646,8 +1646,8 @@ def where(condition, x, y): return jt.where(condition,x, y) -def ones_like(x, dtype=None): - return jt.ones_like(x, dtype=dtype) +def ones_like(x): + return jt.ones_like(x) def zeros_like(x, dtype=None): @@ -1734,7 +1734,7 @@ def set_seed(seed): def is_tensor(x): - return isinstance(x, jt.Tensor) + return isinstance(x, jt.Var) def tensor_scatter_nd_update(tensor, indices, updates): tensor = jt.array(tensor) @@ -1765,10 +1765,10 @@ def mask_select(x, mask, axis = 0): elif axis == 3: return x[:,:,:, mask] -def eye(n, m=None, dtype=None): +def eye(n, m=None, dtype="float32"): if m is None: m = n - return jt.init.eye((n,m), dtype =dtype) + return jt.init.eye((n,m), dtype=dtype) def einsum(equation, *operands): diff --git a/tensorlayerx/backend/ops/jittor_nn.py b/tensorlayerx/backend/ops/jittor_nn.py index 985f3f0..40e2438 100644 --- a/tensorlayerx/backend/ops/jittor_nn.py +++ b/tensorlayerx/backend/ops/jittor_nn.py @@ -88,11 +88,17 @@ def preprocess_1d_format(data_format, padding): data_format = "NLC" elif data_format in ["channels_first", "NCW", "NCL"]: data_format = "NCL" - elif data_format == None: + elif data_format is None: data_format = None else: raise Exception("Unsupported data format: " + str(data_format)) + padding = padding_format(padding) + # Convert padding to numerical representation for arithmetic operations + if padding == "same": + padding = 1 + elif padding == "valid": + padding = 0 return data_format, padding @@ -634,18 +640,25 @@ def same_padding(input, weight, strides, dilations): return rows_odd, cols_odd, depth_odd, padding_rows, padding_cols, padding_depth + class Conv2D(object): def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None, groups=1): self.data_format, self.padding = preprocess_2d_format(data_format, padding) + + # Ensure strides is a tuple/list of length 2 with non-zero values + if len(strides) != 2 or strides[0] == 0 or strides[1] == 0: + raise ValueError("Stride values must be greater than zero and of length 2") + + # Adjust the strides and dilations for the data format if self.data_format == 'NHWC': - self.strides = (strides[1], strides[2]) + self.strides = (strides[0], strides[1]) self.dilations = (dilations[0], dilations[1]) elif self.data_format == 'NCHW': - self.strides = (strides[1], strides[2]) - self.dilations = (dilations[1], dilations[2]) - self.groups = groups + self.strides = (strides[0], strides[1]) + self.dilations = (dilations[0], dilations[1]) + self.groups = groups def __call__(self, input, filters): @@ -847,8 +860,7 @@ def __call__(self, *args, **kwargs): class MaxPool(object): - - def __init__(self, ksize, strides, padding, return_mask = False, data_format=None): + def __init__(self, ksize, strides, padding, return_mask=False, data_format=None): self.ksize = ksize self.strides = strides self.return_mask = return_mask @@ -863,6 +875,7 @@ def __init__(self, ksize, strides, padding, return_mask = False, data_format=Non def __call__(self, inputs): if self.data_format == 'channels_last': inputs = nhwc_to_nchw(inputs) + if len(inputs.shape) == 2 or len(inputs.shape) == 3: raise NotImplementedError @@ -872,6 +885,7 @@ def __call__(self, inputs): else: out = nn.max_pool2d(inputs, self.ksize, self.strides, padding=self.padding, return_indices=self.return_mask) + if len(inputs.shape) == 5: if self.padding in ['SAME', 'same']: out = self.maxpool3d_same_padding(inputs) @@ -879,6 +893,7 @@ def __call__(self, inputs): out = nn.max_pool3d(inputs, self.ksize, self.strides, padding=self.padding, return_indices=self.return_mask) + if self.data_format == 'channels_last': if self.return_mask: outputs = [None, None] @@ -891,6 +906,7 @@ def __call__(self, inputs): return out + def maxpool2d_same_padding(self, input): rows_odd, cols_odd, padding_rows, padding_cols = same_padding(input, self.ksize, self.strides, (1, 1)) if rows_odd or cols_odd: @@ -965,7 +981,6 @@ def __call__(self, *args, **kwargs): raise NotImplementedError("AvgPool1d is not implemented in Jittor backend") - class AvgPool(object): def __init__(self, ksize, strides, padding, data_format=None): @@ -994,7 +1009,7 @@ def __call__(self, inputs): if self.padding in ['SAME', 'same']: out = self.avgpool3d_same_padding(inputs) else: - out = nn.AvgPool2d(inputs, self.ksize, self.strides, padding=self.padding) + out = nn.AvgPool3d(inputs, self.ksize, self.strides, padding=self.padding) if self.data_format == 'channels_last': return nchw_to_nhwc(out) @@ -1002,6 +1017,7 @@ def __call__(self, inputs): return out + def avgpool2d_same_padding(self, input): rows_odd, cols_odd, padding_rows, padding_cols = same_padding(input, self.ksize, self.strides, (1, 1)) if rows_odd or cols_odd: @@ -1065,7 +1081,7 @@ def avg_pool2d(input, kernel_size, stride=None, padding=0, data_format='NCHW'): def avg_pool3d(input, kernel_size, stride=None, padding=0, data_format='NCDHW'): data_format, padding = preprocess_3d_format(data_format, padding) - avg_pool_obj = AvgPool(kernel_size, stride, padding, data_format) + avg_pool_obj = AvgPool3d(kernel_size, stride, padding) return avg_pool_obj(input) class MaxPool3d(object): @@ -1149,7 +1165,6 @@ def __call__(self, inputs): # avg_pool_obj = AvgPool(ksize, strides, padding, data_format) # return avg_pool_obj(input) - def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None): """ Performs an N-D pooling operation. @@ -1158,8 +1173,6 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_ ---------- input : tensor Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] - if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape - if data_format starts with "NC". Pooling happens over the spatial dimensions only. window_shape : int Sequence of N ints >= 1. pooling_type : string @@ -1168,12 +1181,9 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_ Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1. padding : string The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME". - See the "returns" section of tf.ops.convolution for details. data_format : string Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"), or the second dimension (if data_format starts with "NC"). - For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". - For N=3, the valid values are "NDHWC" (default) and "NCDHW". dilations : list of ints Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1. name : string @@ -1193,7 +1203,6 @@ def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_ return pool_obj(input) - class DepthwiseConv2d(object): def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1, in_channels=None): @@ -2023,58 +2032,97 @@ def __call__(self, input, hx=None): -class lstmcell(object): - def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): - self.weight_ih = weight_ih - self.weight_hh = weight_hh - self.bias_ih = bias_ih - self.bias_hh = bias_hh - def __call__(self, input, h, c): - gates = jt.matmul(input, jt.transpose(self.weight_ih)) + jt.matmul(h, jt.transpose(self.weight_hh)) +class lstmcell(Module): + def __init__(self, weight_ih, weight_hh, bias_ih=None, bias_hh=None): + super(lstmcell, self).__init__() + + self.weight_ih = weight_ih # Shape: [input_size, 4 * hidden_size] + self.weight_hh = weight_hh # Shape: [hidden_size, 4 * hidden_size] + self.bias_ih = bias_ih if bias_ih is not None else jt.ones(4 * weight_ih.shape[1]) # Bias for input-to-hidden + self.bias_hh = bias_hh if bias_hh is not None else jt.ones(4 * weight_hh.shape[1]) # Bias for hidden-to-hidden + + # Extract input_size and hidden_size from the weight shapes + self.input_size = weight_ih.shape[0] + self.hidden_size = weight_hh.shape[0] + + def execute(self, input, h, c): + + gates_input = jt.matmul(input, self.weight_ih) # [batch_size, 4 * hidden_size] + gates_hidden = jt.matmul(h, self.weight_hh) # [batch_size, 4 * hidden_size] + + gates = gates_input + gates_hidden + + # Add bias terms if provided if self.bias_ih is not None: gates += self.bias_ih + self.bias_hh i, f, g, o = jt.chunk(gates, 4, dim=1) - i = jt.sigmoid(i) - f = jt.sigmoid(f) - g = jt.tanh(g) - o = jt.sigmoid(o) - c_new = f * c + i * g - h_new = o * jt.tanh(c_new) - return h_new, h_new, c_new + # Apply activations to the gates + i = jt.sigmoid(i) # Input gate + f = jt.sigmoid(f) # Forget gate + g = jt.tanh(g) # Cell gate (candidate cell state) + o = jt.sigmoid(o) # Output gate + + # Compute new cell state + c_new = f * c + i * g # Cell state update + + # Compute new hidden state + h_new = o * jt.tanh(c_new) # Hidden state + + return h_new, h_new, c_new # Return hidden state and cell state + + + class grucell(Module): def __init__(self, weight_ih, weight_hh, bias_ih=None, bias_hh=None): super(grucell, self).__init__() - self.weight_ih = weight_ih - self.weight_hh = weight_hh - self.bias_ih = bias_ih - self.bias_hh = bias_hh - self.hidden_size = weight_hh.shape[1] - def execute(self, inputs, states): - hx = states[0] if isinstance(states, (tuple, list)) else states - gates = jt.matmul(inputs, self.weight_ih.t()) + jt.matmul(hx, self.weight_hh.t()) - if self.bias_ih is not None and self.bias_hh is not None: - gates += self.bias_ih + self.bias_hh + # Initialize the weights and biases + self.weight_ih = weight_ih # Shape: [input_size, 3 * hidden_size] + self.weight_hh = weight_hh # Shape: [hidden_size, 3 * hidden_size] + self.bias_ih = bias_ih if bias_ih is not None else jt.ones(weight_ih.shape[1]) # Bias for input-to-hidden + self.bias_hh = bias_hh if bias_hh is not None else jt.ones(weight_hh.shape[1]) # Bias for hidden-to-hidden + + # Extract input_size and hidden_size from weight shapes + self.input_size = weight_ih.shape[0] + self.hidden_size = weight_hh.shape[0] + + def execute(self, inputs, hx): + """ + Args: + - inputs: Input tensor [batch_size, input_size] + - hx: Previous hidden state [batch_size, hidden_size] - # Separate the gates - r, z, n = jt.chunk(gates, 3, dim=1) - - r = jt.sigmoid(r) - z = jt.sigmoid(z) - n = jt.tanh(n + r * (jt.matmul(hx, self.weight_hh[2 * self.hidden_size:].t()) + (self.bias_hh[2 * self.hidden_size:] if self.bias_hh is not None else 0))) - hy = (1 - z) * n + z * hx + Returns: + - hy: New hidden state [batch_size, hidden_size] + - hy_new: New hidden state (same as hy) for consistency + """ + + # Split the weights for the gates (GRU uses 3 * hidden_size) + weight_ih_r, weight_ih_z, weight_ih_h = jt.split(self.weight_ih, 3, dim=1) + weight_hh_r, weight_hh_z, weight_hh_h = jt.split(self.weight_hh, 3, dim=1) - return hy, hy + # Bias terms for reset, update, and candidate hidden states + bias_ih_r, bias_ih_z, bias_ih_h = jt.split(self.bias_ih, 3) + bias_hh_r, bias_hh_z, bias_hh_h = jt.split(self.bias_hh, 3) + # 1. Compute the reset gate (r) + r = jt.sigmoid(jt.matmul(inputs, weight_ih_r) + bias_ih_r + jt.matmul(hx, weight_hh_r) + bias_hh_r) + # 2. Compute the update gate (z) + z = jt.sigmoid(jt.matmul(inputs, weight_ih_z) + bias_ih_z + jt.matmul(hx, weight_hh_z) + bias_hh_z) + # 3. Compute the candidate hidden state (h') + h_hat = jt.tanh(jt.matmul(inputs, weight_ih_h) + bias_ih_h + r * (jt.matmul(hx, weight_hh_h) + bias_hh_h)) + # 4. Compute the new hidden state (h) + hy = (1 - z) * hx + z * h_hat + return hy, hy # Return the new hidden state as both outputs (for consistency) class rnnbase(Module): @@ -2746,9 +2794,22 @@ def swish(input): return NotImplementedError -def linear(input, weight, bias = None): +def linear(input, weight, bias=None): + ''' Custom Linear Layer Implementation ''' + + # Perform matrix multiplication (input * weight^T) + x = jt.matmul(input, weight) # input is of shape [batch, in_features], weight is of shape [in_features, out_features] + + if bias is not None: + # Ensure the bias is correctly reshaped for broadcasting + if bias.ndim == 1: + # Bias should be broadcasted across the batch dimension + bias = bias.reshape(1, -1) # Shape: [1, out_features] + + # Add bias to the result + x = x + bias # Broadcasting bias to match the result shape - return nn.linear(input, weight, bias) + return x def unfold(input, kernel_size, dilation = 1, padding = 0, stride = 1): diff --git a/tensorlayerx/backend/ops/oneflow_backend.py b/tensorlayerx/backend/ops/oneflow_backend.py index b4990b7..47e0b3e 100644 --- a/tensorlayerx/backend/ops/oneflow_backend.py +++ b/tensorlayerx/backend/ops/oneflow_backend.py @@ -179,7 +179,7 @@ def random_uniform(shape, minval=0, maxval=1, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.randn(shape, dtype=_dtypeDict[dtype]) out = w.uniform_(minval, maxval) @@ -211,7 +211,7 @@ def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) return flow.normal(shape, mean=mean, std=stddev, dtype=_dtypeDict[dtype]) @@ -241,7 +241,7 @@ def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.empty(shape, dtype=_dtypeDict[dtype]) out = nn.init.truncated_normal_(w, mean=mean, std=stddev) @@ -271,7 +271,7 @@ def he_normal(shape, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.empty(shape, dtype=_dtypeDict[dtype]) out = nn.init.kaiming_normal_(w) @@ -301,7 +301,7 @@ def he_uniform(shape, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.empty(shape, dtype=_dtypeDict[dtype]) out = nn.init.kaiming_uniform_(w) @@ -331,7 +331,7 @@ def xavier_normal(shape, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.empty(shape, dtype=_dtypeDict[dtype]) out = nn.init.xavier_normal_(w) @@ -363,7 +363,7 @@ def xavier_uniform(shape, gain=1.0, dtype=None, seed=None): if seed is not None: flow.manual_seed(seed) else: - flow.manual_seed(flow.random.gen_seed()) + flow.manual_seed(flow.initial_seed()) w = flow.empty(shape, dtype=_dtypeDict[dtype]) out = nn.init.xavier_uniform_(w, gain=gain) @@ -674,7 +674,7 @@ def reduce_mean(input_tensor, axis=None, keepdims=False): if axis is not None: return flow.mean(input_tensor, dim=axis, keepdim=keepdims) else: - return flow.mean(input_tensor, keepdim=keepdims) + return flow.mean(input_tensor) class ReduceMax(object): @@ -718,7 +718,7 @@ def reduce_max(input_tensor, axis=None, keepdims=False): if axis is not None: return flow.max(input_tensor, dim=axis, keepdim=keepdims) else: - return flow.max(input_tensor, keepdim=keepdims) + return flow.max(input_tensor) def reduce_min(input_tensor, axis=None, keepdims=False): @@ -1582,11 +1582,11 @@ def count_nonzero(x, axis=None, keepdims=None, dtype="int64"): return convert_to_tensor(non_zero) -def cumprod(x, axis=None, dtype=None, out=None): +def cumprod(x, axis=0, dtype=None, out=None): return flow.cumprod(x, dim=axis) -def cumsum(x, axis=None, dtype=None, out=None): +def cumsum(x, axis=0, dtype=None, out=None): return flow.cumsum(x, dim=axis) def equal(x, y): @@ -1892,7 +1892,7 @@ def mask_select(x, mask, axis = 0): elif axis == 3: return x[:,:,:, mask] -def eye(n, m=None, dtype=None): +def eye(n, m=None, dtype=flow.float32): if m is None: m = n return flow.eye(n, m, dtype=dtype) @@ -2014,5 +2014,4 @@ def flip(x, axis): def mv(x, vec): - raise NotImplementedError - + raise NotImplementedError \ No newline at end of file