{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# **Introduction**\n", "\n", "This guide follows the same format as [quickstart](00-quickstart.ipynb) but explores further functionality provided by twinLab. In this jupyter notebook we will:\n", "\n", "1. Upload a dataset to twinLab.\n", "2. List, view and summarise uploaded datasets.\n", "3. Use `Emulator.train` to create a surrogate model.\n", "4. List, view and summarise trained emulators.\n", "5. Use the model to make a prediction with `Emulator.predict`.\n", "6. Visualise the results and their uncertainty.\n", "7. Verify the model using `Emulator.sample`.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " ====== TwinLab Client Initialisation ======\n", " Version : 2.0.0\n", " Server : https://twinlab.digilab.co.uk\n", " Environment : /Users/mead/digiLab/twinLab-Demos/.env\n", "\n" ] } ], "source": [ "# Standard imports\n", "from pprint import pprint\n", "\n", "# Third-party imports\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "\n", "# Project imports\n", "import twinlab as tl" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Your twinLab information**\n", "\n", "Confirm your twinLab version\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'cloud': '2.0.0',\n", " 'modal': '0.2.0',\n", " 'library': '1.3.0',\n", " 'image': 'twinlab-prod'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tl.versions()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And view your user information, including how many credits you have.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'username': 'alexander', 'credits': 1000}" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tl.user_information()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Upload a dataset**\n", "\n", "Datasets must be data presented as a `pandas.DataFrame` object, or a filepaths which points to a csv file that can be parsed to a `pandas.DataFrame` object. **Both must be formatted with clearly labelled columns.** Here, we will label the input (predictor) variable `x` and the output variable `y`. In `twinlab`, data is expected to be in column-feature format, meaning each row represents a single data sample, and each column represents a data feature.\n", "\n", "`twinLab` contains a `Dataset` class with attirbutes and methods to process, view and summarise the dataset. Datasets must be created with a `dataset_id` which is used to access them. The dataset can be uploaded using the `upload` method.\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
xy
00.696469-0.817374
10.2861390.887656
20.2268510.921553
30.551315-0.326334
40.719469-0.832518
50.4231060.400669
60.980764-0.164966
70.684830-0.960764
80.4809320.340115
90.3921180.845795
\n", "
" ], "text/plain": [ " x y\n", "0 0.696469 -0.817374\n", "1 0.286139 0.887656\n", "2 0.226851 0.921553\n", "3 0.551315 -0.326334\n", "4 0.719469 -0.832518\n", "5 0.423106 0.400669\n", "6 0.980764 -0.164966\n", "7 0.684830 -0.960764\n", "8 0.480932 0.340115\n", "9 0.392118 0.845795" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Dataframe is uploading.\n", "Processing dataset\n", "Dataset example_data was processed.\n" ] } ], "source": [ "x = [\n", " 0.6964691855978616,\n", " 0.28613933495037946,\n", " 0.2268514535642031,\n", " 0.5513147690828912,\n", " 0.7194689697855631,\n", " 0.42310646012446096,\n", " 0.9807641983846155,\n", " 0.6848297385848633,\n", " 0.48093190148436094,\n", " 0.3921175181941505,\n", "]\n", "\n", "y = [\n", " -0.8173739564129022,\n", " 0.8876561174050408,\n", " 0.921552660721474,\n", " -0.3263338765412979,\n", " -0.8325176123242133,\n", " 0.4006686354731812,\n", " -0.16496626502368078,\n", " -0.9607643657025954,\n", " 0.3401149876855609,\n", " 0.8457949914442409,\n", "]\n", "\n", "# Creating the dataframe using the above arrays\n", "df = pd.DataFrame({\"x\": x, \"y\": y})\n", "\n", "# View the dataset before uploading\n", "display(df)\n", "\n", "# Define the name of the dataset\n", "dataset_id = \"example_data\"\n", "\n", "# Intialise a Dataset object\n", "dataset = tl.Dataset(id=dataset_id)\n", "\n", "# Upload the dataset\n", "dataset.upload(df, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **View datasets**\n", "\n", "Once a dataset has been uploaded it can be easily accessed using built in twinLab functions. A list of all uploaded datasets can be produced, individual datasets can be viewed and summarised. This summary contains some basic statistics of the data.\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['2DActive_Data',\n", " 'Excel-test',\n", " 'Falmouth-Mikey',\n", " 'New_Points',\n", " 'biscuits',\n", " 'eval_data',\n", " 'example_data',\n", " 'functional-data',\n", " 'functional-test-data',\n", " 'fusion',\n", " 'sampled-data',\n", " 'twinLab-logo']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List all datasets on cloud\n", "tl.list_datasets()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
xy
00.696469-0.817374
10.2861390.887656
20.2268510.921553
30.551315-0.326334
40.719469-0.832518
50.4231060.400669
60.980764-0.164966
70.684830-0.960764
80.4809320.340115
90.3921180.845795
\n", "
" ], "text/plain": [ " x y\n", "0 0.696469 -0.817374\n", "1 0.286139 0.887656\n", "2 0.226851 0.921553\n", "3 0.551315 -0.326334\n", "4 0.719469 -0.832518\n", "5 0.423106 0.400669\n", "6 0.980764 -0.164966\n", "7 0.684830 -0.960764\n", "8 0.480932 0.340115\n", "9 0.392118 0.845795" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# View the dataset\n", "dataset.view()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
xy
count10.00000010.000000
mean0.5441990.029383
std0.2293520.748191
min0.226851-0.960764
25%0.399865-0.694614
50%0.5161230.087574
75%0.6935590.734513
max0.9807640.921553
\n", "
" ], "text/plain": [ " x y\n", "count 10.000000 10.000000\n", "mean 0.544199 0.029383\n", "std 0.229352 0.748191\n", "min 0.226851 -0.960764\n", "25% 0.399865 -0.694614\n", "50% 0.516123 0.087574\n", "75% 0.693559 0.734513\n", "max 0.980764 0.921553" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get a statistical summary of the dataset\n", "dataset.summarise()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Train an emulator**\n", "\n", "The `Emulator` class is used to train and implement your surrogate models. As with datasets, an id is defined, this is what the model will be saved as in the cloud. When training a model the arguments are passed using a `TrainParams` object; `TrainParams` is a class that contains all the necessary parameters needed to train your model. To train the model we use the `Emulator.train` function, inputting the `TrainParams` object as an argument to this function.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model example_emulator has begun training.\n", "Training complete!\n", "\n" ] } ], "source": [ "# Initialise emulator\n", "emulator_id = \"example_emulator\"\n", "\n", "emulator = tl.Emulator(id=emulator_id)\n", "\n", "# Define the training parameters for your emulator\n", "params = tl.TrainParams(train_test_ratio=1.0)\n", "\n", "# Train the mulator using the train method\n", "emulator.train(dataset=dataset, inputs=[\"x\"], outputs=[\"y\"], params=params, verbose=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **View emulators**\n", "\n", "Just as with datasets all saved emulators can be listed, viewed and summarised.\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['2DActiveGP',\n", " 'Example_emulator',\n", " 'Excel-emulator',\n", " 'Excel-model',\n", " 'Hello',\n", " 'backward-model',\n", " 'biscuits',\n", " 'campaign',\n", " 'decoder',\n", " 'example_emulator',\n", " 'fusion',\n", " 'gardening',\n", " 'my_emulator',\n", " 'new-campaign',\n", " 'twinLab-logo',\n", " 'universe']" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# List emulators\n", "tl.list_emulators()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'model_id': 'example_emulator',\n", " 'fidelity': None,\n", " 'estimator': 'gaussian_process_regression',\n", " 'estimator_kwargs': {'detrend': False,\n", " 'device': 'cpu',\n", " 'covar_module': None,\n", " 'estimator_type': None},\n", " 'decompose_input': False,\n", " 'input_explained_variance': 0.99,\n", " 'decompose_output': False,\n", " 'output_explained_variance': 0.99,\n", " 'train_test_ratio': 1.0,\n", " 'model_selection': False,\n", " 'model_selection_kwargs': {'seed': None,\n", " 'evaluation_metric': 'MSLL',\n", " 'val_ratio': 0.2,\n", " 'base_kernels': 'restricted',\n", " 'depth': 1,\n", " 'beam': None,\n", " 'resource_per_trial': {'cpu': 1, 'gpu': 0}},\n", " 'seed': None,\n", " 'inputs': ['x'],\n", " 'outputs': ['y'],\n", " 'dataset_id': 'example_data',\n", " 'modal_handle': 'fc-ktRp3Nc749lshsKZvgJH4y'}" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# View an emulator's parameters\n", "emulator.view()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'model_summary': {'data_diagnostics': {'inputs': {'x': {'25%': 0.39986475367672814,\n", " '50%': 0.5161233352836261,\n", " '75%': 0.693559323844612,\n", " 'count': 10.0,\n", " 'max': 0.9807641983846156,\n", " 'mean': 0.544199352975335,\n", " 'min': 0.2268514535642031,\n", " 'std': 0.22935216613691597}},\n", " 'outputs': {'y': {'25%': -0.6946139364450011,\n", " '50%': 0.0875743613309401,\n", " '75%': 0.7345134024514759,\n", " 'count': 10.0,\n", " 'max': 0.921552660721474,\n", " 'mean': 0.029383131672480845,\n", " 'min': -0.9607643657025954,\n", " 'std': 0.7481906564998719}}},\n", " 'estimator_diagnostics': {'covar_module': 'ScaleKernel(\\n'\n", " ' (base_kernel): '\n", " 'MaternKernel(\\n'\n", " ' '\n", " '(lengthscale_prior): '\n", " 'GammaPrior()\\n'\n", " ' '\n", " '(raw_lengthscale_constraint): '\n", " 'Positive()\\n'\n", " ' )\\n'\n", " ' '\n", " '(outputscale_prior): '\n", " 'GammaPrior()\\n'\n", " ' '\n", " '(raw_outputscale_constraint): '\n", " 'Positive()\\n'\n", " ')',\n", " 'covar_module.base_kernel.lengthscale_prior.concentration': 3.0,\n", " 'covar_module.base_kernel.lengthscale_prior.rate': 6.0,\n", " 'covar_module.base_kernel.original_lengthscale': [[0.4232063885665337]],\n", " 'covar_module.base_kernel.raw_lengthscale': [[-0.6408405631160488]],\n", " 'covar_module.base_kernel.raw_lengthscale_constraint.lower_bound': 0.0,\n", " 'covar_module.base_kernel.raw_lengthscale_constraint.upper_bound': inf,\n", " 'covar_module.original_outputscale': 1.7130960752713094,\n", " 'covar_module.outputscale_prior.concentration': 2.0,\n", " 'covar_module.outputscale_prior.rate': 0.15000000596046448,\n", " 'covar_module.raw_outputscale': 1.514271061131159,\n", " 'covar_module.raw_outputscale_constraint.lower_bound': 0.0,\n", " 'covar_module.raw_outputscale_constraint.upper_bound': inf,\n", " 'input_transform._coefficient': [[0.7539127448204125]],\n", " 'input_transform._offset': [[0.2268514535642031]],\n", " 'likelihood.noise_covar.noise_prior.concentration': 1.100000023841858,\n", " 'likelihood.noise_covar.noise_prior.rate': 0.05000000074505806,\n", " 'likelihood.noise_covar.original_noise': [0.031576525703137535],\n", " 'likelihood.noise_covar.raw_noise': [0.031576525703137535],\n", " 'likelihood.noise_covar.raw_noise_constraint.lower_bound': 9.999999747378752e-05,\n", " 'likelihood.noise_covar.raw_noise_constraint.upper_bound': inf,\n", " 'mean_module': 'ConstantMean()',\n", " 'mean_module.original_constant': 0.21052500685316855,\n", " 'mean_module.raw_constant': 0.21052500685316855,\n", " 'outcome_transform._stdvs_sq': [[0.5597892584737093]],\n", " 'outcome_transform.means': [[0.029383131672480856]],\n", " 'outcome_transform.stdvs': [[0.7481906564998719]]},\n", " 'transformer_diagnostics': []}}\n" ] } ], "source": [ "# View the status of a campaign\n", "pprint(emulator.summarise())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Prediction using the trained emulators**\n", "\n", "The surrogate model is now trained and saved to the cloud under the `emulator_id`. It can now be used to make predictions. First define a dataset of inputs for which you want to find outputs; ensure that this is a `pandas.DataFrame` object. Then call `Emulator.predict` with the keyword arguments being the evaluation dataset.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x
00.000000
10.007874
20.015748
30.023622
40.031496
......
1230.968504
1240.976378
1250.984252
1260.992126
1271.000000
\n", "

128 rows × 1 columns

\n", "
" ], "text/plain": [ " x\n", "0 0.000000\n", "1 0.007874\n", "2 0.015748\n", "3 0.023622\n", "4 0.031496\n", ".. ...\n", "123 0.968504\n", "124 0.976378\n", "125 0.984252\n", "126 0.992126\n", "127 1.000000\n", "\n", "[128 rows x 1 columns]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ " y y\n", "0 0.617689 0.656265\n", "1 0.629105 0.640576\n", "2 0.640630 0.624421\n", "3 0.652252 0.607809\n", "4 0.663957 0.590755\n" ] } ], "source": [ "# Define the inputs for the dataset\n", "x_eval = np.linspace(0, 1, 128)\n", "\n", "# Convert to a dataframe\n", "df_eval = pd.DataFrame({\"x\": x_eval})\n", "display(df_eval)\n", "\n", "# Predict the results\n", "predictions = emulator.predict(df_eval)\n", "result_df = pd.concat([predictions[0], predictions[1]], axis=1)\n", "df_mean, df_stdev = result_df.iloc[:,0], result_df.iloc[:,1]\n", "df_mean, df_stdev = df_mean.values, df_stdev.values\n", "print(result_df.head())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Viewing the results**\n", "\n", "`Emulator.predict` outputs mean values for each input and their standard deviation; this gives the abilty to nicely visualise the uncertainty in results.\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot parameters\n", "nsigs = [1, 2]\n", "color = \"blue\"\n", "alpha = 0.5\n", "plot_training_data = True\n", "plot_model_mean = True\n", "plot_model_bands = True\n", "\n", "# Plot results\n", "grid = df_eval[\"x\"]\n", "mean = df_mean\n", "err = df_stdev\n", "if plot_model_bands:\n", " label = r\"Model prediction\"\n", " plt.fill_between(grid, np.nan, np.nan, lw=0, color=color, alpha=alpha, label=label)\n", " for isig, nsig in enumerate(nsigs):\n", " plt.fill_between(\n", " grid,\n", " mean - nsig * err,\n", " mean + nsig * err,\n", " lw=0,\n", " color=color,\n", " alpha=alpha / (isig + 1),\n", " )\n", "if plot_model_mean:\n", " label = r\"Model prediction\" if not plot_model_bands else None\n", " plt.plot(grid, mean, color=color, alpha=alpha, label=label)\n", "if plot_training_data:\n", " plt.plot(df[\"x\"], df[\"y\"], \".\", color=\"black\", label=\"Training data\")\n", "plt.xlim((0.0, 1.0))\n", "plt.xlabel(r\"$X$\")\n", "plt.ylabel(r\"$y$\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Sampling from an emulator**\n", "\n", "The `Emulator.sample` function can be used to retrieve a number of results from your model. It requires the inputs for which you want the values and how many outputs to calculate for each.\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
y
0123456789...90919293949596979899
01.2641271.2594980.6648561.290327-1.0611150.0633561.3266661.789418-0.0090241.259869...0.9392940.3622070.0115211.053759-0.4095181.0144780.7519291.6412820.8260820.469116
11.2630281.2741430.6434891.300232-0.9942510.1086711.3422421.7648970.0085741.246439...0.9781430.3439550.0508091.024005-0.4188231.0161790.7796501.5982080.8266130.468406
21.2601501.2837850.6218061.308611-0.9210230.1572521.3544301.7401390.0313731.232385...1.0147550.3280730.0914200.993893-0.4238981.0147950.8091901.5522880.8309500.468252
31.2560191.2903810.6020941.315111-0.8428580.2071791.3634441.7142330.0566311.218001...1.0481970.3158760.1300090.965887-0.4241341.0106280.8404161.5035790.8392560.469101
41.2498431.2945670.5844121.319314-0.7620380.2586421.3679871.6862840.0819111.204734...1.0775720.3075630.1666100.942370-0.4195361.0031310.8718471.4522000.8517650.470634
..................................................................
123-0.243466-0.376759-0.222920-0.164613-0.123029-0.107526-0.491095-0.200305-0.075683-0.080205...-0.130376-0.0881580.005756-0.193393-0.067758-0.131498-0.122326-0.211535-0.343941-0.117467
124-0.237385-0.361383-0.176733-0.150867-0.107936-0.042114-0.411948-0.214185-0.082167-0.053210...-0.099505-0.143001-0.020947-0.184918-0.064551-0.152278-0.089314-0.181028-0.336784-0.087687
125-0.229653-0.344955-0.129858-0.138228-0.0933840.024250-0.334807-0.228415-0.090190-0.028539...-0.063867-0.199541-0.047168-0.171532-0.062122-0.173057-0.057060-0.150006-0.333920-0.053370
126-0.220961-0.326984-0.082629-0.126160-0.0807730.088925-0.260510-0.242954-0.099340-0.005383...-0.024045-0.256805-0.072422-0.153905-0.059526-0.193384-0.026572-0.118515-0.335199-0.014476
127-0.211589-0.307667-0.034203-0.113222-0.0701490.152874-0.191118-0.259736-0.1089930.016120...0.019210-0.314823-0.096437-0.133574-0.056565-0.2133580.001427-0.087783-0.3391150.027192
\n", "

128 rows × 100 columns

\n", "
" ], "text/plain": [ " y \\\n", " 0 1 2 3 4 5 6 \n", "0 1.264127 1.259498 0.664856 1.290327 -1.061115 0.063356 1.326666 \n", "1 1.263028 1.274143 0.643489 1.300232 -0.994251 0.108671 1.342242 \n", "2 1.260150 1.283785 0.621806 1.308611 -0.921023 0.157252 1.354430 \n", "3 1.256019 1.290381 0.602094 1.315111 -0.842858 0.207179 1.363444 \n", "4 1.249843 1.294567 0.584412 1.319314 -0.762038 0.258642 1.367987 \n", ".. ... ... ... ... ... ... ... \n", "123 -0.243466 -0.376759 -0.222920 -0.164613 -0.123029 -0.107526 -0.491095 \n", "124 -0.237385 -0.361383 -0.176733 -0.150867 -0.107936 -0.042114 -0.411948 \n", "125 -0.229653 -0.344955 -0.129858 -0.138228 -0.093384 0.024250 -0.334807 \n", "126 -0.220961 -0.326984 -0.082629 -0.126160 -0.080773 0.088925 -0.260510 \n", "127 -0.211589 -0.307667 -0.034203 -0.113222 -0.070149 0.152874 -0.191118 \n", "\n", " ... \\\n", " 7 8 9 ... 90 91 92 \n", "0 1.789418 -0.009024 1.259869 ... 0.939294 0.362207 0.011521 \n", "1 1.764897 0.008574 1.246439 ... 0.978143 0.343955 0.050809 \n", "2 1.740139 0.031373 1.232385 ... 1.014755 0.328073 0.091420 \n", "3 1.714233 0.056631 1.218001 ... 1.048197 0.315876 0.130009 \n", "4 1.686284 0.081911 1.204734 ... 1.077572 0.307563 0.166610 \n", ".. ... ... ... ... ... ... ... \n", "123 -0.200305 -0.075683 -0.080205 ... -0.130376 -0.088158 0.005756 \n", "124 -0.214185 -0.082167 -0.053210 ... -0.099505 -0.143001 -0.020947 \n", "125 -0.228415 -0.090190 -0.028539 ... -0.063867 -0.199541 -0.047168 \n", "126 -0.242954 -0.099340 -0.005383 ... -0.024045 -0.256805 -0.072422 \n", "127 -0.259736 -0.108993 0.016120 ... 0.019210 -0.314823 -0.096437 \n", "\n", " \n", " 93 94 95 96 97 98 99 \n", "0 1.053759 -0.409518 1.014478 0.751929 1.641282 0.826082 0.469116 \n", "1 1.024005 -0.418823 1.016179 0.779650 1.598208 0.826613 0.468406 \n", "2 0.993893 -0.423898 1.014795 0.809190 1.552288 0.830950 0.468252 \n", "3 0.965887 -0.424134 1.010628 0.840416 1.503579 0.839256 0.469101 \n", "4 0.942370 -0.419536 1.003131 0.871847 1.452200 0.851765 0.470634 \n", ".. ... ... ... ... ... ... ... \n", "123 -0.193393 -0.067758 -0.131498 -0.122326 -0.211535 -0.343941 -0.117467 \n", "124 -0.184918 -0.064551 -0.152278 -0.089314 -0.181028 -0.336784 -0.087687 \n", "125 -0.171532 -0.062122 -0.173057 -0.057060 -0.150006 -0.333920 -0.053370 \n", "126 -0.153905 -0.059526 -0.193384 -0.026572 -0.118515 -0.335199 -0.014476 \n", "127 -0.133574 -0.056565 -0.213358 0.001427 -0.087783 -0.339115 0.027192 \n", "\n", "[128 rows x 100 columns]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Define the sample inputs\n", "sample_inputs = pd.DataFrame({\"x\": np.linspace(0, 1, 128)})\n", "\n", "# Define number of samples to calculate for each input\n", "num_samples = 100\n", "\n", "# Calculate the samples using twinLab\n", "sample_result = emulator.sample(sample_inputs, num_samples)\n", "\n", "# View the results in the form of a dataframe\n", "display(sample_result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Viewing the results**\n", "\n", "The results can be plotted over the top of the previous graph giving a nice visualisation of the sampled data, with the model's uncertainity.\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot parameters\n", "color_curve = \"blue\"\n", "alpha_curve = 0.10\n", "color_data = \"black\"\n", "plot_training_data = True\n", "plot_model_bands = False\n", "\n", "# Plot samples drawn from the model\n", "if plot_training_data:\n", " plt.plot(df[\"x\"], df[\"y\"], \".\", color=color_data, label=\"Training data\")\n", "plt.plot(sample_inputs, sample_result[\"y\"], color=color_curve, alpha=alpha_curve)\n", "plt.xlim((0.0, 1.0))\n", "plt.xlabel(r\"$X$\")\n", "plt.ylabel(r\"$y$\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### **Deleting datasets and emulators**\n", "\n", "To keep your cloud storage tidy you should delete your datasets and emulators when you are finished with them. `Emulator.delete` and `Dataset.delete` deletes the emulators and the datasets from the cloud storage respectively.\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Delete dataset\n", "dataset.delete()\n", "\n", "# Delete emulator\n", "emulator.delete()" ] } ], "metadata": { "kernelspec": { "display_name": "twinlab-demos-5eekO54e-py3.11", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 2 }