Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ class AdjointGenerator
bool used =
unnecessaryInstructions.find(&I) == unnecessaryInstructions.end();
if (!used) {
// if decided to cache a value, preserve it here for later
// replacement in EnzymeLogic
auto found = gutils->knownRecomputeHeuristic.find(&I);
if (found != gutils->knownRecomputeHeuristic.end() &&
!gutils->unnecessaryIntermediates.count(&I))
if (found != gutils->knownRecomputeHeuristic.end() && !found->second)
used = true;
}
auto iload = gutils->getNewFromOriginal((Value *)&I);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/CacheUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class CacheUtility {
public:
virtual ~CacheUtility();

private:
protected:
/// Map of Loop to requisite loop information needed for AD (forward/reverse
/// induction/etc)
std::map<llvm::Loop *, LoopContext> loopContexts;
Expand Down
36 changes: 36 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4724,6 +4724,42 @@ void GradientUtils::computeMinCache(
}
}

SmallPtrSet<Instruction *, 3> NewLoopBoundReq;
{
std::deque<Instruction *> LoopBoundRequirements;

for (auto &context : loopContexts) {
for (auto val : {context.second.maxLimit, context.second.trueLimit}) {
if (auto inst = dyn_cast_or_null<Instruction>(val)) {
LoopBoundRequirements.push_back(inst);
}
}
}
SmallPtrSet<Instruction *, 3> Seen;
while (LoopBoundRequirements.size()) {
Instruction *val = LoopBoundRequirements.front();
LoopBoundRequirements.pop_front();
if (NewLoopBoundReq.count(val))
continue;
if (Seen.count(val))
continue;
Seen.insert(val);
if (auto orig = isOriginal(val)) {
NewLoopBoundReq.insert(orig);
} else {
for (auto &op : val->operands()) {
if (auto inst = dyn_cast<Instruction>(op)) {
LoopBoundRequirements.push_back(inst);
}
}
}
}
for (auto inst : NewLoopBoundReq) {
OneLevelSeen[UsageKey(inst, ValueType::Primal)] = true;
FullSeen[UsageKey(inst, ValueType::Primal)] = true;
}
}

for (BasicBlock &BB : *oldFunc) {
if (guaranteedUnreachable.count(&BB))
continue;
Expand Down
52 changes: 52 additions & 0 deletions enzyme/test/Integration/ReverseMode/metarwr.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme --enzyme-inline=1 -S | %lli -

#include <stdio.h>
#include <math.h>
#include <assert.h>

#include "test_utils.h"

void __enzyme_autodiff(void*, ...);

void call(double* __restrict__ a, long** data) {
long* segment = data[0];
long size = segment[1] - segment[0];
printf("seg[1]=%d seg[0]=%d\n", segment[1], segment[0]);
for (size_t i=0; i<size; i++)
a[i] *= 2;
data[0] = 0;
}

void alldiv(double* __restrict__ a, long** meta) {
call(a, meta);
a[0] = 0;
}

int main(int argc, char** argv) {

long meta[2] = { 198, 200 };
long* mmeta = (long*)meta;
double *val = malloc(200*sizeof(double));
val[1] = 7;
double *dval = malloc(200*sizeof(double));
dval[1] = 1;
double* a = (double*)val;
double* da = (double*)dval;

__enzyme_autodiff((void*)alldiv, (double*)val, (double*)dval, &mmeta);

printf("a = %p, da=%p\n", a, da);
printf("val=%f dval=%f\n", val[0], dval[0]);
printf("meta=%d\n", meta);
fflush(0);

APPROX_EQ(dval[1], 2.0, 1e-8);
return 0;
}