• This repository has been archived on 02/Nov/2021
  • Stars
    star
    129
  • Rank 279,262 (Top 6 %)
  • Language
    Lua
  • License
    Apache License 2.0
  • Created over 7 years ago
  • Updated over 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

This project implements random forests and gradient boosted decision trees (GBDT). The latter uses gradient tree boosting. Both use ensemble learning to produce ensembles of decision trees (that is, forests).

Torch decision tree library

local dt = require 'decisiontree'

This project implements random forests and gradient boosted decision trees (GBDT). The latter uses gradient tree boosting. Both use ensemble learning to produce ensembles of decision trees (that is, forests).

nn.DFD

One practical application for decision forests is to discretize an input feature space into a richer output feature space. The nn.DFD Module can be used as a decision forest discretizer (DFD):

local dfd = nn.DFD(df, onlyLastNode)

where df is a dt.DecisionForest instance or the table returned by the method getReconstructionInfo() on another nn.DFD module, and onlyLastNode is a boolean that indicates that module should return only the id of the last node visited on each tree (by default it outputs all traversed nodes except for the roots). The nn.DFD module requires dense input tensors. Sparse input tensors (tables of tensors) are not supported. The output returned by a call to updateOutput is a batch of sparse tensors. This output where output[1] and output[2] are a respectively a list of key and value tensors:

{
  { [torch.LongTensor], ... , [torch.LongTensor] },
  { [torch.Tensor], ... , [torch.Tensor] }
}

This module doesn't support CUDA.

Example

As a concrete example, let us first train a Random Forest on a dummy dense dataset:

local nExample = 100
local batchsize = 2
local inputsize = 10

-- train Random Forest
local trainSet = dt.getDenseDummyData(nExample, nil, inputsize)
local opt = {
   activeRatio=0.5,
   featureBaggingSize=5,
   nTree=4,
   maxLeafNodes=nExample/2,
   minLeafSize=nExample/10,
}
local trainer = dt.RandomForestTrainer(opt)
local df = trainer:train(trainSet, trainSet.featureIds)
mytester:assert(#df.trees == opt.nTree)

Now that we have df, a dt.DecisionForest instance, we can use it to initialize nn.DFD:

local dfd = nn.DFD(df)

The dfd instance holds no reference to df, instead it extracts the relevant attributes from df. These attributes are stored in tensors for batching and efficiency.

We can discretize a hypothetical input by calling forward:

local input = trainSet.input:sub(1,batchsize)
local output = dfd:forward(input)

The resulting output is a table consisting of two tables: keys and values. The keys and values tables each contains batchsize tensors:

print(output)
{
  1 :
    {
      1 : LongTensor - size: 14
      2 : LongTensor - size: 16
      3 : LongTensor - size: 15
      4 : LongTensor - size: 13
    }
  2 :
    {
      1 : DoubleTensor - size: 14
      2 : DoubleTensor - size: 16
      3 : DoubleTensor - size: 15
      4 : DoubleTensor - size: 13
    }
}

An example's feature keys (LongTensor) and commensurate values (DoubleTensor) have the same number of elements. The examples have variable number of key-value pairs representing the nodes traversed in the tree. The output feature space has as many dimensions (that is, possible feature keys) for each node in the forest.

torch.SparseTensor

Suppose you have a set of keys mapped to values:

local keys = torch.LongTensor{1,3,4,7,2}
local values = torch.Tensor{0.1,0.3,0.4,0.7,0.2}

You can use a SparseTensor to encapsulate these into a read-only tensor:

local st = torch.SparseTensor(input, target)

The decisiontree library uses SparseTensors to simulate the __index method of the torch.Tensor. For example, one can obtain the value associated to key 3 of the above st instance:

local value = st[3]
assert(value == 0.3)

When the key,value pair are missing, nil is returned instead:

local value = st[2]
assert(value == nil)

The best implementation for this kind of indexing is slow (it uses a sequential scan of the keys). To speedup indexing, one can call the buildIndex()` method before hand:

st:buildIndex()

The buildIndex() creates a hash map (a Lua table) of keys to their commensurate indices in the values table.

dt.DataSet

The CartTrainer, RandomForestTrainer and GradientBoostTrainer require that data sets be encapsulated into a DataSet. Suppose you have a dataset of dense inputs and targets:

local nExample = 10
local nFeature = 5
local input = torch.randn(nExample, nFeature)
local target = torch.Tensor(nExample):random(0,1)

these can be encapsulated into a DataSet object:

local dataset = dt.DataSet(input, target)

Now suppose you have a dataset where the input is a table of SparseTensor instances:

local input = {}
for i=1,nExample do
   local nKeyVal = math.random(2,nFeature)
   local keys = torch.LongTensor(nKeyVal):random(1,nFeature)
   local values = torch.randn(nKeyVal)
   input[i] = torch.SparseTensor(keys, values)
end

You can still use a DataSet to encapsulate the sparse dataset:

local dataset = dt.DataSet(input, target)

The main purpose of the DataSet class is to sort each feature by value. This is captured by the sortFeatureValues(input) method, which is called in the constructor:

local sortedFeatureValues, featureIds = self:sortFeatureValues(input)

The featureIds is a torch.LongTensor of all available feature IDs. For a dense input tensor, this is just torch.LongTensor():range(1,input:size(2)). But for a sparse input tensor, the featureIds tensor only contains the feature IDs present in the dataset.

The resulting sortedFeatureValues is a table mapping featureIds to exampleIds sorted by featureValues. For each featureId, examples are sorted by featureValue in ascending order. For example, the table might look like: {featureId=exampleIds} where examplesIds={1,3,2}.

The CartTrainer accesses the sortedFeatureValues tensor by calling getSortedFeature(featureId):

local exampleIdsWithFeature = dataset:getSortedFeature(featureId)

The ability to access examples IDs sorted by feature value, given a feature ID, is the main purpose of the DataSet. The CartTrainer relies on these sorted lists to find the best way to split a set of examples between two tree nodes.

dt.CartTrainer

local trainer = dt.CartTrainer(dataset, minLeafSize, maxLeafNodes)

The CartTrainer is used by the RandomForestTrainer and GradientBoostTrainer to train individual trees. CART stands for classification and regression trees. However, only binary classifiers are unit tested.

The constructor takes the following arguments:

  • dataset is a dt.DataSet instance representing the training set.
  • minLeafSize is the minimum examples per leaf node in a tree. The larger the value, the more regularization.
  • maxLeafNodes is the maximum nodes in the tree. The lower the value, the more regularization.

Training is initiated by calling the train() method:

local trainSet = dt.DataSet(input, target)
local rootTreeState = dt.GiniState(trainSet:getExampleIds())
local activeFeatures = trainSet.featureIds
local tree = trainer:train(rootTreeState, activeFeatures)

The resulting tree is a CartTree instance. The rootTreeState is a TreeState instance like GiniState (used by RandomForestTrainer) or GradientBoostState (used by GradientBoostTrainer). The activeFeatures is a LongTensor of feature IDs that used to build the tree. Every other feature ID is ignored during training. This is useful for feature bagging.

By default the CartTrainer runs in a single-thread. The featureParallel(nThread) method can be called before calling train() to parallelize training using nThread workers:

local nThread = 3
trainer:featureParallel(nThread)
trainer:train(rootTreeState, activeFeatures)

Feature parallelization assigns a set of features IDs to each thread.

The CartTrainer can be used as a stand-alone tree trainer. But it is recommended to use it within the context of a RandomForestTrainer or GradientBoostTrainer instead. The latter typically generalize better.

RandomForestTrainer

The RandomForestTrainer is used to train a random forest:

local nExample = trainSet:size()
local opt = {
   activeRatio=0.5,
   featureBaggingSize=5,
   nTree=14,
   maxLeafNodes=nExample/2,
   minLeafSize=nExample/10,
}
local trainer = dt.RandomForestTrainer(opt)
local forest = trainer:train(trainSet, trainSet.featureIds)

The returned forest is a DecisionForest instance. A DecisionForest has a similar interface to the CartTree. Indeed, they both sub-class the DecisionTree abstract class.

The constructor takes a single opt table argument, which contains the actual arguments:

  • activeRatio is the ratio of active examples per tree. This is used for boostrap sampling.
  • featureBaggingSize is the number of features per tree. This is also used fpr feature bagging.
  • nTree is the number of trees to be trained.
  • maxLeafNodes and minLeafSize are passed to the underlying CartTrainer constructor (controls regularization).

Internally, the RandomForestTrainer passes a GiniBoostState to the CartTrainer:train() method.

Training can be parallelized by calling treeParallel(nThread):

local nThread = 3
trainer:treeParallel(nThread)
local forest = trainer:train(trainSet, trainSet.featureIds)

Training then parallelizes by training each tree in its own thread worker.

GradientBoostTrainer

References:

Graient boosted decision trees (GBDT) can be trained as follows:

local nExample = trainSet:size()
local maxLeafNode, minLeafSize = nExample/2, nExample/10
local cartTrainer = dt.CartTrainer(trainSet, minLeafSize, maxLeafNode)

local opt = {
  lossFunction=nn.LogitBoostCriterion(false),
  treeTrainer=cartTrainer,
  shrinkage=0.1,
  downsampleRatio=0.8,
  featureBaggingSize=-1,
  nTree=14,
  evalFreq=8,
  earlyStop=0
}

local trainer = dt.GradientBoostTrainer(opt)
local forest = trainer:train(trainSet, trainSet.featureIds, validSet)

The above code snippet uses the LogitBoostCriterion outlined in reference A. It is used for training binary classification trees.

The returned forest is a DecisionForest instance. A DecisionForest has a similar interface to the CartTree. Indeed, they both sub-class the DecisionTree abstract class.

The constructor takes a single opt table argument, which contains the actual arguments:

  • lossFunction is a nn.Criterion instance extended to include the updateHessInput(input, target) and backward2(input, target). These return the hessian of the input.
  • treeTrainer is a CartTrainer instance. Its featureParallel() method can be called to implement feature parallelization.
  • shrinkage is the weight of each additional tree.
  • downsampleRatio is the ratio of examples to be sampled for each tree. Used for bootstrap sampling.
  • featureBaggingSize is the number of features to sample per tree. Used for feature bagging. -1 defaults to torch.round(math.sqrt(featureIds:size(1)))
  • nTree is the maximum number of trees.
  • evalFreq is the number of epochs between calls to validate() for cross-validation and early-stopping.
  • earlyStop is the maximum number of epochs to wait for early-stopping.

Internally, the GradientBoostTrainer passes a GradientBoostState to the CartTrainer:train() method.

TreeState

An abstract class that holds the state of a subtree during decision tree training. It also manages the state of candidate splits.

local treeState = dt.TreeState(exampleIds)

The exampleIds argument is a LongTensor containing the example IDs that make up the sub-tree.

GiniState

A TreeState subclass used internally by the RandomForestTrainer. Uses Gini impurity to determine how to split trees.

local treeState = dt.GiniState(exampleIds)

The exampleIds argument is a LongTensor containing the example IDs that make up the sub-tree.

GradientBoostState

A TreeState subclass used internally by the GradientBoostTrainer. It implements the GBDT spliting algorithm, which uses a loss function.

local treeState = dt.GradientBoostState(exampleIds, lossFunction)

The exampleIds argument is a LongTensor containing the example IDs that make up the sub-tree. The lossFunction is an nn.Criterion instance (see GradientBoostTrainer).

WorkPool

Utility class that simplifies construction of a pool of daemon threads with which to execute tasks in parallel.

local workpool = dt.WorkPool(nThread)

CartTree

Implements a trained CART decision tree:

local tree = nn.CartTree(rootNode)

The rootNode is a CartNode instance. Each CartNode contains pointers to left and right branches, which are themselves CartNode instances.

For inference, use the score(input) method:

local score = tree:score(input)

More Repositories

1

snowflake

Snowflake is a network service for generating unique ID numbers at high scale with some simple guarantees.
Scala
7,648
star
2

diffy

Find potential bugs in your services with Diffy
Scala
3,825
star
3

flockdb

A distributed, fault-tolerant graph database
Scala
3,337
star
4

kestrel

simple, distributed message queue system (inactive)
Scala
2,774
star
5

twui

A UI framework for Mac based on Core Animation
Objective-C
2,740
star
6

CocoaSPDY

SPDY for iOS and OS X
Objective-C
2,389
star
7

gizzard

[Archived] A flexible sharding framework for creating eventually-consistent distributed datastores
Scala
2,256
star
8

distributedlog

A high performance replicated log service. (The development is moved to Apache Incubator)
Java
2,224
star
9

recess

A simple and attractive code quality tool for CSS built on top of LESS
CSS
2,187
star
10

commons

Twitter common libraries for python and the JVM (deprecated)
Java
2,099
star
11

iago

A load generator, built for engineers
Scala
1,347
star
12

twitter-text-js

A JavaScript implementation of Twitter's text processing library
1,211
star
13

ambrose

A platform for visualization and real-time monitoring of data workflows
Java
1,181
star
14

twitter-kit-android

Twitter Kit for Android
Java
831
star
15

ostrich

A stats collector & reporter for Scala servers (deprecated)
Scala
773
star
16

twitter-kit-ios

Twitter Kit is a native SDK to include Twitter content inside mobile apps.
Objective-C
690
star
17

twitter-text-rb

A library that does auto linking and extraction of usernames, lists and hashtags in tweets
613
star
18

mysos

Cotton (formerly known as Mysos)
590
star
19

twitter-text-objc

An Objective-C implementation of Twitter's text processing library
587
star
20

torch-autograd

Autograd automatically differentiates native Torch code
Lua
559
star
21

ospriet

An example audience moderation app built on Twitter
JavaScript
408
star
22

cloudhopper-smpp

Efficient, scalable, and flexible Java implementation of the Short Messaging Peer to Peer Protocol (SMPP)
Java
382
star
23

twitter-text-java

A Java implementation of Twitter's text processing library
364
star
24

jvmgcprof

A simple utility for profile allocation and garbage collection activity in the JVM
C
342
star
25

css-flip

A CSS BiDi flipper
JavaScript
313
star
26

clockworkraven

Human-Powered Data Analysis with Mechanical Turk
Ruby
300
star
27

torch-twrl

Torch-twrl is a package that enables reinforcement learning in Torch.
Lua
251
star
28

cassie

A Scala client for Cassandra
Scala
244
star
29

twemperf

A tool for measuring memcached server performance
C
242
star
30

hdfs-du

Visualize your HDFS cluster usage
JavaScript
230
star
31

pycascading

A Python wrapper for Cascading
Python
222
star
32

RTLtextarea

Automatically detects RTL and configures a text input
JavaScript
169
star
33

haplocheirus

A Redis-backed storage engine for timelines
Scala
133
star
34

standard-project

A slightly more standard sbt project plugin library
Scala
132
star
35

elephant-twin

Elephant Twin is a framework for creating indexes in Hadoop
Java
96
star
36

torch-ipc

A set of primitives for parallel computation in Torch
C
95
star
37

torch-distlearn

A set of distributed learning algorithms for Torch
Lua
93
star
38

libcrunch

A lightweight mapping framework that maps data objects to a number of nodes, subject to constraints
Java
92
star
39

scribe

A Ruby client library for Scribe
Ruby
90
star
40

sbt-package-dist

sbt 11 plugin codifying best practices for building, packaging, and publishing
Scala
88
star
41

twisitor

A simple and spectacular photo-tweeting birdhouse
JavaScript
84
star
42

flockdb-client

A Ruby client library for FlockDB
Ruby
81
star
43

code-of-conduct

Open Source Code of Conduct at Twitter
80
star
44

twitter-text-conformance

Conformance testing data for the twitter-text-* repositories
77
star
45

torch-dataset

An extensible and high performance method of reading, sampling and processing data for Torch
Lua
76
star
46

cdk

CDK is a tool to quickly generate single-file html slide presentations from AsciiDoc
CSS
74
star
47

naggati2

Protocol builder for netty using scala (DEPRECATED)
Scala
74
star
48

twitter-kit-unity

Twitter Kit for Unity
C#
71
star
49

plumage.js

Batteries Included App Framework for Data Intensive UIs
JavaScript
66
star
50

gozer

Prototype mesos framework using new low-level API built in Go
Go
61
star
51

bookkeeper

Twitter's fork of Apache BookKeeper (will push changes upstream eventually)
Java
59
star
52

grabby-hands

A JVM Kestrel client that aggregates queues from multiple servers. Implemented in Scala with Java bindings. In use at Twitter for all JVM Search and Streaming Kestrel interactions.
Scala
56
star
53

gizzmo

A command-line client for Gizzard
Ruby
54
star
54

thrift

Twitter's out-of-date, forked thrift
C++
53
star
55

libkestrel

libkestrel
Scala
47
star
56

time_constants

Time constants, in seconds, so you don't have to use slow ActiveSupport helpers
Ruby
47
star
57

sbt-scrooge

An SBT plugin that adds a mixin for doing Thrift code auto-generation during your compile phase
Scala
44
star
58

cli-guide.js

CLI Guide JQuery Plugin
JavaScript
41
star
59

sbt-thrift

sbt rules for generating source stubs out of thrift IDLs, for java & scala
Ruby
38
star
60

jaqen

A type-safe heterogenous Map or a Named field Tuple
Scala
35
star
61

spitball

A very simple gem package generation tool built on bundler
Ruby
33
star
62

torch-thrift

A Thrift codec for Torch
C
29
star
63

jsr166e

JSR166e for Twitter
Java
27
star
64

unishark

Unishark: Another unittest extension for Python
Python
26
star
65

raggiana

A simple standalone Finagle stats viewer
JavaScript
21
star
66

sekhmet

foundational tools and building blocks for gaining insights and diagnosing system health in real-time
20
star
67

periscope-live-engagement-unity-sdk

Periscope Live Engagement Unity SDK
C#
20
star
68

twitterActors

Improved Scala actors library; used internally at Twitter
Scala
19
star
69

finatra-activator-http-seed

Typesafe activator template for constructing a Finatra HTTP server application:
Scala
18
star
70

killdeer

Killdeer is a simple server for replaying a sample of responses to sythentically recreate production response characteristics.
Scala
16
star
71

elephant-twin-lzo

Elephant Twin LZO uses Elephant Twin to create LZO block indexes
Java
15
star
72

bittern

Bittern Cache uses nvdimm to speed up block io operations
C
14
star
73

finatra-activator-thrift-seed

Typesafe activator template for constructing a Finatra Thrift server application: https://twitter.github.io/finatra/user-guide/ β€”
Scala
11
star
74

chainsaw

A thin Scala wrapper for SLF4J
Scala
10
star
75

PerfTracepoint

Perf tracepoint support for the JVM
Java
7
star
76

oscon-puzzles

OSCON 2014 Puzzle
JavaScript
7
star
77

scala-json

JSON in Scala (deprecated)
Scala
5
star
78

scala-csp-config

A Scala library for configuring Content Security Policy headers for HTTP responses.
Scala
4
star
79

.github

3
star
80

finatra-misc

Miscellaneous libraries and utils used by Finatra
Scala
3
star
81

autolog-clustering

USF Capstone Project for Auto-log Clustering
Python
1
star