• Stars
    star
    1,742
  • Rank 25,545 (Top 0.6 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 6 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Petastorm library enables single machine or distributed training and evaluation of deep learning models from datasets in Apache Parquet format. It supports ML frameworks such as Tensorflow, Pytorch, and PySpark and can be used from pure Python code.

Petastorm

Build Status Code coverage License Latest Version

Petastorm is an open source data access library developed at Uber ATG. This library enables single machine or distributed training and evaluation of deep learning models directly from datasets in Apache Parquet format. Petastorm supports popular Python-based machine learning (ML) frameworks such as Tensorflow, PyTorch, and PySpark. It can also be used from pure Python code.

Documentation web site: https://petastorm.readthedocs.io

Installation

pip install petastorm

There are several extra dependencies that are defined by the petastorm package that are not installed automatically. The extras are: tf, tf_gpu, torch, opencv, docs, test.

For example to trigger installation of GPU version of tensorflow and opencv, use the following pip command:

pip install petastorm[opencv,tf_gpu]

Generating a dataset

A dataset created using Petastorm is stored in Apache Parquet format. On top of a Parquet schema, petastorm also stores higher-level schema information that makes multidimensional arrays into a native part of a petastorm dataset.

Petastorm supports extensible data codecs. These enable a user to use one of the standard data compressions (jpeg, png) or implement her own.

Generating a dataset is done using PySpark. PySpark natively supports Parquet format, making it easy to run on a single machine or on a Spark compute cluster. Here is a minimalistic example writing out a table with some random data.

import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType

from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField

# The schema defines how the dataset schema looks like
HelloWorldSchema = Unischema('HelloWorldSchema', [
    UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('image1', np.uint8, (128, 256, 3), CompressedImageCodec('png'), False),
    UnischemaField('array_4d', np.uint8, (None, 128, 30, None), NdarrayCodec(), False),
])


def row_generator(x):
    """Returns a single entry in the generated dataset. Return a bunch of random values as an example."""
    return {'id': x,
            'image1': np.random.randint(0, 255, dtype=np.uint8, size=(128, 256, 3)),
            'array_4d': np.random.randint(0, 255, dtype=np.uint8, size=(4, 128, 30, 3))}


def generate_petastorm_dataset(output_url='file:///tmp/hello_world_dataset'):
    rowgroup_size_mb = 256

    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
    sc = spark.sparkContext

    # Wrap dataset materialization portion. Will take care of setting up spark environment variables as
    # well as save petastorm specific metadata
    rows_count = 10
    with materialize_dataset(spark, output_url, HelloWorldSchema, rowgroup_size_mb):

        rows_rdd = sc.parallelize(range(rows_count))\
            .map(row_generator)\
            .map(lambda x: dict_to_spark_row(HelloWorldSchema, x))

        spark.createDataFrame(rows_rdd, HelloWorldSchema.as_spark_schema()) \
            .coalesce(10) \
            .write \
            .mode('overwrite') \
            .parquet(output_url)
  • HelloWorldSchema is an instance of a Unischema object. Unischema is capable of rendering types of its fields into different framework specific formats, such as: Spark StructType, Tensorflow tf.DType and numpy numpy.dtype.
  • To define a dataset field, you need to specify a type, shape, a codec instance and whether the field is nullable for each field of the Unischema.
  • We use PySpark for writing output Parquet files. In this example, we launch PySpark on a local box (.master('local[2]')). Of course for a larger scale dataset generation we would need a real compute cluster.
  • We wrap spark dataset generation code with the materialize_dataset context manager. The context manager is responsible for configuring row group size at the beginning and write out petastorm specific metadata at the end.
  • The row generating code is expected to return a Python dictionary indexed by a field name. We use row_generator function for that.
  • dict_to_spark_row converts the dictionary into a pyspark.Row object while ensuring schema HelloWorldSchema compliance (shape, type and is-nullable condition are tested).
  • Once we have a pyspark.DataFrame we write it out to a parquet storage. The parquet schema is automatically derived from HelloWorldSchema.

Plain Python API

The petastorm.reader.Reader class is the main entry point for user code that accesses the data from an ML framework such as Tensorflow or Pytorch. The reader has multiple features such as:

  • Selective column readout
  • Multiple parallelism strategies: thread, process, single-threaded (for debug)
  • N-grams readout support
  • Row filtering (row predicates)
  • Shuffling
  • Partitioning for multi-GPU training
  • Local caching

Reading a dataset is simple using the petastorm.reader.Reader class which can be created using the petastorm.make_reader factory method:

from petastorm import make_reader

 with make_reader('hdfs://myhadoop/some_dataset') as reader:
    for row in reader:
        print(row)

hdfs://... and file://... are supported URL protocols.

Once a Reader is instantiated, you can use it as an iterator.

Tensorflow API

To hookup the reader into a tensorflow graph, you can use the tf_tensors function:

from petastorm.tf_utils import tf_tensors

with make_reader('file:///some/localpath/a_dataset') as reader:
   row_tensors = tf_tensors(reader)
   with tf.Session() as session:
       for _ in range(3):
           print(session.run(row_tensors))

Alternatively, you can use new tf.data.Dataset API;

from petastorm.tf_utils import make_petastorm_dataset

with make_reader('file:///some/localpath/a_dataset') as reader:
    dataset = make_petastorm_dataset(reader)
    iterator = dataset.make_one_shot_iterator()
    tensor = iterator.get_next()
    with tf.Session() as sess:
        sample = sess.run(tensor)
        print(sample.id)

Pytorch API

As illustrated in pytorch_example.py, reading a petastorm dataset from pytorch can be done via the adapter class petastorm.pytorch.DataLoader, which allows custom pytorch collating function and transforms to be supplied.

Be sure you have torch and torchvision installed:

pip install torchvision

The minimalist example below assumes the definition of a Net class and train and test functions, included in pytorch_example:

import torch
from petastorm.pytorch import DataLoader

torch.manual_seed(1)
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def _transform_row(mnist_row):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    return (transform(mnist_row['image']), mnist_row['digit'])


transform = TransformSpec(_transform_row, removed_fields=['idx'])

with DataLoader(make_reader('file:///localpath/mnist/train', num_epochs=10,
                            transform_spec=transform, seed=1, shuffle_rows=True), batch_size=64) as train_loader:
    train(model, device, train_loader, 10, optimizer, 1)
with DataLoader(make_reader('file:///localpath/mnist/test', num_epochs=10,
                            transform_spec=transform), batch_size=1000) as test_loader:
    test(model, device, test_loader)

If you are working with very large batch sizes and do not need support for Decimal/strings we provide a petastorm.pytorch.BatchedDataLoader that can buffer using Torch tensors (cpu or cuda) with a signficantly higher throughput.

If the size of your dataset can fit into system memory, you can use an in-memory version dataloader petastorm.pytorch.InMemBatchedDataLoader. This dataloader only reades the dataset once, and caches data in memory to avoid additional I/O for multiple epochs.

Spark Dataset Converter API

Spark converter API simplifies the data conversion from Spark to TensorFlow or PyTorch. The input Spark DataFrame is first materialized in the parquet format and then loaded as a tf.data.Dataset or torch.utils.data.DataLoader.

The minimalist example below assumes the definition of a compiled tf.keras model and a Spark DataFrame containing a feature column followed by a label column.

from petastorm.spark import SparkDatasetConverter, make_spark_converter
import tensorflow.compat.v1 as tf  # pylint: disable=import-error

# specify a cache dir first.
# the dir is used to save materialized spark dataframe files
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, 'hdfs:/...')

df = ... # `df` is a spark dataframe

# create a converter from `df`
# it will materialize `df` to cache dir.
converter = make_spark_converter(df)

# make a tensorflow dataset from `converter`
with converter.make_tf_dataset() as dataset:
    # the `dataset` is `tf.data.Dataset` object
    # dataset transformation can be done if needed
    dataset = dataset.map(...)
    # we can train/evaluate model on the `dataset`
    model.fit(dataset)
    # when exiting the context, the reader of the dataset will be closed

# delete the cached files of the dataframe.
converter.delete()

The minimalist example below assumes the definition of a Net class and train and test functions, included in pytorch_example.py, and a Spark DataFrame containing a feature column followed by a label column.

from petastorm.spark import SparkDatasetConverter, make_spark_converter

# specify a cache dir first.
# the dir is used to save materialized spark dataframe files
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, 'hdfs:/...')

df_train, df_test = ... # `df_train` and `df_test` are spark dataframes
model = Net()

# create a converter_train from `df_train`
# it will materialize `df_train` to cache dir. (the same for df_test)
converter_train = make_spark_converter(df_train)
converter_test = make_spark_converter(df_test)

# make a pytorch dataloader from `converter_train`
with converter_train.make_torch_dataloader() as dataloader_train:
    # the `dataloader_train` is `torch.utils.data.DataLoader` object
    # we can train model using the `dataloader_train`
    train(model, dataloader_train, ...)
    # when exiting the context, the reader of the dataset will be closed

# the same for `converter_test`
with converter_test.make_torch_dataloader() as dataloader_test:
    test(model, dataloader_test, ...)

# delete the cached files of the dataframes.
converter_train.delete()
converter_test.delete()

Analyzing petastorm datasets using PySpark and SQL

A Petastorm dataset can be read into a Spark DataFrame using PySpark, where you can use a wide range of Spark tools to analyze and manipulate the dataset.

# Create a dataframe object from a parquet file
dataframe = spark.read.parquet(dataset_url)

# Show a schema
dataframe.printSchema()

# Count all
dataframe.count()

# Show a single column
dataframe.select('id').show()

SQL can be used to query a Petastorm dataset:

spark.sql(
   'SELECT count(id) '
   'from parquet.`file:///tmp/hello_world_dataset`').collect()

You can find a full code sample here: pyspark_hello_world.py,

Non Petastorm Parquet Stores

Petastorm can also be used to read data directly from Apache Parquet stores. To achieve that, use make_batch_reader (and not make_reader). The following table summarizes the differences make_batch_reader and make_reader functions.

make_reader make_batch_reader
Only Petastorm datasets (created using materializes_dataset) Any Parquet store (some native Parquet column types are not supported yet.
The reader returns one record at a time. The reader returns batches of records. The size of the batch is not fixed and defined by Parquet row-group size.
Predicates passed to make_reader are evaluated per single row. Predicates passed to make_batch_reader are evaluated per batch.
Can filter parquet file based on the filters argument. Can filter parquet file based on the filters argument

Troubleshooting

See the Troubleshooting page and please submit a ticket if you can't find an answer.

See also

  1. Gruener, R., Cheng, O., and Litvin, Y. (2018) Introducing Petastorm: Uber ATG's Data Access Library for Deep Learning. URL: https://eng.uber.com/petastorm/
  2. QCon.ai 2019: "Petastorm: A Light-Weight Approach to Building ML Pipelines".

How to Contribute

We prefer to receive contributions in the form of GitHub pull requests. Please send pull requests against the github.com/uber/petastorm repository.

  • If you are looking for some ideas on what to contribute, check out github issues and comment on the issue.
  • If you have an idea for an improvement, or you'd like to report a bug but don't have time to fix it please a create a github issue.

To contribute a patch:

  • Break your work into small, single-purpose patches if possible. It's much harder to merge in a large change with a lot of disjoint features.
  • Submit the patch as a GitHub pull request against the master branch. For a tutorial, see the GitHub guides on forking a repo and sending a pull request.
  • Include a detailed describtion of the proposed change in the pull request.
  • Make sure that your code passes the unit tests. You can find instructions how to run the unit tests here.
  • Add new unit tests for your code.

Thank you in advance for your contributions!

See the Development for development related information.

More Repositories

1

react-vis

Data Visualization Components
JavaScript
8,653
star
2

baseweb

A React Component library implementing the Base design language
TypeScript
8,611
star
3

cadence

Cadence is a distributed, scalable, durable, and highly available orchestration engine to execute asynchronous long-running business logic in a scalable and resilient way.
Go
7,766
star
4

RIBs

Uber's cross-platform mobile architecture framework.
Kotlin
7,661
star
5

kraken

P2P Docker registry capable of distributing TBs of data in seconds
Go
5,817
star
6

prototool

Your Swiss Army Knife for Protocol Buffers
Go
5,053
star
7

causalml

Uplift modeling and causal inference with machine learning algorithms
Python
4,720
star
8

h3

Hexagonal hierarchical geospatial indexing system
C
4,566
star
9

NullAway

A tool to help eliminate NullPointerExceptions (NPEs) in your Java code with low build-time overhead
Java
3,517
star
10

AutoDispose

Automatic binding+disposal of RxJava streams.
Java
3,357
star
11

aresdb

A GPU-powered real-time analytics storage and query engine.
Go
2,981
star
12

react-digraph

A library for creating directed graph editors
JavaScript
2,581
star
13

piranha

A tool for refactoring code related to feature flag APIs
Java
2,219
star
14

orbit

A Python package for Bayesian forecasting with object-oriented design and probabilistic models under the hood.
Python
1,797
star
15

ios-snapshot-test-case

Snapshot view unit tests for iOS
Objective-C
1,763
star
16

needle

Compile-time safe Swift dependency injection framework
Swift
1,740
star
17

manifold

A model-agnostic visual debugging tool for machine learning
JavaScript
1,637
star
18

okbuck

OkBuck is a gradle plugin that lets developers utilize the Buck build system on a gradle project.
Java
1,536
star
19

UberSignature

Provides an iOS view controller allowing a user to draw their signature with their finger in a realistic style.
Objective-C
1,283
star
20

nanoscope

An extremely accurate Android method tracing tool.
HTML
1,239
star
21

tchannel

network multiplexing and framing protocol for RPC
Thrift
1,150
star
22

queryparser

Parsing and analysis of Vertica, Hive, and Presto SQL.
Haskell
1,067
star
23

fiber

Distributed Computing for AI Made Simple
Python
1,037
star
24

neuropod

A uniform interface to run deep learning models from multiple frameworks
C++
928
star
25

uReplicator

Improvement of Apache Kafka Mirrormaker
Java
894
star
26

pam-ussh

uber's ssh certificate pam module
Go
832
star
27

ringpop-go

Scalable, fault-tolerant application-layer sharding for Go applications
Go
813
star
28

h3-js

h3-js provides a JavaScript version of H3, a hexagon-based geospatial indexing system.
JavaScript
796
star
29

mockolo

Efficient Mock Generator for Swift
Swift
770
star
30

xviz

A protocol for real-time transfer and visualization of autonomy data
JavaScript
760
star
31

h3-py

Python bindings for H3, a hierarchical hexagonal geospatial indexing system
Python
751
star
32

streetscape.gl

Visualization framework for autonomy and robotics data encoded in XVIZ
JavaScript
702
star
33

react-view

React View is an interactive playground, documentation and code generator for your components.
TypeScript
686
star
34

nebula.gl

A suite of 3D-enabled data editing overlays, suitable for deck.gl
TypeScript
659
star
35

RxDogTag

Automatic tagging of RxJava 2+ originating subscribe points for onError() investigation.
Java
645
star
36

peloton

Unified Resource Scheduler to co-schedule mixed types of workloads such as batch, stateless and stateful jobs in a single cluster for better resource utilization.
Go
636
star
37

motif

A simple DI API for Android / Java
Kotlin
531
star
38

signals-ios

Typeful eventing
Objective-C
526
star
39

tchannel-go

Go implementation of a multiplexing and framing protocol for RPC calls
Go
479
star
40

grafana-dash-gen

grafana dash dash dash gen
JavaScript
473
star
41

marmaray

Generic Data Ingestion & Dispersal Library for Hadoop
Java
473
star
42

zanzibar

A build system & configuration system to generate versioned API gateways.
Go
451
star
43

clay

Clay is a framework for building RESTful backend services using best practices. It’s a wrapper around Flask.
Python
441
star
44

astro

Astro is a tool for managing multiple Terraform executions as a single command
Go
429
star
45

NEAL

πŸ”ŽπŸž A language-agnostic linting platform
OCaml
423
star
46

react-vis-force

d3-force graphs as React Components.
JavaScript
401
star
47

arachne

An always-on framework that performs end-to-end functional network testing for reachability, latency, and packet loss
Go
387
star
48

cadence-web

Web UI for visualizing workflows on Cadence
JavaScript
378
star
49

Python-Sample-Application

Python
375
star
50

rides-ios-sdk

Uber Rides iOS SDK (beta)
Swift
366
star
51

stylist

A stylist creates cool styles. Stylist is a Gradle plugin that codegens a base set of Android XML themes.
Kotlin
355
star
52

storagetapper

StorageTapper is a scalable realtime MySQL change data streaming, logical backup and logical replication service
Go
333
star
53

swift-concurrency

Concurrency utilities for Swift
Swift
323
star
54

RemoteShuffleService

Remote shuffle service for Apache Spark to store shuffle data on remote servers.
Java
316
star
55

cyborg

Display Android Vectordrawables on iOS.
Swift
300
star
56

rides-android-sdk

Uber Rides Android SDK (beta)
Java
287
star
57

h3-go

Go bindings for H3, a hierarchical hexagonal geospatial indexing system
Go
279
star
58

h3-java

Java bindings for H3, a hierarchical hexagonal geospatial indexing system
Java
258
star
59

h3-py-notebooks

Jupyter notebooks for h3-py, a hierarchical hexagonal geospatial indexing system
Jupyter Notebook
244
star
60

hermetic_cc_toolchain

Bazel C/C++ toolchain for cross-compiling C/C++ programs
Starlark
230
star
61

geojson2h3

Conversion utilities between H3 indexes and GeoJSON
JavaScript
214
star
62

artist

An artist creates views. Artist is a Gradle plugin that codegens a base set of Android Views.
Kotlin
210
star
63

tchannel-node

JavaScript
205
star
64

RxCentralBle

A reactive, interface-driven central role Bluetooth LE library for Android
Java
199
star
65

uberalls

Track code coverage metrics with Jenkins and Phabricator
Go
187
star
66

SwiftCodeSan

SwiftCodeSan is a tool that "sanitizes" code written in Swift.
Swift
172
star
67

rides-python-sdk

Uber Rides Python SDK (beta)
Python
170
star
68

doubles

Test doubles for Python.
Python
166
star
69

logtron

A logging MACHINE
JavaScript
158
star
70

cadence-java-client

Java framework for Cadence Workflow Service
Java
139
star
71

athenadriver

A fully-featured AWS Athena database driver (+ athenareader https://github.com/uber/athenadriver/tree/master/athenareader)
Go
138
star
72

cassette

Store and replay HTTP requests made in your Python app
Python
138
star
73

UBTokenBar

Flexible and extensible UICollectionView based TokenBar written in Swift
Swift
136
star
74

tchannel-java

A Java implementation of the TChannel protocol.
Java
133
star
75

bayesmark

Benchmark framework to easily compare Bayesian optimization methods on real machine learning tasks
Python
128
star
76

android-template

This template provides a starting point for open source Android projects at Uber.
Java
127
star
77

crumb

An annotation processor for breadcrumbing metadata across compilation boundaries.
Kotlin
122
star
78

py-find-injection

Look for SQL injection attacks in python source code
Python
119
star
79

rides-java-sdk

Uber Rides Java SDK (beta)
Java
102
star
80

startup-reason-reporter

Reports the reason why an iOS App started.
Objective-C
96
star
81

uber-poet

A mock swift project generator & build runner to help benchmark various module dependency graphs.
Python
94
star
82

cadence-java-samples

Java
93
star
83

charlatan

A Python library to efficiently manage and install database fixtures
Python
89
star
84

swift-abstract-class

Compile-time abstract class validation for Swift
Swift
83
star
85

simple-store

Simple yet performant asynchronous file storage for Android
Java
81
star
86

tchannel-python

Python implementation of the TChannel protocol.
Python
77
star
87

eight-track

Record and playback HTTP requests
JavaScript
70
star
88

client-platform-engineering

A collection of cookbooks, scripts and binaries used to manage our macOS, Ubuntu and Windows endpoints
Ruby
70
star
89

multidimensional_urlencode

Python library to urlencode a multidimensional dict
Python
67
star
90

lint-checks

A set of opinionated and useful lint checks
Kotlin
67
star
91

uncaught-exception

Handle uncaught exceptions.
JavaScript
66
star
92

swift-common

Common code used by various Uber open source projects
Swift
65
star
93

uberscriptquery

UberScriptQuery, a SQL-like DSL to make writing Spark jobs super easy
Java
58
star
94

sentry-logger

A Sentry transport for Winston
JavaScript
55
star
95

graph.gl

WebGL2-Powered Visualization Components for Graph Visualization
JavaScript
50
star
96

nanoscope-art

C++
48
star
97

assume-role-cli

CLI for AssumeRole is a tool for running programs with temporary credentials from AWS's AssumeRole API.
Go
47
star
98

airlock

A prober to probe HTTP based backends for health
JavaScript
47
star
99

mutornadomon

Easy-to-install monitor endpoint for Tornado applications
Python
46
star
100

kafka-logger

A kafka logger for winston
JavaScript
45
star